diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index e173a4c72174..bdb8920a399e 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -96,6 +96,10 @@ class Base4bitTests(unittest.TestCase): num_inference_steps = 10 seed = 0 + @classmethod + def setUpClass(cls): + torch.use_deterministic_algorithms(True) + def get_dummy_inputs(self): prompt_embeds = load_pt( "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", @@ -480,7 +484,6 @@ def test_generate_quality_dequantize(self): r""" Test that loading the model and unquantize it produce correct results. """ - torch.use_deterministic_algorithms(True) self.pipeline_4bit.transformer.dequantize() output = self.pipeline_4bit( prompt=self.prompt, diff --git a/tests/quantization/bnb/test_mixed_int8.py b/tests/quantization/bnb/test_mixed_int8.py index bb7b12de60fe..d048b0b7db46 100644 --- a/tests/quantization/bnb/test_mixed_int8.py +++ b/tests/quantization/bnb/test_mixed_int8.py @@ -97,6 +97,10 @@ class Base8bitTests(unittest.TestCase): num_inference_steps = 10 seed = 0 + @classmethod + def setUpClass(cls): + torch.use_deterministic_algorithms(True) + def get_dummy_inputs(self): prompt_embeds = load_pt( "https://huggingface.co/datasets/hf-internal-testing/bnb-diffusers-testing-artifacts/resolve/main/prompt_embeds.pt", @@ -485,7 +489,6 @@ def test_generate_quality_dequantize(self): r""" Test that loading the model and unquantize it produce correct results. """ - torch.use_deterministic_algorithms(True) self.pipeline_8bit.transformer.dequantize() output = self.pipeline_8bit( prompt=self.prompt,