Hello all,
I have been confronted with a problem that I am unsure how to solve and therefore thought to ask the community. Hopefully the problem is not so trivial
We know that using some_tvm_schedule[some_tvm_stage].split(axis=some_tvm_axis,factor=some_factor)
scheduling primitive we can create two axes some_tvm_axis.outer
and some_tvm_axis.inner
the domains would be (to my understanding):
some_tvm_axis.outer.dom = [0, ceil(previous_dom/some_factor))
some_tvm_axis.inner.dom = [0,some_factor)
This is all good and clear to understand… until you have a schedule with a couple of stages.
Why? If the schedule has many stages, then I have an impression that the schedule.normalize()
and schedule.InferBound(...)
will create larger domains for the stages farther away from the output.
- Is this impression I have true? or am I seeing such an expansion due to some other problems?
Suppose now that you would want to split some axis and you wish that the split factors are integer dividers of the original domains (i.e. if(not some_tvm_axis.dom.extent%16){some_factor=4}
) . Because the methods mentioned before (i.e. InferBound
) change the domains of the axes of all stages, then you can no longer expect the inferred bounds to still be an integer multiplier of the factors.
But at the level you can call split(...)
there is no information of the consumer stages expanding the original domains. In other words you define all your split(...)
assuming that the domains wont be changed
- Is there any way to track the changes of domains before
schedule.ScheduleOps(...)
? or am I again misunderstanding something
I know that the ir_pass.LoopPartition
(if enabled) can deal with some of the cases of the domain being expanded (a bunch of if(likely(...))
statements get simplified), but I am still seeing some problems of ir_pass.LoopPartition
generating for loops like this for (variable_name, 0,0)
which is a for loop which shouldn’t execute
- Can all code inside a scope of a
for(variable_name,0,0)
be skipped? always?
FYI here a simple output of print(tvm.lower(...,simple_mode=True))
for two cases
In both:
- Original dimension of axes c=[0,64), h=[0,256) and w=[0,512)
- Splitting factors are 64, 256 and 512 respectively (i.e. perfect splits)
//First case taskDataOutput_T0_L1 is the output of the schedule
//Without Splitting
produce taskDataOutput_T0_L1 {
for (c, 0, 64) {
for (h, 0, 256) {
for (w, 0, 512) {
taskDataOutput_T0_L1[((((c*256) + h)*512) + w)] = previousStageTensor[((((c*256) + h)*512) + w)]
}
}
}
}
//With Splitting
produce taskDataOutput_T0_L1 {
for (c.inner, 0, 64) {
for (h.inner, 0, 256) {
for (w.inner, 0, 512) {
taskDataOutput_T0_L1[((((c.inner*256) + h.inner)*512) + w.inner)] = previousStageTensor[((((c.inner*256) + h.inner)*512) + w.inner)]
}
}
}
}
//Second case taskDataOutput_T0_L1 is an intermediate stage of the schedule
//Without Splitting
produce taskDataOutput_T0_L1 {
for (c, 0, 64) {
for (h, 0, 258) {//was originally 256
for (w, 0, 514) { //was originally 512
if (likely((1 <= h))) {
if (likely((h < 257))) {
if (likely((1 <= w))) {
if (likely((w < 513))) {
taskDataOutput_T0_L1[(((((c*256) + h)*512) + w) + -513)] = previousStageTensor[(((((c*256) + h)*512) + w) + -513)]
}
}
}
}
}
}
}
}
//With Splitting
produce taskDataOutput_T0_L1 {
for (h.outer, 0, 2) {//because ceil(258/256) !=0
for (w.outer, 0, 2) {//because ceil(514/512) !=0
for (c.inner, 0, 64) {
for (h.inner, 0, 256) {
for (w.inner, 0, 512) {
if (likely(((h.outer*256) < (258 - h.inner)))) {
if (likely(((w.outer*512) < (514 - w.inner)))) {
if (likely(((1 - h.inner) <= (h.outer*256)))) {
if (likely(((h.outer*256) < (257 - h.inner)))) {
if (likely(((1 - w.inner) <= (w.outer*512)))) {
if (likely(((w.outer*512) < (513 - w.inner)))) {
taskDataOutput_T0_L1[(((((((c.inner + h.outer)*256) + h.inner) + w.outer)*512) + w.inner) + -513)] = previousStageTensor[(((((((c.inner + h.outer)*256) + h.inner) + w.outer)*512) + w.inner) + -513)]
}
}
}
}
}
}
}
}
}
}
}
}
- Suggestions?
Thanks
PS:
The printed code after enabling ir_pass.LoopPartition
in case #2 with splitting.
(Which in all honesty looks awful but is probably correct)
produce taskDataOutput_T0_L1 {
for (h.outer, 0, 2) {
for (w.outer, 0, 2) {
for (c.inner, 0, 64) {
for (h.inner, 0, (1 - min((h.outer*256), 1))) {
for (w.inner, 0, 512) {
if (((h.outer*256) < (258 - h.inner))) {
if (((w.outer*512) < (514 - w.inner))) {
if (((1 - h.inner) <= (h.outer*256))) {
if (((h.outer*256) < (257 - h.inner))) {
if (((1 - w.inner) <= (w.outer*512))) {
if (((w.outer*512) < (513 - w.inner))) {
taskDataOutput_T0_L1[(((((((c.inner + h.outer)*256) + h.inner) + w.outer)*512) + w.inner) + -513)] = previousStageTensor[(((((((c.inner + h.outer)*256) + h.inner) + w.outer)*512) + w.inner) + -513)]
}
}
}
}
}
}
}
}
for (h.inner, 0, ((min((h.outer*256), 1) - max((h.outer*256), 2)) + 256)) {
for (w.inner, 0, (1 - min((w.outer*512), 1))) {
if (((w.outer*512) < (514 - w.inner))) {
if (((1 - w.inner) <= (w.outer*512))) {
if (((w.outer*512) < (513 - w.inner))) {
taskDataOutput_T0_L1[(((((((c.inner + h.outer)*256) + (h.inner - min((h.outer*256), 1))) + w.outer)*512) + w.inner) + -1)] = previousStageTensor[(((((((c.inner + h.outer)*256) + (h.inner - min((h.outer*256), 1))) + w.outer)*512) + w.inner) + -1)]
}
}
}
}
for (w.inner, 0, ((min((w.outer*512), 1) - max((w.outer*512), 2)) + 512)) {
taskDataOutput_T0_L1[((((((c.inner + h.outer)*256) + (h.inner - min((h.outer*256), 1))) + w.outer)*512) + (w.inner - min((w.outer*512), 1)))] = previousStageTensor[((((((c.inner + h.outer)*256) + (h.inner - min((h.outer*256), 1))) + w.outer)*512) + (w.inner - min((w.outer*512), 1)))]
}
for (w.inner, 0, (max((w.outer*512), 2) + -1)) {
if (((w.outer*512) < ((max((w.outer*512), 2) - w.inner) + 1))) {
if ((((max((w.outer*512), 2) - w.inner) + -512) <= (w.outer*512))) {
if (((w.outer*512) < (max((w.outer*512), 2) - w.inner))) {
taskDataOutput_T0_L1[(((((((c.inner + h.outer)*256) + (h.inner - min((h.outer*256), 1))) + w.outer)*512) + (w.inner - max((w.outer*512), 2))) + 512)] = previousStageTensor[(((((((c.inner + h.outer)*256) + (h.inner - min((h.outer*256), 1))) + w.outer)*512) + (w.inner - max((w.outer*512), 2))) + 512)]
}
}
}
}
}
for (h.inner, 0, (max((h.outer*256), 2) + -1)) {
for (w.inner, 0, 512) {
if (((h.outer*256) < ((max((h.outer*256), 2) - h.inner) + 1))) {
if (((w.outer*512) < (514 - w.inner))) {
if ((((max((h.outer*256), 2) - h.inner) + -256) <= (h.outer*256))) {
if (((h.outer*256) < (max((h.outer*256), 2) - h.inner))) {
if (((1 - w.inner) <= (w.outer*512))) {
if (((w.outer*512) < (513 - w.inner))) {
taskDataOutput_T0_L1[(((((((c.inner + h.outer)*256) + (h.inner - max((h.outer*256), 2))) + w.outer)*512) + w.inner) + 131071)] = previousStageTensor[(((((((c.inner + h.outer)*256) + (h.inner - max((h.outer*256), 2))) + w.outer)*512) + w.inner) + 131071)]
}
}
}
}
}
}
}
}
}
}
}
}