diff --git a/crossfire/__init__.py b/crossfire/__init__.py index c7d6aaa..5bbd55a 100644 --- a/crossfire/__init__.py +++ b/crossfire/__init__.py @@ -48,9 +48,12 @@ def flatten(data, nested_columns=None): nested_columns = set(nested_columns or NESTED_COLUMNS) if not nested_columns.issubset(NESTED_COLUMNS): raise NestedColumnError(nested_columns) - if isinstance(data, dict): - keys = set(data.keys()) & nested_columns - for key in keys: - data.update({f"{key}_{k}": v for k, v in data.get(key).items()}) - data.pop(key) + if isinstance(data, list): + if not data: return data + keys = set(data[0].keys()) & nested_columns + for item in data: + for key in keys: + item.update({f"{key}_{k}": v for k, v in item.get(key).items()}) + item.pop(key) + return data diff --git a/tests/test_flatten.py b/tests/test_flatten.py index c905daa..298442c 100644 --- a/tests/test_flatten.py +++ b/tests/test_flatten.py @@ -5,10 +5,12 @@ from crossfire import NestedColumnError, flatten -DICT_DATA = { - "answer": 42, - "contextInfo": {"context1": "info1", "context2": "info2"}, -} +DICT_DATA = [ + { + "answer": 42, + "contextInfo": {"context1": "info1", "context2": "info2"}, + } +] PD_DATA = DataFrame([DICT_DATA]) GEOMETRY = [Point(4, 2)] GEOPD_DATA = GeoDataFrame([DICT_DATA], crs="EPSG:4326", geometry=GEOMETRY) @@ -19,13 +21,19 @@ def teste_flatten_wrong_nested_columns_value_error(): flatten(DICT_DATA, nested_columns=["wrong"]) +def teste_flatten_with_emptylist(): + assert flatten([]) == [] + + # test the flatten function with a dictionary mocking it to assert _flatten_dict function is being called def test_flatten_dict(): flattened_dict = flatten( DICT_DATA, nested_columns=["contextInfo", "transports"] ) - assert flattened_dict == { - "answer": 42, - "contextInfo_context1": "info1", - "contextInfo_context2": "info2", - } + assert flattened_dict == [ + { + "answer": 42, + "contextInfo_context1": "info1", + "contextInfo_context2": "info2", + } + ]