During a conversation with @slyubomirsky, it came up that Relax functions have two distinct ways to express scalar values. Both PrimValue (struct info of R.Prim
) and and scalar tensors (struct info of R.Tensor(ndim=0)
) can store a single scalar value. However, the allowed usage for PrimValue and 0-d tensors is very different.
- A relax tensor can be scaled by a 0-d tensor, but not by a PrimValue.
- A PrimValue can be used to define a symbolic shape variable, but a 0-d tensor cannot.
- Some operations require a 0-d tensor, and cannot accept the equivalent PrimValue (e.g. The fill value for
R.full
) - Some operations require a PrimValue, and cannot accept the equivalent 0-d tensor (e.g. The offset parameter for
R.tril
) - A TIR functions with primitive TIR parameters can called through
R.call_tir
with PrimValue arguments, but not 0-d tensors.
In this gist, I show several attempts to work around these limitations. Setting aside a few bugs that need to be fixed, I was able to convert int64
scalars between PrimValue and 0-d tensors, but it requires a round-about series of conversions. These methods use R.shape
as an intermediate, and can therefore only be applied to int64
values.
I think we need methods that could convert between PrimValue and 0-d tensors in general. Ideally, this could be transparent to the end-user, similar to how operations between scalar values and arrays are handled in numpy or pytorch. This would have a number of benefits.
-
User-friendliness
- Consistency with other packages: Pytorch and numpy both allow scalar values to be broadcasted in tensor operations, and do not require a user to explicitly convert the scalar to a 0-d tensor first.
- Consistency within TVM: TE allows scalar values in expressions. Multiplying a TE tensor by a 0-d tensor produces the expression
A[ax0, ax1] * B[()]
, and multiplying a TE tensor by a scalar TIR variable produces the expressionA[ax0, ax1] * scalar
. - Ease of use for compiled functions. Currently, relax scalar parameters must be wrapped in either
tvm.runtime.ShapeTuple
ortvm.nd.array
, depending on whether they are used as symbolic variables or relax tensor within the function. A user shouldn’t need to know how a scalar is going to be used internally in order to provide a scalar to the function. AllowingR.Prim
to be accepted by the function and easily used in relax
-
More efficient representations when lowered to TIR
- (Possibly) more efficient kernel API. A 0-d tensor is passed by pointer, where a scalar is passed as a value. The scalar may be passed as a single register, where the tensor requires dereferencing a pointer-to-scalar on use.
- (Possibly) more efficient kernel optimizations. At the TIR level, a
ScalarBuf[()]
access may be modified between use, where aScalarVar
has a single known value. Optimizing multiple accesses ofScalarBuf[()]
requires dataflow tracking to verify that the buffer hasn’t changed.