Hi. I’d like to have an ‘argmax’ operation in NNVM. I discovered that it exists in TVM and only NNVM is missing so I tried to add it using existing code for ‘max’ as a template.
I’ve tested topi::argmax by writing a small C++ program, and it seems to work correctly. Unfortunately, call to the same function from NNVM seems to drop information about return type at some point. If the correct answer is 1, sym.argmax
returns 1e-45, which is probably due to int32 being treated as float32.
I’ve published the current patch as https://github.com/dmlc/tvm/pull/1462 Could you please suggest what module may need additional debugging ? In particular, did I write FInferType
-related code correctly?
Regards