Skip to content

Commit 46bcef7

Browse files
itholicHyukjinKwon
authored andcommitted
[SPARK-36438][PYTHON] Support list-like Python objects for Series comparison
### What changes were proposed in this pull request? This PR proposes to implement `Series` comparison with list-like Python objects. Currently `Series` doesn't support the comparison to list-like Python objects such as `list`, `tuple`, `dict`, `set`. **Before** ```python >>> psser 0 1 1 2 2 3 dtype: int64 >>> psser == [3, 2, 1] Traceback (most recent call last): ... TypeError: The operation can not be applied to list. ... ``` **After** ```python >>> psser 0 1 1 2 2 3 dtype: int64 >>> psser == [3, 2, 1] 0 False 1 True 2 False dtype: bool ``` This was originally proposed in databricks/koalas#2022, and all reviews in origin PR has been resolved. ### Why are the changes needed? To follow pandas' behavior. ### Does this PR introduce _any_ user-facing change? Yes, the `Series` comparison with list-like Python objects now possible. ### How was this patch tested? Unittests Closes #34114 from itholic/SPARK-36438. Authored-by: itholic <haejoon.lee@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent f678c75 commit 46bcef7

File tree

5 files changed

+184
-6
lines changed

5 files changed

+184
-6
lines changed

python/pyspark/pandas/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,11 @@ def __abs__(self: IndexOpsLike) -> IndexOpsLike:
394394

395395
# comparison operators
396396
def __eq__(self, other: Any) -> SeriesOrIndex: # type: ignore[override]
397-
return self._dtype_op.eq(self, other)
397+
# pandas always returns False for all items with dict and set.
398+
if isinstance(other, (dict, set)):
399+
return self != self
400+
else:
401+
return self._dtype_op.eq(self, other)
398402

399403
def __ne__(self, other: Any) -> SeriesOrIndex: # type: ignore[override]
400404
return self._dtype_op.ne(self, other)

python/pyspark/pandas/data_type_ops/base.py

+91-4
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,98 @@ def ge(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
376376
raise TypeError(">= can not be applied to %s." % self.pretty_name)
377377

378378
def eq(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
379-
from pyspark.pandas.base import column_op
380-
381-
_sanitize_list_like(right)
379+
if isinstance(right, (list, tuple)):
380+
from pyspark.pandas.series import first_series, scol_for
381+
from pyspark.pandas.frame import DataFrame
382+
from pyspark.pandas.internal import NATURAL_ORDER_COLUMN_NAME, InternalField
383+
384+
len_right = len(right)
385+
if len(left) != len(right):
386+
raise ValueError("Lengths must be equal")
387+
388+
sdf = left._internal.spark_frame
389+
structed_scol = F.struct(
390+
sdf[NATURAL_ORDER_COLUMN_NAME],
391+
*left._internal.index_spark_columns,
392+
left.spark.column
393+
)
394+
# The size of the list is expected to be small.
395+
collected_structed_scol = F.collect_list(structed_scol)
396+
# Sort the array by NATURAL_ORDER_COLUMN so that we can guarantee the order.
397+
collected_structed_scol = F.array_sort(collected_structed_scol)
398+
right_values_scol = F.array([F.lit(x) for x in right]) # type: ignore
399+
index_scol_names = left._internal.index_spark_column_names
400+
scol_name = left._internal.spark_column_name_for(left._internal.column_labels[0])
401+
# Compare the values of left and right by using zip_with function.
402+
cond = F.zip_with(
403+
collected_structed_scol,
404+
right_values_scol,
405+
lambda x, y: F.struct(
406+
*[
407+
x[index_scol_name].alias(index_scol_name)
408+
for index_scol_name in index_scol_names
409+
],
410+
F.when(x[scol_name].isNull() | y.isNull(), False)
411+
.otherwise(
412+
x[scol_name] == y,
413+
)
414+
.alias(scol_name)
415+
),
416+
).alias(scol_name)
417+
# 1. `sdf_new` here looks like the below (the first field of each set is Index):
418+
# +----------------------------------------------------------+
419+
# |0 |
420+
# +----------------------------------------------------------+
421+
# |[{0, false}, {1, true}, {2, false}, {3, true}, {4, false}]|
422+
# +----------------------------------------------------------+
423+
sdf_new = sdf.select(cond)
424+
# 2. `sdf_new` after the explode looks like the below:
425+
# +----------+
426+
# | col|
427+
# +----------+
428+
# |{0, false}|
429+
# | {1, true}|
430+
# |{2, false}|
431+
# | {3, true}|
432+
# |{4, false}|
433+
# +----------+
434+
sdf_new = sdf_new.select(F.explode(scol_name))
435+
# 3. Here, the final `sdf_new` looks like the below:
436+
# +-----------------+-----+
437+
# |__index_level_0__| 0|
438+
# +-----------------+-----+
439+
# | 0|false|
440+
# | 1| true|
441+
# | 2|false|
442+
# | 3| true|
443+
# | 4|false|
444+
# +-----------------+-----+
445+
sdf_new = sdf_new.select("col.*")
446+
447+
index_spark_columns = [
448+
scol_for(sdf_new, index_scol_name) for index_scol_name in index_scol_names
449+
]
450+
data_spark_columns = [scol_for(sdf_new, scol_name)]
451+
452+
internal = left._internal.copy(
453+
spark_frame=sdf_new,
454+
index_spark_columns=index_spark_columns,
455+
data_spark_columns=data_spark_columns,
456+
index_fields=[
457+
InternalField.from_struct_field(index_field)
458+
for index_field in sdf_new.select(index_spark_columns).schema.fields
459+
],
460+
data_fields=[
461+
InternalField.from_struct_field(
462+
sdf_new.select(data_spark_columns).schema.fields[0]
463+
)
464+
],
465+
)
466+
return first_series(DataFrame(internal))
467+
else:
468+
from pyspark.pandas.base import column_op
382469

383-
return column_op(Column.__eq__)(left, right)
470+
return column_op(Column.__eq__)(left, right)
384471

385472
def ne(self, left: IndexOpsLike, right: Any) -> SeriesOrIndex:
386473
from pyspark.pandas.base import column_op

python/pyspark/pandas/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ def rfloordiv(self, other: Any) -> "Series":
675675
koalas = CachedAccessor("koalas", PandasOnSparkSeriesMethods)
676676

677677
# Comparison Operators
678-
def eq(self, other: Any) -> bool:
678+
def eq(self, other: Any) -> "Series":
679679
"""
680680
Compare if the current value is equal to the other.
681681

python/pyspark/pandas/tests/test_ops_on_diff_frames.py

+37
Original file line numberDiff line numberDiff line change
@@ -1845,6 +1845,29 @@ def _test_cov(self, pser1, pser2):
18451845
pscov = psser1.cov(psser2, min_periods=3)
18461846
self.assert_eq(pcov, pscov, almost=True)
18471847

1848+
def test_series_eq(self):
1849+
pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
1850+
psser = ps.from_pandas(pser)
1851+
1852+
# other = Series
1853+
pandas_other = pd.Series([np.nan, 1, 3, 4, np.nan, 6], name="x")
1854+
pandas_on_spark_other = ps.from_pandas(pandas_other)
1855+
self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
1856+
self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())
1857+
1858+
# other = Series with different Index
1859+
pandas_other = pd.Series(
1860+
[np.nan, 1, 3, 4, np.nan, 6], index=[10, 20, 30, 40, 50, 60], name="x"
1861+
)
1862+
pandas_on_spark_other = ps.from_pandas(pandas_other)
1863+
self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
1864+
1865+
# other = Index
1866+
pandas_other = pd.Index([np.nan, 1, 3, 4, np.nan, 6], name="x")
1867+
pandas_on_spark_other = ps.from_pandas(pandas_other)
1868+
self.assert_eq(pser.eq(pandas_other), psser.eq(pandas_on_spark_other).sort_index())
1869+
self.assert_eq(pser == pandas_other, (psser == pandas_on_spark_other).sort_index())
1870+
18481871

18491872
class OpsOnDiffFramesDisabledTest(PandasOnSparkTestCase, SQLTestUtils):
18501873
@classmethod
@@ -2039,6 +2062,20 @@ def test_combine_first(self):
20392062
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
20402063
psdf1.combine_first(psdf2)
20412064

2065+
def test_series_eq(self):
2066+
pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
2067+
psser = ps.from_pandas(pser)
2068+
2069+
others = (
2070+
ps.Series([np.nan, 1, 3, 4, np.nan, 6], name="x"),
2071+
ps.Index([np.nan, 1, 3, 4, np.nan, 6], name="x"),
2072+
)
2073+
for other in others:
2074+
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
2075+
psser.eq(other)
2076+
with self.assertRaisesRegex(ValueError, "Cannot combine the series or dataframe"):
2077+
psser == other
2078+
20422079

20432080
if __name__ == "__main__":
20442081
from pyspark.pandas.tests.test_ops_on_diff_frames import * # noqa: F401

python/pyspark/pandas/tests/test_series.py

+50
Original file line numberDiff line numberDiff line change
@@ -3071,6 +3071,56 @@ def _test_cov(self, pdf):
30713071
pscov = psdf["s1"].cov(psdf["s2"], min_periods=4)
30723072
self.assert_eq(pcov, pscov, almost=True)
30733073

3074+
def test_eq(self):
3075+
pser = pd.Series([1, 2, 3, 4, 5, 6], name="x")
3076+
psser = ps.from_pandas(pser)
3077+
3078+
# other = Series
3079+
self.assert_eq(pser.eq(pser), psser.eq(psser))
3080+
self.assert_eq(pser == pser, psser == psser)
3081+
3082+
# other = dict
3083+
other = {1: None, 2: None, 3: None, 4: None, np.nan: None, 6: None}
3084+
self.assert_eq(pser.eq(other), psser.eq(other))
3085+
self.assert_eq(pser == other, psser == other)
3086+
3087+
# other = set
3088+
other = {1, 2, 3, 4, np.nan, 6}
3089+
self.assert_eq(pser.eq(other), psser.eq(other))
3090+
self.assert_eq(pser == other, psser == other)
3091+
3092+
# other = list
3093+
other = [np.nan, 1, 3, 4, np.nan, 6]
3094+
if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
3095+
self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
3096+
self.assert_eq(pser == other, (psser == other).sort_index())
3097+
else:
3098+
self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
3099+
self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())
3100+
3101+
# other = tuple
3102+
other = (np.nan, 1, 3, 4, np.nan, 6)
3103+
if LooseVersion(pd.__version__) >= LooseVersion("1.2"):
3104+
self.assert_eq(pser.eq(other), psser.eq(other).sort_index())
3105+
self.assert_eq(pser == other, (psser == other).sort_index())
3106+
else:
3107+
self.assert_eq(pser.eq(other).rename("x"), psser.eq(other).sort_index())
3108+
self.assert_eq((pser == other).rename("x"), (psser == other).sort_index())
3109+
3110+
# other = list with the different length
3111+
other = [np.nan, 1, 3, 4, np.nan]
3112+
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
3113+
psser.eq(other)
3114+
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
3115+
psser == other
3116+
3117+
# other = tuple with the different length
3118+
other = (np.nan, 1, 3, 4, np.nan)
3119+
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
3120+
psser.eq(other)
3121+
with self.assertRaisesRegex(ValueError, "Lengths must be equal"):
3122+
psser == other
3123+
30743124

30753125
if __name__ == "__main__":
30763126
from pyspark.pandas.tests.test_series import * # noqa: F401

0 commit comments

Comments
 (0)