Device error report

When I run

dtype = "float32"
target = tvm.target.Target("metal")
dev = tvm.metal()
evaluator = func.time_evaluator(func.entry_name, dev, number=5)
print("Strassen: %f" % evaluator(tvm.nd.array(a), tvm.nd.array(b), c).mean)

It shows

TVMError: Assert fail: T.tvm_struct_get(A, 0, 10, "int32") == 8, Argument mmult.A.device_type has an unsatisfied constraint: 8 == T.tvm_struct_get(A, 0, 10, "int32")

It is difficult to understand. When I use

dtype = "float32"
target = tvm.target.Target("metal")
dev = tvm.metal()
evaluator = func.time_evaluator(func.entry_name, dev, number=5)
print("Strassen: %f" % evaluator(tvm.nd.array(a,dev), tvm.nd.array(b,dev), c).mean)

It is correct.

It is quite similar to Pytorch that param and model should on the same device! tvm.nd.array(a,dev) moves tensor to the device, if use tvm.nd.array(a), tensor is on host memory

1 Like

Thank you. Yes, you are right.