diff --git a/model.py b/model.py index 130ab6b..ab8ba94 100644 --- a/model.py +++ b/model.py @@ -26,7 +26,7 @@ class ModelArgs: dim: int = 4096 intermediate_size: int = None n_local_heads: int = -1 - head_dim: int = 64 + head_dim: int = None rope_base: float = 10000 norm_eps: float = 1e-5 @@ -37,7 +37,8 @@ def __post_init__(self): hidden_dim = 4 * self.dim n_hidden = int(2 * hidden_dim / 3) self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head + if self.head_dim is None: + self.head_dim = self.dim // self.n_head @classmethod def from_name(cls, name: str): @@ -50,6 +51,8 @@ def from_name(cls, name: str): transformer_configs = { + "gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384), + "gemma-7b": dict(dim=3072, vocab_size=256000, n_layer=28, n_head=16, n_local_heads=16, intermediate_size=24576, head_dim=256), "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), "7B": dict(n_layer=32, n_head=32, dim=4096), "13B": dict(n_layer=40, n_head=40, dim=5120), @@ -94,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None: def setup_caches(self, max_batch_size, max_seq_length): if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: return - head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, self.config.head_dim) - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: @@ -109,6 +111,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: mask = self.causal_mask[None, None, input_pos] freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) + x = (self.config.dim ** 0.5) * x for i, layer in enumerate(self.layers): x = layer(x, input_pos, freqs_cis, mask) @@ -143,7 +146,7 @@ def __init__(self, config: ModelArgs): total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim # key, query, value projections for all heads, but in a batch self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) - self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.wo = nn.Linear(config.n_head * config.head_dim, config.dim, bias=False) self.kv_cache = None self.n_head = config.n_head @@ -163,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + q, k, v = self.wqkv(x).split([self.n_head * self.head_dim, kv_size, kv_size], dim=-1) q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) @@ -181,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_head * self.head_dim) y = self.wo(y) return y @@ -195,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None: self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) def forward(self, x: Tensor) -> Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + return self.w2(F.gelu(self.w1(x), approximate="tanh") * self.w3(x)) class RMSNorm(nn.Module): @@ -209,7 +212,7 @@ def _norm(self, x): def forward(self, x: Tensor) -> Tensor: output = self._norm(x.float()).type_as(x) - return output * self.weight + return output * (1 + self.weight) def precompute_freqs_cis( diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index b92114c..44fb618 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -30,8 +30,10 @@ def convert_hf_checkpoint( config = ModelArgs.from_name(model_name) print(f"Model config {config.__dict__}") + from safetensors import safe_open + # Load the json file containing weight mapping - model_map_json = checkpoint_dir / "pytorch_model.bin.index.json" + model_map_json = checkpoint_dir / "model.safetensors.index.json" assert model_map_json.is_file() @@ -65,7 +67,8 @@ def permute(w, n_head): merged_result = {} for file in sorted(bin_files): - state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) + state_dict = safe_open(str(file), framework="pt", device='cpu') + state_dict = {k: state_dict.get_tensor(k) for k in state_dict.keys()} merged_result.update(state_dict) final_result = {} for key, value in merged_result.items(): @@ -92,6 +95,9 @@ def permute(w, n_head): del final_result[key] del final_result[key.replace("wq", "wk")] del final_result[key.replace("wq", "wv")] + if "output.weight" not in final_result: + final_result["output.weight"] = final_result["tok_embeddings.weight"] + print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth")