[RFC] TensorIR: A schedulable IR for TVM

Amazing progress here! So great!

I have a short question, though. When trying out the Blitz tutorial, I didn’t get how to use constant to define e.g. shapes or types

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, (8,), dtype="float32")
        B = T.match_buffer(b, (8,), dtype="float32")
        ...

I’d be interested how to do something like this:

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, x:int, dt: str):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, (x,), dtype=dt)
        B = T.match_buffer(b, (x,), dtype=dt)
        if x < cond1:
              ...
        else:
             ....

        ...

I get an error message whenever using a construct similar to this.

Is this possible. If yes, it would be great if you could give a short example.

CC: @crazyjackie1993 @Hzfengsy @junrushao

I’m not sure if there’s any easy way to represent string types, but integers can be represented as T.int32 or T.int64, etc. So in your example, you can probably use the below syntax:

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, x: T.int32):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, (x,), dtype=dt)
        B = T.match_buffer(b, (x,), dtype=dt)
        ...

Note that the TVMScript is a represent for TensorIR (TVM IR) rather than part of runnable python code. That means we can only use TVM data structure (e.g. T.int32, T.handle) rather than python type (e.g. int, str).

Unfortunately, TVM does not have a type hint for type struct (str is not the type struct). So we can not use constants to define buffer types. We may support it in the future if it is needed.

Thanks for the replies! I’m not 100% sure if I understood everything correct.

The feature I’m looking for is to replace TE by TensorIR. Here an example (from the TVM repository) with lots of parameters defined as Python constants.

conv = te.compute(
    (1, ofm_height, ofm_width, ofm_channels),
    lambda nn, hh, ww, cc: te.sum(
        dmaed_ifm(
            nn, hh * stride_h + rh * dilation_h, ww * stride_w + rw * dilation_w, rc
        ).astype(ifm.dtype)
        * weight[cc, rh, rw, rc].astype(ifm.dtype)
        + (scale_bias[cc, 0] * scale_bias[cc, 9]).astype(ifm.dtype),
        axis=[rh, rw, rc],
    ),
    name="ethosu_conv2d",
    attrs=conv2d_attrs,
)

The way I interpret the figure from the Blitz tutorial is that I can replace TE by TVMScript(that represents Tensor IR). However, I don’t see the way to use (Python) variables to influence the schedule here.

You have indicated that it might not be possible yet using TVMScript. Is there another way of bringing the shapes (or type or other constants) of an operator into the TensorIR AST? At least from the tutorial it was not directly clear for me.

IIUC, TVMscript cannot handle such a case. This’s also my question. So I try to modify the TVMscript. The final python snippet is similar as the torchscript:

class ScriptModule(object):
   def __init__(self, x: ty.int32):
        self.x = x

    def main(self, a:ty.handle, b:ty.handle, c:ty.handle)->None:
         T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, (self.x,), dtype=dt)
        B = T.match_buffer(b, (self.x,), dtype=dt)

I recommend that you could read the code about torchscript.

I see your pain. Currently we cannot programmatically construct (meta-program) TVMScript from Python. See also the discussion in [RFC] Hybrid Script Support for TIR - #34 by masahi

1 Like

IMHO, we can do the same staff like torch script. But like you mentioned before, how to modulize the function may be the main obstacle. Cross function calles may affect codegen, split device and host, etc.

I agree that native meta-programming is not supported for now. But we have another API called specialize, which can do similar things.

1 Like

I know that the Blitz tutorial does not show enough information about TensorIR. But it’s a very simple tutorial for the new TVM users, showing how to play with IRModule and TensorIR. More advanced docs and tutorials will come soon after the v0.8 release

I’m still a little confused about Tensor IR. The relay.build function will convert the model to the Tensor IR without TE, do I understand right?

Current relay build still use TE, TensorIR is only a experimental feature in v0.8.

Is there currently any (even experimental) support for autotvm/auto-scheduler atop TensorIR?

Thanks!

Yep @jkosaian , and that’s meta schedule :slight_smile:

Great! Thank you for your help!

Could you point us to a link with an example for using meta schedule with TensorIR? Doesn’t need to be a full tutorial, a unit test or experimental code would also be sufficient

Hi TVM.

I’m very interested in TensorIR. And I’m now testing a simple matrix multiplication and want to bind the axis to GPU thread. But there is an error saying that the child block is neither a complete block nor a reduction block. Would you please help me address this issue?

My code is shown below.

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle, c: T.handle):
        T.func_attr({"global_symbol": "main", "tir.noalias": True})
        A = T.match_buffer(a, [128, 128])
        B = T.match_buffer(b, [128, 128])
        C = T.match_buffer(c, [128, 128])
        
        for i, j, k in T.grid(128, 128, 128):
            with T.block("update"):
                vi, vj, vk= T.axis.remap("SSR", [i, j, k])
                C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

ir_module = MyModule
sch = tvm.tir.Schedule(ir_module)
block_b = sch.get_block("update")
(i,j,k) = sch.get_loops(block_b)
sch.bind(i,"threadIdx.x")

@RenyuanLIU You probably forgot the init block. CC: @Hzfengsy

Hi @RenyuanLIU. A reduction block contain both init part and update part. The full block should be

with T.block():
    vi, vj, vk= T.axis.remap("SSR", [i, j, k])
    with T.init():
        C[vi, vj] = 0
    C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]

Those seems to be great lessons that we can use to enhance tutorials and docs!

2 Likes

Thank you very much!