Nvcc fatal : Value 'sm_86' is not defined for option 'gpu-architecture'

I met this error when trying to compile a mxnet model. The docker image is built based on nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04.

  • driver: 465.31
  • CUDA: 11.0
  • GPU: RTX3090
  • tvm commit: 34570f27e

The test script is as below:

import tvm
from tvm import relay
import mxnet as mx
from mxnet.gluon.model_zoo.vision import get_model

block = get_model("resnet18_v2", pretrained=True)
shape_dict = {"data": (1, 3, 224, 224)}
mod, params = relay.frontend.from_mxnet(block, shape_dict)

target = 'cuda'
# target = 'cuda -arch=sm_80'

with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target, params=params)

The error is like

.....
inverse[(3)] = (inverse[(3)] + (bgemm[((((((int)blockIdx.x) * 128) + ((int)threadIdx.x)) + 57344))] * -1.000000e+00f));
  inverse[(3)] = (inverse[(3)] + (bgemm[((((((int)blockIdx.x) * 128) + ((int)threadIdx.x)) + 73728))] * -1.000000e+00f));
  inverse[(3)] = (inverse[(3)] + bgemm[((((((int)blockIdx.x) * 128) + ((int)threadIdx.x)) + 81920))]);
  inverse[(3)] = (inverse[(3)] + bgemm[((((((int)blockIdx.x) * 128) + ((int)threadIdx.x)) + 90112))]);
  inverse[(3)] = (inverse[(3)] + (bgemm[((((((int)blockIdx.x) * 128) + ((int)threadIdx.x)) + 106496))] * -1.000000e+00f));
  inverse[(3)] = (inverse[(3)] + bgemm[((((((int)blockIdx.x) * 128) + ((int)threadIdx.x)) + 114688))]);
  inverse[(3)] = (inverse[(3)] + bgemm[((((((int)blockIdx.x) * 128) + ((int)threadIdx.x)) + 122880))]);
  for (int ax2_inner = 0; ax2_inner < 2; ++ax2_inner) {
    for (int ax3_inner = 0; ax3_inner < 2; ++ax3_inner) {
      if (((((((int)threadIdx.x) & 15) >> 2) * 2) + ax2_inner) < 7) {
        if ((((((int)threadIdx.x) & 3) * 2) + ax3_inner) < 7) {
          T_relu[(((((((((int)blockIdx.x) * 392) + ((((int)threadIdx.x) >> 4) * 49)) + (((((int)threadIdx.x) & 15) >> 2) * 14)) + (ax2_inner * 7)) + ((((int)threadIdx.x) & 3) * 2)) + ax3_inner))] = max((inverse[(((ax2_inner * 2) + ax3_inner))] + placeholder[(((((int)blockIdx.x) * 8) + (((int)threadIdx.x) >> 4)))]), 0.000000e+00f);
        }
      }
    }
  }
}


Compilation error:
nvcc fatal   : Value 'sm_86' is not defined for option 'gpu-architecture'

BTW, Setting target to “cuda -arch=sm_80” doesn’t help.

I have 2 problems here,

  • Why tvm specify sm_86 instead of sm_80? If I remembered correctly, sm_86 is supported from CUDA 11.1 ?
  • Setting cuda arch by target = 'cuda -arch=sm_80' doesn’t help at all. The error is still Value 'sm_86' is not defined for option 'gpu-architecture'. Did I misunderstood the usage of target here ?

A similar problem can be found here: Nvcc fatal : Value 'sm_75' is not defined for option 'gpu-architecture'

Adding the code below could resolve this error, though it might not be a good solution

@tvm.register_func("tvm_callback_cuda_compile", override=True)
def tvm_callback_cuda_compile(code):
    from tvm.autotvm.env import AutotvmGlobalScope
    from tvm.contrib import nvcc
    """use nvcc to generate ptx code for better optimization"""
    curr_cuda_target_arch = AutotvmGlobalScope.current.cuda_target_arch
    target = "fatbin" if isinstance(curr_cuda_target_arch, list) else "ptx"
    ptx = nvcc.compile_cuda(code, target=target,
        # arch=AutotvmGlobalScope.current.cuda_target_arch
        arch=["-gencode", "arch=compute_80,code=sm_80"]
    )
    return ptx

Because your GPU is sm86 device, we probably have some code that uses sm86 by default. TVM cannot tell which version of CUDA introduced support for sm86.

Is at a desired behavior that setting target = 'cuda -arch=sm_80' doesn’t really change the cuda arch when compiling the model?

I don’t know if we are supporting -arch option for codegen, but yes that sounds reasonable.

We do support -arch in NVCC utility:

But currently we use a weird way to pass this argument:

As a result, target= "cuda -arch=xx" is not effective. Instead, you need to set the global variable as follows:

from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
set_cuda_target_arch('sm_80')

"cuda -arch=xx" is a good idea for assigning the specific -sm arch rather than the weird way above. Is it necessary to add this api to tvm? I can try to make a PR. :smiley:

1 Like

Agree with @pzq . Setting cuda arch has nothing to do with autotvm. It looks much better if we make cuda -arch=xx effective.

@comaniac @masahi @nicklhy This PR #9544 may solve the problem effectively.

Can you help me to review it?

1 Like

Hi, I tried to build and run tvm operator on a paltform with 3080ti and cuda 10.2(sm_75) , and set arch=sm_75 when building, it built successfully. But when I want to run the func built just, I always get an error(CUDAError: cuModuleLoadData(&(module_[device_id]), data_.c_str()) failed with error: CUDA_ERROR_NO_BINARY_FOR_GPU).This error disappered when I use CUDA 11.1+ whose arch is sm_86, the same as tvm default selection.There are no args for me to chage the runtime device arch, so I wonder how can I run the sm_75 code on 3080ti.