]> www.ginac.de Git - ginac.git/commitdiff
Explicit derivation of functions.
authorVladimir V. Kisil <kisilv@maths.leeds.ac.uk>
Sun, 8 Feb 2015 19:50:51 +0000 (20:50 +0100)
committerRichard Kreckel <kreckel@ginac.de>
Sun, 8 Feb 2015 19:52:57 +0000 (20:52 +0100)
Some function cannot be cleanly differentiated through the chain rule.
For example, it is natural to define derivative of the absolute value as

(abs(f))'=(f'*f.conjugate()+f*f'.conjugate())/2/abs(f)

This patch adds a possibility to define derivatives of functions in this way.
In particular the derivative of abs(), Order(), real_part(), imag_part() and
conjugate() are defined.

For example, conjugate of a derivative with respect of a real symbol
If x is real then U.diff(x)-I*V.diff(x) represents both
conjugate(U+I*V).diff(x) and conjugate((U+I*V).diff(x))
Thus in this patch we use the rule

conjugate(f)'=conjugate(f')

for a derivative with respect to the real symbol.

Signed-off-by: Vladimir V. Kisil <kisilv@maths.leeds.ac.uk>
check/exam_inifcns.cpp
doc/tutorial/ginac.texi
ginac/function.cppy
ginac/function.hppy
ginac/function.py
ginac/inifcns.cpp

index 19ad1b98def75b8b334b6a4d79fb29da166158c9..a0acb2de8af183d1f1094b06c7d9f506e5c37b6b 100644 (file)
@@ -343,6 +343,71 @@ static unsigned inifcns_consist_various()
        return result;
 }
 
+/* Several tests for derivetives */
+static unsigned inifcns_consist_derivatives()
+{
+       unsigned result = 0;
+       symbol z, w;
+       realsymbol x;
+       ex e, e1;
+
+       e=pow(x,z).conjugate().diff(x);
+       e1=pow(x,z).conjugate()*z.conjugate()/x;
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: pow(x,z).conjugate().diff(x) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       e=pow(w,z).conjugate().diff(w);
+       e1=pow(w,z).conjugate()*z.conjugate()/w;
+       if ( (e-e1).normal().is_zero() ) {
+               clog << "ERROR: pow(w,z).conjugate().diff(w) " << e << " = " << e1 << endl;
+               ++result;
+       }
+
+       e=atanh(x).imag_part().diff(x);
+       if (! e.is_zero() ) {
+               clog << "ERROR: atanh(x).imag_part().diff(x) " << e << " != 0" << endl;
+               ++result;
+       }
+
+       e=atanh(w).imag_part().diff(w);
+       if ( e.is_zero() ) {
+               clog << "ERROR: atanh(w).imag_part().diff(w) " << e << " = 0" << endl;
+               ++result;
+       }
+
+       e=atanh(x).real_part().diff(x);
+       e1=pow(1-x*x,-1);
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: atanh(x).real_part().diff(x) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       e=atanh(w).real_part().diff(w);
+       e1=pow(1-w*w,-1);
+       if ( (e-e1).normal().is_zero() ) {
+               clog << "ERROR: atanh(w).real_part().diff(w) " << e << " = " << e1 << endl;
+               ++result;
+       }
+
+       e=abs(log(z)).diff(z);
+       e1=(conjugate(log(z))/z+log(z)/conjugate(z))/abs(log(z))/2;
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: abs(log(z)).diff(z) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       e=Order(pow(x,4)).diff(x);
+       e1=Order(pow(x,3));
+       if (! (e-e1).normal().is_zero() ) {
+               clog << "ERROR: Order(pow(x,4)).diff(x) " << e << " != " << e1 << endl;
+               ++result;
+       }
+
+       return result;
+}
+
 unsigned exam_inifcns()
 {
        unsigned result = 0;
@@ -357,6 +422,7 @@ unsigned exam_inifcns()
        result += inifcns_consist_exp();  cout << '.' << flush;
        result += inifcns_consist_log();  cout << '.' << flush;
        result += inifcns_consist_various();  cout << '.' << flush;
+       result += inifcns_consist_derivatives();  cout << '.' << flush;
        
        return result;
 }
index 21e31b2c98501fc7581ceed6c3013e978dbbc3f1..3ac53981fe9e5602d1ddb1b11b923b66fedaa78d 100644 (file)
@@ -7103,6 +7103,25 @@ specifies which parameter to differentiate in a partial derivative in
 case the function has more than one parameter, and its main application
 is for correct handling of the chain rule.
 
+Derivatives of some functions, for example @code{abs()} and
+@code{Order()}, could not be evaluated through the chain rule. In such
+cases the full derivative may be specified as shown for @code{Order()}:
+
+@example
+static ex Order_expl_derivative(const ex & arg, const symbol & s)
+@{
+       return Order(arg.diff(s));
+@}
+@end example
+
+That is, we need to supply a procedure, which returns the expression of
+derivative with respect to the variable @code{s} for the argument
+@code{arg}. This procedure need to be registered with the function
+through the option @code{expl_derivative_func} (see the next
+Subsection). In contrast, a partial derivative, e.g. as was defined for
+@code{cos()} above, needs to be registered through the option
+@code{derivative_func}. 
+
 An implementation of the series expansion is not needed for @code{cos()} as
 it doesn't have any poles and GiNaC can do Taylor expansion by itself (as
 long as it knows what the derivative of @code{cos()} is). @code{tan()}, on
@@ -7138,14 +7157,15 @@ functions without any special options.
 eval_func(<C++ function>)
 evalf_func(<C++ function>)
 derivative_func(<C++ function>)
+expl_derivative_func(<C++ function>)
 series_func(<C++ function>)
 conjugate_func(<C++ function>)
 @end example
 
 These specify the C++ functions that implement symbolic evaluation,
-numeric evaluation, partial derivatives, and series expansion, respectively.
-They correspond to the GiNaC methods @code{eval()}, @code{evalf()},
-@code{diff()} and @code{series()}.
+numeric evaluation, partial derivatives, explicit derivative, and series
+expansion, respectively.  They correspond to the GiNaC methods
+@code{eval()}, @code{evalf()}, @code{diff()} and @code{series()}.
 
 The @code{eval_func()} function needs to use @code{.hold()} if no further
 automatic evaluation is desired or possible.
index d8a261f6ca3321e2d8f20ee7b586c3b0bf0372fe..dba9f4e0f096b8234234e540c109484ebdf260f8 100644 (file)
@@ -79,7 +79,7 @@ void function_options::initialize()
        set_name("unnamed_function", "\\\\mbox{unnamed}");
        nparams = 0;
        eval_f = evalf_f = real_part_f = imag_part_f = conjugate_f = expand_f
-               = derivative_f = power_f = series_f = 0;
+               = derivative_f = expl_derivative_f = power_f = series_f = 0;
        info_f = 0;
        evalf_params_first = true;
        use_return_type = false;
@@ -90,6 +90,7 @@ void function_options::initialize()
        imag_part_use_exvector_args = false;
        expand_use_exvector_args = false;
        derivative_use_exvector_args = false;
+       expl_derivative_use_exvector_args = false;
        power_use_exvector_args = false;
        series_use_exvector_args = false;
        print_use_exvector_args = false;
@@ -630,10 +631,10 @@ ex function::derivative(const symbol & s) const
 {
        ex result;
 
-       if (serial == Order_SERIAL::serial) {
-               // Order Term function only differentiates the argument
-               return Order(seq[0].diff(s));
-       } else {
+       try {
+               // Explicit derivation
+               result = expl_derivative(s);
+       } catch (...) {
                // Chain rule
                ex arg_diff;
                size_t num = seq.size();
@@ -752,6 +753,28 @@ ex function::pderivative(unsigned diff_param) const // partial differentiation
        throw(std::logic_error("function::pderivative(): no diff function defined"));
 }
 
+ex function::expl_derivative(const symbol & s) const // explicit differentiation
+{
+       GINAC_ASSERT(serial<registered_functions().size());
+       const function_options &opt = registered_functions()[serial];
+
+       // No explicit derivative defined? Then this function shall not be called!
+       if (opt.expl_derivative_f == NULL)
+               throw(std::logic_error("function::expl_derivative(): explicit derivation is called, but no such function defined"));
+
+       current_serial = serial;
+       if (opt.expl_derivative_use_exvector_args)
+               return ((expl_derivative_funcp_exvector)(opt.expl_derivative_f))(seq, s);
+       switch (opt.nparams) {
+               // the following lines have been generated for max. @maxargs@ parameters
++++ for N in range(1, maxargs + 1):
+               case @N@:
+                       return ((expl_derivative_funcp_@N@)(opt.expl_derivative_f))(@seq('seq[%(n)d]', N, 0)@, s);
+---
+               // end of generated lines
+       }
+}
+
 ex function::power(const ex & power_param) const // power of function
 {
        GINAC_ASSERT(serial<registered_functions().size());
index 6259d7a6f40fa25903484cec92c7fbfea0594641..971786d23a57efde2701899fb0de2523e21feabf 100644 (file)
@@ -59,6 +59,7 @@ typedef ex (* real_part_funcp)();
 typedef ex (* imag_part_funcp)();
 typedef ex (* expand_funcp)();
 typedef ex (* derivative_funcp)();
+typedef ex (* expl_derivative_funcp)();
 typedef ex (* power_funcp)();
 typedef ex (* series_funcp)();
 typedef void (* print_funcp)();
@@ -73,6 +74,7 @@ typedef ex (* real_part_funcp_@N@)( @args@ );
 typedef ex (* imag_part_funcp_@N@)( @args@ );
 typedef ex (* expand_funcp_@N@)( @args@, unsigned );
 typedef ex (* derivative_funcp_@N@)( @args@, unsigned );
+typedef ex (* expl_derivative_funcp_@N@)( @args@, const symbol & );
 typedef ex (* power_funcp_@N@)( @args@, const ex & );
 typedef ex (* series_funcp_@N@)( @args@, const relational &, int, unsigned );
 typedef void (* print_funcp_@N@)( @args@, const print_context & );
@@ -87,6 +89,7 @@ typedef ex (* @fp@_funcp_exvector)(const exvector &);
 ---
 typedef ex (* expand_funcp_exvector)(const exvector &, unsigned);
 typedef ex (* derivative_funcp_exvector)(const exvector &, unsigned);
+typedef ex (* expl_derivative_funcp_exvector)(const exvector &, const symbol &);
 typedef ex (* power_funcp_exvector)(const exvector &, const ex &);
 typedef ex (* series_funcp_exvector)(const exvector &, const relational &, int, unsigned);
 typedef void (* print_funcp_exvector)(const exvector &, const print_context &);
@@ -159,6 +162,7 @@ protected:
        imag_part_funcp imag_part_f;
        expand_funcp expand_f;
        derivative_funcp derivative_f;
+       expl_derivative_funcp expl_derivative_f;
        power_funcp power_f;
        series_funcp series_f;
        std::vector<print_funcp> print_dispatch_table;
@@ -182,6 +186,7 @@ protected:
        bool imag_part_use_exvector_args;
        bool expand_use_exvector_args;
        bool derivative_use_exvector_args;
+       bool expl_derivative_use_exvector_args;
        bool power_use_exvector_args;
        bool series_use_exvector_args;
        bool print_use_exvector_args;
@@ -251,6 +256,7 @@ protected:
        // non-virtual functions in this class
 protected:
        ex pderivative(unsigned diff_param) const; // partial differentiation
+       ex expl_derivative(const symbol & s) const; // partial differentiation
        static std::vector<function_options> & registered_functions();
        bool lookup_remember_table(ex & result) const;
        void store_remember_table(ex const & result) const;
index 3f5e54eb6028d95fef63e72c880c1cc6d17f2ab7..465976b349c28b650afc34e698d92234ff70f2c5 100755 (executable)
@@ -2,7 +2,7 @@
 # encoding: utf-8
 
 maxargs = 14
-methods = "eval evalf conjugate real_part imag_part expand derivative power series info print".split()
+methods = "eval evalf conjugate real_part imag_part expand derivative expl_derivative power series info print".split()
 
 import sys, os, optparse
 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'scripts'))
index 84cd2852ddd6bdb23160bf6d9ccf11e4352d0b72..28d54fed3e79924dffbd95c457add5ff2bbd662a 100644 (file)
@@ -24,6 +24,7 @@
 #include "ex.h"
 #include "constant.h"
 #include "lst.h"
+#include "fderivative.h"
 #include "matrix.h"
 #include "mul.h"
 #include "power.h"
@@ -66,6 +67,19 @@ static ex conjugate_conjugate(const ex & arg)
        return arg;
 }
 
+// If x is real then U.diff(x)-I*V.diff(x) represents both conjugate(U+I*V).diff(x) 
+// and conjugate((U+I*V).diff(x))
+static ex conjugate_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return conjugate(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(conjugate(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 static ex conjugate_real_part(const ex & arg)
 {
        return arg.real_part();
@@ -115,6 +129,7 @@ static bool conjugate_info(const ex & arg, unsigned inf)
 
 REGISTER_FUNCTION(conjugate_function, eval_func(conjugate_eval).
                                       evalf_func(conjugate_evalf).
+                                      expl_derivative_func(conjugate_expl_derivative).
                                       info_func(conjugate_info).
                                       print_func<print_latex>(conjugate_print_latex).
                                       conjugate_func(conjugate_conjugate).
@@ -159,8 +174,21 @@ static ex real_part_imag_part(const ex & arg)
        return 0;
 }
 
+// If x is real then Re(e).diff(x) is equal to Re(e.diff(x)) 
+static ex real_part_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return real_part_function(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(real_part(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 REGISTER_FUNCTION(real_part_function, eval_func(real_part_eval).
                                       evalf_func(real_part_evalf).
+                                      expl_derivative_func(real_part_expl_derivative).
                                       print_func<print_latex>(real_part_print_latex).
                                       conjugate_func(real_part_conjugate).
                                       real_part_func(real_part_real_part).
@@ -204,8 +232,21 @@ static ex imag_part_imag_part(const ex & arg)
        return 0;
 }
 
+// If x is real then Im(e).diff(x) is equal to Im(e.diff(x)) 
+static ex imag_part_expl_derivative(const ex & arg, const symbol & s)
+{
+       if (s.info(info_flags::real))
+               return imag_part_function(arg.diff(s));
+       else {
+               exvector vec_arg;
+               vec_arg.push_back(arg);
+               return fderivative(ex_to<function>(imag_part(arg)).get_serial(),0,vec_arg).hold()*arg.diff(s);
+       }
+}
+
 REGISTER_FUNCTION(imag_part_function, eval_func(imag_part_eval).
                                       evalf_func(imag_part_evalf).
+                                      expl_derivative_func(imag_part_expl_derivative).
                                       print_func<print_latex>(imag_part_print_latex).
                                       conjugate_func(imag_part_conjugate).
                                       real_part_func(imag_part_real_part).
@@ -275,6 +316,12 @@ static ex abs_expand(const ex & arg, unsigned options)
                return abs(arg).hold();
 }
 
+static ex abs_expl_derivative(const ex & arg, const symbol & s)
+{
+       ex diff_arg = arg.diff(s);
+       return (diff_arg*arg.conjugate()+arg*diff_arg.conjugate())/2/abs(arg);
+}
+
 static void abs_print_latex(const ex & arg, const print_context & c)
 {
        c.s << "{|"; arg.print(c); c.s << "|}";
@@ -341,6 +388,7 @@ bool abs_info(const ex & arg, unsigned inf)
 REGISTER_FUNCTION(abs, eval_func(abs_eval).
                        evalf_func(abs_evalf).
                        expand_func(abs_expand).
+                       expl_derivative_func(abs_expl_derivative).
                        info_func(abs_info).
                        print_func<print_latex>(abs_print_latex).
                        print_func<print_csrc_float>(abs_print_csrc_float).
@@ -977,11 +1025,15 @@ static ex Order_imag_part(const ex & x)
        return Order(x).hold();
 }
 
-// Differentiation is handled in function::derivative because of its special requirements
+static ex Order_expl_derivative(const ex & arg, const symbol & s)
+{
+       return Order(arg.diff(s));
+}
 
 REGISTER_FUNCTION(Order, eval_func(Order_eval).
                          series_func(Order_series).
                          latex_name("\\mathcal{O}").
+                         expl_derivative_func(Order_expl_derivative).
                          conjugate_func(Order_conjugate).
                          real_part_func(Order_real_part).
                          imag_part_func(Order_imag_part));