Preface
Strassen Algorithm 發展到後期開始有所變型,諸如
(1) 將展開方式做為調整
(2) 運算順序做為調整
(3) 加入多行緒
(4) n 使用閥值
(5) 改善必需為 2 的整數次羃方陣之限制
(6) 展開更多項次
本文只講原型,其他有興趣可翻翻文獻。以往被認為最快的 Coppersmith–Winograd_algorithm ,今年 (2012) 有篇文獻似乎挑戰其 O(n2.3737),Multiplying matrices faster than coppersmith-winograd,也是基於 Coppersminth-Winograd 便是。
Basic
拿 C[2][2] = A[2][2] * B[2][2] 來說,展開來共用了 8 個乘法、4 個加法;經過化簡後變成 11 個加法、7個減法、7個乘法, 關鍵是拿多出來 14 個 cache hit 較高的加減法,換取 1 個乘法 。假設
| m11 m12 | | a11 a12 | | b11 b12 |
| | = | | * | |
| m21 m22 | | a21 a22 | | b21 b22 |
原本公式是
m11 = a11 * b11 + a12 * b21
m12 = a11 * b12 + a12 * b22
m21 = a21 * b11 + a22 * b21
m22 = a21 * b12 + a22 * b22
使用了 8 個乘法與 4 個加法,重新表達一下
p1 = (a12 - a22) * (b21 + b22)
p2 = (a11 + a22) * (b11 + b22)
p3 = (a11 - a21) * (b11 + b12)
p4 = (a11 + a12) * b22
p5 = a11 * (b12 - b22)
p6 = a22 * (b21 - b11)
p7 = (a21 + a22) * b11
m11 = p1 + p2 - p4 + p6
m12 = p4 + p5
m21 = p6 + p7
m22 = p2 - p3 + p5 - p7
變成使用了 7 個乘法,18 個加減法 (重點是少了一個乘法),示意碼如下
- double ** mat_strassen_2(double **rst, double **a, double **b)
- {
- double p1 = (a[0][1] - a[1][1]) * ( b[1][0] + b[1][1] );
- double p2 = (a[0][0] + a[1][1]) * ( b[0][0] + b[1][1] );
- double p3 = (a[0][0] - a[1][0]) * ( b[0][0] + b[0][1] );
- double p4 = (a[0][0] + a[0][1]) * b[1][1];
- double p5 = a[0][0] * (b[0][1] - b[1][1]);
- double p6 = a[1][1] * (b[1][0] - b[0][0]);
- double p7 = (a[1][0] + a[1][1]) * b[0][0];
- rst[0][0] = p1 + p2 - p4 + p6;
- rst[0][1] = p4 + p5;
- rst[1][0] = p6 + p7;
- rst[1][1] = p2 - p3 + p5 - p7;
- return rst;
- }
這種模型剛好適合分而治之做法,只是過程中記憶體會不斷 allocate、release,一份較完整又落落長之程式碼如下
- //////////////////////////////////////////////////////////////////////////
- // Source Strassen Algorithm
- double ** mat_strassen_n(
- double **rst, double **a, double **b, int n)
- {
- int i, j, n2;
- double **p1, **p2, **p3, **p4, **p5, **p6, **p7;
- double **a11, **a12, **a21, **a22;
- double **b11, **b12, **b21, **b22;
- double **m11, **m12, **m21, **m22;
- double **arst, **brst;
- if(n==2) return mat_strassen_2(rst, a, b);
- n2 = n/2;
- //
- // allocate memory
- //
- p1 = mat_new(n2, n2), p2 = mat_new(n2, n2), p3 = mat_new(n2, n2);
- p4 = mat_new(n2, n2), p5 = mat_new(n2, n2), p6 = mat_new(n2, n2);
- p7 = mat_new(n2, n2);
- a11 = mat_new(n2, n2), a12 = mat_new(n2, n2);
- a21 = mat_new(n2, n2), a22 = mat_new(n2, n2);
- b11 = mat_new(n2, n2), b12 = mat_new(n2, n2);
- b21 = mat_new(n2, n2), b22 = mat_new(n2, n2);
- m11 = mat_new(n2, n2), m12 = mat_new(n2, n2);
- m21 = mat_new(n2, n2), m22 = mat_new(n2, n2);
- arst = mat_new(n2, n2), brst = mat_new(n2, n2);
- //
- // divide matrix
- //
- for(i=0 ; i<n2; ++i){
- for(j=0 ; j<n2; ++j){
- a11[i][j] = a[i][j];
- a12[i][j] = a[i][n2+j];
- a21[i][j] = a[i+n2][j];
- a22[i][j] = a[i+n2][j+n2];
- b11[i][j] = b[i][j];
- b12[i][j] = b[i][n2+j];
- b21[i][j] = b[i+n2][j];
- b22[i][j] = b[i+n2][j+n2];
- }
- }
- //
- // calculate p1~p7 , m11,m12,m21,m22
- //
- // p1 = (a12 - a22) * (b21 + b22)
- mat_strassen_n(p1,
- mat_sub(arst, a12, a22,n2,n2),
- mat_add(brst, b21, b22,n2,n2),
- n2);
- // p2 = (a11 + a22) * (b11 + b22)
- mat_strassen_n(p2,
- mat_add(arst, a11, a22,n2,n2),
- mat_add(brst, b11, b22,n2,n2),
- n2);
- // p3 = (a11 - a21) * (b11 + b12)
- mat_strassen_n(p3,
- mat_sub(arst, a11, a21,n2,n2),
- mat_add(brst, b11, b12,n2,n2),
- n2);
- // p4 = (a11 + a12) * b22
- mat_strassen_n(p4,
- mat_add(arst, a11, a12,n2,n2),
- b22,
- n2);
- // p5 = a11 * (b12 - b22)
- mat_strassen_n(p5,
- a11,
- mat_sub(brst, b12, b22,n2,n2),
- n2);
- // p6 = a22 * (b21 - b11)
- mat_strassen_n(p6,
- a22,
- mat_sub(brst, b21,b11,n2,n2),
- n2);
- // p7 = (a21 + a22) * b11
- mat_strassen_n(p7,
- mat_add(arst,a21,a22,n2,n2),
- b11,
- n2);
- // m11 = p1 + p2 - p4 + p6
- // m11 = p1 + p2 - (p4 - p6)
- mat_sub(
- m11,
- mat_add(arst, p1, p2, n2,n2),
- mat_sub(brst, p4, p6, n2,n2),
- n2, n2);
- // m12 = p4 + p5
- mat_add( m12, p4, p5, n2, n2);
- // m21 = p6 + p7
- mat_add(m21, p6, p7, n2, n2);
- // m22 = p2 - p3 + p5 - p7
- mat_add(
- m22,
- mat_sub(arst, p2, p3, n2, n2),
- mat_sub(brst, p5, p7, n2, n2),
- n2,n2);
- //
- // record result
- //
- for(i=0; i<n2; ++i){
- for(j=0; j<n2; ++j){
- rst[i][j] = m11[i][j];
- rst[i][j+n2] = m12[i][j];
- rst[i+n2][j] = m21[i][j];
- rst[i+n2][j+n2] = m22[i][j];
- }
- }
- //
- // release memory
- //
- free(p1), free(p2), free(p3), free(p4), free(p5);
- free(p6), free(p7), free(arst), free(brst);
- free(m11), free(m12), free(m21), free(m22);
- return rst;
- }
欲測試效能,於是再寫一份普通的矩陣乘法做為比較
- double ** mat_mul(
- double **rst, double **a, double **b,
- int m, int n, int p)
- {
- int i, j, k;
- double ta;
- memset((void*)*rst, 0, sizeof(rst[0][0])*m*n);
- for(k=0; k<n; ++k){
- for(i=0; i<m; ++i){
- ta = a[i][k];
- for(j=0; j<p; ++j)
- rst[i][j] += ta * b[k][j];
- }
- }
- return rst;
- }
很遺憾的是,拿 N = 256 來測時,發現上述的 mat_strassen 跑得慢非常多!要跑出不錯效果時,只需將一開頭稍加改過即可。
- double ** mat_strassen_n(
- double **rst, double **a, double **b, int n)
- {
- /* some declare */
- // if(n==2) return mat_strassen_2(rst, a, b); // comment this!
- if(n <= 128) return mat_mul(rst, a, b, n, n, n); // 設閥值
- /* others */
- }
在維度小於 128 時,直接呼叫內部之 mat_mul 進行 ( 所以一開始的 mat_strassen_2 其實是白寫的 ),閥值要設多少是個可探討的問題,此處不深入。
重新測過,這時候使用 mat_strassen_n 確實會比 mat_mul 快,但實際上提昇有限,甚至沒以 純 cache hit 方式改善 來得明顯,因一般真的拿來用的話不會用上面程式碼撰之,它有幾項缺點
[1] memory 使用量大
[2] allocate、release 動作量大 < 可改善 >
[3] 只限於 2^n 維度之方陣 < 可改善 >
另也有人提出,將原本的 10 加法 8 減法 7 乘法,換成 7 加法8 減法 7 乘法,類似的研究應也不少。針對 strassen Algorithm ,本文點到為止,其他的待有機會時再研究。
Other issues
Matrix 乘法議題較常探討的圍繞在五個議題
[1] cache hit , TLB hit
[2] Strassen’s algorithm (1969)
[3] Winogra algorithm (1980) : 使用 FFT。
[4] Coppersmith–Winograd Algotihm (1990) : 使用 FFT
[5] 平行化問題 : 平行化處理筆者接觸不深,故此處不予探討。可參考 openmp
Reference
some Algorithms Introduction
[1] Matrix multiplication in wiki
[2] Strassen Algorithm in wiki
[3] Coppersmith–Winograd in wiki
About Cache
[1] StackOverflow : Programmatically get the cache line size ?
[2] msdn : GetLogicalProcessorInformation function
[5] wiki : CPU cache
[7] wiki : compiler optimization
[8] A MapReduce Algorithm for Matrix Multiplication
Others
[1] openmp
[2] OpenBLAS
[3] GSL
[4] MIT leture : High Performace Computing , Performace Evaluation, and Trends
[5] Pdf :A Case Study on High Performance Matrix Multiplication
[6] Origin2000™ and Onyx2™ Performance Tuning and Optimization Guide
[7] Pdf : The Cache Performance and Optimizations of Blocked Algorithms
留言列表