onnx - Serbipunk/notes GitHub Wiki

trouble shooting

Please fix either the inputs or the model.

nn.Unflatten

https://github.com/pytorch/pytorch/issues/49538#issuecomment-1137847439

class View(nn.Module):
    def __init__(self, dim,  shape):
        super(View, self).__init__()
        self.dim = dim
        self.shape = shape

    def forward(self, input):
        new_shape = list(input.shape)[:self.dim] + list(self.shape) + list(input.shape)[self.dim+1:]
        return input.view(*new_shape)


nn.Unflatten = View