Skip to content

Commit f68e81e

Browse files
authored
Merge pull request #123 from mingfeima/pr_weight_only_quantization_cpu
Add weight only quantization support for cpu device
2 parents 1c23b94 + fba5d25 commit f68e81e

File tree

3 files changed

+44
-34
lines changed

3 files changed

+44
-34
lines changed

README.md

+9-4
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,11 @@ python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth
123123
To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though.
124124

125125
## Quantization
126+
Choose device to use by
127+
```bash
128+
# The current support devices: cuda, cpu
129+
export DEVICE=cuda
130+
```
126131
### Int8 Weight-Only Quantization
127132
To generate this version of the model
128133
```bash
@@ -131,19 +136,19 @@ python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode in
131136
```
132137
To run with int8, just pass the int8 checkpoint to generate.py.
133138
```bash
134-
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth
139+
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --device $DEVICE
135140
```
136141

137142
### Int4 Weight-Only Quantization
138143
To generate int4 version of model
139144
```bash
140-
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.pth
141-
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32
145+
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth
146+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32 --device $DEVICE
142147
```
143148

144149
To run with int4, just pass the int4 checkpoint to generate.py.
145150
```bash
146-
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
151+
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth --compile --device $DEVICE
147152
```
148153

149154
## Speculative Sampling

generate.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def encode_tokens(tokenizer, string, bos=True, device='cuda'):
213213
return torch.tensor(tokens, dtype=torch.int, device=device)
214214

215215
def _load_model(checkpoint_path, device, precision, use_tp):
216+
use_cuda = 'cuda' in device
216217
with torch.device('meta'):
217218
model = Transformer.from_name(checkpoint_path.parent.name)
218219

@@ -223,13 +224,14 @@ def _load_model(checkpoint_path, device, precision, use_tp):
223224
model = simple_quantizer.convert_for_runtime()
224225

225226
if "int4" in str(checkpoint_path):
226-
print("Using int4 quantization!")
227+
print("Using int4 weight-only quantization!")
227228
path_comps = checkpoint_path.name.split(".")
228-
assert path_comps[-2].startswith("g")
229-
groupsize = int(path_comps[-2][1:])
229+
assert path_comps[-3].startswith("g")
230+
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
231+
groupsize = int(path_comps[-3][1:])
230232
from quantize import WeightOnlyInt4QuantHandler
231233
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
232-
model = simple_quantizer.convert_for_runtime()
234+
model = simple_quantizer.convert_for_runtime(use_cuda)
233235

234236
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
235237
model.load_state_dict(checkpoint, assign=True)
@@ -412,7 +414,7 @@ def callback(x):
412414
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
413415
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
414416
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
415-
parser.add_argument('--device', type=str, default="cuda", help='device to use')
417+
parser.add_argument('--device', type=str, default="cuda", help='Device to use')
416418

417419
args = parser.parse_args()
418420
main(

quantize.py

+28-25
Original file line numberDiff line numberDiff line change
@@ -326,8 +326,8 @@ def create_quantized_state_dict(self):
326326
for fqn, mod in self.mod.named_modules():
327327
if isinstance(mod, torch.nn.Linear):
328328
int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
329-
cur_state_dict[f"{fqn}.weight"] = int8_weight
330-
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
329+
cur_state_dict[f"{fqn}.weight"] = int8_weight.to('cpu')
330+
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to('cpu')
331331

332332
return cur_state_dict
333333

@@ -376,21 +376,21 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
376376
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
377377
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
378378

379-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
379+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding, use_cuda):
380380
for name, child in module.named_children():
381381
if isinstance(child, nn.Linear):
382382
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
383383
setattr(module, name, WeightOnlyInt4Linear(
384384
child.in_features, child.out_features, bias=False,
385-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
385+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False, use_cuda=use_cuda
386386
))
387387
elif padding:
388388
setattr(module, name, WeightOnlyInt4Linear(
389389
child.in_features, child.out_features, bias=False,
390-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
390+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True, use_cuda=use_cuda
391391
))
392392
else:
393-
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
393+
replace_linear_int4(child, groupsize, inner_k_tiles, padding, use_cuda)
394394

395395

396396
class WeightOnlyInt4QuantHandler:
@@ -403,12 +403,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
403403
assert inner_k_tiles in [2, 4, 8]
404404

405405
@torch.no_grad()
406-
def create_quantized_state_dict(self, use_cuda = True):
407-
if use_cuda:
408-
device="cuda"
409-
else:
410-
device="cpu"
411-
406+
def create_quantized_state_dict(self):
412407
cur_state_dict = self.mod.state_dict()
413408
for fqn, mod in self.mod.named_modules():
414409
if isinstance(mod, torch.nn.Linear):
@@ -431,15 +426,15 @@ def create_quantized_state_dict(self, use_cuda = True):
431426
"and that groupsize and inner_k_tiles*16 evenly divide into it")
432427
continue
433428
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
434-
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
429+
weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles
435430
)
436431
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
437432
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
438433

439434
return cur_state_dict
440435

441-
def convert_for_runtime(self):
442-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
436+
def convert_for_runtime(self, use_cuda):
437+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
443438
return self.mod
444439

445440
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
@@ -476,8 +471,8 @@ def make_names_and_values_dict_func(q, qparams):
476471
super().__init__()
477472

478473

479-
def convert_for_runtime(self):
480-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
474+
def convert_for_runtime(self, use_cuda):
475+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
481476
return self.mod
482477

483478
class WeightOnlyInt4Linear(torch.nn.Module):
@@ -488,7 +483,7 @@ class WeightOnlyInt4Linear(torch.nn.Module):
488483

489484
def __init__(
490485
self, in_features: int, out_features: int,
491-
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
486+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True, use_cuda=True,
492487
) -> None:
493488
super().__init__()
494489
self.padding = padding
@@ -505,10 +500,16 @@ def __init__(
505500

506501
assert out_features % 8 == 0, "require out_features % 8 == 0"
507502
assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0"
508-
self.register_buffer(
509-
"weight",
510-
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
511-
)
503+
if use_cuda:
504+
self.register_buffer(
505+
"weight",
506+
torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32)
507+
)
508+
else:
509+
self.register_buffer(
510+
"weight",
511+
torch.empty((out_features, in_features // 2), dtype=torch.uint8)
512+
)
512513
self.register_buffer(
513514
"scales_and_zeros",
514515
torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16)
@@ -538,10 +539,10 @@ def quantize(
538539
percdamp: float = .01,
539540
blocksize: int = 128,
540541
label: str = '',
542+
device: str = 'cuda',
541543
) -> None:
542544
assert checkpoint_path.is_file(), checkpoint_path
543545

544-
device = 'cpu'
545546
precision = torch.bfloat16
546547

547548
print("Loading model ...")
@@ -565,12 +566,13 @@ def quantize(
565566

566567
elif mode == 'int4':
567568
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
569+
print(f"Prepacking model weights in {device} optimal layout")
568570
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
569571
quantized_state_dict = quant_handler.create_quantized_state_dict()
570572

571573
dir_name = checkpoint_path.parent
572574
base_name = checkpoint_path.name
573-
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
575+
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.{device}.pth")
574576

575577
elif mode == 'int4-gptq':
576578
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
@@ -617,6 +619,7 @@ def quantize(
617619
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
618620
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
619621
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
622+
parser.add_argument('--device', type=str, default='cuda', help='device to use')
620623

621624
args = parser.parse_args()
622-
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)
625+
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device)

0 commit comments

Comments
 (0)