Finally, Texture scoping realization works by
-
Stage 1 is generic and straight forward by using convert_layout
pass that transforms the
shapes as well as injecting layout_transform
ops as needed.
-
Stage 2 This pass is responsible for injecting appropriate VDevice
into StructInfo
and
adding any copies if there is a conflict between producer and consumer scopes.
After convert_layout
the mod looks like below
``
I.ir_module
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv: R.Tensor((2, 16, 56, 56, 4), dtype="float32") = R.layout_transform(
x,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0, i1 // 4, i2, i3, i1 % 4)))
lv1: R.Tensor((8, 64, 3, 3, 4), dtype="float32") = R.layout_transform(
w,
index_map=T.index_map(
lambda i0, i1, i2, i3: (i0 // 4, i1, i2, i3, i0 % 4)))
lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d(
lv,
lv1,
data_layout="NCHW4c",
kernel_layout="OIHW4o",
out_layout="NCHW4c",
out_dtype="float32"
)
gv: R.Tensor((2, 32, 54, 54), dtype="float32") = R.layout_transform(
lv2,
index_map=T.index_map(
lambda i0, i1, i2, i3, i4: (i0, i1 * 4 + i4, i2, i3)))
R.output(gv)
return gv
``
Here, the param layout transforms are injected properly and the conv2d op is operating
in 5D shapes.
Now, the scope annotation decisions are done by
- For
op_pattern
< kCommReduce
we just look for shape being 5D and inner dimension = 4
- For
op_pattern
> kCommReduce
we make decisions selectively. Currently we do enable texture scope for Conv2D
, PoolOps
.
The trick here is while this pass is in action we need op_pattern
information for ops that are
below kCommReduce
as well op attrbutes for selective ops like Conv2D
, PoolOps
…etc.
op_pattern
is available after legalization and TIROpPattern
pass does an analysis. However,
op specific attributes doesn’t exist after legalization.
To solve this issue, we go legalization in parts.
At first, we call legalization by skipping the list of ops we wanted not to legalize.
LigalizeOps
is enhanced to accept skip_ops
list for this purpose.
After legalization and AnnotateTIROpPattern
this way the mod likes like
``
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv = R.call_tir(cls.te_layout_transform, (x,),
out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32")
)
lv1 = R.call_tir(cls.te_layout_transform1, (w,),
out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32")
)
lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32") = R.nn.conv2d(
lv,
lv1,
data_layout="NCHW4c",
kernel_layout="OIHW4o",
out_layout="NCHW4c",
out_dtype="float32"
)
gv = R.call_tir(cls.te_layout_transform2, (lv2,),
out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32")
)
R.output(gv)
return gv
``
Here, the legalized prim functions does have op_pattern
attribute.
We now have what we wanted to run this pass.
This pass in principle does scope annotation based on consumer prioriry. i.e.
For any tensor object we tries to assign scope based on the sonsuner requirement.
The conflicts and multiple consumers for same tensor are handled by injecting
appropriate copies.
-
CollectConsumerScopeInfo:
Visitor collects all consumer demand for each input
-
CollectProducerScopeInfo:
Visitor does finalizes the scope for each input and output based
on consumer scope information. It does evaluating mutiple consumer cases and conflicts.
-
DefineVDevice:
Pass does injects hint_on_device for each argument. It also tries to update
out StructInfo containing VDevice information. This update for tir calls is straight forward
as sinfo_args
in CallNode
is meant for this purpose. This sinfo_args
for other calls by
design is invalid as we do this by “FInferStructInfo”.
Another issue we have with “FInferStructInfo” is they can’t decide this
memory scope information which is done by this pass based on consumer demand.
Hence, we are going to use the sinfo_args
to indicate this information.
So, this pass attributes sinfo_args
for regumar calls too and FInferStructInfo
implmentation
do take VDevice
information from this hint. This also solves the issue of mixed VDevice
for arguments of an op this way.
After these steps the mod looks like
``
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv: R.Tensor((2, 64, 56, 56), dtype="float32") = R.hint_on_device(
x, R.device(dev_type=4, dev_id=0), "global"
)
lv_1 = R.call_tir(cls.te_layout_transform, (lv,),
out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32",
vdevice="opencl:0:global.texture-nhwc"
)
)
lv1: R.Tensor((32, 64, 3, 3), dtype="float32") = R.hint_on_device(
w, R.device(dev_type=4, dev_id=0), "global"
)
lv1_1 = R.call_tir(cls.te_layout_transform1, (lv1,),
out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32",
vdevice="opencl:2:global.texture-weight"
)
)
lv2: R.Tensor((2, 16, 56, 56, 4), dtype="float32",
vdevice="opencl:0:global.texture-nhwc"
) = R.hint_on_device(lv_1, R.device(dev_type=4, dev_id=0), "global.texture-nhwc")
lv3: R.Tensor((8, 64, 3, 3, 4), dtype="float32",
vdevice="opencl:2:global.texture-weight"
) = R.hint_on_device(lv1_1, R.device(dev_type=4, dev_id=0), "global.texture-weight")
lv2_1: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"
) = R.nn.conv2d(
lv2, lv3,
data_layout="NCHW4c", kernel_layout="OIHW4o",
out_layout="NCHW4c", out_dtype="float32",
sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"),
)
)
lv4: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"
) = R.hint_on_device(lv2_1, R.device(dev_type=4, dev_id=0), "global")
gv = R.call_tir(cls.te_layout_transform2, (lv4,),
out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global")
)
R.output(gv)
return gv
``
What we have above is hint_on_device
injections and out_sinfo
for all calls.
Now, we apply RealizeVDevice
to formalize the hints. Followed by we also call
RemoveRedundantAssignments
that removes redundant assignments like
lv: R.Tensor((2, 64, 56, 56), dtype="float32", vdevice="opencl:1:global") = x
lv1: R.Tensor((32, 64, 3, 3), dtype="float32", vdevice="opencl:1:global") = w
These assignments are result of hint_on_device not realizing any copy while consumer and
producer has same memory scope or vdevice. These assignments do impact operator fusion later.
Now the mod looks like,
``
class Module:
@R.function
def main(
x: R.Tensor((2, 64, 56, 56), dtype="float32"),
w: R.Tensor((32, 64, 3, 3), dtype="float32")
) -> R.Tensor((2, 32, 54, 54), dtype="float32"):
with R.dataflow():
lv = R.call_tir(cls.te_layout_transform, (x,),
out_sinfo=R.Tensor((2, 16, 56, 56, 4), dtype="float32",
vdevice="opencl:0:global.texture-nhwc"
)
)
lv1 = R.call_tir(cls.te_layout_transform1, (w,),
out_sinfo=R.Tensor((8, 64, 3, 3, 4), dtype="float32",
vdevice="opencl:2:global.texture-weight"
)
)
lv2: R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"
) = R.nn.conv2d(
lv2, lv3,
data_layout="NCHW4c", kernel_layout="OIHW4o",
out_layout="NCHW4c", out_dtype="float32",
sinfo_args=(R.Tensor((2, 8, 54, 54, 4), dtype="float32",
vdevice="opencl:1:global"),
)
)
gv = R.call_tir(cls.te_layout_transform2, (lv4,),
out_sinfo=R.Tensor((2, 32, 54, 54), dtype="float32", vdevice="opencl:1:global")
)
R.output(gv)
return gv
``
Followed by, the compilation pipeline calls
- legalization of the remaining ops: This legalization do forwards the annotated
out_sinfo
VDevice
information to tir_calls
-
AnnotateTIROpPattern
: TIROp Patterns for newly legalizes ops
- Fusion
-
OptimizeToVDeviceForScopeChange
: There exist some ToVDevice copies from texture to buffer
This pass removes the copes and updates producer scope to global.
-
SpecializePrimFuncBasedOnCallSite
: Finally we update the Buffer Var maps according to
VDevice scopes.
To Review:
FInferStructInfo
enhancement by mixed scope is handled by hinting required scope info in sinfo_args
in CallNode
Ref. tvm/python/tvm/relax/utils.py at cc03780b1cee0a06a26161181aede98b3a39d00f · apache/tvm · GitHub
The output VDevice
is assumed to be same as inputs VDevice. The hinted VDevice info doesn’t forward here. Hence, we are updating the VDevice
info in legalization pass post processing looking at CallNode hinted StructInfo.