Skip to content

Commit

Permalink
adopt Adopt
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 26, 2024
1 parent 308c470 commit 0791dfe
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
7 changes: 6 additions & 1 deletion alphafold3_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from lion_pytorch.foreach import Lion
from adam_atan2_pytorch.foreach import AdamAtan2
from adam_atan2_pytorch.adopt_atan2 import AdoptAtan2

from ema_pytorch import EMA

Expand Down Expand Up @@ -180,6 +181,7 @@ def __init__(
ema_on_cpu = False,
ema_update_model_with_ema_every: int | None = None,
use_adam_atan2: bool = False,
use_adopt_atan2: bool = False,
use_lion: bool = False,
use_torch_compile: bool = False
):
Expand Down Expand Up @@ -247,11 +249,14 @@ def __init__(
if not exists(optimizer):
optimizer_klass = Adam

assert at_most_one_of(use_adam_atan2, use_lion)
assert at_most_one_of(use_adam_atan2, use_adopt_atan2, use_lion)

if use_adam_atan2:
default_adam_kwargs.pop('eps', None)
optimizer_klass = AdamAtan2
elif use_adopt_atan2:
default_adam_kwargs.pop('eps', None)
optimizer_klass = AdoptAtan2
elif use_lion:
default_adam_kwargs.pop('eps', None)
optimizer_klass = Lion
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "alphafold3-pytorch"
version = "0.6.7"
version = "0.6.8"
description = "Alphafold 3 - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" },
Expand All @@ -24,7 +24,7 @@ classifiers=[
]

dependencies = [
"adam-atan2-pytorch>=0.0.8",
"adam-atan2-pytorch>=0.1.12",
"awscliv2>=2.3.1",
"beartype",
"biopython>=1.83",
Expand Down

0 comments on commit 0791dfe

Please # to comment.