diff --git a/run_alphafold.py b/run_alphafold.py index 53ec054..86d7ff1 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -396,7 +396,7 @@ def predict_structure( ) -> Sequence[ResultsForSeed]: """Runs the full inference pipeline to predict structures for each seed.""" - print(f'Featurising data for seeds {fold_input.rng_seeds}...') + print(f'Featurising data with {len(fold_input.rng_seeds)} seed(s)...') featurisation_start_time = time.time() ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd) featurised_examples = featurisation.featurise_input( @@ -407,28 +407,32 @@ def predict_structure( conformer_max_iterations=conformer_max_iterations, ) print( - f'Featurising data for seeds {fold_input.rng_seeds} took' + f'Featurising data with {len(fold_input.rng_seeds)} seed(s) took' f' {time.time() - featurisation_start_time:.2f} seconds.' ) + print( + 'Running model inference and extracting output structure samples with' + f' {len(fold_input.rng_seeds)} seed(s)...' + ) all_inference_start_time = time.time() all_inference_results = [] for seed, example in zip(fold_input.rng_seeds, featurised_examples): - print(f'Running model inference for seed {seed}...') + print(f'Running model inference with seed {seed}...') inference_start_time = time.time() rng_key = jax.random.PRNGKey(seed) result = model_runner.run_inference(example, rng_key) print( - f'Running model inference for seed {seed} took' + f'Running model inference with seed {seed} took' f' {time.time() - inference_start_time:.2f} seconds.' ) - print(f'Extracting output structures (one per sample) for seed {seed}...') + print(f'Extracting output structure samples with seed {seed}...') extract_structures = time.time() inference_results = model_runner.extract_structures( batch=example, result=result, target_name=fold_input.name ) print( - f'Extracting output structures (one per sample) for seed {seed} took' - f' {time.time() - extract_structures:.2f} seconds.' + f'Extracting {len(inference_results)} output structure samples with' + f' seed {seed} took {time.time() - extract_structures:.2f} seconds.' ) embeddings = model_runner.extract_embeddings(result) @@ -441,13 +445,9 @@ def predict_structure( embeddings=embeddings, ) ) - print( - 'Running model inference and extracting output structures for seed' - f' {seed} took {time.time() - inference_start_time:.2f} seconds.' - ) print( - 'Running model inference and extracting output structures for seeds' - f' {fold_input.rng_seeds} took' + 'Running model inference and extracting output structures with' + f' {len(fold_input.rng_seeds)} seed(s) took' f' {time.time() - all_inference_start_time:.2f} seconds.' ) return all_inference_results @@ -585,7 +585,7 @@ def process_fold_input( Raises: ValueError: If the fold input has no chains. """ - print(f'Processing fold input {fold_input.name}') + print(f'\nRunning fold job {fold_input.name}...') if not fold_input.chains: raise ValueError('Fold input has no chains.') @@ -595,16 +595,12 @@ def process_fold_input( f'{output_dir}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' ) print( - f'Output directory {output_dir} exists and is non-empty, using instead' - f' {new_output_dir}.' + f'Output will be written in {new_output_dir} since {output_dir} is' + ' non-empty.' ) output_dir = new_output_dir - - if model_runner is not None: - # If we're running inference, check we can load the model parameters before - # (possibly) launching the data pipeline. - print('Checking we can load the model parameters...') - _ = model_runner.model_params + else: + print(f'Output will be written in {output_dir}') if data_pipeline_config is None: print('Skipping data pipeline...') @@ -612,15 +608,14 @@ def process_fold_input( print('Running data pipeline...') fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input) - print(f'Output directory: {output_dir}') write_fold_input_json(fold_input, output_dir) if model_runner is None: - print('Skipping inference...') + print('Skipping model inference...') output = fold_input else: print( - f'Predicting 3D structure for {fold_input.name} for seed(s)' - f' {fold_input.rng_seeds}...' + f'Predicting 3D structure for {fold_input.name} with' + f' {len(fold_input.rng_seeds)} seed(s)...' ) all_inference_results = predict_structure( fold_input=fold_input, @@ -628,10 +623,7 @@ def process_fold_input( buckets=buckets, conformer_max_iterations=conformer_max_iterations, ) - print( - f'Writing outputs for {fold_input.name} for seed(s)' - f' {fold_input.rng_seeds}...' - ) + print(f'Writing outputs with {len(fold_input.rng_seeds)} seed(s)...') write_outputs( all_inference_results=all_inference_results, output_dir=output_dir, @@ -639,7 +631,7 @@ def process_fold_input( ) output = all_inference_results - print(f'Done processing fold input {fold_input.name}.') + print(f'Fold job {fold_input.name} done.\n') return output @@ -768,17 +760,16 @@ def main(_): device=devices[_GPU_DEVICE.value], model_dir=pathlib.Path(MODEL_DIR.value), ) + # Check we can load the model parameters before launching anything. + print('Checking that model parameters can be loaded...') + _ = model_runner.model_params else: - print('Skipping running model inference.') model_runner = None - print('Processing fold inputs.') num_fold_inputs = 0 for fold_input in fold_inputs: if _NUM_SEEDS.value is not None: - print( - f'Expanding fold input {fold_input.name} to {_NUM_SEEDS.value} seeds' - ) + print(f'Expanding fold job {fold_input.name} to {_NUM_SEEDS.value} seeds') fold_input = fold_input.with_multiple_seeds(_NUM_SEEDS.value) process_fold_input( fold_input=fold_input, @@ -790,11 +781,9 @@ def main(_): ) num_fold_inputs += 1 - print(f'Done processing {num_fold_inputs} fold inputs.') + print(f'Done running {num_fold_inputs} fold jobs.') if __name__ == '__main__': - flags.mark_flags_as_required([ - 'output_dir', - ]) + flags.mark_flags_as_required(['output_dir']) app.run(main) diff --git a/src/alphafold3/data/featurisation.py b/src/alphafold3/data/featurisation.py index a5d41fa..1ea5258 100644 --- a/src/alphafold3/data/featurisation.py +++ b/src/alphafold3/data/featurisation.py @@ -77,7 +77,7 @@ def featurise_input( for rng_seed in fold_input.rng_seeds: featurisation_start_time = time.time() if verbose: - print(f'Featurising {fold_input.name} with rng_seed {rng_seed}.') + print(f'Featurising data with seed {rng_seed}.') batch = data_pipeline.process_item( fold_input=fold_input, ccd=ccd, @@ -86,8 +86,8 @@ def featurise_input( ) if verbose: print( - f'Featurising {fold_input.name} with rng_seed {rng_seed} ' - f'took {time.time() - featurisation_start_time:.2f} seconds.' + f'Featurising data with seed {rng_seed} took' + f' {time.time() - featurisation_start_time:.2f} seconds.' ) batches.append(batch)