Skip to content

Commit ee31659

Browse files
committedMay 5, 2022
Fix tests
1 parent 0a7931c commit ee31659

File tree

4 files changed

+86
-54
lines changed

4 files changed

+86
-54
lines changed
 

‎mars/dataframe/groupby/aggregation.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -701,34 +701,36 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
701701
out = op.outputs[0]
702702
for group_key in input_objs[0].groups.keys():
703703
group_objs = [o.get_group(group_key) for o in input_objs]
704+
704705
agg_done = False
705706
if op.stage == OperandStage.map:
706-
result = custom_reduction.pre(group_objs[0])
707+
res_tuple = custom_reduction.pre(group_objs[0])
707708
agg_done = custom_reduction.pre_with_agg
708-
if not isinstance(result, tuple):
709-
result = (result,)
709+
if not isinstance(res_tuple, tuple):
710+
res_tuple = (res_tuple,)
710711
else:
711-
result = group_objs
712+
res_tuple = group_objs
712713

713714
if not agg_done:
714-
result = custom_reduction.agg(*result)
715-
if not isinstance(result, tuple):
716-
result = (result,)
715+
res_tuple = custom_reduction.agg(*res_tuple)
716+
if not isinstance(res_tuple, tuple):
717+
res_tuple = (res_tuple,)
717718

718719
if op.stage == OperandStage.agg:
719-
result = custom_reduction.post(*result)
720-
if not isinstance(result, tuple):
721-
result = (result,)
722-
723-
if out.ndim == 2:
724-
if result[0].ndim == 1:
725-
result = tuple(r.to_frame().T for r in result)
726-
if op.stage == OperandStage.agg:
727-
result = tuple(r.astype(out.dtypes) for r in result)
728-
else:
729-
result = tuple(xdf.Series(r) for r in result)
720+
res_tuple = custom_reduction.post(*res_tuple)
721+
if not isinstance(res_tuple, tuple):
722+
res_tuple = (res_tuple,)
723+
724+
new_res_list = []
725+
for r in res_tuple:
726+
if out.ndim == 2 and r.ndim == 1:
727+
r = r.to_frame().T
728+
elif out.ndim < 2:
729+
if getattr(r, "ndim", 0) == 2:
730+
r = r.iloc[0, :]
731+
else:
732+
r = xdf.Series(r)
730733

731-
for r in result:
732734
if len(input_objs[0].grouper.names) == 1:
733735
r.index = xdf.Index(
734736
[group_key], name=input_objs[0].grouper.names[0]
@@ -737,7 +739,21 @@ def _do_custom_agg(op, custom_reduction, *input_objs):
737739
r.index = xdf.MultiIndex.from_tuples(
738740
[group_key], names=input_objs[0].grouper.names
739741
)
740-
results.append(result)
742+
743+
if op.groupby_params.get("selection"):
744+
# correct columns for groupby-selection-agg paradigms
745+
selection = op.groupby_params["selection"]
746+
r.columns = [selection] if input_objs[0].ndim == 1 else selection
747+
748+
if out.ndim == 2 and op.stage == OperandStage.agg:
749+
dtype_cols = set(out.dtypes.index) & set(r.columns)
750+
conv_dtypes = {
751+
k: v for k, v in out.dtypes.items() if k in dtype_cols
752+
}
753+
r = r.astype(conv_dtypes)
754+
new_res_list.append(r)
755+
756+
results.append(tuple(new_res_list))
741757
if not results and op.stage == OperandStage.agg:
742758
empty_df = pd.DataFrame(
743759
[], columns=out.dtypes.index, index=out.index_value.to_pandas()[:0]

‎mars/dataframe/groupby/tests/test_groupby.py

+21
Original file line numberDiff line numberDiff line change
@@ -476,3 +476,24 @@ def test_groupby_fill():
476476
assert len(r.chunks) == 4
477477
assert r.shape == (len(s1),)
478478
assert r.chunks[0].shape == (np.nan,)
479+
480+
481+
def test_groupby_nunique():
482+
df1 = pd.DataFrame(
483+
[
484+
[1, 1, 10],
485+
[1, 1, np.nan],
486+
[1, 1, np.nan],
487+
[1, 2, np.nan],
488+
[1, 2, 20],
489+
[1, 2, np.nan],
490+
[1, 3, np.nan],
491+
[1, 3, np.nan],
492+
],
493+
columns=["one", "two", "three"],
494+
)
495+
mdf = md.DataFrame(df1, chunk_size=3)
496+
497+
r = tile(mdf.groupby(["one", "two"]).nunique())
498+
assert len(r.chunks) == 1
499+
assert isinstance(r.chunks[0].op, DataFrameGroupByAgg)

‎mars/dataframe/groupby/tests/test_groupby_execution.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1241,13 +1241,16 @@ def test_groupby_nunique(setup):
12411241
# test with as_index=False
12421242
mdf = md.DataFrame(df1, chunk_size=13)
12431243
if _agg_size_as_frame:
1244+
res = mdf.groupby("b", as_index=False)["a"].nunique().execute().fetch()
1245+
expected = df1.groupby("b", as_index=False)["a"].nunique()
12441246
pd.testing.assert_frame_equal(
1245-
mdf.groupby("b", as_index=False)["a"]
1246-
.nunique()
1247-
.execute()
1248-
.fetch()
1249-
.sort_values(by="b", ignore_index=True),
1250-
df1.groupby("b", as_index=False)["a"]
1251-
.nunique()
1252-
.sort_values(by="b", ignore_index=True),
1247+
res.sort_values(by="b", ignore_index=True),
1248+
expected.sort_values(by="b", ignore_index=True),
1249+
)
1250+
1251+
res = mdf.groupby("b", as_index=False)[["a", "c"]].nunique().execute().fetch()
1252+
expected = df1.groupby("b", as_index=False)[["a", "c"]].nunique()
1253+
pd.testing.assert_frame_equal(
1254+
res.sort_values(by="b", ignore_index=True),
1255+
expected.sort_values(by="b", ignore_index=True),
12531256
)

‎mars/dataframe/reduction/nunique.py

+18-26
Original file line numberDiff line numberDiff line change
@@ -41,61 +41,53 @@ def __init__(
4141
self._dropna = dropna
4242
self._use_arrow_dtype = use_arrow_dtype
4343

44-
@staticmethod
45-
def _drop_duplicates_to_arrow(v, explode=False):
44+
def _drop_duplicates(self, xdf, value, explode=False):
4645
if explode:
47-
v = v.explode()
48-
try:
49-
return ArrowListArray([v.drop_duplicates().to_numpy()])
50-
except pa.ArrowInvalid:
51-
# fallback due to diverse dtypes
52-
return [v.drop_duplicates().to_list()]
46+
value = value.explode()
47+
48+
if not self._use_arrow_dtype or xdf is cudf:
49+
return [value.drop_duplicates().to_numpy()]
50+
else:
51+
try:
52+
return ArrowListArray([value.drop_duplicates().to_numpy()])
53+
except pa.ArrowInvalid:
54+
# fallback due to diverse dtypes
55+
return [value.drop_duplicates().to_numpy()]
5356

5457
def pre(self, in_data): # noqa: W0221 # pylint: disable=arguments-differ
5558
xdf = cudf if self.is_gpu() else pd
5659
if isinstance(in_data, xdf.Series):
57-
unique_values = in_data.drop_duplicates()
60+
unique_values = self._drop_duplicates(xdf, in_data)
5861
return xdf.Series(unique_values, name=in_data.name)
5962
else:
6063
if self._axis == 0:
6164
data = dict()
6265
for d, v in in_data.iteritems():
63-
if not self._use_arrow_dtype or xdf is cudf:
64-
data[d] = [v.drop_duplicates().to_list()]
65-
else:
66-
data[d] = self._drop_duplicates_to_arrow(v)
66+
data[d] = self._drop_duplicates(xdf, v)
6767
df = xdf.DataFrame(data)
6868
else:
6969
df = xdf.DataFrame(columns=[0])
7070
for d, v in in_data.iterrows():
71-
if not self._use_arrow_dtype or xdf is cudf:
72-
df.loc[d] = [v.drop_duplicates().to_list()]
73-
else:
74-
df.loc[d] = self._drop_duplicates_to_arrow(v)
71+
df.loc[d] = self._drop_duplicates(xdf, v)
7572
return df
7673

7774
def agg(self, in_data): # noqa: W0221 # pylint: disable=arguments-differ
7875
xdf = cudf if self.is_gpu() else pd
7976
if isinstance(in_data, xdf.Series):
80-
unique_values = in_data.explode().drop_duplicates()
77+
unique_values = self._drop_duplicates(xdf, in_data, explode=True)
8178
return xdf.Series(unique_values, name=in_data.name)
8279
else:
8380
if self._axis == 0:
8481
data = dict()
8582
for d, v in in_data.iteritems():
86-
if not self._use_arrow_dtype or xdf is cudf:
87-
data[d] = [v.explode().drop_duplicates().to_list()]
88-
else:
83+
if self._use_arrow_dtype and xdf is not cudf:
8984
v = pd.Series(v.to_numpy())
90-
data[d] = self._drop_duplicates_to_arrow(v, explode=True)
85+
data[d] = self._drop_duplicates(xdf, v, explode=True)
9186
df = xdf.DataFrame(data)
9287
else:
9388
df = xdf.DataFrame(columns=[0])
9489
for d, v in in_data.iterrows():
95-
if not self._use_arrow_dtype or xdf is cudf:
96-
df.loc[d] = [v.explode().drop_duplicates().to_list()]
97-
else:
98-
df.loc[d] = self._drop_duplicates_to_arrow(v, explode=True)
90+
df.loc[d] = self._drop_duplicates(xdf, v, explode=True)
9991
return df
10092

10193
def post(self, in_data): # noqa: W0221 # pylint: disable=arguments-differ

0 commit comments

Comments
 (0)