diff --git a/nn/nets.py b/nn/nets.py index 07dfe11..6c9ef83 100644 --- a/nn/nets.py +++ b/nn/nets.py @@ -109,7 +109,7 @@ def forward(self, inputs): x = self.maxPool2(x) x = self.relu3(self.conv3(x)) # (N, C, 1, 1) -> (N, C) - x = x.squeeze() + x = x.reshape(x.shape[0], -1) x = self.relu4(self.fc1(x)) x = self.fc2(x)