Skip to content

Commit d66d554

Browse files
authored
Add tearDown method to LoRA tests. (#6660)
* update * update
1 parent c7df846 commit d66d554

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

Diff for: tests/lora/test_lora_layers_old_backend.py

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
import gc
1617
import os
1718
import random
1819
import tempfile
@@ -1662,6 +1663,11 @@ def test_lora_at_different_scales(self):
16621663
@deprecate_after_peft_backend
16631664
@require_torch_gpu
16641665
class LoraIntegrationTests(unittest.TestCase):
1666+
def tearDown(self):
1667+
super().tearDown()
1668+
gc.collect()
1669+
torch.cuda.empty_cache()
1670+
16651671
def test_dreambooth_old_format(self):
16661672
generator = torch.Generator("cpu").manual_seed(0)
16671673

Diff for: tests/lora/test_lora_layers_peft.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
import copy
16+
import gc
1617
import importlib
1718
import os
1819
import tempfile
@@ -1205,6 +1206,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
12051206
"latent_channels": 4,
12061207
}
12071208

1209+
def tearDown(self):
1210+
super().tearDown()
1211+
gc.collect()
1212+
torch.cuda.empty_cache()
1213+
12081214
@slow
12091215
@require_torch_gpu
12101216
def test_integration_move_lora_cpu(self):
@@ -1434,6 +1440,11 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
14341440
"sample_size": 128,
14351441
}
14361442

1443+
def tearDown(self):
1444+
super().tearDown()
1445+
gc.collect()
1446+
torch.cuda.empty_cache()
1447+
14371448

14381449
@slow
14391450
@require_torch_gpu
@@ -1468,11 +1479,9 @@ class LoraIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
14681479
}
14691480

14701481
def tearDown(self):
1471-
import gc
1472-
1482+
super().tearDown()
14731483
gc.collect()
14741484
torch.cuda.empty_cache()
1475-
gc.collect()
14761485

14771486
def test_dreambooth_old_format(self):
14781487
generator = torch.Generator("cpu").manual_seed(0)
@@ -1757,11 +1766,9 @@ class LoraSDXLIntegrationTests(PeftLoraLoaderMixinTests, unittest.TestCase):
17571766
}
17581767

17591768
def tearDown(self):
1760-
import gc
1761-
1769+
super().tearDown()
17621770
gc.collect()
17631771
torch.cuda.empty_cache()
1764-
gc.collect()
17651772

17661773
def test_sdxl_0_9_lora_one(self):
17671774
generator = torch.Generator().manual_seed(0)

0 commit comments

Comments
 (0)