During a conversation with @slyubomirsky, it came up that Relax functions have two distinct ways to express scalar values. Both PrimValue (struct info of R.Prim) and and scalar tensors (struct info of R.Tensor(ndim=0)) can store a single scalar value. However, the allowed usage for PrimValue and 0-d tensors is very different.
- A relax tensor can be scaled by a 0-d tensor, but not by a PrimValue.
- A PrimValue can be used to define a symbolic shape variable, but a 0-d tensor cannot.
- Some operations require a 0-d tensor, and cannot accept the equivalent PrimValue (e.g. The fill value for
R.full) - Some operations require a PrimValue, and cannot accept the equivalent 0-d tensor (e.g. The offset parameter for
R.tril) - A TIR functions with primitive TIR parameters can called through
R.call_tirwith PrimValue arguments, but not 0-d tensors.
In this gist, I show several attempts to work around these limitations. Setting aside a few bugs that need to be fixed, I was able to convert int64 scalars between PrimValue and 0-d tensors, but it requires a round-about series of conversions. These methods use R.shape as an intermediate, and can therefore only be applied to int64 values.
I think we need methods that could convert between PrimValue and 0-d tensors in general. Ideally, this could be transparent to the end-user, similar to how operations between scalar values and arrays are handled in numpy or pytorch. This would have a number of benefits.
-
User-friendliness
- Consistency with other packages: Pytorch and numpy both allow scalar values to be broadcasted in tensor operations, and do not require a user to explicitly convert the scalar to a 0-d tensor first.
- Consistency within TVM: TE allows scalar values in expressions. Multiplying a TE tensor by a 0-d tensor produces the expression
A[ax0, ax1] * B[()], and multiplying a TE tensor by a scalar TIR variable produces the expressionA[ax0, ax1] * scalar. - Ease of use for compiled functions. Currently, relax scalar parameters must be wrapped in either
tvm.runtime.ShapeTupleortvm.nd.array, depending on whether they are used as symbolic variables or relax tensor within the function. A user shouldn’t need to know how a scalar is going to be used internally in order to provide a scalar to the function. AllowingR.Primto be accepted by the function and easily used in relax
-
More efficient representations when lowered to TIR
- (Possibly) more efficient kernel API. A 0-d tensor is passed by pointer, where a scalar is passed as a value. The scalar may be passed as a single register, where the tensor requires dereferencing a pointer-to-scalar on use.
- (Possibly) more efficient kernel optimizations. At the TIR level, a
ScalarBuf[()]access may be modified between use, where aScalarVarhas a single known value. Optimizing multiple accesses ofScalarBuf[()]requires dataflow tracking to verify that the buffer hasn’t changed.