[RFC][SYCL] SYCL backend for TVM

Summary

Add a new backend language——SYCL, enhancing TVM’s compatibility and portability across different types of accelerators.

How to use?Similar to other backends, only need to specify target='sycl'. See User Guide for details.

Motivation

What is SYCL?

SYCL is a cross-platform programming language, targeting heterogeneous computing architecture with a host connected to various heterogeneous accelerators. In implementation, SYCL is a high-level abstraction layer that wraps low-level APIs such as OpenCL, CUDA, Level0, HIP, XRT, Vulkan, etc. Compared to the cross-platform OpenCL, SYCL provides a higher-level programming model based on modern C++ and broader device support.

SYCL emerged in 2015 as a high-level abstraction layer for OpenCL. After the SYCL 2020 specification, OpenCL is no longer the only low-level backend for SYCL. Although it has appeared for a short time, SYCL has always received attention from the industry. SYCL is a standard that has some different implementations, such as Intel® oneAPI DPC++, ComputeCpp, HipSYCL, NeoSYCL, and triSYCL.

Based on this background, we propose this RFC to add the SYCL backend, enhancing the compatibility and portability of TVM across different types of accelerators.

Guide-level explanation

How to use?

Similar to other backends such as cuda, specify target='sycl' in the corresponding TVM API.

tgt = tvm.target.Target(target='sycl') #Target
……
lib = relay.build(mod, target='sycl', params=params) #model build
……
dev = tvm.device('sycl', 0) # Device that support sycl
input = tvm.nd.array(data, device=dev) #model input

The following sample code shows that operator gemm with CUDA and SYCL backends respectively, and compare whether the results of the two backends are consistent.

import numpy as np
import tvm.relay as relay
from tvm.contrib import graph_executor
import tvm.testing
import tvm

# define GEMM
M = 1024
N = 1024
data_shape = (M, N)
dtype = 'float32'
X1 = relay.var("X1", shape=data_shape, dtype=dtype)
X2 = relay.var("X2", shape=data_shape, dtype=dtype)
func = relay.nn.dense(X1, X2)
mod = tvm.IRModule.from_expr(func)
# initialize input
X1_np = np.random.uniform(size=data_shape).astype(dtype)
X2_np = np.random.uniform(size=data_shape).astype(dtype)

def build(target:str):
    # model build
    tgt = tvm.target.Target(target=target, host="llvm")
    with tvm.transform.PassContext(opt_level=3):
        lib = relay.build(mod, target=tgt, params=None)
    # print CUDA/SYCL source code
    # print(lib.get_lib().imported_modules[0].get_source()) 
    dev = tvm.device(target, 0)
    module = graph_executor.GraphModule(lib["default"](dev))
    module.set_input("X1", X1_np)
    module.set_input("X2", X1_np)
    module.run()
    tvm_output = module.get_output(0).numpy()
    return tvm_output
    
cuda_output = build(target="cuda")
sycl_output = build(target="sycl")
tvm.testing.assert_allclose(cuda_output, sycl_output, rtol=1e-5, atol=1e-5)

In addition, SYCL backend supports performance optimization using Auto-scheduling. Auto-scheduling sample code reference Auto-scheduling a Neural Network for NVIDIA GPU — tvm 0.14.dev0 documentation, just specify target=‘sycl’.

Currently Supported devices:

  • NVIDIA GPU
  • AMD GPU
  • Intel GPU

Reference-level explanation

This RFC only adds the SYCL backend to TVM, no other features will be affected.

The added code for SYCL backend mainly includes:

  • SYCL codegen, from TIR to SYCL kernel code. The input of SYCL codegen is the abstract syntax tree of TIR, SYCL codegen traverses the TIR syntax tree, and converts TIR to SYCL kernel code.
  • SYCL runtime. SYCL host operations, such as memory copy, device information query, kernel submission, etc.

There are some SYCL-aware compilers, such as DPC++, hipSYCL and ComputeCpp. This RFC uses Open source DPC++, which built on LLVM and uses the Clang front end, SYCL 2020 standards.

Drawbacks

SYCL does not support runtime compilation like NVRTC for cuda now, which allows to compile codegen kernel code directly to an executable kernel at runtime. In order to make the SYCL backend compatible with the TVM runtime framework, this RFC compiles the SYCL kernel code into a dynamic link library for calling during TVM build. TVM build (for example, relay.build) time increases due to the overhead time of compiling to a dynamic link library when target='sycl'. If there are any problems, please let me know.

Rationale and alternatives

Prior art

Unresolved questions

Future possibilities

  • support TVM meta schedule and TVM unity
  • add additional optimizations for specific hardware types
  • support more types of accelerator
2 Likes

Awesome work! Thanks for the great RFC. Would be great to send it to https://github.com/apache/tvm-rfcs at the same time :slight_smile:

BTW, I’m curious about the performance gap between SYCL and vendor-provided language. i.e. If we have the same schedule and TIR, what’s the performance of the CUDA target and the SYCL target on the same device?

Also, I only see sycl target in the example. Is it used across different GPU kinds (ROCm, CUDA, etc.)

SYCL support different GPU kinds (nvidia, amd, etc). The following are two ways for specifying the GPU kind:

  • Set GPU kind when tvm compile. For example, set(SYCL_GPU "nvidia") in tvm config.cmake, then set target="sycl" in user code.

  • Set GPU kind by Target. For example, set target="sycl -gpu=nvidia" in user code.

Which way is better?

OK :grinning_face_with_smiling_eyes:, the detailed performance comparison between SYCL and vendor-provided language after sufficient auto-tuning optimization will be published in a few days.

:thinking:Due to different underlying compilers (for example, nvcc vs clang++), the optimal schedule for different backend languages is different. One schedule may perform well for CUDA, but not necessarily for SYCL.

This is definitely interesting! I’d love to learn a bit more about the potential approach to GPU codegen - to accommodate SYCL, what passes are you expecting to be updated, and what are expected to be added. A high-level description of those changes would be much appreciated to help us understand the system design!

Thanks! “passes” refers to compiler passes, such as operator fusion and constant folding, right?

The input of codegen is the optimized tensor IR. All optimization passes are completed before this. Therefore, SYCL backend does not need to update or add any passes.

From a performance optimization perspective, hardware-related passes should be selected based on the specific hardware type.

The input of codegen is the optimized tensor IR. All optimization passes are completed before this. Therefore, SYCL backend does not need to update or add any passes.

This is great to hear that existing optimizations have already been sufficient to generate code for SYCL. Usually I would expect some hardware-related changes, but it seems a non-issue for SYCL, which is great!

From a performance optimization perspective, hardware-related passes should be selected based on the specific hardware type.

Could you elaborate in this proposal that which passes are required? This will give a better picture on the subsequent engineering efforts needed.

My high level take is that we mainly need a codegen backend and runtime, just like the opencl codegen and runtime. Should be compatible with the existing TensorIR infra

Good question. The hardware-related pass for sycl is currently only a rough idea. It remains to be further experimented as to what passes should be added for each specific hardware type.

You are right. In fact, this repo has a basic implementation of SYCL codegen and runtime.