Skip to content

Commit

Permalink
Make AlphaFold logging more consistent, logical, and brief
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720097581
Change-Id: Ic6d489d742aaf4d3bbdb7fcfa39910f4ee682e35
  • Loading branch information
Augustin-Zidek authored and copybara-github committed Jan 27, 2025
1 parent a5c1185 commit 7acc88f
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 43 deletions.
69 changes: 29 additions & 40 deletions run_alphafold.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def predict_structure(
) -> Sequence[ResultsForSeed]:
"""Runs the full inference pipeline to predict structures for each seed."""

print(f'Featurising data for seeds {fold_input.rng_seeds}...')
print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...')
featurisation_start_time = time.time()
ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd)
featurised_examples = featurisation.featurise_input(
Expand All @@ -407,28 +407,32 @@ def predict_structure(
conformer_max_iterations=conformer_max_iterations,
)
print(
f'Featurising data for seeds {fold_input.rng_seeds} took'
f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took'
f' {time.time() - featurisation_start_time:.2f} seconds.'
)
print(
'Running model inference and extracting output structure samples with'
f' {len(fold_input.rng_seeds)} seed(s)...'
)
all_inference_start_time = time.time()
all_inference_results = []
for seed, example in zip(fold_input.rng_seeds, featurised_examples):
print(f'Running model inference for seed {seed}...')
print(f'Running model inference with seed {seed}...')
inference_start_time = time.time()
rng_key = jax.random.PRNGKey(seed)
result = model_runner.run_inference(example, rng_key)
print(
f'Running model inference for seed {seed} took'
f'Running model inference with seed {seed} took'
f' {time.time() - inference_start_time:.2f} seconds.'
)
print(f'Extracting output structures (one per sample) for seed {seed}...')
print(f'Extracting output structure samples with seed {seed}...')
extract_structures = time.time()
inference_results = model_runner.extract_structures(
batch=example, result=result, target_name=fold_input.name
)
print(
f'Extracting output structures (one per sample) for seed {seed} took'
f' {time.time() - extract_structures:.2f} seconds.'
f'Extracting {len(inference_results)} output structure samples with'
f' seed {seed} took {time.time() - extract_structures:.2f} seconds.'
)

embeddings = model_runner.extract_embeddings(result)
Expand All @@ -441,13 +445,9 @@ def predict_structure(
embeddings=embeddings,
)
)
print(
'Running model inference and extracting output structures for seed'
f' {seed} took {time.time() - inference_start_time:.2f} seconds.'
)
print(
'Running model inference and extracting output structures for seeds'
f' {fold_input.rng_seeds} took'
'Running model inference and extracting output structures with'
f' {len(fold_input.rng_seeds)} seed(s) took'
f' {time.time() - all_inference_start_time:.2f} seconds.'
)
return all_inference_results
Expand Down Expand Up @@ -585,7 +585,7 @@ def process_fold_input(
Raises:
ValueError: If the fold input has no chains.
"""
print(f'Processing fold input {fold_input.name}')
print(f'\nRunning fold job {fold_input.name}...')

if not fold_input.chains:
raise ValueError('Fold input has no chains.')
Expand All @@ -595,51 +595,43 @@ def process_fold_input(
f'{output_dir}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}'
)
print(
f'Output directory {output_dir} exists and is non-empty, using instead'
f' {new_output_dir}.'
f'Output will be written in {new_output_dir} since {output_dir} is'
' non-empty.'
)
output_dir = new_output_dir

if model_runner is not None:
# If we're running inference, check we can load the model parameters before
# (possibly) launching the data pipeline.
print('Checking we can load the model parameters...')
_ = model_runner.model_params
else:
print(f'Output will be written in {output_dir}')

if data_pipeline_config is None:
print('Skipping data pipeline...')
else:
print('Running data pipeline...')
fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input)

print(f'Output directory: {output_dir}')
write_fold_input_json(fold_input, output_dir)
if model_runner is None:
print('Skipping inference...')
print('Skipping model inference...')
output = fold_input
else:
print(
f'Predicting 3D structure for {fold_input.name} for seed(s)'
f' {fold_input.rng_seeds}...'
f'Predicting 3D structure for {fold_input.name} with'
f' {len(fold_input.rng_seeds)} seed(s)...'
)
all_inference_results = predict_structure(
fold_input=fold_input,
model_runner=model_runner,
buckets=buckets,
conformer_max_iterations=conformer_max_iterations,
)
print(
f'Writing outputs for {fold_input.name} for seed(s)'
f' {fold_input.rng_seeds}...'
)
print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...')
write_outputs(
all_inference_results=all_inference_results,
output_dir=output_dir,
job_name=fold_input.sanitised_name(),
)
output = all_inference_results

print(f'Done processing fold input {fold_input.name}.')
print(f'Fold job {fold_input.name} done.\n')
return output


Expand Down Expand Up @@ -768,17 +760,16 @@ def main(_):
device=devices[_GPU_DEVICE.value],
model_dir=pathlib.Path(MODEL_DIR.value),
)
# Check we can load the model parameters before launching anything.
print('Checking that model parameters can be loaded...')
_ = model_runner.model_params
else:
print('Skipping running model inference.')
model_runner = None

print('Processing fold inputs.')
num_fold_inputs = 0
for fold_input in fold_inputs:
if _NUM_SEEDS.value is not None:
print(
f'Expanding fold input {fold_input.name} to {_NUM_SEEDS.value} seeds'
)
print(f'Expanding fold job {fold_input.name} to {_NUM_SEEDS.value} seeds')
fold_input = fold_input.with_multiple_seeds(_NUM_SEEDS.value)
process_fold_input(
fold_input=fold_input,
Expand All @@ -790,11 +781,9 @@ def main(_):
)
num_fold_inputs += 1

print(f'Done processing {num_fold_inputs} fold inputs.')
print(f'Done running {num_fold_inputs} fold jobs.')


if __name__ == '__main__':
flags.mark_flags_as_required([
'output_dir',
])
flags.mark_flags_as_required(['output_dir'])
app.run(main)
6 changes: 3 additions & 3 deletions src/alphafold3/data/featurisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def featurise_input(
for rng_seed in fold_input.rng_seeds:
featurisation_start_time = time.time()
if verbose:
print(f'Featurising {fold_input.name} with rng_seed {rng_seed}.')
print(f'Featurising data with seed {rng_seed}.')
batch = data_pipeline.process_item(
fold_input=fold_input,
ccd=ccd,
Expand All @@ -86,8 +86,8 @@ def featurise_input(
)
if verbose:
print(
f'Featurising {fold_input.name} with rng_seed {rng_seed} '
f'took {time.time() - featurisation_start_time:.2f} seconds.'
f'Featurising data with seed {rng_seed} took'
f' {time.time() - featurisation_start_time:.2f} seconds.'
)
batches.append(batch)

Expand Down

0 comments on commit 7acc88f

Please # to comment.