Skip to content

Commit

Permalink
perf: Generalize the arg_sort fast path onto Column (#20437)
Browse files Browse the repository at this point in the history
  • Loading branch information
coastalwhite authored Dec 25, 2024
1 parent 96b7a9a commit a6fffd4
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 49 deletions.
168 changes: 165 additions & 3 deletions crates/polars-core/src/frame/column/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -894,8 +894,167 @@ impl Column {
}

pub fn arg_sort(&self, options: SortOptions) -> IdxCa {
if self.is_empty() {
return IdxCa::from_vec(self.name().clone(), Vec::new());
}

if self.null_count() == self.len() {
// We might need to maintain order so just respect the descending parameter.
let values = if options.descending {
(0..self.len() as IdxSize).rev().collect()
} else {
(0..self.len() as IdxSize).collect()
};

return IdxCa::from_vec(self.name().clone(), values);
}

let is_sorted = Some(self.is_sorted_flag());
let Some(is_sorted) = is_sorted.filter(|v| !matches!(v, IsSorted::Not)) else {
return self.as_materialized_series().arg_sort(options);
};

// Fast path: the data is sorted.
let is_sorted_dsc = matches!(is_sorted, IsSorted::Descending);
let invert = options.descending != is_sorted_dsc;

let mut values = Vec::with_capacity(self.len());

#[inline(never)]
fn extend(
start: IdxSize,
end: IdxSize,
slf: &Column,
values: &mut Vec<IdxSize>,
is_only_nulls: bool,
invert: bool,
maintain_order: bool,
) {
debug_assert!(start <= end);
debug_assert!(start as usize <= slf.len());
debug_assert!(end as usize <= slf.len());

if !invert || is_only_nulls {
values.extend(start..end);
return;
}

// If we don't have to maintain order but we have to invert. Just flip it around.
if !maintain_order {
values.extend((start..end).rev());
return;
}

// If we want to maintain order but we also needs to invert, we need to invert
// per group of items.
//
// @NOTE: Since the column is sorted, arg_unique can also take a fast path and
// just do a single traversal.
let arg_unique = slf
.slice(start as i64, (end - start) as usize)
.arg_unique()
.unwrap();

assert!(!arg_unique.has_nulls());

let num_unique = arg_unique.len();

// Fast path: all items are unique.
if num_unique == (end - start) as usize {
values.extend((start..end).rev());
return;
}

if num_unique == 1 {
values.extend(start..end);
return;
}

let mut prev_idx = end - start;
for chunk in arg_unique.downcast_iter() {
for &idx in chunk.values().as_slice().iter().rev() {
values.extend(start + idx..start + prev_idx);
prev_idx = idx;
}
}
}
macro_rules! extend {
($start:expr, $end:expr) => {
extend!($start, $end, is_only_nulls = false);
};
($start:expr, $end:expr, is_only_nulls = $is_only_nulls:expr) => {
extend(
$start,
$end,
self,
&mut values,
$is_only_nulls,
invert,
options.maintain_order,
);
};
}

let length = self.len() as IdxSize;
let null_count = self.null_count() as IdxSize;

if null_count == 0 {
extend!(0, length);
} else {
let has_nulls_last = self.get(self.len() - 1).unwrap().is_null();
match (options.nulls_last, has_nulls_last) {
(true, true) => {
// Current: Nulls last, Wanted: Nulls last
extend!(0, length - null_count);
extend!(length - null_count, length, is_only_nulls = true);
},
(true, false) => {
// Current: Nulls first, Wanted: Nulls last
extend!(null_count, length);
extend!(0, null_count, is_only_nulls = true);
},
(false, true) => {
// Current: Nulls last, Wanted: Nulls first
extend!(length - null_count, length, is_only_nulls = true);
extend!(0, length - null_count);
},
(false, false) => {
// Current: Nulls first, Wanted: Nulls first
extend!(0, null_count, is_only_nulls = true);
extend!(null_count, length);
},
}
}

// @NOTE: This can theoretically be pushed into the previous operation but it is really
// worth it... probably not...
if let Some((limit, limit_dsc)) = options.limit {
let limit = limit.min(length);

if limit_dsc {
values = values.drain((length - limit) as usize..).collect();
} else {
values.truncate(limit as usize);
}
}

IdxCa::from_vec(self.name().clone(), values)
}

pub fn arg_sort_multiple(
&self,
by: &[Column],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
// @scalar-opt
self.as_materialized_series().arg_sort(options)
self.as_materialized_series().arg_sort_multiple(by, options)
}

pub fn arg_unique(&self) -> PolarsResult<IdxCa> {
match self {
Column::Scalar(s) => Ok(IdxCa::new_vec(s.name().clone(), vec![0])),
_ => self.as_materialized_series().arg_unique(),
}
}

pub fn bit_repr(&self) -> Option<BitRepr> {
Expand Down Expand Up @@ -986,8 +1145,11 @@ impl Column {
}

pub fn is_sorted_flag(&self) -> IsSorted {
// @scalar-opt
self.as_materialized_series().is_sorted_flag()
match self {
Column::Series(s) => s.is_sorted_flag(),
Column::Partitioned(s) => s.partitions().is_sorted_flag(),
Column::Scalar(_) => IsSorted::Ascending,
}
}

pub fn unique(&self) -> PolarsResult<Column> {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/dataframe/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ def test_multiple_column_sort() -> None:
pl.DataFrame({"a": [3, 2, 1], "b": ["b", "a", "a"]}),
)
assert_frame_equal(
df.sort("b", descending=True),
df.sort("b", descending=True, maintain_order=True),
pl.DataFrame({"a": [3, 1, 2], "b": ["b", "a", "a"]}),
)
assert_frame_equal(
Expand Down
2 changes: 2 additions & 0 deletions py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,8 @@ def test_cat_preserve_lexical_ordering_on_concat() -> None:
assert df2["x"].dtype == dtype


# TODO: Bug see: https://github.com/pola-rs/polars/issues/20440
@pytest.mark.may_fail_auto_streaming
def test_cat_append_lexical_sorted_flag() -> None:
df = pl.DataFrame({"x": [0, 1, 1], "y": ["B", "B", "A"]}).with_columns(
pl.col("y").cast(pl.Categorical(ordering="lexical"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def test_global_and_local(
yield


# @TODO: Bug, see https://github.com/pola-rs/polars/issues/20440
@pytest.mark.may_fail_auto_streaming
def test_categorical_lexical_sort() -> None:
df = pl.DataFrame(
{"cats": ["z", "z", "k", "a", "b"], "vals": [3, 1, 2, 2, 3]}
Expand Down
10 changes: 4 additions & 6 deletions py-polars/tests/unit/operations/test_interpolate_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,15 @@ def test_interpolate_by_leading_nulls() -> None:
}
)
result = df.select(pl.col("values").interpolate_by("times"))
expected = pl.DataFrame(
{"values": [None, None, None, 1.0, 1.7999999999999998, 4.6, 5.0]}
)
expected = pl.DataFrame({"values": [None, None, None, 1.0, 1.8, 4.6, 5.0]})
assert_frame_equal(result, expected)
result = (
df.sort("times", descending=True)
df.sort("times", maintain_order=True, descending=True)
.with_columns(pl.col("values").interpolate_by("times"))
.sort("times")
.sort("times", maintain_order=True)
.drop("times")
)
assert_frame_equal(result, expected)
assert_frame_equal(result, expected, check_exact=False)


@pytest.mark.parametrize("dataset", ["floats", "dates"])
Expand Down
17 changes: 12 additions & 5 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,11 +285,18 @@ def test_join_on_cast() -> None:

df_b = pl.DataFrame({"a": [-2, -3, 3, 10]})

assert df_a.join(df_b, on=pl.col("a").cast(pl.Int64)).to_dict(as_series=False) == {
"index": [1, 2, 3, 5],
"a": [-2, 3, 3, 10],
"a_right": [-2, 3, 3, 10],
}
assert_frame_equal(
df_a.join(df_b, on=pl.col("a").cast(pl.Int64)),
pl.DataFrame(
{
"index": [1, 2, 3, 5],
"a": [-2, 3, 3, 10],
"a_right": [-2, 3, 3, 10],
}
),
check_row_order=False,
check_dtypes=False,
)
assert df_a.lazy().join(
df_b.lazy(), on=pl.col("a").cast(pl.Int64)
).collect().to_dict(as_series=False) == {
Expand Down
Loading

0 comments on commit a6fffd4

Please # to comment.