Hi, it’s nice to see strassen has attracted attention again. I would like to know which hardware have you used and how many cores have you used?
Actually, it’s easy to implement strassen in TVM, and I have tested this algorithm with two different implementations.
TE version:
def strassen_gemm(N):
def gemm(A, B, N, name=""):
global GEMM_COUNT
if name != "":
name += "G%d_" % GEMM_COUNT
GEMM_COUNT += 1
if (N > DIRECT_SIZE):
return strassen(A, B, N, name)
else:
return direct(A, B, N, name)
def direct(A, B, N, name):
k = tvm.reduce_axis((0, N))
C = tvm.compute(A.shape, lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=k),
name=name+'C')
return C
def split(A, new_n, ori_name="Matrix"):
A11 = tvm.compute((new_n, new_n),
lambda i, j: A[i][j], name=ori_name+"11")
A12 = tvm.compute((new_n, new_n),
lambda i, j: A[i][j+new_n], name=ori_name+"12")
A21 = tvm.compute((new_n, new_n),
lambda i, j: A[i+new_n][j], name=ori_name+"21")
A22 = tvm.compute((new_n, new_n),
lambda i, j: A[i+new_n][j+new_n], name=ori_name+"22")
return A11, A12, A21, A22
def sub(A, B, N, name):
C = tvm.compute((N, N),
lambda i, j: A[i][j] - B[i][j], name=name)
return C
def add(A, B, N, name):
C = tvm.compute((N, N),
lambda i, j: A[i][j] + B[i][j], name=name)
return C
def strassen(A, B, N, name):
global GEMM_LEVEL
new_n = int(N / 2)
A11, A12, A21, A22 = split(A, new_n, name+"A")
B11, B12, B21, B22 = split(B, new_n, name+"B")
S1 = sub(B12, B22, new_n, name+"S1")
S2 = add(A11, A12, new_n, name+"S2")
S3 = add(A21, A22, new_n, name+"S3")
S4 = sub(B21, B11, new_n, name+"S4")
S5 = add(A11, A22, new_n, name+"S5")
S6 = add(B11, B22, new_n, name+"S6")
S7 = sub(A12, A22, new_n, name+"S7")
S8 = add(B21, B22, new_n, name+"S8")
S9 = sub(A11, A21, new_n, name+"S9")
S10 = add(B11, B12, new_n, name+"S10")
level = GEMM_LEVEL
GEMM_LEVEL += 1
P1 = gemm(A11, S1, new_n, name+"L%d_"%level)
P2 = gemm(S2, B22, new_n, name+"L%d_"%level)
P3 = gemm(S3, B11, new_n, name+"L%d_"%level)
P4 = gemm(A22, S4, new_n, name+"L%d_"%level)
P5 = gemm(S5, S6, new_n, name+"L%d_"%level)
P6 = gemm(S7, S8, new_n, name+"L%d_"%level)
P7 = gemm(S9, S10, new_n, name+"L%d_"%level)
C11 = tvm.compute((new_n, new_n),
lambda i, j: P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j], name=name+"C11")
C12 = add(P1, P2, new_n, name+"C12")
C21 = add(P3, P4, new_n, name+"C21")
C22 = tvm.compute((new_n, new_n),
lambda i, j: P5[i][j] + P1[i][j] - P3[i][j] - P7[i][j], name=name+"C22")
C = tvm.compute((N, N),
lambda i, j: tvm.if_then_else(i < new_n,
tvm.if_then_else(j < new_n, C11[i][j], C12[i][j-new_n]),
tvm.if_then_else(j < new_n, C21[i-new_n][j], C22[i-new_n][j-new_n])),
name=name+"C")
return C
A = tvm.placeholder((N, N), name="A")
B = tvm.placeholder((N, N), name="B")
C = gemm(A, B, N)
sch = tvm.create_schedule(C.op)
return sch, [A, B, C]
Relay Version(I have also tried an implementation of merging the gemm of 7 sub-matrix to a single batch_matmul):
def strassen_gemm(N, K, M, max_level=1):
# A [N, K]
# B [K, M]
# C [N, M]
def gemm(A, B, N, K, M, level):
if (level < max_level and N % 2 == 0 and
K % 2 == 0 and M % 2 == 0):
return strassen(A, B, N, K, M, level)
else:
return direct(A, B, N, K, M)
def direct(A, B, N, K, M):
C = relay.nn.dense(A, relay.transpose(B, [1, 0]))
return C
def split(A, new_x, new_y):
A11 = relay.strided_slice(A, [0, 0], [new_x, new_y])
A12 = relay.strided_slice(A, [0, new_y], [new_x, new_y*2])
A21 = relay.strided_slice(A, [new_x, 0], [new_x*2, new_y])
A22 = relay.strided_slice(A, [new_x, new_y], [new_x*2, new_y*2])
return A11, A12, A21, A22
def strassen(A, B, N, K, M, level):
new_n = int(N / 2)
new_k = int(K / 2)
new_m = int(M / 2)
A11, A12, A21, A22 = split(A, new_n, new_k)
B11, B12, B21, B22 = split(B, new_k, new_m)
S1 = B12 - B22
P1 = gemm(A11, S1, new_n, new_k, new_m, level+1)
S2 = A11 + A12
P2 = gemm(S2, B22, new_n, new_k, new_m, level+1)
C12 = P1 + P2
S3 = A21 + A22
P3 = gemm(S3, B11, new_n, new_k, new_m, level+1)
S4 = B21 - B11
P4 = gemm(A22, S4, new_n, new_k, new_m, level+1)
C21 = P3 + P4
S5 = A11 + A22
S6 = B11 + B22
P5 = gemm(S5, S6, new_n, new_k, new_m, level+1)
S7 = A12 - A22
S8 = B21 + B22
P6 = gemm(S7, S8, new_n, new_k, new_m, level+1)
C11 = P5 + P4 - P2 + P6
S9 = A11 - A21
S10 = B11 + B12
P7 = gemm(S9, S10, new_n, new_k, new_m, level+1)
C22 = P5 + P1 - P3 - P7
C1 = relay.concatenate([C11, C12], 1)
C2 = relay.concatenate([C21, C22], 1)
C = relay.concatenate([C1, C2], 0)
return C
def strassen_merge(A, B, N):
new_n = int(N / 2)
A11, A12, A21, A22 = split(A, new_n)
B11, B12, B21, B22 = split(B, new_n)
S1 = B12 - B22
S2 = A11 + A12
S3 = A21 + A22
S4 = B21 - B11
S5 = A11 + A22
S6 = B11 + B22
S7 = A12 - A22
S8 = B21 + B22
S9 = A11 - A21
S10 = B11 + B12
if new_n > direct_size:
P1 = gemm(A11, S1, new_n)
P2 = gemm(S2, B22, new_n)
P3 = gemm(S3, B11, new_n)
P4 = gemm(A22, S4, new_n)
P5 = gemm(S5, S6, new_n)
P6 = gemm(S7, S8, new_n)
P7 = gemm(S9, S10, new_n)
else:
Merge_A = []
for a in [A11, S2, S3, A22, S5, S7, S9]:
Merge_A.append(relay.expand_dims(a, 0))
Merge_A = relay.concatenate(Merge_A, 0)
Merge_B = []
for b in [S1, B22, B11, S4, S6, S8, S10]:
Merge_B.append(relay.expand_dims(b, 0))
Merge_B = relay.concatenate(Merge_B, 0)
Merge_C = relay.nn.batch_matmul(Merge_A, relay.transpose(Merge_B, [0, 2, 1]))
ss = relay.split(Merge_C, 7)
P1 = relay.reshape(ss[0], [new_n, new_n])
P2 = relay.reshape(ss[1], [new_n, new_n])
P3 = relay.reshape(ss[2], [new_n, new_n])
P4 = relay.reshape(ss[3], [new_n, new_n])
P5 = relay.reshape(ss[4], [new_n, new_n])
P6 = relay.reshape(ss[5], [new_n, new_n])
P7 = relay.reshape(ss[6], [new_n, new_n])
C11 = P5 + P4 - P2 + P6
C12 = P1 + P2
C21 = P3 + P4
C22 = P5 + P1 - P3 - P7
C1 = relay.concatenate([C11, C12], 1)
C2 = relay.concatenate([C21, C22], 1)
C = relay.concatenate([C1, C2], 0)
return C
A = relay.var("A", shape=(N, K))
B = relay.var("B", shape=(K, M))
C = gemm(A, B, N, K, M, 0)
return A, B, C
The evaluation performance is not so good in the end. Only in a 4 cores 1024*1024*1024 case with direct_size = 512, I get better performance with strassen.
I think there are several reasons for this:
- The TE version contains too much stages, which makes it hard to schedule, even we have the auto_schedule tool Ansor.
- The Relay version contains some unnatural
sliceandconcat, which are not so friendly for the memory access. - Op trends to perform better in gemm with a larger size. When we split a single gemm to 7 sub-matrix, these gemm with smaller size are likely to perform lower GFlops.
- MNN manages it’s memory access and compute threads well. It can even run the 7 sub-matrix gemm in parallel, while TVM cannot support inter_op parallelism.
- For the strassen algorithm itself, in my understanding it does save computation in single thread running(in theory can reduce from O(3) to O(2.7)), but when we take it to a multi-thread situation I think it will not be so beneficial.
So my conclusion is:
- Strassen should be more powerful with little CPU cores, e.g. in a ARM CPU with only 4 or 8 cores, which is just the target hardware of MNN. In a Intel CPU with more cores, I don’t think we can benefit from strassen.
- MNN does have better memory/thread management since it’s directly written in C. TVM seems not able to do the same thing with codegen.