[TVM 0.8] Check failed: Expected Array[Array[PrimExpr]], but got Array[index 0: Array[index 0: Array]]

The case works fine in TVM 0.6, but occured Check failed in TVM 0.8.

The recurrence step in TVM 0.8 is similar to the following:

  1. add a function in Stage such as:

    // include/tvm/te/schedule.h
    TVM_DLL Stage& TestArrayArrayPrimExpr(Array<Array<Expr>> input);
    
    // src/te/schedule/schedule_lang.cc
    Stage& Stage::TestArrayArrayPrimExpr(Array<Array<PrimExpr>> input){
      for (const auto out : input) {
        for (const auto in : out) {
          std::cout << __FILE__ << ": " << __LINE__ << ": " << __FUNCTION__ << ": in is " << in
                    << std::endl;
        }
      }
      return *this;
    }
    

​ The parameter type is nested Array: Array<Array<PrimExpr>>

  1. Then register function as global function for calling in python, such as

    TVM_REGISTER_GLOBAL("te.TestArrayArrayPrimExpr").set_body_method(&Stage::TestArrayArrayPrimExpr);
    
  2. Then wrap the function in tvm.te.schedule.Stage (tvm/te/schedule.py), such as

    def test_aa(self, *args):
        _ffi_api.TestArrayArrayPrimExpr(self, *args)
    
  3. Create a python test and use the func, such as

    def test_nest_array_case(self):
        data = tvm.placeholder((32,32), name="data", dtype="float16")
        out = tvm.compute((32,32), lambda *indice: data(*indice) + 1, name="out")
        # schedule
        s = tvm.create_schedule(out.op)
        dataL = s.cache_read(data, "shared", [out])
        outL = s.cache_write(out, "local")
        s[outL].test_aa([['a', 'b'], ['c', 'd'], ['e', 123]])
    
  4. Then I got stack trace and Check failed message below:

    tvm._ffi.base.TVMError: Traceback (most recent call last):
      1: TVMFuncCall
      0: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::te::Stage& (tvm::te::Stage, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::te::Stage, tvm::te::Stage&, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void> >(tvm::te::Stage& (tvm::te::Stage::*)(tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>))::{lambda(tvm::te::Stage, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>)#1}>(tvm::runtime::Registry::set_body_method<tvm::te::Stage, tvm::te::Stage&, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void> >(tvm::te::Stage& (tvm::te::Stage::*)(tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>))::{lambda(tvm::te::Stage, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>)#1}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
      2: TVMFuncCall
      1: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::te::Stage& (tvm::te::Stage, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>)>::AssignTypedLambda<tvm::runtime::Registry::set_body_method<tvm::te::Stage, tvm::te::Stage&, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void> >(tvm::te::Stage& (tvm::te::Stage::*)(tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>))::{lambda(tvm::te::Stage, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>)#1}>(tvm::runtime::Registry::set_body_method<tvm::te::Stage, tvm::te::Stage&, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void> >(tvm::te::Stage& (tvm::te::Stage::*)(tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>))::{lambda(tvm::te::Stage, tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void>)#1}, std::string)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
      0: tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void> tvm::runtime::TVMPODValue_::AsObjectRef<tvm::runtime::Array<tvm::runtime::Array<tvm::PrimExpr, void>, void> >() const
      File "compiler_depend.ts", line 716
    TVMError: In function te.TestArrayArrayPrimExpr: error while converting argument 1: [21:04:30] compiler_depend.ts:1593: 
    ---------------------------------------------------------------
    An error occurred during the execution of TVM.
    For more information, please see: https://tvm.apache.org/docs/errors.html
    ---------------------------------------------------------------
      Check failed: (!checked_type.defined()) is false: Expected Array[Array[PrimExpr]], but got Array[index 0: Array[index 0: Array]]
    

    According to stack trace and Check failed info:

    Expected Array[Array[PrimExpr]], but got Array[index 0: Array[index 0: Array]]

    I found TVMPODValue_::AsObjectRef in code include/tvm/runtime/packed_func.h

    if (type_code_ == kTVMObjectHandle) {
        // normal object type check.
        Object* ptr = static_cast<Object*>(value_.v_handle);
        Optional<String> checked_type = ObjectTypeChecker<TObjectRef>::CheckAndGetMismatch(ptr);
        ICHECK(!checked_type.defined()) << "Expected " << ObjectTypeChecker<TObjectRef>::TypeName()
                                        << ", but got " << checked_type.value();
        return TObjectRef(GetObjectPtr<Object>(ptr));
      }
    

    and ObjectTypeChecker::CheckAndGetMismatch in struct ObjectTypeChecker<Array>

    struct ObjectTypeChecker<Array<T>> {
      static Optional<String> CheckAndGetMismatch(const Object* ptr) {
        if (ptr == nullptr) {
          return NullOpt;
        }
        if (!ptr->IsInstance<ArrayNode>()) {
          return String(ptr->GetTypeKey());
        }
        const ArrayNode* n = static_cast<const ArrayNode*>(ptr);
        for (size_t i = 0; i < n->size(); i++) {
          const ObjectRef& p = (*n)[i];
          Optional<String> check_subtype = ObjectTypeChecker<T>::CheckAndGetMismatch(p.get());
          if (check_subtype.defined()) {
            return String("Array[index " + std::to_string(i) + ": " + check_subtype.value() + "]");
          }
        }
        return NullOpt;
      }
    
      blablabla....
    };
    

    I do not really understand interaction between c++ and python.

    Why add ‘Array[index’ in return string when check_subtype.defined() in CheckAndGetMismatch?

    How can I transfer param type as Array<Array<PrimExpr>> to c++ function from python in TVM 0.8 just like Array<Array<Expr>> in TVM 0.6 ?

    Thanks a lot.

I find the reason for this ‘Check failed’

In this issue, function convert_to_object (/tvm/runtime/object_generic.py) changed the return type in condition if isinstance(value, string_types).
In 0.6 version, this condition return type is StringImm which ContainerType is ExprNode.
In 0.8 version, this condition return type is runtime.String which ContainerType is StringObj.
Because StringObj is not an instance of PrimExprNode, CheckAndGetMismatch function returns actual type runtime.String which cause Check failed in fucntion TVMPODValue_::AsObjectRef()