Strassen Algorithm for Dense

In Alibaba’s MNN:, which use Strassen algorithm for Dense op. Could we leverage it? Anybody does benchmark for it?

One issue is that we currently don’t have many popular models implemented that rely heavily on dense ops. Vision models may not be a good reference benchmark in this case.

However, I think we can do something, if strassen is useful. Because many custom models uses Dense heavily.

@FrozenGene,do you have any updates about this topic?

I did some tests, and the result is that tvm cannot exceed MNN.

The performance can not beyond would have many reasons, but I think strassen algorithm is not one key part. @jcf94 has done some experiment on this.

@FrozenGene Thank you for your reply. I refer to this tutorial for my experiments:, and this tutorial seems to be written by you. I tried [1024, 1024] * [1024, 1024], [2048, 2048] * [2048, 2048], [256, 256] * [256, 256], [128, 128] * [128, 128], etc, and tvm does not exceed strassen algorithm at all. I think tvm should be able to get better performance and beyond strassen algorithm, but I don’t know what to do. Do you have any ideas? Thanks a lot

@jcf94 Can you briefly describe the results of your experiments? Thanks a lot

Hi, it’s nice to see strassen has attracted attention again. I would like to know which hardware have you used and how many cores have you used?

Actually, it’s easy to implement strassen in TVM, and I have tested this algorithm with two different implementations.

TE version:

def strassen_gemm(N):
    def gemm(A, B, N, name=""):
        global GEMM_COUNT
        if name != "":
            name += "G%d_" % GEMM_COUNT
            GEMM_COUNT += 1
        if (N > DIRECT_SIZE):
            return strassen(A, B, N, name)
            return direct(A, B, N, name)

    def direct(A, B, N, name):
        k = tvm.reduce_axis((0, N))
        C = tvm.compute(A.shape, lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=k),
        return C

    def split(A, new_n, ori_name="Matrix"):
        A11 = tvm.compute((new_n, new_n),
            lambda i, j: A[i][j], name=ori_name+"11")
        A12 = tvm.compute((new_n, new_n),
            lambda i, j: A[i][j+new_n], name=ori_name+"12")
        A21 = tvm.compute((new_n, new_n),
            lambda i, j: A[i+new_n][j], name=ori_name+"21")
        A22 = tvm.compute((new_n, new_n),
            lambda i, j: A[i+new_n][j+new_n], name=ori_name+"22")
        return A11, A12, A21, A22

    def sub(A, B, N, name):
        C = tvm.compute((N, N),
            lambda i, j: A[i][j] - B[i][j], name=name)
        return C

    def add(A, B, N, name):
        C = tvm.compute((N, N),
            lambda i, j: A[i][j] + B[i][j], name=name)
        return C

    def strassen(A, B, N, name):
        global GEMM_LEVEL
        new_n = int(N / 2)

        A11, A12, A21, A22 = split(A, new_n, name+"A")
        B11, B12, B21, B22 = split(B, new_n, name+"B")

        S1 = sub(B12, B22, new_n, name+"S1")
        S2 = add(A11, A12, new_n, name+"S2")
        S3 = add(A21, A22, new_n, name+"S3")
        S4 = sub(B21, B11, new_n, name+"S4")
        S5 = add(A11, A22, new_n, name+"S5")
        S6 = add(B11, B22, new_n, name+"S6")
        S7 = sub(A12, A22, new_n, name+"S7")
        S8 = add(B21, B22, new_n, name+"S8")
        S9 = sub(A11, A21, new_n, name+"S9")
        S10 = add(B11, B12, new_n, name+"S10")

        level = GEMM_LEVEL
        GEMM_LEVEL += 1
        P1 = gemm(A11, S1, new_n, name+"L%d_"%level)
        P2 = gemm(S2, B22, new_n, name+"L%d_"%level)
        P3 = gemm(S3, B11, new_n, name+"L%d_"%level)
        P4 = gemm(A22, S4, new_n, name+"L%d_"%level)
        P5 = gemm(S5, S6, new_n, name+"L%d_"%level)
        P6 = gemm(S7, S8, new_n, name+"L%d_"%level)
        P7 = gemm(S9, S10, new_n, name+"L%d_"%level)

        C11 = tvm.compute((new_n, new_n),
                lambda i, j: P5[i][j] + P4[i][j] - P2[i][j] + P6[i][j], name=name+"C11")
        C12 = add(P1, P2, new_n, name+"C12")
        C21 = add(P3, P4, new_n, name+"C21")
        C22 = tvm.compute((new_n, new_n),
                lambda i, j: P5[i][j] + P1[i][j] - P3[i][j] - P7[i][j], name=name+"C22")

        C = tvm.compute((N, N),
                lambda i, j: tvm.if_then_else(i < new_n,
                    tvm.if_then_else(j < new_n, C11[i][j], C12[i][j-new_n]),
                    tvm.if_then_else(j < new_n, C21[i-new_n][j], C22[i-new_n][j-new_n])),
        return C

    A = tvm.placeholder((N, N), name="A")
    B = tvm.placeholder((N, N), name="B")
    C = gemm(A, B, N)
    sch = tvm.create_schedule(C.op)
    return sch, [A, B, C]

Relay Version(I have also tried an implementation of merging the gemm of 7 sub-matrix to a single batch_matmul):

def strassen_gemm(N, K, M, max_level=1):
    # A [N, K]
    # B [K, M]
    # C [N, M]
    def gemm(A, B, N, K, M, level):
        if (level < max_level and N % 2 == 0 and
                K % 2 == 0 and M % 2 == 0):
            return strassen(A, B, N, K, M, level)
            return direct(A, B, N, K, M)

    def direct(A, B, N, K, M):
        C = relay.nn.dense(A, relay.transpose(B, [1, 0]))
        return C

    def split(A, new_x, new_y):
        A11 = relay.strided_slice(A, [0, 0], [new_x, new_y])
        A12 = relay.strided_slice(A, [0, new_y], [new_x, new_y*2])
        A21 = relay.strided_slice(A, [new_x, 0], [new_x*2, new_y])
        A22 = relay.strided_slice(A, [new_x, new_y], [new_x*2, new_y*2])
        return A11, A12, A21, A22

    def strassen(A, B, N, K, M, level):
        new_n = int(N / 2)
        new_k = int(K / 2)
        new_m = int(M / 2)

        A11, A12, A21, A22 = split(A, new_n, new_k)
        B11, B12, B21, B22 = split(B, new_k, new_m)

        S1 = B12 - B22
        P1 = gemm(A11, S1, new_n, new_k, new_m, level+1)
        S2 = A11 + A12
        P2 = gemm(S2, B22, new_n, new_k, new_m, level+1)
        C12 = P1 + P2

        S3 = A21 + A22
        P3 = gemm(S3, B11, new_n, new_k, new_m, level+1)
        S4 = B21 - B11
        P4 = gemm(A22, S4, new_n, new_k, new_m, level+1)
        C21 = P3 + P4

        S5 = A11 + A22
        S6 = B11 + B22
        P5 = gemm(S5, S6, new_n, new_k, new_m, level+1)
        S7 = A12 - A22
        S8 = B21 + B22
        P6 = gemm(S7, S8, new_n, new_k, new_m, level+1)
        C11 = P5 + P4 - P2 + P6

        S9 = A11 - A21
        S10 = B11 + B12
        P7 = gemm(S9, S10, new_n, new_k, new_m, level+1)
        C22 = P5 + P1 - P3 - P7

        C1 = relay.concatenate([C11, C12], 1)
        C2 = relay.concatenate([C21, C22], 1)
        C = relay.concatenate([C1, C2], 0)
        return C

    def strassen_merge(A, B, N):
        new_n = int(N / 2)

        A11, A12, A21, A22 = split(A, new_n)
        B11, B12, B21, B22 = split(B, new_n)

        S1 = B12 - B22
        S2 = A11 + A12
        S3 = A21 + A22
        S4 = B21 - B11
        S5 = A11 + A22
        S6 = B11 + B22
        S7 = A12 - A22
        S8 = B21 + B22
        S9 = A11 - A21
        S10 = B11 + B12

        if new_n > direct_size:
            P1 = gemm(A11, S1, new_n)
            P2 = gemm(S2, B22, new_n)
            P3 = gemm(S3, B11, new_n)
            P4 = gemm(A22, S4, new_n)
            P5 = gemm(S5, S6, new_n)
            P6 = gemm(S7, S8, new_n)
            P7 = gemm(S9, S10, new_n)
            Merge_A = []
            for a in [A11, S2, S3, A22, S5, S7, S9]:
                Merge_A.append(relay.expand_dims(a, 0))
            Merge_A = relay.concatenate(Merge_A, 0)

            Merge_B = []
            for b in [S1, B22, B11, S4, S6, S8, S10]:
                Merge_B.append(relay.expand_dims(b, 0))
            Merge_B = relay.concatenate(Merge_B, 0)

            Merge_C = relay.nn.batch_matmul(Merge_A, relay.transpose(Merge_B, [0, 2, 1]))
            ss = relay.split(Merge_C, 7)
            P1 = relay.reshape(ss[0], [new_n, new_n])
            P2 = relay.reshape(ss[1], [new_n, new_n])
            P3 = relay.reshape(ss[2], [new_n, new_n])
            P4 = relay.reshape(ss[3], [new_n, new_n])
            P5 = relay.reshape(ss[4], [new_n, new_n])
            P6 = relay.reshape(ss[5], [new_n, new_n])
            P7 = relay.reshape(ss[6], [new_n, new_n])

        C11 = P5 + P4 - P2 + P6
        C12 = P1 + P2
        C21 = P3 + P4
        C22 = P5 + P1 - P3 - P7

        C1 = relay.concatenate([C11, C12], 1)
        C2 = relay.concatenate([C21, C22], 1)
        C = relay.concatenate([C1, C2], 0)
        return C

    A = relay.var("A", shape=(N, K))
    B = relay.var("B", shape=(K, M))
    C = gemm(A, B, N, K, M, 0)
    return A, B, C

The evaluation performance is not so good in the end. Only in a 4 cores 1024*1024*1024 case with direct_size = 512, I get better performance with strassen.

I think there are several reasons for this:

  1. The TE version contains too much stages, which makes it hard to schedule, even we have the auto_schedule tool Ansor.
  2. The Relay version contains some unnatural slice and concat, which are not so friendly for the memory access.
  3. Op trends to perform better in gemm with a larger size. When we split a single gemm to 7 sub-matrix, these gemm with smaller size are likely to perform lower GFlops.
  4. MNN manages it’s memory access and compute threads well. It can even run the 7 sub-matrix gemm in parallel, while TVM cannot support inter_op parallelism.
  5. For the strassen algorithm itself, in my understanding it does save computation in single thread running(in theory can reduce from O(3) to O(2.7)), but when we take it to a multi-thread situation I think it will not be so beneficial.

So my conclusion is:

  1. Strassen should be more powerful with little CPU cores, e.g. in a ARM CPU with only 4 or 8 cores, which is just the target hardware of MNN. In a Intel CPU with more cores, I don’t think we can benefit from strassen.
  2. MNN does have better memory/thread management since it’s directly written in C. TVM seems not able to do the same thing with codegen.

@jcf94 has explained very well for strassen algorithm. The link you posted is I wrote. However, we should notice that my post is not to show the best performance TVM could achieve, just show how easy TVM could a reasonable performance (beyond numpy).

If we still want to improve performance, we still could dig it. For example adding auto_unroll configuration / more split levels and so on. However, I think this is should be completed by our AutoTVM v2.0 (Auto Scheduler). You could try our auto scheduler. Simple matmul using topi should be upstreamed completely, right? cc @jcf94

Thank you very much for your reply.

The hardware I use is AArch64 CPU with 8 cores. I refer to this tutorial to deploy tvm: c++ thread that load and use tvm library is bound to 3 intermediate frequency cpus, and TVM_NUM_THREADS is set to 1(There is a question that confuses me: the larger the TVM_NUM_THREADS, the worse the performance, so TVM_NUM_THREADS is set to 1.I did not figure out why is the optimal TVM_NUM_THREADS not 3 or 8.)

According to your conclusion, I think it is not easy to make tvm beyond MNN on my current hardware(ARM CPU, 8 cores, not want to occupy all the cpus), which makes me feel a little frustrated.But I will still make some efforts , such as try Ansor.If there is any progress, I will be happy to discuss further with you.

Thank you very much for your reply.

Ansor looks great, I am very interested, I will try it.If there is any progress, I will be happy to discuss further with you.

I don’t think u should set TVM_NUM_THREADS on arm because of arm’s BIG LITTLE architecture. I think you should call runtime.config_thread_pool to complete the core binding work. Another thing is we shouldn’t make tvm worker thread run different frequency cpus (aka, one worker thread is in the BIG cpu, one worker thread is in the LITTLE cpu), this will bring worse performance.

1 Like

Thank you very much for your reply.

As I said before, I refer to this tutorial to deploy tvm: I export function as a library first, then load and call the function in C++.

According to your suggestion, I set the cpu affinity in this way before calling tvm::runtime::Module::LoadFromFile:

tvm::runtime::threading::ThreadGroup::AffinityMode mode = static_casttvm::runtime::threading::ThreadGroup::AffinityMode(static_cast(-1)); tvm::runtime::ThreadPool::ThreadLocal()->UpdateWorkerConfiguration(mode, 4);

The frequency of each of my CPU is shown below:

index: 7  freqs: 3130000
index: 4  freqs: 2544000
index: 5  freqs: 2544000
index: 6  freqs: 2544000
index: 0  freqs: 2045000
index: 1  freqs: 2045000
index: 2  freqs: 2045000
index: 3  freqs: 2045000

Then, I unset TVM_NUM_THREADS and tested many times.Compared with before(TVM_NUM_THREADS=1), the performance is indeed better. However, the time-consuming fluctuation is relatively large. For 256 * 256 * 256, the minimum time-consuming can reach 1745us, and the maximum time-consuming can reach 10971us.

On your case, current code is will call 4 cores (id 0 ~ 3). So parallel brings you better performance.

About time consuming functions, Do you use auto tvm? If you use auto tvm, the default cpu TVM uses is big core (that is index 7). If you decide to use 4 little cores, you should make auto tvm use these 4 little cores too. One elegant way is we should have thread_mod to make users set (see link: autotvm.RPCRunner and TVM_NUM_THREADS). Current workaround could be done we disable core 4, 5, 6, 7 on devices temporally. (We indeed to provide one interface for users how to control big / little cores when to tune).

Thank you for your reply.

Regarding time-consuming fluctuations, I didn’t make it clear. After autotvm tune is completed, I picked the best record for time-consuming testing, and its time-consuming fluctuates significantly.I calculate the time difference between the start and the end to get the time-consuming.

struct timeval curTime1;

gettimeofday(&curTime1, NULL);

size_t milli_start = curTime1.tv_sec*1000000 + curTime1.tv_usec;

tvm::runtime::TVMRetValue ret = f(x, y, z);

struct timeval curTime2;

gettimeofday(&curTime2, NULL);

size_t milli_end = curTime2.tv_sec*1000000 + curTime2.tv_usec;

size_t run_time = milli_end - milli_start;

However, the time-consuming of strassen algorithm does not fluctuate significantly. So I am curious whether time-consuming fluctuation is related to tvm, or it is just caused by cpu load changes(After all, cpu is not dedicated).

If you want to measure it more robust, you should run it more times and calculate its average time. For example you could run 1000 times.