Preface

Strassen Algorithm 發展到後期開始有所變型,諸如 

(1) 將展開方式做為調整
(2) 運算順序做為調整
(3) 加入多行緒  
(4) n 使用閥值
(5) 改善必需為 2 的整數次羃方陣之限制
(6) 展開更多項次 

本文只講原型,其他有興趣可翻翻文獻。以往被認為最快的 Coppersmith–Winograd_algorithm ,今年 (2012) 有篇文獻似乎挑戰其 O(n2.3737),Multiplying matrices faster than coppersmith-winograd,也是基於 Coppersminth-Winograd 便是。

 

Basic

 

拿 C[2][2] = A[2][2] * B[2][2] 來說,展開來共用了 8 個乘法、4 個加法;經過化簡後變成 11 個加法、7個減法、7個乘法, 關鍵是拿多出來 14 個 cache hit 較高的加減法,換取 1 個乘法 。假設

 

| m11 m12 |   | a11 a12 |   | b11 b12 |
|         | = |         | * |         |
| m21 m22 |   | a21 a22 |   | b21 b22 |

 

原本公式是

m11 = a11 * b11 + a12 * b21
m12 = a11 * b12 + a12 * b22
m21 = a21 * b11 + a22 * b21
m22 = a21 * b12 + a22 * b22

使用了 8 個乘法與 4 個加法,重新表達一下

p1 = (a12 - a22) * (b21 + b22)    
p2 = (a11 + a22) * (b11 + b22)   
p3 = (a11 - a21) * (b11 + b12)  
p4 = (a11 + a12) * b22
p5 = a11 * (b12 - b22) 
p6 = a22 * (b21 - b11) 
p7 = (a21 + a22) * b11 

m11 = p1 + p2 - p4 + p6
m12 = p4 + p5
m21 = p6 + p7
m22 = p2 - p3 + p5 - p7 

變成使用了 7 個乘法,18 個加減法 (重點是少了一個乘法),示意碼如下

Code Snippet
  1. double ** mat_strassen_2(double **rst, double **a, double **b)
  2. {
  3.     double p1 = (a[0][1] - a[1][1]) * ( b[1][0] + b[1][1] );
  4.     double p2 = (a[0][0] + a[1][1]) * ( b[0][0] + b[1][1] );
  5.     double p3 = (a[0][0] - a[1][0]) * ( b[0][0] + b[0][1] );
  6.     double p4 = (a[0][0] + a[0][1]) * b[1][1];
  7.     double p5 = a[0][0] * (b[0][1] - b[1][1]);
  8.     double p6 = a[1][1] * (b[1][0] - b[0][0]);
  9.     double p7 = (a[1][0] + a[1][1]) * b[0][0];
  10.  
  11.     rst[0][0] = p1 + p2 - p4 + p6;
  12.     rst[0][1] = p4 + p5;
  13.     rst[1][0] = p6 + p7;
  14.     rst[1][1] = p2 - p3 + p5 - p7;
  15.  
  16.     return rst;
  17. }

 

這種模型剛好適合分而治之做法,只是過程中記憶體會不斷 allocate、release,一份較完整又落落長之程式碼如下

 

Code Snippet
  1. //////////////////////////////////////////////////////////////////////////
  2. // Source Strassen Algorithm
  3. double ** mat_strassen_n(
  4.     double **rst, double **a, double **b, int n)
  5. {
  6.     int i, j, n2;
  7.     double **p1, **p2, **p3, **p4, **p5, **p6, **p7;
  8.     double **a11, **a12, **a21, **a22;
  9.     double **b11, **b12, **b21, **b22;
  10.     double **m11, **m12, **m21, **m22;
  11.     double **arst, **brst;
  12.  
  13.     if(n==2) return mat_strassen_2(rst, a, b);
  14.  
  15.     n2 = n/2;  
  16.  
  17.     //
  18.     // allocate memory    
  19.     //
  20.     p1 = mat_new(n2, n2), p2 = mat_new(n2, n2), p3 = mat_new(n2, n2);
  21.     p4 = mat_new(n2, n2), p5 = mat_new(n2, n2), p6 = mat_new(n2, n2);
  22.     p7 = mat_new(n2, n2);
  23.  
  24.     a11 = mat_new(n2, n2), a12 = mat_new(n2, n2);
  25.     a21 = mat_new(n2, n2), a22 = mat_new(n2, n2);
  26.     b11 = mat_new(n2, n2), b12 = mat_new(n2, n2);
  27.     b21 = mat_new(n2, n2), b22 = mat_new(n2, n2);    
  28.     m11 = mat_new(n2, n2), m12 = mat_new(n2, n2);
  29.     m21 = mat_new(n2, n2), m22 = mat_new(n2, n2);    
  30.     arst = mat_new(n2, n2), brst = mat_new(n2, n2);
  31.  
  32.     //
  33.     // divide matrix
  34.     //     
  35.     for(i=0 ; i<n2; ++i){
  36.         for(j=0 ; j<n2; ++j){
  37.             a11[i][j] = a[i][j];
  38.             a12[i][j] = a[i][n2+j];
  39.             a21[i][j] = a[i+n2][j];
  40.             a22[i][j] = a[i+n2][j+n2];
  41.  
  42.             b11[i][j] = b[i][j];
  43.             b12[i][j] = b[i][n2+j];
  44.             b21[i][j] = b[i+n2][j];
  45.             b22[i][j] = b[i+n2][j+n2];
  46.         }
  47.     }
  48.  
  49.     //
  50.     // calculate p1~p7 , m11,m12,m21,m22
  51.     //     
  52.  
  53.     //  p1 = (a12 - a22) * (b21 + b22)
  54.     mat_strassen_n(p1,
  55.         mat_sub(arst, a12, a22,n2,n2),
  56.         mat_add(brst, b21, b22,n2,n2),
  57.         n2);
  58.  
  59.     //     p2 = (a11 + a22) * (b11 + b22)
  60.     mat_strassen_n(p2,
  61.         mat_add(arst, a11, a22,n2,n2),
  62.         mat_add(brst, b11, b22,n2,n2),
  63.         n2);
  64.  
  65.     //     p3 = (a11 - a21) * (b11 + b12)
  66.     mat_strassen_n(p3,
  67.         mat_sub(arst, a11, a21,n2,n2),
  68.         mat_add(brst, b11, b12,n2,n2),
  69.         n2);
  70.  
  71.     //     p4 = (a11 + a12) * b22
  72.     mat_strassen_n(p4,
  73.         mat_add(arst, a11, a12,n2,n2),
  74.         b22,
  75.         n2);
  76.  
  77.     // p5 = a11 * (b12 - b22)
  78.     mat_strassen_n(p5,
  79.         a11,
  80.         mat_sub(brst, b12, b22,n2,n2),
  81.         n2);
  82.  
  83.  
  84.     //  p6 = a22 * (b21 - b11)
  85.     mat_strassen_n(p6,
  86.         a22,
  87.         mat_sub(brst, b21,b11,n2,n2),
  88.         n2);
  89.  
  90.     
  91.     //  p7 = (a21 + a22) * b11
  92.     mat_strassen_n(p7,
  93.         mat_add(arst,a21,a22,n2,n2),
  94.         b11,
  95.         n2);
  96.  
  97.     //     m11 = p1 + p2 - p4 + p6
  98.     //  m11 = p1 + p2 - (p4 - p6)
  99.     mat_sub(
  100.         m11,
  101.         mat_add(arst, p1, p2, n2,n2),
  102.         mat_sub(brst, p4, p6, n2,n2),
  103.         n2, n2);
  104.  
  105.     //     m12 = p4 + p5
  106.     mat_add( m12, p4, p5, n2, n2);
  107.  
  108.     //     m21 = p6 + p7
  109.     mat_add(m21, p6, p7, n2, n2);
  110.  
  111.     //     m22 = p2 - p3 + p5 - p7
  112.     mat_add(
  113.         m22,
  114.         mat_sub(arst, p2, p3, n2, n2),
  115.         mat_sub(brst, p5, p7, n2, n2),
  116.         n2,n2);
  117.  
  118.     //
  119.     // record result
  120.     //     
  121.     for(i=0; i<n2; ++i){
  122.         for(j=0; j<n2; ++j){
  123.             rst[i][j]       = m11[i][j];
  124.             rst[i][j+n2]    = m12[i][j];
  125.             rst[i+n2][j]    = m21[i][j];
  126.             rst[i+n2][j+n2] = m22[i][j];
  127.         }
  128.     }
  129.  
  130.     //
  131.     // release memory
  132.     //
  133.     free(p1), free(p2), free(p3), free(p4), free(p5);
  134.     free(p6), free(p7), free(arst), free(brst);
  135.     free(m11), free(m12), free(m21), free(m22);
  136.     return rst;
  137. }

 

欲測試效能,於是再寫一份普通的矩陣乘法做為比較

 

Code Snippet
  1. double ** mat_mul(
  2.     double **rst, double **a, double **b,
  3.     int m, int n, int p)
  4. {
  5.     int i, j, k;
  6.     double ta;
  7.     memset((void*)*rst, 0, sizeof(rst[0][0])*m*n);
  8.  
  9.     for(k=0; k<n; ++k){
  10.         for(i=0; i<m; ++i){
  11.             ta = a[i][k];            
  12.             for(j=0; j<p; ++j)
  13.                 rst[i][j] += ta * b[k][j];
  14.         }
  15.     }
  16.     return rst;
  17. }

 

很遺憾的是,拿 N = 256 來測時,發現上述的 mat_strassen 跑得慢非常多!要跑出不錯效果時,只需將一開頭稍加改過即可。 

Code Snippet
  1. double ** mat_strassen_n(
  2.     double **rst, double **a, double **b, int n)
  3. {
  4.     /* some declare */
  5.  
  6.     // if(n==2) return mat_strassen_2(rst, a, b); // comment this!
  7.     if(n <= 128) return mat_mul(rst, a, b, n, n, n); // 設閥值
  8.  
  9.     /* others */
  10. }

在維度小於 128 時,直接呼叫內部之 mat_mul 進行 ( 所以一開始的 mat_strassen_2 其實是白寫的 ),閥值要設多少是個可探討的問題,此處不深入。

重新測過,這時候使用 mat_strassen_n 確實會比 mat_mul 快,但實際上提昇有限,甚至沒以 純 cache hit 方式改善 來得明顯,因一般真的拿來用的話不會用上面程式碼撰之,它有幾項缺點

 

[1] memory 使用量大

[2] allocate、release  動作量大 < 可改善 >

[3] 只限於 2^n 維度之方陣 < 可改善 > 

 

另也有人提出,將原本的 10 加法 8 減法 7 乘法,換成 7 加法8 減法 7 乘法,類似的研究應也不少。針對 strassen Algorithm ,本文點到為止,其他的待有機會時再研究。  

 

Other issues

 

Matrix 乘法議題較常探討的圍繞在五個議題

 

[1] cache hit , TLB hit

[2] Strassen’s algorithm  (1969)

[3] Winogra algorithm (1980) : 使用 FFT。

[4] Coppersmith–Winograd Algotihm (1990) : 使用 FFT

[5] 平行化問題 : 平行化處理筆者接觸不深,故此處不予探討。可參考 openmp

 

 

Reference

 

some Algorithms Introduction  

[1] Matrix multiplication in wiki

[2] Strassen Algorithm in wiki

[3] Coppersmith–Winograd in wiki

[4] Strassen Algorithm

 

About Cache

[1] StackOverflow : Programmatically get the cache line size ?

[2] msdn : GetLogicalProcessorInformation function

[3] wiki : Cache(computing)

[4] wiki : Cache algorithms

[5] wiki : CPU cache

[6] wiki : CPU 快取 ( 簡中 )

[7] wiki : compiler optimization

[8] A MapReduce Algorithm for Matrix Multiplication

 

Others

[1] openmp 

[2] OpenBLAS

[3]  GSL

[4] MIT leture : High Performace Computing , Performace Evaluation, and Trends

[5] Pdf :A Case Study on High Performance Matrix Multiplication 

[6] Origin2000™ and Onyx2™ Performance Tuning and Optimization Guide

 [7] Pdf : The Cache Performance and Optimizations of Blocked Algorithms 

edisonx 發表在 痞客邦 PIXNET 留言(0) 人氣()