Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix: Validate asof join by args in IR resolving phase #20473

Merged
merged 1 commit into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ pub fn resolve_join(
}

let owned = Arc::unwrap_or_clone;
let mut input_left = input_left.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))
})?;
let mut input_right = input_right.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))
})?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

if options.args.how.is_cross() {
polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");
} else {
Expand All @@ -65,6 +75,21 @@ pub fn resolve_join(

options.args.validation.is_valid_join(&options.args.how)?;

#[cfg(feature = "asof_join")]
if let JoinType::AsOf(opt) = &options.args.how {
match (&opt.left_by, &opt.right_by) {
(None, None) => {},
(Some(l), Some(r)) => {
polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'");
validate_columns_in_input(l, &schema_left, "asof_join")?;
validate_columns_in_input(r, &schema_right, "asof_join")?;
},
_ => {
polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'")
},
}
}

polars_ensure!(
left_on.len() == right_on.len(),
InvalidOperation:
Expand All @@ -76,16 +101,6 @@ pub fn resolve_join(
);
}

let mut input_left = input_left.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))
})?;
let mut input_right = input_right.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))
})?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options)
.map_err(|e| e.context(failed_here!(join schema resolving)))?;

Expand Down Expand Up @@ -120,6 +135,7 @@ pub fn resolve_join(
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right)
.map_err(|e| e.context("'join' failed".into()))?;

// Re-evaluate because of mutable borrows earlier.
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

Expand Down
7 changes: 4 additions & 3 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4568,9 +4568,10 @@ def join_asof(
if by is not None:
by_left_ = [by] if isinstance(by, str) else by
by_right_ = by_left_
elif (by_left is not None) and (by_right is not None):
by_left_ = [by_left] if isinstance(by_left, str) else by_left
by_right_ = [by_right] if isinstance(by_right, str) else by_right
elif (by_left is not None) or (by_right is not None):
by_left_ = [by_left] if isinstance(by_left, str) else by_left # type: ignore[assignment]
by_right_ = [by_right] if isinstance(by_right, str) else by_right # type: ignore[assignment]

else:
# no by
by_left_ = None
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/operations/test_join_asof.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,3 +1196,21 @@ def test_asof_join_by_schema() -> None:
)

assert q.collect_schema() == q.collect().schema


def test_raise_invalid_by_arg_13020() -> None:
df1 = pl.DataFrame({"asOfDate": [date(2020, 1, 1)]})
df2 = pl.DataFrame(
{
"endityId": [date(2020, 1, 1)],
"eventDate": ["A"],
}
)
with pytest.raises(pl.exceptions.InvalidOperationError, match="expected both"):
df1.sort("asOfDate").join_asof(
df2.sort("eventDate"),
left_on="asOfDate",
right_on="eventDate",
by_left=None,
by_right=["entityId"],
)
Loading