Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[example] Added gemma support #115

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ModelArgs:
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
head_dim: int = None
rope_base: float = 10000
norm_eps: float = 1e-5

Expand All @@ -37,7 +37,8 @@ def __post_init__(self):
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
if self.head_dim is None:
self.head_dim = self.dim // self.n_head

@classmethod
def from_name(cls, name: str):
Expand All @@ -50,6 +51,8 @@ def from_name(cls, name: str):


transformer_configs = {
"gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384),
"gemma-7b": dict(dim=3072, vocab_size=256000, n_layer=28, n_head=16, n_local_heads=16, intermediate_size=24576, head_dim=256),
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
Expand Down Expand Up @@ -94,21 +97,21 @@ def __init__(self, config: ModelArgs) -> None:
def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
return
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, self.config.head_dim)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.head_dim, self.config.rope_base)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)
x = (self.config.dim ** 0.5) * x

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
Expand Down Expand Up @@ -143,7 +146,7 @@ def __init__(self, config: ModelArgs):
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
# key, query, value projections for all heads, but in a batch
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.wo = nn.Linear(config.n_head * config.head_dim, config.dim, bias=False)
self.kv_cache = None

self.n_head = config.n_head
Expand All @@ -163,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q, k, v = self.wqkv(x).split([self.n_head * self.head_dim, kv_size, kv_size], dim=-1)

q = q.view(bsz, seqlen, self.n_head, self.head_dim)
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
Expand All @@ -181,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.n_head * self.head_dim)

y = self.wo(y)
return y
Expand All @@ -195,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None:
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
return self.w2(F.gelu(self.w1(x), approximate="tanh") * self.w3(x))


class RMSNorm(nn.Module):
Expand All @@ -209,7 +212,7 @@ def _norm(self, x):

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
return output * (1 + self.weight)


def precompute_freqs_cis(
Expand Down
10 changes: 8 additions & 2 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def convert_hf_checkpoint(
config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

from safetensors import safe_open

# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = checkpoint_dir / "model.safetensors.index.json"

assert model_map_json.is_file()

Expand Down Expand Up @@ -65,7 +67,8 @@ def permute(w, n_head):

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
state_dict = safe_open(str(file), framework="pt", device='cpu')
state_dict = {k: state_dict.get_tensor(k) for k in state_dict.keys()}
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
Expand All @@ -92,6 +95,9 @@ def permute(w, n_head):
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
if "output.weight" not in final_result:
final_result["output.weight"] = final_result["tok_embeddings.weight"]

print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")

Expand Down