Hi all,
I find programming in TVM to result in an extremely large number of non-scoped variables. The main problem is that the axes and tensors are not grouped, and simple mistakes result in extremely verbose low-level errors. 90% of my mistakes are just from not keeping tensors and axes grouped together.
I’m curious what people think of an less low-level scheduling language. I generally write my code in this style, which is much less verbose, fixes double splitting, and prevents errors from mixing up which axis belongs to which tensor.
ll, nn = s.axes(C)
reduce_axis = s.reduce_axis(C)
ll = ll.split(TPB)
nn = nn.split(TPB)
mm = reduce_axis.split(TPB)
s.reorder(C, (ll.outer, nn.outer, ll.inner, nn.inner, mm.outer, mm.inner))
# Bind blocks and threads to C
ll.outer.bind(te.thread_axis("blockIdx.x"))
nn.outer.bind(te.thread_axis("blockIdx.y"))
ll.inner.bind(tx)
nn.inner.bind(ty)
# Set up Caching
ll_A, mm_A = s.axes(AA)
ll_A = ll_A.split(TPB)
mm_A = mm_A.split(TPB)
s.reorder(AA, (ll_A.outer, mm_A.outer, ll_A.inner, mm_A.inner))
mm.outer.compute_at(AA)
ll_A.inner.bind(tx)
mm_A.inner.bind(ty)
Do people have any other tricks? Ideally there would be a really nice way to group together spliting of two tensors in the same way (in this case ll_A mirrrors ll, why are they seperate?)