@@ -37,7 +37,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False):
37
37
super ().__init__ ()
38
38
self .dim = config .dim
39
39
self .n_heads = config .n_heads
40
- self .head_dim = config .dim // config . n_heads
40
+ self .head_dim = config .head_dim
41
41
self .n_kv_heads = config .n_kv_heads
42
42
self .num_key_value_groups = config .n_heads // self .n_kv_heads
43
43
self .max_seq_len = config .max_seq_len
@@ -304,7 +304,7 @@ def __init__(
304
304
):
305
305
super ().__init__ ()
306
306
self .dim = config .dim
307
- self .head_dim = config .dim // config . n_heads
307
+ self .head_dim = config .head_dim
308
308
self .max_batch_size = config .max_batch_size
309
309
self .max_seq_len = config .max_seq_len
310
310
self .n_heads = config .n_heads
@@ -328,9 +328,11 @@ def __init__(
328
328
self .output = nn .Linear (config .dim , config .vocab_size , bias = False )
329
329
self .tok_embeddings = nn .Embedding (config .vocab_size , config .dim )
330
330
freqs_cos , freqs_sin = precompute_freqs_cis (
331
- config .dim // config . n_heads ,
331
+ config .head_dim ,
332
332
config .max_seq_len ,
333
333
config .rope_freq_base ,
334
+ config .use_scaled_rope ,
335
+ config .rope_scale_factor ,
334
336
)
335
337
self .register_buffer ("freqs_cos" , freqs_cos , persistent = False )
336
338
self .register_buffer ("freqs_sin" , freqs_sin , persistent = False )
0 commit comments