About nn.sparse_conv2d and Conv2dToSparse Pass

I use Conv2dToSparse Pass to replace nn.conv2d with nn.sparse_conv2d. After changing the stride value of the convolution operation, the shape of the calculation result remains unchanged, which is different from using nn.conv2d.The tvm version is 0.8

Here is the simplified code

 def conv2d(data,  weight=None, **kwargs):
  name = kwargs.get("name")
  kwargs.pop("name")
  if not weight:
      weight = relay.var(name + "_weight")
  return relay.nn.conv2d(data, weight, **kwargs)


def simple_net(data, name, channels, kernel_size=(3, 3), strides=(1, 1),
              padding=(1, 1), epsilon=1e-5):
  conv = conv2d(
      data=data,
      channels=channels,
      kernel_size=kernel_size,
      strides=strides,
      padding=padding,
      data_layout='NCHW',
      name=name+'_conv')
  return conv

def test_relay():
  bs_r, bs_c = 1, 1
  sparsity = 0
  layout = "NCHW"
  kernel_size = 3

  data_shape = (1, 3, 224, 224)
  kernel_shape = (1, 3, 3, 3)
  dtype = "float32"

  data = relay.var("data", shape=data_shape, dtype=dtype)
  y = simple_net(data, "graph", 1, kernel_size=(3, 3), strides=(3, 3))
  func = relay.Function(relay.analysis.free_vars(y), y)

  params = {
    "graph_conv_weight": tvm.nd.array(np.random.uniform(-1, 1, kernel_shape).astype(dtype)),
  }
  print("before pass:", func)
  func, params = relay.data_dep_optimization.bsr_conv2d.convert(func, params, (bs_r, bs_c), sparsity, layout, kernel_size)

  print("after pass:", func)

  with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(func, "llvm", params=params)

  dev = tvm.cpu(0)
  dtype = "float32"
  m = graph_executor.GraphModule(lib["default"](dev))
  # set inputs
  np_data = np.random.uniform(-1, 1, (1, 3, 224, 224))
  m.set_input("data", tvm.nd.array(np_data.astype(dtype)))
  # execute
  m.run()
  # get outputs
  tvm_output = m.get_output(0)
  print("output shape: ", tvm_output.shape)

the output result

before pass: fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight) {
  nn.conv2d(%data, %graph_conv_weight, strides=[3, 3], padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3])
}
after pass: fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight.data: Tensor[(27, 1, 1), float32], %graph_conv_weight.indices: Tensor[(27), int32], %graph_conv_weight.indptr: Tensor[(2), int32]) -> Tensor[(1, 1, 224, 224), float32] {
  nn.sparse_conv2d(%data, %graph_conv_weight.data, %graph_conv_weight.indices, %graph_conv_weight.indptr, layout="NCHW", kernel_size=[3, 3]) /* ty=Tensor[(1, 1, 224, 224), float32] */
}
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
output shape:  (1, 1, 224, 224)

If you don’t use Conv2dToSparse by delete convert function, the result is

fn (%data: Tensor[(1, 3, 224, 224), float32], %graph_conv_weight) {
  nn.conv2d(%data, %graph_conv_weight, strides=[3, 3], padding=[1, 1, 1, 1], channels=1, kernel_size=[3, 3])
}
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
output shape:  (1, 1, 75, 75)

the relay.data_dep_optimization.bsr_conv2d.convert is defined as follow:

def convert(func, params, blocksize, sparsity_threshold, layout="NHWC", kernel_size=1):
    """Convert a conv2d func and according parameters to block sparse

    Parameters
    ----------
    func : relay.Expr
        Expr will be optimized to sparse operation
    params : Dict[Srting, tvm.nd.array]
        Parameters of the Expr
    blocksize : Tuple(int, int)
        Blocksize for BSR matrix
    sparsity_threshold : float
        Minimal sparsity requirement for converting.
        If weight sparsity is lower than this threshold,
        the dense operation will be kept.
    layout : str
        layout of network

    Returns
    -------
    new_func: relay.Expr
        Mutated Expr with sparse operations

    params: Dict[Srting, tvm.nd.array]
        New params with BSR matrix for mutated Expr
    """
    weight_info = process_params(func, params, blocksize, sparsity_threshold, layout, kernel_size)
    new_func = _run_opt_pass(
        func,
        relay.transform.Conv2dToSparse(
            weight_info.weight_name, weight_info.weight_shape, layout, kernel_size
        ),
    )

    return new_func, params

btw, how to find where is the problem for new as me?:slight_smile: