Arm fp16 instruction support

I’m wondering whether we support Arm v8.2 FP16 instructions in TVM codegen, for example in the following code snippet, I don’t see any FMLA vD.8h, vN.8h, vM.8h instructions in conv.s

I’m using LLVM 8.0

import tvm
import topi
from topi.util import get_const_tuple
import numpy as np

from topi.nn.pad import pad

target = 'llvm -device=arm_cpu -target=aarch64-linux-gnu -mattr=+v8.2a,+fullfp16,+fp-armv8,+dotprod,+crc,+crypto,+neon'
dtype = 'float16'

def get_fp32_len():
    return 8

def _fallback_schedule(in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
    WPAD, HPAD = padding
    WSTR, HSTR = strides

    simd_width = get_fp32_len()
    out_width = (width + 2 * WPAD - filter_width) // WSTR + 1

    oc_bn = 1
    for bn in range(simd_width, 0, -1):
        if num_filter % bn == 0:
            oc_bn = bn
            break

    ic_bn = 1
    for bn in range(oc_bn, 0, -1):
        if in_channel % bn == 0:
            ic_bn = bn
            break

    reg_n = 1
    for n in range(31, 0, -1):
        if out_width % n == 0:
            reg_n = n
            break

    return ic_bn, oc_bn, reg_n, False


def conv_compute(data, kernel, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
    out_dtype = data.dtype

    dilation_h, dilation_w = 1, 1

    HPAD, WPAD = padding
    HSTR, WSTR = strides

    batch_size, in_channel, in_height, in_width = get_const_tuple(data.shape)
    num_filter, _, kernel_height, kernel_width = get_const_tuple(kernel.shape)

    pad_height = in_height + 2 * HPAD
    pad_width = in_width + 2 * WPAD

    dilated_kernel_h = (kernel_height - 1) * dilation_h + 1
    dilated_kernel_w = (kernel_width - 1) * dilation_w + 1
    out_height = (in_height + 2 * HPAD - dilated_kernel_h) // HSTR + 1
    out_width = (in_width + 2 * WPAD - dilated_kernel_w) // WSTR + 1

    # pack data
    DOPAD = (HPAD != 0 or WPAD != 0)
    if DOPAD:
        data_pad = pad(data, (0, 0, HPAD, WPAD), name="data_pad")
    else:
        data_pad = data

    # fetch schedule
    ic_bn, oc_bn, reg_n, unroll_kw = _fallback_schedule(in_channel, height, width, num_filter,
                                                        filter_height, filter_width, padding, strides)

    shape = (batch_size, in_channel // ic_bn, pad_height, ic_bn, pad_width)
    data_vec = tvm.compute(shape,
                           lambda n, C, h, c, w: data_pad[n, C * ic_bn + c, h, w],
                           name='data_vec')

    # pack kernel
    shape = (num_filter//oc_bn, in_channel//ic_bn,
             kernel_height, kernel_width, ic_bn, oc_bn)
    kernel_vec = tvm.compute(shape,
                             lambda CO, CI, h, w, ci, co:
                             kernel[CO * oc_bn + co, CI * ic_bn + ci, h, w],
                             name='kernel_vec')

    # convolution
    oshape = (batch_size, num_filter//oc_bn, out_height, out_width, oc_bn)
    unpack_shape = (batch_size, num_filter, out_height, out_width)

    ic = tvm.reduce_axis((0, in_channel), name='ic')
    kh = tvm.reduce_axis((0, kernel_height), name='kh')
    kw = tvm.reduce_axis((0, kernel_width), name='kw')

    conv = tvm.compute(oshape, lambda n, oc_chunk, oh, ow, oc_block:
                       tvm.sum(data_vec[n, ic//ic_bn, oh*HSTR+kh*dilation_h, ic%ic_bn,
                                        ow*WSTR+kw*dilation_w].astype(out_dtype) *
                               kernel_vec[oc_chunk, ic//ic_bn, kh, kw, ic%ic_bn,
                                          oc_block].astype(out_dtype),
                               axis=[ic, kh, kw]), name='conv')

    unpack = tvm.compute(unpack_shape,
                         lambda n, c, h, w: conv[n, c // oc_bn, h, w, c % oc_bn]
                         .astype(out_dtype),
                         name='output_unpack',
                         tag='conv2d_nchw')
    return unpack


def conv_schedule(C, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
    s = tvm.create_schedule(C.op)
    op = C.op
    output = op.output(0)
    conv_out = op.input_tensors[0]
    kernel_vec = conv_out.op.input_tensors[1]
    kernel = kernel_vec.op.input_tensors[0]
    if isinstance(kernel.op, tvm.tensor.ComputeOp) and "dilate" in kernel.op.tag:
        s[kernel].compute_inline()
    data_vec = conv_out.op.input_tensors[0]
    data = data_vec.op.input_tensors[0]
    data_pad = None
    if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
        data_pad = data
        data = data_pad.op.input_tensors[0]

    _, _, kh, kw = get_const_tuple(kernel.shape)

    # fetch schedule
    ic_bn, oc_bn, reg_n, unroll_kw = _fallback_schedule(in_channel, height, width, num_filter,
                                                        filter_height, filter_width, padding, strides)

    # no stride and padding info here
    HPAD, WPAD = padding
    DOPAD = (HPAD != 0 or WPAD != 0)

    A, W = data, kernel_vec
    A0, A1 = data_pad, data_vec

    # schedule data
    if DOPAD:
        s[A0].compute_inline()
    batch, ic_chunk, ih, ic_block, iw = s[A1].op.axis
    parallel_axis = s[A1].fuse(ic_chunk, ih)
    s[A1].parallel(parallel_axis)

    # schedule kernel pack
    oc_chunk, ic_chunk, oh, ow, ic_block, oc_block = s[W].op.axis
    s[W].reorder(oc_chunk, oh, ic_chunk, ow, ic_block, oc_block)
    if oc_bn > 1:
        s[W].vectorize(oc_block)
    parallel_axis = s[W].fuse(oc_chunk, oh)
    s[W].parallel(parallel_axis)

    # schedule conv
    C, O0 = conv_out, output
    CC = s.cache_write(C, 'global')

    _, oc_chunk, oh, ow, oc_block = s[C].op.axis
    ow_chunk, ow_block = s[C].split(ow, factor=reg_n)
    s[C].reorder(oc_chunk, oh, ow_chunk, ow_block, oc_block)
    s[C].fuse(oc_chunk, oh)
    s[C].vectorize(oc_block)

    s[CC].compute_at(s[C], ow_chunk)
    _, oc_chunk, oh, ow, oc_block = s[CC].op.axis
    ic, kh, kw = s[CC].op.reduce_axis

    ow_chunk, ow_block = s[CC].split(ow, factor=reg_n)
    ic_chunk, ic_block = s[CC].split(ic, factor=ic_bn)

    if unroll_kw:
        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, ic_block, kw, ow_block, oc_block)
        s[CC].unroll(kw)
    else:
        s[CC].reorder(oc_chunk, oh, ow_chunk, ic_chunk, kh, kw, ic_block, ow_block, oc_block)

    s[CC].fuse(oc_chunk, oh)
    s[CC].vectorize(oc_block)
    s[CC].unroll(ow_block)

    return s


def run_conv2d(batch_size, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides):
    A = tvm.placeholder((batch_size, in_channel, height, width), name='A', dtype=dtype)
    W = tvm.placeholder((num_filter, in_channel, filter_height, filter_width), name='W', dtype=dtype)

    a_shape = get_const_tuple(A.shape)
    w_shape = get_const_tuple(W.shape)

    def get_ref_data():
        a_np = np.random.uniform(size=a_shape).astype(dtype)
        w_np = np.random.uniform(size=w_shape).astype(dtype)
        from topi.testing.conv2d_nchw_python import conv2d_nchw_python
        conv_np = conv2d_nchw_python(a_np, w_np, stride=(1,1), padding=(1,1))
        return a_np, w_np, conv_np

    a_np, w_np, conv_np = get_ref_data()

    C = conv_compute(A, W, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides)
    s = conv_schedule(C, in_channel, height, width, num_filter, filter_height, filter_width, padding, strides)
    # s = tvm.create_schedule(C.op)
    print(tvm.lower(s, [A, W, C], simple_mode=True))

    from tvm import rpc
    from tvm.contrib import util
    host = '0.0.0.0'
    port = 8499
    remote = rpc.connect(host, port)
    ctx = remote.cpu()
    # ctx = tvm.cpu()
    func = tvm.build(s, [A, W, C], target=target)
    func.save('conv.s')
    temp = util.tempdir()
    path = temp.relpath('lib.tar')
    func.export_library(path)
    remote.upload(path)
    func = remote.load_module('lib.tar')

    conv = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=dtype), ctx)
    time_f = func.time_evaluator(func.entry_name, ctx, number=50)
    cost_conv = time_f(tvm.nd.array(a_np, ctx), tvm.nd.array(w_np, ctx), conv).mean
    print('conv: %g ms/op' % (cost_conv * 1000.0))

    # np.testing.assert_allclose(conv.asnumpy(), conv_np, rtol=1e-5)


if __name__ == "__main__":
    run_conv2d(batch_size=1, in_channel=64, height=56, width=56, num_filter=64, filter_height=3, filter_width=3,
               padding=(1, 1), strides=(1, 1))

4 Likes

@yzhliu may I know any update for this topic? I met similar issue https://discuss.tvm.apache.org/t/arm-fp16-instrin-support-in-m1-chip/12381