Skip to content

Commit

Permalink
Added gemma-7b performance
Browse files Browse the repository at this point in the history
  • Loading branch information
Chillee committed Mar 7, 2024
1 parent ef055fc commit fc64185
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ModelArgs:
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
head_dim: int = None
rope_base: float = 10000
norm_eps: float = 1e-5

Expand All @@ -37,7 +37,8 @@ def __post_init__(self):
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
if self.head_dim is None:
self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
Expand All @@ -51,6 +52,7 @@ def from_name(cls, name: str):

transformer_configs = {
"gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384),
"gemma-7b": dict(dim=3072, vocab_size=256000, n_layer=28, n_head=16, n_local_heads=16, intermediate_size=24576, head_dim=256),
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
Expand Down Expand Up @@ -95,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None:
def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, self.config.head_dim)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(self, config: ModelArgs):
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.wo = nn.Linear(config.n_head * config.head_dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
Expand All @@ -165,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q, k, v = self.wqkv(x).split([self.n_head * self.head_dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand All @@ -183,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_head * self.head_dim)

y = self.wo(y)
return y
Expand All @@ -197,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None:
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.gelu(self.w1(x)) * self.w3(x))
return self.w2(F.gelu(self.w1(x)) * self.w3(x), approximate="tanh")


class RMSNorm(nn.Module):
Expand Down

0 comments on commit fc64185

Please # to comment.