From 48b5a460bc0cd3e6ae70655b1bf4b5b4c20ff79e Mon Sep 17 00:00:00 2001 From: zhujian <505169307@qq.com> Date: Tue, 28 May 2019 14:30:43 +0800 Subject: [PATCH] =?UTF-8?q?fix(LeNet5):=20=E6=89=B9=E9=87=8F=E5=A4=A7?= =?UTF-8?q?=E5=B0=8F=E4=B8=BA1=E6=97=B6LeNet5=E5=89=8D=E5=90=91=E6=93=8D?= =?UTF-8?q?=E4=BD=9C=E4=B8=ADC5->F6=E6=AD=A5=E9=AA=A4=E5=87=BA=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nn/nets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nn/nets.py b/nn/nets.py index 07dfe11..6c9ef83 100644 --- a/nn/nets.py +++ b/nn/nets.py @@ -109,7 +109,7 @@ def forward(self, inputs): x = self.maxPool2(x) x = self.relu3(self.conv3(x)) # (N, C, 1, 1) -> (N, C) - x = x.squeeze() + x = x.reshape(x.shape[0], -1) x = self.relu4(self.fc1(x)) x = self.fc2(x)