From 30c2e62d0d2cba00aba95bb5a98d654aac117987 Mon Sep 17 00:00:00 2001 From: Hao Jiang Date: Sun, 31 Mar 2024 16:55:46 +0800 Subject: [PATCH] Fix fused models for tf >= 4.39 --- awq/modules/fused/model.py | 8 ++++++++ setup.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index c1ba2c1e..8733722b 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -83,6 +83,14 @@ def __init__(self, vocab_size, blocks, embedding, norm): self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks) self.norm = norm self.last_forward_num_tokens = 0 + + @property + def embed_tokens(self): + return self.embedding + + @property + def layers(self): + return self.blocks @torch.inference_mode() def forward( diff --git a/setup.py b/setup.py index c02ff895..8ef3a96f 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ def get_kernels_whl_url( requirements = [ "torch>=2.0.1", - "transformers>=4.35.0,<=4.38.2", + "transformers>=4.35.0", "tokenizers>=0.12.1", "typing_extensions>=4.8.0", "accelerate",