The Relay OP argwhere with dynamic shape fails to run on OpenCL devices that do not support double precision

OPENCL EXTENSION cl_khr_fp64 is not supported by some OpenCL device. tvm.topi.ceil_log2 is part of argwhere. When it comes to dynamic shape, ceil_log2 would return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, “float64”))), x.dtype).

so I added a condition in python/tvm/topi/math.py to make ceil_log2 return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, “float32”))), x.dtype) and made it work.

if "opencl" in tvm.target.Target.current().kind.name:
        if isinstance(x, tvm.tir.expr.Var):
            if ("any_dim" in x.name) & ("int32" in x.dtype):
                return cast(tvm.tir.ceil(tvm.tir.log2(cast(x, "float32"))), x.dtype)

The default data type of any_dim is int32, so I think some values close to 2,147,483,647 would be converted into wrong values in float32.

In this case, what reasonable changes should I make?

I met a similar issue with vulkan. My solution is to compute integer log2 entirely in integer, without casting to fp64. See https://github.com/apache/tvm/blob/e082ef5f9db89d3f0ff20340717b9d7547a25eaf/python/tvm/topi/math.py#L747-L750

You need to add support for computing clz via opencl intrinsic. See https://github.com/apache/tvm/pull/7825

1 Like

Hi masahi,

Just add

TVM_REGISTER_OP("tir.clz")
    .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", DispatchPureExtern<Direct>);

in tvm/src/target/source/intrin_rule_opencl.cc

and use same computing rule as vulkan

if "opencl" in tvm.target.Target.current().kind.name:
        if isinstance(x, tvm.tir.expr.Var):
            if ("any_dim" in x.name) & ("int" in x.dtype):
                clz = tvm.tir.clz(x)
                bits = int(x.dtype[-2:])
                res = tvm.tir.if_then_else(x & (x - 1) == 0, bits - clz - 1, bits - clz)
                
                if res.dtype != x.dtype:
                    return cast(res, x.dtype)
                
                return res

and I got OpenCL kernel source like

start[(0)] = (((long)2 << ((((long)(((any_dim & (any_dim - 1)) == 0) ? (31 - clz(any_dim)) : (32 - clz(any_dim)))) - ((long)j)) - (long)1)) * ((((long)((int)get_group_id(0))) * (long)256) + ((long)((int)get_local_id(0)))));

It work in OpenCL,too. Thank you so much for your suggestion.