if (row != col)
throw (std::logic_error("matrix::inverse(): matrix not square"));
- // NOTE: the Gauss-Jordan elimination used here can in principle be
- // replaced by two clever calls to gauss_elimination() and some to
- // transpose(). Wouldn't be more efficient (maybe less?), just more
- // orthogonal.
- matrix tmp(row,col);
- // set tmp to the unit matrix
- for (unsigned i=0; i<col; ++i)
- tmp.m[i*col+i] = _ex1();
+ // This routine actually doesn't do anything fancy at all. We compute the
+ // inverse of the matrix A by solving the system A * A^{-1} == Id.
- // create a copy of this matrix
- matrix cpy(*this);
- for (unsigned r1=0; r1<row; ++r1) {
- int indx = cpy.pivot(r1, r1);
- if (indx == -1) {
+ // First populate the identity matrix supposed to become the right hand side.
+ matrix identity(row,col);
+ for (unsigned i=0; i<row; ++i)
+ identity.set(i,i,_ex1());
+
+ // Populate a dummy matrix of variables, just because of compatibility with
+ // matrix::solve() which wants this (for compatibility with under-determined
+ // systems of equations).
+ matrix vars(row,col);
+ for (unsigned r=0; r<row; ++r)
+ for (unsigned c=0; c<col; ++c)
+ vars.set(r,c,symbol());
+
+ matrix sol(row,col);
+ try {
+ sol = this->solve(vars,identity);
+ } catch (const std::runtime_error & e) {
+ if (e.what()==std::string("matrix::solve(): inconsistent linear system"))
throw (std::runtime_error("matrix::inverse(): singular matrix"));
- }
- if (indx != 0) { // swap rows r and indx of matrix tmp
- for (unsigned i=0; i<col; ++i)
- tmp.m[r1*col+i].swap(tmp.m[indx*col+i]);
- }
- ex a1 = cpy.m[r1*col+r1];
- for (unsigned c=0; c<col; ++c) {
- cpy.m[r1*col+c] /= a1;
- tmp.m[r1*col+c] /= a1;
- }
- for (unsigned r2=0; r2<row; ++r2) {
- if (r2 != r1) {
- if (!cpy.m[r2*col+r1].is_zero()) {
- ex a2 = cpy.m[r2*col+r1];
- // yes, there is something to do in this column
- for (unsigned c=0; c<col; ++c) {
- cpy.m[r2*col+c] -= a2 * cpy.m[r1*col+c];
- if (!cpy.m[r2*col+c].info(info_flags::numeric))
- cpy.m[r2*col+c] = cpy.m[r2*col+c].normal();
- tmp.m[r2*col+c] -= a2 * tmp.m[r1*col+c];
- if (!tmp.m[r2*col+c].info(info_flags::numeric))
- tmp.m[r2*col+c] = tmp.m[r2*col+c].normal();
- }
- }
- }
- }
+ else
+ throw;
}
-
- return tmp;
+ return sol;
}
switch(algo) {
case solve_algo::gauss:
aug.gauss_elimination();
+ break;
case solve_algo::divfree:
aug.division_free_elimination();
+ break;
case solve_algo::bareiss:
default:
aug.fraction_free_elimination();