Skip to content

Commit

Permalink
Update tie_word_embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 3, 2024
1 parent edd8641 commit 7dba0b8
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions candle-holder-models/src/models/llama/modeling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use std::sync::Arc;
use candle_core::{DType, Device, Module, Tensor};
use candle_holder::{Error, Result};
use candle_nn::{
embedding, init::DEFAULT_KAIMING_NORMAL, linear_no_bias, ops::softmax_last_dim, rms_norm,
rotary_emb::rope, Dropout, Embedding, Linear, RmsNorm, VarBuilder,
embedding, linear_no_bias, ops::softmax_last_dim, rms_norm, rotary_emb::rope, Dropout,
Embedding, Linear, RmsNorm, VarBuilder,
};

use crate::{
Expand Down Expand Up @@ -417,12 +417,11 @@ pub struct LlamaForCausalLM {
impl LlamaForCausalLM {
fn load_lm_head(vb: VarBuilder, config: &LlamaConfig) -> Result<Linear> {
let lm_head = if config.tie_word_embeddings.unwrap_or(false) {
let init_ws = DEFAULT_KAIMING_NORMAL;
let ws = vb
.pp("model.embed_tokens")
.get_with_hints((config.vocab_size, config.hidden_size), "weight", init_ws)?
.t()?;
Linear::new(ws, None)
linear_no_bias(
config.hidden_size,
config.vocab_size,
vb.pp("model.embed_tokens"),
)?
} else {
linear_no_bias(config.hidden_size, config.vocab_size, vb.pp("lm_head"))?
};
Expand Down

0 comments on commit 7dba0b8

Please # to comment.