AttributeError: <class 'tvm.te.tensor.ExternOp'> has no attribute axis

Hello guys. I meet an error when I open relay fused op pass as follow:

  3: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, tvm::GlobalVarSupply)
        at /cpfs01/user/hezeyu/code/tvm/src/relay/backend/te_compiler.cc:395
  2: tvm::relay::tec::PrimFuncFor(tvm::relay::Function const&, tvm::Target const&, tvm::GlobalVarSupply, tvm::NameSupply)
        at /cpfs01/user/hezeyu/code/tvm/src/relay/backend/te_compiler_cache.cc:756
  1: tvm::relay::tec::ScheduleBuilder::Create(tvm::relay::Function const&, tvm::GlobalVarSupply, tvm::NameSupply)
        at /cpfs01/user/hezeyu/code/tvm/src/relay/backend/te_compiler_cache.cc:692
  0: tvm::relay::OpImplementation::Schedule(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Target const&)
        at /cpfs01/user/hezeyu/code/tvm/src/relay/ir/op_strategy.cc:41
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/cpfs01/user/hezeyu/code/tvm/python/tvm/relay/op/strategy/cuda.py", line 36, in schedule_injective_cuda
    return topi.cuda.schedule_injective(outs)
  File "/cpfs01/user/hezeyu/code/tvm/python/tvm/topi/cuda/injective.py", line 136, in schedule_injective
    schedule_injective_from_existing(s, out)
  File "/cpfs01/user/hezeyu/code/tvm/python/tvm/topi/cuda/injective.py", line 50, in schedule_injective_from_existing
    fused = sch[out].fuse(*sch[out].op.axis)
  File "/cpfs01/user/hezeyu/code/tvm/python/tvm/runtime/object.py", line 75, in __getattr__
    raise AttributeError(f"{type(self)} has no attribute {name}") from None
AttributeError: <class 'tvm.te.tensor.ExternOp'> has no attribute axis

This is my ops:

    fuse_ops: def @main(%xyz_sampled: Tensor[(32768, 128, 3), float32] /* ty=Tensor[(32768, 128, 3), float32] */, %ray_valid: Tensor[(32768, 128), int8] /* ty=Tensor[(32768, 128), int8] */, %output: Tensor[(32768, 128, 4), float32] /* ty=Tensor[(32768, 128, 4), float32] */, %masks: Tensor[(24, 32768, 128), int8] /* ty=Tensor[(24, 32768, 128), int8] */) -> Tensor[(32768, 128, 4), float32] {
      %1 = fn (%p0: Tensor[(32768, 128, 3), float32] /* ty=Tensor[(32768, 128, 3), float32] */, %p1: Tensor[(32768, 128), int8] /* ty=Tensor[(32768, 128), int8] */, %p2: Tensor[(32768, 128, 4), float32] /* ty=Tensor[(32768, 128, 4), float32] */, %p3: Tensor[(24, 32768, 128), int8] /* ty=Tensor[(24, 32768, 128), int8] */, Primitive=1) -> Tensor[(32768, 128, 4), float32] {
        %0 = add(%p0, 1f /* ty=float32 */) /* ty=Tensor[(32768, 128, 3), float32] */;
        assign_blocks_to_samples(%0, %p1, %p2, %p3, plane_x=4, plane_y=6) /* ty=Tensor[(32768, 128, 4), float32] */
      } /* ty=fn (Tensor[(32768, 128, 3), float32], Tensor[(32768, 128), int8], Tensor[(32768, 128, 4), float32], Tensor[(24, 32768, 128), int8]) -> Tensor[(32768, 128, 4), float32] */;
      %1(%xyz_sampled, %ray_valid, %output, %masks) /* ty=Tensor[(32768, 128, 4), float32] */
    }

“assign_blocks_to_samples” is a new registered elemwise relay op.

I found that it would be work if the ops like this:

fuse_ops: def @main(%xyz_sampled: Tensor[(32768, 128, 3), float32] /* ty=Tensor[(32768, 128, 3), float32] */, %ray_valid: Tensor[(32768, 128), int8] /* ty=Tensor[(32768, 128), int8] */, %output: Tensor[(32768, 128, 4), float32] /* ty=Tensor[(32768, 128, 4), float32] */, %masks: Tensor[(24, 32768, 128), int8] /* ty=Tensor[(24, 32768, 128), int8] */) -> Tensor[(32768, 128, 4), float32] {
  %1 = fn (%p0: Tensor[(32768, 128, 3), float32] /* ty=Tensor[(32768, 128, 3), float32] */, %p1: Tensor[(32768, 128), int8] /* ty=Tensor[(32768, 128), int8] */, %p2: Tensor[(32768, 128, 4), float32] /* ty=Tensor[(32768, 128, 4), float32] */, %p3: Tensor[(24, 32768, 128), int8] /* ty=Tensor[(24, 32768, 128), int8] */, Primitive=1) -> Tensor[(32768, 128, 4), float32] {
    %0 = assign_blocks_to_samples(%p0, %p1, %p2, %p3, plane_x=4, plane_y=6) /* ty=Tensor[(32768, 128, 4), float32] */;
    add(%0, 1f /* ty=float32 */) /* ty=Tensor[(32768, 128, 4), float32] */
  } /* ty=fn (Tensor[(32768, 128, 3), float32], Tensor[(32768, 128), int8], Tensor[(32768, 128, 4), float32], Tensor[(24, 32768, 128), int8]) -> Tensor[(32768, 128, 4), float32] */;
  %1(%xyz_sampled, %ray_valid, %output, %masks) /* ty=Tensor[(32768, 128, 4), float32] */
}

What causes my newly registered operator cloud not fused after relay.add?

Here is my test code:

xyz_sampled_relay = relay.var('xyz_sampled', shape=(32768, 128, 3), dtype="float32")
ray_valid_relay = relay.var('ray_valid', shape=(32768, 128), dtype="int8")
output_relay = relay.var('output', shape=(32768, 128, 4), dtype="float32")
masks_relay = relay.var('masks', shape=(24, 32768, 128), dtype="int8")
aaaaa = relay.add(xyz_sampled_relay, relay.const(1.0))
abts = relay.assign_blocks_to_samples(aaaaa, ray_valid_relay, output_relay, masks_relay, 4, 6)
relay_func = relay.Function([xyz_sampled_relay, ray_valid_relay, output_relay, masks_relay], abts)
mod = tvm.IRModule.from_expr(relay_func)
target = tvm.target.Target(target="cuda")
device = tvm.device(target.kind.name, 0)
with relay.build_config(opt_level=1):
    lib = relay.build_module.build(mod, target, params=None) # error occurs

Any help would be appreciated.

I also meet problem raise AttributeError(f"{type(self)} has no attribute {name}") from None AttributeError: <class 'tvm.te.tensor.PlaceholderOp'> has no attribute axis