diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/quote.py b/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/quote.py index 67332e3010f24..c7540c3ecce1d 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/quote.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_type_checking/quote.py @@ -65,3 +65,10 @@ def f(): def func(value: DataFrame): ... + + +def f(): + from pandas import DataFrame, Series + + def baz() -> DataFrame | Series: + ... diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs b/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs index 5eee0365e12fa..377b4ca7effa2 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/rules/runtime_import_in_type_checking_block.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use anyhow::Result; +use itertools::Itertools; use rustc_hash::FxHashMap; use ruff_diagnostics::{Diagnostic, Fix, FixAvailability, Violation}; @@ -262,7 +263,7 @@ pub(crate) fn runtime_import_in_type_checking_block( /// Generate a [`Fix`] to quote runtime usages for imports in a type-checking block. fn quote_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> Result { - let mut quote_reference_edits = imports + let quote_reference_edits = imports .iter() .flat_map(|ImportBinding { binding, .. }| { binding.references.iter().filter_map(|reference_id| { @@ -280,14 +281,12 @@ fn quote_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) }) }) .collect::>>()?; - let quote_reference_edit = quote_reference_edits - .pop() - .expect("Expected at least one reference"); - Ok( - Fix::unsafe_edits(quote_reference_edit, quote_reference_edits).isolate(Checker::isolation( - checker.semantic().parent_statement_id(node_id), - )), - ) + + let mut rest = quote_reference_edits.into_iter().dedup(); + let head = rest.next().expect("Expected at least one reference"); + Ok(Fix::unsafe_edits(head, rest).isolate(Checker::isolation( + checker.semantic().parent_statement_id(node_id), + ))) } /// Generate a [`Fix`] to remove runtime imports from a type-checking block. diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs b/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs index cb3adaed08368..e868418fec4ec 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs +++ b/crates/ruff_linter/src/rules/flake8_type_checking/rules/typing_only_runtime_import.rs @@ -1,6 +1,7 @@ use std::borrow::Cow; use anyhow::Result; +use itertools::Itertools; use rustc_hash::FxHashMap; use ruff_diagnostics::{Diagnostic, DiagnosticKind, Fix, FixAvailability, Violation}; @@ -506,7 +507,7 @@ fn fix_imports(checker: &Checker, node_id: NodeId, imports: &[ImportBinding]) -> add_import_edit .into_edits() .into_iter() - .chain(quote_reference_edits), + .chain(quote_reference_edits.into_iter().dedup()), ) .isolate(Checker::isolation( checker.semantic().parent_statement_id(node_id), diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_runtime-import-in-type-checking-block_quote.py.snap b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_runtime-import-in-type-checking-block_quote.py.snap index 0baeba9f62ec1..ae71c56c8195a 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_runtime-import-in-type-checking-block_quote.py.snap +++ b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_runtime-import-in-type-checking-block_quote.py.snap @@ -18,5 +18,7 @@ quote.py:64:28: TCH004 [*] Quote references to `pandas.DataFrame`. Import is in 66 |- def func(value: DataFrame): 66 |+ def func(value: "DataFrame"): 67 67 | ... +68 68 | +69 69 | diff --git a/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_typing-only-third-party-import_quote.py.snap b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_typing-only-third-party-import_quote.py.snap index 9c43070c0f5b8..eb208bedce285 100644 --- a/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_typing-only-third-party-import_quote.py.snap +++ b/crates/ruff_linter/src/rules/flake8_type_checking/snapshots/ruff_linter__rules__flake8_type_checking__tests__quote_typing-only-third-party-import_quote.py.snap @@ -196,4 +196,60 @@ quote.py:54:24: TCH002 Move third-party import `pandas.DataFrame` into a type-ch | = help: Move into type-checking block +quote.py:71:24: TCH002 [*] Move third-party import `pandas.DataFrame` into a type-checking block + | +70 | def f(): +71 | from pandas import DataFrame, Series + | ^^^^^^^^^ TCH002 +72 | +73 | def baz() -> DataFrame | Series: + | + = help: Move into type-checking block + +ℹ Unsafe fix + 1 |+from typing import TYPE_CHECKING + 2 |+ + 3 |+if TYPE_CHECKING: + 4 |+ from pandas import DataFrame, Series +1 5 | def f(): +2 6 | from pandas import DataFrame +3 7 | +-------------------------------------------------------------------------------- +68 72 | +69 73 | +70 74 | def f(): +71 |- from pandas import DataFrame, Series +72 75 | +73 |- def baz() -> DataFrame | Series: + 76 |+ def baz() -> "DataFrame | Series": +74 77 | ... + +quote.py:71:35: TCH002 [*] Move third-party import `pandas.Series` into a type-checking block + | +70 | def f(): +71 | from pandas import DataFrame, Series + | ^^^^^^ TCH002 +72 | +73 | def baz() -> DataFrame | Series: + | + = help: Move into type-checking block + +ℹ Unsafe fix + 1 |+from typing import TYPE_CHECKING + 2 |+ + 3 |+if TYPE_CHECKING: + 4 |+ from pandas import DataFrame, Series +1 5 | def f(): +2 6 | from pandas import DataFrame +3 7 | +-------------------------------------------------------------------------------- +68 72 | +69 73 | +70 74 | def f(): +71 |- from pandas import DataFrame, Series +72 75 | +73 |- def baz() -> DataFrame | Series: + 76 |+ def baz() -> "DataFrame | Series": +74 77 | ... +