- @pytest.mark.parametrize(
- "x_shape, y_shape, transpose_y, epilogue, residual_block",
- [
- # Regular
- ((32, 6), (6, 16), False, "none", "none"),
- ((_vars["a"], 6), (6, 16), False, "bias", "none"),
- # Transposed
- ((4, 16), (16, 128), True, "relu", "none"),
- ((35, 8), (8, 8), True, "gelu", "none"),
- # 3D x 3D
- ((6, 32, 8), (6, 8, 10), False, "bias", "none"),
- ((6, 32, 8), (6, 8, 10), True, "none", "none"),
- ((_vars["a"], 32, 8), (_vars["a"], 8, 10), True, "gelu", "none"),
- # 3D x 2D
- ((6, 32, 8), (8, 10), False, "none", "none"),
- ((_vars["a"], 32, 8), (8, 10), False, "bias", "none"),
- ((10, 16, 8), (8, 10), True, "relu", "none"),
- # 2D x 3D
- ((32, 8), (10, 8, 10), False, "relu", "none"),
- ((32, 8), (_vars["a"], 8, 10), True, "gelu", "none"),
This file has been truncated. show original