Sparse OpenCL error: scheduling sparse computations that use tir.ir_builder

I am working with custom sparse ops, and though it all works for the LLVM backend, I am having some issues generating code for OpenCL.

The simplest op I’ve made is a sparse GEMM matrix multiplication, though this issue exists for other ops I’ve made.

The sparse part of the computation uses a lightly adapted version of TVM’s csrmm_default (found in python/tvm/topi/sparse/csrmm.py). I’ve made a reproducibile example of OpenCL code being generated for sparse matrix multiplication in TVM available as this gist.

When I combine this sparse GEMM computation in a larger Topi Op for GEMM Convolution (featuring im2col, padding), it works for LLVM. However, when I try to generate OpenCL code for it, I get an error:


  Did you forget to bind?
    Variable `out` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `out` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `out` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "../src/tir/analysis/verify_memory.cc", line 202
RuntimeError: Memory verification failed with the following errors:
PrimFunc([placeholder, placeholder, placeholder, placeholder, out]) attrs={"global_symbol": "fused_nn_conv2d_sparse", "tir.noalias": (bool)1, "target": opencl -keys=mali,opencl,gpu -device=mali -max_num_threads=256 -model=unknown} {
  // attr [data_im2col] storage_scope = "global"
  allocate data_im2col[float32 * 10368]
  for (k, 0, 128) {
    for (m, 0, 81) {
      data_im2col[((k*81) + m)] = placeholder[(((((floordiv(k, 16)*64) + (floordiv(m, 9)*8)) + (floordiv(floormod(k, 16), 4)*8)) + floormod(m, 9)) + floormod(k, 4))]
    }
  }
  // attr [0] extern_scope = 0
  parallel (row, 0, 8) {
    // attr [dot] storage_scope = "local"
    allocate dot[float32x81 * 1]
    out[ramp((row*81), 1, 81)] = x81(0f)
    dot[ramp(0, 1, 81)] = x81(0f)
    for (idx, 0, (placeholder[(row + 1)] - placeholder[row])) {
      dot[ramp(0, 1, 81)] = (dot[ramp(0, 1, 81)] + (x81(placeholder[(placeholder[row] + idx)])*data_im2col[ramp((placeholder[(placeholder[row] + idx)]*81), 1, 81)]))
    }
    out[ramp((row*81), 1, 81)] = (out[ramp((row*81), 1, 81)] + dot[ramp(0, 1, 81)])
  }
}

Other folk have had similar issues, (e.g this issue here). @vinx13 suggested binding the axis with a schedule.

However, since this operation uses the tir.ir_builder system, rather than the standard TVM Relay operation building, writing a schedule (e.g. by getting our loop axes with s[last].op.axis) fails with:

AttributeError: <class 'tvm.te.tensor.ExternOp'> has no attribute axis

This above issue is described in an earlier question, to which @tqchen said “extern op cannot be scheduled because we get no control over the internal implementation of the op.”.

However, since this uses TVM, is there a way to schedule these ir_builder operations?

I’ve looked through @ziheng’s RFC on the matter, and am still not sure. Is making a schedule that binds an axis even the right way to fix this OpenCL codegen issue?

I don’t think we can bind schedule these operations. One workaround is to use thread axis in the ir builder directly (make a copy of the cpu version and modify it), see example here https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/nms.py

Thanks, this is a really useful workaround!

I’ve added an initial binding, and it’s made all but one of my bidning errors disappear.

max_threads = int(tvm.target.Target.current(allow_none=False).max_num_threads)
nthread_tx = max_threads
nthread_bx = max_threads
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
irb.scope_attr(tx, "thread_extent", nthread_tx)
irb.scope_attr(bx, "thread_extent", nthread_bx)
n = bx * max_threads + tx

I’m still throwing this single error:

  Did you forget to bind?
    Variable `placeholder` is directly accessed by host memory (it is not contained in a thread environment or in the function arguments.
  File "../src/tir/analysis/verify_memory.cc", line 202
RuntimeError: Memory verification failed with the following errors:
PrimFunc([placeholder, placeholder, placeholder, placeholder, out]) attrs={"global_symbol": "fused_nn_conv2d_sparse", "tir.noalias": (bool)1, "target": opencl -keys=mali,opencl,gpu -device=mali -max_num_threads=256 -model=unknown} {
  // attr [data_im2col] storage_scope = "global"
  allocate data_im2col[float32 * 10368]
  for (k, 0, 128) {
    for (m, 0, 81) {
      data_im2col[((k*81) + m)] = placeholder[(((((floordiv(k, 16)*64) + (floordiv(m, 9)*8)) + (floordiv(floormod(k, 16), 4)*8)) + floormod(m, 9)) + floormod(k, 4))]
    }
  }
  // attr [0] extern_scope = 0
  // attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 256
  // attr [dot] storage_scope = "local"
  allocate dot[float32 * 1]
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 256
  for (batch, 0, 1) {
    parallel (row, 0, 8) {
      if ((((blockIdx.x*256) + threadIdx.x) < 81)) {
        out[(((blockIdx.x*256) + (row*81)) + threadIdx.x)] = 0f
        dot[0] = 0f
        for (idx, 0, (placeholder[(row + 1)] - placeholder[row])) {
          dot[0] = (dot[0] + (placeholder[(placeholder[row] + idx)]*data_im2col[(((blockIdx.x*256) + (placeholder[(placeholder[row] + idx)]*81)) + threadIdx.x)]))
        }
        out[(((blockIdx.x*256) + (row*81)) + threadIdx.x)] = (out[(((blockIdx.x*256) + (row*81)) + threadIdx.x)] + dot[0])
      }
    }
  }
}

I believe that this error is coming from the Relay operations, in this case the im2col function not getting a thread binding.

This would seem like a case for the standard Relay schedule builder, and binding in the usual way.

However, all of the tensors I try to bind are either PlaceholderOp, or the IRBuilder’s ExternalOp. I would expect them to be ComputeOp.

E.g., for my GEMM convolution example if I have:

conv_out = op.output(0)
data_col = conv_out.op.input_tensors[0] 
_, N, M =  s[data_col].op.axis

This fails with:

AttributeError: <class 'tvm.te.tensor.PlaceholderOp'> has no attribute axis

It’s interesting that the Relay expressions are generating PlaceholderOp instead of the usual ComputeOps when we have an IRBuilder expression in the TOPI operation (unless there’s something I’ve done wrong here).

Is there a way I can bind PlaceholderOp operations, such as using a schedule?

Hi @Wheest. The issue you are hitting in your first post is that operator implementations are often targeted for a specific device type (CPU, GPU), as parallelism is not consistent between them. You were trying to use an implementation for the CPU that will not work on the GPU. You can find a sparse matrix multiplication for the gpu here: https://github.com/apache/tvm/blob/main/python/tvm/topi/cuda/sparse.py#L134.

Note that these implementations are using TIR directly. The normal way to implement ops is to use a topi compute definition plus a schedule (https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html). You have one compute definition and then a schedule for each type of hardware. However, this does not work for sparse because topi cannot express the computation. Instead, we have to write TIR directly using ir_builder. For each device type we want to support, we need to write a separate computation. Also, with ir_builder, there is no concept of scheduling. You write the loops directly.

Regarding the second error you are getting: this looks like a scheduling error. Relay sits at a higher level than topi + scheduling or ir_builder. It is not responsible for how each operator is implemented. If you post the code, I can help you debug it.

Thanks @tkonolige.

My understanding, as concisely as I can is:

In TVM, we can write expressions in a few ways. The most common is the Relay API. There is also the IRBuilder API, which is what we write sparse computations in.

We can have several expressions combined into a single operation, called a TOPI operation.

E.g. for GEMM convolution we could have three Relay expressions: im2col, padding, and GEMM.

Together, they make one TOPI operation.

Now, for TVM GPU code generation, it’s important that the compiler knows what loops to bind to (e.g. the threadIdx and blockIdx). These can be worked out automatically for Relay (I think), but you can set them explicitly using a schedule?

However, when we have a TOPI operation which has a mix of Relay and IRBuilder expressions, then binding these GPU threads becomes difficult. You can explicitly write the binding for the IRBuilder parts, but expressions written in Relay behave weirdly and can’t be bound to (my 2nd post in this thread).

You can see the CPU code I’m working with in this pull request I’ve just made. Ideally, I want to reuse all of this for the GPU, even if it’s just copied and with bindings added.

Reimplementing say im2col in IRBuilder would be a possible solution, but annoying.

This isn’t quite correct. Relay is a high-level description of a model. In it you compose together different operators. Each operator can have multiple implementations (also called strategies) that are either 1. a combination of a TE expression (I incorrectly called this TOPI in the previous post) and a schedule or 2. hand written TIR code (written using IRBuilder).

TOPI is just a collection of operator implementations.

In your GEMM example, relay.nn.dense(A, B) would be the Relay expression. dense_cublas.cuda and dense_nopack.x86 are examples of different implementations. Within one of these implementations, you may have TE compute subexpressions called im2col and padding.

Relay knows nothing about the implementation of an operator, so I think you mean TE here. However, TE will not automatically bind loops for you. You must provide your own schedule which binds these threads. (You can also autogenerate a schedule using autoscheduler).

(Once again I am going to assume you mean TE instead of Relay.)

You can mix TE and IRBuilder. I believe you are just grabbing the wrong tensor. You can try printing out the tensor to see if you’ve selected the correct one to grab the axis from. If you provide a simple example (and a way to run it), I can help you debug.

Thanks for correcting my misunderstanding of terms, I’ve got a lot of updating of my notes to do.

You’re right that I wasn’t grabbing the right tensor for the schedule. I was naïvely assuming that the order that the tensors were passed to my external expression would match the order they would be accessed in conv_out.op.input_tensors. Printing the tensors in the schedule to check this is a debugging technique I should have considered.

I’ve now got my single-layer test successfully generating OpenCL code without the above errors. The output is incorrect, as I was a bit quick and dirty with how I split the axes for binding. Nothing a bit of whiteboarding and fresh air won’t solve.

Hi, I’ve been working on generating CUDA code for sparse matrix multiplication using TVM in order to save me from the laborious work of designing and tuning the kernel. Since I’m new to TVM , I can’t figure out why topi cannot express the computation. In this code, it seems that the compute and schedule method can also be used in the SpMM. I suppose that maybe I can also use cache_read or other techniques described in this tutorial to achieve higher performance? I’ve also read the code you mentioned above, and I’m wondering if there is a shorted way to express such operation. Thanks!

You can express sparse-dense matrix multiplication in TE, but you cannot apply certain scheduling operations. Namely, you cannot tensors used as indices in other tensors. This is a critical step for well performing GPU kernels, so I wrote them in TIR so that I could do this caching. If you do not need caching of the indices, you could try TE + scheduling (what is described in the tutorial).

Note that we do have some fairly optimized sparse matrix kernels for CUDA already in TVM. Is there a workload where the ones currently in TVM are not fast enough?

Thanks! Sorry for the late reply. I haven’t found the workload that slow down the current kernel yet:). What’s more, I notice that AutoTIR has been merged to main recently. Since I haven’t found the doc for it, I’m curious about its function. And to be specific, is there any optimization for the hand-written kernels?

Meta scheduler (autotir) has not been full merged yet. I believe it will work on sparse workloads, but I don’t know how well it compares to the hand written kernels.

Got it ! Thanks for your reply!