hi all ,
i have to write a custom logic of mem buffer creation in TVM where i need the names of the arguments of CallNode with op as nn.dense.
The function signature looks something like this:
vector<string> Dense(CallNode* call){
input_name = call->args[0]->name_hint();
weight_name = call->args[1]->name_hint();
// do something with names.
}
Though i want to achieve some thing like this, but unable to do so as i am not able to extract a VarNode from call->args[0].
Any help would be highly appreciated.
thanks.
I actually tried what you want to do, but I failed to get(set or define) the name because it seems that name_hint is not a member of Expr, but only for Var or GlobalVar. In my case, common argument type is another “CallNode” because they are output of other relay op.
I hope somebody who knows the way identifying ops by name leave the reply… Actually I’m getting used to not using a name of relay operator, it would be very helpful when I first get into TVM.
Anyway, what is your use case that requires a name of operator?
hi @gangmul12
thanks for your reply,
you can create a relay program as such :
a = relay.var("input1" , shape = (10,) , dtype = "float32")
b = relay.var("input2" , shape = (10,) , dtype = "float32")
c = relay.add(a,b)
func = relay.Function([a, b] , c)
mod = tvm.IRModule.from_expr(func)
mod = relay.transform.InferType()(mod)
this will give you args are VarNode for CallNode : Add.
keeping this example as base, i would like to get the names of both VarNode which are serving as args for CallNode.
My use case is such that ,
i would like to distinguish between input and weight argument of relay.nn.dense function , and want to create a map of arg_name and its type (input or weight). I have some use of this map later in some other part of the code.
some thing like this :
You can downcast the object from Expr to a VarNode using .as<VarNode>.
From your example it should work something like this:
vector<string> Dense(CallNode* call){
const auto input = call->args[0];
const auto weight = call->args[1];
if (input.as<VarNode>()) {
input_name = input.as<VarNode>()->name_hint();
}
if (weight.as<VarNode>()) {
weight_name = weight.as<VarNode()->name_hint();
}
// do something with names.
}
The if condition is needed to ensure that the downcast is valid (if the Expr is indeed a reference to VarNode, as opposed to some other expression like ConstantNode), as .as<VarNode> would return a nullptr if not.