Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add data2vec loss class, rename modules #16

Merged
merged 6 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ mmlearn_run 'hydra.searchpath=[pkg://path.to.config.directory]' +experiment=<nam
Hydra will compose the experiment configuration from all the configurations in the specified directory as well as all the
configurations in the `mmlearn` package. *Note the dot-separated path to the directory containing the experiment configuration
files.*
One can add a path to `hydra.searchpath` either as a package (`pkg://path.to.config.directory`) or as a file system
(`file://path/to/config/directory`). However, new configs in `mmlearn` are added to hydra's external store inside
One can add a path to `hydra.searchpath` either as a package (`pkg://path.to.config.directory`) or as a file system
(`file://path/to/config/directory`). However, new configs in `mmlearn` are added to hydra's external store inside
`path/to/config/directory/__init__.py` which is only interpreted when the config directory is added as a package.
Hence, please refrain from using the `file://` notation.

Expand Down
4 changes: 2 additions & 2 deletions mmlearn/modules/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Encoders."""

from mmlearn.modules.encoders.clip_encoders import (
from mmlearn.modules.encoders.clip import (
HFCLIPTextEncoder,
HFCLIPTextEncoderWithProjection,
HFCLIPVisionEncoder,
HFCLIPVisionEncoderWithProjection,
PubMedBERTForCLIPTextEncoding,
)
from mmlearn.modules.encoders.hf_text_encoders import HFTextEncoder
from mmlearn.modules.encoders.text import HFTextEncoder


__all__ = [
Expand Down
5 changes: 3 additions & 2 deletions mmlearn/modules/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Loss functions."""

from mmlearn.modules.losses.contrastive_loss import CLIPLoss
from mmlearn.modules.losses.contrastive import CLIPLoss
from mmlearn.modules.losses.data2vec import Data2VecLoss


__all__ = ["CLIPLoss"]
__all__ = ["CLIPLoss", "Data2VecLoss"]
84 changes: 84 additions & 0 deletions mmlearn/modules/losses/data2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Implementation of Data2vec loss function."""

import math
from typing import Optional

import torch
from hydra_zen import store
from torch import nn
from torch.nn.functional import mse_loss, smooth_l1_loss


@store(group="modules/losses", provider="mmlearn")
class Data2VecLoss(nn.Module):
"""Data2Vec loss function.

Parameters
----------
beta : float, optional, default=0
Specifies the beta parameter for smooth L1 loss. If 0, MSE loss is used.
loss_scale : float or None, optional, default=None
Scaling factor for the loss. If None, uses 1 / sqrt(embedding_dim).
reduction : str, optional, default='none'
Specifies the reduction to apply to the output:
``'none'`` | ``'mean'`` | ``'sum'``.
"""

def __init__(
self,
beta: float = 0,
loss_scale: Optional[float] = None,
reduction: str = "none",
) -> None:
"""Initialize the loss."""
super().__init__()
self.beta = beta
self.loss_scale = loss_scale
if reduction not in ["none", "mean", "sum"]:
raise ValueError(f"Unsupported reduction mode: {reduction}")
self.reduction = reduction

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Compute the Data2Vec loss.

Parameters
----------
x : torch.Tensor of shape (batch_size, num_patches, embedding_dim)
Predicted embeddings.
y : torch.Tensor of shape (batch_size, num_patches, embedding_dim)
Target embeddings.

Returns
-------
torch.Tensor
Data2Vec loss value.

Raises
------
ValueError
If the shapes of x and y do not match.
"""
if x.shape != y.shape:
raise ValueError(f"Shape mismatch: x: {x.shape}, y: {y.shape}")

x = x.view(-1, x.size(-1)).float()
y = y.view(-1, y.size(-1))

if self.beta == 0:
loss = mse_loss(x, y, reduction="none")
else:
loss = smooth_l1_loss(x, y, reduction="none", beta=self.beta)

if self.loss_scale is not None:
scale = self.loss_scale
else:
scale = 1 / math.sqrt(x.size(-1))

loss = loss * scale

if self.reduction == "mean":
return loss.mean()
if self.reduction == "sum":
return loss.sum()
# 'none'
return loss.view(x.size(0), -1).sum(1)
2 changes: 1 addition & 1 deletion projects/bioscan_clip/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch

from mmlearn.conf import external_store
from mmlearn.modules.encoders.hf_text_encoders import HFTextEncoder
from mmlearn.modules.encoders.text import HFTextEncoder
from mmlearn.modules.encoders.vision import TimmViT

from projects.bioscan_clip.encoders import BarcodeBERT
Expand Down