Introducing TY-NNP backend with end2end TensorIR integration

Thanks for your comments:)

Perhaps I could take a fake example on Conv2d to describe it:

fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3, 7, 7), int8]) {
  %conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight: Tensor[(32, 3, 7, 7), int8], Primitive=1) {
    nn.conv2d(%data, %weight, padding=[1, 1, 1, 1],  kernel_size=[7, 7], out_dtype="int32")
  };
  %conv_fn(%arg0, %nn.conv2d_arg)
}

and the coresponding PrimFunc for primitive call %conv_fn would be like

@T.prim_func
def main(x: T.Buffer[...], weight: T.Buffer[(32, 3, 7, 7), "int8"], y: T.Buffer[...]) -> None:
     # body

Assume to utilize the specific hardware, we want to arrange I/O channels into 4*4 tiles. There are extra two notes:

  • We get to know the “best” weight layout until a TIR schedule/tuning is done.
  • The required layout is out of scope of common representations like “OIHW”, “OHWI”, etc.

The TIR schedule part would do following transformation on weight:

o, i, h, w = s.get_read_buffer_axes(conv_block)
o_outer, o_inner = s.buffer_split(o, factor=4)  # [32, 3, 7, 7] -> [8, 4, 3, 7, 7]
i_outer, i_inner = s.buffer_split(i, factor=4)  # [8, 4, 3, 7, 7] -> [8, 4, 1, 4, 7, 7]
s.buffer_reorder(o_outer, o_inner, i_outer, i_inner, h, w)  #  [8, 4, 1, 4, 7, 7] -> [8, 1, 4, 4, 7, 7]

Above we use a set of extended TensorIR primitives, but they can just be seen as sugars of ongoing schedule primitive transform_layout:

The point is that they are not arbitary index remappings (compare to a general transform_layout). We ensure every such schedule step takes exact equivalent relay transformations.

In TIR schedule phase, we trace every buffer layout change on function param buffer (we can do that since they are what we implement), generate the transform (&& reverse transform) in relay on each step, and finally compose them into single layout transform (&& reverse transform) functions in relay.

For the used example, it would be:

  • s.buffer_split(o, factor=4)

    • x → relay.reshape(x, [-1, 4, 3, 7, 7])
    • (reverse) x → relay.reshape(x, [32, 3, 7, 7])
  • s.buffer_split(i, factor=4)

    • x → relay.reshape(relay.nn.pad(x, […, (0, 1), …]), [8, 4, -1, 4, 7, 7])
    • (reverse) x → relay.strided_slice(relay.reshape(x, [8, 4, 4, 7, 7]), begin=…, end=…)
  • s.buffer_reorder(...)

    • x → relay.transpose(x, […])
    • (reverse) x → relay.transpose(x, […])

Finally all transforms (&& reverse transforms) are composed into two relay.Function objects to rewrite relay-level layouts, which accepts original relay params, returns updated params tuple:

fn (%p0: Tensor[..., int8], %p1: Tensor[(32, 3, 7, 7), int8]) {
  %0 = reshape(%p1, newshape=[...]);
  %1 = nn.pad(%0, pad_width=[...]);
  %2 = reshape(%1, newshape=[...]);
  %3 = transpose(%2, axes=[...]);
  (%p0, %3)
}

and the reverse direction is:

fn (%p0: Tensor[..., int8], %p1: Tensor[(8, 4, 1, 4, 7, 7), int8]) {
  %0 = transpose(%p1, axes=[...]);
  %1 = reshape(%0, newshape=[...]);
  %2 = strided_slice(%1, begin=[...], end=[...], strides=[...]);
  %3 = reshape(%2, newshape=[32, 3, 7, 7]);
  (%p0, %3)
}

A relay pass now can perform “pre”-schedule for each primitive function, fetch the layout transform functions from schedule result, and perform relay-level layout updation. Finally, an extra FoldConstants could eliminate all extra transformations out of primitive calls typically.

 fn (%arg0: Tensor[(1, 32, 224, 224), int8], %nn.conv2d_arg: Tensor[(32, 3, 7, 7), int8]) {
  %0 = reshape(%nn.conv2d_arg, newshape=[...]);
  %1 = nn.pad(%0, pad_width=[...]);
  %2 = reshape(%1, newshape=[...]);
  %3 = transpose(%2, axes=[...]);
  %conv_fn = fn (%data: Tensor[(1, 3, 224, 224), int8], %weight: Tensor[(8, 4, 1, 4, 7, 7), int8], Primitive=1, DevicePrimFuncKey=873487) {
   %4 = transpose(%weight, axes=[...]);
   %5 = reshape(%4, newshape=[...]);
   %6 = strided_slice(%5, begin=[...], end=[...], strides=[...]);
   %7 = reshape(%6, newshape=[32, 3, 7, 7]); 
   nn.conv2d(%data, %7, padding=[1, 1, 1, 1], kernel_size=[7, 7], out_dtype="int32");
  };
  %conv_fn(%arg0, %3)
}

The actual params are transformed before call into %conv_fn and the formal params are reversed within %conv_fn’s body. Why we need reverse transforms is that we currently can not represent a “lowered” function call in relay (correct me). It is a workaround for us to keep a valid primitive function body, that is, the relay module after pass can still be safely evaluated on a CPU.

All things described are only targeted to weights (free tensors) now. We check that a tensor produced/consumed by other relay calls should not get transformed. For input and output layouts, we find relay ConvertLayout can cover the currently demands. However, I think there is no essential difference between “appliable functions to transform layout” and a simple tag like “NCHW” on a input/output, it is possible to rewrite the input/output with the same machanism.

One remaining issue here is that we have to hack the CompileEngine(now TECompiler) to cache and reuse the previously scheduled PrimFuncs. Very glad to know if existing machanisms (like relay_to_tir?) can help us :slight_smile: cc @areusch