Skip to content

Commit

Permalink
verifier needed for inference time scaling
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 22, 2025
1 parent 8617eaf commit a36b554
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion alphafold3_pytorch/alphafold3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2835,6 +2835,7 @@ def sample(
use_tqdm_pbar = True,
tqdm_pbar_title = 'sampling time step',
return_all_timesteps = False,
verifier: Module | None = None,
**network_condition_kwargs
) -> Float['b m 3'] | Float['ts b m 3']:

Expand Down Expand Up @@ -6770,6 +6771,7 @@ def forward(
num_recycling_steps: int = 1,
diffusion_add_bond_loss: bool = False,
diffusion_add_smooth_lddt_loss: bool = False,
diffusion_verifier: Module | None = None,
distogram_atom_indices: Int['b n'] | None = None,
molecule_atom_indices: Int['b n'] | None = None, # the 'token centre atoms' mentioned in the paper, unsure where it is used in the architecture
num_sample_steps: int | None = None,
Expand Down Expand Up @@ -7187,7 +7189,8 @@ def forward(
pairwise_trunk = pairwise,
pairwise_rel_pos_feats = relative_position_encoding,
molecule_atom_lens = molecule_atom_lens,
return_all_timesteps = return_all_diffused_atom_pos
return_all_timesteps = return_all_diffused_atom_pos,
verifier = diffusion_verifier
)

if exists(atom_mask):
Expand Down

0 comments on commit a36b554

Please # to comment.