CUDA FP16 example

Is converting a model to FP16 with target = “cuda” supported? If so, is there an example pass I could look at to convert my model?

cc @vinx13 @Hzfengsy

Thanks!

1 Like

also cc @anijain2305 @xyzhou

Unfortunately, I’m afraid there isn’t. We usually use native fp16 models directly.

Got it. Is there any plan to do this in the future?

@jonso here’s a little relay pass I’ve been using to downcast a model from FP32 to FP16. I don’t think the target really matters as you would apply this pass before compiling to cuda.

from tvm.relay import transform as _transform
from tvm.ir import IRModule
from tvm.relay import cast

def downcast_fp16(func):
    # pylint: disable=line-too-long
    """Downcast to fp16 mutator
    Parameters
    ---------
    graph: Function
        The original graph.
    Retruns
    -------
    The graph after dowmcasting to half-precision floating-point.
    """
    # get_valid_counts and non_max_suppression does not support fp16 so we create a filter list for them
    filter_list = ['vision.get_valid_counts', 'vision.non_max_suppression']
    class DowncastMutator(ExprMutator):
        """Downcast to fp16 mutator"""
        def visit_call(self, call):
            dtype = 'float32' if call.op.name in filter_list else 'float16'
            new_fn = self.visit(call.op)
            # Collec the original dtypes
            type_list = []
            if call.op.name in filter_list:
                # For nms
                for arg in call.args:
                    if isinstance(arg, TupleGetItem) and isinstance(arg.tuple_value, Call):
                        tuple_types = arg.tuple_value.checked_type.fields
                        type_list.append(tuple_types[arg.index].dtype)
                if call.op.name == 'vision.get_valid_counts':
                    tuple_types = call.checked_type.fields
                    for cur_type in tuple_types:
                        type_list.append(cur_type.dtype)

            args = [self.visit(arg) for arg in call.args]
            new_args = list()
            arg_idx = 0
            for arg in args:
                if isinstance(arg, (Var, Constant)):
                    new_args.append(cast(arg, dtype=dtype))
                else:
                    if call.op.name in filter_list:
                        if isinstance(arg, TupleGetItem) and type_list[arg_idx] == 'int32':
                            new_args.append(arg)
                        else:
                            new_args.append(cast(arg, dtype=dtype))
                    else:
                        new_args.append(arg)
                arg_idx += 1
            if call.op.name in filter_list and call.op.name != 'vision.get_valid_counts':
                return cast(Call(new_fn, new_args, call.attrs), dtype='float16')
            return Call(new_fn, new_args, call.attrs)

    class UpcastMutator(ExprMutator):
        """upcast output back to fp32 mutator"""
        def visit_call(self, call):
            return cast(call, dtype='float32')

    def infer_type(expr):
        """A method to infer the type of an intermediate node in the relay graph"""
        mod = IRModule.from_expr(expr)
        mod = _transform.InferType()(mod)
        entry = mod["main"]
        return entry if isinstance(expr, Function) else entry.body

    func = infer_type(func)
    downcast_pass = DowncastMutator()
    func = downcast_pass.visit(func)
    upcast_pass = UpcastMutator()
    func = upcast_pass.visit(func)
    func = infer_type(func)
    return func
4 Likes

Awesome, thanks a lot @jwfromm! Do you have any experience in how this impacts accuracy? For example, I know that CUDA’s __float2half function has a decent amount of logic.

unfortunately i havent really tested the impact on accuracy. I’ve been using this pass primarily for perf measurements. I think it should be fairly easy to modify to improve the conversion logic though.

I am trying to find the actual implementation of Cast for each device, but am having trouble finding it. @jwfromm do you know where it is?

Assuming you’re looking for the low level code, you can find the cuda cast generator in tvm/src/target/source/ codegen_cuda in the VisitExpr(const CastNode* op, std::ostream& os) function. However, you probably want to do the casting ahead of time in relay rather than on device. If you use a pass like the one I posted above, you convert the operations int he graph before they’re compiled.

Thanks a lot. I’ve been playing around with this on a BERT model, but I’m hitting some issues when calling relay.build with opt level 3. The target is cuda. The error message looks like this:

unresolved intrinsic sqrt with return type float16x4

It comes from codegen_c.cc. Does this mean that sqrt isn’t supported with float16?

Edit: @Hzfengsy or @trevor-m do you have any insight here? My target is cuda, and the part of the code this is complaining about is outside of my external codegen (it’s compiled by TVM).

Actually, why is codegen_c being called at all here? It seems like only codegen_cuda should be called.