Skip to content

Commit

Permalink
Try loading dtype from config.json
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 2, 2024
1 parent 8709023 commit 7d65b03
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
16 changes: 16 additions & 0 deletions candle-holder-models/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use candle_core::DType;
use candle_holder::utils::serde::deserialize_single_or_vec;
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
Expand Down Expand Up @@ -38,6 +39,8 @@ pub struct PretrainedConfig {
/// The IDs of the EOS tokens.
#[serde(default, deserialize_with = "deserialize_single_or_vec")]
eos_token_id: Option<Vec<u32>>,
#[serde(default)]
torch_dtype: Option<String>,
}

fn deserialize_id2label<'de, D>(deserializer: D) -> Result<Option<HashMap<usize, String>>, D::Error>
Expand Down Expand Up @@ -75,6 +78,18 @@ impl PretrainedConfig {
pub fn get_eos_token_id(&self) -> Option<&Vec<u32>> {
self.eos_token_id.as_ref()
}

pub fn get_dtype(&self) -> Option<DType> {
match &self.torch_dtype {
Some(dtype) => match dtype.as_str() {
"float16" => Some(DType::F16),
"float32" => Some(DType::F32),
"bfloat16" => Some(DType::BF16),
_ => None,
},
None => None,
}
}
}

impl PretrainedConfig {
Expand All @@ -96,6 +111,7 @@ impl Default for PretrainedConfig {
pad_token_id: None,
bos_token_id: None,
eos_token_id: None,
torch_dtype: None,
}
}
}
14 changes: 12 additions & 2 deletions candle-holder-models/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,12 @@ macro_rules! impl_from_pretrained_method {
.get_config()
.expect("Model config not found. Cannot load the model.")
.clone();
let dtype = dtype.unwrap_or($default_dtype);
let dtype = dtype.unwrap_or(
serde_json::from_value::<PretrainedConfig>(config.clone())
.expect("Could not parse model config.")
.get_dtype()
.unwrap_or($default_dtype),
);
let vb = model_info.get_var_builder(dtype, device)?;
if $load_generation_config {
Self::load_with_generation_config(
Expand Down Expand Up @@ -345,7 +350,12 @@ macro_rules! impl_auto_model_from_pretrained_method {
let model: Result<Box<dyn PreTrainedModel>> = match model_type {
$(
$model_type => {
let dtype = dtype.unwrap_or($default_dtype);
let dtype = dtype.unwrap_or(
serde_json::from_value::<PretrainedConfig>(config.clone())
.expect("Could not parse model config.")
.get_dtype()
.unwrap_or($default_dtype),
);
let vb = model_info.get_var_builder(dtype, device)?;
if $load_generation_config {
Ok(Box::new($model_struct::load_with_generation_config(vb, config, model_info.get_generation_config().cloned())?))
Expand Down

0 comments on commit 7d65b03

Please # to comment.