Hi,
I tried to implemented the graph annotation as required for the BYOC flow and basically copied over most of the DNNL example and extended it, but now when testing the graph annotation with
mod = transform.AnnotateTarget("dla")(mod)
mod = transform.PartitionGraph()(mod)
the graph does not change. What did I miss?
The graph:
def @main(%x: Tensor[(10, 10), int8], %y: Tensor[(10, 10), int8]) -> Tensor[(10, 10), int8] { %0 = multiply(%y, %y) /* ty=Tensor[(10, 10), int8] /; %1 = add(%x, %x) / ty=Tensor[(10, 10), int8] /; subtract(%0, %1) / ty=Tensor[(10, 10), int8] */ }
the annotation rules:
def _register_external_op_helper(op_name, supported=True):
@tvm.ir.register_op_attr(op_name, "target.dla")
def _func_wrapper(attrs, args):
return supported
return _func_wrapper
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
@register_pattern_table("dla")
def dla_pattern_table():
def qnn_conv_pattern():
"""Create a quantized convolution pattern.
Returns
-------
pattern : dataflow_pattern.AltPattern
Denotes the convolution pattern.
"""
#Relay handles padding as separate operation
#but Simulator inlines it (cause it is using the TF implementation)
pattern = is_op('nn.pad')(wildcard()) | wildcard()
pattern = is_op('qnn.conv2d')(
pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
pattern = is_op('qnn.requantize')(
pattern, wildcard(), wildcard(), is_constant(), is_constant())
return pattern
def check_qnn_conv(extract):
"""Check qnn conv pattern is supported by dla."""
if extract.attrs.out_dtype != "int8":
return False
call = extract
while call.op.name != "qnn.conv2d":
call = call.args[0]
return qnn_conv2d(call.attrs, call.args)
def qnn_dense_pattern():
pattern = is_op('qnn.dense')(
pattern, is_constant(), is_constant(), is_constant(), is_constant(), is_constant())
pattern = pattern.optional(lambda x: is_op('nn.bias_add')(x, is_constant()))
pattern = is_op('qnn.requantize')(
pattern, wildcard(), wildcard(), is_constant(), is_constant())
return pattern
def check_qnn_dense(extract):
"""Check if qnn dense pattern is supported by dla"""
if extract.attrs.out_dtype != "int8":
return False
call = extract
while call.op.name != "qnn.dense":
call = call.args[0]
return qnn_dense(call.attrs, call.args)
return ['dla.qnn_conv2d', qnn_conv_pattern(), check_qnn_conv),
'dla.qnn_dense', qnn_dense_pattern(), check_qnn_dense)]
...