# Some standard imports import numpy as np from torch import nn import torch.onnx import torch.nn.init as init from caffe2.python.model_helper import ModelHelper from pytorch_helper import PyTorchModule import unittest from caffe2.python.core import workspace from test_pytorch_common import skipIfNoLapack class TestCaffe2Backend(unittest.TestCase): @skipIfNoLapack @unittest.skip("test broken because Lapack was always missing.") def test_helper(self): class SuperResolutionNet(nn.Module): def __init__(self, upscale_factor, inplace=False): super(SuperResolutionNet, self).__init__() self.relu = nn.ReLU(inplace=inplace) self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self._initialize_weights() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pixel_shuffle(self.conv4(x)) return x def _initialize_weights(self): init.orthogonal(self.conv1.weight, init.calculate_gain("relu")) init.orthogonal(self.conv2.weight, init.calculate_gain("relu")) init.orthogonal(self.conv3.weight, init.calculate_gain("relu")) init.orthogonal(self.conv4.weight) torch_model = SuperResolutionNet(upscale_factor=3) fake_input = torch.randn(1, 1, 224, 224, requires_grad=True) # use ModelHelper to create a C2 net helper = ModelHelper(name="test_model") start = helper.Sigmoid(["the_input"]) # Embed the ONNX-converted pytorch net inside it toutput, = PyTorchModule(helper, torch_model, (fake_input,), [start]) output = helper.Sigmoid(toutput) workspace.RunNetOnce(helper.InitProto()) workspace.FeedBlob("the_input", fake_input.data.numpy()) # print([ k for k in workspace.blobs ]) workspace.RunNetOnce(helper.Proto()) c2_out = workspace.FetchBlob(str(output)) torch_out = torch.sigmoid(torch_model(torch.sigmoid(fake_input))) np.testing.assert_almost_equal(torch_out.data.cpu().numpy(), c2_out, decimal=3) if __name__ == "__main__": unittest.main()