[Solved][С++ API] Help defining CUDA module

Hi all. Could anyone familiar with GPU codegen internals check my minimal vecadd program?

#include <random>
#include <iomanip>
#include <array>
#include <exception>

#include <tvm/tvm.h>
#include <tvm/operation.h>
#include <tvm/tensor.h>
#include <tvm/build_module.h>
#include <topi/broadcast.h>

using namespace std;

int main(int argc, char **argv)
{
  /* Shape variable */
  auto n = tvm::var("n");
  tvm::Array<tvm::Expr> shape = {n};
  tvm::Tensor A = tvm::placeholder(shape, tvm::Float(32), "A");
  tvm::Tensor B = tvm::placeholder(shape, tvm::Float(32), "B");

  /* Build a graph for computing A + B */
  tvm::Tensor C = tvm::compute(shape, tvm::FCompute([=](auto i){ return A(i) + B(i); } )) ;

  /* Prepare a function `vecadd` with no optimizations */
  tvm::Schedule s = tvm::create_schedule({C->op});
  tvm::BuildConfig config = tvm::build_config();
  std::unordered_map<tvm::Tensor, tvm::Buffer> binds;
  auto lowered = tvm::lower(s, {A,B,C}, "vecadd", binds, config);

  /* Output IR dump to stderr */
  cerr << lowered[0]->body << endl;

  tvm::IterVar block_idx = tvm::thread_axis(tvm::Range(), "blockIdx.x");
  tvm::IterVar thread_idx = tvm::thread_axis(tvm::Range(), "threadIdx.x");

  tvm::IterVar i,j;
  s[C].split(C->op->root_iter_vars()[0],64,&i,&j);
  s[C].bind(i, block_idx);
  s[C].bind(j, thread_idx);

  /* Output IR dump to stderr */
  tvm::Target target = tvm::target::cuda();
  tvm::Target target_host = tvm::target::llvm();
  tvm::runtime::Module mod = tvm::build(lowered, target, target_host, config);
  mod->SaveToFile(std::string(argv[0]) + ".cuda", "cuda");
  return 0;
}

The error raised is

build_module.cc:422: 
Check failed: ir::VerifyMemory(x, target->device_type)
Direct host side access to device memory is detected in vecadd. Did you forget to bind?

No, I didn’t forget to bind. My guess is that blockIdx/threadIdx IterVars are created incorrectly. Do you have any suggestions?

I run the following Python code and it worked fine:

import tvm
import time
import numpy as np

device = "cuda"
suffix = "ptx"

n = tvm.var ("n")
A = tvm.placeholder ((n), name='A', dtype="float32")
B = tvm.placeholder ((n), name='B', dtype="float32")
C = tvm.compute (A.shape, lambda *i: A(*i) + B(*i), name='C')
s = tvm.create_schedule (C.op)

bx, tx = s[C].split (C.op.axis[0], factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))

module = tvm.build(s, [A, B, C], device, target_host="llvm")

temp = tvm.contrib.util.tempdir()
module.save (temp.relpath("myadd.o"))

It looks like the problem is not in IterVar creation. The C++ issue is still unsolved.

maybe you should call tvm::lower after bind

1 Like

Bingo! You are right, how could I miss that. Thanks. Initial code contained more errors, below is the corrected version.

A notable fact: in Python, tvm.build accepts the schedule object, but in C++, it is tvm::lower which accepts schedule and tvm::build accepts LoweredFunc.