[Unity] PrimValue and 0-d tensors in Relax functions

During a conversation with @slyubomirsky, it came up that Relax functions have two distinct ways to express scalar values. Both PrimValue (struct info of R.Prim) and and scalar tensors (struct info of R.Tensor(ndim=0)) can store a single scalar value. However, the allowed usage for PrimValue and 0-d tensors is very different.

  • A relax tensor can be scaled by a 0-d tensor, but not by a PrimValue.
  • A PrimValue can be used to define a symbolic shape variable, but a 0-d tensor cannot.
  • Some operations require a 0-d tensor, and cannot accept the equivalent PrimValue (e.g. The fill value for R.full)
  • Some operations require a PrimValue, and cannot accept the equivalent 0-d tensor (e.g. The offset parameter for R.tril)
  • A TIR functions with primitive TIR parameters can called through R.call_tir with PrimValue arguments, but not 0-d tensors.

In this gist, I show several attempts to work around these limitations. Setting aside a few bugs that need to be fixed, I was able to convert int64 scalars between PrimValue and 0-d tensors, but it requires a round-about series of conversions. These methods use R.shape as an intermediate, and can therefore only be applied to int64 values.

I think we need methods that could convert between PrimValue and 0-d tensors in general. Ideally, this could be transparent to the end-user, similar to how operations between scalar values and arrays are handled in numpy or pytorch. This would have a number of benefits.

  • User-friendliness

    • Consistency with other packages: Pytorch and numpy both allow scalar values to be broadcasted in tensor operations, and do not require a user to explicitly convert the scalar to a 0-d tensor first.
    • Consistency within TVM: TE allows scalar values in expressions. Multiplying a TE tensor by a 0-d tensor produces the expression A[ax0, ax1] * B[()], and multiplying a TE tensor by a scalar TIR variable produces the expression A[ax0, ax1] * scalar.
    • Ease of use for compiled functions. Currently, relax scalar parameters must be wrapped in either tvm.runtime.ShapeTuple or tvm.nd.array , depending on whether they are used as symbolic variables or relax tensor within the function. A user shouldn’t need to know how a scalar is going to be used internally in order to provide a scalar to the function. Allowing R.Prim to be accepted by the function and easily used in relax
  • More efficient representations when lowered to TIR

    • (Possibly) more efficient kernel API. A 0-d tensor is passed by pointer, where a scalar is passed as a value. The scalar may be passed as a single register, where the tensor requires dereferencing a pointer-to-scalar on use.
    • (Possibly) more efficient kernel optimizations. At the TIR level, a ScalarBuf[()] access may be modified between use, where a ScalarVar has a single known value. Optimizing multiple accesses of ScalarBuf[()] requires dataflow tracking to verify that the buffer hasn’t changed.
1 Like

0-d tensor and prim value are two different things becuase they resides in either host(in the case of PrimValue) and device(0-d tensor)

In general we could enhance certain cases to directly work with the prim values and not having to convert to 0-d tensors. converting into 0-d tensor is good since it usually involves just a kernel then getting inlined.

Converting back would need more mindfulness as the computation of the 0-d tensor are async, so there will be more performace implications (and sync in pipeline), as such being explicit is likely better.

I believe what we said when we first added PrimValues is that we wanted scalars used for arithmetic to be 0-dimensional tensors and to use PrimValues for metadata. What prompted the discussion between Eric and me is that the existence of the shape-to-tensor, tensor-to-shape, and shape-to-PrimValue conversions means that there can be multiple ways to express the same value. It’s a little confusing. Personally, I like drawing a line between PrimValues and 0-dimensional tensors, but Eric makes the very good point that TIR supports scalar inputs (which are distinct from 0-dimensional tensors), which is relevant for Relax’s direct support for PrimFuncs.

0-d tensor and prim value are two different things becuase they resides in either host(in the case of PrimValue) and device(0-d tensor)

Good point on the residency. Thinking on it, I think I’d describe the distinction slightly differently: A tensor may be expensive to copy, and has a specified location. A PrimValue is cheap to copy, and has an unspecified location.

I’m don’t think it would be accurate to say that a PrimValue is always on the host-side, because it may appear within a fused relax function marked with attr::kPrimitive. Specifying them as being cheap to copy would capture the semantics that they may occur on the host-side (user-provided parameters), they may occur on the device-side (part of primitive functions), and they may occur on the interface between host/device (scalar kernel parameters). The exact lowered form of a PrimValue would be up to the compiler, and wouldn’t be something present in the user-facing Relax semantics.

In general we could enhance certain cases to directly work with the prim values and not having to convert to 0-d tensors. converting into 0-d tensor is good since it usually involves just a kernel then getting inlined.

I think that would cover the majority of the use cases I’ve run into. Having a conversion from PrimValue to 0-d tensor would solve the usability issues with operators like R.full and R.mul(tensor, prim_value).

Converting back would need more mindfulness as the computation of the 0-d tensor are async, so there will be more performace implications (and sync in pipeline), as such being explicit is likely better.

Good point. For cases where the PrimValue would be generated and consumed within a single fused kernel, that overhead wouldn’t apply, but a PrimValue generated on the device and used on the host would require synchronization.

I think the 0d-tensor-to-scalar conversion would be much rarer in practice, and can be tabled for now.

After @tqchen’s comment, I think I’m in favor of there being a distinction between the two, with tensors being able to indicate where they are located. The gaps in expressibility (e.g. scaling by a PrimValue) and user-friendliness (e.g. passing a python float instead of tvm.nd.array(value, dtype='float32')) should be solvable with automatic wrappers.

For performance differences, I did a bit more digging, and I think we’re okay. The tir.noalias attribute, set by default for most kernels, is used to output the __restrict__ attribute for kernel pointer arguments. My main worry was that optimization would be worse for void scale(float16* buf, size_t n, float16* pointer_to_scalar) than for void scale(float16* buf, size_t n, float16 scalar), in case buf <= pointer_to_scalar < buf+n. With the __restrict__ keyword, we should be good there, but I’d want to performance test it before saying that conclusively.