From 7d65b032e65af2737e0d7b5d52b5d0530c95e5aa Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Tue, 3 Sep 2024 00:31:57 +0200 Subject: [PATCH] Try loading dtype from `config.json` --- candle-holder-models/src/config.rs | 16 ++++++++++++++++ candle-holder-models/src/model.rs | 14 ++++++++++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/candle-holder-models/src/config.rs b/candle-holder-models/src/config.rs index ba7e38e..08369ee 100644 --- a/candle-holder-models/src/config.rs +++ b/candle-holder-models/src/config.rs @@ -1,3 +1,4 @@ +use candle_core::DType; use candle_holder::utils::serde::deserialize_single_or_vec; use serde::{Deserialize, Deserializer, Serialize}; use std::collections::HashMap; @@ -38,6 +39,8 @@ pub struct PretrainedConfig { /// The IDs of the EOS tokens. #[serde(default, deserialize_with = "deserialize_single_or_vec")] eos_token_id: Option>, + #[serde(default)] + torch_dtype: Option, } fn deserialize_id2label<'de, D>(deserializer: D) -> Result>, D::Error> @@ -75,6 +78,18 @@ impl PretrainedConfig { pub fn get_eos_token_id(&self) -> Option<&Vec> { self.eos_token_id.as_ref() } + + pub fn get_dtype(&self) -> Option { + match &self.torch_dtype { + Some(dtype) => match dtype.as_str() { + "float16" => Some(DType::F16), + "float32" => Some(DType::F32), + "bfloat16" => Some(DType::BF16), + _ => None, + }, + None => None, + } + } } impl PretrainedConfig { @@ -96,6 +111,7 @@ impl Default for PretrainedConfig { pad_token_id: None, bos_token_id: None, eos_token_id: None, + torch_dtype: None, } } } diff --git a/candle-holder-models/src/model.rs b/candle-holder-models/src/model.rs index e963f95..62156e8 100644 --- a/candle-holder-models/src/model.rs +++ b/candle-holder-models/src/model.rs @@ -296,7 +296,12 @@ macro_rules! impl_from_pretrained_method { .get_config() .expect("Model config not found. Cannot load the model.") .clone(); - let dtype = dtype.unwrap_or($default_dtype); + let dtype = dtype.unwrap_or( + serde_json::from_value::(config.clone()) + .expect("Could not parse model config.") + .get_dtype() + .unwrap_or($default_dtype), + ); let vb = model_info.get_var_builder(dtype, device)?; if $load_generation_config { Self::load_with_generation_config( @@ -345,7 +350,12 @@ macro_rules! impl_auto_model_from_pretrained_method { let model: Result> = match model_type { $( $model_type => { - let dtype = dtype.unwrap_or($default_dtype); + let dtype = dtype.unwrap_or( + serde_json::from_value::(config.clone()) + .expect("Could not parse model config.") + .get_dtype() + .unwrap_or($default_dtype), + ); let vb = model_info.get_var_builder(dtype, device)?; if $load_generation_config { Ok(Box::new($model_struct::load_with_generation_config(vb, config, model_info.get_generation_config().cloned())?))