Skip to content

Commit 542480c

Browse files
authored
Fix static_llama to read some previously hardcoded options from ModelArgs
Differential Revision: D70414663 Pull Request resolved: #8846
1 parent 19a3002 commit 542480c

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

examples/qualcomm/oss_scripts/llama/model/static_llama.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
3737
super().__init__()
3838
self.dim = config.dim
3939
self.n_heads = config.n_heads
40-
self.head_dim = config.dim // config.n_heads
40+
self.head_dim = config.head_dim
4141
self.n_kv_heads = config.n_kv_heads
4242
self.num_key_value_groups = config.n_heads // self.n_kv_heads
4343
self.max_seq_len = config.max_seq_len
@@ -304,7 +304,7 @@ def __init__(
304304
):
305305
super().__init__()
306306
self.dim = config.dim
307-
self.head_dim = config.dim // config.n_heads
307+
self.head_dim = config.head_dim
308308
self.max_batch_size = config.max_batch_size
309309
self.max_seq_len = config.max_seq_len
310310
self.n_heads = config.n_heads
@@ -328,9 +328,11 @@ def __init__(
328328
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
329329
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
330330
freqs_cos, freqs_sin = precompute_freqs_cis(
331-
config.dim // config.n_heads,
331+
config.head_dim,
332332
config.max_seq_len,
333333
config.rope_freq_base,
334+
config.use_scaled_rope,
335+
config.rope_scale_factor,
334336
)
335337
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
336338
self.register_buffer("freqs_sin", freqs_sin, persistent=False)

0 commit comments

Comments
 (0)