[TOPI][CUDA] "scatter_nd"' has a very poor performance on CUDA backend (>1000x slower than hand-written cuda code)

Problem Statement

Existing cuda “scatter_nd” op (which written with TIR) has 2 problems, which block I from deploying it to real-world GPU devices:

  1. There is an integer overflow bug in it’s TIR implementation, for which I proposed a PR to fix this problem: [BugFix][TOPI] Fix the integer overflow problem of the scatter_nd op. by zhuwenxi · Pull Request #8415 · apache/tvm · GitHub
  2. It has a relatively very poor performance on GPU. In my case, it’s almost 1000x slower than my naive hand written CUDA implementation.

Code to Reproduce:

import tvm
import numpy as np
import tvm.relay as relay

dev = tvm.cuda()
target = tvm.target.Target("cuda")

# input data:
data_np = np.zeros((32, 128, 128, 256)).astype(np.float32)
indices_np = np.random.uniform(1,5,(32, 600, 3)).astype(np.int64)
updates_np = np.random.rand(32, 600, 256).astype(np.float32)

# Construct relay input nodes:
data = relay.var("data", shape=data_np.shape, dtype=str(data_np.dtype))
indices = relay.var("indices", shape=indices_np.shape, dtype=str(indices_np.dtype))
updates = relay.var("updates", shape=updates_np.shape, dtype=str(updates_np.dtype))

# Compute indices:
indices_dim = len(indices_np.shape)
axes = list(range(indices_dim))
indices_t = relay.transpose(indices, axes[-1:] + axes[:-1])

# Construct relay scatter_nd op:
out = relay.op.scatter_nd(data, indices_t, updates, "update")
func = relay.Function([data, indices, updates], out)

# Execute scatter_nd:
intrp = relay.create_executor("debug", device=dev, target=target)
op_res = intrp.evaluate(func)(data_np, indices_np, updates_np)

Result

The script above takes 2.89 s on Nvidia T4. In comparison, I wrote a very naive cuda implementation, which takes only 2.27 ms.

Root cause

The algorithm of scatter_nd is consisting of 2 stages:

  1. Init the output tensor, make all it’s element 0;
  2. Update part of the output tensor to given values;

Since the above output tensor and update tensor have different shapes, they actually require different thread/block number to achieve best performance on GPU. Thus there comes to 2 approaches to implement this op:

  1. Implement it with 2 cuda kernels, 1 for init and 1 for update, and let them have different thread/block configuration to achieve best performance;
  2. Implement it with 1 cuda kernel, and let it do init and update simultaneously;

There is a trade-off here, use approach 1 to achieve best hardware utility or approach 2 to get rid of some kernel launch time cost. Apparently, existing scatter_nd op takes the latter one, implement a single-kernel with TIR.

Unfortunately, approach 2 has a relatively very poor performance in real world cases, we can check the cuda code below to see details:

Input tensor: 32 * 128 * 128 * 256 Updates tensor: 32 * 600 * 256 CUDA kernel config: grid (1, 1, 1), block(256, 1, 1), every thread has to do 32 * 128 * 128 elements init and 32 * 600 elements update.

Now we can see clearly where the problem is, the block size 256 is way too small to fully utilize GPU SM. That’s exactly the reason why it is so slow when running on GPU.

My questions

  1. How to address this scatter_nd performance problem? Maybe by adopting the 2 kernels implementation approach? I don’t know.
  2. What is the long time plan for these ops implement with TIR, including scatter_nd? I noticed there is a new feature AutoTIR is going on, will it be able to solve such kind of problems?

Your script is not accurately timing the scatter_nd cuda kernel. Use this instead:

lib = relay.build(func, target)
g = graph_executor.GraphModule(lib["default"](dev))
r = g.module.time_evaluator("run", tvm.cuda())(tvm.nd.array(data_np, dev), tvm.nd.array(indices_np, dev), tvm.nd.array(updates_np, dev))
print(r)

On my machine (GTX 1070), the kernel takes 17ms.

Scatter_nd certainly could be optimized more though and we welcome patches. However, I’m not sure splitting it into two kernels will necessarily make it faster. Right now we do not parallelize over the indices in the update because handling multiple updates to the same location is hard. This is why you are seeing a low number of threads. If you want to improve the parallelism, you’ll have to find the correct way to handle multiple updates. The two approaches I can think of here are 1. use atomicAdd and 2. sort the indices and then make sure that updates the the same indices are handled by one thread.

@tkonolige I did a few experiments and found out where the problem is.

Your fix that “reduce the large bound” not only resolve the integer overflow issue, but also make it 100x faster. That’s great!

Although 20 ms sounds like a good number, but I have to say it is still 10x slower than the 2 kernel implementation.

TIR implementation list below:

def scatter_nd_init(data, indices, updates, mode):
    def gen_ir_nv(data_ptr, indices_ptr, updates_ptr, out_ptr):
        ib = tvm.tir.ir_builder.create()

        data = ib.buffer_ptr(data_ptr)
        indices = ib.buffer_ptr(indices_ptr)
        updates = ib.buffer_ptr(updates_ptr)
        out = ib.buffer_ptr(out_ptr)

        # We combine all the indices dimensions but the first one into a single
        # dimension so we can iterate it in single loop instead of an arbitrary
        # number of loops. We do the same thing for all the update dimensions.
        fused_indices_dimension = 1
        for i in indices_ptr.shape[1:]:
            fused_indices_dimension *= i

        fused_updates_dimension = 1
        for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
            fused_updates_dimension *= i

        fused_shape = 1
        for i in data_ptr.shape:
            fused_shape *= i

        # For now we avoid parallizing over dimensions indexed by `indices` as
        # there may be repeated indices and hadling parallel accumulation can
        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
        # work well when these dimensions are large enough to saturate memory
        # bandwidth, but performance will be bad when these dimensions are
        # small.
        bidx = te.thread_axis("blockIdx.x")
        bidy = te.thread_axis("blockIdx.y")
        tidx = te.thread_axis("threadIdx.x")

        # InitValue kernel
        gridDim = int(ceil_div((fused_shape - 1), data_ptr.shape[-1]))
        blockDim = data_ptr.shape[-1]

        ib.scope_attr(bidx, "thread_extent", gridDim)
        ib.scope_attr(tidx, "thread_extent", blockDim)

        gid = bidx * blockDim + tidx
        with ib.if_scope(gid < fused_shape):
            out[gid] = 0.0

        return ib.get()

    out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
    return te.extern(
        [data.shape],
        [data, indices, updates],
        lambda ins, outs: gen_ir_nv(ins[0], ins[1], ins[2], outs[0]),
        dtype=data.dtype,
        out_buffers=[out_buf],
        name="scatter_nd_cuda",
        tag="scatter_nd_cuda",
    )
def scatter_nd_update(data, indices, updates, mode):
    _verify_scatter_nd_inputs(data, indices, updates)

    def gen_ir_nv(data_ptr, indices_ptr, updates_ptr, out_ptr):
        ib = tvm.tir.ir_builder.create()

        data = ib.buffer_ptr(data_ptr)
        indices = ib.buffer_ptr(indices_ptr)
        updates = ib.buffer_ptr(updates_ptr)
        out = ib.buffer_ptr(out_ptr)

        # We combine all the indices dimensions but the first one into a single
        # dimension so we can iterate it in single loop instead of an arbitrary
        # number of loops. We do the same thing for all the update dimensions.
        fused_indices_dimension = 1
        for i in indices_ptr.shape[1:]:
            fused_indices_dimension *= i

        fused_updates_dimension = 1
        for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
            fused_updates_dimension *= i

        fused_shape = 1
        for i in data_ptr.shape:
            fused_shape *= i

        # For now we avoid parallizing over dimensions indexed by `indices` as
        # there may be repeated indices and hadling parallel accumulation can
        # be hard. So we parallelize over X_M .. X_{N-1} instead. This will
        # work well when these dimensions are large enough to saturate memory
        # bandwidth, but performance will be bad when these dimensions are
        # small.
        bidx = te.thread_axis("blockIdx.x")
        bidy = te.thread_axis("blockIdx.y")
        tidx = te.thread_axis("threadIdx.x")

        # Real ScatterND kernel
        ldIndices0 = indices_ptr.shape[-1]  // 600
        ldIndices1 = indices_ptr.shape[0]   // 3
        ldUpdates0 = updates_ptr.shape[1]   // 600
        ldUpdates1 = updates_ptr.shape[-1]  // 256
        map_bs_max = data_ptr.shape[0]      // 32
        map_r_max  = data_ptr.shape[1]      // 128
        map_c_max  = data_ptr.shape[2]      // 128


        ib.scope_attr(bidx, "thread_extent", ldUpdates0)
        ib.scope_attr(bidy, "thread_extent", map_bs_max)
        ib.scope_attr(tidx, "thread_extent", ldUpdates1)

        baseID_idx = bidy * ldIndices0 * ldIndices1 + \
                                    bidx * ldIndices1

        bs_idx = indices[baseID_idx + 0]
        r__idx = indices[baseID_idx + 1]
        c__idx = indices[baseID_idx + 2]

        #with ib.if_scope(bs_idx < map_bs_max and r__idx < map_r_max and c__idx < map_c_max):
        with ib.if_scope(tvm.tir.all(bs_idx < map_bs_max ,r__idx < map_r_max , c__idx < map_c_max)):

            outputIdx = bs_idx * map_r_max * map_c_max * ldUpdates1 + \
                                        r__idx * map_c_max * ldUpdates1 + \
                                                    c__idx * ldUpdates1 + \
                                                                    tidx

            inputIdx  = bidy * ldUpdates0 * ldUpdates1 + \
                                        bidx * ldUpdates1 + \
                                                    tidx
            out[outputIdx] = updates[inputIdx]
        return ib.get()

    out_buf = tvm.tir.decl_buffer(data.shape, data.dtype, "out_buf")
    return te.extern(
        [data.shape],
        [data, indices, updates],
        lambda ins, outs: gen_ir_nv(ins[0], ins[1], ins[2], outs[0]),
        dtype=data.dtype,
        out_buffers=[out_buf],
        name="scatter_nd_cuda",
        tag="scatter_nd_cuda",
    )

And according to my measurement, the scatter_nd_init takes 2.36 ms and scatter_nd_update takes 0.274 ms, that makes a total 2.63 ms time cost of scatter_nd.

If the copy is taking most of the time, then we should certainly split it into its own kernel to parallelize it. Can you submit a PR for this?

Note that you don’t need to write two separate tir functions, you can split kernels by using ib.new_scope().

Great, I will fire a PR soon.

By the way, I’m a little bit curious about the long term plan for these ops who are implemented with TIR. As far as I known, unlike ordinary ops implemented with TE, TIR ops can’t be auto-tuned and thus are not scalable (achieve optimal performance) across different HW architectures.

So what is the long term plan for these TIR ops? Will TVM extends its TE expressive power to enable “data dependent indexing” so these ops could be migrated to TE, or the new feature AutoTIR would help address the problem (particularly cover the case of scatter_nd, which could choose a better implement between 1-kernel and 2-kernels impls)?

I believe meta-scheduler (autotir) should be able to auto tune these ops. I don’t think the plan will be to migrate these ops to te. Instead we will just gave some time tir code that does the computation and metascheduler will optimize it.

Great, thank you for the reply @tkonolige .