When I run
dtype = "float32"
target = tvm.target.Target("metal")
dev = tvm.metal()
evaluator = func.time_evaluator(func.entry_name, dev, number=5)
print("Strassen: %f" % evaluator(tvm.nd.array(a), tvm.nd.array(b), c).mean)
It shows
TVMError: Assert fail: T.tvm_struct_get(A, 0, 10, "int32") == 8, Argument mmult.A.device_type has an unsatisfied constraint: 8 == T.tvm_struct_get(A, 0, 10, "int32")
It is difficult to understand. When I use
dtype = "float32"
target = tvm.target.Target("metal")
dev = tvm.metal()
evaluator = func.time_evaluator(func.entry_name, dev, number=5)
print("Strassen: %f" % evaluator(tvm.nd.array(a,dev), tvm.nd.array(b,dev), c).mean)
It is correct.