@@ -326,8 +326,8 @@ def create_quantized_state_dict(self):
326
326
for fqn , mod in self .mod .named_modules ():
327
327
if isinstance (mod , torch .nn .Linear ):
328
328
int8_weight , scales , _ = dynamically_quantize_per_channel (mod .weight .float (), - 128 , 127 , torch .int8 )
329
- cur_state_dict [f"{ fqn } .weight" ] = int8_weight
330
- cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype )
329
+ cur_state_dict [f"{ fqn } .weight" ] = int8_weight . to ( 'cpu' )
330
+ cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype ). to ( 'cpu' )
331
331
332
332
return cur_state_dict
333
333
@@ -376,21 +376,21 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
376
376
def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
377
377
return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
378
378
379
- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding ):
379
+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding , use_cuda ):
380
380
for name , child in module .named_children ():
381
381
if isinstance (child , nn .Linear ):
382
382
if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
383
383
setattr (module , name , WeightOnlyInt4Linear (
384
384
child .in_features , child .out_features , bias = False ,
385
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False ,
385
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False , use_cuda = use_cuda
386
386
))
387
387
elif padding :
388
388
setattr (module , name , WeightOnlyInt4Linear (
389
389
child .in_features , child .out_features , bias = False ,
390
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True ,
390
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True , use_cuda = use_cuda
391
391
))
392
392
else :
393
- replace_linear_int4 (child , groupsize , inner_k_tiles , padding )
393
+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding , use_cuda )
394
394
395
395
396
396
class WeightOnlyInt4QuantHandler :
@@ -403,12 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
403
403
assert inner_k_tiles in [2 , 4 , 8 ]
404
404
405
405
@torch .no_grad ()
406
- def create_quantized_state_dict (self , use_cuda = True ):
407
- if use_cuda :
408
- device = "cuda"
409
- else :
410
- device = "cpu"
411
-
406
+ def create_quantized_state_dict (self ):
412
407
cur_state_dict = self .mod .state_dict ()
413
408
for fqn , mod in self .mod .named_modules ():
414
409
if isinstance (mod , torch .nn .Linear ):
@@ -431,15 +426,15 @@ def create_quantized_state_dict(self, use_cuda = True):
431
426
"and that groupsize and inner_k_tiles*16 evenly divide into it" )
432
427
continue
433
428
weight_int4pack , scales_and_zeros = prepare_int4_weight_and_scales_and_zeros (
434
- weight .to (torch .bfloat16 ). to ( device = device ) , self .groupsize , self .inner_k_tiles
429
+ weight .to (torch .bfloat16 ), self .groupsize , self .inner_k_tiles
435
430
)
436
431
cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack .to ('cpu' )
437
432
cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros .to ('cpu' )
438
433
439
434
return cur_state_dict
440
435
441
- def convert_for_runtime (self ):
442
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
436
+ def convert_for_runtime (self , use_cuda ):
437
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
443
438
return self .mod
444
439
445
440
class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
@@ -476,8 +471,8 @@ def make_names_and_values_dict_func(q, qparams):
476
471
super ().__init__ ()
477
472
478
473
479
- def convert_for_runtime (self ):
480
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
474
+ def convert_for_runtime (self , use_cuda ):
475
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
481
476
return self .mod
482
477
483
478
class WeightOnlyInt4Linear (torch .nn .Module ):
@@ -488,7 +483,7 @@ class WeightOnlyInt4Linear(torch.nn.Module):
488
483
489
484
def __init__ (
490
485
self , in_features : int , out_features : int ,
491
- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True ,
486
+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True , use_cuda = True ,
492
487
) -> None :
493
488
super ().__init__ ()
494
489
self .padding = padding
@@ -505,10 +500,16 @@ def __init__(
505
500
506
501
assert out_features % 8 == 0 , "require out_features % 8 == 0"
507
502
assert in_features % (inner_k_tiles * 16 ) == 0 , "require in_features % (innerKTiles * 16) == 0"
508
- self .register_buffer (
509
- "weight" ,
510
- torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
511
- )
503
+ if use_cuda :
504
+ self .register_buffer (
505
+ "weight" ,
506
+ torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
507
+ )
508
+ else :
509
+ self .register_buffer (
510
+ "weight" ,
511
+ torch .empty ((out_features , in_features // 2 ), dtype = torch .uint8 )
512
+ )
512
513
self .register_buffer (
513
514
"scales_and_zeros" ,
514
515
torch .empty ((in_features // groupsize , out_features , 2 ), dtype = torch .bfloat16 )
@@ -538,10 +539,10 @@ def quantize(
538
539
percdamp : float = .01 ,
539
540
blocksize : int = 128 ,
540
541
label : str = '' ,
542
+ device : str = 'cuda' ,
541
543
) -> None :
542
544
assert checkpoint_path .is_file (), checkpoint_path
543
545
544
- device = 'cpu'
545
546
precision = torch .bfloat16
546
547
547
548
print ("Loading model ..." )
@@ -565,12 +566,13 @@ def quantize(
565
566
566
567
elif mode == 'int4' :
567
568
print ("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" )
569
+ print (f"Prepacking model weights in { device } optimal layout" )
568
570
quant_handler = WeightOnlyInt4QuantHandler (model , groupsize )
569
571
quantized_state_dict = quant_handler .create_quantized_state_dict ()
570
572
571
573
dir_name = checkpoint_path .parent
572
574
base_name = checkpoint_path .name
573
- new_base_name = base_name .replace ('.pth' , f"{ label } int4.g{ groupsize } .pth" )
575
+ new_base_name = base_name .replace ('.pth' , f"{ label } int4.g{ groupsize } .{ device } . pth" )
574
576
575
577
elif mode == 'int4-gptq' :
576
578
print ("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ..." )
@@ -617,6 +619,7 @@ def quantize(
617
619
parser .add_argument ('--percdamp' , type = float , default = .01 , help = 'gptq percentage dampening' )
618
620
parser .add_argument ('--blocksize' , type = int , default = 128 , help = 'blocksize for gptq' )
619
621
parser .add_argument ('--label' , type = str , default = '_' , help = 'label to add to output filename' )
622
+ parser .add_argument ('--device' , type = str , default = 'cuda' , help = 'device to use' )
620
623
621
624
args = parser .parse_args ()
622
- quantize (args .checkpoint_path , args .mode , args .groupsize , args .calibration_tasks , args .calibration_limit , args .calibration_seq_length , args .pad_calibration_inputs , args .percdamp , args .blocksize , args .label )
625
+ quantize (args .checkpoint_path , args .mode , args .groupsize , args .calibration_tasks , args .calibration_limit , args .calibration_seq_length , args .pad_calibration_inputs , args .percdamp , args .blocksize , args .label , args . device )
0 commit comments