diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2773300caa3a..2a42432a7d98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit - rev: "v0.2.1" + rev: v0.9.2 hooks: - id: ruff language_version: python3 @@ -16,7 +16,7 @@ repos: hooks: - id: vermin - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.9.0 + rev: v1.14.1 hooks: - id: mypy additional_dependencies: diff --git a/docs/myproject/tasks.py b/docs/myproject/tasks.py index 40df99086bde..595403f33d4f 100644 --- a/docs/myproject/tasks.py +++ b/docs/myproject/tasks.py @@ -2,5 +2,4 @@ @task -def my_background_task(name: str): - ... +def my_background_task(name: str): ... diff --git a/flows/check_output_of_interrupted_serve.py b/flows/check_output_of_interrupted_serve.py index 90212be7af55..8bd85877b6b8 100644 --- a/flows/check_output_of_interrupted_serve.py +++ b/flows/check_output_of_interrupted_serve.py @@ -35,14 +35,14 @@ async def main(): # Check if each expected message is in the corresponding line for expected in expected_messages: - assert ( - expected in stderr_output - ), f"Expected '{expected}' not found in '{stderr_output}'" + assert expected in stderr_output, ( + f"Expected '{expected}' not found in '{stderr_output}'" + ) for unexpected in unexpected_messages: - assert ( - unexpected not in stderr_output - ), f"Unexpected '{unexpected}' found in '{stderr_output}'" + assert unexpected not in stderr_output, ( + f"Unexpected '{unexpected}' found in '{stderr_output}'" + ) print("All expected log messages were found") diff --git a/flows/load_flows_concurrently.py b/flows/load_flows_concurrently.py index 1c1925e994c5..849cc130e50f 100644 --- a/flows/load_flows_concurrently.py +++ b/flows/load_flows_concurrently.py @@ -28,9 +28,9 @@ async def run_stress_test(): for i in range(10): # Run 10 iterations try: count = await test_iteration() - print(f"Iteration {i+1}: Successfully loaded {count} flows") + print(f"Iteration {i + 1}: Successfully loaded {count} flows") except Exception as e: - print(f"Iteration {i+1}: Failed with error: {str(e)}") + print(f"Iteration {i + 1}: Failed with error: {str(e)}") return False return True diff --git a/flows/serve_a_flow.py b/flows/serve_a_flow.py index c11a0e2f1980..5f40707a0c65 100644 --- a/flows/serve_a_flow.py +++ b/flows/serve_a_flow.py @@ -48,8 +48,8 @@ def count_runs(counter_dir: Path): actual_run_count = count_runs(counter_dir) - assert ( - actual_run_count >= MINIMUM_EXPECTED_N_FLOW_RUNS - ), f"Expected at least {MINIMUM_EXPECTED_N_FLOW_RUNS} flow runs, got {actual_run_count}" + assert actual_run_count >= MINIMUM_EXPECTED_N_FLOW_RUNS, ( + f"Expected at least {MINIMUM_EXPECTED_N_FLOW_RUNS} flow runs, got {actual_run_count}" + ) print(f"Successfully completed and audited {actual_run_count} flow runs") diff --git a/flows/worker.py b/flows/worker.py index 9c8e109a7a0d..954f4f8f1b15 100644 --- a/flows/worker.py +++ b/flows/worker.py @@ -52,9 +52,9 @@ def main(): except subprocess.CalledProcessError as e: # Check that the error message contains kubernetes worker type for type in ["process", "kubernetes"]: - assert type in str( - e.output - ), f"Worker type {type!r} missing from output {e.output}" + assert type in str(e.output), ( + f"Worker type {type!r} missing from output {e.output}" + ) subprocess.check_call( ["prefect", "work-pool", "create", "test-worker-pool", "-t", "kubernetes"], @@ -87,9 +87,9 @@ def main(): ) worker_events = [e for e in events if e.event.startswith("prefect.worker.")] - assert ( - len(worker_events) == 2 - ), f"Expected 2 worker events, got {len(worker_events)}" + assert len(worker_events) == 2, ( + f"Expected 2 worker events, got {len(worker_events)}" + ) start_events = [e for e in worker_events if e.event == "prefect.worker.started"] stop_events = [e for e in worker_events if e.event == "prefect.worker.stopped"] @@ -99,9 +99,9 @@ def main(): print("Captured expected worker start and stop events!") - assert ( - stop_events[0].follows == start_events[0].id - ), "Stop event should follow start event" + assert stop_events[0].follows == start_events[0].id, ( + "Stop event should follow start event" + ) if __name__ == "__main__": diff --git a/requirements-dev.txt b/requirements-dev.txt index 72dfc684e4ee..cde265278b55 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,10 +1,8 @@ -ruff cairosvg codespell>=2.2.6 ipython jinja2 moto >= 5 -mypy >= 1.9.0 numpy pillow pre-commit @@ -23,11 +21,17 @@ redis>=5.0.1 setuptools uv>=0.4.5 vale -vermin virtualenv watchfiles respx +# Linters and dev tools that are also in .pre-commit-config.yaml, which +# should usually be updated together. +codespell==2.2.6 +ruff==0.9.2 +mypy==1.14.1 +vermin==1.6.0 + # type stubs types-cachetools types-PyYAML diff --git a/scripts/generate-lower-bounds.py b/scripts/generate-lower-bounds.py index b2e113324d7e..4b97fbd9d2eb 100755 --- a/scripts/generate-lower-bounds.py +++ b/scripts/generate-lower-bounds.py @@ -27,6 +27,7 @@ pip install $(generate-lower-bounds.py | tr "\n" " ") """ + import re import sys diff --git a/scripts/generate_settings_schema.py b/scripts/generate_settings_schema.py index cb366fb1df3b..e4cad0deb60e 100644 --- a/scripts/generate_settings_schema.py +++ b/scripts/generate_settings_schema.py @@ -11,9 +11,9 @@ def generate(self, schema, mode="validation"): json_schema = super().generate(schema, mode=mode) json_schema["title"] = "Prefect Settings" json_schema["$schema"] = self.schema_dialect - json_schema[ - "$id" - ] = "https://github.com/PrefectHQ/prefect/schemas/settings.schema.json" + json_schema["$id"] = ( + "https://github.com/PrefectHQ/prefect/schemas/settings.schema.json" + ) return json_schema diff --git a/scripts/wait-for-server.py b/scripts/wait-for-server.py index cb08f116711f..23b415ffca8b 100755 --- a/scripts/wait-for-server.py +++ b/scripts/wait-for-server.py @@ -13,7 +13,6 @@ PREFECT_API_URL="http://localhost:4200" ./scripts/wait-for-server.py """ - import sys import anyio diff --git a/setup.cfg b/setup.cfg index f22b738e3aef..9b50a4e5f7e4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -79,6 +79,7 @@ plugins= ignore_missing_imports = True follow_imports = skip +python_version = 3.9 [mypy-ruamel] ignore_missing_imports = True diff --git a/src/integrations/prefect-aws/prefect_aws/secrets_manager.py b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py index c13404b3246f..b21e7256836f 100644 --- a/src/integrations/prefect-aws/prefect_aws/secrets_manager.py +++ b/src/integrations/prefect-aws/prefect_aws/secrets_manager.py @@ -335,9 +335,9 @@ def example_delete_secret_with_recovery_window(): delete_secret_kwargs: Dict[str, Union[str, int, bool]] = dict(SecretId=secret_name) if force_delete_without_recovery: - delete_secret_kwargs[ - "ForceDeleteWithoutRecovery" - ] = force_delete_without_recovery + delete_secret_kwargs["ForceDeleteWithoutRecovery"] = ( + force_delete_without_recovery + ) else: delete_secret_kwargs["RecoveryWindowInDays"] = recovery_window_in_days diff --git a/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py b/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py index 54cf62958350..338e3015aad7 100644 --- a/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py +++ b/src/integrations/prefect-aws/prefect_aws/workers/ecs_worker.py @@ -902,9 +902,9 @@ def _watch_task_and_get_exit_code( # Check the status code of the Prefect container container = _get_container(task["containers"], container_name) - assert ( - container is not None - ), f"'{container_name}' container missing from task: {task}" + assert container is not None, ( + f"'{container_name}' container missing from task: {task}" + ) status_code = container.get("exitCode") self._report_container_status_code(logger, container_name, status_code) @@ -1552,12 +1552,12 @@ def _prepare_task_run_request( and configuration.network_configuration and configuration.vpc_id ): - task_run_request[ - "networkConfiguration" - ] = self._custom_network_configuration( - configuration.vpc_id, - configuration.network_configuration, - configuration, + task_run_request["networkConfiguration"] = ( + self._custom_network_configuration( + configuration.vpc_id, + configuration.network_configuration, + configuration, + ) ) # Ensure the container name is set if not provided at template time diff --git a/src/integrations/prefect-aws/tests/test_client_parameters.py b/src/integrations/prefect-aws/tests/test_client_parameters.py index 6435c81e7d9f..690b7aecc2b8 100644 --- a/src/integrations/prefect-aws/tests/test_client_parameters.py +++ b/src/integrations/prefect-aws/tests/test_client_parameters.py @@ -131,9 +131,9 @@ def test_get_params_override_with_both_cert_path(self, tmp_path): def test_get_params_override_with_default_verify(self): params = AwsClientParameters() override_params = params.get_params_override() - assert ( - "verify" not in override_params - ), "verify should not be in params_override when not explicitly set" + assert "verify" not in override_params, ( + "verify should not be in params_override when not explicitly set" + ) def test_get_params_override_with_explicit_verify(self): params_true = AwsClientParameters(verify=True) @@ -142,12 +142,12 @@ def test_get_params_override_with_explicit_verify(self): override_params_true = params_true.get_params_override() override_params_false = params_false.get_params_override() - assert ( - "verify" in override_params_true - ), "verify should be in params_override when explicitly set to True" + assert "verify" in override_params_true, ( + "verify should be in params_override when explicitly set to True" + ) assert override_params_true["verify"] is True - assert ( - "verify" in override_params_false - ), "verify should be in params_override when explicitly set to False" + assert "verify" in override_params_false, ( + "verify should be in params_override when explicitly set to False" + ) assert override_params_false["verify"] is False diff --git a/src/integrations/prefect-aws/tests/test_credentials.py b/src/integrations/prefect-aws/tests/test_credentials.py index 32e6c1c2a812..7c2043c74158 100644 --- a/src/integrations/prefect-aws/tests/test_credentials.py +++ b/src/integrations/prefect-aws/tests/test_credentials.py @@ -97,9 +97,9 @@ def test_aws_credentials_change_causes_cache_miss(client_type): new_client = credentials.get_client(client_type) - assert ( - initial_client is not new_client - ), "Client should be different after configuration change" + assert initial_client is not new_client, ( + "Client should be different after configuration change" + ) assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" @@ -125,9 +125,9 @@ def test_minio_credentials_change_causes_cache_miss(client_type): new_client = credentials.get_client(client_type) - assert ( - initial_client is not new_client - ), "Client should be different after configuration change" + assert initial_client is not new_client, ( + "Client should be different after configuration change" + ) assert _get_client_cached.cache_info().misses == 2, "Cache should miss twice" diff --git a/src/integrations/prefect-aws/tests/test_s3.py b/src/integrations/prefect-aws/tests/test_s3.py index 69121d21a807..3fd2ad632aef 100644 --- a/src/integrations/prefect-aws/tests/test_s3.py +++ b/src/integrations/prefect-aws/tests/test_s3.py @@ -1138,9 +1138,9 @@ def test_round_trip_default_credentials(self): # https://github.com/PrefectHQ/prefect/issues/13349 S3Bucket(bucket_name="round-trip-bucket").save("round-tripper") loaded = S3Bucket.load("round-tripper") - assert hasattr( - loaded.credentials, "aws_access_key_id" - ), "`credentials` were not properly initialized" + assert hasattr(loaded.credentials, "aws_access_key_id"), ( + "`credentials` were not properly initialized" + ) @pytest.mark.parametrize( "client_parameters", diff --git a/src/integrations/prefect-aws/tests/test_utilities.py b/src/integrations/prefect-aws/tests/test_utilities.py index cdff07b8458a..9cabe2e1f29c 100644 --- a/src/integrations/prefect-aws/tests/test_utilities.py +++ b/src/integrations/prefect-aws/tests/test_utilities.py @@ -9,15 +9,15 @@ class TestHashCollection: def test_simple_dict(self): simple_dict = {"key1": "value1", "key2": "value2"} - assert hash_collection(simple_dict) == hash_collection( - simple_dict - ), "Simple dictionary hashing failed" + assert hash_collection(simple_dict) == hash_collection(simple_dict), ( + "Simple dictionary hashing failed" + ) def test_nested_dict(self): nested_dict = {"key1": {"subkey1": "subvalue1"}, "key2": "value2"} - assert hash_collection(nested_dict) == hash_collection( - nested_dict - ), "Nested dictionary hashing failed" + assert hash_collection(nested_dict) == hash_collection(nested_dict), ( + "Nested dictionary hashing failed" + ) def test_complex_structure(self): complex_structure = { @@ -57,9 +57,9 @@ def test_existing_path(self): doc = {"key1": {"subkey1": "value1"}} path = ["key1", "subkey1"] ensure_path_exists(doc, path) - assert doc == { - "key1": {"subkey1": "value1"} - }, "Existing path modification failed" + assert doc == {"key1": {"subkey1": "value1"}}, ( + "Existing path modification failed" + ) def test_new_path_object(self): doc = {} @@ -77,14 +77,14 @@ def test_existing_path_array(self): doc = {"key1": [{"subkey1": "value1"}]} path = ["key1", "0", "subkey1"] ensure_path_exists(doc, path) - assert doc == { - "key1": [{"subkey1": "value1"}] - }, "Existing path modification for array failed" + assert doc == {"key1": [{"subkey1": "value1"}]}, ( + "Existing path modification for array failed" + ) def test_existing_path_array_index_out_of_range(self): doc = {"key1": []} path = ["key1", "0", "subkey1"] ensure_path_exists(doc, path) - assert doc == { - "key1": [{"subkey1": {}}] - }, "Existing path modification for array index out of range failed" + assert doc == {"key1": [{"subkey1": {}}]}, ( + "Existing path modification for array index out of range failed" + ) diff --git a/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py b/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py index 7958590f0416..452dcc1efb9b 100644 --- a/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py +++ b/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py @@ -643,9 +643,9 @@ async def test_task_definition_arn(aws_credentials: AwsCredentials, flow_run: Fl task = describe_task(ecs_client, task_arn) print(task) - assert ( - task["taskDefinitionArn"] == task_definition_arn - ), "The task definition should be used without registering a new one" + assert task["taskDefinitionArn"] == task_definition_arn, ( + "The task definition should be used without registering a new one" + ) @pytest.mark.usefixtures("ecs_mocks") @@ -684,9 +684,9 @@ async def test_task_definition_arn_with_variables_that_are_ignored( _, task_arn = parse_identifier(result.identifier) task = describe_task(ecs_client, task_arn) - assert ( - task["taskDefinitionArn"] == task_definition_arn - ), "A new task definition should not be registered" + assert task["taskDefinitionArn"] == task_definition_arn, ( + "A new task definition should not be registered" + ) # TODO: Add logging for this case # assert ( @@ -721,9 +721,9 @@ async def test_environment_variables( prefect_container_definition = _get_container( task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME ) - assert not prefect_container_definition[ - "environment" - ], "Variables should not be passed until runtime" + assert not prefect_container_definition["environment"], ( + "Variables should not be passed until runtime" + ) prefect_container_overrides = _get_container( task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME @@ -828,9 +828,9 @@ async def test_slugified_labels( # Check if the slugified tags are as expected for key, value in expected_tags.items(): - assert ( - actual_tags.get(key) == value - ), f"Failed for key: {key} with expected value: {value}, but got {actual_tags.get(key)}" + assert actual_tags.get(key) == value, ( + f"Failed for key: {key} with expected value: {value}, but got {actual_tags.get(key)}" + ) @pytest.mark.usefixtures("ecs_mocks") @@ -1690,9 +1690,9 @@ async def test_worker_cache_miss_for_registered_task_definitions_clears_from_cac task_2 = describe_task(ecs_client, task_arn_2) assert task_1["taskDefinitionArn"] != task_2["taskDefinitionArn"] - assert ( - task_1["taskDefinitionArn"] not in _TASK_DEFINITION_CACHE.values() - ), _TASK_DEFINITION_CACHE + assert task_1["taskDefinitionArn"] not in _TASK_DEFINITION_CACHE.values(), ( + _TASK_DEFINITION_CACHE + ) @pytest.mark.usefixtures("ecs_mocks") @@ -1852,9 +1852,9 @@ async def test_worker_task_definition_cache_hit_on_config_changes( _, task_arn_2 = parse_identifier(result_2.identifier) task_2 = describe_task(ecs_client, task_arn_2) - assert ( - task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"] - ), "The existing task definition should be used" + assert task_1["taskDefinitionArn"] == task_2["taskDefinitionArn"], ( + "The existing task definition should be used" + ) @pytest.mark.usefixtures("ecs_mocks") @@ -1973,12 +1973,12 @@ async def test_user_defined_container_in_task_definition_template( default_container_overrides = _get_container( container_overrides, ECS_DEFAULT_CONTAINER_NAME ) - assert ( - user_container_overrides - ), "The user defined container should be included in overrides" - assert ( - default_container_overrides is None - ), "The default container should not be in overrides" + assert user_container_overrides, ( + "The user defined container should be included in overrides" + ) + assert default_container_overrides is None, ( + "The default container should not be in overrides" + ) @pytest.mark.usefixtures("ecs_mocks") @@ -2016,9 +2016,9 @@ async def test_user_defined_container_image_in_task_definition_template( prefect_container = _get_container( task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME ) - assert ( - prefect_container["image"] == "use-this-image" - ), "The image from the task definition should be used" + assert prefect_container["image"] == "use-this-image", ( + "The image from the task definition should be used" + ) @pytest.mark.usefixtures("ecs_mocks") @@ -2206,9 +2206,9 @@ async def test_user_defined_environment_variables_in_task_run_request_template( task_definition["containerDefinitions"], ECS_DEFAULT_CONTAINER_NAME ) - assert ( - prefect_container_definition["environment"] == [] - ), "No environment variables in the task definition" + assert prefect_container_definition["environment"] == [], ( + "No environment variables in the task definition" + ) prefect_container_overrides = _get_container( task["overrides"]["containerOverrides"], ECS_DEFAULT_CONTAINER_NAME diff --git a/src/integrations/prefect-azure/tests/test_aci_worker.py b/src/integrations/prefect-azure/tests/test_aci_worker.py index ff9bf1fe8476..b4bfedb404d9 100644 --- a/src/integrations/prefect-azure/tests/test_aci_worker.py +++ b/src/integrations/prefect-azure/tests/test_aci_worker.py @@ -1061,6 +1061,6 @@ async def test_consistent_container_group_naming( name_without_prefix_and_id = container_group_name[8:-37] assert name_without_prefix_and_id.replace("-", " ").lower() in flow_name.lower() - assert ( - len(container_group_name) <= max_length - ), f"Length: {len(container_group_name)}, Max: {max_length}" + assert len(container_group_name) <= max_length, ( + f"Length: {len(container_group_name)}, Max: {max_length}" + ) diff --git a/src/integrations/prefect-dask/prefect_dask/task_runners.py b/src/integrations/prefect-dask/prefect_dask/task_runners.py index 60066faaf870..3778acf0764a 100644 --- a/src/integrations/prefect-dask/prefect_dask/task_runners.py +++ b/src/integrations/prefect-dask/prefect_dask/task_runners.py @@ -321,8 +321,7 @@ def submit( parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectDaskFuture[R]]] = None, dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ) -> PrefectDaskFuture[R]: - ... + ) -> PrefectDaskFuture[R]: ... @overload def submit( @@ -331,8 +330,7 @@ def submit( parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectDaskFuture[R]]] = None, dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ) -> PrefectDaskFuture[R]: - ... + ) -> PrefectDaskFuture[R]: ... def submit( self, @@ -367,8 +365,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> PrefectFutureList[PrefectDaskFuture[R]]: - ... + ) -> PrefectFutureList[PrefectDaskFuture[R]]: ... @overload def map( @@ -376,8 +373,7 @@ def map( task: "Task[Any, R]", parameters: Dict[str, Any], wait_for: Optional[Iterable[PrefectFuture]] = None, - ) -> PrefectFutureList[PrefectDaskFuture[R]]: - ... + ) -> PrefectFutureList[PrefectDaskFuture[R]]: ... def map( self, diff --git a/src/integrations/prefect-databricks/prefect_databricks/flows.py b/src/integrations/prefect-databricks/prefect_databricks/flows.py index f49a4098abff..4ffa3bfe99f2 100644 --- a/src/integrations/prefect-databricks/prefect_databricks/flows.py +++ b/src/integrations/prefect-databricks/prefect_databricks/flows.py @@ -57,8 +57,7 @@ class DatabricksJobRunTimedOut(Exception): @flow( name="Submit jobs runs and wait for completion", description=( - "Triggers a Databricks jobs runs and waits for the " - "triggered runs to complete." + "Triggers a Databricks jobs runs and waits for the triggered runs to complete." ), ) async def jobs_runs_submit_and_wait_for_completion( @@ -493,8 +492,7 @@ def jobs_runs_wait_for_completion_flow(): @flow( name="Submit existing job runs and wait for completion", description=( - "Triggers a Databricks jobs runs and waits for the " - "triggered runs to complete." + "Triggers a Databricks jobs runs and waits for the triggered runs to complete." ), ) async def jobs_runs_submit_by_id_and_wait_for_completion( diff --git a/src/integrations/prefect-databricks/tests/test_flows.py b/src/integrations/prefect-databricks/tests/test_flows.py index b121e1960b13..3851cbc74b55 100644 --- a/src/integrations/prefect-databricks/tests/test_flows.py +++ b/src/integrations/prefect-databricks/tests/test_flows.py @@ -277,8 +277,7 @@ async def test_run_skipped( ) match = re.escape( # escape to handle the parentheses - "Databricks Jobs Runs Submit (prefect-job ID 36108) " - "was skipped: testing." + "Databricks Jobs Runs Submit (prefect-job ID 36108) was skipped: testing." ) with pytest.raises(DatabricksJobSkipped, match=match): await jobs_runs_submit_and_wait_for_completion( diff --git a/src/integrations/prefect-dbt/prefect_dbt/cloud/jobs.py b/src/integrations/prefect-dbt/prefect_dbt/cloud/jobs.py index b493cc1e3bd9..f56bdddb7b63 100644 --- a/src/integrations/prefect-dbt/prefect_dbt/cloud/jobs.py +++ b/src/integrations/prefect-dbt/prefect_dbt/cloud/jobs.py @@ -689,8 +689,7 @@ async def _wait_until_state( elapsed_time_seconds = time.time() - start_time if elapsed_time_seconds > timeout_seconds: raise DbtCloudJobRunTimedOut( - f"Max wait time of {timeout_seconds} " - "seconds exceeded while waiting" + f"Max wait time of {timeout_seconds} seconds exceeded while waiting" ) await asyncio.sleep(interval_seconds) @@ -752,7 +751,9 @@ async def fetch_result(self, step: Optional[int] = None) -> Dict[str, Any]: run_status = DbtCloudJobRunStatus(run_data.get("status")) if run_status == DbtCloudJobRunStatus.SUCCESS: try: - async with self._dbt_cloud_credentials.get_administrative_client() as client: # noqa + async with ( + self._dbt_cloud_credentials.get_administrative_client() as client + ): # noqa response = await client.list_run_artifacts( run_id=self.run_id, step=step ) @@ -1127,8 +1128,7 @@ def run_dbt_cloud_job_flow(): return result except DbtCloudJobRunFailed: logger.info( - f"Retrying job run with ID: {run.run_id} " - f"{targeted_retries} more times" + f"Retrying job run with ID: {run.run_id} {targeted_retries} more times" ) run = await task(run.retry_failed_steps.aio)(run) targeted_retries -= 1 diff --git a/src/integrations/prefect-dbt/prefect_dbt/utilities.py b/src/integrations/prefect-dbt/prefect_dbt/utilities.py index 9430e869ef0a..9811e7f019ec 100644 --- a/src/integrations/prefect-dbt/prefect_dbt/utilities.py +++ b/src/integrations/prefect-dbt/prefect_dbt/utilities.py @@ -1,6 +1,7 @@ """ Utility functions for prefect-dbt """ + import os from typing import Any, Dict, Optional diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py index 348ad84fa35a..48377de4e146 100644 --- a/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py +++ b/src/integrations/prefect-gcp/prefect_gcp/workers/cloud_run_v2.py @@ -242,9 +242,9 @@ def _populate_image_if_not_present(self): Populates the job body with the image if not present. """ if "image" not in self.job_body["template"]["template"]["containers"][0]: - self.job_body["template"]["template"]["containers"][0][ - "image" - ] = f"docker.io/{get_prefect_image_name()}" + self.job_body["template"]["template"]["containers"][0]["image"] = ( + f"docker.io/{get_prefect_image_name()}" + ) def _populate_or_format_command(self): """ @@ -253,13 +253,13 @@ def _populate_or_format_command(self): command = self.job_body["template"]["template"]["containers"][0].get("command") if command is None: - self.job_body["template"]["template"]["containers"][0][ - "command" - ] = shlex.split(self._base_flow_run_command()) + self.job_body["template"]["template"]["containers"][0]["command"] = ( + shlex.split(self._base_flow_run_command()) + ) elif isinstance(command, str): - self.job_body["template"]["template"]["containers"][0][ - "command" - ] = shlex.split(command) + self.job_body["template"]["template"]["containers"][0]["command"] = ( + shlex.split(command) + ) def _format_args_if_present(self): """ @@ -268,9 +268,9 @@ def _format_args_if_present(self): args = self.job_body["template"]["template"]["containers"][0].get("args") if args is not None and isinstance(args, str): - self.job_body["template"]["template"]["containers"][0][ - "args" - ] = shlex.split(args) + self.job_body["template"]["template"]["containers"][0]["args"] = ( + shlex.split(args) + ) def _remove_vpc_access_if_unset(self): """ @@ -751,8 +751,7 @@ def _watch_job_execution_and_get_result( ) except Exception as exc: logger.critical( - f"Encountered an exception while waiting for job run completion - " - f"{exc}" + f"Encountered an exception while waiting for job run completion - {exc}" ) raise diff --git a/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py b/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py index 4f92c7a195c1..645abf5a4ab2 100644 --- a/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py +++ b/src/integrations/prefect-gcp/prefect_gcp/workers/vertex.py @@ -525,8 +525,7 @@ async def _create_and_begin_job( ) logger.info( - f"Job {job_name!r} created. " - f"The full job name is {custom_job_run.name!r}" + f"Job {job_name!r} created. The full job name is {custom_job_run.name!r}" ) return custom_job_run diff --git a/src/integrations/prefect-gcp/tests/test_cloud_run_worker_v2.py b/src/integrations/prefect-gcp/tests/test_cloud_run_worker_v2.py index 5b422df5c6be..5d115fbe6ca1 100644 --- a/src/integrations/prefect-gcp/tests/test_cloud_run_worker_v2.py +++ b/src/integrations/prefect-gcp/tests/test_cloud_run_worker_v2.py @@ -171,9 +171,9 @@ def test_format_args_if_present(self, cloud_run_worker_v2_job_config): def test_remove_vpc_access_if_connector_unset( self, cloud_run_worker_v2_job_config, vpc_access ): - cloud_run_worker_v2_job_config.job_body["template"]["template"][ - "vpcAccess" - ] = vpc_access + cloud_run_worker_v2_job_config.job_body["template"]["template"]["vpcAccess"] = ( + vpc_access + ) cloud_run_worker_v2_job_config._remove_vpc_access_if_unset() diff --git a/src/integrations/prefect-gcp/tests/test_credentials.py b/src/integrations/prefect-gcp/tests/test_credentials.py index 66770d967f58..fd35be01c1d1 100644 --- a/src/integrations/prefect-gcp/tests/test_credentials.py +++ b/src/integrations/prefect-gcp/tests/test_credentials.py @@ -187,9 +187,9 @@ async def test_get_job_service_async_client_cached( project=project, ) - assert ( - _get_job_service_async_client_cached.cache_info().hits == 0 - ), "Initial call count should be 0" + assert _get_job_service_async_client_cached.cache_info().hits == 0, ( + "Initial call count should be 0" + ) credentials.get_job_service_async_client(client_options={}) credentials.get_job_service_async_client(client_options={}) @@ -214,9 +214,9 @@ async def test_get_job_service_async_client_cached_from_file( project=project, ) - assert ( - _get_job_service_async_client_cached.cache_info().hits == 0 - ), "Initial call count should be 0" + assert _get_job_service_async_client_cached.cache_info().hits == 0, ( + "Initial call count should be 0" + ) credentials.get_job_service_async_client(client_options={}) credentials.get_job_service_async_client(client_options={}) diff --git a/src/integrations/prefect-gcp/tests/test_vertex_worker.py b/src/integrations/prefect-gcp/tests/test_vertex_worker.py index 3d299b81ecc8..ff16762ca1d4 100644 --- a/src/integrations/prefect-gcp/tests/test_vertex_worker.py +++ b/src/integrations/prefect-gcp/tests/test_vertex_worker.py @@ -124,9 +124,9 @@ def test_valid_command_formatting( "worker_pool_specs" ][0]["container_spec"]["command"] - job_config.job_spec["worker_pool_specs"][0]["container_spec"][ - "command" - ] = "echo -n hello" + job_config.job_spec["worker_pool_specs"][0]["container_spec"]["command"] = ( + "echo -n hello" + ) job_config.prepare_for_flow_run(flow_run, None, None) assert ["echo", "-n", "hello"] == job_config.job_spec["worker_pool_specs"][0][ "container_spec" diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py index 48ee0d01d50d..43db19cc7f99 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/events.py @@ -108,9 +108,9 @@ async def _emit_pod_event( for container_status in pod.status.container_statuses: if container_status.state.terminated.reason in EVICTED_REASONS: pod_phase = "evicted" - resource[ - "kubernetes.reason" - ] = container_status.state.terminated.reason + resource["kubernetes.reason"] = ( + container_status.state.terminated.reason + ) break return emit_event( diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py index 5650ca551f7b..2cc63eb5798b 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/jobs.py @@ -386,7 +386,7 @@ async def _cleanup(self): namespace=self._kubernetes_job.namespace, **self._kubernetes_job.api_kwargs, ) - self.logger.info(f"Job {job_name} deleted " f"with {deleted_v1_job.status!r}.") + self.logger.info(f"Job {job_name} deleted with {deleted_v1_job.status!r}.") @sync_compatible async def wait_for_completion(self, print_func: Optional[Callable] = None): diff --git a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py index 69157e39465d..5f50bc7d5d59 100644 --- a/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py +++ b/src/integrations/prefect-kubernetes/prefect_kubernetes/worker.py @@ -404,9 +404,9 @@ def _populate_env_in_manifest(self): # a list of dicts. Might be able to improve this in the future with a better # default `env` value and better typing. else: - self.job_manifest["spec"]["template"]["spec"]["containers"][0][ - "env" - ] = transformed_env + self.job_manifest["spec"]["template"]["spec"]["containers"][0]["env"] = ( + transformed_env + ) def _update_prefect_api_url_if_local_server(self): """If the API URL has been set by the base environment rather than the by the @@ -704,9 +704,9 @@ async def _replace_api_key_with_secret( ) # Store configuration so that we can delete the secret when the worker shuts # down - self._created_secrets[ - (secret.metadata.name, secret.metadata.namespace) - ] = configuration + self._created_secrets[(secret.metadata.name, secret.metadata.namespace)] = ( + configuration + ) if secret_name: new_api_env_entry = { "name": "PREFECT_API_KEY", diff --git a/src/integrations/prefect-ray/prefect_ray/task_runners.py b/src/integrations/prefect-ray/prefect_ray/task_runners.py index d8bef2ac5dcc..145bf42945d0 100644 --- a/src/integrations/prefect-ray/prefect_ray/task_runners.py +++ b/src/integrations/prefect-ray/prefect_ray/task_runners.py @@ -70,6 +70,7 @@ def count_to(highest_number): #9 ``` """ + from __future__ import annotations import asyncio # noqa: I001 @@ -249,8 +250,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> PrefectRayFuture[R]: - ... + ) -> PrefectRayFuture[R]: ... @overload def submit( @@ -259,8 +259,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> PrefectRayFuture[R]: - ... + ) -> PrefectRayFuture[R]: ... def submit( self, @@ -307,8 +306,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, - ) -> PrefectFutureList[PrefectRayFuture[R]]: - ... + ) -> PrefectFutureList[PrefectRayFuture[R]]: ... @overload def map( @@ -316,8 +314,7 @@ def map( task: "Task[Any, R]", parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, - ) -> PrefectFutureList[PrefectRayFuture[R]]: - ... + ) -> PrefectFutureList[PrefectRayFuture[R]]: ... def map( self, diff --git a/src/integrations/prefect-redis/tests/test_tasks.py b/src/integrations/prefect-redis/tests/test_tasks.py index d417af996a04..02567155e22a 100644 --- a/src/integrations/prefect-redis/tests/test_tasks.py +++ b/src/integrations/prefect-redis/tests/test_tasks.py @@ -122,7 +122,7 @@ async def test_set_obj(redis_credentials: RedisDatabase, random_key: str): await redis_set.fn(redis_credentials, random_key, ref_obj, ex=60) test_value = await redis_get.fn(redis_credentials, random_key) - assert type(ref_obj) == type(test_value) + assert type(ref_obj) is type(test_value) assert len(ref_obj) == len(test_value) assert ref_obj[0] == test_value[0] diff --git a/src/integrations/prefect-shell/prefect_shell/commands.py b/src/integrations/prefect-shell/prefect_shell/commands.py index 50360595730b..e40f7e7939e9 100644 --- a/src/integrations/prefect-shell/prefect_shell/commands.py +++ b/src/integrations/prefect-shell/prefect_shell/commands.py @@ -109,9 +109,7 @@ def example_shell_run_command_flow(): ) if not stderr and lines: stderr = f"{lines[-1]}\n" - msg = ( - f"Command failed with exit code {process.returncode}:\n" f"{stderr}" - ) + msg = f"Command failed with exit code {process.returncode}:\n{stderr}" raise RuntimeError(msg) finally: if os.path.exists(tmp.name): diff --git a/src/integrations/prefect-snowflake/tests/test_credentials.py b/src/integrations/prefect-snowflake/tests/test_credentials.py index 28e660f67b47..b1541b0707e8 100644 --- a/src/integrations/prefect-snowflake/tests/test_credentials.py +++ b/src/integrations/prefect-snowflake/tests/test_credentials.py @@ -136,9 +136,9 @@ def test_snowflake_credentials_validate_private_key_password( def test_snowflake_credentials_validate_private_key_passphrase( private_credentials_params, ): - private_credentials_params[ - "private_key_passphrase" - ] = private_credentials_params.pop("password") + private_credentials_params["private_key_passphrase"] = ( + private_credentials_params.pop("password") + ) credentials_params_missing = private_credentials_params.copy() password = credentials_params_missing.pop("private_key_passphrase") private_key = credentials_params_missing.pop("private_key") @@ -155,9 +155,9 @@ def test_snowflake_credentials_validate_private_key_path( private_key_path = tmp_path / "private_key.pem" private_key_path.write_bytes(private_credentials_params.pop("private_key")) private_credentials_params["private_key_path"] = private_key_path - private_credentials_params[ - "private_key_passphrase" - ] = private_credentials_params.pop("password") + private_credentials_params["private_key_passphrase"] = ( + private_credentials_params.pop("password") + ) credentials = SnowflakeCredentials(**private_credentials_params) assert credentials.resolve_private_key() is not None diff --git a/src/prefect/_internal/compatibility/migration.py b/src/prefect/_internal/compatibility/migration.py index 0478be7f4ca5..ebd7d647d4c4 100644 --- a/src/prefect/_internal/compatibility/migration.py +++ b/src/prefect/_internal/compatibility/migration.py @@ -34,7 +34,7 @@ ```python # at top from prefect._internal.compatibility.migration import getattr_migration - + # at bottom __getattr__ = getattr_migration(__name__) ``` diff --git a/src/prefect/_internal/concurrency/calls.py b/src/prefect/_internal/concurrency/calls.py index 4a715ac90491..15de9506dc18 100644 --- a/src/prefect/_internal/concurrency/calls.py +++ b/src/prefect/_internal/concurrency/calls.py @@ -155,8 +155,7 @@ def cancel(self) -> bool: if TYPE_CHECKING: - def __get_result(self) -> T: - ... + def __get_result(self) -> T: ... def result(self, timeout: Optional[float] = None) -> T: """Return the result of the call that the future represents. diff --git a/src/prefect/_internal/concurrency/cancellation.py b/src/prefect/_internal/concurrency/cancellation.py index 57564ba4c9f5..443aac08f63c 100644 --- a/src/prefect/_internal/concurrency/cancellation.py +++ b/src/prefect/_internal/concurrency/cancellation.py @@ -473,13 +473,11 @@ def cancel(self, throw: bool = True): @overload -def get_deadline(timeout: float) -> float: - ... +def get_deadline(timeout: float) -> float: ... @overload -def get_deadline(timeout: None) -> None: - ... +def get_deadline(timeout: None) -> None: ... def get_deadline(timeout: Optional[float]) -> Optional[float]: diff --git a/src/prefect/_internal/concurrency/threads.py b/src/prefect/_internal/concurrency/threads.py index b94e6127b2ad..af32f9b9ddca 100644 --- a/src/prefect/_internal/concurrency/threads.py +++ b/src/prefect/_internal/concurrency/threads.py @@ -135,9 +135,9 @@ def __init__( self.thread = threading.Thread( name=name, daemon=daemon, target=self._entrypoint ) - self._ready_future: concurrent.futures.Future[ - bool - ] = concurrent.futures.Future() + self._ready_future: concurrent.futures.Future[bool] = ( + concurrent.futures.Future() + ) self._loop: Optional[asyncio.AbstractEventLoop] = None self._shutdown_event: Event = Event() self._run_once: bool = run_once diff --git a/src/prefect/_internal/schemas/validators.py b/src/prefect/_internal/schemas/validators.py index 879774fb297e..dadfb79dae4d 100644 --- a/src/prefect/_internal/schemas/validators.py +++ b/src/prefect/_internal/schemas/validators.py @@ -40,13 +40,15 @@ @overload -def raise_on_name_alphanumeric_dashes_only(value: str, field_name: str = ...) -> str: - ... +def raise_on_name_alphanumeric_dashes_only( + value: str, field_name: str = ... +) -> str: ... @overload -def raise_on_name_alphanumeric_dashes_only(value: None, field_name: str = ...) -> None: - ... +def raise_on_name_alphanumeric_dashes_only( + value: None, field_name: str = ... +) -> None: ... def raise_on_name_alphanumeric_dashes_only( @@ -64,15 +66,13 @@ def raise_on_name_alphanumeric_dashes_only( @overload def raise_on_name_alphanumeric_underscores_only( value: str, field_name: str = ... -) -> str: - ... +) -> str: ... @overload def raise_on_name_alphanumeric_underscores_only( value: None, field_name: str = ... -) -> None: - ... +) -> None: ... def raise_on_name_alphanumeric_underscores_only( @@ -149,13 +149,13 @@ def validate_parameters_conform_to_schema( @overload -def validate_parameter_openapi_schema(schema: M, values: Mapping[str, Any]) -> M: - ... +def validate_parameter_openapi_schema(schema: M, values: Mapping[str, Any]) -> M: ... @overload -def validate_parameter_openapi_schema(schema: None, values: Mapping[str, Any]) -> None: - ... +def validate_parameter_openapi_schema( + schema: None, values: Mapping[str, Any] +) -> None: ... def validate_parameter_openapi_schema( @@ -198,13 +198,11 @@ def reconcile_schedules_runner(values: MM) -> MM: @overload -def validate_schedule_max_scheduled_runs(v: int, limit: int) -> int: - ... +def validate_schedule_max_scheduled_runs(v: int, limit: int) -> int: ... @overload -def validate_schedule_max_scheduled_runs(v: None, limit: int) -> None: - ... +def validate_schedule_max_scheduled_runs(v: None, limit: int) -> None: ... def validate_schedule_max_scheduled_runs(v: Optional[int], limit: int) -> Optional[int]: @@ -260,15 +258,13 @@ def default_anchor_date(v: pendulum.DateTime) -> pendulum.DateTime: @overload -def default_timezone(v: str, values: Optional[Mapping[str, Any]] = ...) -> str: - ... +def default_timezone(v: str, values: Optional[Mapping[str, Any]] = ...) -> str: ... @overload def default_timezone( v: None, values: Optional[Mapping[str, Any]] = ... -) -> Optional[str]: - ... +) -> Optional[str]: ... def default_timezone( @@ -411,13 +407,11 @@ def validate_load_kwargs(value: M) -> M: @overload -def cast_type_names_to_serializers(value: str) -> "Serializer[Any]": - ... +def cast_type_names_to_serializers(value: str) -> "Serializer[Any]": ... @overload -def cast_type_names_to_serializers(value: "Serializer[T]") -> "Serializer[T]": - ... +def cast_type_names_to_serializers(value: "Serializer[T]") -> "Serializer[T]": ... def cast_type_names_to_serializers( @@ -457,13 +451,11 @@ def validate_compressionlib(value: str) -> str: # TODO: if we use this elsewhere we can change the error message to be more generic @overload -def list_length_50_or_less(v: list[float]) -> list[float]: - ... +def list_length_50_or_less(v: list[float]) -> list[float]: ... @overload -def list_length_50_or_less(v: None) -> None: - ... +def list_length_50_or_less(v: None) -> None: ... def list_length_50_or_less(v: Optional[list[float]]) -> Optional[list[float]]: @@ -474,13 +466,11 @@ def list_length_50_or_less(v: Optional[list[float]]) -> Optional[list[float]]: # TODO: if we use this elsewhere we can change the error message to be more generic @overload -def validate_not_negative(v: float) -> float: - ... +def validate_not_negative(v: float) -> float: ... @overload -def validate_not_negative(v: None) -> None: - ... +def validate_not_negative(v: None) -> None: ... def validate_not_negative(v: Optional[float]) -> Optional[float]: @@ -490,13 +480,11 @@ def validate_not_negative(v: Optional[float]) -> Optional[float]: @overload -def validate_message_template_variables(v: str) -> str: - ... +def validate_message_template_variables(v: str) -> str: ... @overload -def validate_message_template_variables(v: None) -> None: - ... +def validate_message_template_variables(v: None) -> None: ... def validate_message_template_variables(v: Optional[str]) -> Optional[str]: @@ -521,13 +509,11 @@ def validate_default_queue_id_not_none(v: Optional[UUID]) -> UUID: @overload -def validate_max_metadata_length(v: MM) -> MM: - ... +def validate_max_metadata_length(v: MM) -> MM: ... @overload -def validate_max_metadata_length(v: None) -> None: - ... +def validate_max_metadata_length(v: None) -> None: ... def validate_max_metadata_length(v: Optional[MM]) -> Optional[MM]: @@ -544,13 +530,11 @@ def validate_max_metadata_length(v: Optional[MM]) -> Optional[MM]: @overload -def validate_cache_key_length(cache_key: str) -> str: - ... +def validate_cache_key_length(cache_key: str) -> str: ... @overload -def validate_cache_key_length(cache_key: None) -> None: - ... +def validate_cache_key_length(cache_key: None) -> None: ... def validate_cache_key_length(cache_key: Optional[str]) -> Optional[str]: @@ -587,13 +571,11 @@ def set_run_policy_deprecated_fields(values: MM) -> MM: @overload -def return_v_or_none(v: str) -> str: - ... +def return_v_or_none(v: str) -> str: ... @overload -def return_v_or_none(v: None) -> None: - ... +def return_v_or_none(v: None) -> None: ... def return_v_or_none(v: Optional[str]) -> Optional[str]: @@ -629,13 +611,11 @@ def validate_name_present_on_nonanonymous_blocks(values: M) -> M: @overload -def validate_working_dir(v: str) -> Path: - ... +def validate_working_dir(v: str) -> Path: ... @overload -def validate_working_dir(v: None) -> None: - ... +def validate_working_dir(v: None) -> None: ... def validate_working_dir(v: Optional[Path | str]) -> Optional[Path]: @@ -652,13 +632,11 @@ def validate_working_dir(v: Optional[Path | str]) -> Optional[Path]: @overload -def validate_block_document_name(value: str) -> str: - ... +def validate_block_document_name(value: str) -> str: ... @overload -def validate_block_document_name(value: None) -> None: - ... +def validate_block_document_name(value: None) -> None: ... def validate_block_document_name(value: Optional[str]) -> Optional[str]: @@ -673,13 +651,11 @@ def validate_artifact_key(value: str) -> str: @overload -def validate_variable_name(value: str) -> str: - ... +def validate_variable_name(value: str) -> str: ... @overload -def validate_variable_name(value: None) -> None: - ... +def validate_variable_name(value: None) -> None: ... def validate_variable_name(value: Optional[str]) -> Optional[str]: diff --git a/src/prefect/agent.py b/src/prefect/agent.py index 7d5d3c1309c2..e0d5ba07c514 100644 --- a/src/prefect/agent.py +++ b/src/prefect/agent.py @@ -1,6 +1,7 @@ """ 2024-06-27: This surfaces an actionable error message for moved or removed objects in Prefect 3.0 upgrade. """ + from typing import Any, Callable from prefect._internal.compatibility.migration import getattr_migration diff --git a/src/prefect/automations.py b/src/prefect/automations.py index 6ec5192a4354..115491dab7c3 100644 --- a/src/prefect/automations.py +++ b/src/prefect/automations.py @@ -179,13 +179,11 @@ def update(self: Self): @overload @classmethod - async def aread(cls, id: UUID, name: Optional[str] = ...) -> Self: - ... + async def aread(cls, id: UUID, name: Optional[str] = ...) -> Self: ... @overload @classmethod - async def aread(cls, id: None = None, name: str = ...) -> Self: - ... + async def aread(cls, id: None = None, name: str = ...) -> Self: ... @classmethod async def aread(cls, id: Optional[UUID] = None, name: Optional[str] = None) -> Self: @@ -227,13 +225,11 @@ async def aread(cls, id: Optional[UUID] = None, name: Optional[str] = None) -> S @overload @classmethod - async def read(cls, id: UUID, name: Optional[str] = ...) -> Self: - ... + async def read(cls, id: UUID, name: Optional[str] = ...) -> Self: ... @overload @classmethod - async def read(cls, id: None = None, name: str = ...) -> Self: - ... + async def read(cls, id: None = None, name: str = ...) -> Self: ... @classmethod @async_dispatch(aread) diff --git a/src/prefect/blocks/notifications.py b/src/prefect/blocks/notifications.py index ae74bf05f05e..5dbda64f472c 100644 --- a/src/prefect/blocks/notifications.py +++ b/src/prefect/blocks/notifications.py @@ -22,14 +22,14 @@ class AbstractAppriseNotificationBlock(NotificationBlock, ABC): An abstract class for sending notifications using Apprise. """ - notify_type: Literal[ - "prefect_default", "info", "success", "warning", "failure" - ] = Field( - default=PREFECT_NOTIFY_TYPE_DEFAULT, - description=( - "The type of notification being performed; the prefect_default " - "is a plain notification that does not attach an image." - ), + notify_type: Literal["prefect_default", "info", "success", "warning", "failure"] = ( + Field( + default=PREFECT_NOTIFY_TYPE_DEFAULT, + description=( + "The type of notification being performed; the prefect_default " + "is a plain notification that does not attach an image." + ), + ) ) def __init__(self, *args: Any, **kwargs: Any): diff --git a/src/prefect/cli/_prompts.py b/src/prefect/cli/_prompts.py index fcd516b215b7..c22cfa55be9f 100644 --- a/src/prefect/cli/_prompts.py +++ b/src/prefect/cli/_prompts.py @@ -217,8 +217,7 @@ def prompt_select_from_table( table_kwargs: dict[str, Any] | None = None, opt_out_message: None = None, opt_out_response: Any = None, -) -> dict[str, T]: - ... +) -> dict[str, T]: ... @overload @@ -230,8 +229,7 @@ def prompt_select_from_table( table_kwargs: dict[str, Any] | None = None, opt_out_message: str = "", opt_out_response: Any = None, -) -> dict[str, T] | None: - ... +) -> dict[str, T] | None: ... def prompt_select_from_table( @@ -665,9 +663,9 @@ async def prompt_push_custom_docker_image( import prefect_docker credentials_block = prefect_docker.DockerRegistryCredentials - push_step[ - "credentials" - ] = "{{ prefect_docker.docker-registry-credentials.docker_registry_creds_name }}" + push_step["credentials"] = ( + "{{ prefect_docker.docker-registry-credentials.docker_registry_creds_name }}" + ) docker_registry_creds_name = f"deployment-{slugify(deployment_config['name'])}-{slugify(deployment_config['work_pool']['name'])}-registry-creds" create_new_block = False try: @@ -981,7 +979,7 @@ async def prompt_select_blob_storage_credentials( url = urls.url_for(new_block_document) if url: console.print( - "\nView/Edit your new credentials block in the UI:" f"\n[blue]{url}[/]\n", + f"\nView/Edit your new credentials block in the UI:\n[blue]{url}[/]\n", soft_wrap=True, ) return f"{{{{ prefect.blocks.{creds_block_type_slug}.{new_block_document.name} }}}}" diff --git a/src/prefect/cli/_utilities.py b/src/prefect/cli/_utilities.py index 2e09c86909ef..25c794ee733a 100644 --- a/src/prefect/cli/_utilities.py +++ b/src/prefect/cli/_utilities.py @@ -1,6 +1,7 @@ """ Utilities for Prefect CLI commands """ + from __future__ import annotations import functools diff --git a/src/prefect/cli/cloud/__init__.py b/src/prefect/cli/cloud/__init__.py index d106832d17f6..5c584b86bc79 100644 --- a/src/prefect/cli/cloud/__init__.py +++ b/src/prefect/cli/cloud/__init__.py @@ -1,6 +1,7 @@ """ Command line interface for interacting with Prefect Cloud """ + from __future__ import annotations import os @@ -170,15 +171,15 @@ def get_current_workspace(workspaces: Iterable[Workspace]) -> Workspace | None: @overload -def prompt_select_from_list(console: Console, prompt: str, options: list[str]) -> str: - ... +def prompt_select_from_list( + console: Console, prompt: str, options: list[str] +) -> str: ... @overload def prompt_select_from_list( console: Console, prompt: str, options: list[tuple[T, str]] -) -> T: - ... +) -> T: ... def prompt_select_from_list( diff --git a/src/prefect/cli/cloud/webhook.py b/src/prefect/cli/cloud/webhook.py index 734a2e7edf35..138a4c1aa413 100644 --- a/src/prefect/cli/cloud/webhook.py +++ b/src/prefect/cli/cloud/webhook.py @@ -104,7 +104,7 @@ async def create( "template": template, }, ) - app.console.print(f'Successfully created webhook {response["name"]}') + app.console.print(f"Successfully created webhook {response['name']}") @webhook_app.command() @@ -124,7 +124,7 @@ async def rotate(webhook_id: UUID): # The /webhooks API lives inside the /accounts/{id}/workspaces/{id} routing tree async with get_cloud_client(host=PREFECT_API_URL.value()) as client: response = await client.request("POST", f"/webhooks/{webhook_id}/rotate") - app.console.print(f'Successfully rotated webhook URL to {response["slug"]}') + app.console.print(f"Successfully rotated webhook URL to {response['slug']}") @webhook_app.command() diff --git a/src/prefect/client/base.py b/src/prefect/client/base.py index 7eed8a92a497..8b17d7594099 100644 --- a/src/prefect/client/base.py +++ b/src/prefect/client/base.py @@ -55,8 +55,7 @@ @runtime_checkable class ASGIApp(Protocol): - async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: - ... + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ... @asynccontextmanager @@ -301,9 +300,9 @@ async def _send_with_retry( ) await anyio.sleep(retry_seconds) - assert ( - response is not None - ), "Retry handling ended without response or exception" + assert response is not None, ( + "Retry handling ended without response or exception" + ) # We ran out of retries, return the failed response return response @@ -520,9 +519,9 @@ def _send_with_retry( ) time.sleep(retry_seconds) - assert ( - response is not None - ), "Retry handling ended without response or exception" + assert response is not None, ( + "Retry handling ended without response or exception" + ) # We ran out of retries, return the failed response return response diff --git a/src/prefect/client/collections.py b/src/prefect/client/collections.py index e5bd79f04325..02ae6c13b353 100644 --- a/src/prefect/client/collections.py +++ b/src/prefect/client/collections.py @@ -7,14 +7,11 @@ class CollectionsMetadataClient(Protocol): - async def read_worker_metadata(self) -> Dict[str, Any]: - ... + async def read_worker_metadata(self) -> Dict[str, Any]: ... - async def __aenter__(self) -> "CollectionsMetadataClient": - ... + async def __aenter__(self) -> "CollectionsMetadataClient": ... - async def __aexit__(self, *exc_info: Any) -> Any: - ... + async def __aexit__(self, *exc_info: Any) -> Any: ... def get_collections_metadata_client( diff --git a/src/prefect/client/orchestration/__init__.py b/src/prefect/client/orchestration/__init__.py index 5735f01e06d9..54d00099bda9 100644 --- a/src/prefect/client/orchestration/__init__.py +++ b/src/prefect/client/orchestration/__init__.py @@ -157,15 +157,13 @@ def get_client( *, httpx_settings: Optional[dict[str, Any]] = ..., sync_client: Literal[False] = False, -) -> "PrefectClient": - ... +) -> "PrefectClient": ... @overload def get_client( *, httpx_settings: Optional[dict[str, Any]] = ..., sync_client: Literal[True] = ... -) -> "SyncPrefectClient": - ... +) -> "SyncPrefectClient": ... def get_client( diff --git a/src/prefect/client/schemas/objects.py b/src/prefect/client/schemas/objects.py index a54a9f470659..bd6aa31c5f38 100644 --- a/src/prefect/client/schemas/objects.py +++ b/src/prefect/client/schemas/objects.py @@ -221,8 +221,7 @@ def result( raise_on_failure: Literal[True] = ..., fetch: bool = ..., retry_result_failure: bool = ..., - ) -> R: - ... + ) -> R: ... @overload def result( @@ -230,8 +229,7 @@ def result( raise_on_failure: Literal[False] = False, fetch: bool = ..., retry_result_failure: bool = ..., - ) -> Union[R, Exception]: - ... + ) -> Union[R, Exception]: ... @overload def result( @@ -239,8 +237,7 @@ def result( raise_on_failure: bool = ..., fetch: bool = ..., retry_result_failure: bool = ..., - ) -> Union[R, Exception]: - ... + ) -> Union[R, Exception]: ... @deprecated.deprecated_parameter( "fetch", diff --git a/src/prefect/client/schemas/schedules.py b/src/prefect/client/schemas/schedules.py index 674957ad9d8b..cc26822a05e4 100644 --- a/src/prefect/client/schemas/schedules.py +++ b/src/prefect/client/schemas/schedules.py @@ -103,8 +103,7 @@ def __init__( Union[pendulum.DateTime, datetime.datetime, str] ] = None, timezone: Optional[str] = None, - ) -> None: - ... + ) -> None: ... class CronSchedule(PrefectBaseModel): diff --git a/src/prefect/context.py b/src/prefect/context.py index b2852501ed9a..0a24544b46de 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -123,8 +123,7 @@ class ContextModel(BaseModel): if TYPE_CHECKING: # subclasses can pass through keyword arguments to the pydantic base model - def __init__(self, **kwargs: Any) -> None: - ... + def __init__(self, **kwargs: Any) -> None: ... # The context variable for storing data must be defined by the child class __var__: ClassVar[ContextVar[Self]] diff --git a/src/prefect/events/schemas/automations.py b/src/prefect/events/schemas/automations.py index 2eeea40214c6..6376c2d90cb2 100644 --- a/src/prefect/events/schemas/automations.py +++ b/src/prefect/events/schemas/automations.py @@ -198,7 +198,9 @@ def enforce_minimum_within_for_proactive_triggers( "10 seconds" ) - return data | {"within": within} if within else data + if within: + data = {**data, "within": within} + return data def describe_for_cli(self, indent: int = 0) -> str: """Return a human-readable description of this trigger for the CLI""" @@ -248,7 +250,7 @@ class MetricTriggerQuery(PrefectBaseModel): threshold: float = Field( ..., description=( - "The threshold value against which we'll compare " "the query result." + "The threshold value against which we'll compare the query result." ), ) operator: MetricTriggerOperator = Field( diff --git a/src/prefect/exceptions.py b/src/prefect/exceptions.py index e99e0e52477e..04b361e609c3 100644 --- a/src/prefect/exceptions.py +++ b/src/prefect/exceptions.py @@ -161,7 +161,7 @@ def __init__(self, msg: str): @classmethod def from_validation_error(cls, exc: ValidationError) -> Self: bad_params = [ - f'{".".join(str(item) for item in err["loc"])}: {err["msg"]}' + f"{'.'.join(str(item) for item in err['loc'])}: {err['msg']}" for err in exc.errors() ] msg = "Flow run received invalid parameters:\n - " + "\n - ".join(bad_params) diff --git a/src/prefect/flow_runs.py b/src/prefect/flow_runs.py index 526f0793d6b1..4e1ec2481f81 100644 --- a/src/prefect/flow_runs.py +++ b/src/prefect/flow_runs.py @@ -139,8 +139,7 @@ async def pause_flow_run( timeout: int = 3600, poll_interval: int = 10, key: Optional[str] = None, -) -> None: - ... +) -> None: ... @overload @@ -149,8 +148,7 @@ async def pause_flow_run( timeout: int = 3600, poll_interval: int = 10, key: Optional[str] = None, -) -> T: - ... +) -> T: ... @sync_compatible @@ -308,8 +306,7 @@ async def suspend_flow_run( timeout: Optional[int] = 3600, key: Optional[str] = None, client: Optional[PrefectClient] = None, -) -> None: - ... +) -> None: ... @overload @@ -319,8 +316,7 @@ async def suspend_flow_run( timeout: Optional[int] = 3600, key: Optional[str] = None, client: Optional[PrefectClient] = None, -) -> T: - ... +) -> T: ... @sync_compatible diff --git a/src/prefect/flows.py b/src/prefect/flows.py index 0b435ef2a8d2..78618f27f997 100644 --- a/src/prefect/flows.py +++ b/src/prefect/flows.py @@ -120,8 +120,7 @@ class FlowStateHook(Protocol, Generic[P, R]): def __call__( self, flow: Flow[P, R], flow_run: FlowRun, state: State - ) -> Awaitable[None] | None: - ... + ) -> Awaitable[None] | None: ... if TYPE_CHECKING: @@ -1535,16 +1534,14 @@ def __call__(self: "Flow[P, NoReturn]", *args: P.args, **kwargs: P.kwargs) -> No @overload def __call__( self: "Flow[P, Coroutine[Any, Any, T]]", *args: P.args, **kwargs: P.kwargs - ) -> Coroutine[Any, Any, T]: - ... + ) -> Coroutine[Any, Any, T]: ... @overload def __call__( self: "Flow[P, T]", *args: P.args, **kwargs: P.kwargs, - ) -> T: - ... + ) -> T: ... @overload def __call__( @@ -1552,8 +1549,7 @@ def __call__( *args: P.args, return_state: Literal[True], **kwargs: P.kwargs, - ) -> Awaitable[State[T]]: - ... + ) -> Awaitable[State[T]]: ... @overload def __call__( @@ -1561,8 +1557,7 @@ def __call__( *args: P.args, return_state: Literal[True], **kwargs: P.kwargs, - ) -> State[T]: - ... + ) -> State[T]: ... def __call__( self, @@ -1701,8 +1696,7 @@ async def visualize(self, *args: "P.args", **kwargs: "P.kwargs"): class FlowDecorator: @overload - def __call__(self, __fn: Callable[P, R]) -> Flow[P, R]: - ... + def __call__(self, __fn: Callable[P, R]) -> Flow[P, R]: ... @overload def __call__( @@ -1728,8 +1722,7 @@ def __call__( on_cancellation: Optional[list[FlowStateHook[..., Any]]] = None, on_crashed: Optional[list[FlowStateHook[..., Any]]] = None, on_running: Optional[list[FlowStateHook[..., Any]]] = None, - ) -> Callable[[Callable[P, R]], Flow[P, R]]: - ... + ) -> Callable[[Callable[P, R]], Flow[P, R]]: ... @overload def __call__( @@ -1755,8 +1748,7 @@ def __call__( on_cancellation: Optional[list[FlowStateHook[..., Any]]] = None, on_crashed: Optional[list[FlowStateHook[..., Any]]] = None, on_running: Optional[list[FlowStateHook[..., Any]]] = None, - ) -> Callable[[Callable[P, R]], Flow[P, R]]: - ... + ) -> Callable[[Callable[P, R]], Flow[P, R]]: ... def __call__( self, @@ -1950,8 +1942,7 @@ def __call__( def from_source( source: Union[str, "RunnerStorage", ReadableDeploymentStorage], entrypoint: str, - ) -> Union["Flow[..., Any]", Coroutine[Any, Any, "Flow[..., Any]"]]: - ... + ) -> Union["Flow[..., Any]", Coroutine[Any, Any, "Flow[..., Any]"]]: ... flow: FlowDecorator = FlowDecorator() @@ -2223,8 +2214,7 @@ def _display_serve_start_message(*args: "RunnerDeployment"): from rich.table import Table help_message_top = ( - "[green]Your deployments are being served and polling for" - " scheduled runs!\n[/]" + "[green]Your deployments are being served and polling for scheduled runs!\n[/]" ) table = Table(title="Deployments", show_header=False) diff --git a/src/prefect/infrastructure/__init__.py b/src/prefect/infrastructure/__init__.py index 7d5d3c1309c2..e0d5ba07c514 100644 --- a/src/prefect/infrastructure/__init__.py +++ b/src/prefect/infrastructure/__init__.py @@ -1,6 +1,7 @@ """ 2024-06-27: This surfaces an actionable error message for moved or removed objects in Prefect 3.0 upgrade. """ + from typing import Any, Callable from prefect._internal.compatibility.migration import getattr_migration diff --git a/src/prefect/infrastructure/base.py b/src/prefect/infrastructure/base.py index 7d5d3c1309c2..e0d5ba07c514 100644 --- a/src/prefect/infrastructure/base.py +++ b/src/prefect/infrastructure/base.py @@ -1,6 +1,7 @@ """ 2024-06-27: This surfaces an actionable error message for moved or removed objects in Prefect 3.0 upgrade. """ + from typing import Any, Callable from prefect._internal.compatibility.migration import getattr_migration diff --git a/src/prefect/infrastructure/provisioners/__init__.py b/src/prefect/infrastructure/provisioners/__init__.py index bb8270360841..df3c3c3d0051 100644 --- a/src/prefect/infrastructure/provisioners/__init__.py +++ b/src/prefect/infrastructure/provisioners/__init__.py @@ -22,20 +22,17 @@ class Provisioner(Protocol): @property - def console(self) -> rich.console.Console: - ... + def console(self) -> rich.console.Console: ... @console.setter - def console(self, value: rich.console.Console) -> None: - ... + def console(self, value: rich.console.Console) -> None: ... async def provision( self, work_pool_name: str, base_job_template: Dict[str, Any], client: Optional["PrefectClient"] = None, - ) -> Dict[str, Any]: - ... + ) -> Dict[str, Any]: ... def get_infrastructure_provisioner_for_work_pool_type( diff --git a/src/prefect/infrastructure/provisioners/coiled.py b/src/prefect/infrastructure/provisioners/coiled.py index f0adb5d7ce50..9e92ece2b4bf 100644 --- a/src/prefect/infrastructure/provisioners/coiled.py +++ b/src/prefect/infrastructure/provisioners/coiled.py @@ -118,9 +118,9 @@ async def _create_coiled_credentials_block( block_type_id=credentials_block_type.id ) ) - assert ( - credentials_block_schema is not None - ), f"Unable to find schema for block type {credentials_block_type.slug}" + assert credentials_block_schema is not None, ( + f"Unable to find schema for block type {credentials_block_type.slug}" + ) block_doc = await client.create_block_document( block_document=BlockDocumentCreate( diff --git a/src/prefect/infrastructure/provisioners/container_instance.py b/src/prefect/infrastructure/provisioners/container_instance.py index 4bbace63fa27..b364359912c5 100644 --- a/src/prefect/infrastructure/provisioners/container_instance.py +++ b/src/prefect/infrastructure/provisioners/container_instance.py @@ -10,6 +10,7 @@ ContainerInstancePushProvisioner: A class for provisioning infrastructure using Azure Container Instances. """ + from __future__ import annotations import json diff --git a/src/prefect/infrastructure/provisioners/ecs.py b/src/prefect/infrastructure/provisioners/ecs.py index 2f6cb1460f87..1f11ae314176 100644 --- a/src/prefect/infrastructure/provisioners/ecs.py +++ b/src/prefect/infrastructure/provisioners/ecs.py @@ -341,9 +341,9 @@ async def provision( block_type_id=credentials_block_type.id ) ) - assert ( - credentials_block_schema is not None - ), f"Unable to find schema for block type {credentials_block_type.slug}" + assert credentials_block_schema is not None, ( + f"Unable to find schema for block type {credentials_block_type.slug}" + ) block_doc = await client.create_block_document( block_document=BlockDocumentCreate( @@ -597,9 +597,9 @@ async def provision( ) advance() - base_job_template["variables"]["properties"]["cluster"][ - "default" - ] = self._cluster_name + base_job_template["variables"]["properties"]["cluster"]["default"] = ( + self._cluster_name + ) @property def next_steps(self) -> list[str]: diff --git a/src/prefect/infrastructure/provisioners/modal.py b/src/prefect/infrastructure/provisioners/modal.py index 04960d9c5d89..68e4a1506ee5 100644 --- a/src/prefect/infrastructure/provisioners/modal.py +++ b/src/prefect/infrastructure/provisioners/modal.py @@ -124,9 +124,9 @@ async def _create_modal_credentials_block( block_type_id=credentials_block_type.id ) ) - assert ( - credentials_block_schema is not None - ), f"Unable to find schema for block type {credentials_block_type.slug}" + assert credentials_block_schema is not None, ( + f"Unable to find schema for block type {credentials_block_type.slug}" + ) block_doc = await client.create_block_document( block_document=BlockDocumentCreate( diff --git a/src/prefect/input/run_input.py b/src/prefect/input/run_input.py index 1567b0d2ff72..caf6efe024ec 100644 --- a/src/prefect/input/run_input.py +++ b/src/prefect/input/run_input.py @@ -664,8 +664,7 @@ def receive_input( # type: ignore[overload-overlap] key_prefix: Optional[str] = None, flow_run_id: Optional[UUID] = None, with_metadata: bool = False, -) -> GetInputHandler[R]: - ... +) -> GetInputHandler[R]: ... @overload @@ -678,8 +677,7 @@ def receive_input( key_prefix: Optional[str] = None, flow_run_id: Optional[UUID] = None, with_metadata: bool = False, -) -> GetAutomaticInputHandler[T]: - ... +) -> GetAutomaticInputHandler[T]: ... def receive_input( @@ -697,9 +695,9 @@ def receive_input( # the signature is the same as here: # Union[Type[R], Type[T], pydantic.BaseModel], # Seems like a possible mypy bug, so we'll ignore the type check here. - input_cls: Union[ - Type[AutomaticRunInput[T]], Type[R] - ] = run_input_subclass_from_type(input_type) # type: ignore[arg-type] + input_cls: Union[Type[AutomaticRunInput[T]], Type[R]] = ( + run_input_subclass_from_type(input_type) + ) # type: ignore[arg-type] if issubclass(input_cls, AutomaticRunInput): return input_cls.receive( diff --git a/src/prefect/results.py b/src/prefect/results.py index 845917a18363..52f14c50d243 100644 --- a/src/prefect/results.py +++ b/src/prefect/results.py @@ -575,19 +575,19 @@ async def _read(self, key: str, holder: str) -> "ResultRecord[Any]": {}, ) metadata = ResultRecordMetadata.load_bytes(metadata_content) - assert ( - metadata.storage_key is not None - ), "Did not find storage key in metadata" + assert metadata.storage_key is not None, ( + "Did not find storage key in metadata" + ) result_content = await _call_explicitly_async_block_method( self.result_storage, "read_path", (metadata.storage_key,), {}, ) - result_record: ResultRecord[ - Any - ] = ResultRecord.deserialize_from_result_and_metadata( - result=result_content, metadata=metadata_content + result_record: ResultRecord[Any] = ( + ResultRecord.deserialize_from_result_and_metadata( + result=result_content, metadata=metadata_content + ) ) await emit_result_read_event(self, resolved_key_path) else: @@ -739,9 +739,9 @@ async def _persist_result_record( result_record: The result record to persist. holder: The holder of the lock if a lock was set on the record. """ - assert ( - result_record.metadata.storage_key is not None - ), "Storage key is required on result record" + assert result_record.metadata.storage_key is not None, ( + "Storage key is required on result record" + ) key = result_record.metadata.storage_key if result_record.metadata.storage_block_id is None: diff --git a/src/prefect/runner/server.py b/src/prefect/runner/server.py index b58cbf02e4d5..97c84fa01f78 100644 --- a/src/prefect/runner/server.py +++ b/src/prefect/runner/server.py @@ -139,9 +139,9 @@ async def get_deployment_router( ) # Used for updating the route schemas later on - schemas[ - f"{deployment.name}-{deployment_id}" - ] = deployment.parameter_openapi_schema + schemas[f"{deployment.name}-{deployment_id}"] = ( + deployment.parameter_openapi_schema + ) schemas[deployment_id] = deployment.name return router, schemas diff --git a/src/prefect/runner/storage.py b/src/prefect/runner/storage.py index 031ab1a05139..9f9cf53778c7 100644 --- a/src/prefect/runner/storage.py +++ b/src/prefect/runner/storage.py @@ -347,13 +347,13 @@ def to_pull_step(self) -> dict[str, Any]: } } if self._include_submodules: - pull_step["prefect.deployments.steps.git_clone"][ - "include_submodules" - ] = self._include_submodules + pull_step["prefect.deployments.steps.git_clone"]["include_submodules"] = ( + self._include_submodules + ) if isinstance(self._credentials, Block): - pull_step["prefect.deployments.steps.git_clone"][ - "credentials" - ] = f"{{{{ {self._credentials.get_block_placeholder()} }}}}" + pull_step["prefect.deployments.steps.git_clone"]["credentials"] = ( + f"{{{{ {self._credentials.get_block_placeholder()} }}}}" + ) elif isinstance(self._credentials, dict): if isinstance(self._credentials.get("access_token"), Secret): pull_step["prefect.deployments.steps.git_clone"]["credentials"] = { @@ -546,9 +546,9 @@ def replace_block_with_placeholder(obj: Any) -> Any: } } if required_package: - step["prefect.deployments.steps.pull_from_remote_storage"][ - "requires" - ] = required_package + step["prefect.deployments.steps.pull_from_remote_storage"]["requires"] = ( + required_package + ) return step def __eq__(self, __value: Any) -> bool: diff --git a/src/prefect/runner/submit.py b/src/prefect/runner/submit.py index 56c14b392e98..b3d592bc2a63 100644 --- a/src/prefect/runner/submit.py +++ b/src/prefect/runner/submit.py @@ -101,8 +101,7 @@ def submit_to_runner( prefect_callable: Union[Flow[Any, Any], Task[Any, Any]], parameters: Dict[str, Any], retry_failed_submissions: bool = True, -) -> FlowRun: - ... +) -> FlowRun: ... @overload @@ -110,8 +109,7 @@ def submit_to_runner( prefect_callable: Union[Flow[Any, Any], Task[Any, Any]], parameters: list[dict[str, Any]], retry_failed_submissions: bool = True, -) -> list[FlowRun]: - ... +) -> list[FlowRun]: ... @sync_compatible diff --git a/src/prefect/serializers.py b/src/prefect/serializers.py index 02e4b6f1801d..de49230f00c9 100644 --- a/src/prefect/serializers.py +++ b/src/prefect/serializers.py @@ -75,17 +75,17 @@ class Serializer(BaseModel, Generic[D]): """ def __init__(self, **data: Any) -> None: - type_string = get_dispatch_key(self) if type(self) != Serializer else "__base__" + type_string = ( + get_dispatch_key(self) if type(self) is not Serializer else "__base__" + ) data.setdefault("type", type_string) super().__init__(**data) @overload - def __new__(cls, *, type: str, **kwargs: Any) -> "Serializer[Any]": - ... + def __new__(cls, *, type: str, **kwargs: Any) -> "Serializer[Any]": ... @overload - def __new__(cls, *, type: None = ..., **kwargs: Any) -> Self: - ... + def __new__(cls, *, type: None = ..., **kwargs: Any) -> Self: ... def __new__(cls, **kwargs: Any) -> Union[Self, "Serializer[Any]"]: if type_ := kwargs.get("type"): diff --git a/src/prefect/server/api/dependencies.py b/src/prefect/server/api/dependencies.py index b20db69cf127..eadf1d134c77 100644 --- a/src/prefect/server/api/dependencies.py +++ b/src/prefect/server/api/dependencies.py @@ -1,6 +1,7 @@ """ Utilities for injecting FastAPI dependencies. """ + from __future__ import annotations import logging diff --git a/src/prefect/server/database/_migrations/versions/postgresql/2024_08_14_150111_97429116795e_add_deployment_concurrency_limit.py b/src/prefect/server/database/_migrations/versions/postgresql/2024_08_14_150111_97429116795e_add_deployment_concurrency_limit.py index 86984afb8563..dd6b8bf45e3b 100644 --- a/src/prefect/server/database/_migrations/versions/postgresql/2024_08_14_150111_97429116795e_add_deployment_concurrency_limit.py +++ b/src/prefect/server/database/_migrations/versions/postgresql/2024_08_14_150111_97429116795e_add_deployment_concurrency_limit.py @@ -5,6 +5,7 @@ Create Date: 2024-08-14 15:01:11.152219 """ + import sqlalchemy as sa from alembic import op diff --git a/src/prefect/server/database/_migrations/versions/postgresql/2024_09_11_090317_555ed31b284d_add_concurrency_options.py b/src/prefect/server/database/_migrations/versions/postgresql/2024_09_11_090317_555ed31b284d_add_concurrency_options.py index 35cd5e051300..5949d9eebece 100644 --- a/src/prefect/server/database/_migrations/versions/postgresql/2024_09_11_090317_555ed31b284d_add_concurrency_options.py +++ b/src/prefect/server/database/_migrations/versions/postgresql/2024_09_11_090317_555ed31b284d_add_concurrency_options.py @@ -5,6 +5,7 @@ Create Date: 2024-09-11 09:03:17.744587 """ + import sqlalchemy as sa from alembic import op diff --git a/src/prefect/server/database/_migrations/versions/postgresql/2024_09_16_152051_eaec5004771f_add_deployment_to_global_concurrency_.py b/src/prefect/server/database/_migrations/versions/postgresql/2024_09_16_152051_eaec5004771f_add_deployment_to_global_concurrency_.py index 6f23b07842ce..343cd6713568 100644 --- a/src/prefect/server/database/_migrations/versions/postgresql/2024_09_16_152051_eaec5004771f_add_deployment_to_global_concurrency_.py +++ b/src/prefect/server/database/_migrations/versions/postgresql/2024_09_16_152051_eaec5004771f_add_deployment_to_global_concurrency_.py @@ -5,6 +5,7 @@ Create Date: 2024-09-16 15:20:51.582204 """ + import sqlalchemy as sa from alembic import op diff --git a/src/prefect/server/database/_migrations/versions/sqlite/2024_08_14_145052_f93e1439f022_add_deployment_concurrency_limit.py b/src/prefect/server/database/_migrations/versions/sqlite/2024_08_14_145052_f93e1439f022_add_deployment_concurrency_limit.py index 1a1aae909d94..5f23c0b7da08 100644 --- a/src/prefect/server/database/_migrations/versions/sqlite/2024_08_14_145052_f93e1439f022_add_deployment_concurrency_limit.py +++ b/src/prefect/server/database/_migrations/versions/sqlite/2024_08_14_145052_f93e1439f022_add_deployment_concurrency_limit.py @@ -5,6 +5,7 @@ Create Date: 2024-08-14 14:50:52.420436 """ + import sqlalchemy as sa from alembic import op diff --git a/src/prefect/server/database/_migrations/versions/sqlite/2024_09_11_090106_7d6350aea855_add_concurrency_options.py b/src/prefect/server/database/_migrations/versions/sqlite/2024_09_11_090106_7d6350aea855_add_concurrency_options.py index aad95debdf64..9f6bbd958f23 100644 --- a/src/prefect/server/database/_migrations/versions/sqlite/2024_09_11_090106_7d6350aea855_add_concurrency_options.py +++ b/src/prefect/server/database/_migrations/versions/sqlite/2024_09_11_090106_7d6350aea855_add_concurrency_options.py @@ -5,6 +5,7 @@ Create Date: 2024-09-11 09:01:06.678866 """ + import sqlalchemy as sa from alembic import op diff --git a/src/prefect/server/database/configurations.py b/src/prefect/server/database/configurations.py index 5ff2fc8089fd..251cb4457ed3 100644 --- a/src/prefect/server/database/configurations.py +++ b/src/prefect/server/database/configurations.py @@ -206,10 +206,10 @@ async def engine(self) -> AsyncEngine: self.timeout, ) if cache_key not in ENGINES: - kwargs: dict[ - str, Any - ] = get_current_settings().server.database.sqlalchemy.model_dump( - mode="json", + kwargs: dict[str, Any] = ( + get_current_settings().server.database.sqlalchemy.model_dump( + mode="json", + ) ) connect_args: dict[str, Any] = kwargs.pop("connect_args") app_name = connect_args.pop("application_name", None) diff --git a/src/prefect/server/database/dependencies.py b/src/prefect/server/database/dependencies.py index 6dc3082df930..cc68a3611ad3 100644 --- a/src/prefect/server/database/dependencies.py +++ b/src/prefect/server/database/dependencies.py @@ -168,13 +168,11 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: @overload -def db_injector(func: _DBMethod[T, P, R]) -> _Method[T, P, R]: - ... +def db_injector(func: _DBMethod[T, P, R]) -> _Method[T, P, R]: ... @overload -def db_injector(func: _DBFunction[P, R]) -> _Function[P, R]: - ... +def db_injector(func: _DBFunction[P, R]) -> _Function[P, R]: ... def db_injector( @@ -265,17 +263,14 @@ class DBInjector( if TYPE_CHECKING: @overload - def __new__(cls, func: _DBMethod[T, P, R]) -> "DBInjector[T, P, R]": - ... + def __new__(cls, func: _DBMethod[T, P, R]) -> "DBInjector[T, P, R]": ... @overload - def __new__(cls, func: _DBFunction[P, R]) -> "DBInjector[None, P, R]": - ... + def __new__(cls, func: _DBFunction[P, R]) -> "DBInjector[None, P, R]": ... def __new__( cls, func: Union[_DBMethod[T, P, R], _DBFunction[P, R]] - ) -> Union["DBInjector[T, P, R]", "DBInjector[None, P, R]"]: - ... + ) -> Union["DBInjector[T, P, R]", "DBInjector[None, P, R]"]: ... def __init__(self, func: Union[_DBMethod[T, P, R], _DBFunction[P, R]]) -> None: super().__init__(cast(Callable[P, R], func)) @@ -289,18 +284,15 @@ def __set_name__(self, owner: type[T], name: str) -> None: object.__setattr__(self, "__name__", name) @overload - def __get__(self, instance: None, owner: type[T]) -> Self: - ... + def __get__(self, instance: None, owner: type[T]) -> Self: ... @overload def __get__( self, instance: T, owner: Optional[type[T]] = None - ) -> "_DBInjectorMethod[T, P, R]": - ... + ) -> "_DBInjectorMethod[T, P, R]": ... @overload - def __get__(self, instance: None, owner: None) -> Never: - ... + def __get__(self, instance: None, owner: None) -> Never: ... def __get__( self, instance: Optional[T], owner: Optional[type[T]] = None diff --git a/src/prefect/server/database/interface.py b/src/prefect/server/database/interface.py index 239064b7fe14..f56f08512370 100644 --- a/src/prefect/server/database/interface.py +++ b/src/prefect/server/database/interface.py @@ -21,9 +21,9 @@ class DBSingleton(type): """Ensures that only one database interface is created per unique key""" - _instances: dict[ - tuple[str, _UniqueKey, _UniqueKey, _UniqueKey], "DBSingleton" - ] = dict() + _instances: dict[tuple[str, _UniqueKey, _UniqueKey, _UniqueKey], "DBSingleton"] = ( + dict() + ) def __call__( cls, diff --git a/src/prefect/server/database/orm_models.py b/src/prefect/server/database/orm_models.py index ef504438ca08..0202c28c0a74 100644 --- a/src/prefect/server/database/orm_models.py +++ b/src/prefect/server/database/orm_models.py @@ -828,18 +828,18 @@ def job_variables(self) -> Mapped[dict[str, Any]]: concurrency_limit_id: Mapped[Optional[uuid.UUID]] = mapped_column( sa.ForeignKey("concurrency_limit_v2.id", ondelete="SET NULL"), ) - global_concurrency_limit: Mapped[ - Optional["ConcurrencyLimitV2"] - ] = sa.orm.relationship( - lazy="selectin", + global_concurrency_limit: Mapped[Optional["ConcurrencyLimitV2"]] = ( + sa.orm.relationship( + lazy="selectin", + ) ) - concurrency_options: Mapped[ - Optional[schemas.core.ConcurrencyOptions] - ] = mapped_column( - Pydantic(schemas.core.ConcurrencyOptions), - server_default=None, - nullable=True, - default=None, + concurrency_options: Mapped[Optional[schemas.core.ConcurrencyOptions]] = ( + mapped_column( + Pydantic(schemas.core.ConcurrencyOptions), + server_default=None, + nullable=True, + default=None, + ) ) tags: Mapped[list[str]] = mapped_column(JSON, server_default="[]", default=list) diff --git a/src/prefect/server/database/query_components.py b/src/prefect/server/database/query_components.py index 0f4fe13fa446..e0a493e6819f 100644 --- a/src/prefect/server/database/query_components.py +++ b/src/prefect/server/database/query_components.py @@ -129,16 +129,14 @@ def make_timestamp_intervals( start_time: pendulum.DateTime, end_time: pendulum.DateTime, interval: datetime.timedelta, - ) -> sa.Select[tuple[pendulum.DateTime, pendulum.DateTime]]: - ... + ) -> sa.Select[tuple[pendulum.DateTime, pendulum.DateTime]]: ... @abstractmethod def set_state_id_on_inserted_flow_runs_statement( self, inserted_flow_run_ids: Sequence[UUID], insert_flow_run_states: Iterable[dict[str, Any]], - ) -> sa.Update: - ... + ) -> sa.Update: ... @abstractmethod async def get_flow_run_notifications_from_queue( diff --git a/src/prefect/server/events/clients.py b/src/prefect/server/events/clients.py index b8d90ab40d56..72a145cda643 100644 --- a/src/prefect/server/events/clients.py +++ b/src/prefect/server/events/clients.py @@ -38,8 +38,7 @@ class EventsClient(abc.ABC): """The abstract interface for a Prefect Events client""" @abc.abstractmethod - async def emit(self, event: Event) -> Optional[Event]: - ... + async def emit(self, event: Event) -> Optional[Event]: ... async def __aenter__(self) -> Self: return self @@ -113,9 +112,9 @@ def emitted_events_count(cls) -> int: def assert_emitted_event_count(cls, count: int) -> None: """Assert that the given number of events were emitted.""" total_num_events = cls.emitted_events_count() - assert ( - total_num_events == count - ), f"The number of emitted events did not match the expected count: {total_num_events=} != {count=}" + assert total_num_events == count, ( + f"The number of emitted events did not match the expected count: {total_num_events=} != {count=}" + ) @classmethod def assert_emitted_event_with( diff --git a/src/prefect/server/events/ordering.py b/src/prefect/server/events/ordering.py index 3c20f4b76a27..1e8cea139a00 100644 --- a/src/prefect/server/events/ordering.py +++ b/src/prefect/server/events/ordering.py @@ -54,8 +54,9 @@ def __init__(self, event: ReceivedEvent): class event_handler(Protocol): - async def __call__(self, event: ReceivedEvent, depth: int = 0) -> None: - ... # pragma: no cover + async def __call__( + self, event: ReceivedEvent, depth: int = 0 + ) -> None: ... # pragma: no cover class CausalOrdering: diff --git a/src/prefect/server/events/schemas/automations.py b/src/prefect/server/events/schemas/automations.py index ad59d57a0c23..a008ca3e65b5 100644 --- a/src/prefect/server/events/schemas/automations.py +++ b/src/prefect/server/events/schemas/automations.py @@ -109,8 +109,7 @@ def all_triggers(self) -> Sequence[Trigger]: @abc.abstractmethod def create_automation_state_change_event( self, firing: "Firing", trigger_state: TriggerState - ) -> ReceivedEvent: - ... + ) -> ReceivedEvent: ... class CompositeTrigger(Trigger, abc.ABC): @@ -174,8 +173,7 @@ def num_expected_firings(self) -> int: return len(self.triggers) @abc.abstractmethod - def ready_to_fire(self, firings: Sequence["Firing"]) -> bool: - ... + def ready_to_fire(self, firings: Sequence["Firing"]) -> bool: ... class CompoundTrigger(CompositeTrigger): diff --git a/src/prefect/server/models/block_documents.py b/src/prefect/server/models/block_documents.py index 638d2cb9183e..710108c39938 100644 --- a/src/prefect/server/models/block_documents.py +++ b/src/prefect/server/models/block_documents.py @@ -563,9 +563,9 @@ async def update_block_document( proposed_block_schema = await session.get( db.BlockSchema, proposed_block_schema_id ) - assert ( - proposed_block_schema - ), f"Block schema {proposed_block_schema_id} not found" + assert proposed_block_schema, ( + f"Block schema {proposed_block_schema_id} not found" + ) # make sure the proposed schema is of the same block type as the current document if ( diff --git a/src/prefect/server/models/block_registration.py b/src/prefect/server/models/block_registration.py index 76cdb66b897d..d801e4165b26 100644 --- a/src/prefect/server/models/block_registration.py +++ b/src/prefect/server/models/block_registration.py @@ -44,9 +44,9 @@ async def _install_protected_system_blocks(session: AsyncSession) -> None: orm_block_type = await models.block_types.create_block_type( session=session, block_type=server_block_type, override=True ) - assert ( - orm_block_type is not None - ), f"Failed to create block type {block_type}" + assert orm_block_type is not None, ( + f"Failed to create block type {block_type}" + ) await models.block_schemas.create_block_schema( session=session, @@ -174,7 +174,7 @@ async def _register_collection_blocks(session: AsyncSession) -> None: # due to schema reference dependencies, we need to register all block types first # and then register all block schemas - block_schemas: dict[str, dict] = {} + block_schemas: dict[str, dict[str, Any]] = {} async with session.begin(): for block_type in block_types: diff --git a/src/prefect/server/models/block_schemas.py b/src/prefect/server/models/block_schemas.py index dee1cb34f4c9..a3d9e30cf73e 100644 --- a/src/prefect/server/models/block_schemas.py +++ b/src/prefect/server/models/block_schemas.py @@ -97,10 +97,10 @@ async def create_block_schema( insert_values["fields"], definitions ) if non_block_definitions: - insert_values["fields"][ - "definitions" - ] = _get_non_block_reference_definitions( - insert_values["fields"], definitions + insert_values["fields"]["definitions"] = ( + _get_non_block_reference_definitions( + insert_values["fields"], definitions + ) ) else: # Prevent storing definitions for blocks. Those are reconstructed on read. @@ -157,7 +157,7 @@ async def _register_nested_block_schemas( db: PrefectDBInterface, session: AsyncSession, parent_block_schema_id: UUID, - block_schema_references: Dict[str, Union[Dict[str, str], List[Dict[str, str]]]], + block_schema_references: dict[str, Union[dict[str, str], List[dict[str, str]]]], base_fields: Dict, definitions: Optional[Dict], override: bool = False, @@ -248,7 +248,7 @@ def _get_fields_for_child_schema( base_fields: Dict, reference_name: str, reference_block_type: orm_models.BlockType, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Returns the field definitions for a child schema. The fields definitions are pulled from the provided `definitions` dictionary based on the information extracted from `base_fields` using the `reference_name`. `reference_block_type` @@ -447,7 +447,7 @@ def _construct_block_schema_spec_definitions( block_schemas_with_references: List[ Tuple[BlockSchema, Optional[str], Optional[UUID]] ], -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Constructs field definitions for a block schema based on the nested block schemas as defined in the block_schemas_with_references list. @@ -498,7 +498,7 @@ def _find_block_schema_via_checksum( def _add_block_schemas_fields_to_definitions( definitions: Dict, child_block_schema: BlockSchema -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Returns a new definitions dict with the fields of a block schema and it's child block schemas added to the existing definitions. @@ -522,7 +522,7 @@ def _construct_block_schema_fields_with_block_references( block_schemas_with_references: List[ Tuple[BlockSchema, Optional[str], Optional[UUID]] ], -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Constructs the block_schema_references in a block schema's fields attributes. Returns a copy of the block schema with block_schema_references added. @@ -548,9 +548,9 @@ def _construct_block_schema_fields_with_block_references( parent_block_schema_id, ) in block_schemas_with_references: if parent_block_schema_id == parent_block_schema.id: - assert ( - nested_block_schema.block_type - ), f"{nested_block_schema} has no block type" + assert nested_block_schema.block_type, ( + f"{nested_block_schema} has no block type" + ) new_block_schema_reference = { "block_schema_checksum": nested_block_schema.checksum, @@ -558,9 +558,9 @@ def _construct_block_schema_fields_with_block_references( } # A block reference for this key does not yet exist if name not in block_schema_fields_copy["block_schema_references"]: - block_schema_fields_copy["block_schema_references"][ - name - ] = new_block_schema_reference + block_schema_fields_copy["block_schema_references"][name] = ( + new_block_schema_reference + ) else: # List of block references for this key already exist and the block # reference that we are attempting add isn't present diff --git a/src/prefect/server/orchestration/core_policy.py b/src/prefect/server/orchestration/core_policy.py index fd99ab14a095..aa9583ae1b29 100644 --- a/src/prefect/server/orchestration/core_policy.py +++ b/src/prefect/server/orchestration/core_policy.py @@ -57,14 +57,12 @@ class CoreFlowPolicyWithoutDeploymentConcurrency(FlowRunOrchestrationPolicy): """ @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy],], - type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy],], + type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], ] - ): + ]: return cast( list[ Union[ @@ -97,14 +95,12 @@ class CoreFlowPolicy(FlowRunOrchestrationPolicy): """ @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], - type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], + type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], ] - ): + ]: return cast( list[ Union[ @@ -139,14 +135,12 @@ class CoreTaskPolicy(TaskRunOrchestrationPolicy): """ @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], - type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], + type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], ] - ): + ]: return cast( list[ Union[ @@ -179,14 +173,12 @@ class ClientSideTaskOrchestrationPolicy(TaskRunOrchestrationPolicy): """ @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], - type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], + type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], ] - ): + ]: return cast( list[ Union[ @@ -217,12 +209,10 @@ class BackgroundTaskPolicy(TaskRunOrchestrationPolicy): """ @staticmethod - def priority() -> ( - list[ - type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]] - | type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]] - ] - ): + def priority() -> list[ + type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]] + | type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]] + ]: return cast( list[ Union[ @@ -252,14 +242,12 @@ def priority() -> ( class MinimalFlowPolicy(FlowRunOrchestrationPolicy): @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], - type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], + type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], ] - ): + ]: return [ BypassCancellingFlowRunsWithNoInfra, # cancel scheduled or suspended runs from the UI InstrumentFlowRunStateTransitions, @@ -269,14 +257,12 @@ def priority() -> ( class MarkLateRunsPolicy(FlowRunOrchestrationPolicy): @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], - type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], + type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], ] - ): + ]: return [ EnsureOnlyScheduledFlowsMarkedLate, InstrumentFlowRunStateTransitions, @@ -285,14 +271,12 @@ def priority() -> ( class MinimalTaskPolicy(TaskRunOrchestrationPolicy): @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], - type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], + type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], ] - ): + ]: return [ ReleaseTaskConcurrencySlots, # always release concurrency slots ] @@ -1220,9 +1204,9 @@ async def before_transition( if initial_state is None or proposed_state is None: return - self.original_flow_policy: dict[ - str, Any - ] = context.run.empirical_policy.model_dump() + self.original_flow_policy: dict[str, Any] = ( + context.run.empirical_policy.model_dump() + ) # Do not allow runs to be marked as crashed, paused, or cancelling if already terminal if proposed_state.type in { diff --git a/src/prefect/server/orchestration/dependencies.py b/src/prefect/server/orchestration/dependencies.py index a96808989b69..d0877b3f1a9b 100644 --- a/src/prefect/server/orchestration/dependencies.py +++ b/src/prefect/server/orchestration/dependencies.py @@ -137,14 +137,14 @@ async def parameter_lambda(): return tmp_orchestration_parameters try: - ORCHESTRATION_DEPENDENCIES[ - "task_orchestration_parameters_provider" - ] = parameter_lambda + ORCHESTRATION_DEPENDENCIES["task_orchestration_parameters_provider"] = ( + parameter_lambda + ) yield finally: - ORCHESTRATION_DEPENDENCIES[ - "task_orchestration_parameters_provider" - ] = starting_task_orchestration_parameters + ORCHESTRATION_DEPENDENCIES["task_orchestration_parameters_provider"] = ( + starting_task_orchestration_parameters + ) @contextmanager @@ -159,11 +159,11 @@ async def parameter_lambda(): return tmp_orchestration_parameters try: - ORCHESTRATION_DEPENDENCIES[ - "flow_orchestration_parameters_provider" - ] = parameter_lambda + ORCHESTRATION_DEPENDENCIES["flow_orchestration_parameters_provider"] = ( + parameter_lambda + ) yield finally: - ORCHESTRATION_DEPENDENCIES[ - "flow_orchestration_parameters_provider" - ] = starting_flow_orchestration_parameters + ORCHESTRATION_DEPENDENCIES["flow_orchestration_parameters_provider"] = ( + starting_flow_orchestration_parameters + ) diff --git a/src/prefect/server/orchestration/global_policy.py b/src/prefect/server/orchestration/global_policy.py index 786d2ca6a916..a2e59ba38849 100644 --- a/src/prefect/server/orchestration/global_policy.py +++ b/src/prefect/server/orchestration/global_policy.py @@ -32,15 +32,13 @@ from prefect.server.schemas.core import FlowRunPolicy -def COMMON_GLOBAL_TRANSFORMS() -> ( - list[ - type[ - BaseUniversalTransform[ - orm_models.Run, Union[core.FlowRunPolicy, core.TaskRunPolicy] - ] +def COMMON_GLOBAL_TRANSFORMS() -> list[ + type[ + BaseUniversalTransform[ + orm_models.Run, Union[core.FlowRunPolicy, core.TaskRunPolicy] ] ] -): +]: return [ SetRunStateType, SetRunStateName, @@ -63,14 +61,12 @@ class GlobalFlowPolicy(BaseOrchestrationPolicy[orm_models.FlowRun, core.FlowRunP """ @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], - type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.FlowRun, core.FlowRunPolicy]], + type[BaseOrchestrationRule[orm_models.FlowRun, core.FlowRunPolicy]], ] - ): + ]: return cast( list[ Union[ @@ -98,14 +94,12 @@ class GlobalTaskPolicy(BaseOrchestrationPolicy[orm_models.TaskRun, core.TaskRunP """ @staticmethod - def priority() -> ( - list[ - Union[ - type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], - type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], - ] + def priority() -> list[ + Union[ + type[BaseUniversalTransform[orm_models.TaskRun, core.TaskRunPolicy]], + type[BaseOrchestrationRule[orm_models.TaskRun, core.TaskRunPolicy]], ] - ): + ]: return cast( list[ Union[ diff --git a/src/prefect/server/orchestration/policies.py b/src/prefect/server/orchestration/policies.py index d590bc0df11a..ebd3e2d0fb83 100644 --- a/src/prefect/server/orchestration/policies.py +++ b/src/prefect/server/orchestration/policies.py @@ -39,9 +39,9 @@ class BaseOrchestrationPolicy(ABC, Generic[T, RP]): @staticmethod @abstractmethod - def priority() -> ( - list[type[BaseUniversalTransform[T, RP] | BaseOrchestrationRule[T, RP]]] - ): + def priority() -> list[ + type[BaseUniversalTransform[T, RP] | BaseOrchestrationRule[T, RP]] + ]: """ A list of orchestration rules in priority order. """ diff --git a/src/prefect/server/schemas/core.py b/src/prefect/server/schemas/core.py index 029455c668d8..a771d1988a0c 100644 --- a/src/prefect/server/schemas/core.py +++ b/src/prefect/server/schemas/core.py @@ -1,6 +1,7 @@ """ Full schemas of Prefect REST API objects. """ + from __future__ import annotations import datetime diff --git a/src/prefect/server/schemas/schedules.py b/src/prefect/server/schemas/schedules.py index 438c07d8b6db..cfcabbea42f1 100644 --- a/src/prefect/server/schemas/schedules.py +++ b/src/prefect/server/schemas/schedules.py @@ -1,6 +1,7 @@ """ Schedule schemas """ + from __future__ import annotations import datetime diff --git a/src/prefect/server/schemas/states.py b/src/prefect/server/schemas/states.py index d0ac008f04f0..1407ebd3d403 100644 --- a/src/prefect/server/schemas/states.py +++ b/src/prefect/server/schemas/states.py @@ -225,20 +225,19 @@ def fresh_copy(self, **kwargs: Any) -> Self: ) @overload - def result(self, raise_on_failure: Literal[True] = ..., fetch: bool = ...) -> Any: - ... + def result( + self, raise_on_failure: Literal[True] = ..., fetch: bool = ... + ) -> Any: ... @overload def result( self, raise_on_failure: Literal[False] = False, fetch: bool = ... - ) -> Union[Any, Exception]: - ... + ) -> Union[Any, Exception]: ... @overload def result( self, raise_on_failure: bool = ..., fetch: bool = ... - ) -> Union[Any, Exception]: - ... + ) -> Union[Any, Exception]: ... def result( self, raise_on_failure: bool = True, fetch: bool = True diff --git a/src/prefect/server/services/flow_run_notifications.py b/src/prefect/server/services/flow_run_notifications.py index bbe8fe2ebd6e..2bbabd19a634 100644 --- a/src/prefect/server/services/flow_run_notifications.py +++ b/src/prefect/server/services/flow_run_notifications.py @@ -54,9 +54,9 @@ async def run_once(self, db: PrefectDBInterface) -> None: # all retrieved notifications are deleted, assert that we only got one # since we only send the first notification returned - assert ( - len(notifications) == 1 - ), "Expected one notification; query limit not respected." + assert len(notifications) == 1, ( + "Expected one notification; query limit not respected." + ) try: await self.send_flow_run_notification( diff --git a/src/prefect/server/utilities/database.py b/src/prefect/server/utilities/database.py index d133037f7bad..6fa18d107cd6 100644 --- a/src/prefect/server/utilities/database.py +++ b/src/prefect/server/utilities/database.py @@ -194,9 +194,9 @@ class JSON(TypeDecorator[Any]): to SQL compilation """ - impl: type[postgresql.JSONB] | type[TypeEngine[Any]] | TypeEngine[ - Any - ] = postgresql.JSONB + impl: type[postgresql.JSONB] | type[TypeEngine[Any]] | TypeEngine[Any] = ( + postgresql.JSONB + ) cache_ok: bool | None = True def load_dialect_impl(self, dialect: sa.Dialect) -> TypeEngine[Any]: @@ -242,8 +242,7 @@ def __init__( self, pydantic_type: type[T], sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, - ) -> None: - ... + ) -> None: ... # This overload is needed to allow for typing special forms (e.g. # Union[...], etc.) as these can't be married with `type[...]`. Also see @@ -253,8 +252,7 @@ def __init__( self: "Pydantic[Any]", pydantic_type: Any, sa_column_type: Optional[Union[type[TypeEngine[Any]], TypeEngine[Any]]] = None, - ) -> None: - ... + ) -> None: ... def __init__( self, @@ -264,9 +262,9 @@ def __init__( super().__init__() self._pydantic_type = pydantic_type if sa_column_type is not None: - self.impl: type[JSON] | type[TypeEngine[Any]] | TypeEngine[ - Any - ] = sa_column_type + self.impl: type[JSON] | type[TypeEngine[Any]] | TypeEngine[Any] = ( + sa_column_type + ) def process_bind_param( self, value: Optional[T], dialect: sa.Dialect diff --git a/src/prefect/server/utilities/messaging/__init__.py b/src/prefect/server/utilities/messaging/__init__.py index 57e6fe760193..1a593adfc947 100644 --- a/src/prefect/server/utilities/messaging/__init__.py +++ b/src/prefect/server/utilities/messaging/__init__.py @@ -35,30 +35,25 @@ class Message(Protocol): """ @property - def data(self) -> Union[str, bytes]: - ... + def data(self) -> Union[str, bytes]: ... @property - def attributes(self) -> Mapping[str, Any]: - ... + def attributes(self) -> Mapping[str, Any]: ... class Cache(abc.ABC): @abc.abstractmethod - async def clear_recently_seen_messages(self) -> None: - ... + async def clear_recently_seen_messages(self) -> None: ... @abc.abstractmethod async def without_duplicates( self, attribute: str, messages: Iterable[M] - ) -> list[M]: - ... + ) -> list[M]: ... @abc.abstractmethod async def forget_duplicates( self, attribute: str, messages: Iterable[Message] - ) -> None: - ... + ) -> None: ... class Publisher(AbstractAsyncContextManager["Publisher"], abc.ABC): @@ -67,16 +62,15 @@ def __init__( topic: str, cache: Optional[Cache] = None, deduplicate_by: Optional[str] = None, - ) -> None: - ... + ) -> None: ... @abc.abstractmethod - async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None: - ... + async def publish_data( + self, data: bytes, attributes: Mapping[str, str] + ) -> None: ... @abc.abstractmethod - async def __aenter__(self) -> Self: - ... + async def __aenter__(self) -> Self: ... @abc.abstractmethod async def __aexit__( @@ -84,8 +78,7 @@ async def __aexit__( exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], - ) -> None: - ... + ) -> None: ... @dataclass diff --git a/src/prefect/settings/base.py b/src/prefect/settings/base.py index 6d3ff587e4ca..57ebbbd06ed9 100644 --- a/src/prefect/settings/base.py +++ b/src/prefect/settings/base.py @@ -108,9 +108,9 @@ def to_environment_variables( ) env_variables.update(child_env) elif (value := env.get(key)) is not None: - env_variables[ - f"{self.model_config.get('env_prefix')}{key.upper()}" - ] = _to_environment_variable_value(value) + env_variables[f"{self.model_config.get('env_prefix')}{key.upper()}"] = ( + _to_environment_variable_value(value) + ) return env_variables @model_serializer( diff --git a/src/prefect/settings/sources.py b/src/prefect/settings/sources.py index ae855307f4fd..81f0ebdda611 100644 --- a/src/prefect/settings/sources.py +++ b/src/prefect/settings/sources.py @@ -188,7 +188,7 @@ def get_field_value( self.field_is_complex(field), ) - name = f"{self.config.get('env_prefix','')}{field_name.upper()}" + name = f"{self.config.get('env_prefix', '')}{field_name.upper()}" value = self.profile_settings.get(name) return value, field_name, self.field_is_complex(field) @@ -266,9 +266,9 @@ def __init__( settings_cls: Type[BaseSettings], ): super().__init__(settings_cls) - self.toml_file_path: Path | str | Sequence[ - Path | str - ] | None = settings_cls.model_config.get("toml_file", DEFAULT_PREFECT_TOML_PATH) + self.toml_file_path: Path | str | Sequence[Path | str] | None = ( + settings_cls.model_config.get("toml_file", DEFAULT_PREFECT_TOML_PATH) + ) self.toml_data: dict[str, Any] = self._read_files(self.toml_file_path) self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get( "prefect_toml_table_header", tuple() diff --git a/src/prefect/task_runners.py b/src/prefect/task_runners.py index 9f42a8724cab..def7dd1fd7b5 100644 --- a/src/prefect/task_runners.py +++ b/src/prefect/task_runners.py @@ -84,8 +84,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> F: - ... + ) -> F: ... @overload @abc.abstractmethod @@ -95,8 +94,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> F: - ... + ) -> F: ... @abc.abstractmethod def submit( @@ -105,8 +103,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> F: - ... + ) -> F: ... def map( self, @@ -251,8 +248,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> PrefectConcurrentFuture[R]: - ... + ) -> PrefectConcurrentFuture[R]: ... @overload def submit( @@ -261,8 +257,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> PrefectConcurrentFuture[R]: - ... + ) -> PrefectConcurrentFuture[R]: ... def submit( self, @@ -337,8 +332,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, - ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: - ... + ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: ... @overload def map( @@ -346,8 +340,7 @@ def map( task: "Task[Any, R]", parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, - ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: - ... + ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: ... def map( self, @@ -402,8 +395,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> PrefectDistributedFuture[R]: - ... + ) -> PrefectDistributedFuture[R]: ... @overload def submit( @@ -412,8 +404,7 @@ def submit( parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, dependencies: dict[str, set[TaskRunInput]] | None = None, - ) -> PrefectDistributedFuture[R]: - ... + ) -> PrefectDistributedFuture[R]: ... def submit( self, @@ -458,8 +449,7 @@ def map( task: "Task[P, Coroutine[Any, Any, R]]", parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, - ) -> PrefectFutureList[PrefectDistributedFuture[R]]: - ... + ) -> PrefectFutureList[PrefectDistributedFuture[R]]: ... @overload def map( @@ -467,8 +457,7 @@ def map( task: "Task[Any, R]", parameters: dict[str, Any], wait_for: Iterable[PrefectFuture[Any]] | None = None, - ) -> PrefectFutureList[PrefectDistributedFuture[R]]: - ... + ) -> PrefectFutureList[PrefectDistributedFuture[R]]: ... def map( self, diff --git a/src/prefect/task_worker.py b/src/prefect/task_worker.py index a781fc61503d..3ae218916716 100644 --- a/src/prefect/task_worker.py +++ b/src/prefect/task_worker.py @@ -248,9 +248,9 @@ async def _subscribe_to_task_scheduling(self): token_acquired = await self._acquire_token(task_run.id) if token_acquired: - assert ( - self._runs_task_group is not None - ), "Task group was not initialized" + assert self._runs_task_group is not None, ( + "Task group was not initialized" + ) self._runs_task_group.start_soon( self._safe_submit_scheduled_task_run, task_run ) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 9b291f7a6686..3dc18357dbe9 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -226,8 +226,7 @@ def is_callback_with_parameters(cls, callable: Callable[..., str]) -> TypeIs[Sel sig = inspect.signature(callable) return "parameters" in sig.parameters - def __call__(self, parameters: dict[str, Any]) -> str: - ... + def __call__(self, parameters: dict[str, Any]) -> str: ... StateHookCallable: TypeAlias = Callable[ @@ -483,9 +482,9 @@ def __init__( f"Invalid `retry_delay_seconds` provided; must be an int, float, list or callable. Received type {type(retry_delay_seconds)}" ) else: - self.retry_delay_seconds: Union[ - float, int, list[float], None - ] = retry_delay_seconds + self.retry_delay_seconds: Union[float, int, list[float], None] = ( + retry_delay_seconds + ) if isinstance(self.retry_delay_seconds, list) and ( len(self.retry_delay_seconds) > 50 @@ -978,8 +977,7 @@ def __call__( self: "Task[P, R]", *args: P.args, **kwargs: P.kwargs, - ) -> R: - ... + ) -> R: ... # Keyword parameters `return_state` and `wait_for` aren't allowed after the # ParamSpec `*args` parameter, so we lose return type typing when either of @@ -992,8 +990,7 @@ def __call__( return_state: Literal[True] = True, wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> State[R]: - ... + ) -> State[R]: ... @overload def __call__( @@ -1002,8 +999,7 @@ def __call__( return_state: Literal[False] = False, wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> R: - ... + ) -> R: ... def __call__( self: "Union[Task[P, R], Task[P, NoReturn]]", @@ -1046,8 +1042,7 @@ def submit( self: "Task[P, R]", *args: P.args, **kwargs: P.kwargs, - ) -> PrefectFuture[R]: - ... + ) -> PrefectFuture[R]: ... @overload def submit( @@ -1056,8 +1051,7 @@ def submit( return_state: Literal[False], wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> PrefectFuture[R]: - ... + ) -> PrefectFuture[R]: ... @overload def submit( @@ -1066,8 +1060,7 @@ def submit( return_state: Literal[False], wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> PrefectFuture[R]: - ... + ) -> PrefectFuture[R]: ... @overload def submit( @@ -1076,8 +1069,7 @@ def submit( return_state: Literal[True], wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> State[R]: - ... + ) -> State[R]: ... @overload def submit( @@ -1086,8 +1078,7 @@ def submit( return_state: Literal[True], wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> State[R]: - ... + ) -> State[R]: ... def submit( self: "Union[Task[P, R], Task[P, Coroutine[Any, Any, R]]]", @@ -1221,8 +1212,7 @@ def map( wait_for: Optional[Iterable[Union[PrefectFuture[R], R]]] = ..., deferred: bool = ..., **kwargs: Any, - ) -> list[State[R]]: - ... + ) -> list[State[R]]: ... @overload def map( @@ -1231,8 +1221,7 @@ def map( wait_for: Optional[Iterable[Union[PrefectFuture[R], R]]] = ..., deferred: bool = ..., **kwargs: Any, - ) -> PrefectFutureList[R]: - ... + ) -> PrefectFutureList[R]: ... @overload def map( @@ -1242,8 +1231,7 @@ def map( wait_for: Optional[Iterable[Union[PrefectFuture[R], R]]] = ..., deferred: bool = ..., **kwargs: Any, - ) -> list[State[R]]: - ... + ) -> list[State[R]]: ... @overload def map( @@ -1252,8 +1240,7 @@ def map( wait_for: Optional[Iterable[Union[PrefectFuture[R], R]]] = ..., deferred: bool = ..., **kwargs: Any, - ) -> PrefectFutureList[R]: - ... + ) -> PrefectFutureList[R]: ... @overload def map( @@ -1263,8 +1250,7 @@ def map( wait_for: Optional[Iterable[Union[PrefectFuture[R], R]]] = ..., deferred: bool = ..., **kwargs: Any, - ) -> list[State[R]]: - ... + ) -> list[State[R]]: ... @overload def map( @@ -1274,8 +1260,7 @@ def map( wait_for: Optional[Iterable[Union[PrefectFuture[R], R]]] = ..., deferred: bool = ..., **kwargs: Any, - ) -> PrefectFutureList[R]: - ... + ) -> PrefectFutureList[R]: ... def map( self, @@ -1610,8 +1595,7 @@ async def serve(self) -> NoReturn: @overload -def task(__fn: Callable[P, R]) -> Task[P, R]: - ... +def task(__fn: Callable[P, R]) -> Task[P, R]: ... # see https://github.com/PrefectHQ/prefect/issues/16380 @@ -1646,8 +1630,7 @@ def task( on_failure: Optional[list[StateHookCallable]] = None, retry_condition_fn: Literal[None] = None, viz_return_value: Any = None, -) -> Callable[[Callable[P, R]], Task[P, R]]: - ... +) -> Callable[[Callable[P, R]], Task[P, R]]: ... # see https://github.com/PrefectHQ/prefect/issues/16380 @@ -1682,8 +1665,7 @@ def task( on_failure: Optional[list[StateHookCallable]] = None, retry_condition_fn: Optional[Callable[[Task[P, R], TaskRun, State], bool]] = None, viz_return_value: Any = None, -) -> Callable[[Callable[P, R]], Task[P, R]]: - ... +) -> Callable[[Callable[P, R]], Task[P, R]]: ... @overload # TODO: do we need this overload? @@ -1719,8 +1701,7 @@ def task( on_failure: Optional[list[StateHookCallable]] = None, retry_condition_fn: Optional[Callable[[Task[P, Any], TaskRun, State], bool]] = None, viz_return_value: Any = None, -) -> Callable[[Callable[P, R]], Task[P, R]]: - ... +) -> Callable[[Callable[P, R]], Task[P, R]]: ... def task( diff --git a/src/prefect/telemetry/bootstrap.py b/src/prefect/telemetry/bootstrap.py index bbd3c95a78ad..49ce8272b78d 100644 --- a/src/prefect/telemetry/bootstrap.py +++ b/src/prefect/telemetry/bootstrap.py @@ -15,12 +15,10 @@ from opentelemetry.sdk.trace import TracerProvider -def setup_telemetry() -> ( - Union[ - tuple["TracerProvider", "MeterProvider", "LoggerProvider"], - tuple[None, None, None], - ] -): +def setup_telemetry() -> Union[ + tuple["TracerProvider", "MeterProvider", "LoggerProvider"], + tuple[None, None, None], +]: settings = prefect.settings.get_current_settings() server_type = determine_server_type() diff --git a/src/prefect/telemetry/services.py b/src/prefect/telemetry/services.py index e2fcd928af69..6865a849356a 100644 --- a/src/prefect/telemetry/services.py +++ b/src/prefect/telemetry/services.py @@ -15,11 +15,9 @@ class OTLPExporter(Protocol[T_contra]): - def export(self, __items: Sequence[T_contra]) -> Any: - ... + def export(self, __items: Sequence[T_contra]) -> Any: ... - def shutdown(self) -> Any: - ... + def shutdown(self) -> Any: ... class BaseQueueingExporter(BatchedQueueService[BatchItem]): diff --git a/src/prefect/testing/cli.py b/src/prefect/testing/cli.py index 03c893527336..59c6fe212bd4 100644 --- a/src/prefect/testing/cli.py +++ b/src/prefect/testing/cli.py @@ -39,13 +39,13 @@ def check_contains(cli_result: Result, content: str, should_contain: bool) -> No display_content = content if should_contain: - assert ( - content in output - ), f"Desired contents {display_content!r} not found in CLI output" + assert content in output, ( + f"Desired contents {display_content!r} not found in CLI output" + ) else: - assert ( - content not in output - ), f"Undesired contents {display_content!r} found in CLI output" + assert content not in output, ( + f"Undesired contents {display_content!r} found in CLI output" + ) def invoke_and_assert( diff --git a/src/prefect/testing/standard_test_suites/blocks.py b/src/prefect/testing/standard_test_suites/blocks.py index f71f80c1244e..046fde7823fb 100644 --- a/src/prefect/testing/standard_test_suites/blocks.py +++ b/src/prefect/testing/standard_test_suites/blocks.py @@ -27,12 +27,12 @@ def test_all_fields_have_a_description(self, block: type[Block]) -> None: # block fields are currently excluded from this test. Once block field # descriptions are supported by the UI, remove this clause. continue - assert ( - field.description - ), f"{block.__name__} is missing a description on {name}" - assert field.description.endswith( - "." - ), f"{name} description on {block.__name__} does not end with a period" + assert field.description, ( + f"{block.__name__} is missing a description on {name}" + ) + assert field.description.endswith("."), ( + f"{name} description on {block.__name__} does not end with a period" + ) def test_has_a_valid_code_example(self, block: type[Block]) -> None: code_example = block.get_code_example() @@ -57,11 +57,11 @@ def test_has_a_valid_code_example(self, block: type[Block]) -> None: def test_has_a_valid_image(self, block: type[Block]) -> None: logo_url = block._logo_url - assert ( - logo_url is not None - ), f"{block.__name__} is missing a value for _logo_url" + assert logo_url is not None, ( + f"{block.__name__} is missing a value for _logo_url" + ) img = Image.open(urlopen(str(logo_url))) assert img.width == img.height, "Logo should be a square image" - assert ( - 1000 > img.width > 45 - ), f"Logo should be between 200px and 1000px wid, but is {img.width}px wide" + assert 1000 > img.width > 45, ( + f"Logo should be between 200px and 1000px wid, but is {img.width}px wide" + ) diff --git a/src/prefect/testing/utilities.py b/src/prefect/testing/utilities.py index 1fa270da6b38..2ed0b99a4b6f 100644 --- a/src/prefect/testing/utilities.py +++ b/src/prefect/testing/utilities.py @@ -45,7 +45,7 @@ def exceptions_equal(a: Exception, b: Exception) -> bool: """ if a == b: return True - return type(a) == type(b) and getattr(a, "args", None) == getattr(b, "args", None) + return type(a) is type(b) and getattr(a, "args", None) == getattr(b, "args", None) # AsyncMock has a new import path in Python 3.9+ @@ -202,9 +202,9 @@ async def get_most_recent_flow_run( def assert_blocks_equal( found: Block, expected: Block, exclude_private: bool = True, **kwargs: Any ) -> None: - assert isinstance( - found, type(expected) - ), f"Unexpected type {type(found).__name__}, expected {type(expected).__name__}" + assert isinstance(found, type(expected)), ( + f"Unexpected type {type(found).__name__}, expected {type(expected).__name__}" + ) if exclude_private: exclude = set(kwargs.pop("exclude", set())) diff --git a/src/prefect/utilities/_engine.py b/src/prefect/utilities/_engine.py index c3a99676dcc4..624e44845ee3 100644 --- a/src/prefect/utilities/_engine.py +++ b/src/prefect/utilities/_engine.py @@ -1,6 +1,5 @@ """Internal engine utilities""" - from collections.abc import Callable from functools import partial from typing import TYPE_CHECKING, Any, Union diff --git a/src/prefect/utilities/asyncutils.py b/src/prefect/utilities/asyncutils.py index ef3b017af9d1..7d2f7f9534e8 100644 --- a/src/prefect/utilities/asyncutils.py +++ b/src/prefect/utilities/asyncutils.py @@ -129,8 +129,7 @@ def run_coro_as_sync( *, force_new_thread: bool = ..., wait_for_result: Literal[True] = ..., -) -> R: - ... +) -> R: ... @overload @@ -139,8 +138,7 @@ def run_coro_as_sync( *, force_new_thread: bool = ..., wait_for_result: Literal[False] = False, -) -> R: - ... +) -> R: ... def run_coro_as_sync( @@ -366,15 +364,13 @@ async def ctx_call(): @overload def asyncnullcontext( value: None = None, *args: Any, **kwargs: Any -) -> AbstractAsyncContextManager[None, None]: - ... +) -> AbstractAsyncContextManager[None, None]: ... @overload def asyncnullcontext( value: R, *args: Any, **kwargs: Any -) -> AbstractAsyncContextManager[R, None]: - ... +) -> AbstractAsyncContextManager[R, None]: ... @asynccontextmanager diff --git a/src/prefect/utilities/collections.py b/src/prefect/utilities/collections.py index 3588b6f48a73..0502db91b05b 100644 --- a/src/prefect/utilities/collections.py +++ b/src/prefect/utilities/collections.py @@ -253,8 +253,7 @@ def visit_collection( context: dict[str, VT] = ..., remove_annotations: bool = ..., _seen: Optional[set[int]] = ..., -) -> Any: - ... +) -> Any: ... @overload @@ -267,8 +266,7 @@ def visit_collection( context: None = None, remove_annotations: bool = ..., _seen: Optional[set[int]] = ..., -) -> Any: - ... +) -> Any: ... @overload @@ -281,8 +279,7 @@ def visit_collection( context: dict[str, VT] = ..., remove_annotations: bool = ..., _seen: Optional[set[int]] = ..., -) -> Optional[Any]: - ... +) -> Optional[Any]: ... @overload @@ -295,8 +292,7 @@ def visit_collection( context: None = None, remove_annotations: bool = ..., _seen: Optional[set[int]] = ..., -) -> Optional[Any]: - ... +) -> Optional[Any]: ... @overload @@ -309,8 +305,7 @@ def visit_collection( context: dict[str, VT] = ..., remove_annotations: bool = ..., _seen: Optional[set[int]] = ..., -) -> None: - ... +) -> None: ... def visit_collection( @@ -545,13 +540,11 @@ def visit_expression(expr: Any) -> Any: @overload def remove_nested_keys( keys_to_remove: list[HashableT], obj: NestedDict[HashableT, VT] -) -> NestedDict[HashableT, VT]: - ... +) -> NestedDict[HashableT, VT]: ... @overload -def remove_nested_keys(keys_to_remove: list[HashableT], obj: Any) -> Any: - ... +def remove_nested_keys(keys_to_remove: list[HashableT], obj: Any) -> Any: ... def remove_nested_keys( @@ -579,13 +572,13 @@ def remove_nested_keys( @overload -def distinct(iterable: Iterable[HashableT], key: None = None) -> Iterator[HashableT]: - ... +def distinct( + iterable: Iterable[HashableT], key: None = None +) -> Iterator[HashableT]: ... @overload -def distinct(iterable: Iterable[T], key: Callable[[T], Hashable]) -> Iterator[T]: - ... +def distinct(iterable: Iterable[T], key: Callable[[T], Hashable]) -> Iterator[T]: ... def distinct( @@ -609,15 +602,13 @@ def _key(__i: Any) -> Hashable: @overload def get_from_dict( dct: NestedDict[str, VT], keys: Union[str, list[str]], default: None = None -) -> Optional[VT]: - ... +) -> Optional[VT]: ... @overload def get_from_dict( dct: NestedDict[str, VT], keys: Union[str, list[str]], default: R -) -> Union[VT, R]: - ... +) -> Union[VT, R]: ... def get_from_dict( diff --git a/src/prefect/utilities/dispatch.py b/src/prefect/utilities/dispatch.py index 603c24aecbd1..74ce0665c06e 100644 --- a/src/prefect/utilities/dispatch.py +++ b/src/prefect/utilities/dispatch.py @@ -45,15 +45,13 @@ def get_registry_for_type(cls: T) -> Optional[dict[str, T]]: @overload def get_dispatch_key( cls_or_instance: Any, allow_missing: Literal[False] = False -) -> str: - ... +) -> str: ... @overload def get_dispatch_key( cls_or_instance: Any, allow_missing: Literal[True] = ... -) -> Optional[str]: - ... +) -> Optional[str]: ... def get_dispatch_key( diff --git a/src/prefect/utilities/dockerutils.py b/src/prefect/utilities/dockerutils.py index eb6bd18b024b..c9df261362e3 100644 --- a/src/prefect/utilities/dockerutils.py +++ b/src/prefect/utilities/dockerutils.py @@ -331,7 +331,7 @@ def build( def assert_has_line(self, line: str) -> None: """Asserts that the given line is in the Dockerfile""" all_lines = "\n".join( - [f" {i+1:>3}: {line}" for i, line in enumerate(self.dockerfile_lines)] + [f" {i + 1:>3}: {line}" for i, line in enumerate(self.dockerfile_lines)] ) message = ( f"Expected {line!r} not found in Dockerfile. Dockerfile:\n{all_lines}" @@ -347,12 +347,12 @@ def assert_line_absent(self, line: str) -> None: surrounding_lines = "\n".join( [ - f" {i+1:>3}: {line}" + f" {i + 1:>3}: {line}" for i, line in enumerate(self.dockerfile_lines[i - 2 : i + 2]) ] ) message = ( - f"Unexpected {line!r} found in Dockerfile at line {i+1}. " + f"Unexpected {line!r} found in Dockerfile at line {i + 1}. " f"Surrounding lines:\n{surrounding_lines}" ) @@ -368,7 +368,7 @@ def assert_line_before(self, first: str, second: str) -> None: surrounding_lines = "\n".join( [ - f" {i+1:>3}: {line}" + f" {i + 1:>3}: {line}" for i, line in enumerate( self.dockerfile_lines[second_index - 2 : first_index + 2] ) @@ -377,8 +377,8 @@ def assert_line_before(self, first: str, second: str) -> None: message = ( f"Expected {first!r} to appear before {second!r} in the Dockerfile, but " - f"{first!r} was at line {first_index+1} and {second!r} as at line " - f"{second_index+1}. Surrounding lines:\n{surrounding_lines}" + f"{first!r} was at line {first_index + 1} and {second!r} as at line " + f"{second_index + 1}. Surrounding lines:\n{surrounding_lines}" ) assert first_index < second_index, message diff --git a/src/prefect/utilities/names.py b/src/prefect/utilities/names.py index baeb2b1b6475..31832aee27c4 100644 --- a/src/prefect/utilities/names.py +++ b/src/prefect/utilities/names.py @@ -68,5 +68,5 @@ def obfuscate_string(s: str, show_tail: bool = False) -> str: # take up to 4 characters, but only after the 10th character suffix = s[10:][-4:] if suffix and show_tail: - result = f"{result[:-len(suffix)]}{suffix}" + result = f"{result[: -len(suffix)]}{suffix}" return result diff --git a/src/prefect/utilities/processutils.py b/src/prefect/utilities/processutils.py index 7951fac4cae5..2bcf33096149 100644 --- a/src/prefect/utilities/processutils.py +++ b/src/prefect/utilities/processutils.py @@ -256,8 +256,7 @@ async def run_process( task_status: anyio.abc.TaskStatus[T] = ..., task_status_handler: Callable[[anyio.abc.Process], T] = ..., **kwargs: Any, -) -> anyio.abc.Process: - ... +) -> anyio.abc.Process: ... @overload @@ -270,8 +269,7 @@ async def run_process( task_status: Optional[anyio.abc.TaskStatus[int]] = ..., task_status_handler: None = None, **kwargs: Any, -) -> anyio.abc.Process: - ... +) -> anyio.abc.Process: ... @overload @@ -284,8 +282,7 @@ async def run_process( task_status: Optional[anyio.abc.TaskStatus[T]] = None, task_status_handler: Optional[Callable[[anyio.abc.Process], T]] = None, **kwargs: Any, -) -> anyio.abc.Process: - ... +) -> anyio.abc.Process: ... async def run_process( diff --git a/src/prefect/utilities/pydantic.py b/src/prefect/utilities/pydantic.py index 381456ca23bc..4918bca4c360 100644 --- a/src/prefect/utilities/pydantic.py +++ b/src/prefect/utilities/pydantic.py @@ -52,15 +52,13 @@ def _unreduce_model(model_name: str, json: str) -> Any: @overload -def add_cloudpickle_reduction(__model_cls: type[M]) -> type[M]: - ... +def add_cloudpickle_reduction(__model_cls: type[M]) -> type[M]: ... @overload def add_cloudpickle_reduction( __model_cls: None = None, **kwargs: Any -) -> Callable[[type[M]], type[M]]: - ... +) -> Callable[[type[M]], type[M]]: ... def add_cloudpickle_reduction( @@ -144,7 +142,7 @@ def add_type_dispatch(model_cls: type[M]) -> type[M]: elif not defines_dispatch_key and defines_type_field: field_type_annotation = model_cls.model_fields["type"].annotation - if field_type_annotation != str and field_type_annotation is not None: + if field_type_annotation is not str and field_type_annotation is not None: raise TypeError( f"Model class {model_cls.__name__!r} defines a 'type' field with " f"type {field_type_annotation.__name__!r} but it must be 'str'." @@ -169,7 +167,7 @@ def dispatch_key_from_type_field(cls: type[M]) -> str: def __init__(__pydantic_self__: M, **data: Any) -> None: type_string = ( get_dispatch_key(__pydantic_self__) - if type(__pydantic_self__) != model_cls + if type(__pydantic_self__) is not model_cls else "__base__" ) data.setdefault("type", type_string) diff --git a/src/prefect/utilities/templating.py b/src/prefect/utilities/templating.py index a009e7d03366..d9e86fc5c7ab 100644 --- a/src/prefect/utilities/templating.py +++ b/src/prefect/utilities/templating.py @@ -92,22 +92,19 @@ def find_placeholders(template: T) -> set[Placeholder]: @overload def apply_values( template: T, values: dict[str, Any], remove_notset: Literal[True] = True -) -> T: - ... +) -> T: ... @overload def apply_values( template: T, values: dict[str, Any], remove_notset: Literal[False] = False -) -> Union[T, type[NotSet]]: - ... +) -> Union[T, type[NotSet]]: ... @overload def apply_values( template: T, values: dict[str, Any], remove_notset: bool = False -) -> Union[T, type[NotSet]]: - ... +) -> Union[T, type[NotSet]]: ... def apply_values( diff --git a/src/prefect/utilities/visualization.py b/src/prefect/utilities/visualization.py index b149fa42806e..8f29827c3285 100644 --- a/src/prefect/utilities/visualization.py +++ b/src/prefect/utilities/visualization.py @@ -42,8 +42,7 @@ def track_viz_task( task_name: str, parameters: dict[str, Any], viz_return_value: Optional[Any] = None, -) -> Coroutine[Any, Any, Any]: - ... +) -> Coroutine[Any, Any, Any]: ... @overload @@ -52,8 +51,7 @@ def track_viz_task( task_name: str, parameters: dict[str, Any], viz_return_value: Optional[Any] = None, -) -> Any: - ... +) -> Any: ... def track_viz_task( diff --git a/src/prefect/workers/base.py b/src/prefect/workers/base.py index 16b0d4534945..3110a0b1f598 100644 --- a/src/prefect/workers/base.py +++ b/src/prefect/workers/base.py @@ -939,9 +939,9 @@ async def _check_flow_run(self, flow_run: "FlowRun") -> None: was created from a deployment with a storage block. """ if flow_run.deployment_id: - assert ( - self._client and self._client._started - ), "Client must be started to check flow run deployment." + assert self._client and self._client._started, ( + "Client must be started to check flow run deployment." + ) deployment = await self._client.read_deployment(flow_run.deployment_id) if deployment.storage_document_id: raise ValueError( diff --git a/src/prefect/workers/block.py b/src/prefect/workers/block.py index 7d5d3c1309c2..e0d5ba07c514 100644 --- a/src/prefect/workers/block.py +++ b/src/prefect/workers/block.py @@ -1,6 +1,7 @@ """ 2024-06-27: This surfaces an actionable error message for moved or removed objects in Prefect 3.0 upgrade. """ + from typing import Any, Callable from prefect._internal.compatibility.migration import getattr_migration diff --git a/src/prefect/workers/cloud.py b/src/prefect/workers/cloud.py index 7d5d3c1309c2..e0d5ba07c514 100644 --- a/src/prefect/workers/cloud.py +++ b/src/prefect/workers/cloud.py @@ -1,6 +1,7 @@ """ 2024-06-27: This surfaces an actionable error message for moved or removed objects in Prefect 3.0 upgrade. """ + from typing import Any, Callable from prefect._internal.compatibility.migration import getattr_migration diff --git a/src/prefect/workers/process.py b/src/prefect/workers/process.py index e0059d185395..c65e54354e36 100644 --- a/src/prefect/workers/process.py +++ b/src/prefect/workers/process.py @@ -13,6 +13,7 @@ For more information about work pools and workers, checkout out the [Prefect docs](/concepts/work-pools/). """ + from __future__ import annotations import contextlib diff --git a/tests/_internal/compatibility/test_async_dispatch.py b/tests/_internal/compatibility/test_async_dispatch.py index c2427a728d3e..6d76e0f9feea 100644 --- a/tests/_internal/compatibility/test_async_dispatch.py +++ b/tests/_internal/compatibility/test_async_dispatch.py @@ -207,15 +207,15 @@ def check_context() -> None: try: loop.call_soon(check_context) loop.run_forever() - assert ( - result is True - ), "the result we captured while loop was running should be True" + assert result is True, ( + "the result we captured while loop was running should be True" + ) finally: loop.close() asyncio.set_event_loop(None) - assert ( - is_in_async_context() is False - ), "the loop should be closed and not considered an async context" + assert is_in_async_context() is False, ( + "the loop should be closed and not considered an async context" + ) class TestIsInARunContext: diff --git a/tests/_internal/concurrency/test_services.py b/tests/_internal/concurrency/test_services.py index 15cea7996acc..48bc432417ba 100644 --- a/tests/_internal/concurrency/test_services.py +++ b/tests/_internal/concurrency/test_services.py @@ -483,8 +483,7 @@ def test_queue_service_start_failure_contains_traceback_only_at_debug( class ExceptionOnHandleService(QueueService[int]): exception_msg = "Oh no!" - async def _handle(self): - ... + async def _handle(self): ... async def _main_loop(self): raise Exception(self.exception_msg) diff --git a/tests/_internal/concurrency/test_waiters.py b/tests/_internal/concurrency/test_waiters.py index 894da9040b14..7233795cc0d4 100644 --- a/tests/_internal/concurrency/test_waiters.py +++ b/tests/_internal/concurrency/test_waiters.py @@ -129,9 +129,9 @@ def test_sync_waiter_timeout_in_worker_thread(): assert t1 - t0 < 2 assert call.cancelled() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on cancel" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on cancel" + ) @pytest.mark.skip(reason="This test is flaky and should be rewritten") @@ -170,9 +170,9 @@ def on_worker_thread(): assert t1 - t0 < 2 assert waiting_callback.cancelled() assert call.cancelled() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on cancel" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on cancel" + ) async def test_async_waiter_timeout_in_worker_thread(): @@ -196,9 +196,9 @@ async def test_async_waiter_timeout_in_worker_thread(): call.result() assert call.cancelled() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on cancel" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on cancel" + ) @pytest.mark.skip(reason="This test hangs and should be rewritten") @@ -232,9 +232,9 @@ def on_worker_thread(): assert t1 - t0 < 2 assert call.cancelled() assert waiting_callback.cancelled() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on cancel" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on cancel" + ) async def test_async_waiter_timeout_in_worker_thread_mixed_sleeps(): @@ -283,9 +283,9 @@ async def test_async_waiter_base_exception_in_worker_thread(exception_cls, raise with pytest.raises(exception_cls, match="test"): call.result() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on exception" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on exception" + ) @pytest.mark.parametrize("raise_fn", [raises, araises], ids=["sync", "async"]) @@ -318,9 +318,9 @@ def on_worker_thread(): with pytest.raises(exception_cls, match="test"): callback.result() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on exception" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on exception" + ) @pytest.mark.parametrize("raise_fn", [raises, araises], ids=["sync", "async"]) @@ -343,9 +343,9 @@ def test_sync_waiter_base_exception_in_worker_thread(exception_cls, raise_fn): with pytest.raises(exception_cls, match="test"): call.result() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on exception" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on exception" + ) @pytest.mark.parametrize("raise_fn", [raises, araises], ids=["sync", "async"]) @@ -376,6 +376,6 @@ def on_worker_thread(): # to the main thread should have the error with pytest.raises(exception_cls, match="test"): callback.result() - assert ( - done_callback.result(timeout=0) == 1 - ), "The done callback should still be called on exception" + assert done_callback.result(timeout=0) == 1, ( + "The done callback should still be called on exception" + ) diff --git a/tests/_internal/test_integrations.py b/tests/_internal/test_integrations.py index e44f28e4eeaf..871992db39ee 100644 --- a/tests/_internal/test_integrations.py +++ b/tests/_internal/test_integrations.py @@ -42,6 +42,6 @@ def test_known_extras_for_packages(): # Check each entry in known_extras for package, extra in KNOWN_EXTRAS_FOR_PACKAGES.items(): extra_name = extra.split("[")[1][:-1] - assert ( - extra_name in extras_require - ), f"Extra '{extra_name}' for package '{package}' not found in setup.py" + assert extra_name in extras_require, ( + f"Extra '{extra_name}' for package '{package}' not found in setup.py" + ) diff --git a/tests/blocks/test_core.py b/tests/blocks/test_core.py index d70c6d12cc16..88524bf5f26e 100644 --- a/tests/blocks/test_core.py +++ b/tests/blocks/test_core.py @@ -666,7 +666,7 @@ def test_from_block_document(self, block_type_x): ) block = Block._from_block_document(api_block) - assert type(block) == self.MyRegisteredBlock + assert type(block) is self.MyRegisteredBlock assert block.x == "x" assert block._block_schema_id == block_schema_id assert block._block_document_id == api_block.id @@ -683,7 +683,7 @@ def test_from_block_document_anonymous(self, block_type_x): ) block = Block._from_block_document(api_block) - assert type(block) == self.MyRegisteredBlock + assert type(block) is self.MyRegisteredBlock assert block.x == "x" assert block._block_schema_id == block_schema_id assert block._block_document_id == api_block.id @@ -704,7 +704,7 @@ class BlockyMcBlock(Block): ) block = BlockyMcBlock._from_block_document(api_block) - assert type(block) == BlockyMcBlock + assert type(block) is BlockyMcBlock assert block.fizz == "buzz" assert block._block_schema_id == block_schema_id assert block._block_document_id == api_block.id @@ -763,9 +763,9 @@ def test_create_block_schema_from_block_without_capabilities( assert block_schema.checksum == test_block._calculate_schema_checksum() assert block_schema.fields == test_block.model_json_schema() - assert ( - block_schema.capabilities == [] - ), "No capabilities should be defined for this Block and defaults to []" + assert block_schema.capabilities == [], ( + "No capabilities should be defined for this Block and defaults to []" + ) assert block_schema.version == DEFAULT_BLOCK_SCHEMA_VERSION def test_create_block_schema_from_block_with_capabilities( @@ -775,9 +775,9 @@ def test_create_block_schema_from_block_with_capabilities( assert block_schema.checksum == test_block._calculate_schema_checksum() assert block_schema.fields == test_block.model_json_schema() - assert ( - block_schema.capabilities == [] - ), "No capabilities should be defined for this Block and defaults to []" + assert block_schema.capabilities == [], ( + "No capabilities should be defined for this Block and defaults to []" + ) assert block_schema.version == DEFAULT_BLOCK_SCHEMA_VERSION def test_create_block_schema_with_no_version_specified( @@ -2514,24 +2514,24 @@ async def test_block_type_slug_excluded_from_document(self, prefect_client): def test_base_parse_works_for_base_instance(self): block = BaseBlock.model_validate(BaseBlock().model_dump()) - assert type(block) == BaseBlock + assert type(block) is BaseBlock block = BaseBlock.model_validate(BaseBlock().model_dump()) - assert type(block) == BaseBlock + assert type(block) is BaseBlock def test_base_parse_creates_child_instance_from_dict(self): block = BaseBlock.model_validate(AChildBlock().model_dump()) - assert type(block) == AChildBlock + assert type(block) is AChildBlock block = BaseBlock.model_validate(BChildBlock().model_dump()) - assert type(block) == BChildBlock + assert type(block) is BChildBlock def test_base_parse_creates_child_instance_from_json(self): block = BaseBlock.model_validate_json(AChildBlock().model_dump_json()) - assert type(block) == AChildBlock + assert type(block) is AChildBlock block = BaseBlock.model_validate_json(BChildBlock().model_dump_json()) - assert type(block) == BChildBlock + assert type(block) is BChildBlock def test_base_parse_retains_default_attributes(self): block = BaseBlock.model_validate(AChildBlock().model_dump()) @@ -2550,17 +2550,17 @@ def test_base_parse_retains_set_base_attributes(self): def test_base_field_creates_child_instance_from_object(self): model = ParentModel(block=AChildBlock()) - assert type(model.block) == AChildBlock + assert type(model.block) is AChildBlock model = ParentModel(block=BChildBlock()) - assert type(model.block) == BChildBlock + assert type(model.block) is BChildBlock def test_base_field_creates_child_instance_from_dict(self): model = ParentModel(block=AChildBlock().model_dump()) - assert type(model.block) == AChildBlock + assert type(model.block) is AChildBlock model = ParentModel(block=BChildBlock().model_dump()) - assert type(model.block) == BChildBlock + assert type(model.block) is BChildBlock def test_created_block_has_pydantic_attributes(self): block = BaseBlock.model_validate(AChildBlock().model_dump()) @@ -2602,30 +2602,30 @@ class UnionParentModel(BaseModel): block: Union[AChildBlock, BChildBlock] model = UnionParentModel(block=AChildBlock(a=3).model_dump()) - assert type(model.block) == AChildBlock + assert type(model.block) is AChildBlock # Assignment with a copy works still model.block = model.block.model_copy() - assert type(model.block) == AChildBlock + assert type(model.block) is AChildBlock assert model.block model = UnionParentModel(block=BChildBlock(b=4).model_dump()) - assert type(model.block) == BChildBlock + assert type(model.block) is BChildBlock def test_base_field_creates_child_instance_with_assignment_validation(self): class AssignmentParentModel(BaseModel, validate_assignment=True): block: BaseBlock model = AssignmentParentModel(block=AChildBlock(a=3).model_dump()) - assert type(model.block) == AChildBlock + assert type(model.block) is AChildBlock assert model.block.a == 3 model.block = model.block.model_copy() - assert type(model.block) == AChildBlock + assert type(model.block) is AChildBlock assert model.block.a == 3 model.block = BChildBlock(b=4).model_dump() - assert type(model.block) == BChildBlock + assert type(model.block) is BChildBlock assert model.block.b == 4 diff --git a/tests/cli/test_config.py b/tests/cli/test_config.py index 76ce8b8e9e17..f134ac0b6136 100644 --- a/tests/cli/test_config.py +++ b/tests/cli/test_config.py @@ -129,9 +129,9 @@ def test_set_with_invalid_value_type(): ) profiles = load_profiles() - assert ( - PREFECT_API_DATABASE_TIMEOUT not in profiles["foo"].settings - ), "The setting should not be saved" + assert PREFECT_API_DATABASE_TIMEOUT not in profiles["foo"].settings, ( + "The setting should not be saved" + ) def test_set_with_unparsable_setting(): @@ -363,9 +363,9 @@ def test_view_excludes_unset_settings_without_show_defaults_flag(monkeypatch): lines = res.stdout.splitlines() assert "PREFECT_PROFILE='foo'" in lines - assert len(expected) < len( - _get_valid_setting_names(prefect.settings.Settings) - ), "All settings were not expected; we should only have a subset." + assert len(expected) < len(_get_valid_setting_names(prefect.settings.Settings)), ( + "All settings were not expected; we should only have a subset." + ) @pytest.mark.skip("TODO") @@ -380,9 +380,9 @@ def test_view_includes_unset_settings_with_show_defaults(): printed_settings = {} for line in lines[1:]: setting, value = line.split("=", maxsplit=1) - assert ( - setting not in printed_settings - ), f"Setting displayed multiple times: {setting}" + assert setting not in printed_settings, ( + f"Setting displayed multiple times: {setting}" + ) printed_settings[setting] = value assert printed_settings.keys() == _get_valid_setting_names( @@ -395,13 +395,12 @@ def test_view_includes_unset_settings_with_show_defaults(): "REST OF SECRETS", ): # TODO: clean this up continue - assert ( - value - == ( - expected_value - := f"'{expected_settings[prefect.settings.env_var_to_accessor(key)]}'" - ) - ), f"Displayed setting does not match set value: {key} = {value} != {expected_value}" + assert value == ( + expected_value + := f"'{expected_settings[prefect.settings.env_var_to_accessor(key)]}'" + ), ( + f"Displayed setting does not match set value: {key} = {value} != {expected_value}" + ) @pytest.mark.parametrize( @@ -435,9 +434,9 @@ def test_view_shows_setting_sources(monkeypatch, command): for line in lines[i + 1 :]: # Assert that each line ends with a source - assert any( - line.endswith(s) for s in [FROM_DEFAULT, FROM_PROFILE, FROM_ENV] - ), f"Source missing from line: {line}" + assert any(line.endswith(s) for s in [FROM_DEFAULT, FROM_PROFILE, FROM_ENV]), ( + f"Source missing from line: {line}" + ) # Assert that sources are correct assert f"PREFECT_API_DATABASE_TIMEOUT='2.0' {FROM_PROFILE}" in lines diff --git a/tests/cli/test_deploy.py b/tests/cli/test_deploy.py index 361cbbc56b84..773407e302e5 100644 --- a/tests/cli/test_deploy.py +++ b/tests/cli/test_deploy.py @@ -2501,9 +2501,9 @@ async def test_rrule_deployment_yaml( with prefect_file.open(mode="r") as f: deploy_config = yaml.safe_load(f) - deploy_config["deployments"][0]["schedule"][ - "rrule" - ] = "DTSTART:20220910T110000\nRRULE:FREQ=HOURLY;BYDAY=MO,TU,WE,TH,FR,SA;BYHOUR=9,10,11,12,13,14,15,16,17" + deploy_config["deployments"][0]["schedule"]["rrule"] = ( + "DTSTART:20220910T110000\nRRULE:FREQ=HOURLY;BYDAY=MO,TU,WE,TH,FR,SA;BYHOUR=9,10,11,12,13,14,15,16,17" + ) with prefect_file.open(mode="w") as f: yaml.safe_dump(deploy_config, f) diff --git a/tests/cli/test_global_concurrency_limit.py b/tests/cli/test_global_concurrency_limit.py index 02da8f5e4054..f4fdac50cad3 100644 --- a/tests/cli/test_global_concurrency_limit.py +++ b/tests/cli/test_global_concurrency_limit.py @@ -366,9 +366,9 @@ async def test_update_gcl_active_slots( name=global_concurrency_limit.name ) - assert ( - client_res.active_slots == 10 - ), f"Expected active slots to be 10, got {client_res.active_slots}" + assert client_res.active_slots == 10, ( + f"Expected active slots to be 10, got {client_res.active_slots}" + ) async def test_update_gcl_slot_decay_per_second( @@ -393,9 +393,9 @@ async def test_update_gcl_slot_decay_per_second( name=global_concurrency_limit.name ) - assert ( - client_res.slot_decay_per_second == 0.5 - ), f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + assert client_res.slot_decay_per_second == 0.5, ( + f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + ) async def test_update_gcl_multiple_fields( @@ -424,12 +424,12 @@ async def test_update_gcl_multiple_fields( name=global_concurrency_limit.name ) - assert ( - client_res.active_slots == 10 - ), f"Expected active slots to be 10, got {client_res.active_slots}" - assert ( - client_res.slot_decay_per_second == 0.5 - ), f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + assert client_res.active_slots == 10, ( + f"Expected active slots to be 10, got {client_res.active_slots}" + ) + assert client_res.slot_decay_per_second == 0.5, ( + f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + ) async def test_update_gcl_to_inactive( @@ -453,9 +453,9 @@ async def test_update_gcl_to_inactive( name=global_concurrency_limit.name ) - assert ( - client_res.active is False - ), f"Expected active to be False, got {client_res.active}" + assert client_res.active is False, ( + f"Expected active to be False, got {client_res.active}" + ) async def test_update_gcl_to_active( @@ -479,9 +479,9 @@ async def test_update_gcl_to_active( name=global_concurrency_limit.name ) - assert ( - client_res.active is True - ), f"Expected active to be True, got {client_res.active}" + assert client_res.active is True, ( + f"Expected active to be True, got {client_res.active}" + ) def test_update_gcl_not_found(): @@ -522,16 +522,16 @@ async def test_create_gcl( client_res = await prefect_client.read_global_concurrency_limit_by_name(name="test") - assert ( - client_res.name == "test" - ), f"Expected name to be 'test', got {client_res.name}" + assert client_res.name == "test", ( + f"Expected name to be 'test', got {client_res.name}" + ) assert client_res.limit == 10, f"Expected limit to be 10, got {client_res.limit}" - assert ( - client_res.active_slots == 10 - ), f"Expected active slots to be 10, got {client_res.active_slots}" - assert ( - client_res.slot_decay_per_second == 0.5 - ), f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + assert client_res.active_slots == 10, ( + f"Expected active slots to be 10, got {client_res.active_slots}" + ) + assert client_res.slot_decay_per_second == 0.5, ( + f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + ) async def test_create_gcl_no_fields(): @@ -640,13 +640,13 @@ async def test_create_gcl_succeeds( client_res = await prefect_client.read_global_concurrency_limit_by_name(name="test") - assert ( - client_res.name == "test" - ), f"Expected name to be 'test', got {client_res.name}" + assert client_res.name == "test", ( + f"Expected name to be 'test', got {client_res.name}" + ) assert client_res.limit == 10, f"Expected limit to be 10, got {client_res.limit}" - assert ( - client_res.active_slots == 10 - ), f"Expected active slots to be 10, got {client_res.active_slots}" - assert ( - client_res.slot_decay_per_second == 0.5 - ), f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + assert client_res.active_slots == 10, ( + f"Expected active slots to be 10, got {client_res.active_slots}" + ) + assert client_res.slot_decay_per_second == 0.5, ( + f"Expected slot decay per second to be 0.5, got {client_res.slot_decay_per_second}" + ) diff --git a/tests/cli/test_profile.py b/tests/cli/test_profile.py index 76df203b10cd..e0cb7b89bf57 100644 --- a/tests/cli/test_profile.py +++ b/tests/cli/test_profile.py @@ -451,9 +451,9 @@ def test_rename_profile_renames_profile(): profiles = load_profiles() assert "foo" not in profiles, "The original profile should not exist anymore" - assert profiles["bar"].settings == { - PREFECT_API_KEY: "foo" - }, "Settings should be retained" + assert profiles["bar"].settings == {PREFECT_API_KEY: "foo"}, ( + "Settings should be retained" + ) assert profiles.active_name != "bar", "The active profile should not be changed" @@ -500,9 +500,9 @@ def test_rename_profile_warns_on_environment_variable_active_profile(monkeypatch ) profiles = load_profiles() - assert ( - profiles.active_name != "foo" - ), "The active profile should not be updated in the file" + assert profiles.active_name != "foo", ( + "The active profile should not be updated in the file" + ) def test_inspect_profile_unknown_name(): diff --git a/tests/cli/test_start_server.py b/tests/cli/test_start_server.py index 20541ee4d061..18416bcda663 100644 --- a/tests/cli/test_start_server.py +++ b/tests/cli/test_start_server.py @@ -131,9 +131,9 @@ def test_start_and_stop_background_server(self, unused_tcp_port: int): expected_code=0, ) - assert not ( - PREFECT_HOME.value() / "server.pid" - ).exists(), "Server PID file exists" + assert not (PREFECT_HOME.value() / "server.pid").exists(), ( + "Server PID file exists" + ) def test_start_duplicate_background_server( self, unused_tcp_port_factory: Callable[[], int] @@ -220,9 +220,9 @@ def test_stop_stale_pid_file(self, unused_tcp_port: int): expected_code=0, ) - assert not ( - PREFECT_HOME.value() / "server.pid" - ).exists(), "Server PID file exists" + assert not (PREFECT_HOME.value() / "server.pid").exists(), ( + "Server PID file exists" + ) @pytest.mark.service("process") @@ -237,9 +237,9 @@ async def test_sigint_shutsdown_cleanly(self): with anyio.fail_after(SHUTDOWN_TIMEOUT): exit_code = await server_process.wait() - assert ( - exit_code == 0 - ), "After one sigint, the process should exit successfully" + assert exit_code == 0, ( + "After one sigint, the process should exit successfully" + ) server_process.out.seek(0) out = server_process.out.read().decode() @@ -259,9 +259,9 @@ async def test_sigterm_shutsdown_cleanly(self): with anyio.fail_after(SHUTDOWN_TIMEOUT): exit_code = await server_process.wait() - assert ( - exit_code == -signal.SIGTERM - ), "After a sigterm, the server process should indicate it was terminated" + assert exit_code == -signal.SIGTERM, ( + "After a sigterm, the server process should indicate it was terminated" + ) server_process.out.seek(0) out = server_process.out.read().decode() @@ -281,9 +281,9 @@ async def test_ctrl_break_shutsdown_cleanly(self): with anyio.fail_after(SHUTDOWN_TIMEOUT): exit_code = await server_process.wait() - assert ( - exit_code == 0 - ), "After a ctrl-break, the process should exit successfully" + assert exit_code == 0, ( + "After a ctrl-break, the process should exit successfully" + ) server_process.out.seek(0) out = server_process.out.read().decode() diff --git a/tests/cli/test_version.py b/tests/cli/test_version.py index 5be9cbaf15b4..a572b34137bd 100644 --- a/tests/cli/test_version.py +++ b/tests/cli/test_version.py @@ -77,7 +77,7 @@ def test_correct_output_ephemeral_sqlite(monkeypatch, disable_hosted_api_server) Version: {prefect.__version__} API version: {SERVER_API_VERSION} Python version: {platform.python_version()} - Git commit: {version_info['full-revisionid'][:8]} + Git commit: {version_info["full-revisionid"][:8]} Built: {built.to_day_datetime_string()} OS/Arch: {sys.platform}/{platform.machine()} Profile: {profile.name} @@ -112,7 +112,7 @@ def test_correct_output_ephemeral_postgres(monkeypatch, disable_hosted_api_serve Version: {prefect.__version__} API version: {SERVER_API_VERSION} Python version: {platform.python_version()} - Git commit: {version_info['full-revisionid'][:8]} + Git commit: {version_info["full-revisionid"][:8]} Built: {built.to_day_datetime_string()} OS/Arch: {sys.platform}/{platform.machine()} Profile: {profile.name} @@ -136,7 +136,7 @@ def test_correct_output_non_ephemeral_server_type(): expected_output=f"""Version: {prefect.__version__} API version: {SERVER_API_VERSION} Python version: {platform.python_version()} -Git commit: {version_info['full-revisionid'][:8]} +Git commit: {version_info["full-revisionid"][:8]} Built: {built.to_day_datetime_string()} OS/Arch: {sys.platform}/{platform.machine()} Profile: {profile.name} diff --git a/tests/client/test_client_routes.py b/tests/client/test_client_routes.py index 9e8528c3b190..304307266fda 100644 --- a/tests/client/test_client_routes.py +++ b/tests/client/test_client_routes.py @@ -26,6 +26,6 @@ def test_server_routes_match_openapi_schema(): print() # Add blank line between routes for readability # Assert ServerRoutes are subset of OpenAPI paths - assert ( - not missing_routes - ), f"{len(missing_routes)} routes are missing from OpenAPI schema" + assert not missing_routes, ( + f"{len(missing_routes)} routes are missing from OpenAPI schema" + ) diff --git a/tests/client/test_prefect_client.py b/tests/client/test_prefect_client.py index 2be38114d9ad..40f03d92b141 100644 --- a/tests/client/test_prefect_client.py +++ b/tests/client/test_prefect_client.py @@ -2396,9 +2396,9 @@ async def test_read_automations_by_name_multiple_same_name( ) assert read_route.called - assert ( - len(read_automation) == 2 - ), "Expected two automations with the same name" + assert len(read_automation) == 2, ( + "Expected two automations with the same name" + ) assert all( [ automation.name == created_automation["name"] diff --git a/tests/conftest.py b/tests/conftest.py index 1851a585212d..ae3b94ed7e5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -392,13 +392,13 @@ def cleanup(drain_log_workers, drain_events_workers): @pytest.fixture(scope="session", autouse=True) def safety_check_settings(): # Safety check for connection to an external API - assert ( - PREFECT_API_URL.value() is None - ), "Tests should not be run connected to an external API." + assert PREFECT_API_URL.value() is None, ( + "Tests should not be run connected to an external API." + ) # Safety check for home directory - assert ( - str(PREFECT_HOME.value()) == TEST_PREFECT_HOME - ), "Tests should use the temporary test directory" + assert str(PREFECT_HOME.value()) == TEST_PREFECT_HOME, ( + "Tests should use the temporary test directory" + ) @pytest.fixture(scope="session", autouse=True) diff --git a/tests/events/jinja_filters/test_ui_url.py b/tests/events/jinja_filters/test_ui_url.py index 8e8149fa487f..f8f3e680c2f1 100644 --- a/tests/events/jinja_filters/test_ui_url.py +++ b/tests/events/jinja_filters/test_ui_url.py @@ -76,7 +76,7 @@ def test_automation_url(chonk_party: Automation): rendered = template.render({"automation": chonk_party}) assert rendered == ( - "http://localhost:3000" f"/automations/automation/{chonk_party.id}" + f"http://localhost:3000/automations/automation/{chonk_party.id}" ) @@ -109,7 +109,7 @@ def test_flow_resource_url(chonk_party: Automation): } ) - assert rendered == ("http://localhost:3000" f"/flows/flow/{flow_id}") + assert rendered == (f"http://localhost:3000/flows/flow/{flow_id}") def test_flow_run_resource_url(chonk_party: Automation): diff --git a/tests/events/server/actions/test_jinja_templated_action.py b/tests/events/server/actions/test_jinja_templated_action.py index 793c91a8ebb8..0c2c1c07aed7 100644 --- a/tests/events/server/actions/test_jinja_templated_action.py +++ b/tests/events/server/actions/test_jinja_templated_action.py @@ -1324,8 +1324,8 @@ async def test_work_pool_is_available_to_templates( assert ( rendered == f""" - Name: { work_queue.name } - Pool: { work_pool.name } + Name: {work_queue.name} + Pool: {work_pool.name} """ ) diff --git a/tests/events/server/test_events_queries.py b/tests/events/server/test_events_queries.py index 8e8371018f9e..62b66c27b2f4 100644 --- a/tests/events/server/test_events_queries.py +++ b/tests/events/server/test_events_queries.py @@ -179,7 +179,9 @@ def assert_events_ordered_descending(events: List[ReceivedEvent]): last = events[0] for i, event in enumerate(events[1:]): - assert event.occurred <= last.occurred, f"Event at index {i+1} is out of order" + assert event.occurred <= last.occurred, ( + f"Event at index {i + 1} is out of order" + ) last = event @@ -189,7 +191,9 @@ def assert_events_ordered_ascending(events: List[ReceivedEvent]): last = events[0] for i, event in enumerate(events[1:]): - assert event.occurred >= last.occurred, f"Event at index {i+1} is out of order" + assert event.occurred >= last.occurred, ( + f"Event at index {i + 1} is out of order" + ) last = event diff --git a/tests/experimental/test_lineage.py b/tests/experimental/test_lineage.py index a474743fe887..027cb808b3c7 100644 --- a/tests/experimental/test_lineage.py +++ b/tests/experimental/test_lineage.py @@ -577,7 +577,7 @@ async def test_emit_external_resource_lineage_with_context_resources( for i, call in enumerate(context_calls): assert call.kwargs["event"] == "prefect.lineage.upstream-interaction" assert call.kwargs["resource"] == { - "prefect.resource.id": f"context{i+1}", + "prefect.resource.id": f"context{i + 1}", "prefect.resource.role": "flow-run" if i == 0 else "flow", "prefect.resource.lineage-group": "global", } @@ -747,7 +747,7 @@ async def test_emit_result_read_event_with_downstream_resources( assert resource_uri is not None assert call.kwargs["event"] == "prefect.result.read" assert call.kwargs["resource"] == { - "prefect.resource.id": f"downstream{i+1}", + "prefect.resource.id": f"downstream{i + 1}", "prefect.resource.lineage-group": "global", } assert call.kwargs["related"] == [ diff --git a/tests/infrastructure/provisioners/test_coiled.py b/tests/infrastructure/provisioners/test_coiled.py index eca5cdceb11f..a83f688e099b 100644 --- a/tests/infrastructure/provisioners/test_coiled.py +++ b/tests/infrastructure/provisioners/test_coiled.py @@ -89,12 +89,6 @@ async def test_provision( mock_coiled, mock_coiled, ] - # simulate coiled token creation - mock_coiled.config.Config.return_value.get.side_effect = [ - None, - None, - "mock_token", - ] work_pool_name = "work-pool-name" base_job_template = {"variables": {"properties": {"credentials": {}}}} @@ -108,7 +102,7 @@ async def test_provision( "work-pool-name-coiled-credentials", "coiled-credentials" ) - assert block_document.data["api_token"], str == "mock_token" + assert block_document.data["api_token"] == "local-api-token-from-dask-config" # Check if the base job template was updated assert result["variables"]["properties"]["credentials"] == { @@ -160,9 +154,6 @@ async def test_provision_existing_coiled_credentials( mock_coiled, mock_coiled, ] # coiled is already installed - mock_coiled.config.Config.return_value.get.side_effect = [ - "mock_token", - ] # coiled config exists work_pool_name = "work-pool-name" base_job_template = {"variables": {"properties": {"credentials": {}}}} @@ -176,7 +167,7 @@ async def test_provision_existing_coiled_credentials( "work-pool-name-coiled-credentials", "coiled-credentials" ) - assert block_document.data["api_token"], str == "mock_token" + assert block_document.data["api_token"] == "local-api-token-from-dask-config" # Check if the base job template was updated assert result["variables"]["properties"]["credentials"] == { diff --git a/tests/infrastructure/provisioners/test_container_instance.py b/tests/infrastructure/provisioners/test_container_instance.py index 840b3f679a87..a6ef966e8bed 100644 --- a/tests/infrastructure/provisioners/test_container_instance.py +++ b/tests/infrastructure/provisioners/test_container_instance.py @@ -1537,9 +1537,9 @@ async def test_aci_provision_no_existing_credentials_block( "client_secret": "", } - new_base_job_template["variables"]["properties"]["subscription_id"][ - "default" - ] = "12345678-1234-1234-1234-123456789012" + new_base_job_template["variables"]["properties"]["subscription_id"]["default"] = ( + "12345678-1234-1234-1234-123456789012" + ) new_base_job_template["variables"]["properties"]["resource_group_name"][ "default" @@ -1839,13 +1839,13 @@ async def test_aci_provision_existing_credentials_block( ignore_if_exists=True, return_json=True, ) - assert ( - unexpected_call not in provisioner.azure_cli.run_command.mock_calls - ), "Unexpected call made: {call}" + assert unexpected_call not in provisioner.azure_cli.run_command.mock_calls, ( + "Unexpected call made: {call}" + ) - new_base_job_template["variables"]["properties"]["subscription_id"][ - "default" - ] = "12345678-1234-1234-1234-123456789012" + new_base_job_template["variables"]["properties"]["subscription_id"]["default"] = ( + "12345678-1234-1234-1234-123456789012" + ) new_base_job_template["variables"]["properties"]["resource_group_name"][ "default" @@ -2189,9 +2189,9 @@ async def test_aci_provision_interactive_default_provisioning( "client_secret": "", } - new_base_job_template["variables"]["properties"]["subscription_id"][ - "default" - ] = "12345678-1234-1234-1234-123456789012" + new_base_job_template["variables"]["properties"]["subscription_id"]["default"] = ( + "12345678-1234-1234-1234-123456789012" + ) new_base_job_template["variables"]["properties"]["resource_group_name"][ "default" @@ -2537,9 +2537,9 @@ def prompt_mocks(*args, **kwargs): "client_secret": "", } - new_base_job_template["variables"]["properties"]["subscription_id"][ - "default" - ] = "12345678-1234-1234-1234-123456789012" + new_base_job_template["variables"]["properties"]["subscription_id"]["default"] = ( + "12345678-1234-1234-1234-123456789012" + ) new_base_job_template["variables"]["properties"]["resource_group_name"][ "default" diff --git a/tests/infrastructure/provisioners/test_modal.py b/tests/infrastructure/provisioners/test_modal.py index 5ea24d580eee..b9ad382580b5 100644 --- a/tests/infrastructure/provisioners/test_modal.py +++ b/tests/infrastructure/provisioners/test_modal.py @@ -99,7 +99,7 @@ async def test_provision( ) assert block_document.data["token_id"] == "mock_id" - assert block_document.data["token_secret"], str == "mock_secret" + assert block_document.data["token_secret"] == "mock_secret" # Check if the base job template was updated assert result["variables"]["properties"]["modal_credentials"] == { @@ -177,7 +177,7 @@ async def test_provision_existing_modal_credentials( ) assert block_document.data["token_id"] == "mock_id" - assert block_document.data["token_secret"], str == "mock_secret" + assert block_document.data["token_secret"] == "mock_secret" # Check if the base job template was updated assert result["variables"]["properties"]["modal_credentials"] == { diff --git a/tests/public/flows/test_flow_with_mapped_tasks.py b/tests/public/flows/test_flow_with_mapped_tasks.py index 064f52364cbb..e9d4dbf8c9bd 100644 --- a/tests/public/flows/test_flow_with_mapped_tasks.py +++ b/tests/public/flows/test_flow_with_mapped_tasks.py @@ -10,7 +10,7 @@ def generate_task_run_name(parameters: dict) -> str: - names.append(f'{task_run.task_name} - input: {parameters["input"]["number"]}') + names.append(f"{task_run.task_name} - input: {parameters['input']['number']}") return names[-1] diff --git a/tests/results/test_flow_results.py b/tests/results/test_flow_results.py index 89f708ca24bf..3042a9afefad 100644 --- a/tests/results/test_flow_results.py +++ b/tests/results/test_flow_results.py @@ -384,7 +384,7 @@ async def foo(): assert result == {"foo": "bar"} local_storage = await LocalFileSystem.load("my-result-storage") - result_bytes = await local_storage.read_path(f"{tmp_path/'my-result.pkl'}") + result_bytes = await local_storage.read_path(f"{tmp_path / 'my-result.pkl'}") saved_python_result = ResultRecord.deserialize(result_bytes).result assert saved_python_result == {"foo": "bar"} diff --git a/tests/runner/test_storage.py b/tests/runner/test_storage.py index b72773d0f5d1..3911a1574502 100644 --- a/tests/runner/test_storage.py +++ b/tests/runner/test_storage.py @@ -221,9 +221,9 @@ async def test_clone_repo_sparse(self, mock_run_process: AsyncMock, monkeypatch) ] mock_run_process.assert_has_awaits(expected_calls) - assert ( - mock_run_process.await_args_list == expected_calls - ), f"Unexpected calls: {mock_run_process.await_args_list}" + assert mock_run_process.await_args_list == expected_calls, ( + f"Unexpected calls: {mock_run_process.await_args_list}" + ) async def test_clone_existing_repo_sparse( self, mock_run_process: AsyncMock, monkeypatch @@ -256,9 +256,9 @@ async def test_clone_existing_repo_sparse( ] mock_run_process.assert_has_awaits(expected_calls) - assert ( - mock_run_process.await_args_list == expected_calls - ), f"Unexpected calls: {mock_run_process.await_args_list}" + assert mock_run_process.await_args_list == expected_calls, ( + f"Unexpected calls: {mock_run_process.await_args_list}" + ) async def test_pull_code_with_username_and_password( self, diff --git a/tests/runner/test_webserver.py b/tests/runner/test_webserver.py index f2c3ce1e13ef..8d4aa084e87e 100644 --- a/tests/runner/test_webserver.py +++ b/tests/runner/test_webserver.py @@ -151,9 +151,10 @@ async def test_runners_deployment_run_route_execs_flow_run(self, runner: Runner) webserver = await build_server(runner) client = TestClient(webserver) - with mock.patch( - "prefect.runner.server.get_client", new=mock_get_client - ), mock.patch.object(runner, "execute_in_background"): + with ( + mock.patch("prefect.runner.server.get_client", new=mock_get_client), + mock.patch.object(runner, "execute_in_background"), + ): with client: response = client.post(f"/deployment/{deployment_id}/run") assert response.status_code == 201, response.json() diff --git a/tests/scripts/test_generate_lower_bounds.py b/tests/scripts/test_generate_lower_bounds.py index 53d5bc9d31ba..a4617cf0f101 100644 --- a/tests/scripts/test_generate_lower_bounds.py +++ b/tests/scripts/test_generate_lower_bounds.py @@ -1,6 +1,7 @@ -"""" +""" " Tests scripts/generate-lower-bounds.py """ + import runpy import pytest diff --git a/tests/server/api/test_server.py b/tests/server/api/test_server.py index c86135c4aa33..67248b9e8a38 100644 --- a/tests/server/api/test_server.py +++ b/tests/server/api/test_server.py @@ -252,9 +252,9 @@ async def test_runs_wrapped_function_on_missing_key( self, current_block_registry_hash ): assert not PREFECT_MEMO_STORE_PATH.value().exists() - assert ( - PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION.value() - ), "Memoization is not enabled" + assert PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION.value(), ( + "Memoization is not enabled" + ) test_func = AsyncMock() @@ -277,9 +277,9 @@ async def test_runs_wrapped_function_on_mismatched_key( memo_store_with_mismatched_key, current_block_registry_hash, ): - assert ( - PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION.value() - ), "Memoization is not enabled" + assert PREFECT_MEMOIZE_BLOCK_AUTO_REGISTRATION.value(), ( + "Memoization is not enabled" + ) test_func = AsyncMock() diff --git a/tests/server/database/test_dependencies.py b/tests/server/database/test_dependencies.py index 22804778e7e8..669be67d817d 100644 --- a/tests/server/database/test_dependencies.py +++ b/tests/server/database/test_dependencies.py @@ -30,34 +30,28 @@ async def test_injecting_an_existing_database_database_config(ConnectionConfig): with dependencies.temporary_database_config(ConnectionConfig(None)): db = dependencies.provide_database_interface() - assert type(db.database_config) == ConnectionConfig + assert type(db.database_config) is ConnectionConfig async def test_injecting_a_really_dumb_database_database_config(): class UselessConfiguration(BaseDatabaseConfiguration): - async def engine(self): - ... + async def engine(self): ... - async def session(self, engine): - ... + async def session(self, engine): ... - async def create_db(self, connection, base_metadata): - ... + async def create_db(self, connection, base_metadata): ... - async def drop_db(self, connection, base_metadata): - ... + async def drop_db(self, connection, base_metadata): ... - def is_inmemory(self, engine): - ... + def is_inmemory(self, engine): ... - async def begin_transaction(self, session, locked): - ... + async def begin_transaction(self, session, locked): ... with dependencies.temporary_database_config( UselessConfiguration(connection_url=None) ): db = dependencies.provide_database_interface() - assert type(db.database_config) == UselessConfiguration + assert type(db.database_config) is UselessConfiguration @pytest.mark.parametrize( @@ -66,35 +60,28 @@ async def begin_transaction(self, session, locked): async def test_injecting_existing_query_components(QueryComponents): with dependencies.temporary_query_components(QueryComponents()): db = dependencies.provide_database_interface() - assert type(db.queries) == QueryComponents + assert type(db.queries) is QueryComponents async def test_injecting_really_dumb_query_components(): class ReallyBrokenQueries(BaseQueryComponents): # --- dialect-specific SqlAlchemy bindings - def insert(self, obj): - ... + def insert(self, obj): ... - def greatest(self, *values): - ... + def greatest(self, *values): ... - def least(self, *values): - ... + def least(self, *values): ... # --- dialect-specific JSON handling - def uses_json_strings(self) -> bool: - ... + def uses_json_strings(self) -> bool: ... - def cast_to_json(self, json_obj): - ... + def cast_to_json(self, json_obj): ... - def build_json_object(self, *args): - ... + def build_json_object(self, *args): ... - def json_arr_agg(self, json_array): - ... + def json_arr_agg(self, json_array): ... # --- dialect-optimized subqueries @@ -103,8 +90,7 @@ def make_timestamp_intervals( start_time, end_time, interval, - ): - ... + ): ... def set_state_id_on_inserted_flow_runs_statement( self, @@ -112,22 +98,18 @@ def set_state_id_on_inserted_flow_runs_statement( frs_model, inserted_flow_run_ids, insert_flow_run_states, - ): - ... + ): ... async def get_flow_run_notifications_from_queue(self, session, limit): pass def get_scheduled_flow_runs_from_work_queues( self, limit_per_queue, work_queue_ids, scheduled_before - ): - ... + ): ... - def _get_scheduled_flow_runs_from_work_pool_template_path(self): - ... + def _get_scheduled_flow_runs_from_work_pool_template_path(self): ... - def _build_flow_run_graph_v2_query(self): - ... + def _build_flow_run_graph_v2_query(self): ... async def flow_run_graph_v2( self, @@ -140,7 +122,7 @@ async def flow_run_graph_v2( with dependencies.temporary_query_components(ReallyBrokenQueries()): db = dependencies.provide_database_interface() - assert type(db.queries) == ReallyBrokenQueries + assert type(db.queries) is ReallyBrokenQueries @pytest.mark.parametrize( @@ -149,7 +131,7 @@ async def flow_run_graph_v2( async def test_injecting_existing_orm_configs(ORMConfig): with dependencies.temporary_orm_config(ORMConfig()): db = dependencies.provide_database_interface() - assert type(db.orm) == ORMConfig + assert type(db.orm) is ORMConfig async def test_inject_interface_class(): diff --git a/tests/server/database/test_migrations.py b/tests/server/database/test_migrations.py index 3045d6cf3514..57a465a65de8 100644 --- a/tests/server/database/test_migrations.py +++ b/tests/server/database/test_migrations.py @@ -182,9 +182,9 @@ async def test_backfill_state_name(db, flow): (str(flow_run_id), "foo", "My Custom Name"), (str(null_state_flow_run_id), "null state", None), ] - assert ( - expected_flow_runs == flow_runs - ), "state_name is backfilled for flow runs" + assert expected_flow_runs == flow_runs, ( + "state_name is backfilled for flow runs" + ) task_runs = [ (str(tr[0]), tr[1], tr[2]) @@ -200,9 +200,9 @@ async def test_backfill_state_name(db, flow): (str(task_run_id), "foo-task", "My Custom Name"), (str(null_state_task_run_id), "null-state-task", None), ] - assert ( - expected_task_runs == task_runs - ), "state_name is backfilled for task runs" + assert expected_task_runs == task_runs, ( + "state_name is backfilled for task runs" + ) finally: await run_sync_in_worker_thread(alembic_upgrade) @@ -323,9 +323,9 @@ async def test_backfill_artifacts(db): artifact["type"], artifact["description"], ) - assert ( - result == expected_result - ), "data migration populates artifact_collection table" + assert result == expected_result, ( + "data migration populates artifact_collection table" + ) finally: await run_sync_in_worker_thread(alembic_upgrade) diff --git a/tests/server/models/test_block_registration.py b/tests/server/models/test_block_registration.py index 9d9ea09a8512..5a017e6b6c45 100644 --- a/tests/server/models/test_block_registration.py +++ b/tests/server/models/test_block_registration.py @@ -50,9 +50,9 @@ async def test_full_registration_with_empty_database( assert len(registered_blocks) == expected_number_of_registered_block_types registered_block_slugs = {b.slug for b in registered_blocks} - assert PROTECTED_BLOCKS.issubset( - registered_block_slugs - ), "When changing protected blocks, edit PROTECTED_BLOCKS defined above" + assert PROTECTED_BLOCKS.issubset(registered_block_slugs), ( + "When changing protected blocks, edit PROTECTED_BLOCKS defined above" + ) assert sum(b.is_protected for b in registered_blocks) == len( PROTECTED_BLOCKS ), "When changing protected blocks, edit PROTECTED_BLOCKS defined above" diff --git a/tests/server/orchestration/api/test_deployments.py b/tests/server/orchestration/api/test_deployments.py index 35b23578ec9d..7266caedde0b 100644 --- a/tests/server/orchestration/api/test_deployments.py +++ b/tests/server/orchestration/api/test_deployments.py @@ -29,9 +29,9 @@ def assert_status_events(deployment_name: str, events: List[str]): if event.resource.name == deployment_name ] - assert len(events) == len( - deployment_specific_events - ), f"Expected events {events}, but found {deployment_specific_events}" + assert len(events) == len(deployment_specific_events), ( + f"Expected events {events}, but found {deployment_specific_events}" + ) for i, event in enumerate(events): assert event == deployment_specific_events[i].event @@ -1037,9 +1037,9 @@ async def test_create_deployment_with_concurrency_limit( assert response.status_code == status.HTTP_201_CREATED json_response = response.json() - assert ( - json_response["concurrency_limit"] is None - ), "Deprecated int-only field should be None for backwards-compatibility" + assert json_response["concurrency_limit"] is None, ( + "Deprecated int-only field should be None for backwards-compatibility" + ) global_concurrency_limit = json_response.get("global_concurrency_limit") assert global_concurrency_limit is not None @@ -1146,9 +1146,9 @@ async def test_read_deployment_with_concurrency_limit( assert response.status_code == status.HTTP_200_OK json_response = response.json() - assert ( - json_response["concurrency_limit"] is None - ), "Deprecated int-only field should be None for backwards-compatibility" + assert json_response["concurrency_limit"] is None, ( + "Deprecated int-only field should be None for backwards-compatibility" + ) global_concurrency_limit = json_response.get("global_concurrency_limit") assert global_concurrency_limit is not None diff --git a/tests/server/orchestration/api/test_flow_run_graph_v2.py b/tests/server/orchestration/api/test_flow_run_graph_v2.py index 38b4baf51ff0..5e3e74157140 100644 --- a/tests/server/orchestration/api/test_flow_run_graph_v2.py +++ b/tests/server/orchestration/api/test_flow_run_graph_v2.py @@ -1167,9 +1167,9 @@ async def flow_run_task_artifacts( flow_run, # db.FlowRun, flat_tasks, # list[db.TaskRun], ): # -> list[db.Artifact]: - assert ( - len(flat_tasks) >= 5 - ), "Setup error - this fixture expects to use at least 5 tasks" + assert len(flat_tasks) >= 5, ( + "Setup error - this fixture expects to use at least 5 tasks" + ) task_artifact = db.Artifact( flow_run_id=flow_run.id, @@ -1279,21 +1279,20 @@ async def test_reading_graph_for_flow_run_with_artifacts( flow_run_id=flow_run.id, ) - assert ( - graph.artifacts - == [ - GraphArtifact( - id=expected_top_level_artifact.id, - created=expected_top_level_artifact.created, - key=expected_top_level_artifact.key, - type=expected_top_level_artifact.type, - data=expected_top_level_artifact.data - if expected_top_level_artifact.type == "progress" - else None, - is_latest=True, - ) - ] - ), "Expected artifacts associated with the flow run but not with a task to be included at the roof of the graph." + assert graph.artifacts == [ + GraphArtifact( + id=expected_top_level_artifact.id, + created=expected_top_level_artifact.created, + key=expected_top_level_artifact.key, + type=expected_top_level_artifact.type, + data=expected_top_level_artifact.data + if expected_top_level_artifact.type == "progress" + else None, + is_latest=True, + ) + ], ( + "Expected artifacts associated with the flow run but not with a task to be included at the roof of the graph." + ) expected_graph_artifacts = defaultdict(list) for task_artifact in flow_run_task_artifacts: @@ -1328,9 +1327,9 @@ async def test_artifacts_on_flow_run_graph_limited_by_setting( flow_run_task_artifacts, # List[db.Artifact], ): test_max_artifacts_setting = 2 - assert ( - len(flow_run_task_artifacts) > test_max_artifacts_setting - ), "Setup error - expected total # of graph artifacts to be greater than the limit being used for testing" + assert len(flow_run_task_artifacts) > test_max_artifacts_setting, ( + "Setup error - expected total # of graph artifacts to be greater than the limit being used for testing" + ) with temporary_settings( {PREFECT_API_MAX_FLOW_RUN_GRAPH_ARTIFACTS: test_max_artifacts_setting} diff --git a/tests/server/orchestration/api/test_flow_runs.py b/tests/server/orchestration/api/test_flow_runs.py index afe4c4da5fc8..ac22d3ddd485 100644 --- a/tests/server/orchestration/api/test_flow_runs.py +++ b/tests/server/orchestration/api/test_flow_runs.py @@ -1892,9 +1892,9 @@ async def test_set_flow_run_state_uses_deployment_concurrency_orchestration_for_ "headers": {}, } if client_version: - post_kwargs["headers"][ - "User-Agent" - ] = f"prefect/{client_version} (API 2.19.3)" + post_kwargs["headers"]["User-Agent"] = ( + f"prefect/{client_version} (API 2.19.3)" + ) response = await client.post( f"/flow_runs/{flow_run_with_concurrency_limit.id}/set_state", **post_kwargs, @@ -1998,9 +1998,9 @@ async def test_history_interval_must_be_one_second_or_larger(self, client): history_interval_seconds=0.9, ), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert b"History interval must not be less than 1 second" in response.content @@ -2895,6 +2895,6 @@ async def test_download_flow_run_logs( # number of logs generated plus 1 for the header row expected_line_count = len(flow_run_1_logs) + 1 - assert ( - line_count == expected_line_count - ), f"Expected {expected_line_count} lines, got {line_count}" + assert line_count == expected_line_count, ( + f"Expected {expected_line_count} lines, got {line_count}" + ) diff --git a/tests/server/orchestration/api/test_task_run_subscriptions.py b/tests/server/orchestration/api/test_task_run_subscriptions.py index 08d99137259e..af0fb9cf9d80 100644 --- a/tests/server/orchestration/api/test_task_run_subscriptions.py +++ b/tests/server/orchestration/api/test_task_run_subscriptions.py @@ -326,8 +326,9 @@ async def test_task_queue_scheduled_size_limit(self): ) await queue.put(task_run) - with patch("asyncio.sleep", return_value=None), pytest.raises( - asyncio.TimeoutError + with ( + patch("asyncio.sleep", return_value=None), + pytest.raises(asyncio.TimeoutError), ): extra_task_run = ServerTaskRun( id=uuid4(), @@ -337,9 +338,9 @@ async def test_task_queue_scheduled_size_limit(self): ) await asyncio.wait_for(queue.put(extra_task_run), timeout=0.01) - assert ( - queue._scheduled_queue.qsize() == max_scheduled_size - ), "Queue size should be at its configured limit" + assert queue._scheduled_queue.qsize() == max_scheduled_size, ( + "Queue size should be at its configured limit" + ) async def test_task_queue_retry_size_limit(self): task_key = "test_retry_limit" @@ -356,8 +357,9 @@ async def test_task_queue_retry_size_limit(self): ) await queue.retry(task_run) - with patch("asyncio.sleep", return_value=None), pytest.raises( - asyncio.TimeoutError + with ( + patch("asyncio.sleep", return_value=None), + pytest.raises(asyncio.TimeoutError), ): extra_task_run = ServerTaskRun( id=uuid4(), @@ -367,9 +369,9 @@ async def test_task_queue_retry_size_limit(self): ) await asyncio.wait_for(queue.retry(extra_task_run), timeout=0.01) - assert ( - queue._retry_queue.qsize() == max_retry_size - ), "Retry queue size should be at its configured limit" + assert queue._retry_queue.qsize() == max_retry_size, ( + "Retry queue size should be at its configured limit" + ) @pytest.fixture diff --git a/tests/server/orchestration/api/test_workers.py b/tests/server/orchestration/api/test_workers.py index ead9fd698004..07b86976a502 100644 --- a/tests/server/orchestration/api/test_workers.py +++ b/tests/server/orchestration/api/test_workers.py @@ -144,16 +144,16 @@ async def test_create_duplicate_work_pool(self, client, work_pool): @pytest.mark.parametrize("name", ["", "hi/there", "hi%there"]) async def test_create_work_pool_with_invalid_name(self, client, name): response = await client.post("/work_pools/", json=dict(name=name)) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) @pytest.mark.parametrize("name", ["''", " ", "' ' "]) async def test_create_work_pool_with_emptyish_name(self, client, name): response = await client.post("/work_pools/", json=dict(name=name)) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert "name cannot be an empty string" in response.content.decode() @pytest.mark.parametrize("type", ["PROCESS", "K8S", "AGENT"]) @@ -176,9 +176,9 @@ async def test_create_work_pool_template_validation_missing_keys(self, client): "/work_pools/", json=dict(name="Pool 1", base_job_template={"foo": "bar", "x": ["y"]}), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert ( "The `base_job_template` must contain both a `job_configuration` key and a" " `variables` key." in response.json()["exception_detail"][0]["msg"] @@ -205,9 +205,9 @@ async def test_create_work_pool_template_validation_missing_variables(self, clie "/work_pools/", json=dict(name="Pool 1", base_job_template=missing_variable_template), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert ( "The variables specified in the job configuration template must be " "present as properties in the variables schema. " @@ -241,9 +241,9 @@ async def test_create_work_pool_template_validation_missing_nested_variables( "/work_pools/", json=dict(name="Pool 1", base_job_template=missing_variable_template), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert ( "The variables specified in the job configuration template must be " "present as properties in the variables schema. " @@ -499,9 +499,9 @@ async def test_update_work_pool_invalid_concurrency( f"/work_pools/{work_pool.name}", json=dict(concurrency_limit=-5), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) session.expunge_all() result = await models.workers.read_work_pool( @@ -588,9 +588,9 @@ async def test_update_work_pool_template_validation_missing_keys( f"/work_pools/{name}", json=dict(name=name, base_job_template={"foo": "bar", "x": ["y"]}), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert ( "The `base_job_template` must contain both a `job_configuration` key and a" " `variables` key." in response.json()["exception_detail"][0]["msg"] @@ -628,9 +628,9 @@ async def test_update_work_pool_template_validation_missing_variables( f"/work_pools/{name}", json=dict(name=name, base_job_template=missing_variable_template), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert ( "The variables specified in the job configuration template must be " "present as properties in the variables schema. " @@ -674,9 +674,9 @@ async def test_update_work_pool_template_validation_missing_nested_variables( f"/work_pools/{name}", json=dict(name="Pool 1", base_job_template=missing_variable_template), ) - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert ( "The variables specified in the job configuration template must be " "present as properties in the variables schema. " @@ -1408,9 +1408,9 @@ async def test_worker_heartbeat_does_not_updates_work_pool_status_if_paused( async def test_heartbeat_worker_requires_name(self, client, work_pool): response = await client.post(f"/work_pools/{work_pool.name}/workers/heartbeat") - assert ( - response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY - ), response.text + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY, ( + response.text + ) assert b'"missing","loc":["body","name"]' in response.content async def test_heartbeat_worker_upserts_for_same_name(self, client, work_pool): @@ -1904,9 +1904,9 @@ async def test_updates_last_polled_on_a_full_work_pool( work_queues = parse_obj_as(List[WorkQueue], work_queues_response.json()) for work_queue in work_queues: - assert ( - work_queue.last_polled is not None - ), "Work queue should have updated last_polled" + assert work_queue.last_polled is not None, ( + "Work queue should have updated last_polled" + ) assert work_queue.last_polled > now async def test_updates_statuses_on_a_full_work_pool( diff --git a/tests/server/orchestration/test_core_policy.py b/tests/server/orchestration/test_core_policy.py index 7e13055c8871..c53cab46d5c7 100644 --- a/tests/server/orchestration/test_core_policy.py +++ b/tests/server/orchestration/test_core_policy.py @@ -89,15 +89,15 @@ async def assert_deployment_concurrency_limit( """ await session.refresh(deployment) limit = deployment.global_concurrency_limit - assert ( - limit is not None - ), f"No concurrency limit found for deployment {deployment.id}" - assert ( - limit.limit == expected_limit - ), f"Expected concurrency limit {expected_limit}, but got {limit.limit}" - assert ( - limit.active_slots == expected_active_slots - ), f"Expected {expected_active_slots} active slots, but got {limit.active_slots}" + assert limit is not None, ( + f"No concurrency limit found for deployment {deployment.id}" + ) + assert limit.limit == expected_limit, ( + f"Expected concurrency limit {expected_limit}, but got {limit.limit}" + ) + assert limit.active_slots == expected_active_slots, ( + f"Expected {expected_active_slots} active slots, but got {limit.active_slots}" + ) @pytest.fixture @@ -782,9 +782,9 @@ async def missing_flow_run(self): missing_flow_run, ) - assert ( - ctx.run.flow_run_run_count == 1 - ), "The run count should not be updated if the flow run is missing" + assert ctx.run.flow_run_run_count == 1, ( + "The run count should not be updated if the flow run is missing" + ) class TestPermitRerunningFailedTaskRuns: @@ -824,9 +824,9 @@ async def test_bypasses_terminal_state_rule_if_flow_is_retrying( assert ctx.response_status == SetStateStatus.ACCEPT assert ctx.run.run_count == 0 assert ctx.proposed_state.name == "Retrying" - assert ( - ctx.run.flow_run_run_count == 4 - ), "Orchestration should update the flow run run count tracker" + assert ctx.run.flow_run_run_count == 4, ( + "Orchestration should update the flow run run count tracker" + ) async def test_can_run_again_even_if_exceeding_flow_runs_count( self, @@ -894,9 +894,9 @@ async def test_bypasses_terminal_state_rule_if_configured_automatic_retries_is_e assert ctx.response_status == SetStateStatus.ACCEPT assert ctx.run.run_count == 0 assert ctx.proposed_state.name == "Retrying" - assert ( - ctx.run.flow_run_run_count == 4 - ), "Orchestration should update the flow run run count tracker" + assert ctx.run.flow_run_run_count == 4, ( + "Orchestration should update the flow run run count tracker" + ) async def test_cleans_up_after_invalid_transition( self, diff --git a/tests/server/orchestration/test_rules.py b/tests/server/orchestration/test_rules.py index 16fff38f965f..79a6dbc66bd4 100644 --- a/tests/server/orchestration/test_rules.py +++ b/tests/server/orchestration/test_rules.py @@ -616,21 +616,21 @@ async def cleanup(self, initial_state, validated_state, context): assert await minimal_rule.fizzled() is True - assert ( - await raising_rule.invalid() is False - ), "Rules that error on entry should be fizzled so they can try and clean up" + assert await raising_rule.invalid() is False, ( + "Rules that error on entry should be fizzled so they can try and clean up" + ) assert await raising_rule.fizzled() is True assert outer_before_transition_hook.call_count == 1 assert outer_after_transition_hook.call_count == 0 - assert ( - outer_cleanup_step.call_count == 1 - ), "All rules should clean up side effects" + assert outer_cleanup_step.call_count == 1, ( + "All rules should clean up side effects" + ) assert before_transition_hook.call_count == 1 - assert ( - after_transition_hook.call_count == 0 - ), "The after-transition hook should not run" + assert after_transition_hook.call_count == 0, ( + "The after-transition hook should not run" + ) assert cleanup_step.call_count == 1, "All rules should clean up side effects" assert isinstance(ctx.orchestration_error, RuntimeError) @@ -1201,9 +1201,9 @@ async def after_transition(self, context): assert side_effect == 1 assert before_hook.call_count == 1 - assert ( - after_hook.call_count == 0 - ), "after_transition should not be called if orchestration encountered errors." + assert after_hook.call_count == 0, ( + "after_transition should not be called if orchestration encountered errors." + ) @pytest.mark.parametrize("run_type", ["task", "flow"]) @@ -1465,9 +1465,9 @@ async def test_context_validation_writes_result_artifact( await ctx.validate_proposed_state() assert ctx.run.state.id == ctx.validated_state.id assert ctx.validated_state.id == ctx.proposed_state.id - assert ctx.validated_state.data == { - "value": "some special data" - }, "result data should be attached to the validated state" + assert ctx.validated_state.data == {"value": "some special data"}, ( + "result data should be attached to the validated state" + ) # an artifact should be created with the result data as well if run_type == "task": @@ -1503,9 +1503,9 @@ async def test_context_validation_writes_result_artifact_with_metadata( await ctx.validate_proposed_state() assert ctx.run.state.id == ctx.validated_state.id assert ctx.validated_state.id == ctx.proposed_state.id - assert ctx.validated_state.data == { - "value": "some special data" - }, "sanitized result data should be attached to the validated state" + assert ctx.validated_state.data == {"value": "some special data"}, ( + "sanitized result data should be attached to the validated state" + ) # an artifact should be created with the result data as well if run_type == "task": @@ -1533,9 +1533,9 @@ async def test_context_validation_does_not_write_artifact_when_no_result( await ctx.validate_proposed_state() assert ctx.run.state.id == ctx.validated_state.id assert ctx.validated_state.id == ctx.proposed_state.id - assert ( - ctx.validated_state.data is None - ), "this validated state should have no result" + assert ctx.validated_state.data is None, ( + "this validated state should have no result" + ) # an artifact should be created with the result data as well if run_type == "task": diff --git a/tests/server/services/test_late_runs.py b/tests/server/services/test_late_runs.py index b6f8dc348149..73a167ec50db 100644 --- a/tests/server/services/test_late_runs.py +++ b/tests/server/services/test_late_runs.py @@ -87,9 +87,9 @@ async def now_run(session, flow): async def test_marks_late_run(session, late_run): assert late_run.state.name == "Scheduled" st = late_run.state.state_details.scheduled_time - assert ( - late_run.next_scheduled_start_time == st - ), "Next scheduled time is set by orchestration rules correctly" + assert late_run.next_scheduled_start_time == st, ( + "Next scheduled time is set by orchestration rules correctly" + ) await MarkLateRuns().start(loops=1) @@ -102,9 +102,9 @@ async def test_marks_late_run(session, late_run): async def test_marks_late_run_at_buffer(session, late_run): assert late_run.state.name == "Scheduled" st = late_run.state.state_details.scheduled_time - assert ( - late_run.next_scheduled_start_time == st - ), "Next scheduled time is set by orchestration rules correctly" + assert late_run.next_scheduled_start_time == st, ( + "Next scheduled time is set by orchestration rules correctly" + ) with temporary_settings(updates={PREFECT_API_SERVICES_LATE_RUNS_AFTER_SECONDS: 60}): await MarkLateRuns().start(loops=1) @@ -118,9 +118,9 @@ async def test_marks_late_run_at_buffer(session, late_run): async def test_does_not_mark_run_late_if_within_buffer(session, late_run): assert late_run.state.name == "Scheduled" st = late_run.state.state_details.scheduled_time - assert ( - late_run.next_scheduled_start_time == st - ), "Next scheduled time is set by orchestration rules correctly" + assert late_run.next_scheduled_start_time == st, ( + "Next scheduled time is set by orchestration rules correctly" + ) with temporary_settings(updates={PREFECT_API_SERVICES_LATE_RUNS_AFTER_SECONDS: 61}): await MarkLateRuns().start(loops=1) @@ -134,9 +134,9 @@ async def test_does_not_mark_run_late_if_within_buffer(session, late_run): async def test_does_not_mark_run_late_if_in_future(session, future_run): assert future_run.state.name == "Scheduled" st = future_run.state.state_details.scheduled_time - assert ( - future_run.next_scheduled_start_time == st - ), "Next scheduled time is set by orchestration rules correctly" + assert future_run.next_scheduled_start_time == st, ( + "Next scheduled time is set by orchestration rules correctly" + ) await MarkLateRuns().start(loops=1) @@ -151,9 +151,9 @@ async def test_does_not_mark_run_late_if_now(session, now_run): # run, but it should still be within the 'mark late after' buffer. assert now_run.state.name == "Scheduled" st = now_run.state.state_details.scheduled_time - assert ( - now_run.next_scheduled_start_time == st - ), "Next scheduled time is set by orchestration rules correctly" + assert now_run.next_scheduled_start_time == st, ( + "Next scheduled time is set by orchestration rules correctly" + ) await MarkLateRuns().start(loops=1) diff --git a/tests/server/services/test_scheduler.py b/tests/server/services/test_scheduler.py index b9b67211ef0b..0b7c424bf311 100644 --- a/tests/server/services/test_scheduler.py +++ b/tests/server/services/test_scheduler.py @@ -119,9 +119,9 @@ async def test_create_schedules_from_deployment( expected_dates.update(await schedule.get_dates(service.min_runs)) assert set(expected_dates) == {r.state.state_details.scheduled_time for r in runs} - assert all( - [r.state_name == "Scheduled" for r in runs] - ), "Scheduler sets flow_run.state_name" + assert all([r.state_name == "Scheduled" for r in runs]), ( + "Scheduler sets flow_run.state_name" + ) async def test_create_schedule_respects_max_future_time(flow, session): diff --git a/tests/server/services/test_telemetry.py b/tests/server/services/test_telemetry.py index a5e117c92724..5e8f49385970 100644 --- a/tests/server/services/test_telemetry.py +++ b/tests/server/services/test_telemetry.py @@ -94,10 +94,10 @@ async def test_errors_shutdown_service(error_sens_o_matic_mock, caplog): ] assert len(records) == 1, "An error level log should be emitted" - assert ( - "Failed to send telemetry" in records[0].message - ), "Should inform the user of the failure" + assert "Failed to send telemetry" in records[0].message, ( + "Should inform the user of the failure" + ) - assert ( - "Server error '500 Internal Server Error' for url" in records[0].message - ), "Should include a short version of the exception" + assert "Server error '500 Internal Server Error' for url" in records[0].message, ( + "Should include a short version of the exception" + ) diff --git a/tests/server/utilities/test_connection_leak_warnings.py b/tests/server/utilities/test_connection_leak_warnings.py index 2175a77674cd..94112ca2373d 100644 --- a/tests/server/utilities/test_connection_leak_warnings.py +++ b/tests/server/utilities/test_connection_leak_warnings.py @@ -6,6 +6,7 @@ TEST_CONNECTION_LEAK=true pytest tests/server/utilities/test_connection_leak_warnings.py """ + import os import pytest diff --git a/tests/server/utilities/test_messaging.py b/tests/server/utilities/test_messaging.py index 5664c5a9a5d7..bd912a7950bb 100644 --- a/tests/server/utilities/test_messaging.py +++ b/tests/server/utilities/test_messaging.py @@ -581,15 +581,15 @@ async def handler(message: Message): # Verify all consumers processed equal number of messages messages_per_consumer = num_messages // concurrency for consumer_id in range(1, concurrency + 1): - assert ( - len(processed_by_consumer[consumer_id]) == messages_per_consumer - ), f"Consumer {consumer_id} should process exactly {messages_per_consumer} messages" + assert len(processed_by_consumer[consumer_id]) == messages_per_consumer, ( + f"Consumer {consumer_id} should process exactly {messages_per_consumer} messages" + ) # Verify messages were processed in round-robin order expected_order = [(i % concurrency) + 1 for i in range(num_messages)] - assert ( - processing_order == expected_order - ), "Messages should be distributed in round-robin fashion" + assert processing_order == expected_order, ( + "Messages should be distributed in round-robin fashion" + ) # Verify each consumer got the correct messages for consumer_id in range(1, concurrency + 1): @@ -597,6 +597,6 @@ async def handler(message: Message): actual_indices = [ int(msg.attributes["index"]) for msg in processed_by_consumer[consumer_id] ] - assert ( - actual_indices == expected_indices - ), f"Consumer {consumer_id} should process messages {expected_indices}" + assert actual_indices == expected_indices, ( + f"Consumer {consumer_id} should process messages {expected_indices}" + ) diff --git a/tests/telemetry/instrumentation_tester.py b/tests/telemetry/instrumentation_tester.py index f2df6dbe8281..16742de7861e 100644 --- a/tests/telemetry/instrumentation_tester.py +++ b/tests/telemetry/instrumentation_tester.py @@ -52,8 +52,7 @@ def create_meter_provider(**kwargs) -> Tuple[MeterProvider, InMemoryMetricReader class HasAttributesViaProperty(Protocol): @property - def attributes(self) -> Attributes: - ... + def attributes(self) -> Attributes: ... class HasAttributesViaAttr(Protocol): diff --git a/tests/test_flow_engine.py b/tests/test_flow_engine.py index 0d1fd4496f07..169d9aa2a8a8 100644 --- a/tests/test_flow_engine.py +++ b/tests/test_flow_engine.py @@ -929,9 +929,9 @@ def parent_flow(): assert parent_flow() == "hello" assert flow_run_count == 2, "Parent flow should exhaust retries" - assert ( - child_flow_run_count == 4 - ), "Child flow should run 2 times for each parent run" + assert child_flow_run_count == 4, ( + "Child flow should run 2 times for each parent run" + ) class TestFlowCrashDetection: diff --git a/tests/test_flows.py b/tests/test_flows.py index 1ccb3c2e058c..87adb77e376e 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -1303,13 +1303,13 @@ def parent(): ) assert parent_flow_run_task.task_version == "inner" - assert ( - parent_flow_run_id != child_flow_run_id - ), "The subflow run and parent flow run are distinct" + assert parent_flow_run_id != child_flow_run_id, ( + "The subflow run and parent flow run are distinct" + ) - assert ( - child_state.state_details.task_run_id == parent_flow_run_task.id - ), "The client subflow run state links to the parent task" + assert child_state.state_details.task_run_id == parent_flow_run_task.id, ( + "The client subflow run state links to the parent task" + ) assert all( state.state_details.task_run_id == parent_flow_run_task.id @@ -1325,13 +1325,13 @@ def parent(): parent_flow_run_task.state.state_details.flow_run_id == parent_flow_run_id ), "The parent task belongs to the parent flow" - assert ( - child_flow_run.parent_task_run_id == parent_flow_run_task.id - ), "The server subflow run links to the parent task" + assert child_flow_run.parent_task_run_id == parent_flow_run_task.id, ( + "The server subflow run links to the parent task" + ) - assert ( - child_flow_run.id == child_flow_run_id - ), "The server subflow run id matches the client" + assert child_flow_run.id == child_flow_run_id, ( + "The server subflow run id matches the client" + ) @pytest.mark.skip(reason="Fails with new engine, passed on old engine") async def test_sync_flow_with_async_subflow_and_task_that_awaits_result(self): @@ -2076,9 +2076,9 @@ def my_flow(): log_messages = [log.message for log in logs] assert all([log.task_run_id is None for log in logs]) assert "Hello world!" in log_messages, "Parent log message is present" - assert ( - logs[log_messages.index("Hello world!")].flow_run_id == flow_run_id - ), "Parent log message has correct id" + assert logs[log_messages.index("Hello world!")].flow_run_id == flow_run_id, ( + "Parent log message has correct id" + ) assert "Hello smaller world!" in log_messages, "Child log message is present" assert ( logs[log_messages.index("Hello smaller world!")].flow_run_id @@ -2541,9 +2541,9 @@ def parent_flow(): assert parent_flow() == "hello" assert flow_run_count == 2, "Parent flow should exhaust retries" - assert ( - child_flow_run_count == 4 - ), "Child flow should run 2 times for each parent run" + assert child_flow_run_count == 4, ( + "Child flow should run 2 times for each parent run" + ) def test_global_retry_config(self): with temporary_settings(updates={PREFECT_FLOW_DEFAULT_RETRIES: "1"}): diff --git a/tests/test_logging.py b/tests/test_logging.py index cc359af50cd6..0d34331977ec 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -275,12 +275,12 @@ def test_setup_logging_extra_loggers_does_not_modify_external_logger_level( external_logger = logging.getLogger(ext_name) assert external_logger.level == ext_level, "External logger level was not preserved" if ext_level > logging.NOTSET: - assert external_logger.isEnabledFor( - ext_level - ), "External effective level was not preserved" - assert ( - external_logger.propagate == ext_propagate - ), "External logger propagate was not preserved" + assert external_logger.isEnabledFor(ext_level), ( + "External effective level was not preserved" + ) + assert external_logger.propagate == ext_propagate, ( + "External logger propagate was not preserved" + ) @pytest.fixture diff --git a/tests/test_settings.py b/tests/test_settings.py index af8412df9ab1..356de4d54efa 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -551,9 +551,9 @@ def test_settings_copy_with_update(self): updates={PREFECT_CLIENT_RETRY_EXTRA_CODES: "400,500"}, set_defaults={PREFECT_UNIT_TEST_MODE: False, PREFECT_API_KEY: "TEST"}, ) - assert ( - new_settings.testing.unit_test_mode is True - ), "Not changed, existing value was not default" + assert new_settings.testing.unit_test_mode is True, ( + "Not changed, existing value was not default" + ) assert ( new_settings.api.key is not None and new_settings.api.key.get_secret_value() == "TEST" @@ -750,9 +750,11 @@ def test_loads_when_profile_path_is_not_a_toml_file( assert Settings().testing.test_setting == "FOO" def test_valid_setting_names_matches_supported_settings(self): - assert ( - set(_get_valid_setting_names(Settings)) == set(SUPPORTED_SETTINGS.keys()) - ), "valid_setting_names output did not match supported settings. Please update SUPPORTED_SETTINGS if you are adding or removing a setting." + assert set(_get_valid_setting_names(Settings)) == set( + SUPPORTED_SETTINGS.keys() + ), ( + "valid_setting_names output did not match supported settings. Please update SUPPORTED_SETTINGS if you are adding or removing a setting." + ) class TestSettingAccess: @@ -1228,9 +1230,9 @@ class TestTemporarySettings: def test_temporary_settings(self): assert PREFECT_TEST_MODE.value() is True with temporary_settings(updates={PREFECT_TEST_MODE: False}) as new_settings: - assert ( - PREFECT_TEST_MODE.value_from(new_settings) is False - ), "Yields the new settings" + assert PREFECT_TEST_MODE.value_from(new_settings) is False, ( + "Yields the new settings" + ) assert PREFECT_TEST_MODE.value() is False assert PREFECT_TEST_MODE.value() is True @@ -1259,9 +1261,9 @@ def test_temporary_settings_restores_on_error(self): with temporary_settings(updates={PREFECT_TEST_MODE: False}): raise ValueError() - assert ( - os.environ["PREFECT_TESTING_TEST_MODE"] == "1" - ), "Does not alter os environ." + assert os.environ["PREFECT_TESTING_TEST_MODE"] == "1", ( + "Does not alter os environ." + ) assert PREFECT_TEST_MODE.value() is True @@ -2101,21 +2103,21 @@ def test_equality(self): assert ProfilesCollection( profiles=[foo, bar], active=None - ) == ProfilesCollection( - profiles=[foo, bar] - ), "Explicit and implicit null active should be equal" + ) == ProfilesCollection(profiles=[foo, bar]), ( + "Explicit and implicit null active should be equal" + ) assert ProfilesCollection( profiles=[foo, bar], active="foo" - ) != ProfilesCollection( - profiles=[foo, bar] - ), "One null active should be inequal" + ) != ProfilesCollection(profiles=[foo, bar]), ( + "One null active should be inequal" + ) assert ProfilesCollection( profiles=[foo, bar], active="foo" - ) != ProfilesCollection( - profiles=[foo, bar], active="bar" - ), "Different active should be inequal" + ) != ProfilesCollection(profiles=[foo, bar], active="bar"), ( + "Different active should be inequal" + ) assert ProfilesCollection(profiles=[foo, bar]) == ProfilesCollection( profiles=[ diff --git a/tests/test_task_engine.py b/tests/test_task_engine.py index bf4aacee026a..384c7fb6c530 100644 --- a/tests/test_task_engine.py +++ b/tests/test_task_engine.py @@ -1172,9 +1172,9 @@ async def test_flow(): if i > 0: last_start_time = start_times[i - 1] - assert ( - last_start_time < start_times[i] - ), "Timestamps should be increasing" + assert last_start_time < start_times[i], ( + "Timestamps should be increasing" + ) async def test_global_task_retry_config(self): with temporary_settings(updates={PREFECT_TASK_DEFAULT_RETRIES: "1"}): @@ -1961,9 +1961,9 @@ async def async_task(): second_state = await async_task(return_state=True) assert second_state.is_completed() - assert ( - await first_state.result() != await second_state.result() - ), "Cache did not expire" + assert await first_state.result() != await second_state.result(), ( + "Cache did not expire" + ) async def test_none_policy_with_persist_result_false(self, prefect_client): @task(cache_policy=None, result_storage_key=None, persist_result=False) diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 2440b3a9b295..0377845b9481 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -1285,9 +1285,9 @@ def test_flow(): if i > 0: last_start_time = start_times[i - 1] - assert ( - last_start_time < start_times[i] - ), "Timestamps should be increasing" + assert last_start_time < start_times[i], ( + "Timestamps should be increasing" + ) async def test_global_task_retry_config(self): with temporary_settings(updates={PREFECT_TASK_DEFAULT_RETRIES: "1"}): @@ -4810,8 +4810,7 @@ async def test_task_condition_fn_raises_when_not_a_callable(self): with pytest.raises(TypeError): @task(retry_condition_fn="not a callable") - def my_task(): - ... + def my_task(): ... class TestNestedTasks: diff --git a/tests/utilities/test_callables.py b/tests/utilities/test_callables.py index 943a6bfd1109..b64d094f2518 100644 --- a/tests/utilities/test_callables.py +++ b/tests/utilities/test_callables.py @@ -315,8 +315,7 @@ def test_function_with_pydantic_model_default_across_v1_and_v2(self): class Foo(pydantic.BaseModel): bar: str - def f(foo: Foo = Foo(bar="baz")): - ... + def f(foo: Foo = Foo(bar="baz")): ... schema = callables.parameter_schema(f) assert schema.model_dump_for_openapi() == { @@ -365,8 +364,7 @@ def f( pdate: pendulum.Date = pendulum.date(2025, 1, 1), pduration: pendulum.Duration = pendulum.duration(seconds=5), c: Color = Color.BLUE, - ): - ... + ): ... datetime_schema = { "title": "pdt", diff --git a/tests/utilities/test_collections.py b/tests/utilities/test_collections.py index 7b30c2d6fa81..aab9c8520d00 100644 --- a/tests/utilities/test_collections.py +++ b/tests/utilities/test_collections.py @@ -343,9 +343,9 @@ class RandomPydantic(pydantic.BaseModel): input_model, visit_fn=visit_even_numbers, return_data=True ) - assert ( - output_model.val == input_model.val - ), "The fields value should be used, not the default factory" + assert output_model.val == input_model.val, ( + "The fields value should be used, not the default factory" + ) @pytest.mark.parametrize( "input", @@ -361,9 +361,9 @@ def test_visit_collection_remembers_unset_pydantic_fields(self, input: dict): output_model = visit_collection( input_model, visit_fn=visit_even_numbers, return_data=True ) - assert ( - output_model.model_dump(exclude_unset=True) == input - ), "Unset fields values should be remembered and preserved" + assert output_model.model_dump(exclude_unset=True) == input, ( + "Unset fields values should be remembered and preserved" + ) @pytest.mark.parametrize("immutable", [True, False]) def test_visit_collection_mutation_with_private_pydantic_attributes( @@ -388,9 +388,9 @@ def test_visit_collection_mutation_with_private_pydantic_attributes( # Verify fields set indirectly by checking the expected fields are still set for field in model_instance.model_fields_set: - assert hasattr( - result, field - ), f"The field '{field}' should be set in the result" + assert hasattr(result, field), ( + f"The field '{field}' should be set in the result" + ) def test_visit_collection_recursive_1(self): obj = dict() diff --git a/tests/utilities/test_importtools.py b/tests/utilities/test_importtools.py index d95ebc3773d4..4039a8946859 100644 --- a/tests/utilities/test_importtools.py +++ b/tests/utilities/test_importtools.py @@ -83,9 +83,9 @@ def test_lazy_import_allows_deferred_failure_for_missing_module(): assert isinstance(module, ModuleType) with pytest.raises(ModuleNotFoundError, match="No module named 'flibbidy'") as exc: module.foo - assert ( - "No module named 'flibbidy'" in exc.exconly() - ), "Exception should contain error message" + assert "No module named 'flibbidy'" in exc.exconly(), ( + "Exception should contain error message" + ) def test_lazy_import_includes_help_message_for_missing_modules(): diff --git a/tests/utilities/test_math.py b/tests/utilities/test_math.py index d8475268ae1e..ef05031216b2 100644 --- a/tests/utilities/test_math.py +++ b/tests/utilities/test_math.py @@ -21,6 +21,6 @@ def test_clamped_poisson_intervals(clamping_factor): assert expected_average * 0.97 < observed_average < expected_average * 1.03 - assert max(bunch_of_intervals) < expected_average * ( - 1 + clamping_factor - ), "no intervals should exceed the upper clamp limit" + assert max(bunch_of_intervals) < expected_average * (1 + clamping_factor), ( + "no intervals should exceed the upper clamp limit" + ) diff --git a/tests/utilities/test_pydantic.py b/tests/utilities/test_pydantic.py index 10bed971b300..016c7a4b5cd2 100644 --- a/tests/utilities/test_pydantic.py +++ b/tests/utilities/test_pydantic.py @@ -70,9 +70,9 @@ def test_add_cloudpickle_reduction_with_kwargs(self): # field instead result = cloudpickle.loads(cloudpickle.dumps(model)) - assert ( - result.x == 0 - ), "'x' should return to the default value since it was excluded" + assert result.x == 0, ( + "'x' should return to the default value since it was excluded" + ) assert result.y == "test" diff --git a/tests/utilities/test_timeout.py b/tests/utilities/test_timeout.py index c2c83d3d9c2f..7adfda53433a 100644 --- a/tests/utilities/test_timeout.py +++ b/tests/utilities/test_timeout.py @@ -6,8 +6,7 @@ from prefect.utilities.timeout import timeout, timeout_async -class CustomTimeoutError(TimeoutError): - ... +class CustomTimeoutError(TimeoutError): ... def test_timeout_raises_custom_error_type_sync(): diff --git a/tests/utilities/test_visualization.py b/tests/utilities/test_visualization.py index de3a700d9055..1d4b5131e996 100644 --- a/tests/utilities/test_visualization.py +++ b/tests/utilities/test_visualization.py @@ -347,6 +347,6 @@ async def test_visualize_graph_contents( actual_nodes = set(graph.body) - assert ( - actual_nodes == expected_nodes - ), f"Expected nodes {expected_nodes} but found {actual_nodes}" + assert actual_nodes == expected_nodes, ( + f"Expected nodes {expected_nodes} but found {actual_nodes}" + ) diff --git a/versioneer.py b/versioneer.py index 3d59af5e673d..87256509a7f2 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1833,9 +1833,9 @@ def get_versions(verbose: bool = False) -> Dict[str, Any]: handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or bool(cfg.verbose) # `bool()` used to avoid `None` - assert ( - cfg.versionfile_source is not None - ), "please set versioneer.versionfile_source" + assert cfg.versionfile_source is not None, ( + "please set versioneer.versionfile_source" + ) assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source)