Skip to content

Commit

Permalink
Add use-chain-from-iterable check (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
dosisod authored Oct 4, 2023
1 parent 14cfc38 commit d43b6c4
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 1 deletion.
41 changes: 40 additions & 1 deletion docs/checks.md
Original file line number Diff line number Diff line change
Expand Up @@ -2046,4 +2046,43 @@ Good:
args = ["hello", "world!"]

cmd = shlex.join(args)
```
```

## FURB179: `use-chain-from-iterable`

Categories: `itertools` `performance` `readability`

When flattening an list of lists, use the `chain.from_iterable` function
from the `itertools` stdlib package. This function is faster than native
list/generator comprehensions or using `sum()` with a list default.

Bad:

```python
from itertools import chain

rows = [[1, 2], [3, 4]]

# using list comprehension
flat = [col for row in rows for col in row]

# using sum()
flat = sum(rows, [])

# using chain(*x)
flat = chain(*rows)
```

Good:

```python
from itertools import chain

rows = [[1, 2], [3, 4]]

flat = chain.from_iterable(*rows)
```

Note: `chain(*x)` may be marginally faster/slower depending on the length
of `x`. Since `*` might potentially expand to a lot of arguments, it is
better to use `chain.from_iterable()` when you are unsure.
100 changes: 100 additions & 0 deletions refurb/checks/itertools/use_chain_from_iterable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from dataclasses import dataclass

from mypy.nodes import (
ArgKind,
CallExpr,
GeneratorExpr,
ListExpr,
NameExpr,
RefExpr,
)

from refurb.error import Error


@dataclass
class ErrorInfo(Error):
"""
When flattening an list of lists, use the `chain.from_iterable` function
from the `itertools` stdlib package. This function is faster than native
list/generator comprehensions or using `sum()` with a list default.
Bad:
```
from itertools import chain
rows = [[1, 2], [3, 4]]
# using list comprehension
flat = [col for row in rows for col in row]
# using sum()
flat = sum(rows, [])
# using chain(*x)
flat = chain(*rows)
```
Good:
```
from itertools import chain
rows = [[1, 2], [3, 4]]
flat = chain.from_iterable(*rows)
```
Note: `chain(*x)` may be marginally faster/slower depending on the length
of `x`. Since `*` might potentially expand to a lot of arguments, it is
better to use `chain.from_iterable()` when you are unsure.
"""

name = "use-chain-from-iterable"
categories = ("itertools", "performance", "readability")
code = 179


def check(node: GeneratorExpr | CallExpr, errors: list[Error]) -> None:
match node:
case GeneratorExpr(
left_expr=RefExpr(fullname=expr),
sequences=[_, RefExpr(fullname=inner_source)],
indices=[RefExpr(fullname=outer), RefExpr(fullname=inner)],
is_async=[False, False],
condlists=[[], []],
) if expr == inner and inner_source == outer:
old = "... for ... in x for ... in ..."
new = "chain.from_iterable(x)"

msg = f"Replace `{old}` with `{new}`"

errors.append(ErrorInfo.from_node(node, msg))

case CallExpr(
callee=RefExpr(fullname="builtins.sum"),
args=[_, ListExpr(items=[])],
):
old = "sum(x, [])"
new = "chain.from_iterable(x)"

msg = f"Replace `{old}` with `{new}`"

errors.append(ErrorInfo.from_node(node, msg))

case CallExpr(
callee=RefExpr(fullname="itertools.chain") as callee,
args=[_],
arg_kinds=[ArgKind.ARG_STAR],
):
chain = (
"chain" if isinstance(callee, NameExpr) else "itertools.chain"
)

old = f"{chain}(*x)"
new = f"{chain}.from_iterable(x)"

msg = f"Replace `{old}` with `{new}`"

errors.append(ErrorInfo.from_node(node, msg))
70 changes: 70 additions & 0 deletions test/data/err_179.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from itertools import chain
import itertools


rows = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]

def f():
return rows

# these should match

def flatten_via_generator(rows):
return (col for row in rows for col in row)

def flatten_via_list_comp(rows):
return [col for row in rows for col in row]

def flatten_via_set_comp(rows):
return {col for row in rows for col in row}

def flatten_with_function_source():
return (col for row in f() for col in row)

def flatten_via_sum(rows):
return sum(rows, [])

def flatten_via_chain_splat(rows):
return chain(*rows)

def flatten_via_chain_splat_2(rows):
return itertools.chain(*rows)


# these should not

def flatten_via_generator_modified(rows):
return (col + 1 for row in rows for col in row)

def flatten_via_generator_modified_2(rows):
return (col for [row] in rows for col in row)

def flatten_via_generator_modified_3(rows):
return (col for row in rows for [col] in row)

def flatten_via_generator_with_if(rows):
return (col for row in rows for col in row if col)

def flatten_via_generator_with_if_2(rows):
return (col for row in rows if row for col in row)

def flatten_via_dict_comp(rows):
return {col: "" for row in rows for col in row}

async def flatten_async_generator(rows):
return (col async for row in rows for col in row)

async def flatten_async_generator_2(rows):
return (col for row in rows async for col in row)

async def flatten_async_generator_3(rows):
return (col async for row in rows async for col in row)

def flatten_via_sum_with_default(rows):
return sum(rows, [1])

def flatten_via_chain_without_splat(rows):
return chain(rows)

def flatten_via_chain_from_iterable(rows):
return chain.from_iterable(rows)
7 changes: 7 additions & 0 deletions test/data/err_179.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
test/data/err_179.py:13:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:16:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:19:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:22:12 [FURB179]: Replace `... for ... in x for ... in ...` with `chain.from_iterable(x)`
test/data/err_179.py:25:12 [FURB179]: Replace `sum(x, [])` with `chain.from_iterable(x)`
test/data/err_179.py:28:12 [FURB179]: Replace `chain(*x)` with `chain.from_iterable(x)`
test/data/err_179.py:31:12 [FURB179]: Replace `itertools.chain(*x)` with `itertools.chain.from_iterable(x)`

0 comments on commit d43b6c4

Please # to comment.