]> www.ginac.de Git - ginac.git/blob - ginac/add.cpp
* Fix CLN output of -2*a-2*b. (Chris Dams)
[ginac.git] / ginac / add.cpp
1 /** @file add.cpp
2  *
3  *  Implementation of GiNaC's sums of expressions. */
4
5 /*
6  *  GiNaC Copyright (C) 1999-2002 Johannes Gutenberg University Mainz, Germany
7  *
8  *  This program is free software; you can redistribute it and/or modify
9  *  it under the terms of the GNU General Public License as published by
10  *  the Free Software Foundation; either version 2 of the License, or
11  *  (at your option) any later version.
12  *
13  *  This program is distributed in the hope that it will be useful,
14  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
15  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16  *  GNU General Public License for more details.
17  *
18  *  You should have received a copy of the GNU General Public License
19  *  along with this program; if not, write to the Free Software
20  *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
21  */
22
23 #include <iostream>
24 #include <stdexcept>
25
26 #include "add.h"
27 #include "mul.h"
28 #include "archive.h"
29 #include "operators.h"
30 #include "matrix.h"
31 #include "utils.h"
32
33 namespace GiNaC {
34
35 GINAC_IMPLEMENT_REGISTERED_CLASS(add, expairseq)
36
37 //////////
38 // default ctor, dtor, copy ctor, assignment operator and helpers
39 //////////
40
41 add::add()
42 {
43         tinfo_key = TINFO_add;
44 }
45
46 DEFAULT_COPY(add)
47 DEFAULT_DESTROY(add)
48
49 //////////
50 // other constructors
51 //////////
52
53 // public
54
55 add::add(const ex & lh, const ex & rh)
56 {
57         tinfo_key = TINFO_add;
58         overall_coeff = _ex0;
59         construct_from_2_ex(lh,rh);
60         GINAC_ASSERT(is_canonical());
61 }
62
63 add::add(const exvector & v)
64 {
65         tinfo_key = TINFO_add;
66         overall_coeff = _ex0;
67         construct_from_exvector(v);
68         GINAC_ASSERT(is_canonical());
69 }
70
71 add::add(const epvector & v)
72 {
73         tinfo_key = TINFO_add;
74         overall_coeff = _ex0;
75         construct_from_epvector(v);
76         GINAC_ASSERT(is_canonical());
77 }
78
79 add::add(const epvector & v, const ex & oc)
80 {
81         tinfo_key = TINFO_add;
82         overall_coeff = oc;
83         construct_from_epvector(v);
84         GINAC_ASSERT(is_canonical());
85 }
86
87 add::add(epvector * vp, const ex & oc)
88 {
89         tinfo_key = TINFO_add;
90         GINAC_ASSERT(vp!=0);
91         overall_coeff = oc;
92         construct_from_epvector(*vp);
93         delete vp;
94         GINAC_ASSERT(is_canonical());
95 }
96
97 //////////
98 // archiving
99 //////////
100
101 DEFAULT_ARCHIVING(add)
102
103 //////////
104 // functions overriding virtual functions from base classes
105 //////////
106
107 // public
108
109 void add::print(const print_context & c, unsigned level) const
110 {
111         if (is_a<print_tree>(c)) {
112
113                 inherited::print(c, level);
114
115         } else if (is_a<print_csrc>(c)) {
116
117                 if (precedence() <= level)
118                         c.s << "(";
119         
120                 // Print arguments, separated by "+"
121                 epvector::const_iterator it = seq.begin(), itend = seq.end();
122                 while (it != itend) {
123                 
124                         // If the coefficient is -1, it is replaced by a single minus sign
125                         if (it->coeff.is_equal(_ex1)) {
126                                 it->rest.print(c, precedence());
127                         } else if (it->coeff.is_equal(_ex_1)) {
128                                 c.s << "-";
129                                 it->rest.print(c, precedence());
130                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num1)) {
131                                 it->rest.print(c, precedence());
132                                 c.s << "/";
133                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
134                         } else if (ex_to<numeric>(it->coeff).numer().is_equal(_num_1)) {
135                                 c.s << "-";
136                                 it->rest.print(c, precedence());
137                                 c.s << "/";
138                                 ex_to<numeric>(it->coeff).denom().print(c, precedence());
139                         } else {
140                                 it->coeff.print(c, precedence());
141                                 c.s << "*";
142                                 it->rest.print(c, precedence());
143                         }
144                 
145                         // Separator is "+", except if the following expression would have a leading minus sign or the sign is sitting in parenthesis (as in a ctor)
146                         ++it;
147                         if (it != itend
148                                 && (is_a<print_csrc_cl_N>(c)  // sign inside ctor arguments
149                                         || !(it->coeff.info(info_flags::negative) || (it->coeff.is_equal(_num1) && is_exactly_a<numeric>(it->rest) && it->rest.info(info_flags::negative)))))
150                                 c.s << "+";
151                 }
152         
153                 if (!overall_coeff.is_zero()) {
154                         if (overall_coeff.info(info_flags::positive))
155                                 c.s << '+';
156                         overall_coeff.print(c, precedence());
157                 }
158                 
159                 if (precedence() <= level)
160                         c.s << ")";
161
162         } else if (is_a<print_python_repr>(c)) {
163
164                 c.s << class_name() << '(';
165                 op(0).print(c);
166                 for (unsigned i=1; i<nops(); ++i) {
167                         c.s << ',';
168                         op(i).print(c);
169                 }
170                 c.s << ')';
171
172         } else {
173
174                 if (precedence() <= level) {
175                         if (is_a<print_latex>(c))
176                                 c.s << "{(";
177                         else
178                                 c.s << "(";
179                 }
180
181                 numeric coeff;
182                 bool first = true;
183
184                 // First print the overall numeric coefficient, if present
185                 if (!overall_coeff.is_zero()) {
186                         if (!is_a<print_tree>(c))
187                                 overall_coeff.print(c, 0);
188                         else
189                                 overall_coeff.print(c, precedence());
190                         first = false;
191                 }
192
193                 // Then proceed with the remaining factors
194                 epvector::const_iterator it = seq.begin(), itend = seq.end();
195                 while (it != itend) {
196                         coeff = ex_to<numeric>(it->coeff);
197                         if (!first) {
198                                 if (coeff.csgn() == -1) c.s << '-'; else c.s << '+';
199                         } else {
200                                 if (coeff.csgn() == -1) c.s << '-';
201                                 first = false;
202                         }
203                         if (!coeff.is_equal(_num1) &&
204                             !coeff.is_equal(_num_1)) {
205                                 if (coeff.is_rational()) {
206                                         if (coeff.is_negative())
207                                                 (-coeff).print(c);
208                                         else
209                                                 coeff.print(c);
210                                 } else {
211                                         if (coeff.csgn() == -1)
212                                                 (-coeff).print(c, precedence());
213                                         else
214                                                 coeff.print(c, precedence());
215                                 }
216                                 if (is_a<print_latex>(c))
217                                         c.s << ' ';
218                                 else
219                                         c.s << '*';
220                         }
221                         it->rest.print(c, precedence());
222                         ++it;
223                 }
224
225                 if (precedence() <= level) {
226                         if (is_a<print_latex>(c))
227                                 c.s << ")}";
228                         else
229                                 c.s << ")";
230                 }
231         }
232 }
233
234 bool add::info(unsigned inf) const
235 {
236         switch (inf) {
237                 case info_flags::polynomial:
238                 case info_flags::integer_polynomial:
239                 case info_flags::cinteger_polynomial:
240                 case info_flags::rational_polynomial:
241                 case info_flags::crational_polynomial:
242                 case info_flags::rational_function: {
243                         epvector::const_iterator i = seq.begin(), end = seq.end();
244                         while (i != end) {
245                                 if (!(recombine_pair_to_ex(*i).info(inf)))
246                                         return false;
247                                 ++i;
248                         }
249                         return overall_coeff.info(inf);
250                 }
251                 case info_flags::algebraic: {
252                         epvector::const_iterator i = seq.begin(), end = seq.end();
253                         while (i != end) {
254                                 if ((recombine_pair_to_ex(*i).info(inf)))
255                                         return true;
256                                 ++i;
257                         }
258                         return false;
259                 }
260         }
261         return inherited::info(inf);
262 }
263
264 int add::degree(const ex & s) const
265 {
266         int deg = INT_MIN;
267         if (!overall_coeff.is_zero())
268                 deg = 0;
269         
270         // Find maximum of degrees of individual terms
271         epvector::const_iterator i = seq.begin(), end = seq.end();
272         while (i != end) {
273                 int cur_deg = i->rest.degree(s);
274                 if (cur_deg > deg)
275                         deg = cur_deg;
276                 ++i;
277         }
278         return deg;
279 }
280
281 int add::ldegree(const ex & s) const
282 {
283         int deg = INT_MAX;
284         if (!overall_coeff.is_zero())
285                 deg = 0;
286         
287         // Find minimum of degrees of individual terms
288         epvector::const_iterator i = seq.begin(), end = seq.end();
289         while (i != end) {
290                 int cur_deg = i->rest.ldegree(s);
291                 if (cur_deg < deg)
292                         deg = cur_deg;
293                 ++i;
294         }
295         return deg;
296 }
297
298 ex add::coeff(const ex & s, int n) const
299 {
300         epvector *coeffseq = new epvector();
301
302         // Calculate sum of coefficients in each term
303         epvector::const_iterator i = seq.begin(), end = seq.end();
304         while (i != end) {
305                 ex restcoeff = i->rest.coeff(s, n);
306                 if (!restcoeff.is_zero())
307                         coeffseq->push_back(combine_ex_with_coeff_to_pair(restcoeff, i->coeff));
308                 ++i;
309         }
310
311         return (new add(coeffseq, n==0 ? overall_coeff : _ex0))->setflag(status_flags::dynallocated);
312 }
313
314 /** Perform automatic term rewriting rules in this class.  In the following
315  *  x stands for a symbolic variables of type ex and c stands for such
316  *  an expression that contain a plain number.
317  *  - +(;c) -> c
318  *  - +(x;1) -> x
319  *
320  *  @param level cut-off in recursive evaluation */
321 ex add::eval(int level) const
322 {
323         epvector *evaled_seqp = evalchildren(level);
324         if (evaled_seqp) {
325                 // do more evaluation later
326                 return (new add(evaled_seqp, overall_coeff))->
327                        setflag(status_flags::dynallocated);
328         }
329         
330 #ifdef DO_GINAC_ASSERT
331         epvector::const_iterator i = seq.begin(), end = seq.end();
332         while (i != end) {
333                 GINAC_ASSERT(!is_exactly_a<add>(i->rest));
334                 if (is_exactly_a<numeric>(i->rest))
335                         dbgprint();
336                 GINAC_ASSERT(!is_exactly_a<numeric>(i->rest));
337                 ++i;
338         }
339 #endif // def DO_GINAC_ASSERT
340         
341         if (flags & status_flags::evaluated) {
342                 GINAC_ASSERT(seq.size()>0);
343                 GINAC_ASSERT(seq.size()>1 || !overall_coeff.is_zero());
344                 return *this;
345         }
346         
347         int seq_size = seq.size();
348         if (seq_size == 0) {
349                 // +(;c) -> c
350                 return overall_coeff;
351         } else if (seq_size == 1 && overall_coeff.is_zero()) {
352                 // +(x;0) -> x
353                 return recombine_pair_to_ex(*(seq.begin()));
354         } else if (!overall_coeff.is_zero() && seq[0].rest.return_type() != return_types::commutative) {
355                 throw (std::logic_error("add::eval(): sum of non-commutative objects has non-zero numeric term"));
356         }
357         return this->hold();
358 }
359
360 ex add::evalm(void) const
361 {
362         // Evaluate children first and add up all matrices. Stop if there's one
363         // term that is not a matrix.
364         epvector *s = new epvector;
365         s->reserve(seq.size());
366
367         bool all_matrices = true;
368         bool first_term = true;
369         matrix sum;
370
371         epvector::const_iterator it = seq.begin(), itend = seq.end();
372         while (it != itend) {
373                 const ex &m = recombine_pair_to_ex(*it).evalm();
374                 s->push_back(split_ex_to_pair(m));
375                 if (is_a<matrix>(m)) {
376                         if (first_term) {
377                                 sum = ex_to<matrix>(m);
378                                 first_term = false;
379                         } else
380                                 sum = sum.add(ex_to<matrix>(m));
381                 } else
382                         all_matrices = false;
383                 ++it;
384         }
385
386         if (all_matrices) {
387                 delete s;
388                 return sum + overall_coeff;
389         } else
390                 return (new add(s, overall_coeff))->setflag(status_flags::dynallocated);
391 }
392
393 ex add::simplify_ncmul(const exvector & v) const
394 {
395         if (seq.empty())
396                 return inherited::simplify_ncmul(v);
397         else
398                 return seq.begin()->rest.simplify_ncmul(v);
399 }    
400
401 // protected
402
403 /** Implementation of ex::diff() for a sum. It differentiates each term.
404  *  @see ex::diff */
405 ex add::derivative(const symbol & y) const
406 {
407         epvector *s = new epvector();
408         s->reserve(seq.size());
409         
410         // Only differentiate the "rest" parts of the expairs. This is faster
411         // than the default implementation in basic::derivative() although
412         // if performs the same function (differentiate each term).
413         epvector::const_iterator i = seq.begin(), end = seq.end();
414         while (i != end) {
415                 s->push_back(combine_ex_with_coeff_to_pair(i->rest.diff(y), i->coeff));
416                 ++i;
417         }
418         return (new add(s, _ex0))->setflag(status_flags::dynallocated);
419 }
420
421 int add::compare_same_type(const basic & other) const
422 {
423         return inherited::compare_same_type(other);
424 }
425
426 unsigned add::return_type(void) const
427 {
428         if (seq.empty())
429                 return return_types::commutative;
430         else
431                 return seq.begin()->rest.return_type();
432 }
433    
434 unsigned add::return_type_tinfo(void) const
435 {
436         if (seq.empty())
437                 return tinfo_key;
438         else
439                 return seq.begin()->rest.return_type_tinfo();
440 }
441
442 ex add::thisexpairseq(const epvector & v, const ex & oc) const
443 {
444         return (new add(v,oc))->setflag(status_flags::dynallocated);
445 }
446
447 ex add::thisexpairseq(epvector * vp, const ex & oc) const
448 {
449         return (new add(vp,oc))->setflag(status_flags::dynallocated);
450 }
451
452 expair add::split_ex_to_pair(const ex & e) const
453 {
454         if (is_exactly_a<mul>(e)) {
455                 const mul &mulref(ex_to<mul>(e));
456                 const ex &numfactor = mulref.overall_coeff;
457                 mul *mulcopyp = new mul(mulref);
458                 mulcopyp->overall_coeff = _ex1;
459                 mulcopyp->clearflag(status_flags::evaluated);
460                 mulcopyp->clearflag(status_flags::hash_calculated);
461                 mulcopyp->setflag(status_flags::dynallocated);
462                 return expair(*mulcopyp,numfactor);
463         }
464         return expair(e,_ex1);
465 }
466
467 expair add::combine_ex_with_coeff_to_pair(const ex & e,
468                                                                                   const ex & c) const
469 {
470         GINAC_ASSERT(is_exactly_a<numeric>(c));
471         if (is_exactly_a<mul>(e)) {
472                 const mul &mulref(ex_to<mul>(e));
473                 const ex &numfactor = mulref.overall_coeff;
474                 mul *mulcopyp = new mul(mulref);
475                 mulcopyp->overall_coeff = _ex1;
476                 mulcopyp->clearflag(status_flags::evaluated);
477                 mulcopyp->clearflag(status_flags::hash_calculated);
478                 mulcopyp->setflag(status_flags::dynallocated);
479                 if (c.is_equal(_ex1))
480                         return expair(*mulcopyp, numfactor);
481                 else if (numfactor.is_equal(_ex1))
482                         return expair(*mulcopyp, c);
483                 else
484                         return expair(*mulcopyp, ex_to<numeric>(numfactor).mul_dyn(ex_to<numeric>(c)));
485         } else if (is_exactly_a<numeric>(e)) {
486                 if (c.is_equal(_ex1))
487                         return expair(e, _ex1);
488                 return expair(ex_to<numeric>(e).mul_dyn(ex_to<numeric>(c)), _ex1);
489         }
490         return expair(e, c);
491 }
492
493 expair add::combine_pair_with_coeff_to_pair(const expair & p,
494                                                                                         const ex & c) const
495 {
496         GINAC_ASSERT(is_exactly_a<numeric>(p.coeff));
497         GINAC_ASSERT(is_exactly_a<numeric>(c));
498
499         if (is_exactly_a<numeric>(p.rest)) {
500                 GINAC_ASSERT(ex_to<numeric>(p.coeff).is_equal(_num1)); // should be normalized
501                 return expair(ex_to<numeric>(p.rest).mul_dyn(ex_to<numeric>(c)),_ex1);
502         }
503
504         return expair(p.rest,ex_to<numeric>(p.coeff).mul_dyn(ex_to<numeric>(c)));
505 }
506         
507 ex add::recombine_pair_to_ex(const expair & p) const
508 {
509         if (ex_to<numeric>(p.coeff).is_equal(_num1))
510                 return p.rest;
511         else
512                 return (new mul(p.rest,p.coeff))->setflag(status_flags::dynallocated);
513 }
514
515 ex add::expand(unsigned options) const
516 {
517         epvector *vp = expandchildren(options);
518         if (vp == NULL) {
519                 // the terms have not changed, so it is safe to declare this expanded
520                 return (options == 0) ? setflag(status_flags::expanded) : *this;
521         }
522         
523         return (new add(vp, overall_coeff))->setflag(status_flags::dynallocated | (options == 0 ? status_flags::expanded : 0));
524 }
525
526 } // namespace GiNaC