Finally, Texture scoping realization works by
-
Stage 1 is generic and straight forward by using
convert_layoutpass that transforms the shapes as well as injectinglayout_transformops as needed. -
Stage 2 This pass is responsible for injecting appropriate
VDeviceintoStructInfoand adding any copies if there is a conflict between producer and consumer scopes.
After convert_layout the mod looks like below
``
I.ir_module
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv: R.Tensor((2, 16, 56, 56, 4), dtype="float32") = R.layout_transform(
x,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4)))
lv1: R.Tensor((8, 64, 3, 3, 4), dtype="float32") = R.layout_transform(
w,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4)))
lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d(
lv,
lv1,
data_layout="NCHW4c",
kernel_layout="OIHW4o",
out_layout="NCHW4c",
out_dtype="float32"
)
gv: R.Tensor((2, 32, 54, 54), dtype="float32") = R.layout_transform(
lv2,
index_map=T.index_map(
lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3)))
R.output(gv)
return gv
``
Here, the param layout transforms are injected properly and the conv2d op is operating in 5D shapes.
Now, the scope annotation decisions are done by
- For
op_pattern<kCommReducewe just look for shape being 5D and inner dimension = 4 - For
op_pattern>kCommReducewe make decisions selectively. Currently we do enable texture scope forConv2D,PoolOps.
The trick here is while this pass is in action we need op_pattern information for ops that are
below kCommReduce as well op attrbutes for selective ops like Conv2D, PoolOps …etc.
op_pattern is available after legalization and TIROpPattern pass does an analysis. However,
op specific attributes doesn’t exist after legalization.
To solve this issue, we go legalization in parts.
At first, we call legalization by skipping the list of ops we wanted not to legalize.
LigalizeOps is enhanced to accept skip_ops list for this purpose.
After legalization and AnnotateTIROpPattern this way the mod likes like
``
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv = R.call_tir(cls.te_layout_transform, (x,),
out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32")
)
lv1 = R.call_tir(cls.te_layout_transform1, (w,),
out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32")
)
lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d(
lv,
lv1,
data_layout="NCHW4c",
kernel_layout="OIHW4o",
out_layout="NCHW4c",
out_dtype="float32"
)
gv = R.call_tir(cls.te_layout_transform2, (lv2,),
out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32")
)
R.output(gv)
return gv
``
Here, the legalized prim functions does have op_pattern attribute.
We now have what we wanted to run this pass.
This pass in principle does scope annotation based on consumer prioriry. i.e. For any tensor object we tries to assign scope based on the sonsuner requirement. The conflicts and multiple consumers for same tensor are handled by injecting appropriate copies.
-
CollectConsumerScopeInfo:Visitor collects all consumer demand for each input -
CollectProducerScopeInfo:Visitor does finalizes the scope for each input and output based on consumer scope information. It does evaluating mutiple consumer cases and conflicts. -
DefineVDevice:Pass does injects hint_on_device for each argument. It also tries to update out StructInfo containing VDevice information. This update for tir calls is straight forward assinfo_argsinCallNodeis meant for this purpose. Thissinfo_argsfor other calls by design is invalid as we do this by “FInferStructInfo”. Another issue we have with “FInferStructInfo” is they can’t decide this memory scope information which is done by this pass based on consumer demand. Hence, we are going to use thesinfo_argsto indicate this information. So, this pass attributessinfo_argsfor regumar calls too andFInferStructInfoimplmentation do takeVDeviceinformation from this hint. This also solves the issue of mixed VDevice for arguments of an op this way.
After these steps the mod looks like
``
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv: R.Tensor((2, 64, 56, 56), dtype="float32") = R.hint_on_device(
x, R.device(dev_type=4, dev_id=0), "global"
)
lv_1 = R.call_tir(cls.te_layout_transform, (lv,),
out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32",
vdevice="opencl:0:global.texture-nhwc"
)
)
lv1: R.Tensor((32, 64, 3, 3), dtype="float32") = R.hint_on_device(
w, R.device(dev_type=4, dev_id=0), "global"
)
lv1_1 = R.call_tir(cls.te_layout_transform1, (lv1,),
out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32",
vdevice="opencl:2:global.texture-weight"
)
)
lv2: R.Tensor((2, 16, 56, 56, 4), dtype="float32",
vdevice="opencl:0:global.texture-nhwc"
) = R.hint_on_device(lv_1, R.device(dev_type=4, dev_id=0), "global.texture-nhwc")
lv3: R.Tensor((8, 64, 3, 3, 4), dtype="float32",
vdevice="opencl:2:global.texture-weight"
) = R.hint_on_device(lv1_1, R.device(dev_type=4, dev_id=0), "global.texture-weight")
lv2_1: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"
) = R.nn.conv2d(
lv2, lv3,
data_layout="NCHW4c", kernel_layout="OIHW4o",
out_layout="NCHW4c", out_dtype="float32",
sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"),
)
)
lv4: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"
) = R.hint_on_device(lv2_1, R.device(dev_type=4, dev_id=0), "global")
gv = R.call_tir(cls.te_layout_transform2, (lv4,),
out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global")
)
R.output(gv)
return gv
``
What we have above is hint_on_device injections and out_sinfo for all calls.
Now, we apply RealizeVDevice to formalize the hints. Followed by we also call
RemoveRedundantAssignments that removes redundant assignments like
lv: R.Tensor((2, 64, 56, 56), dtype="float32", vdevice="opencl:1:global") = x
lv1: R.Tensor((32, 64, 3, 3), dtype="float32", vdevice="opencl:1:global") = w
These assignments are result of hint_on_device not realizing any copy while consumer and producer has same memory scope or vdevice. These assignments do impact operator fusion later.
Now the mod looks like,
``
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv = R.call_tir(cls.te_layout_transform, (x,),
out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32",
vdevice="opencl:0:global.texture-nhwc"
)
)
lv1 = R.call_tir(cls.te_layout_transform1, (w,),
out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32",
vdevice="opencl:2:global.texture-weight"
)
)
lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"
) = R.nn.conv2d(
lv2, lv3,
data_layout="NCHW4c", kernel_layout="OIHW4o",
out_layout="NCHW4c", out_dtype="float32",
sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"),
)
)
gv = R.call_tir(cls.te_layout_transform2, (lv4,),
out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global")
)
R.output(gv)
return gv
``
Followed by, the compilation pipeline calls
- legalization of the remaining ops: This legalization do forwards the annotated
out_sinfoVDeviceinformation totir_calls -
AnnotateTIROpPattern: TIROp Patterns for newly legalizes ops - Fusion
-
OptimizeToVDeviceForScopeChange: There exist some ToVDevice copies from texture to buffer This pass removes the copes and updates producer scope to global. -
SpecializePrimFuncBasedOnCallSite: Finally we update the Buffer Var maps according to VDevice scopes.
To Review:
FInferStructInfo enhancement by mixed scope is handled by hinting required scope info in sinfo_args in CallNode
Ref. tvm/python/tvm/relax/utils.py at cc03780b1cee0a06a26161181aede98b3a39d00f · apache/tvm · GitHub
The output VDevice is assumed to be same as inputs VDevice. The hinted VDevice info doesn’t forward here. Hence, we are updating the VDevice info in legalization pass post processing looking at CallNode hinted StructInfo.