A SPIRV codegen bug or TIR issue?

When I compile a fp16 model into SPIRV, it hit a fatal error at codegen_spirv.cc:453: “Only aligned continuous vector access is allowed in SPIRV”

This is because a TIR LoadNode is neither a scalarized load, nor a vectorized load with a valid RampNode index. The error occurs in “fused_image_resize_kernel0” layer. I dumped the layer’s TIR , but not sure whether it’s a TIR problem or a SPIRV codegen bug. Can someone tell from the following TIR dump?

fused_image_resize_kernel0:

#[version = “0.0.5”] primfn(resize: Pointer(float16x4), placeholder: Pointer(float16x4)) → () attr = {“target”: meta[Target][0], “tir.noalias”: 1, “global_symbol”: “fused_image_resize_kernel0”, “tir.device_thread_axis”: [IterVar(blockIdx.x: int32, (nullptr), “ThreadIndex”, “blockIdx.x”), IterVar(threadIdx.x: int32, (nullptr), “ThreadIndex”, “threadIdx.x”)], “calling_conv”: 2} { attr [IterVar(blockIdx.x, (nullptr), “ThreadIndex”, “blockIdx.x”)] “thread_extent” = 256; attr [IterVar(threadIdx.x, (nullptr), “ThreadIndex”, “threadIdx.x”)] “thread_extent” = 256; for (i0.i1.fused.i2.fused.i3.fused.outer.outer: int32, 0, 8) { resize[ramp(((((i0.i1.fused.i2.fused.i3.fused.outer.outer262144) + (@tir.shift_right(((blockIdx.x1024) + (threadIdx.x4)), 15, dtype=int32)32768)) + (@tir.bitwise_and(((blockIdx.x4) + @tir.shift_right(threadIdx.x, 6, dtype=int32)), 127, dtype=int32)256)) + (@tir.bitwise_and(threadIdx.x, 63, dtype=int32)4)), 1, 4)] = cast(float16x4, cast(float32x4, (float16x4)placeholder[((broadcast(((i0.i1.fused.i2.fused.i3.fused.outer.outer65536) + (@tir.shift_right(((blockIdx.x1024) + (threadIdx.x*4)), 15, dtype=int32)*8192)), 4) + (max(min(cast(int32x4, @tir.call_spirv_pure_glsl450(8u32, ((broadcast(0.5f32, 4)cast(float32x4, broadcast(@tir.bitwise_and(((blockIdx.x4) + @tir.shift_right(threadIdx.x, 6, dtype=int32)), 127, dtype=int32), 4))) + broadcast(1e-05f32, 4)), dtype=float32x4)), broadcast(63, 4)), broadcast(0, 4))*broadcast(128, 4))) + max(min(cast(int32x4, @tir.call_spirv_pure_glsl450(8u32, ((broadcast(0.5f32, 4)*cast(float32x4, ramp((@tir.bitwise_and(threadIdx.x, 63, dtype=int32)*4), 1, 4))) + broadcast(1e-05f32, 4)), dtype=float32x4)), broadcast(127, 4)), broadcast(0, 4)))])) } }

Here’s the TIR dump with a more readable format, the error occurs when compiling the right side of the Load operator: “placeholder[…]”, which is is neither a scalarized load, nor a vectorized load with a valid RampNode index. Unlike resize[…] on the left side, the index for placeholder[…] cannot be casted to a RampNode. Therefore it hit the fatal error.

resize[ramp((	(  ((i0.i1.fused.i2.fused.i3.fused.outer.outer*262144) + (@tir.shift_right(((blockIdx.x*1024) + (threadIdx.x*4)), 15, dtype=int32)*32768))
				  + (@tir.bitwise_and(((blockIdx.x*4) + @tir.shift_right(threadIdx.x, 6, dtype=int32)), 127, dtype=int32)*256)  ) 
			  + (@tir.bitwise_and(threadIdx.x, 63, dtype=int32)*4)), 
			1, 
			4)] 
=  cast(float16x4, 
	 cast(float32x4, 
		  (float16x4*)placeholder[ (  ( broadcast(((i0.i1.fused.i2.fused.i3.fused.outer.outer*65536) + (@tir.shift_right(((blockIdx.x*1024) + (threadIdx.x*4)), 15, dtype=int32)*8192)), 4)
									    + (  max(  min( cast( int32x4, 
															 @tir.call_spirv_pure_glsl450( 8u32, 
																						   (  (	 broadcast(0.5f32, 4)
																							    * cast(float32x4, 
																									   broadcast( @tir.bitwise_and( ( (blockIdx.x*4) + @tir.shift_right(threadIdx.x, 6, dtype=int32)),
																																	  127, 
																																	  dtype=int32), 
																												  4)
																									   )
																							   ) 
																							   + broadcast(1e-05f32, 4)
																							),
																						   dtype=float32x4
																						 )
															), 
														broadcast(63, 4)
													), 
													broadcast(0, 4)
											 ) 
											 * broadcast(128, 4)
										  )
									  ) 
									  + max( min( cast( int32x4, 
														@tir.call_spirv_pure_glsl450( 8u32, 
																					  (  (broadcast(0.5f32, 4)*cast(float32x4, ramp((@tir.bitwise_and(threadIdx.x, 63, dtype=int32)*4), 1, 4))) 
																						+ broadcast(1e-05f32, 4)), 
																					   dtype=float32x4)
																					 ), 
														broadcast(127, 4)), 
											broadcast(0, 4))
									)
								]
	)
)

cc @Lunderberg I’ve seen this before, fp16 is not supported for our vulkan backend yet.

This bug appears when a fp16 multiply is vectorized by “schedule_injective_from_existing” in topi/cuda/injective.py, and the multiply has an operand whose dimension is not the same as the other operand. Vectorizer inserts broadcast instead of creating a Ramp. A possible fix might be to disable the vectorization for injective ops whose operands do not have the same dimension for the vulkan target. Is this reasonable?

Interesting, what happens if we simply set vector_width = 1 at tvm/injective.py at 813136401a11a49d6c15e6013c34dd822a5c4ff6 · apache/tvm · GitHub?

That works. Actually set vector_width=2 also works. This is because for vector_width=2, we have something like below, where broadcast is outside of placeholder: T_multiply_2[ramp(((blockIdx.x2048) + (threadIdx.x2)), 1, 2)] = (broadcast((float16*)placeholder_5[blockIdx.x], 2)(float16x2)placeholder_4[ramp(((blockIdx.x2048) + (threadIdx.x2)), 1, 2)])

When vector_width=4, we have: T_multiply_7[ramp(((blockIdx.x_164096) + (threadIdx.x_344)), 1, 4)] = ((float16x4*)placeholder_39[broadcast(((blockIdx.x_162) + @tir.shift_right(threadIdx.x_34, 9, dtype=int32)), 4)](float16x4*)placeholder_40[ramp(((blockIdx.x_164096) + (threadIdx.x_344)), 1, 4)])

Does this issue still occur after PR#8528? I ran into similar issues where a broadcast node used as an index caused incorrect rewriting of array types.