Skip to content

Commit

Permalink
Fix matryoshka norm loss (PaddlePaddle#9773)
Browse files Browse the repository at this point in the history
* [Trainer] update sequence parallel (PaddlePaddle#9757)

* update emb doc

* update register_sequence_parallel_allreduce_hooks

* update fuse_sequence_parallel_allreduce

* fix matryoshka
  • Loading branch information
DesmonDay authored Jan 13, 2025
1 parent 2e62501 commit bd2d9d0
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions paddlenlp/transformers/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def forward(self, q_reps, p_reps):
if len(self.embedding_matryoshka_dims) > 0:
loss = 0.0
for dim in self.embedding_matryoshka_dims:
reduced_q_reps = q_reps[:, :dim]
reduced_q_reps = q_reps[:, :dim].astype("float32")
reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1)

reduced_p_reps = p_reps[:, :dim]
reduced_p_reps = p_reps[:, :dim].astype("float32")
reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1)

dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps)
Expand Down

0 comments on commit bd2d9d0

Please # to comment.