How to let cache_write() compute at "k" axis in GEMM?

Looks like existing TVM (v0.7) doesn’t allow write cache to compute at “k” axis in GEMM? Please correct me If I’m wrong.

Say I want to create a write cache for matrix C in GEMM, and let “k” to be the outmost axis in the loop nest, thus the schedule code I wrote would be look like this:

CC = s.write_cache(C, 'global');

# The following line crashes:

s[CC].compute_at(s[C], k)

I do understand it’s a desired behavior because after write_cache(), schedule “s” won’t have the reduce_axis “k” anymore, so compute at “k” is not allowed. That’s the reason why either compute at “x” or “y” works properly, while compute at “k” doesn’t.

My question is, does this mean that we’re not allowed to let cache “CC” to compute in the loop order of [k, x, y] or [k, y, x], whenever “k” is the out-most axis? If not, how?