From 572e64456b6917b4ed648eb504f25c51a6b9379b Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 18 Nov 2022 18:56:50 +0000 Subject: [PATCH 1/4] Add configs to run int4 inference --- bloom-inference-scripts/bloom-ds-inference.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bloom-inference-scripts/bloom-ds-inference.py b/bloom-inference-scripts/bloom-ds-inference.py index 4bed6a2..c9191e2 100644 --- a/bloom-inference-scripts/bloom-ds-inference.py +++ b/bloom-inference-scripts/bloom-ds-inference.py @@ -44,7 +44,7 @@ parser = ArgumentParser() parser.add_argument("--name", required=True, type=str, help="model_name") -parser.add_argument("--dtype", type=str, help="float16 or int8", choices=["int8", "float16"], default="float16") +parser.add_argument("--dtype", type=str, help="float16 or int8 or int4", choices=["int8", "float16", "int4"], default="float16") parser.add_argument("--local_rank", required=False, type=int, help="used by dist launchers") parser.add_argument("--batch_size", default=1, type=int, help="batch size") parser.add_argument("--benchmark", action="store_true", help="additionally run benchmark") @@ -100,7 +100,7 @@ def get_checkpoint_files(model_name_or_path): model_name = args.name -infer_dtype = args.dtype +infer_dtype = args.dtype if args.dtype != 'int4' else 'int8' tp_presharded_mode = True if model_name in tp_presharded_models else False @@ -191,6 +191,7 @@ def write_checkponts_json(): mp_size=world_size, base_dir=repo_root, dtype=getattr(torch, infer_dtype), + quantization_bits=8 if args.dtype == 'int8' else 4, checkpoint=checkpoints_json, **kwargs, ) @@ -227,7 +228,7 @@ def write_checkponts_json(): # dynamically extend to support larger bs by repetition input_sentences *= math.ceil(args.batch_size / len(input_sentences)) -generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) +generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=True) print_rank0(f"Generate args {generate_kwargs}") From 132d99db5ebca42c786fe5f7cc054eec85e4316f Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Fri, 18 Nov 2022 21:50:15 +0000 Subject: [PATCH 2/4] fix quantization-bit config & turn off ds_sample --- bloom-inference-scripts/bloom-ds-inference.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/bloom-inference-scripts/bloom-ds-inference.py b/bloom-inference-scripts/bloom-ds-inference.py index c9191e2..0a8aa42 100644 --- a/bloom-inference-scripts/bloom-ds-inference.py +++ b/bloom-inference-scripts/bloom-ds-inference.py @@ -172,6 +172,9 @@ def write_checkponts_json(): if kernel_inject: kwargs = dict(replace_with_kernel_inject=True) + # specify number of bits to choose between in4/int8 + if args.dtype == 'int8' or args.dtype == 'int4': + kwargs.update({'quantization_bits': 8 if args.dtype == 'int8' else 4}) else: kwargs = dict(injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}) @@ -191,7 +194,6 @@ def write_checkponts_json(): mp_size=world_size, base_dir=repo_root, dtype=getattr(torch, infer_dtype), - quantization_bits=8 if args.dtype == 'int8' else 4, checkpoint=checkpoints_json, **kwargs, ) @@ -228,7 +230,7 @@ def write_checkponts_json(): # dynamically extend to support larger bs by repetition input_sentences *= math.ceil(args.batch_size / len(input_sentences)) -generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=True) +generate_kwargs = dict(max_new_tokens=num_tokens, do_sample=False) print_rank0(f"Generate args {generate_kwargs}") From 99cd7c9d8b2f0228145d0134d0a9570a6ac8cf71 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Sat, 19 Nov 2022 00:37:59 +0000 Subject: [PATCH 3/4] change the quantization config format to work with the new style at DeepSpeed --- bloom-inference-scripts/bloom-ds-inference.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bloom-inference-scripts/bloom-ds-inference.py b/bloom-inference-scripts/bloom-ds-inference.py index 0a8aa42..241b5c8 100644 --- a/bloom-inference-scripts/bloom-ds-inference.py +++ b/bloom-inference-scripts/bloom-ds-inference.py @@ -174,7 +174,10 @@ def write_checkponts_json(): kwargs = dict(replace_with_kernel_inject=True) # specify number of bits to choose between in4/int8 if args.dtype == 'int8' or args.dtype == 'int4': - kwargs.update({'quantization_bits': 8 if args.dtype == 'int8' else 4}) + quant_config = "{'quant': {'enabled':True, 'weight':{'num_bits': 8}}}" + kwargs.update(eval(quant_config)) + if args.dtype == 'int4': + kwargs['quant']['weight']['num_bits'] = 4 else: kwargs = dict(injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}) From 32779e8193787778b80b4a1cdcbd5d158c1b0c9f Mon Sep 17 00:00:00 2001 From: Ammar Ahmad Awan Date: Mon, 21 Nov 2022 10:05:37 -0800 Subject: [PATCH 4/4] Update bloom-ds-inference.py --- bloom-inference-scripts/bloom-ds-inference.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/bloom-inference-scripts/bloom-ds-inference.py b/bloom-inference-scripts/bloom-ds-inference.py index 241b5c8..3031426 100644 --- a/bloom-inference-scripts/bloom-ds-inference.py +++ b/bloom-inference-scripts/bloom-ds-inference.py @@ -171,13 +171,19 @@ def write_checkponts_json(): deepspeed.runtime.utils.see_memory_usage("pre-ds-inference-init", force=True) if kernel_inject: - kwargs = dict(replace_with_kernel_inject=True) - # specify number of bits to choose between in4/int8 - if args.dtype == 'int8' or args.dtype == 'int4': - quant_config = "{'quant': {'enabled':True, 'weight':{'num_bits': 8}}}" - kwargs.update(eval(quant_config)) - if args.dtype == 'int4': - kwargs['quant']['weight']['num_bits'] = 4 + if args.dtype == 'int8': + bits = 4 + if args.dtype == 'int4': + bits = 8 + ds_config = { + "replace_with_kernel_inject" : True, + "quant" : { + "enabled" : True, + "weight" : { + "num_bits" : bits + } + } + } else: kwargs = dict(injection_policy={BloomBlock: ("self_attention.dense", "mlp.dense_4h_to_h")}) @@ -194,6 +200,7 @@ def write_checkponts_json(): # checkpoints_json=None model = deepspeed.init_inference( model, + config=ds_config, mp_size=world_size, base_dir=repo_root, dtype=getattr(torch, infer_dtype),