17
17
18
18
//! Logical Expressions: [`Expr`]
19
19
20
- use std:: collections:: HashSet ;
20
+ use std:: collections:: { HashMap , HashSet } ;
21
21
use std:: fmt:: { self , Display , Formatter , Write } ;
22
22
use std:: hash:: { Hash , Hasher } ;
23
23
use std:: mem;
@@ -1380,7 +1380,7 @@ impl Expr {
1380
1380
/// // refs contains "a" and "b"
1381
1381
/// assert_eq!(refs.len(), 2);
1382
1382
/// assert!(refs.contains(&Column::new_unqualified("a")));
1383
- /// assert!(refs.contains(&Column::new_unqualified("b")));
1383
+ /// assert!(refs.contains(&Column::new_unqualified("b")));
1384
1384
/// ```
1385
1385
pub fn column_refs ( & self ) -> HashSet < & Column > {
1386
1386
let mut using_columns = HashSet :: new ( ) ;
@@ -1401,6 +1401,41 @@ impl Expr {
1401
1401
. expect ( "traversal is infallable" ) ;
1402
1402
}
1403
1403
1404
+ /// Return all references to columns and their occurrence counts in the expression.
1405
+ ///
1406
+ /// # Example
1407
+ /// ```
1408
+ /// # use std::collections::HashMap;
1409
+ /// # use datafusion_common::Column;
1410
+ /// # use datafusion_expr::col;
1411
+ /// // For an expression `a + (b * a)`
1412
+ /// let expr = col("a") + (col("b") * col("a"));
1413
+ /// let mut refs = expr.column_refs_counts();
1414
+ /// // refs contains "a" and "b"
1415
+ /// assert_eq!(refs.len(), 2);
1416
+ /// assert_eq!(*refs.get(&Column::new_unqualified("a")).unwrap(), 2);
1417
+ /// assert_eq!(*refs.get(&Column::new_unqualified("b")).unwrap(), 1);
1418
+ /// ```
1419
+ pub fn column_refs_counts ( & self ) -> HashMap < & Column , usize > {
1420
+ let mut map = HashMap :: new ( ) ;
1421
+ self . add_column_ref_counts ( & mut map) ;
1422
+ map
1423
+ }
1424
+
1425
+ /// Adds references to all columns and their occurrence counts in the expression to
1426
+ /// the map.
1427
+ ///
1428
+ /// See [`Self::column_refs`] for details
1429
+ pub fn add_column_ref_counts < ' a > ( & ' a self , map : & mut HashMap < & ' a Column , usize > ) {
1430
+ self . apply ( |expr| {
1431
+ if let Expr :: Column ( col) = expr {
1432
+ * map. entry ( col) . or_default ( ) += 1 ;
1433
+ }
1434
+ Ok ( TreeNodeRecursion :: Continue )
1435
+ } )
1436
+ . expect ( "traversal is infallable" ) ;
1437
+ }
1438
+
1404
1439
/// Returns true if there are any column references in this Expr
1405
1440
pub fn any_column_refs ( & self ) -> bool {
1406
1441
self . exists ( |expr| Ok ( matches ! ( expr, Expr :: Column ( _) ) ) )
0 commit comments