Skip to content

Add diffusion model implementation #408

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 47 commits into
base: dev
Choose a base branch
from
Draft

Add diffusion model implementation #408

wants to merge 47 commits into from

Conversation

vpratz
Copy link
Collaborator

@vpratz vpratz commented Apr 13, 2025

This PR adds a diffusion model implementation for use as an inference network, as discussed in #403. It implements the de#troduced as "EDM" in [1]. The overall structure is taken from the FlowMatching class.

@arrjon @niels-leif-bracher I would appreciate if you take a look and make suggestions regarding how we can incorporate the other diffusion model variants as well. For now, I chose to only expose the sigma_data parameter to the end user, and keep everything else private. This should enable us to also change the internals later on and incrementally add new functionality.

Please let me know how we want to proceed and how much capacity you have to move this forward, so that we can decide whether we want to include the additional options before we merge, or if we merge early and then incrementally add to it later. I have situated the class in the experimental module for now, so that we have some freedom to also change things in the future as we see fit.

[1] https://arxiv.org/abs/2206.00364

Preliminary implementation, to be extended with other variants as well.
@vpratz vpratz added the feature New feature or request label Apr 13, 2025
Copy link

codecov bot commented Apr 13, 2025

Codecov Report

Attention: Patch coverage is 53.65239% with 184 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
bayesflow/experimental/diffusion_model.py 60.28% 137 Missing ⚠️
bayesflow/utils/integrate.py 6.00% 47 Missing ⚠️
Files with missing lines Coverage Δ
bayesflow/experimental/__init__.py 100.00% <100.00%> (ø)
bayesflow/utils/__init__.py 100.00% <100.00%> (ø)
bayesflow/utils/optimal_transport/log_sinkhorn.py 90.47% <ø> (-9.53%) ⬇️
bayesflow/utils/integrate.py 39.89% <6.00%> (-12.19%) ⬇️
bayesflow/experimental/diffusion_model.py 60.28% <60.28%> (ø)

... and 10 files with indirect coverage changes

@arrjon arrjon self-assigned this Apr 14, 2025
@arrjon
Copy link
Collaborator

arrjon commented Apr 14, 2025

Thanks @vpratz for the implementation! sigma_data is already specific for the EDM version. I think, we have to make it even more general, so some additional arguments which can be passed, depending on the type of schedule one wants to use.

I plan to add on top of this additional schedules and samplers until the end of the week.

@vpratz
Copy link
Collaborator Author

vpratz commented Apr 14, 2025

Thanks for taking a look. Do you know whether your implementation would benefit from the pre-conditioning discussed in Elucidating the Design Space of Diffusion-Based Generative Models, and whether we can combine them in one joint framework?

@arrjon
Copy link
Collaborator

arrjon commented Apr 14, 2025

Part of the pre-conditioning can be expressed as a special kind of weighting function: see appendix D.1 in here.

So yes, the aim would be to have one nice framework!

@arrjon
Copy link
Collaborator

arrjon commented Apr 16, 2025

I added some more noise schedules and started to make the implementation more general. This is just a first draft, so you @vpratz get an idea, how we could do it. We should discuss this then and how to move forward.

* fix optimal transport config (#429)

* run linter

* [skip-ci] bump version to 2.0.1
Base automatically changed from dev to main April 22, 2025 14:37
LarsKue and others added 3 commits April 22, 2025 14:48
Optimal transport hot fixes, consistency models test, and fast imports
@arrjon
Copy link
Collaborator

arrjon commented Apr 23, 2025

I added a class NoiseSchedule and different schedules, so it should be easy now to extend to more schedules if necessary. Since EDM has a specific sampling scheme for inference, this is now also defined in the noise schedule. Therefore, we do not have to specify specific sampling step sizes anymore.

Next step would be add stochastic samplers as well.

@vpratz vpratz changed the base branch from main to dev April 24, 2025 07:17
@arrjon
Copy link
Collaborator

arrjon commented Apr 24, 2025

I added a simple stochastic solver: euler maruyama

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
feature New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants