From 5f20a7d61d11a7e51c2215b4b645cd9ccdf5e15a Mon Sep 17 00:00:00 2001 From: Hualong Gervais Date: Fri, 17 Jun 2022 12:31:07 -0700 Subject: [PATCH] Enable automatic type deduction for columns built from lists of dataframes. Summary: Enable type inference functions to automatically deduce the row types when a column is built from a list of dataframe objects without supplying the "dtype" parameter. REMARKS: 1. The same issue occurs when supplying a list of columns as input rather than a list of dataframes. This will be addressed in a subsequent commit once this diff is reviewed. 2. The "ta.velox_rt.list_column_cpu.ListColumnCpu" import is placed in an awkward location because of a circular import issue. Reviewed By: OswinC Differential Revision: D37240169 fbshipit-source-id: 29e6156de8ed178af655d6380677b08f8bb63da0 --- torcharrow/dtypes.py | 12 ++++++++++++ torcharrow/test/test_list_column.py | 13 +++++++++++++ torcharrow/test/test_list_column_cpu.py | 3 +++ 3 files changed, 28 insertions(+) diff --git a/torcharrow/dtypes.py b/torcharrow/dtypes.py index 35ad16f9e..0bd4c0574 100644 --- a/torcharrow/dtypes.py +++ b/torcharrow/dtypes.py @@ -624,6 +624,12 @@ def infer_dtype_from_value(value): for t in value: dtypes.append(infer_dtype_from_value(t)) return prt(value, Tuple(dtypes)) + + from torcharrow.velox_rt.dataframe_cpu import DataFrameCpu + + if isinstance(value, DataFrameCpu): + return prt(value, List(value.dtype)) + raise AssertionError(f"unexpected case {value} of type {type(value)}") @@ -729,8 +735,14 @@ def common_dtype(l: DType, r: DType) -> ty.Optional[DType]: if is_list(l) and is_list(r): k = common_dtype(l.item_dtype, r.item_dtype) return List(k).with_null(l.nullable or r.nullable) if k is not None else None + if is_struct(l) and is_struct(r): + if l.fields == r.fields: + return Struct(l.fields, l.nullable or r.nullable) + else: + return None if l.with_null() == r.with_null(): return l if l.nullable else r + return None diff --git a/torcharrow/test/test_list_column.py b/torcharrow/test/test_list_column.py index c81c1efa7..0524d67a2 100644 --- a/torcharrow/test/test_list_column.py +++ b/torcharrow/test/test_list_column.py @@ -208,6 +208,19 @@ def base_test_fixed_size_list(self): f"Unexpected failure reason: {str(ex.exception)}", ) + def base_test_column_from_dataframe_list(self): + a = ta.dataframe({"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}) + b = ta.column([a, a]) + self.assertEqual( + list(b), + [[(1, 5), (2, 6), (3, 7), (4, 8)], [(1, 5), (2, 6), (3, 7), (4, 8)]], + ) + self.assertEqual( + b.dtype, + dt.List(dt.Struct([dt.Field("a", dt.int64), dt.Field("b", dt.int64)])), + ) + self.assertTrue(isinstance(b, ta.velox_rt.list_column_cpu.ListColumnCpu)) + if __name__ == "__main__": unittest.main() diff --git a/torcharrow/test/test_list_column_cpu.py b/torcharrow/test/test_list_column_cpu.py index bd774bc7f..5e2c50704 100644 --- a/torcharrow/test/test_list_column_cpu.py +++ b/torcharrow/test/test_list_column_cpu.py @@ -46,6 +46,9 @@ def test_map_reduce_etc(self): def test_fixed_size_list(self): self.base_test_fixed_size_list() + def test_column_from_dataframe_list(self): + self.base_test_column_from_dataframe_list() + if __name__ == "__main__": unittest.main()