這次實作主要是熟悉一些 stl 使用,語法盡可能使用 C++,挑用兩種不同資料結構,結果感到有些意外。

第一種資料結構是單純用指標配置一維 heap,做 index 轉換 ( 就是數值分析習慣用的一維模擬二維) ;

另一種資料結構採用 vector< vararray<T> >。先把意外的結論寫下來。

 

valarray 在 vc 底下之操作並沒顯得優勢 ,反之所費時間約為一般陣列存取之兩倍;而 g++ 下之 valarray 確實有操作上之優勢

 

另外大陸的行(row)列(column),與台灣的行(column)列(row) 在中文上是相反的,感到困惑網友請注意這點。

 

[1] 原理篇

 

這裡用的求反矩陣,必須有一份暫存矩陣,而不是原地做。高斯消去法求反矩陣不熟請參考其他相關資料,這裡講的是一些小細節注意地方。假設欲求矩陣 src 之反矩 dst,一開始 dst 先設單位矩陣。在簡化過程之第 k 個迭代時

 

a. 列交換( row swap, Rij)

若要將 src 第 i 列 / 第 j 列互換時,由於第 i 列與第 j 列前面 k 個元素都一定被消成 0 了,所以做列互換時只需行第 k 個元素交換即可;但 dst 一定要所有元素都經過交換。

 

b. 行交換 (column swap, Cij)

將 src 第 i 行 / 第 j 行互換時,這沒話說,不論是 src 還是 dst 必將所有元素全都交換

 

c. 第 i 列乘以 v 加到第 j 列  (Rij (v) )

和列交換相仿,src 由於前面 k 個元素必被消成 0 ,所以做這動作時,只需從第 k 個元素開始執行即可;而 dst 則必須所有元素都進行這步驟。

 

 

[2] 程式碼與執行結果

 

這份程式碼頗長,由於是拿來做測時比較,故上述的列交換、Rij(v) 等都沒經過化簡,程式碼如下。

 

Code Snippet
  1. #include <iostream>
  2. #include<algorithm>
  3. #include <cmath>
  4. #include <vector>
  5. #include <valarray>
  6. #include <ctime>
  7. using namespace std;
  8.  
  9. template<typename T>
  10. bool minv1(const T * src, T * dst, const size_t n, const T eps)
  11. {
  12.     const size_t n2=n*n;
  13.     size_t i, j, k, is, js;
  14.     T pivot, tmp;
  15.  
  16.     T * mbk = new T [n2];
  17.     copy(src, src+n2, mbk); // copy src to mbk
  18.     fill(dst, dst+n2, 0); // set dst to unit matrix
  19.     for(i=0; i<n; ++i) dst[i*n+i]=1;
  20.  
  21.     for(k=0; k<n; ++k){ // execute k times
  22.         pivot = 0; // find pivot
  23.         for(i=k; i<n; ++i) for(j=k; j<n; ++j)
  24.             if(pivot < (tmp = abs(mbk[i*n+k])))
  25.                 pivot=tmp, is=i, js=j;
  26.  
  27.         if(pivot<=eps){ // has no solution
  28.             delete [] mbk;
  29.             return false;
  30.         }
  31.         if(is!=k){ // swap row(is,k)
  32.             swap_ranges(mbk+k*n, mbk+k*n+n, mbk+is*n);
  33.             swap_ranges(dst+k*n, dst+k*n+n, dst+is*n);
  34.         }
  35.         if(js!=k){ // swap col(js,k)
  36.             for(i=0; i<n; ++i){
  37.                 swap(mbk[i*n+k], mbk[i*n+js]);
  38.                 swap(dst[i*n+k], dst[i*n+js]);
  39.             }
  40.         }
  41.         pivot = mbk[k*n+k]; // adjust pivot to 1
  42.         for(j=0; j<n; ++j){
  43.             mbk[k*n+j]/=pivot;
  44.             dst[k*n+j]/=pivot;
  45.         }
  46.         // elimination
  47.         for(i=0; i<n; ++i){
  48.             if(i!=k){
  49.                 tmp = -mbk[i*n+k];
  50.                 for(j=0; j<n; ++j){
  51.                     mbk[i*n+j]+=tmp*mbk[k*n+j];
  52.                     dst[i*n+j]+=tmp*dst[k*n+j];
  53.                 }
  54.             }
  55.         }
  56.     }
  57.     delete mbk;
  58.     return true;
  59. }
  60.  
  61. template<typename T>
  62. bool minv2(vector< valarray<T> > src, vector<valarray<T>> & dst, const size_t n, const T eps)
  63. {
  64.     const size_t n2=n*n;
  65.     size_t i, j, k, is, js;
  66.     T pivot, tmp;
  67.     valarray<T> row_tmp(n);
  68.         
  69.     for(i=0; i<n; ++i) dst[i][i]=1;
  70.  
  71.     for(k=0; k<n; ++k){ // execute k times
  72.         pivot = 0; // find pivot
  73.         for(i=k; i<n; ++i) for(j=k; j<n; ++j)
  74.             if(pivot < (tmp = abs(src[i][k])))
  75.                 pivot=tmp, is=i, js=j;
  76.  
  77.         if(pivot<=eps) // has no solution
  78.             return false;
  79.  
  80.         if(is!=k){ // swap row(is,k)
  81.             row_tmp = src[is], src[is] = src[k], src[k]=row_tmp;
  82.             row_tmp = dst[is], dst[is] = dst[k], dst[k]=row_tmp;
  83.         }
  84.         if(js!=k){ // swap col(js,k)
  85.             for(i=0; i<n; ++i){
  86.                 swap(src[i][k], src[i][js]);
  87.                 swap(dst[i][k], dst[i][js]);
  88.             }
  89.         }
  90.         pivot = src[k][k]; // adjust pivot to 1
  91.         src[k]/=pivot, dst[k]/=pivot;
  92.         
  93.         // elimination
  94.         for(i=0; i<n; ++i){
  95.             if(i!=k){
  96.                 tmp = -src[i][k];
  97.                 src[i]+=tmp*src[k];
  98.                 dst[i]+=tmp*dst[k];
  99.             }
  100.         }
  101.     }
  102.     return true;
  103. }
  104.  
  105. template<typename T>
  106. void gen(T *& arr, const size_t n){
  107.     const size_t n2=n*n;
  108.     for(size_t i=0; i<n2; ++i)
  109.         arr[i] = rand() % 10;
  110. }
  111.  
  112. int main()
  113. {
  114.     const size_t n = 1000;
  115.     const double eps=1e-9;
  116.     double *x = new double[n*n];
  117.     double *y = new double[n*n];    
  118.     vector< valarray<double> > vva(n, valarray<double>(n));
  119.     vector< valarray<double> > vvb(n, valarray<double>(n));
  120.     clock_t ck;
  121.  
  122.     gen(x,n);
  123.     for(size_t i=0; i<n; ++i) for(size_t j=0 ; j<n; ++j)
  124.         vva[i][j] = x[i*n+j];
  125.  
  126.     ck = clock();
  127.     minv1(x, y, n, eps);
  128.     cout << "minv1  : " << clock()-ck << endl;
  129.  
  130.     ck = clock();
  131.     minv2(vva, vvb, n,eps);
  132.     cout << "minv2  : " << clock()-ck << endl;
  133.  
  134.     delete [] x;
  135.     delete [] y;
  136.     system("pause");
  137.     return 0;
  138. }

 

 上面要做輸出還是幹嘛的自己再加,這裡不附。

 

 vc2010 , release 預設參數(含O2) 執行結果如下

minv1 : 7328
minv2 : 14276

Code::Blocks 10.5 (gcc4.4.1) ,-O2,執行結果如下

minv1 : 8423
minv2 : 7143

 

針對這結果其實蠻意外的,valarray 在 vc 底下之操作並沒顯得優勢 ,反之所費時間約為一般陣列存取之兩倍(其實網路上找得到有人針對 vc 下之 valarray 存取做處理,但速度也只與 pointer + heap 相仿);而 g++ 下之 valarray 確實有操作上之優勢

假設變數是

double tmp = 2.0;

valarray<double> a(10), b(10);

但 valarray 很尷尬的是,我們只能寫成

b = tmp * sin(a);

這樣上述的精簡技巧就完全不能用,因實際要做的可能只會是

b[2:9] = tmp * sin(a [2:9] )

 

至於迴圈精簡後的程式碼,下一份供參考。

 

Code Snippet
  1. template<typename T>
  2. bool minv(const T* src, T* dst, const size_t n, const T eps)
  3. {
  4.     const size_t n2=n*n;
  5.     size_t i, j, k, is;
  6.     T pivot, tmp;
  7.  
  8.     T * mbk = new T [n2];
  9.     copy(src, src+n2, mbk); // copy src to mbk
  10.     fill(dst, dst+n2, 0); // set dst to unit matrix
  11.     for(i=0; i<n; ++i) dst[i*n+i]=1;
  12.  
  13.     for(k=0; k<n; ++k){ // execute k times
  14.  
  15.         pivot= abs(mbk[k*n+k]), is=k;
  16.         for(i=k+1; i<n; ++i)  // find pivot
  17.             if(pivot < (tmp = abs(mbk[i*n+k])))
  18.                 pivot=tmp, is=i;
  19.  
  20.         if(pivot<=eps){ // has no solution
  21.             delete [] mbk;
  22.             return false;
  23.         }
  24.  
  25.         if(is!=k){ // swap row(is,k)
  26.             swap_ranges(mbk+k*n+k, mbk+k*n+n, mbk+is*n+k);
  27.             swap_ranges(dst+k*n, dst+k*n+n, dst+is*n);
  28.         }
  29.  
  30.         pivot = mbk[k*n+k]; // adjust pivot to 1
  31.         for(j=k; j<n; ++j)    mbk[k*n+j]/=pivot;
  32.         for(j=0; j<n; ++j)    dst[k*n+j]/=pivot; 
  33.  
  34.         // elimination
  35.         for(i=0; i<n; ++i){
  36.             if(i!=k){
  37.                 tmp = -mbk[i*n+k];
  38.                 for(j=0; j<n; ++j) dst[i*n+j]+=tmp*dst[k*n+j];
  39.                 for(j=k; j<n; ++j) mbk[i*n+j]+=tmp*mbk[k*n+j];                
  40.             }
  41.         }
  42.     }
  43.     delete mbk;
  44.     return true;
  45. }

 

一些更細節的問題,像是回圈合併,這個就不再做撰碼討論。

arrow
arrow
    全站熱搜

    edisonx 發表在 痞客邦 留言(2) 人氣()