Error while running a single layer model of avg_pooling2d using CLML on qualcomm device

Additions for clml:

if not local_demo and enable_clml:
        print("partition clml")
        print(clml.is_clml_runtime_enabled())
        mod = clml.preprocess_module(mod)
        mod = clml.partition_for_clml(mod, params)

desired_layouts = {'nn.conv2d': ['NCHW', 'default'],
                   'nn.dense': ['NCHW', 'default'],
                   'nn.avg_pool2d': ['NCHW', 'default']}
seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts)])

@srkreddy1238 I’ve created single layer average pooling 2d tflite model with NHWC layout, tried to convert to NCHW supported by clml in tvm and tried converting to tvm.so and running on qualcomm device. Got following error.

Output:

RPCError: Error caught from RPC call:

[22:59:47] /home/code/RJ/pytorch/tvm_os/src/runtime/contrib/clml/clml_runtime.cc:1053: InternalError: Check failed: (op && result == CL_SUCCESS) is false: Pooling Error:-30!

@dabhinav10 let me try reproducing this scenario. btw, which device are you trying this ?

S24 Ultra with Qualcomm chipset.

CLML supports only NCHW format. To deploy model to clml path here is correct way to do

mod = tvm.relay.nn.avg_pool2d(data, pool_size=(2, 2), strides=(2, 2), layout=‘NHWC’)

  • It need to convert to NCHW format before applying clml partition

desired_layouts = {‘nn.conv2d’: [‘NCHW’, ‘default’], ‘nn.dense’: [‘NCHW’, ‘default’], ‘nn.avg_pool2d’: [‘NCHW’, ‘default’]}

seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(), relay.transform.ConvertLayout(desired_layouts)])

mod = seq(mod)

if not local_demo and enable_clml:

    print("partition clml")

    print(clml.is_clml_runtime_enabled())

    mod = clml.partition_for_clml(mod, params)

@kvegiraj Here’s the sequence of my script:

mod, params = relay.frontend.from_tflite(
    tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype}
)

desired_layouts = {'nn.conv2d': ['NCHW', 'default'],
                   'nn.dense': ['NCHW', 'default'],
                   'nn.avg_pool2d': ['NCHW', 'default']}

seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts)])

mod = seq(mod)
if not local_demo and enable_clml:
    print("partition clml")
    print(clml.is_clml_runtime_enabled())
    mod = clml.preprocess_module(mod)
    mod = clml.partition_for_clml(mod, params)


target = tvm.target.Target(test_target, host=target)
with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=target, params=params)

Still facing same issue: InternalError: Check failed: (op && result == CL_SUCCESS) is false: Pooling Error:-30

“mod = clml.preprocess_module(mod)” is not required, remove and then try again it should work

@kvegiraj @srkreddy1238 I was able to fix this issue by doing the transform and clml partitioning under tvm.transform.PassContext(opt_level=3) optimization pass.

with tvm.transform.PassContext(opt_level=3):
    mod = seq(mod)
    print("----------ir after layout change------------")
    print(mod)
    print("----------------------")

    if not local_demo and enable_clml:
        print("partition clml")
        print(clml.is_clml_runtime_enabled())
        mod = clml.preprocess_module(mod)
        mod = clml.partition_for_clml(mod, params)


    target = tvm.target.Target(test_target, host=target)

    #mod = seq(mod)
    lib = relay.build(mod, target=target, params=params)

@srkreddy1238 @kvegiraj However, I’m getting a different error with another which is related to kernel.

I’ve converted a tflite model with conv2d followed by a padding layer (zeropadding2d) with NHWC format and a single model with only padding layer with NHWC format from keras. Again tried converting to NCHW to for clml, converted to tvm.so and ran on qualcomm s24 ultra.

Got following error:

InternalError: Check failed: (result == CL_SUCCESS) is false: clEnqueueCopyMLTensorDataQCOM:-48

Model Used:

Used this script:

desired_layouts = {'nn.conv2d': ['NCHW', 'default'],
                   'nn.dense': ['NCHW', 'default'],
                   'nn.avg_pool2d': ['NCHW', 'default']}

seq = tvm.transform.Sequential([relay.transform.RemoveUnusedFunctions(),
                                relay.transform.ConvertLayout(desired_layouts)])

with tvm.transform.PassContext(opt_level=3):
    mod = seq(mod)
    print("----------ir after layout change------------")
    print(mod)
    print("----------------------")

    if not local_demo and enable_clml:
        print("partition clml")
        print(clml.is_clml_runtime_enabled())
        mod = clml.preprocess_module(mod)
        mod = clml.partition_for_clml(mod, params)


    target = tvm.target.Target(test_target, host=target)

    #mod = seq(mod)
    lib = relay.build(mod, target=target, params=params)
    lib_fname = "dummy_model.tvm.so"
    if run_on_device:
        lib_fname="conv_padded_model.tflite.tvm.so"
 
print(ndk)
print (ndk.create_shared)
fcompile = ndk.create_shared if run_on_device else None
lib.export_library(lib_fname, fcompile)

What is padding (4x2) … Should I assume it is happening with pad = ((0,0),(1,1), (1,1), (0,0)). Here is generated mod after transform and it is working with tvm and clml.

def @main(%a: Tensor[(1, 28, 28, 3), float32] /* ty=Tensor[(1, 28, 28, 3), float32] /) → Tensor[(1, 28, 28, 16), float3 2] {
%0 = layout_transform(%a, src_layout=“NHWC”, dst_layout=“NCHW”) /
ty=Tensor[(1, 3, 28, 28), float32] /;
%1 = layout_transform(meta[relay.Constant][0] /
ty=Tensor[(3, 3, 3, 16), float32] /, src_layout=“HWIO”, dst_layout=“OIHW”) / ty=Tensor[(1 6, 3, 3, 3), float32] /;
%2 = expand_dims(meta[relay.Constant][1] /
ty=Tensor[(16), float32] /, axis=0, num_newaxis=3) / ty=Tensor[(1, 1, 1, 16), float32] /;
%3 = nn.conv2d(%0, %1, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3], out_dtype=“float32”) /
ty=Tensor[(1, 16, 26, 26), float32] / ;
%4 = layout_transform(%2, src_layout=“NHWC”, dst_layout=“NCHW”) /
ty=Tensor[(1, 16, 1, 1), float32] /;
%5 = add(%3, %4) /
ty=Tensor[(1, 16, 26, 26), float32] /;
%6 = nn.relu(%5) /
ty=Tensor[(1, 16, 26, 26), float32] /;
%7 = nn.pad(%6, 0 /
ty=int32 /, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) / ty=Tensor[(1, 16, 28, 28), float32] /;
layout_transform(%7, src_layout=“NCHW”, dst_layout=“NHWC”) /
ty=Tensor[(1, 28, 28, 16), float32] */
}

Let me know if the above mod definition doesn’t match with you mod.

@kvegiraj The generated mod for me is:

def @main(%serving_default_conv2d_input:0: Tensor[(1, 28, 28, 3), float32] /* ty=Tensor[(1, 28, 28, 3), float32] span=serving_default_conv2d_input:0:0:0 */, %v_param_1: Tensor[(3, 3, 3, 16), float32] /* ty=Tensor[(3, 3, 3, 16), float32] span=sequential/conv2d/Conv2D:0:0 */, %v_param_2: Tensor[(16), float32] /* ty=Tensor[(16), float32] span=sequential/conv2d/BiasAdd/ReadVariableOp:0:0 */, output_tensor_names=["StatefulPartitionedCall_0"]) -> Tensor[(1, 28, 28, 16), float32] {
  %0 = layout_transform(%serving_default_conv2d_input:0, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 3, 28, 28), float32] */;
  %1 = layout_transform(%v_param_1, src_layout="HWIO", dst_layout="OIHW") /* ty=Tensor[(16, 3, 3, 3), float32] */;
  %2 = expand_dims(%v_param_2, axis=0, num_newaxis=3) /* ty=Tensor[(1, 1, 1, 16), float32] */;
  %3 = nn.conv2d(%0, %1, padding=[0, 0, 0, 0], channels=16, kernel_size=[3, 3]) /* ty=Tensor[(1, 16, 26, 26), float32] span=sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/BiasAdd/ReadVariableOp;sequential/conv2d/Conv2D:0:0 */;
  %4 = layout_transform(%2, src_layout="NHWC", dst_layout="NCHW") /* ty=Tensor[(1, 16, 1, 1), float32] */;
  %5 = add(%3, %4) /* ty=Tensor[(1, 16, 26, 26), float32] */;
  %6 = nn.relu(%5) /* ty=Tensor[(1, 16, 26, 26), float32] span=sequential/conv2d/Relu;sequential/conv2d/BiasAdd;sequential/conv2d/BiasAdd/ReadVariableOp;sequential/conv2d/Conv2D:0:0 */;
  %7 = nn.pad(%6, 0 /* ty=int32 span=StatefulPartitionedCall:0:0:0 */, pad_width=[[0, 0], [0, 0], [1, 1], [1, 1]]) /* ty=Tensor[(1, 16, 28, 28), float32] span=StatefulPartitionedCall:0:0:0 */;
  layout_transform(%7, src_layout="NCHW", dst_layout="NHWC") /* ty=Tensor[(1, 28, 28, 16), float32] */
}

I’m running this using this code to get the mean inference time and getting InternalError: Check failed: (result == CL_SUCCESS) is false: clEnqueueCopyMLTensorDataQCOM:-48

import tvm
import numpy as np
from tvm import te
from tvm.contrib import graph_executor as runtime

ctx = remote.cl(0)
#ctx = remote.cpu(0)

# Transfer the model lib to remote device
remote.upload(lib_fname)
# Load the remote module
rlib = remote.load_module(lib_fname)

# Create a runtime executor module
module = runtime.GraphModule(rlib["default"](ctx))

# Run
module.run()

# Benchmark the performance

ftime = module.module.time_evaluator("run", ctx, number=1, repeat=10)
prof_res = np.array(ftime().results) * 1000
print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))

Which version of CLML SDK is used for your test sample ?

@kvegiraj I’m using CLML 3.0.

Hi @srkreddy1238 @kvegiraj ,

I am getting same clEnqueueCopyMLTensorDataQCOM:-48 error while running model with CLML 3.0. How to overcome this issue.

Thanks Rahul jha

Hi @dabhinav10 @rjharahul @srkreddy1238

We have clml fix for the clml error - clEnqueueCopyMLTensorDataQCOM:-48 . Please use below patch to resolve this issue. [CLML][RUNTIME] Fix CLML EnqueCopy issue · krishnaraj36/tvm_mainline@e3ac6f0 (github.com)

Thanks @kvegiraj for this patch now i am able to fix this clEnqueueCopyMLTensorDataQCOM:-48 error.

With some model i am getting this error :-

terminating with uncaught exception of type tvm::runtime::InternalError: [16:06:19] /tvm/src/runtime/contrib/clml/clml_runtime.cc:1572: InternalError: Check failed: (op && result == CL_SUCCESS) is false: minimum Node Error:-1102

What this error means and how to resolve it.

@rjharahul The issue is getting in minimum op . Can you share the minimum op definition in model or model representation.

@kvegiraj i have added image of minimum layer you can get through this link.

https://github.com/rjharahul21/Models/blob/main/Screenshot%20(1).png

@rjharahul : PR raised for this CLML issue [CLML] Fix in clml pattern check condition by krishnaraj36 · Pull Request #16933 · apache/tvm (github.com)

Please take patch of this PR to resolve your above issue.

The PR has been merged to mainline [CLML] Fix in clml pattern check condition by krishnaraj36 · Pull Request #16933 · apache/tvm (github.com)