From bd2d9d0ab3c2f53cf95a42c2413dad84fe585af3 Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Mon, 13 Jan 2025 16:31:01 +0800 Subject: [PATCH] Fix matryoshka norm loss (#9773) * [Trainer] update sequence parallel (#9757) * update emb doc * update register_sequence_parallel_allreduce_hooks * update fuse_sequence_parallel_allreduce * fix matryoshka --- paddlenlp/transformers/contrastive_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/contrastive_loss.py b/paddlenlp/transformers/contrastive_loss.py index 0252c0712a27..9c61cc0cf4db 100644 --- a/paddlenlp/transformers/contrastive_loss.py +++ b/paddlenlp/transformers/contrastive_loss.py @@ -52,10 +52,10 @@ def forward(self, q_reps, p_reps): if len(self.embedding_matryoshka_dims) > 0: loss = 0.0 for dim in self.embedding_matryoshka_dims: - reduced_q_reps = q_reps[:, :dim] + reduced_q_reps = q_reps[:, :dim].astype("float32") reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1) - reduced_p_reps = p_reps[:, :dim] + reduced_p_reps = p_reps[:, :dim].astype("float32") reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1) dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps)