Skip to content

Commit dee7a47

Browse files
committed
MultiIndex.difference not working with PyArrow timestamps(pandas-dev#61382)
1 parent e55d907 commit dee7a47

16 files changed

+170
-100
lines changed

pandas/core/indexes/multi.py

+54
Original file line numberDiff line numberDiff line change
@@ -3891,6 +3891,60 @@ def equal_levels(self, other: MultiIndex) -> bool:
38913891
# --------------------------------------------------------------------
38923892
# Set Methods
38933893

3894+
def difference(self, other, sort=None):
3895+
"""
3896+
Return a new MultiIndex with elements from the index not in `other`.
3897+
3898+
Parameters
3899+
----------
3900+
other : MultiIndex or array-like
3901+
sort : bool or None, default None
3902+
Whether to sort the resulting index.
3903+
3904+
Returns
3905+
-------
3906+
MultiIndex
3907+
"""
3908+
if not isinstance(other, MultiIndex):
3909+
other = MultiIndex.from_tuples(other, names=self.names)
3910+
3911+
# Convert 'other' to codes using self's levels
3912+
other_codes = []
3913+
for i, (lev, name) in enumerate(zip(self.levels, self.names)):
3914+
level_vals = other.get_level_values(i)
3915+
other_code = lev.get_indexer(level_vals)
3916+
other_codes.append(other_code)
3917+
3918+
# Create mask for elements not in 'other'
3919+
n = len(self)
3920+
mask = np.ones(n, dtype=bool)
3921+
engine = self._engine
3922+
for codes in zip(*other_codes):
3923+
try:
3924+
loc = engine.get_loc(tuple(codes))
3925+
if isinstance(loc, slice):
3926+
mask[loc] = False
3927+
elif isinstance(loc, np.ndarray):
3928+
mask &= ~loc
3929+
else:
3930+
mask[loc] = False
3931+
except KeyError:
3932+
pass
3933+
3934+
new_codes = [code[mask] for code in self.codes]
3935+
result = MultiIndex(
3936+
levels=self.levels,
3937+
codes=new_codes,
3938+
names=self.names,
3939+
verify_integrity=False,
3940+
)
3941+
if sort is None or sort is True:
3942+
try:
3943+
return result.sort_values()
3944+
except TypeError:
3945+
pass
3946+
return result
3947+
38943948
def _union(self, other, sort) -> MultiIndex:
38953949
other, result_names = self._convert_can_do_setop(other)
38963950
if other.has_duplicates:

pandas/tests/frame/test_query_eval.py

+6-16
Original file line numberDiff line numberDiff line change
@@ -160,21 +160,13 @@ def test_query_empty_string(self):
160160
df.query("")
161161

162162
def test_query_duplicate_column_name(self, engine, parser):
163-
df = DataFrame(
164-
{
165-
"A": range(3),
166-
"B": range(3),
167-
"C": range(3)
168-
}
169-
).rename(columns={"B": "A"})
163+
df = DataFrame({"A": range(3), "B": range(3), "C": range(3)}).rename(
164+
columns={"B": "A"}
165+
)
170166

171-
res = df.query('C == 1', engine=engine, parser=parser)
167+
res = df.query("C == 1", engine=engine, parser=parser)
172168

173-
expect = DataFrame(
174-
[[1, 1, 1]],
175-
columns=["A", "A", "C"],
176-
index=[1]
177-
)
169+
expect = DataFrame([[1, 1, 1]], columns=["A", "A", "C"], index=[1])
178170

179171
tm.assert_frame_equal(res, expect)
180172

@@ -1140,9 +1132,7 @@ def test_query_with_nested_special_character(self, parser, engine):
11401132
[">=", operator.ge],
11411133
],
11421134
)
1143-
def test_query_lex_compare_strings(
1144-
self, parser, engine, op, func
1145-
):
1135+
def test_query_lex_compare_strings(self, parser, engine, op, func):
11461136
a = Series(np.random.default_rng(2).choice(list("abcde"), 20))
11471137
b = Series(np.arange(a.size))
11481138
df = DataFrame({"X": a, "Y": b})

pandas/tests/indexes/multi/test_setops.py

+33
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,39 @@ def test_difference(idx, sort):
195195
first.difference([1, 2, 3, 4, 5], sort=sort)
196196

197197

198+
def test_multiindex_difference_pyarrow_timestamp():
199+
pa = pytest.importorskip("pyarrow")
200+
201+
df = (
202+
DataFrame(
203+
[(1, "1900-01-01", "a"), (2, "1900-01-01", "b")],
204+
columns=["id", "date", "val"],
205+
)
206+
.astype(
207+
{
208+
"id": "int64[pyarrow]",
209+
"date": "timestamp[ns][pyarrow]",
210+
"val": "string[pyarrow]",
211+
}
212+
)
213+
.set_index(["id", "date"])
214+
)
215+
216+
idx = df.index
217+
idx_val = idx[0]
218+
219+
# Assert the value exists in the original index
220+
assert idx_val in idx
221+
222+
# Remove idx_val using difference()
223+
new_idx = idx.difference([idx_val])
224+
225+
# Verify the result
226+
assert len(new_idx) == 1
227+
assert idx_val not in new_idx
228+
assert new_idx.equals(MultiIndex.from_tuples([(2, pd.Timestamp("1900-01-01"))]))
229+
230+
198231
def test_difference_sort_special():
199232
# GH-24959
200233
idx = MultiIndex.from_product([[1, 0], ["a", "b"]])

scripts/check_for_inconsistent_pandas_namespace.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
from typing import NamedTuple
3131

3232
ERROR_MESSAGE = (
33-
"{path}:{lineno}:{col_offset}: "
34-
"Found both '{prefix}.{name}' and '{name}' in {path}"
33+
"{path}:{lineno}:{col_offset}: Found both '{prefix}.{name}' and '{name}' in {path}"
3534
)
3635

3736

scripts/check_test_naming.py

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
NOTE: if this finds a false positive, you can add the comment `# not a test` to the
99
class or function definition. Though hopefully that shouldn't be necessary.
1010
"""
11+
1112
from __future__ import annotations
1213

1314
import argparse

scripts/generate_pip_deps_from_conda.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
generated with this script:
1313
$ python scripts/generate_pip_deps_from_conda.py --compare
1414
"""
15+
1516
import argparse
1617
import pathlib
1718
import re

scripts/pandas_errors_documented.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
77
pre-commit run pandas-errors-documented --all-files
88
"""
9+
910
from __future__ import annotations
1011

1112
import argparse

scripts/sort_whatsnew_note.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
2424
pre-commit run sort-whatsnew-items --all-files
2525
"""
26+
2627
from __future__ import annotations
2728

2829
import argparse

scripts/tests/test_check_test_naming.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,7 @@
2424
0,
2525
),
2626
(
27-
"class Foo: # not a test\n"
28-
" pass\n"
29-
"def test_foo():\n"
30-
" Class.foo()\n",
27+
"class Foo: # not a test\n pass\ndef test_foo():\n Class.foo()\n",
3128
"",
3229
0,
3330
),

scripts/tests/test_inconsistent_namespace_check.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,10 @@
55
)
66

77
BAD_FILE_0 = (
8-
"from pandas import Categorical\n"
9-
"cat_0 = Categorical()\n"
10-
"cat_1 = pd.Categorical()"
8+
"from pandas import Categorical\ncat_0 = Categorical()\ncat_1 = pd.Categorical()"
119
)
1210
BAD_FILE_1 = (
13-
"from pandas import Categorical\n"
14-
"cat_0 = pd.Categorical()\n"
15-
"cat_1 = Categorical()"
11+
"from pandas import Categorical\ncat_0 = pd.Categorical()\ncat_1 = Categorical()"
1612
)
1713
BAD_FILE_2 = (
1814
"from pandas import Categorical\n"

scripts/tests/test_validate_docstrings.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def redundant_import(self, paramx=None, paramy=None) -> None:
3434
--------
3535
>>> import numpy as np
3636
>>> import pandas as pd
37-
>>> df = pd.DataFrame(np.ones((3, 3)),
38-
... columns=('a', 'b', 'c'))
37+
>>> df = pd.DataFrame(np.ones((3, 3)), columns=("a", "b", "c"))
3938
>>> df.all(axis=1)
4039
0 True
4140
1 True
@@ -50,14 +49,14 @@ def unused_import(self) -> None:
5049
Examples
5150
--------
5251
>>> import pandas as pdf
53-
>>> df = pd.DataFrame(np.ones((3, 3)), columns=('a', 'b', 'c'))
52+
>>> df = pd.DataFrame(np.ones((3, 3)), columns=("a", "b", "c"))
5453
"""
5554

5655
def missing_whitespace_around_arithmetic_operator(self) -> None:
5756
"""
5857
Examples
5958
--------
60-
>>> 2+5
59+
>>> 2 + 5
6160
7
6261
"""
6362

@@ -66,14 +65,14 @@ def indentation_is_not_a_multiple_of_four(self) -> None:
6665
Examples
6766
--------
6867
>>> if 2 + 5:
69-
... pass
68+
... pass
7069
"""
7170

7271
def missing_whitespace_after_comma(self) -> None:
7372
"""
7473
Examples
7574
--------
76-
>>> df = pd.DataFrame(np.ones((3,3)),columns=('a','b', 'c'))
75+
>>> df = pd.DataFrame(np.ones((3, 3)), columns=("a", "b", "c"))
7776
"""
7877

7978
def write_array_like_with_hyphen_not_underscore(self) -> None:
@@ -227,13 +226,13 @@ def test_validate_all_ignore_errors(self, monkeypatch):
227226
"errors": [
228227
("ER01", "err desc"),
229228
("ER02", "err desc"),
230-
("ER03", "err desc")
229+
("ER03", "err desc"),
231230
],
232231
"warnings": [],
233232
"examples_errors": "",
234233
"deprecated": True,
235234
"file": "file1",
236-
"file_line": "file_line1"
235+
"file_line": "file_line1",
237236
},
238237
)
239238
monkeypatch.setattr(
@@ -272,14 +271,13 @@ def test_validate_all_ignore_errors(self, monkeypatch):
272271
None: {"ER03"},
273272
"pandas.DataFrame.align": {"ER01"},
274273
# ignoring an error that is not requested should be of no effect
275-
"pandas.Index.all": {"ER03"}
276-
}
274+
"pandas.Index.all": {"ER03"},
275+
},
277276
)
278277
# two functions * two not global ignored errors - one function ignored error
279278
assert exit_status == 2 * 2 - 1
280279

281280

282-
283281
class TestApiItems:
284282
@property
285283
def api_doc(self):

0 commit comments

Comments
 (0)