diff --git a/model.py b/model.py index b9ee468..136971c 100644 --- a/model.py +++ b/model.py @@ -26,7 +26,7 @@ class ModelArgs: dim: int = 4096 intermediate_size: int = None n_local_heads: int = -1 - head_dim: int = 64 + head_dim: int = None rope_base: float = 10000 norm_eps: float = 1e-5 @@ -37,7 +37,8 @@ def __post_init__(self): hidden_dim = 4 * self.dim n_hidden = int(2 * hidden_dim / 3) self.intermediate_size = find_multiple(n_hidden, 256) - self.head_dim = self.dim // self.n_head + if self.head_dim is None: + self.head_dim = self.dim // self.n_head @classmethod def from_name(cls, name: str): @@ -51,6 +52,7 @@ def from_name(cls, name: str): transformer_configs = { "gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384), + "gemma-7b": dict(dim=3072, vocab_size=256000, n_layer=28, n_head=16, n_local_heads=16, intermediate_size=24576, head_dim=256), "CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000), "7B": dict(n_layer=32, n_head=32, dim=4096), "13B": dict(n_layer=40, n_head=40, dim=5120), @@ -95,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None: def setup_caches(self, max_batch_size, max_seq_length): if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: return - head_dim = self.config.dim // self.config.n_head max_seq_length = find_multiple(max_seq_length, 8) self.max_seq_length = max_seq_length self.max_batch_size = max_batch_size for b in self.layers: - b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, self.config.head_dim) - self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base) self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: @@ -145,7 +146,7 @@ def __init__(self, config: ModelArgs): total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim # key, query, value projections for all heads, but in a batch self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) - self.wo = nn.Linear(config.dim, config.dim, bias=False) + self.wo = nn.Linear(config.n_head * config.head_dim, config.dim, bias=False) self.kv_cache = None self.n_head = config.n_head @@ -165,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim - q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + q, k, v = self.wqkv(x).split([self.n_head * self.head_dim, kv_size, kv_size], dim=-1) q = q.view(bsz, seqlen, self.n_head, self.head_dim) k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) @@ -183,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) - y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_head * self.head_dim) y = self.wo(y) return y @@ -197,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None: self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False) def forward(self, x: Tensor) -> Tensor: - return self.w2(F.gelu(self.w1(x)) * self.w3(x)) + return self.w2(F.gelu(self.w1(x)) * self.w3(x), approximate="tanh") class RMSNorm(nn.Module):