I found the problem but I am unsure how to solve it as it seems like a fundamental problem in the implementation of Optional<T>
cc @tqchen @junrushao @Hzfengsy @mousius @manupa-arm
I ran your code and I see this output from print("JSON", x)
:
JSON {
"root": 1,
"nodes": [
{
"type_key": ""
},
{ // node 1
"type_key": "IRModule",
"attrs": {
"attrs": "0",
"functions": "2",
"global_type_var_map_": "97",
"global_var_map_": "96",
"source_map": "98",
"type_definitions": "95"
}
},
{ // node 2
"type_key": "Map",
"data": [3, 5]
},
{ // node 3
"type_key": "GlobalVar",
"attrs": {
"_checked_type_": "0",
"name_hint": "4",
"span": "0",
"virtual_device_": "0"
}
},
{ // node 4
"type_key": "runtime.String",
"repr_str": "main"
},
{ // node 5
"type_key": "tir.PrimFunc",
"attrs": {
"_checked_type_": "91",
"attrs": "85",
"body": "12",
"buffer_map": "83",
"params": "6",
"preflattened_buffer_map": "84",
"ret_type": "81",
"span": "90"
}
},
// ...
{ // node 12
"type_key": "tir.AllocateConst",
"attrs": {
"annotations": "79",
"body": "24",
"buffer_var": "13",
"data": "20",
"dtype": "int32",
"extents": "21",
"irmod_storage_idx": "0",
"span": "80"
}
},
// ...
{ // node 20 <-- BUG is here
"type_key": "runtime.NDArray"
},
The problem is evident in node 20 which should not exist. Node 20 is the result of NodeIndexer visiting AllocateConst via VisitAttrs. The bug is mabye inside VisitAttrs:
class AllocateConstNode : public StmtNode {
public:
/*! \brief The buffer variable. */
Var buffer_var;
/*! \brief The optional data associated to the constant.
*/
Optional<runtime::NDArray> data;
/*! \brief If the PrimFunc containing the Stmt is added to IRModule,
this is an optional index to indicate the index within
"Constants" attribute, that is a Array<NDArray> of IRModule.
*/
Optional<Integer> irmod_storage_idx;
/*! \brief The type of the buffer. */
DataType dtype;
/*! \brief The extents of the buffer. */
Array<PrimExpr> extents;
/*! \brief The body to be executed. */
Stmt body;
/*!
* \brief Additional annotations about the allocation.
*
* These annotations can be used as auxiliary hint
* to future transformations.
*/
Map<String, ObjectRef> annotations;
void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer_var", &buffer_var);
v->Visit("data", &data); // BUG: casts Optional<T> to Object* via Visit() overload. Reflection expects this to invoke Visit(const char* key, runtime::NDArray* value).
v->Visit("irmod_storage_idx", &irmod_storage_idx);
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("body", &body);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}
Unfortunately because Optional<T>
is templated, T
is not available in the implementation of NodeIndexer::Visit(const char*, ObjectRef*)
, so this means that there’s no way to tell the difference between an ObjectRef that isn’t defined and an ObjectRef which is an Optional<T>
(and, in this case, what’s T
?).
I made an attempt to fix this, but the unit test fails because load_json
doesn’t know that it needs to lookup the value of "data"
in b64ndarrays
. Not sure how to fix that…we could modify the serialization format or introduce an explicit OptionalNode to actually hold the NDArray and define T
.