From 912b3705030b9583e4dc7c48e31198de6a1081e1 Mon Sep 17 00:00:00 2001 From: Trey Pendragon Date: Fri, 27 Jun 2025 12:25:17 -0700 Subject: [PATCH 1/3] Convert attention_bias to the right dtype. --- backends/candle/src/models/qwen3.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index 1913174e..950e8e42 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -514,7 +514,7 @@ impl Qwen3Model { let attention_bias = if masking { let attention_bias = - Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?; + Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?.to_dtype(self.dtype)?; // Broadcast once instead of at every layer let attention_bias = attention_bias .broadcast_as((batch_size, self.num_attention_heads, max_length, max_length))? From 3e322466547d2b5c5a9ce3dcc95677f5e2ead194 Mon Sep 17 00:00:00 2001 From: Trey Pendragon Date: Fri, 27 Jun 2025 13:55:34 -0700 Subject: [PATCH 2/3] Fix min value. --- backends/candle/src/models/qwen3.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index 950e8e42..064ab6db 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -455,8 +455,13 @@ impl Qwen3Model { let causal_mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?; let causal_mask = causal_mask.expand(&[bs, dim, seq_len, seq_len])?; + let min_value = match self.dtype { + DType::F32 => f32::MIN, + _ => -65504.0, // f16 minimum value + }; + let negatives = - Tensor::full(f32::MIN, attention_bias.shape(), device)?.to_dtype(self.dtype)?; + Tensor::full(min_value, attention_bias.shape(), device)?.to_dtype(self.dtype)?; let zeros = Tensor::zeros_like(&attention_bias)?.to_dtype(self.dtype)?; let causal_mask = causal_mask From dc9f696db90ac46f7ba75e1abaf28fe5ca2dfac8 Mon Sep 17 00:00:00 2001 From: Trey Pendragon Date: Wed, 2 Jul 2025 09:04:40 -0700 Subject: [PATCH 3/3] Run linting. --- backends/candle/src/models/qwen3.rs | 3 ++- backends/src/lib.rs | 7 ++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backends/candle/src/models/qwen3.rs b/backends/candle/src/models/qwen3.rs index 064ab6db..13309927 100644 --- a/backends/candle/src/models/qwen3.rs +++ b/backends/candle/src/models/qwen3.rs @@ -519,7 +519,8 @@ impl Qwen3Model { let attention_bias = if masking { let attention_bias = - Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)?.to_dtype(self.dtype)?; + Tensor::from_vec(attention_bias, (batch_size, 1, 1, max_length), &self.device)? + .to_dtype(self.dtype)?; // Broadcast once instead of at every layer let attention_bias = attention_bias .broadcast_as((batch_size, self.num_attention_heads, max_length, max_length))? diff --git a/backends/src/lib.rs b/backends/src/lib.rs index d333951c..be40b09b 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -150,11 +150,8 @@ impl Backend { } max_input_length = std::cmp::min(max_input_length, max_warmup_length); - let mut seq_lengths: Vec = generate_bucket_sizes( - seq_bucket_size, - max_input_length, - seq_len_exp_base, - ); + let mut seq_lengths: Vec = + generate_bucket_sizes(seq_bucket_size, max_input_length, seq_len_exp_base); if let Some(&last) = seq_lengths.last() { if last < max_input_length { seq_lengths.push(max_input_length);