File tree 1 file changed +1
-10
lines changed
1 file changed +1
-10
lines changed Original file line number Diff line number Diff line change 55
55
import copy
56
56
import tempfile
57
57
import gc
58
- import time
59
58
from torch .testing ._internal .common_utils import TestCase
60
59
61
60
@@ -692,25 +691,17 @@ def reset_memory():
692
691
693
692
reset_memory ()
694
693
m = ToyLinearModel ()
695
- time0 = time .perf_counter ()
696
- m .to (device = "cuda" )
697
- quantize_ (m , int8_weight_only ())
698
- torch .cuda .synchronize ()
699
- time_baseline = time .perf_counter () - time0
694
+ quantize_ (m .to (device = "cuda" ), int8_weight_only ())
700
695
memory_baseline = torch .cuda .max_memory_allocated ()
701
696
702
697
del m
703
698
reset_memory ()
704
699
m = ToyLinearModel ()
705
- time0 = time .perf_counter ()
706
700
quantize_ (m , int8_weight_only (), device = "cuda" )
707
- torch .cuda .synchronize ()
708
- time_streaming = time .perf_counter () - time0
709
701
memory_streaming = torch .cuda .max_memory_allocated ()
710
702
711
703
for param in m .parameters ():
712
704
assert param .is_cuda
713
- self .assertLess (time_streaming , time_baseline * 1.5 )
714
705
self .assertLess (memory_streaming , memory_baseline )
715
706
716
707
You can’t perform that action at this time.
0 commit comments