Casting during reduction

I have implemented a correct-for-large-downscale-increments bilinear filter:

https://github.com/tech-ascent/tvm-clj/blob/master/src/tvm_clj/image/bilinear_reduce.clj#L122

I had to do this with in 2 steps. The first does the reduction in floating point space and saves the result to a temp image. Then a second step casts the temp image back to the original datatype.

Due to mathematical error, I did not want to attempt the reduction completely in uint8 space. I realize that being quite careful this should be possible but let’s ignore that form the moment. Looking at the code, it should be possible to do the reduction to a floating point scalar register and then perform the cast and store the scalar into the dest image. This would be the fastest this could be implemented generally I think especially combined with tiling.

I was unable to figure out how to do this with TVM. Reductions have to be top level nodes and I do not think that I can specify the datatype of the reduction operation to be different than the datatype of the underlying storage.

Ideally I can parameterize my algorithm on at least 2 parameters regardless of execution environment: the storage datatype and the reduction datatype because specifically I know that in this case the reduction can be done completely in a register before being written back into the result.

It is almost like the reduce operator needs an operation 3rd argument which is the operation to take to just before storing the result into the destination which would default to identity.

  1. Can I do this type of operation in 1 step? Reduction in a different numerical space than the storage?

This is easy.

import tvm
from tvm.contrib.util import get_lower_ir

n = 10

A = tvm.placeholder((n, n), dtype='int8', name='A')
k = tvm.reduce_axis((0, n), name='k')
B = tvm.compute((n,), lambda i: tvm.sum(A[i, k].astype('int32'), k), name='reg')
C = tvm.compute((n,), lambda i: B[i].astype('int8'), name='true_storage')

s = tvm.create_schedule([C.op])
s[B].compute_at(s[C], s[C].op.axis[0])
print(get_lower_ir(s))

output

// attr [true_storage] storage_scope = "global"
allocate true_storage[int8 * 10]
// attr [reg] storage_scope = "global"
allocate reg[int32 * 1]
produce true_storage {
  for (i, 0, 10) {
    produce reg {
      reg[0] = 0
      for (k, 0, 10) {
        reg[0] = (reg[0] + int32(A[((i*10) + k)]))
      }
    }
    true_storage[i] = int8(reg[0])
  }
}

That worked great, thanks!