diff --git a/candle-holder-models/src/models/llama/modeling.rs b/candle-holder-models/src/models/llama/modeling.rs index 8c7814e..151e423 100644 --- a/candle-holder-models/src/models/llama/modeling.rs +++ b/candle-holder-models/src/models/llama/modeling.rs @@ -3,8 +3,8 @@ use std::sync::Arc; use candle_core::{DType, Device, Module, Tensor}; use candle_holder::{Error, Result}; use candle_nn::{ - embedding, init::DEFAULT_KAIMING_NORMAL, linear_no_bias, ops::softmax_last_dim, rms_norm, - rotary_emb::rope, Dropout, Embedding, Linear, RmsNorm, VarBuilder, + embedding, linear_no_bias, ops::softmax_last_dim, rms_norm, rotary_emb::rope, Dropout, + Embedding, Linear, RmsNorm, VarBuilder, }; use crate::{ @@ -417,12 +417,11 @@ pub struct LlamaForCausalLM { impl LlamaForCausalLM { fn load_lm_head(vb: VarBuilder, config: &LlamaConfig) -> Result { let lm_head = if config.tie_word_embeddings.unwrap_or(false) { - let init_ws = DEFAULT_KAIMING_NORMAL; - let ws = vb - .pp("model.embed_tokens") - .get_with_hints((config.vocab_size, config.hidden_size), "weight", init_ws)? - .t()?; - Linear::new(ws, None) + linear_no_bias( + config.hidden_size, + config.vocab_size, + vb.pp("model.embed_tokens"), + )? } else { linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))? };