From 2d0d831ae653912bdf2d026d94a3cdaec9aab7fd Mon Sep 17 00:00:00 2001 From: Louis Date: Wed, 24 Jan 2024 15:47:00 +0000 Subject: [PATCH] Fix torch error on misc.CovarianceMatrix (#147) --- .github/workflows/test.yml | 2 +- deepdow/layers/misc.py | 2 +- docs/source/changelog.rst | 9 +++++++++ examples/layers/warp.py | 4 ++-- 4 files changed, 13 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e41053d..5016e13 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -41,7 +41,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python-version: [3.7, 3.8, 3.9] + python-version: ['3.10', '3.11'] steps: - name: Checkout most recent commit uses: actions/checkout@v3 diff --git a/deepdow/layers/misc.py b/deepdow/layers/misc.py index 1869363..3b4cd58 100644 --- a/deepdow/layers/misc.py +++ b/deepdow/layers/misc.py @@ -141,7 +141,7 @@ def compute_covariance(m, shrinkage_strategy=None, shrinkage_coef=0.5): """ fact = 1.0 / (m.size(1) - 1) - m -= torch.mean(m, dim=1, keepdim=True) # !!!!!!!!!!! INPLACE + m = m - torch.mean(m, dim=1, keepdim=True) mt = m.t() s = fact * m.matmul(mt) # sample covariance matrix diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst index 81e7b91..b10fb71 100644 --- a/docs/source/changelog.rst +++ b/docs/source/changelog.rst @@ -1,6 +1,15 @@ Changelog ========= +Unreleased +---------- +Fixes +***** +- Fix PyTorch in-place issue and matplotib plotting problem - #147 + +v0.2.2 +------ + v0.2.1 ------ Added diff --git a/examples/layers/warp.py b/examples/layers/warp.py index 4c33940..6d1e77e 100644 --- a/examples/layers/warp.py +++ b/examples/layers/warp.py @@ -52,5 +52,5 @@ axs[i, 0].plot(tform.numpy().squeeze(), linewidth=3, color='red') axs[i, 1].plot(x_warped.numpy().squeeze(), linewidth=3, color='blue') - axs[i, 0].set_title(r'$\bf{}$ tform'.format(tform_name)) - axs[i, 1].set_title(r'$\bf{}$ warped'.format(tform_name)) + axs[i, 0].set_title('{} tform'.format(tform_name)) + axs[i, 1].set_title('{} warped'.format(tform_name))