2#include <cgv/math/vec.h>
3#include <cgv/math/mat.h>
4#include <cgv/math/perm_mat.h>
5#include <cgv/math/diag_mat.h>
6#include <cgv/math/tri_diag_mat.h>
7#include <cgv/math/up_tri_mat.h>
8#include <cgv/math/low_tri_mat.h>
9#include <cgv/math/lu.h>
10#include <cgv/math/svd.h>
11#include <cgv/math/qr.h>
20bool solve(
const up_tri_mat<T>& a,
const vec<T>&b, vec<T>&x)
22 assert(a.nrows() == a.ncols());
27 for(
int i = N-1; i >= 0;i--)
30 for(
int j = i+1;j < N; j++)
34 x[i] = (b[i] - sum)/a(i,i);
44bool solve(
const up_tri_mat<T>& a,
const mat<T>&b,mat<T>&x)
46 assert(b.nrows() == a.ncols());
48 x.resize(b.nrows(),b.ncols());
49 for(
unsigned i = 0; i < b.ncols();i++)
51 if(!solve(a,b.col(i),xcol))
61bool solve(
const low_tri_mat<T>& a,
const vec<T>&b, vec<T>&x)
67 for(
int i = 0; i < N;i++)
70 for(
int j = 0;j < i; j++)
74 x[i] = (b[i] - sum)/a(i,i);
85bool solve(
const low_tri_mat<T>& a,
const mat<T>&b,mat<T>&x)
87 assert(b.nrows() == a.ncols());
89 x.resize(b.nrows(),b.ncols());
90 for(
unsigned i = 0; i < b.ncols();i++)
92 if(!solve(a,b.col(i),xcol))
102bool solve(
const diag_mat<T>& a,
const vec<T>&b, vec<T>&x)
106 for(
int i = 0; i < N;i++)
118bool solve(
const tri_diag_mat<T>& a,
const vec<T>& b, vec<T>& x)
122 vec<T> aa = a.band(-1);
123 vec<T> bb = a.band(0);
124 vec<T> cc = a.band(1);
133 for(i = 1; i < n; i++)
135 T
id = (bb(i) - cc(i-1) * aa(i));
139 dd(i) = (dd(i) - dd(i-1) * aa(i))/
id;
142 x(n - 1) = dd(n - 1);
143 for(i = n - 2; i >= 0; i--)
144 x(i) = dd(i) - cc(i) * x(i + 1);
153bool solve(
const diag_mat<T>& a,
const mat<T>&b, mat<T>&x)
155 assert(b.nrows() == a.ncols());
157 x.resize(b.nrows(),b.ncols());
158 for(
unsigned i = 0; i < b.ncols();i++)
160 if(!solve(a,b.col(i),xcol))
170bool solve(
const perm_mat &a,
const vec<T> &b, vec<T>&x)
183bool solve(
const perm_mat& a,
const mat<T>&b,mat<T>&x)
185 assert(a.nrows() == a.ncols());
187 x.resize(b.nrows(),b.ncols());
188 for(
unsigned i = 0; i < b.ncols();i++)
190 if(!solve(a,b.col(i),xcol))
203bool lu_solve(
const mat<T>& a,
const vec<T>&b, vec<T>&x)
205 assert(a.nrows() == a.ncols());
213 if(!solve(P,b,temp1))
215 if(!solve(L,temp1,temp2))
217 return solve(U,temp2,x);
223bool solve(
const mat<T>& a,
const vec<T>&b, vec<T>&x)
225 return lu_solve( a, b, x) ;
233bool lu_solve(
const mat<T>& a,
const mat<T>&b,mat<T>&x)
235 assert(a.nrows() == a.ncols());
236 x.resize(b.nrows(),b.ncols());
243 if(!solve(P,b,temp1))
245 if(!solve(L,temp1,temp2))
247 return solve(U,temp2,x);
256bool solve(
const mat<T>& a,
const mat<T>&b, mat<T>&x)
258 return svd_solve( a, b, x) ;
265bool qr_solve(
const mat<T>& a,
const vec<T>&b, vec<T>&x)
274 return solve(r,temp,x);
283bool qr_solve(
const mat<T>& a,
const mat<T>&b, mat<T>&x)
285 assert(a.nrows() == a.ncols());
286 x.resize(b.nrows(),b.ncols());
293 return solve(r,temp,x);
300bool svd_solve(
const mat<T>& a,
const vec<T>&b, vec<T>&x)
318bool svd_solve(
const mat<T>& a,
const mat<T>&b, mat<T>&x)
320 assert(a.nrows() == a.ncols());
321 x.resize(b.nrows(),b.ncols());