Skip to content

Commit 087fdc3

Browse files
authored
feat: Add union_by_name, union_by_name_distinct to DataFrame api (#15489)
* add dataframe union_by_name functions. * Updated tests. * Cargo fmt. * Fixed doc. * Fixed union_by_name docs.
1 parent 7850cef commit 087fdc3

File tree

2 files changed

+171
-0
lines changed

2 files changed

+171
-0
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -685,6 +685,46 @@ impl DataFrame {
685685
})
686686
}
687687

688+
/// Calculate the union of two [`DataFrame`]s using column names, preserving duplicate rows.
689+
///
690+
/// The two [`DataFrame`]s are combined using column names rather than position,
691+
/// filling missing columns with null.
692+
///
693+
///
694+
/// # Example
695+
/// ```
696+
/// # use datafusion::prelude::*;
697+
/// # use datafusion::error::Result;
698+
/// # use datafusion_common::assert_batches_sorted_eq;
699+
/// # #[tokio::main]
700+
/// # async fn main() -> Result<()> {
701+
/// let ctx = SessionContext::new();
702+
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
703+
/// let d2 = df.clone().select_columns(&["b", "c", "a"])?.with_column("d", lit("77"))?;
704+
/// let df = df.union_by_name(d2)?;
705+
/// let expected = vec![
706+
/// "+---+---+---+----+",
707+
/// "| a | b | c | d |",
708+
/// "+---+---+---+----+",
709+
/// "| 1 | 2 | 3 | |",
710+
/// "| 1 | 2 | 3 | 77 |",
711+
/// "+---+---+---+----+"
712+
/// ];
713+
/// # assert_batches_sorted_eq!(expected, &df.collect().await?);
714+
/// # Ok(())
715+
/// # }
716+
/// ```
717+
pub fn union_by_name(self, dataframe: DataFrame) -> Result<DataFrame> {
718+
let plan = LogicalPlanBuilder::from(self.plan)
719+
.union_by_name(dataframe.plan)?
720+
.build()?;
721+
Ok(DataFrame {
722+
session_state: self.session_state,
723+
plan,
724+
projection_requires_validation: true,
725+
})
726+
}
727+
688728
/// Calculate the distinct union of two [`DataFrame`]s.
689729
///
690730
/// The two [`DataFrame`]s must have exactly the same schema. Any duplicate
@@ -724,6 +764,45 @@ impl DataFrame {
724764
})
725765
}
726766

767+
/// Calculate the union of two [`DataFrame`]s using column names with all duplicated rows removed.
768+
///
769+
/// The two [`DataFrame`]s are combined using column names rather than position,
770+
/// filling missing columns with null.
771+
///
772+
///
773+
/// # Example
774+
/// ```
775+
/// # use datafusion::prelude::*;
776+
/// # use datafusion::error::Result;
777+
/// # use datafusion_common::assert_batches_sorted_eq;
778+
/// # #[tokio::main]
779+
/// # async fn main() -> Result<()> {
780+
/// let ctx = SessionContext::new();
781+
/// let df = ctx.read_csv("tests/data/example.csv", CsvReadOptions::new()).await?;
782+
/// let d2 = df.clone().select_columns(&["b", "c", "a"])?;
783+
/// let df = df.union_by_name_distinct(d2)?;
784+
/// let expected = vec![
785+
/// "+---+---+---+",
786+
/// "| a | b | c |",
787+
/// "+---+---+---+",
788+
/// "| 1 | 2 | 3 |",
789+
/// "+---+---+---+"
790+
/// ];
791+
/// # assert_batches_sorted_eq!(expected, &df.collect().await?);
792+
/// # Ok(())
793+
/// # }
794+
/// ```
795+
pub fn union_by_name_distinct(self, dataframe: DataFrame) -> Result<DataFrame> {
796+
let plan = LogicalPlanBuilder::from(self.plan)
797+
.union_by_name_distinct(dataframe.plan)?
798+
.build()?;
799+
Ok(DataFrame {
800+
session_state: self.session_state,
801+
plan,
802+
projection_requires_validation: true,
803+
})
804+
}
805+
727806
/// Return a new `DataFrame` with all duplicated rows removed.
728807
///
729808
/// # Example

datafusion/core/tests/dataframe/mod.rs

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5206,6 +5206,40 @@ fn union_fields() -> UnionFields {
52065206
.collect()
52075207
}
52085208

5209+
#[tokio::test]
5210+
async fn union_literal_is_null_and_not_null() -> Result<()> {
5211+
let str_array_1 = StringArray::from(vec![None::<String>]);
5212+
let str_array_2 = StringArray::from(vec![Some("a")]);
5213+
5214+
let batch_1 =
5215+
RecordBatch::try_from_iter(vec![("arr", Arc::new(str_array_1) as ArrayRef)])?;
5216+
let batch_2 =
5217+
RecordBatch::try_from_iter(vec![("arr", Arc::new(str_array_2) as ArrayRef)])?;
5218+
5219+
let ctx = SessionContext::new();
5220+
ctx.register_batch("union_batch_1", batch_1)?;
5221+
ctx.register_batch("union_batch_2", batch_2)?;
5222+
5223+
let df1 = ctx.table("union_batch_1").await?;
5224+
let df2 = ctx.table("union_batch_2").await?;
5225+
5226+
let batches = df1.union(df2)?.collect().await?;
5227+
let schema = batches[0].schema();
5228+
5229+
for batch in batches {
5230+
// Verify schema is the same for all batches
5231+
if !schema.contains(&batch.schema()) {
5232+
return Err(DataFusionError::Internal(format!(
5233+
"Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}",
5234+
&schema,
5235+
batch.schema()
5236+
)));
5237+
}
5238+
}
5239+
5240+
Ok(())
5241+
}
5242+
52095243
#[tokio::test]
52105244
async fn sparse_union_is_null() {
52115245
// union of [{A=1}, {A=}, {B=3.2}, {B=}, {C="a"}, {C=}]
@@ -5477,6 +5511,64 @@ async fn boolean_dictionary_as_filter() {
54775511
);
54785512
}
54795513

5514+
#[tokio::test]
5515+
async fn test_union_by_name() -> Result<()> {
5516+
let df = create_test_table("test")
5517+
.await?
5518+
.select(vec![col("a"), col("b"), lit(1).alias("c")])?
5519+
.alias("table_alias")?;
5520+
5521+
let df2 = df.clone().select_columns(&["c", "b", "a"])?;
5522+
let result = df.union_by_name(df2)?.sort_by(vec![col("a"), col("b")])?;
5523+
5524+
assert_snapshot!(
5525+
batches_to_sort_string(&result.collect().await?),
5526+
@r"
5527+
+-----------+-----+---+
5528+
| a | b | c |
5529+
+-----------+-----+---+
5530+
| 123AbcDef | 100 | 1 |
5531+
| 123AbcDef | 100 | 1 |
5532+
| CBAdef | 10 | 1 |
5533+
| CBAdef | 10 | 1 |
5534+
| abc123 | 10 | 1 |
5535+
| abc123 | 10 | 1 |
5536+
| abcDEF | 1 | 1 |
5537+
| abcDEF | 1 | 1 |
5538+
+-----------+-----+---+
5539+
"
5540+
);
5541+
Ok(())
5542+
}
5543+
5544+
#[tokio::test]
5545+
async fn test_union_by_name_distinct() -> Result<()> {
5546+
let df = create_test_table("test")
5547+
.await?
5548+
.select(vec![col("a"), col("b"), lit(1).alias("c")])?
5549+
.alias("table_alias")?;
5550+
5551+
let df2 = df.clone().select_columns(&["c", "b", "a"])?;
5552+
let result = df
5553+
.union_by_name_distinct(df2)?
5554+
.sort_by(vec![col("a"), col("b")])?;
5555+
5556+
assert_snapshot!(
5557+
batches_to_sort_string(&result.collect().await?),
5558+
@r"
5559+
+-----------+-----+---+
5560+
| a | b | c |
5561+
+-----------+-----+---+
5562+
| 123AbcDef | 100 | 1 |
5563+
| CBAdef | 10 | 1 |
5564+
| abc123 | 10 | 1 |
5565+
| abcDEF | 1 | 1 |
5566+
+-----------+-----+---+
5567+
"
5568+
);
5569+
Ok(())
5570+
}
5571+
54805572
#[tokio::test]
54815573
async fn test_alias() -> Result<()> {
54825574
let df = create_test_table("test")

0 commit comments

Comments
 (0)