-
-
Notifications
You must be signed in to change notification settings - Fork 55
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
use-chain-from-iterable
check (#293)
- Loading branch information
Showing
4 changed files
with
217 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from dataclasses import dataclass | ||
|
||
from mypy.nodes import ( | ||
ArgKind, | ||
CallExpr, | ||
GeneratorExpr, | ||
ListExpr, | ||
NameExpr, | ||
RefExpr, | ||
) | ||
|
||
from refurb.error import Error | ||
|
||
|
||
@dataclass | ||
class ErrorInfo(Error): | ||
""" | ||
When flattening an list of lists, use the `chain.from_iterable` function | ||
from the `itertools` stdlib package. This function is faster than native | ||
list/generator comprehensions or using `sum()` with a list default. | ||
Bad: | ||
``` | ||
from itertools import chain | ||
rows = [[1, 2], [3, 4]] | ||
# using list comprehension | ||
flat = [col for row in rows for col in row] | ||
# using sum() | ||
flat = sum(rows, []) | ||
# using chain(*x) | ||
flat = chain(*rows) | ||
``` | ||
Good: | ||
``` | ||
from itertools import chain | ||
rows = [[1, 2], [3, 4]] | ||
flat = chain.from_iterable(*rows) | ||
``` | ||
Note: `chain(*x)` may be marginally faster/slower depending on the length | ||
of `x`. Since `*` might potentially expand to a lot of arguments, it is | ||
better to use `chain.from_iterable()` when you are unsure. | ||
""" | ||
|
||
name = "use-chain-from-iterable" | ||
categories = ("itertools", "performance", "readability") | ||
code = 179 | ||
|
||
|
||
def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None: | ||
match node: | ||
case GeneratorExpr( | ||
left_expr=RefExpr(fullname=expr), | ||
sequences=[_, RefExpr(fullname=inner_source)], | ||
indices=[RefExpr(fullname=outer), RefExpr(fullname=inner)], | ||
is_async=[False, False], | ||
condlists=[[], []], | ||
) if expr == inner and inner_source == outer: | ||
old = "... for ... in x for ... in ..." | ||
new = "chain.from_iterable(x)" | ||
|
||
msg = f"Replace `{old}` with `{new}`" | ||
|
||
errors.append(ErrorInfo.from_node(node, msg)) | ||
|
||
case CallExpr( | ||
callee=RefExpr(fullname="builtins.sum"), | ||
args=[_, ListExpr(items=[])], | ||
): | ||
old = "sum(x, [])" | ||
new = "chain.from_iterable(x)" | ||
|
||
msg = f"Replace `{old}` with `{new}`" | ||
|
||
errors.append(ErrorInfo.from_node(node, msg)) | ||
|
||
case CallExpr( | ||
callee=RefExpr(fullname="itertools.chain") as callee, | ||
args=[_], | ||
arg_kinds=[ArgKind.ARG_STAR], | ||
): | ||
chain = ( | ||
"chain" if isinstance(callee, NameExpr) else "itertools.chain" | ||
) | ||
|
||
old = f"{chain}(*x)" | ||
new = f"{chain}.from_iterable(x)" | ||
|
||
msg = f"Replace `{old}` with `{new}`" | ||
|
||
errors.append(ErrorInfo.from_node(node, msg)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from itertools import chain | ||
import itertools | ||
|
||
|
||
rows = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] | ||
|
||
def f(): | ||
return rows | ||
|
||
# these should match | ||
|
||
def flatten_via_generator(rows): | ||
return (col for row in rows for col in row) | ||
|
||
def flatten_via_list_comp(rows): | ||
return [col for row in rows for col in row] | ||
|
||
def flatten_via_set_comp(rows): | ||
return {col for row in rows for col in row} | ||
|
||
def flatten_with_function_source(): | ||
return (col for row in f() for col in row) | ||
|
||
def flatten_via_sum(rows): | ||
return sum(rows, []) | ||
|
||
def flatten_via_chain_splat(rows): | ||
return chain(*rows) | ||
|
||
def flatten_via_chain_splat_2(rows): | ||
return itertools.chain(*rows) | ||
|
||
|
||
# these should not | ||
|
||
def flatten_via_generator_modified(rows): | ||
return (col + 1 for row in rows for col in row) | ||
|
||
def flatten_via_generator_modified_2(rows): | ||
return (col for [row] in rows for col in row) | ||
|
||
def flatten_via_generator_modified_3(rows): | ||
return (col for row in rows for [col] in row) | ||
|
||
def flatten_via_generator_with_if(rows): | ||
return (col for row in rows for col in row if col) | ||
|
||
def flatten_via_generator_with_if_2(rows): | ||
return (col for row in rows if row for col in row) | ||
|
||
def flatten_via_dict_comp(rows): | ||
return {col: "" for row in rows for col in row} | ||
|
||
async def flatten_async_generator(rows): | ||
return (col async for row in rows for col in row) | ||
|
||
async def flatten_async_generator_2(rows): | ||
return (col for row in rows async for col in row) | ||
|
||
async def flatten_async_generator_3(rows): | ||
return (col async for row in rows async for col in row) | ||
|
||
def flatten_via_sum_with_default(rows): | ||
return sum(rows, [1]) | ||
|
||
def flatten_via_chain_without_splat(rows): | ||
return chain(rows) | ||
|
||
def flatten_via_chain_from_iterable(rows): | ||
return chain.from_iterable(rows) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
test/data/err_179.py:13:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` | ||
test/data/err_179.py:16:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` | ||
test/data/err_179.py:19:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` | ||
test/data/err_179.py:22:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)` | ||
test/data/err_179.py:25:12 [FURB179]: Replace `sum(x, [])` with `chain.from_iterable(x)` | ||
test/data/err_179.py:28:12 [FURB179]: Replace `chain(*x)` with `chain.from_iterable(x)` | ||
test/data/err_179.py:31:12 [FURB179]: Replace `itertools.chain(*x)` with `itertools.chain.from_iterable(x)` |