I’m trying to fix the current reshape problem in the keras frontend, but not yet succeeded. This is the testcase I cannot pass.
def test_forward_reshape():
data = keras.layers.Input(shape=(32,32,3))
x = keras.layers.Reshape(target_shape=(32,32,3))(data)
x = keras.layers.GlobalAveragePooling2D()(x)
keras_model = keras.models.Model(data, x)
verify_keras_frontend(keras_model)
The problem is that the target_shape
parameter in Reshape
is in HWC format but the input tensor is in (N)CHW format.
Is there a good way to solve this problem?
My ongoing work is here: