diff --git a/alphafold3_pytorch/trainer.py b/alphafold3_pytorch/trainer.py index 21e9b4bc..9a9676e5 100644 --- a/alphafold3_pytorch/trainer.py +++ b/alphafold3_pytorch/trainer.py @@ -40,6 +40,7 @@ from lion_pytorch.foreach import Lion from adam_atan2_pytorch.foreach import AdamAtan2 +from adam_atan2_pytorch.adopt_atan2 import AdoptAtan2 from ema_pytorch import EMA @@ -180,6 +181,7 @@ def __init__( ema_on_cpu = False, ema_update_model_with_ema_every: int | None = None, use_adam_atan2: bool = False, + use_adopt_atan2: bool = False, use_lion: bool = False, use_torch_compile: bool = False ): @@ -247,11 +249,14 @@ def __init__( if not exists(optimizer): optimizer_klass = Adam - assert at_most_one_of(use_adam_atan2, use_lion) + assert at_most_one_of(use_adam_atan2, use_adopt_atan2, use_lion) if use_adam_atan2: default_adam_kwargs.pop('eps', None) optimizer_klass = AdamAtan2 + elif use_adopt_atan2: + default_adam_kwargs.pop('eps', None) + optimizer_klass = AdoptAtan2 elif use_lion: default_adam_kwargs.pop('eps', None) optimizer_klass = Lion diff --git a/pyproject.toml b/pyproject.toml index f7ff43b8..2084cd59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "alphafold3-pytorch" -version = "0.6.7" +version = "0.6.8" description = "Alphafold 3 - Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }, @@ -24,7 +24,7 @@ classifiers=[ ] dependencies = [ - "adam-atan2-pytorch>=0.0.8", + "adam-atan2-pytorch>=0.1.12", "awscliv2>=2.3.1", "beartype", "biopython>=1.83",