[AutoTVM] How to verify the correctness of different schedule and tile size in autotvm?

In my understanding, autotvm implements operators based on opreation and schedule defination. For example, there are pack and no pack implementations in x86/dense.py. In the process of autotuning, the implementation-defined configs are fed into the actual measurement, but No such correctness verification similar to the tvm.testing.assert_allclose() interface was found in the entire tuning process. So I completely extracted the pack and nopack in x86/dense.py and checked the result of tvm.testing.assert_allclose(), and found:

  1. If the segmented axis is implemented in opreation, the tile size must be divisible, otherwise the calculation result cannot be guaranteed to be correct;
  2. Other axes that are not defined in the opreation in advance can be divisible by not necessarily;
  3. Next, I observed the lower difference between them, and found that the code generated with the tile size that cannot be divisible in 1 is wrong;
  4. So I think this should be a code generation bug.

For example, I test x86/dense.py(nopack implement) M N K = 1 1000 512 and Tm Tn Tk = 1 100 10

produce compute {
  parallel (y.outer.x.outer.fused, 0, 10) {
    produce compute {
      for (z.y.fused.init, 0, 100) {
        compute[ramp(((y.outer.x.outer.fused*1000) + (z.y.fused.init*10)), 1, 10)] = x10(0f)
      }
      for (k, 0, 51) {
        for (z.y.fused, 0, 100) {
          compute[ramp(((y.outer.x.outer.fused*1000) + (z.y.fused*10)), 1, 10)] = (compute[ramp(((y.outer.x.outer.fused*1000) + (z.y.fused*10)), 1, 10)] + (data[ramp((k*10), 1, 10)]*weight[ramp((((y.outer.x.outer.fused*51200) + (z.y.fused*512)) + (k*10)), 1, 10)]))
        }
      }
    }
    for (x.inner, 0, 100) {
      compute[((y.outer.x.outer.fused*100) + x.inner)] = 0f
      for (kk, 0, 10) {
        compute[((y.outer.x.outer.fused*100) + x.inner)] = (compute[((y.outer.x.outer.fused*100) + x.inner)] + compute[(((y.outer.x.outer.fused*1000) + (x.inner*10)) + kk)])
      }
    }
  }
}

Traceback (most recent call last):

  File "OpsGemm/gemm_v3_scheduling.py", line 388, in <module>
    buildandevaluation(s, data, weight, out, a, bt, ct, ctx, ct_np)

  File "OpsGemm/gemm_v3_scheduling.py", line 44, in buildandevaluation
    tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)

  File "/root/tvm/python/tvm/testing.py", line 29, in assert_allclose
    np.testing.assert_allclose(actual, desired, rtol=rtol, atol=atol, verbose=True)

  File "/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 1452, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)

  File "/anaconda3/lib/python3.7/site-packages/numpy/testing/_private/utils.py", line 789, in assert_array_compare
    raise AssertionError(msg)

AssertionError:
Not equal to tolerance rtol=1e-05, atol=1e-07

(mismatch 100.0%)
 x: array([128.66817 , 119.130806, 129.70555 , 126.0419  , 126.232285,
       129.39488 , 128.2362  , 124.842926, 128.9357  , 126.89033 ,
       132.58101 , 128.5313  , 129.82468 , 129.89973 , 125.16623 ,...
 y: array([128.91081 , 119.22253 , 130.08371 , 126.27693 , 126.305466,
       129.55656 , 128.48164 , 124.97862 , 129.08452 , 127.26959 ,
       132.87903 , 128.80084 , 129.95262 , 130.14275 , 125.569626,...

if M N K = 1 1000 512 and Tm Tn Tk = 1 100 16

produce compute {
  parallel (y.outer.x.outer.fused, 0, 10) {
    produce compute {
      for (z.y.fused.init, 0, 100) {
        compute[ramp(((y.outer.x.outer.fused*1600) + (z.y.fused.init*16)), 1, 16)] = x16(0f)
      }
      for (k, 0, 32) {
        for (z.y.fused, 0, 100) {
          compute[ramp(((y.outer.x.outer.fused*1600) + (z.y.fused*16)), 1, 16)] = (compute[ramp(((y.outer.x.outer.fused*1600) + (z.y.fused*16)), 1, 16)] + (data[ramp((k*16), 1, 16)]*weight[ramp((((y.outer.x.outer.fused*51200) + (z.y.fused*512)) + (k*16)), 1, 16)]))
        }
      }
    }
    for (x.inner, 0, 100) {
      compute[((y.outer.x.outer.fused*100) + x.inner)] = 0f
      for (kk, 0, 16) {
        compute[((y.outer.x.outer.fused*100) + x.inner)] = (compute[((y.outer.x.outer.fused*100) + x.inner)] + compute[(((y.outer.x.outer.fused*1600) + (x.inner*16)) + kk)])
      }
    }
  }
}

time: 0.000062

if M N K = 1 1000 512 and Tm Tn Tk = 1 23 16

produce compute {
  parallel (y.outer.x.outer.fused, 0, 44) {
    produce compute {
      for (z.y.fused.init, 0, 23) {
        compute[ramp(((y.outer.x.outer.fused*368) + (z.y.fused.init*16)), 1, 16)] = x16(0f)
      }
      for (k, 0, 32) {
        for (z.y.fused, 0, 23) {
          if (likely((((y.outer.x.outer.fused*23) + z.y.fused) < 1000))) {
            compute[ramp(((y.outer.x.outer.fused*368) + (z.y.fused*16)), 1, 16)] = (compute[ramp(((y.outer.x.outer.fused*368) + (z.y.fused*16)), 1, 16)] + (data[ramp((k*16), 1, 16)]*weight[ramp((((y.outer.x.outer.fused*11776) + (z.y.fused*512)) + (k*16)), 1, 16)]))
          }
        }
      }
    }
    for (x.inner, 0, 23) {
      if (likely((((y.outer.x.outer.fused*23) + x.inner) < 1000))) {
        compute[((y.outer.x.outer.fused*23) + x.inner)] = 0f
      }
      for (kk, 0, 16) {
        if (likely((((y.outer.x.outer.fused*23) + x.inner) < 1000))) {
          if (likely((((y.outer.x.outer.fused*23) + x.inner) < 1000))) {
            compute[((y.outer.x.outer.fused*23) + x.inner)] = (compute[((y.outer.x.outer.fused*23) + x.inner)] + compute[(((y.outer.x.outer.fused*368) + (x.inner*16)) + kk)])
          }
        }
      }
    }
  }
}

time: 0.000958

(ps:In order to observe the lower code, I closed all unroll operations)

Is this what I missed? And how does autotvm ensure the correctness of the results in the autotuning process?

The verification code:

import logging
import numpy as np
import tvm
import random
import sys
import math
import timeit
from tvm import relay
from tvm import autotvm

def numpyBaseline(M,K,N):
    np_repeat = 100
    np_runing_time = timeit.timeit(setup='import numpy\n'
                                         'M = ' + str(M) + '\n'
                                        'K = ' + str(K) + '\n'
                                        'N = ' + str(N) + '\n'
                                        'dtype = "float32"\n'
                                        'a = numpy.random.rand(M, K).astype(dtype)\n'
                                       'b = numpy.random.rand(K, N).astype(dtype)\n',
                                   stmt='answer = numpy.dot(a, b)',
                                   number=np_repeat)
    print("Numpy running time: %f" % (np_runing_time / np_repeat))

def buildandevaluation(s,A,B,C,a,b,c,ctx,c_np):
    with relay.build_config(opt_level=3):
        func = tvm.build(s, [A, B, C], target=target, name='gemm')
    assert func
    func(a, b, c)
    # print(func)
    # #print(func.get_source())
    #
    # print(func.get_function('gemm'))
    # print(func.get_source())
    # with open("gemm.ll", "w", encoding='utf-8') as f:
    #     f.write(str(func.get_source()))
    #     f.close()
    # from tvm.contrib import util
    # temp = util.tempdir()
    # path_dso = temp.relpath("temp.so")
    # path = temp.relpath('lib.tar')
    # func.export_library(path_dso)
    # m = tvm.module.load(path_dso)
    # print(m.get_source())
    tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
    evaluator = func.time_evaluator(func.entry_name, ctx, number=100)
    print('time: %f' % evaluator(a, b, c).mean)
    #print(tvm.lower(s, [A, B, C], simple_mode=True))



###########################################################################################################







def schedule_defination_gemm_dense_default_nopack(M, K, N, dtype, kts):
    '''e2e dense nopack(default) schedule'''
    data = tvm.placeholder((M, K), name='data', dtype=dtype)
    weight = tvm.placeholder((N, K), name='weight', dtype=dtype)
    # create tuning space
    cfg = autotvm.get_config()
    cfg.define_split("tile_y",M,num_outputs=2,policy="oracle")
    cfg.define_split("tile_x",N,num_outputs=2,policy="oracle")
    cfg.define_split("tile_k",K,num_outputs=2,policy="oracle")
    #vec = cfg["tile_k"].size[-1]
    vec = kts
    k = tvm.reduce_axis((0, K // vec), "k")
    #k = tvm.reduce_axis((0, math.ceil(K / vec)), "k")

    CC = tvm.compute((M, N, vec),lambda z, y, x:
        tvm.sum(data[z, k * vec + x].astype(dtype)
                *weight[y, k * vec + x].astype(dtype), axis=k))

    kk = tvm.reduce_axis((0, vec), "kk")

    C = tvm.compute((M, N),lambda y, x: tvm.sum(CC[y, x, kk], axis=kk))
    s = tvm.create_schedule(C.op)
    return s, [data, weight, C]



def schedule_defination_gemm_dense_pack_default(M, K, N, dtype, bn):
    '''e2e dense pack schedule'''
    data = tvm.placeholder((M, K), name='data', dtype=dtype)
    weight = tvm.placeholder((N, K), name='weight', dtype=dtype)
    # create tuning space
    cfg = autotvm.get_config()
    cfg.define_split("tile_y", M, num_outputs=3, policy="verbose")
    cfg.define_split("tile_x", N, num_outputs=3, policy="verbose")
    cfg.define_split("tile_k", K, num_outputs=2, policy="verbose")
    # packw_bn = cfg["tile_x"].size[-1]
    packw_bn = bn
    packw_shape = (N // packw_bn, K, packw_bn)
    packw = tvm.compute(packw_shape,
                        lambda z, y, x: weight[z * packw_bn + x, y], name="packed_weight")

    k = tvm.reduce_axis((0, K), name="k")
    C = tvm.compute((M, N),
                    lambda y, x: tvm.sum(
                        data[y, k].astype(dtype) *
                        packw[tvm.indexdiv(x, packw_bn), k, tvm.indexmod(x, packw_bn)].astype(dtype),
                        axis=k))

    s = tvm.create_schedule(C.op)
    return s, [data, weight, C]



###########################################################################################################


def schedule_optimization_dense_default_nopack(s, C, mts, nts):
    kk, = s[C].op.reduce_axis
    # yo, yi = cfg["tile_y"].apply(s, C, y)
    # xo, xi = cfg["tile_x"].apply(s, C, x)
    yo, xo, yi, xi = s[C].tile(C.op.axis[0], C.op.axis[1], mts, nts)
    s[C].reorder(yo, xo, yi, xi)
    xyo = s[C].fuse(yo, xo)
    s[C].parallel(xyo)
    #s[C].unroll(kk)
    CC, = s[C].op.input_tensors
    s[CC].compute_at(s[C], xyo)
    z, y, x = s[CC].op.axis
    k, = s[CC].op.reduce_axis
    yz = s[CC].fuse(z, y)
    s[CC].reorder(k, yz, x)
    #s[CC].unroll(yz)
    s[CC].vectorize(x)
    data, weight, = s[CC].op.input_tensors
    print(tvm.lower(s, [data, weight,CC, C], simple_mode=True))

def schedule_optimization_dense_pack_default(s,C,mts,kts,nts):
    A, packedB = s[C].op.input_tensors

    z, y, x = s[packedB].op.axis
    s[packedB].reorder(z, x, y)
    s[packedB].parallel(z)
    s[packedB].vectorize(x)

    CC = s.cache_write(C, "global")
    k, = s[CC].op.reduce_axis
    # yo, yi = cfg["tile_y"].apply(s, C, y)
    # xo, xi = cfg["tile_x"].apply(s, C, x)
    yto , yi = s[C].split(C.op.axis[0],factor=mts)
    xto, xi = s[C].split(C.op.axis[1], factor=nts)
    yt,yo = s[C].split(yto,factor=4)
    xt,xo = s[C].split(xto,factor=2)
    #yo, xo, yi, xi = s[C].tile(C.op.axis[0], C.op.axis[1], mts, nts)
    s[C].reorder(yt,xt,yo, xo, yi, xi)
    yxt = s[C].fuse(yt,xt)
    s[C].parallel(yxt)
    xyo = s[C].fuse(yo, xo)
    #s[C].unroll(yi)
    s[C].vectorize(xi)
    s[CC].compute_at(s[C], xyo)
    y, x = s[CC].op.axis
    # ko, ki = cfg["tile_k"].apply(s, CC, k)
    ko, ki = s[CC].split(k, factor=kts)
    s[CC].reorder(ko, ki, y, x)
    s[CC].vectorize(x)
    #s[CC].unroll(y)
    #s[CC].unroll(ki)
    weight, = s[packedB].op.input_tensors
    print(tvm.lower(s, [A, weight, packedB, CC, C], simple_mode=True))


###########################################################################################################




def dense_nopack_0_T(M, K, N, dtype, mts, kts, nts):
    s, [data, weight, C] = schedule_defination_gemm_dense_default_nopack(M, K, N, dtype, kts)
    schedule_optimization_dense_default_nopack(s, C, mts, nts)
    return s, [data, weight, C]

def dense_pack_default(M, K, N, dtype, bn):
    s, [data, weight, C] = schedule_defination_gemm_dense_pack_default(M, K, N, dtype, bn)
    schedule_optimization_dense_pack_default(s, C, mts, kts, nts)
    return s,[data,weight,C]


###########################################################################################################


if __name__ == '__main__':
    M = sys.argv[1]
    K = sys.argv[2]
    N = sys.argv[3]
    M_TS = sys.argv[4]
    K_TS = sys.argv[5]
    N_TS = sys.argv[6]

    M = int(M)
    K = int(K)
    N = int(N)
    mts = int(M_TS)
    kts = int(K_TS)
    nts = int(N_TS)

    random.seed(30)
    target = 'llvm -mcpu=core-avx2'
    dtype = 'float32'
    ctx = tvm.context(target, 0)

    k = tvm.reduce_axis((0, K), 'k')
    A = tvm.placeholder((M, K), name='A')
    B = tvm.placeholder((K, N), name='B')
    BT = tvm.placeholder((N, K), name='BT')
    C = tvm.compute((M, N),lambda x, y: tvm.sum(A[x, k] * B[k, y], axis=k),name='C')
    CT = tvm.compute((M, N),lambda x, y: tvm.sum(A[x, k] * BT[y, k], axis=k),name='CT')

    a_np = np.random.rand(M,K).astype(dtype)
    b_np = np.random.rand(K,N).astype(dtype)
    bt_np = np.random.rand(N,K).astype(dtype)

    c_np = a_np.dot(b_np)
    ct_np = a_np.dot(bt_np.T)

    a = tvm.nd.array(a_np, ctx)
    b = tvm.nd.array(b_np, ctx)
    bt = tvm.nd.array(bt_np, ctx)
    c = tvm.nd.array(c_np, ctx)
    ct = tvm.nd.array(ct_np, ctx)


    # numpyBaseline(M,K,N)

    s = tvm.create_schedule(C.op)
    st = tvm.create_schedule(CT.op)


    print("dense_nopack_0_T")
    s, [data, weight, out] = dense_nopack_0_T(M, K, N, dtype, mts, kts, nts)
    buildandevaluation(s, data, weight, out, a, bt, ct, ctx, ct_np)