From 6795d20b55fb47d11841df79f56b2b937395cfa3 Mon Sep 17 00:00:00 2001 From: Angela Zhang Date: Wed, 3 Mar 2021 19:00:46 -0500 Subject: [PATCH] Update unetSeg.py --- source/modules/nphprediction/source/unetSeg.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/source/modules/nphprediction/source/unetSeg.py b/source/modules/nphprediction/source/unetSeg.py index b776e41..cd18e2d 100644 --- a/source/modules/nphprediction/source/unetSeg.py +++ b/source/modules/nphprediction/source/unetSeg.py @@ -205,9 +205,9 @@ def forward(self, x, prev): net.seg3 = nn.Conv3d(128, num_classes, kernel_size=(1,1,1), stride=(1,1,1)) net.seg2 = nn.Conv3d(64, num_classes, kernel_size=(1,1,1), stride=(1,1,1)) net.seg1 = nn.Conv3d(32, num_classes, kernel_size=(1,1,1), stride=(1,1,1)) - net.seg3.weight = nn.Parameter(unet.seg3.weight) - net.seg2.weight = nn.Parameter(unet.seg2.weight) - net.seg1.weight = nn.Parameter(unet.seg1.weight) + net.seg3.weight = nn.Parameter(net.seg3.weight) + net.seg2.weight = nn.Parameter(net.seg2.weight) + net.seg1.weight = nn.Parameter(net.seg1.weight) net.cpu() if gpu: