[RFC] Annotate Custom Scope layout Relax pass for Adreno GPU

Finally, Texture scoping realization works by

  • Stage 1 is generic and straight forward by using convert_layout pass that transforms the shapes as well as injecting layout_transform ops as needed.

  • Stage 2 This pass is responsible for injecting appropriate VDevice into StructInfo and adding any copies if there is a conflict between producer and consumer scopes.

After convert_layout the mod looks like below

``

 I.ir_module
 class Module:
   @R.function
   def main(
     x: R.Tensor((2, 64, 56, 56), dtype="float32"),
     w: R.Tensor((32, 64, 3, 3), dtype="float32")
   ) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
      with R.dataflow():
          lv: R.Tensor((2, 16, 56, 56, 4), dtype="float32") = R.layout_transform(
              x,
              index_map=T.index_map(
                  lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4)))
          lv1: R.Tensor((8, 64, 3, 3, 4), dtype="float32") = R.layout_transform(
              w,
              index_map=T.index_map(
                  lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4)))
          lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d(
              lv,
              lv1,
              data_layout="NCHW4c",
              kernel_layout="OIHW4o",
              out_layout="NCHW4c",
              out_dtype="float32"
          )
          gv: R.Tensor((2, 32, 54, 54), dtype="float32") = R.layout_transform(
              lv2,
              index_map=T.index_map(
                  lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3)))
          R.output(gv)
      return gv

``

Here, the param layout transforms are injected properly and the conv2d op is operating in 5D shapes.

Now, the scope annotation decisions are done by

  • For op_pattern < kCommReduce we just look for shape being 5D and inner dimension = 4
  • For op_pattern > kCommReduce we make decisions selectively. Currently we do enable texture scope for Conv2D, PoolOps.

The trick here is while this pass is in action we need op_pattern information for ops that are below kCommReduce as well op attrbutes for selective ops like Conv2D, PoolOps …etc. op_pattern is available after legalization and TIROpPattern pass does an analysis. However, op specific attributes doesn’t exist after legalization.

To solve this issue, we go legalization in parts.

At first, we call legalization by skipping the list of ops we wanted not to legalize. LigalizeOps is enhanced to accept skip_ops list for this purpose. After legalization and AnnotateTIROpPattern this way the mod likes like

``

class Module:
    @R.function
    def main(
      x: R.Tensor((2, 64, 56, 56), dtype="float32"),
      w: R.Tensor((32, 64, 3, 3), dtype="float32")
    ) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
    with R.dataflow():
        lv = R.call_tir(cls.te_layout_transform, (x,),
            out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32")
        )
        lv1 = R.call_tir(cls.te_layout_transform1, (w,),
            out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32")
        )
        lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d(
            lv,
            lv1,
            data_layout="NCHW4c",
            kernel_layout="OIHW4o",
            out_layout="NCHW4c",
            out_dtype="float32"
        )
        gv = R.call_tir(cls.te_layout_transform2, (lv2,),
            out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32")
         )
        R.output(gv)
    return gv

``

Here, the legalized prim functions does have op_pattern attribute. We now have what we wanted to run this pass.

This pass in principle does scope annotation based on consumer prioriry. i.e. For any tensor object we tries to assign scope based on the sonsuner requirement. The conflicts and multiple consumers for same tensor are handled by injecting appropriate copies.

  • CollectConsumerScopeInfo: Visitor collects all consumer demand for each input
  • CollectProducerScopeInfo: Visitor does finalizes the scope for each input and output based on consumer scope information. It does evaluating mutiple consumer cases and conflicts.
  • DefineVDevice: Pass does injects hint_on_device for each argument. It also tries to update out StructInfo containing VDevice information. This update for tir calls is straight forward as sinfo_args in CallNode is meant for this purpose. This sinfo_args for other calls by design is invalid as we do this by “FInferStructInfo”. Another issue we have with “FInferStructInfo” is they can’t decide this memory scope information which is done by this pass based on consumer demand. Hence, we are going to use the sinfo_args to indicate this information. So, this pass attributes sinfo_args for regumar calls too and FInferStructInfo implmentation do take VDevice information from this hint. This also solves the issue of mixed VDevice for arguments of an op this way.

After these steps the mod looks like

``

class Module:
 @R.function
 def main(
   x: R.Tensor((2, 64, 56, 56), dtype="float32"),
   w: R.Tensor((32, 64, 3, 3), dtype="float32")
 ) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
    with R.dataflow():
      lv: R.Tensor((2, 64, 56, 56), dtype="float32") = R.hint_on_device(
           x, R.device(dev_type=4, dev_id=0), "global"
      )
      lv_1 = R.call_tir(cls.te_layout_transform, (lv,),
          out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32",
              vdevice="opencl:0:global.texture-nhwc"
          )
      )
      lv1: R.Tensor((32, 64, 3, 3), dtype="float32") = R.hint_on_device(
          w, R.device(dev_type=4, dev_id=0), "global"
      )
      lv1_1 = R.call_tir(cls.te_layout_transform1, (lv1,),
          out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32",
              vdevice="opencl:2:global.texture-weight"
          )
      )
      lv2: R.Tensor((2, 16, 56, 56, 4), dtype="float32",
          vdevice="opencl:0:global.texture-nhwc"
      ) = R.hint_on_device(lv_1, R.device(dev_type=4, dev_id=0), "global.texture-nhwc")
      lv3: R.Tensor((8, 64, 3, 3, 4), dtype="float32",
          vdevice="opencl:2:global.texture-weight"
      ) = R.hint_on_device(lv1_1, R.device(dev_type=4, dev_id=0), "global.texture-weight")
      lv2_1: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
          vdevice="opencl:1:global"
      ) = R.nn.conv2d(
          lv2, lv3,
          data_layout="NCHW4c", kernel_layout="OIHW4o",
          out_layout="NCHW4c", out_dtype="float32",
          sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32",
              vdevice="opencl:1:global"),
          )
      )
      lv4: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
          vdevice="opencl:1:global"
      ) = R.hint_on_device(lv2_1, R.device(dev_type=4, dev_id=0), "global")
      gv = R.call_tir(cls.te_layout_transform2, (lv4,),
          out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global")
      )
      R.output(gv)
  return gv

``

What we have above is hint_on_device injections and out_sinfo for all calls.

Now, we apply RealizeVDevice to formalize the hints. Followed by we also call RemoveRedundantAssignments that removes redundant assignments like

lv: R.Tensor((2, 64, 56, 56), dtype="float32", vdevice="opencl:1:global") = x lv1: R.Tensor((32, 64, 3, 3), dtype="float32", vdevice="opencl:1:global") = w

These assignments are result of hint_on_device not realizing any copy while consumer and producer has same memory scope or vdevice. These assignments do impact operator fusion later.

Now the mod looks like,

``

class Module:
    @R.function
    def main(
      x: R.Tensor((2, 64, 56, 56), dtype="float32"),
      w: R.Tensor((32, 64, 3, 3), dtype="float32")
    ) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
      with R.dataflow():
         lv = R.call_tir(cls.te_layout_transform, (x,),
             out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32",
                 vdevice="opencl:0:global.texture-nhwc"
             )
         )
         lv1 = R.call_tir(cls.te_layout_transform1, (w,),
             out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32",
                 vdevice="opencl:2:global.texture-weight"
             )
         )
         lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
             vdevice="opencl:1:global"
         ) = R.nn.conv2d(
             lv2, lv3,
             data_layout="NCHW4c", kernel_layout="OIHW4o",
             out_layout="NCHW4c", out_dtype="float32",
             sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32",
                 vdevice="opencl:1:global"),
             )
         )
         gv = R.call_tir(cls.te_layout_transform2, (lv4,),
             out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global")
         )
         R.output(gv)
     return gv

``

Followed by, the compilation pipeline calls

  • legalization of the remaining ops: This legalization do forwards the annotated out_sinfo VDevice information to tir_calls
  • AnnotateTIROpPattern : TIROp Patterns for newly legalizes ops
  • Fusion
  • OptimizeToVDeviceForScopeChange: There exist some ToVDevice copies from texture to buffer This pass removes the copes and updates producer scope to global.
  • SpecializePrimFuncBasedOnCallSite: Finally we update the Buffer Var maps according to VDevice scopes.

To Review:

FInferStructInfo enhancement by mixed scope is handled by hinting required scope info in sinfo_args in CallNode

Ref. tvm/python/tvm/relax/utils.py at cc03780b1cee0a06a26161181aede98b3a39d00f · apache/tvm · GitHub

The output VDevice is assumed to be same as inputs VDevice. The hinted VDevice info doesn’t forward here. Hence, we are updating the VDevice info in legalization pass post processing looking at CallNode hinted StructInfo.