Hi there,
I am trying to perform metascheduling on a resnet50 workload, and the focus is to run it on cuda tensorcores. The workload specification code is given below:
mod, params = testing.resnet.get_workload(
batch_size = batch_size, image_shape=input_image_shape,layout="NHCW",dtype="float16")
The parameter specifications are :
batch_size = 1
num_class = 1000
input_image_shape = (224,224,3)
data_shape = (batch_size,) + input_image_shape
output_shape = (batch_size, num_class)
dtype = "float16"
However, when I begin the metascheduling, I encounter this error:
I checked this issue but it has been stagnant for a while now. No contribution has been made to it either. This error however, doesn’t appear when I use the workload with NCHW layout and float32 datatype. The tuning happens as usual.
What is the reason behind this error for resnet with this specific NHWC format? Please help me solve this error. I have put up the full code here : resnet_ms
@zxybazh @junrushao Kindly share your thoughts on what might be going wrong here.