Strassen Strassen算法

问题

用Strassen算法计算两个nn阶矩阵相乘C=A×BC = A \times B

行列式

nn阶行列式(Determinant):

a11a12a1na21a22a2nan1an2ann\begin{vmatrix} a_{11} & a_{12} & \cdots & a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ \vdots & \vdots & & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \end{vmatrix}

表示nn个元素的乘积

a1j1a2j2anjna_{1 j_{1}} a_{2 j_{2}} \cdots a_{n j_{n}}

的代数和。其中j1,j2,,jnj_{1}, j_{2}, \dots, j_{n}1,2,,n1, 2, \dots, n的一个排列。当j1,j2,,jnj_{1}, j_{2}, \dots, j_{n}是奇排列时该项带负号,当j1,j2,,jnj_{1}, j_{2}, \dots, j_{n}是偶排列时该项带正号。对于元素ajpjqa_{j_{p} j_{q}}下标的两个数字,若1p<qn1 \leq p \lt q \leq njp>jqj_{p} \gt j_{q}则称这两个有序的数[jp,jq][ j_{p}, j_{q} ]是逆序对。逆序对的数量称为逆序对数。若a1j1a2j2anjna_{1 j_{1}} a_{2 j_{2}} \cdots a_{n j_{n}}中所有元素下标的逆序对数为偶数,则称排列j1,j2,,jnj_{1}, j_{2}, \dots, j_{n}为偶排列;否则为奇排列。

a11a12a1na21a22a2nan1an2ann=j1j2jn(1)τ(j1j2jn)a1j1a2j2anjn\begin{vmatrix} a_{11} & a_{12} & \cdots & a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ \vdots & \vdots & \quad & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \end{vmatrix} = \sum_{j_{1} j_{2} \cdots j_{n}} (-1)^{ \tau (j_{1} j_{2} \cdots j_{n}) } a_{1 j_{1}} a_{2 j_{2}} \cdots a_{n j_{n}}

其中τ(j1j2jn)\tau (j_{1} j_{2} \cdots j_{n})是行列式的逆序数,j1j2jn\sum_{j_{1} j_{2} \cdots j_{n}}表示对所有nn阶排列求和,该式称为nn阶行列式的完全展开式。

特别的22阶行列式和33阶行列式的完全展开式分别为

abcd=adbc\begin{vmatrix} a & b \\ c & d \end{vmatrix} = a \cdot d - b \cdot c
a11a12a13a21a22a23a31a32a33=a11a22a33+a12a23a31+a13a21a32a13a22a31a12a21a33a11a23a32\begin{vmatrix} a_{11} & a_{12} & a_{13} \\ a_{21} & a_{22} & a_{23} \\ a_{31} & a_{32} & a_{33} \end{vmatrix} = a_{11} a_{22} a_{33} + a_{12} a_{23} a_{31} + a_{13} a_{21} a_{32} - a_{13} a_{22} a_{31} - a_{12} a_{21} a_{33} - a_{11} a_{23} a_{32}

行列式操作和特性:

(1)(1) 经过转置行列式的值不变,即AT=A\begin{vmatrix} A^T \end{vmatrix} = \begin{vmatrix} A \end{vmatrix}。行列式的转置是将AA的行和列交换,得到ATA^{T},转置行列式的任意元素满足aijt=ajia_{ij}^{t} = a_{ji};例如

A33=a11a12a13a21a22a23a31a32a33A33T=a11a21a31a12a22a32a13a23a33A_{33} = \begin{vmatrix} a_{11} & a_{12} & a_{13} \\ a_{21} & a_{22} & a_{23} \\ a_{31} & a_{32} & a_{33} \end{vmatrix} \quad A_{33}^{T} = \begin{vmatrix} a_{11} & a_{21} & a_{31} \\ a_{12} & a_{22} & a_{32} \\ a_{13} & a_{23} & a_{33} \end{vmatrix}

(2)(2) 行列式中的任意两行/列交换位置,行列式的值不变;例如

a11a12a13a21a22a23a31a32a33=a21a22a23a11a12a13a31a32a33=a22a21a23a12a11a13a32a31a33\begin{vmatrix} a_{11} & a_{12} & a_{13} \\ a_{21} & a_{22} & a_{23} \\ a_{31} & a_{32} & a_{33} \end{vmatrix} = \begin{vmatrix} a_{21} & a_{22} & a_{23} \\ a_{11} & a_{12} & a_{13} \\ a_{31} & a_{32} & a_{33} \end{vmatrix} = \begin{vmatrix} a_{22} & a_{21} & a_{23} \\ a_{12} & a_{11} & a_{13} \\ a_{32} & a_{31} & a_{33} \end{vmatrix}

特别的,当两行/列相同时,该行列式的值为00

(3)(3) 某行/列中所有元素若存在公因子kk,则可以将kk提到行列式外;例如

ka11ka12ka1na21a22a2nan1an2ann=ka11a12a1na21a22a2nan1an2ann\begin{vmatrix} k \cdot a_{11} & k \cdot a_{12} & \cdots & k \cdot a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ \vdots & \vdots & & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \end{vmatrix} = k \cdot \begin{vmatrix} a_{11} & a_{12} & \cdots & a_{1n} \\ a_{21} & a_{22} & \cdots & a_{2n} \\ \vdots & \vdots & & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nn} \end{vmatrix}

特别的,某行/列的值全为00,该行列式的值为00;某两行/列的元素对应成比例,行列式的值为00

(4)(4) 某行/列的每个元素是两个元素之和,则可以把行列式拆分为两个行列式之和;例如

a1+b1a2+b2a3+b3c1c2c3d1d2d33=a1a2a3c1c2c3d1d2d33+b1b2b3c1c2c3d1d2d33\begin{vmatrix} a_{1} + b_1 & a_{2} + b_2 & a_{3} + b_3 \\ c_{1} & c_{2} & c_{3} \\ d_{1} & d_{2} & d_{33} \end{vmatrix} = \begin{vmatrix} a_{1} & a_{2} & a_{3} \\ c_{1} & c_{2} & c_{3} \\ d_{1} & d_{2} & d_{33} \end{vmatrix} + \begin{vmatrix} b_{1} & b_{2} & b_{3} \\ c_{1} & c_{2} & c_{3} \\ d_{1} & d_{2} & d_{33} \end{vmatrix}

(5)(5) 把某行/列的kk倍加到另一行/列上,行列式的值不变;例如

a1a2a3b1b2b3c1c2c33=a1a2a3b1+ka1b2+ka2b3+ka3c1c2c33\begin{vmatrix} a_{1} & a_{2} & a_{3} \\ b_{1} & b_{2} & b_{3} \\ c_{1} & c_{2} & c_{33} \end{vmatrix} = \begin{vmatrix} a_{1} & a_{2} & a_{3} \\ b_{1} + k \cdot a_{1} & b_{2} + k \cdot a_{2} & b_{3} + k \cdot a_{3} \\ c_{1} & c_{2} & c_{33} \end{vmatrix}

矩阵

矩阵(Matrix):n×mn \times m个数字组成的nnmm列的表格AnmA_{nm}

Anm=[a11a12a1ma21a22a2man1an2anm]A_{nm} = \begin{bmatrix} a_{11} & a_{12} & \cdots & a_{1m} \\ a_{21} & a_{22} & \cdots & a_{2m} \\ \vdots & \vdots & \ddots & \vdots \\ a_{n1} & a_{n2} & \cdots & a_{nm} \end{bmatrix}

其中第ii行第jj列的元素为aija_{ij}1in,1jm1 \leq i \leq n, 1 \leq j \leq m)。特别的当n=mn = m时称矩阵AAnn阶矩阵或nn阶方阵。

矩阵操作和特性:

(1)(1) 零矩阵:所有元素都为00的矩阵

[000000000]\begin{bmatrix} 0 & 0 & \cdots & 0 \\ 0 & 0 & \cdots & 0 \\ \vdots & \vdots & \ddots & \vdots \\ 0 & 0 & \cdots & 0 \end{bmatrix}

零矩阵记为OO

(2)(2) 若两矩阵AnmA_{nm}BstB_{st}的行和列数量相等,即n=s,m=tn = s, m = t,称。的两矩阵称为同型矩阵。若同型矩阵的所有对应元素也想等,则两矩阵相等。

(3)(3) nn阶方阵AA构成的行列式称为AA的行列式,记作A\begin{vmatrix} A \end{vmatrix}。注意矩阵是一个表格,而行列式经过计算后是一个值。

(4)(4) 矩阵加法:两个同型矩阵可以相加,即Cnm=Anm+BnmC_{nm} = A_{nm} + B_{nm},任意元素满足cij=aij+bijc_{ij} = a_{ij} + b_{ij}1in,1jm1 \leq i \leq n, 1 \leq j \leq m)。矩阵加法满足特性

A+B=B+A(A+B)+C=A+(B+C)A+O=A,AO=A\begin{matrix} A + B = B + A \\ (A + B ) + C = A + (B + C) \\ A + O = A, A - O = A \end{matrix}

(5)(5) 矩阵数乘:矩阵与数可以相乘,即Bnm=kAnmB_{nm} = k \cdot A_{nm},任意元素满足bij=kaijb_{ij} = k \cdot a_{ij}1in,1jm1 \leq i \leq n, 1 \leq j \leq m)。矩阵数乘满足特性

k(mA)=(km)A=m(kA)(k+m)A=kA+mAk(A+B)=kA+kB1A=A0A=O\begin{matrix} k(mA) = (km)A = m(kA) \\ (k + m)A = kA + mA \\ k(A + B) = kA + kB \\ 1 \cdot A = A \\ 0 \cdot A = O \end{matrix}

(6)(6) 矩阵乘法:两个矩阵Anm,BstA_{nm}, B_{st}相乘必须满足条件m=sm = s,即Cnt=Anm×BstC_{nt} = A_{nm} \times B_{st}m=sm = s),任意元素满足cij=k=1maikbkjc_{ij} = \sum_{k=1}^{m} a_{ik} \cdot b_{kj}1in,1jt,1km1 \leq i \leq n, 1 \leq j \leq t, 1 \leq k \leq m)。特别的,若AAnn阶方阵,则kk个矩阵AA相乘的结果记为i=1kA=Ak\prod_{i=1}^{k} A = A^{k},称为AAkk次幂。矩阵乘法满足特性

(AB)C=A(BC)A(B+C)=AB+AC(B+C)A=BA+CA\begin{matrix} (AB)C = A(BC) \\ A(B + C) = AB + AC \\ (B + C)A = BA + CA \end{matrix}

注意一般情况下ABBAAB \neq BA

(7)(7) 矩阵转置:将矩阵AnmA_{nm}的行和列交换得到矩阵AmnTA_{mn}^{T},任意元素满足aijt=ajia_{ij}^{t} = a_{ji}。称AmnTA_{mn}^{T}AA的转置矩阵。矩阵转置满足特性

(A+B)T=AT+BT(kA)T=kAT(AB)T=BTAT(AT)T=A\begin{matrix} (A + B)^{T} = A^{T} + B^{T} \\ (k \cdot A)^{T} = k \cdot A^{T} \\ (A \cdot B)^{T} = B^{T} \cdot A^{T} \\ (A^{T})^{T} = A \end{matrix}

(8)(8) 单位矩阵:nn阶矩阵中,主对角线上的元素都是11,其余元素都是00,称为nn阶单位矩阵,简写作EEEnE_{n}II。即aii=1,aij=0a_{ii} = 1, a_{ij} = 0(其中iji \neq j)。

Ann=[11101201n02112202n0n10n21nn]A_{nn} = \begin{bmatrix} 1_{11} & 0_{12} & \cdots & 0_{1n} \\ 0_{21} & 1_{22} & \cdots & 0_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ 0_{n1} & 0_{n2} & \cdots & 1_{nn} \end{bmatrix}

(9)(9) 数量矩阵:数字kk与单位矩阵EE的积kEk \cdot E称为数量矩阵。

(10)(10) nn阶矩阵的主对角线:即矩阵AnnA_{nn}上的所有元素aiia_{ii}(其中1in1 \leq i \leq n)。所有元素aiia_{ii}连起来称为矩阵的对角线,其中的元素称为对角元素。

(11)(11) 对角矩阵:非对角元素全为00nn阶方阵称为对角矩阵。

(12)(12) 上/下三角矩阵:主对角线以下/上(不包括主对角线)元素都为00nn阶矩阵,即aij=0,i>ja_{ij} = 0, i \gt j(上三角矩阵),aij=0,i<ja_{ij} = 0, i \lt j(下三角矩阵)。

(13)(13) 对称矩阵/反对称矩阵:满足AT=AA^{T} = A(即aijt=ajia_{ij}^{t} = a_{ji})的矩阵为对称矩阵,满足AT=AA^{T} = -A(即aijt=ajia_{ij}^{t} = - a_{ji})的矩阵称为反对称矩阵。

Strassen算法

根据数学定义,计算两个nn阶矩阵相乘,由于cij=k=1naikbkjc_{ij} = \sum_{k=1}^{n} a_{ik} \cdot b_{kj},计算CC中的一个元素的时间复杂度为O(n)O(n)CC中有n2n^2个元素,因此时间复杂度为O(n3)O(n^3)。Strassen算法的时间复杂度比平凡算法更低。

对于nn阶矩阵乘法C=A×BC = A \times B,设nn为偶数,则可以将A,B,CA, B, C三个矩阵划分为44n/2n/2的矩阵,C=A×BC = A \times B转化为

[rstu]=[abcd]×[efgh]\begin{bmatrix} r & s \\ t & u \end{bmatrix} = \begin{bmatrix} a & b \\ c & d \end{bmatrix} \times \begin{bmatrix} e & f \\ g & h \end{bmatrix}

按照矩阵乘法计算方法可知

r=a×e+b×gs=a×f+b×ht=c×e+d×gu=c×f+d×h\begin{matrix} r = a \times e + b \times g \\ s = a \times f + b \times h \\ t = c \times e + d \times g \\ u = c \times f + d \times h \end{matrix}

上面计算中设两个nn阶矩阵相乘的时间复杂度为T(n)T(n),则88次矩阵相乘的时间复杂度为8T(n2)8 \cdot T( \frac{n}{2} )nn阶方阵的加法需要分别计算n2n^2次两个元素之和,因此时间复杂度为O(n2)O(n^2)。由此可知

T(n)=8T(n2)+O(n2)T(n) = 8 \cdot T(\frac{n}{2}) + O(n^2)

通过时间复杂度的推导方法,可以得出T(n)=O(n3)T(n) = O(n^3)。因此分治法的时间复杂度与平凡算法相同。

Strassen算法在分治法基础上设置77个中间矩阵,将上式转化为

p1=a×fa×h=a×(fh)p2=a×h+b×h=(a+b)×hp3=c×e+d×e=(c+d)×ep4=d×gd×e=d×(ge)p5=a×e+a×h+d×e+d×h=(a+d)×(e+h)p6=b×g+b×hd×gd×h=(bd)×(g+h)p7=a×e+a×fc×ec×f=(ac)×(e+f)\begin{matrix} p_1 = a \times f - a \times h = a \times (f - h) \\ p_2 = a \times h + b \times h = (a + b) \times h \\ p_3 = c \times e + d \times e = (c + d) \times e \\ p_4 = d \times g - d \times e = d \times (g - e) \\ p_5 = a \times e + a \times h + d \times e + d \times h = (a + d) \times (e + h) \\ p_6 = b \times g + b \times h - d \times g - d \times h = (b - d) \times (g + h) \\ p_7 = a \times e + a \times f - c \times e - c \times f = (a - c) \times (e + f) \end{matrix}

可得

r=p5+p4p2+p6s=p1+p2t=p3+p4u=p1+p5p3p7\begin{matrix} r = p_5 + p_4 - p_2 + p_6 \\ s = p_1 + p_2 \\ t = p_3 + p_4 \\ u = p_1 + p_5 - p_3 - p_7 \end{matrix}

这样计算矩阵相乘时,只需要计算77次矩阵相乘运算,矩阵间的加减运算的时间复杂度仍然是O(n2)O(n^2)。即有

T(n)=7T(n2)+O(n2)T(n) = 7 \cdot T( \frac{n} {2} ) + O(n^2)

最终推导可得,Strassen算法的时间复杂度为O(nlog27)O(n2.81)O(n^{log_2 7}) \approx O(n^{2.81})

数学复习全书(2013年李永乐李正元考研数学 数学一) - 第二篇 线性代数

源码

Strassen.h

Strassen.cpp

测试

StrassenTest.cpp

Last updated