From 4a6e0c06240b45185709ac4a2902ec42518049d2 Mon Sep 17 00:00:00 2001 From: jxysoft Date: Sat, 16 Dec 2023 05:01:12 +0800 Subject: [PATCH] =?UTF-8?q?fix=20[BUG]=20'DeepSpeedGPTInference'=20object?= =?UTF-8?q?=20has=20no=20attribute=20'dtype'=20for=E2=80=A6=20(#4814)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deepspeed/model_implementations/transformers/ds_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index a41df58ad059..d87d0de997b5 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -163,7 +163,7 @@ def forward( if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \ and input.dtype == torch.float: - target_dtype = torch.half if self.dtype == torch.int8 else self.dtype + target_dtype = torch.half if self.config.dtype == torch.int8 else self.config.dtype input = input.to(target_dtype) with torch.no_grad():