How to make ssd.multibox_prior under cuda repo be called when device is cuda?

Currently generic ssd for llvm backend is under topi. I want to write ir builder program for GPUs under cuda. And I created a repo called ssd just like topi.vision under topi.cuda. And I registered multibox_prior function for cuda devices (@multibox_prior.register(“cuda”)). And I wrote corresponding init.py files. But when I call ssd.multibox_prior using device cuda as follows:
with tvm.target.create(device):
out = ssd.multibox_prior(data, sizes, ratios, steps, offsets, clip)
s = topi.generic.schedule_multibox_prior(out)

But it’s still calling the generic multibox_prior function under topi.vision instead of the one under topi.cuda. Is there anything else I should do to make topi.cuda.ssd.multibox_prior automatically be called whenever I set the compile device to cuda/gpu?

Thanks in advance!

Usually I like to check explicitly in these situations to make sure the imports are happening. Can you verify that your file is imported (e.g., add a raise Exception or a print) at the beginning?

You can also verify that generic_func is firing by instrumenting def https://github.com/dmlc/tvm/blob/d7d44f80b2d38086699b2c1edf29e103435f7b16/python/tvm/target.py#L269

This is what I have in multibox_prior_ir function,

ib = tvm.ir_builder.create()

max_threads = 2
bx = tvm.thread_axis("blockIdx.x")
tx = tvm.thread_axis("threadIdx.x")

p_out = ib.buffer_ptr(out)
in_height = data.shape[2]
in_width = data.shape[3]
num_sizes = len(sizes)
num_ratios = len(ratios)
size_ratio_concat = sizes + ratios
steps_h = steps[0] if steps[0] > 0 else 1.0 / in_height
steps_w = steps[1] if steps[1] > 0 else 1.0 / in_width
offset_h = offsets[0]
offset_w = offsets[1]

ib.scope_attr(bx, "thread_extent", (in_height+max_threads-1) // max_threads)
ib.scope_attr(tx, "thread_extent", max_threads)
i = bx.var * max_threads + tx.var

# with ib.for_range(0, in_height, for_type="parallel", name="i") as i:                      
with ib.if_scope(ib.likely(i < in_height)):
    center_h = (i + offset_h) * steps_h
    with ib.for_range(0, in_width, name="j") as j:
        center_w = (j + offset_w) * steps_w
        for k in range(num_sizes + num_ratios - 1):
            w = tvm.select(k < num_sizes,
                           size_ratio_concat[k] * in_height / in_width / 2.0,
                           size_ratio_concat[0] * in_height / in_width *
                           math.sqrt(size_ratio_concat[k + 1]) / 2.0)
            h = tvm.select(k < num_sizes, size_ratio_concat[k] / 2.0,
                           size_ratio_concat[0] / math.sqrt(size_ratio_concat[k + 1]) / 2.0)
            count = (i * in_width * (num_sizes + num_ratios - 1) +
                     j * (num_sizes + num_ratios - 1) + k) * 4
            p_out[count] = center_w - w
            p_out[count + 1] = center_h - h
            p_out[count + 2] = center_w + w
            p_out[count + 3] = center_h + h

where, experimentally, I implemented explicit thread_extent directly in the ir_builder.
Is there any idea how to do this in a schedule?

Hi, I also bind threads using ir buider. Looks like it’s not that obvious to implement it using scheduler. I’ll upload the code once I finish testing the multibox detection part of ssd.

I think it may be possible to create another ir_builder in a schedule, and bind threads to outs[0].op.body.body, as they are mostly in a ForStmt. Then replace the for statement with a smaller for loop, or just a if statement.

Hi Laura, can i run ssd model with target “opencl” when you finish your implementation?

Hi Laura,

I’d like to run “opencl” target with SSD deep neuron network. As I know, the _contrib_MultiBoxPrior layer and _contrib_MultiBoxDetection layer are currently supported on only CPU target. So, Do you have a plan for OpenCL target SSD deep neuron networks? On the case that I’d like to contribute to TVM project by implementing OpenCL for the two above layers, Can you give me some ideas for coding such as coding structure of tvm, topi and source code implementing location ?

Regards

Yeah, I’ve finished implementing those two layers and they would be ready soon.

3 Likes

Hi Laura,

Did you complete your implementing OpenCL target for the two layer? Can you show me the git branch for this feature if possible? I’d like to perform a SSD network with OpenCL target.

Thanks you