Skip to content

Commit

Permalink
Fix BF16 cannot be used with metal
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 3, 2024
1 parent a7d643d commit edd8641
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions candle-holder-models/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,12 +296,15 @@ macro_rules! impl_from_pretrained_method {
.get_config()
.expect("Model config not found. Cannot load the model.")
.clone();
let dtype = dtype.unwrap_or(
let mut dtype = dtype.unwrap_or(
serde_json::from_value::<PretrainedConfig>(config.clone())
.expect("Could not parse model config.")
.get_dtype()
.unwrap_or($default_dtype),
);
if dtype == DType::BF16 && device.is_metal() {
dtype = DType::F16;
}
let vb = model_info.get_var_builder(dtype, device)?;
if $load_generation_config {
Self::load_with_generation_config(
Expand Down Expand Up @@ -350,12 +353,15 @@ macro_rules! impl_auto_model_from_pretrained_method {
let model: Result<Box<dyn PreTrainedModel>> = match model_type {
$(
$model_type => {
let dtype = dtype.unwrap_or(
let mut dtype = dtype.unwrap_or(
serde_json::from_value::<PretrainedConfig>(config.clone())
.expect("Could not parse model config.")
.get_dtype()
.unwrap_or($default_dtype),
);
if dtype == DType::BF16 && device.is_metal() {
dtype = DType::F16;
}
let vb = model_info.get_var_builder(dtype, device)?;
if $load_generation_config {
Ok(Box::new($model_struct::load_with_generation_config(vb, config, model_info.get_generation_config().cloned())?))
Expand Down

0 comments on commit edd8641

Please # to comment.