Skip to content

Commit

Permalink
fix [BUG] 'DeepSpeedGPTInference' object has no attribute 'dtype' for… (
Browse files Browse the repository at this point in the history
  • Loading branch information
jxysoft authored Dec 15, 2023
1 parent 84eaf5a commit 4a6e0c0
Showing 1 changed file with 1 addition and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def forward(

if (self.config.dtype in [torch.float16, torch.bfloat16, torch.int8]) \
and input.dtype == torch.float:
target_dtype = torch.half if self.dtype == torch.int8 else self.dtype
target_dtype = torch.half if self.config.dtype == torch.int8 else self.config.dtype
input = input.to(target_dtype)

with torch.no_grad():
Expand Down

0 comments on commit 4a6e0c0

Please # to comment.