cgv
Loading...
Searching...
No Matches
lin_solve.h
1#pragma once
2#include <cgv/math/vec.h>
3#include <cgv/math/mat.h>
4#include <cgv/math/perm_mat.h>
5#include <cgv/math/diag_mat.h>
6#include <cgv/math/tri_diag_mat.h>
7#include <cgv/math/up_tri_mat.h>
8#include <cgv/math/low_tri_mat.h>
9#include <cgv/math/lu.h>
10#include <cgv/math/svd.h>
11#include <cgv/math/qr.h>
12
13
14namespace cgv {
15 namespace math {
16
19template<typename T>
20bool solve(const up_tri_mat<T>& a,const vec<T>&b, vec<T>&x)
21{
22 assert(a.nrows() == a.ncols());
23
24 int N = a.nrows();
25 x.resize(N);
26 T sum;
27 for(int i = N-1; i >= 0;i--)
28 {
29 sum =0;
30 for(int j = i+1;j < N; j++)
31 sum += a(i,j)*x(j);
32 if(a(i,i) == 0)
33 return false;
34 x[i] = (b[i] - sum)/a(i,i);
35 }
36 return true;
37}
38
43template<typename T>
44bool solve(const up_tri_mat<T>& a,const mat<T>&b,mat<T>&x)
45{
46 assert(b.nrows() == a.ncols());
47 vec<T> xcol;
48 x.resize(b.nrows(),b.ncols());
49 for(unsigned i = 0; i < b.ncols();i++)
50 {
51 if(!solve(a,b.col(i),xcol))
52 return false;
53 x.set_col(i,xcol);
54 }
55 return true;
56}
57
60template<typename T>
61bool solve(const low_tri_mat<T>& a, const vec<T>&b, vec<T>&x)
62{
63
64 int N = a.nrows();
65 x.resize(N);
66 T sum;
67 for(int i = 0; i < N;i++)
68 {
69 sum =0;
70 for(int j = 0;j < i; j++)
71 sum += a(i,j)*x(j);
72 if(a(i,i) == 0)
73 return false;
74 x[i] = (b[i] - sum)/a(i,i);
75 }
76 return true;
77
78}
79
84template<typename T>
85bool solve(const low_tri_mat<T>& a,const mat<T>&b,mat<T>&x)
86{
87 assert(b.nrows() == a.ncols());
88 vec<T> xcol;
89 x.resize(b.nrows(),b.ncols());
90 for(unsigned i = 0; i < b.ncols();i++)
91 {
92 if(!solve(a,b.col(i),xcol))
93 return false;
94 x.set_col(i,xcol);
95 }
96 return true;
97}
98
101template<typename T>
102bool solve(const diag_mat<T>& a, const vec<T>&b, vec<T>&x)
103{
104 int N = a.ncols();
105 x.resize(N);
106 for(int i = 0; i < N;i++)
107 {
108 if(a(i) == 0)
109 return false;
110 x(i) = (T)b(i)/a(i);
111 }
112 return true;
113}
114
117template <typename T>
118bool solve(const tri_diag_mat<T>& a, const vec<T>& b, vec<T>& x)
119{
120 x.resize(b.dim());
121 int i;
122 vec<T> aa = a.band(-1);
123 vec<T> bb = a.band(0);
124 vec<T> cc = a.band(1);
125 vec<T> dd = b;
126
127 int n = b.dim();
128
129 if(bb(0) == 0)
130 return false;
131 cc(0) /= bb(0);
132 dd(0) /= bb(0);
133 for(i = 1; i < n; i++)
134 {
135 T id = (bb(i) - cc(i-1) * aa(i));
136 if(id == 0)
137 return false;
138 cc(i) /= id;
139 dd(i) = (dd(i) - dd(i-1) * aa(i))/id;
140 }
141
142 x(n - 1) = dd(n - 1);
143 for(i = n - 2; i >= 0; i--)
144 x(i) = dd(i) - cc(i) * x(i + 1);
145 return true;
146}
147
152template<typename T>
153bool solve(const diag_mat<T>& a, const mat<T>&b, mat<T>&x)
154{
155 assert(b.nrows() == a.ncols());
156 vec<T> xcol;
157 x.resize(b.nrows(),b.ncols());
158 for(unsigned i = 0; i < b.ncols();i++)
159 {
160 if(!solve(a,b.col(i),xcol))
161 return false;
162 x.set_col(i,xcol);
163 }
164 return true;
165}
166
169template<typename T>
170bool solve(const perm_mat &a, const vec<T> &b, vec<T>&x)
171{
172
173 x.resize(a.nrows());
174 x=transpose(a)*b;
175 return true;
176}
177
182template<typename T>
183bool solve(const perm_mat& a,const mat<T>&b,mat<T>&x)
184{
185 assert(a.nrows() == a.ncols());
186 vec<T> xcol;
187 x.resize(b.nrows(),b.ncols());
188 for(unsigned i = 0; i < b.ncols();i++)
189 {
190 if(!solve(a,b.col(i),xcol))
191 return false;
192 x.set_col(i,xcol);
193 }
194 return true;
195}
196
197
198
199
202template<typename T>
203bool lu_solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
204{
205 assert(a.nrows() == a.ncols());
206 x.resize(a.nrows());
207 vec<T> temp1,temp2;
208 low_tri_mat<T> L;
209 up_tri_mat<T> U;
210 perm_mat P;
211 if(!lu(a,P,L,U))
212 return false;
213 if(!solve(P,b,temp1))
214 return false;
215 if(!solve(L,temp1,temp2))
216 return false;
217 return solve(U,temp2,x);
218}
219
222template<typename T>
223bool solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
224{
225 return lu_solve( a, b, x) ;
226}
227
232template<typename T>
233bool lu_solve(const mat<T>& a,const mat<T>&b,mat<T>&x)
234{
235 assert(a.nrows() == a.ncols());
236 x.resize(b.nrows(),b.ncols());
237 mat<T> temp1,temp2;
238 low_tri_mat<T> L;
239 up_tri_mat<T> U;
240 perm_mat P;
241 if(!lu(a,P,L,U))
242 return false;
243 if(!solve(P,b,temp1))
244 return false;
245 if(!solve(L,temp1,temp2))
246 return false;
247 return solve(U,temp2,x);
248
249}
250
255template<typename T>
256bool solve(const mat<T>& a, const mat<T>&b, mat<T>&x)
257{
258 return svd_solve( a, b, x) ;
259}
260
261
262
264template<typename T>
265bool qr_solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
266{
267 x.resize(a.nrows());
268 vec<T> temp;
269 mat<T> q;
270 up_tri_mat<T> r;
271 if(!qr(a,q,r))
272 return false;
273 Atx(q,b,temp);
274 return solve(r,temp,x);
275}
276
277
282template<typename T>
283bool qr_solve(const mat<T>& a, const mat<T>&b, mat<T>&x)
284{
285 assert(a.nrows() == a.ncols());
286 x.resize(b.nrows(),b.ncols());
287 mat<T> temp;
288 mat<T> q;
289 up_tri_mat<T> r;
290 if(!qr(a,q,r))
291 return false;
292 AtB(q,b,temp);
293 return solve(r,temp,x);
294}
295
296
297
299template<typename T>
300bool svd_solve(const mat<T>& a, const vec<T>&b, vec<T>&x)
301{
302 x.resize(a.nrows());
303 vec<T> temp;
304 mat<T> u,v;
305 diag_mat<T> d;
306 if(!svd(a,u,d,v))
307 return false;
308
309 Atx(u,b,temp);
310 if(!solve(d,temp,x))
311 return false;
312 x=v*x;
313 return true;
314}
315
317template<typename T>
318bool svd_solve(const mat<T>& a, const mat<T>&b, mat<T>&x)
319{
320 assert(a.nrows() == a.ncols());
321 x.resize(b.nrows(),b.ncols());
322
323 mat<T> temp;
324 mat<T> u,v;
325 diag_mat<T> d;
326 if(!svd(a,u,d,v))
327 return false;
328 AtB(u,b,temp);
329 if(!solve(d,temp,x))
330 return false;
331 x=v*x;
332 return true;
333}
334
335
336
337
338
339
340}
341
342
343}
the cgv namespace
Definition print.h:11