Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Fix stata saving (#624)
Browse files Browse the repository at this point in the history
close #618
  • Loading branch information
aguschin authored Mar 6, 2023
1 parent c08038a commit 92782f5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
18 changes: 17 additions & 1 deletion mlem/contrib/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,11 @@ def has_index(df: pd.DataFrame):
return not isinstance(df.index, pd.RangeIndex)


def has_named_index(df: pd.DataFrame):
"""Returns true if all index columns are named"""
return df.index.name or all(df.index.names)


def _reset_index(df: pd.DataFrame):
"""Transforms indexes to columns"""
index_name = df.index.name or "" # save it for future renaming
Expand Down Expand Up @@ -459,8 +464,19 @@ def write(
write_kwargs.update(self.write_args)
write_kwargs.update(kwargs)

# sometimes index may be consumed by model or used at feature engineering step,
# so we keep it instead of dropping if it's non-trivial
if has_index(df):
df = reset_index(df)
if PANDAS_FORMATS[
"stata"
].write_func == self.write_func and not has_named_index(df):
logging.info(
"Stata format doesn't allow saving columns with empty names, so you must name the index."
"Use `df.index.name = 'index'` to name it or df.reset_index(drop=True) to drop it instead."
)
df = df.reset_index(drop=True)
else:
df = reset_index(df)

with storage.open(path) as (f, art):
if self.string_buffer:
Expand Down
38 changes: 38 additions & 0 deletions tests/contrib/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytest
from fsspec.implementations.local import LocalFileSystem
from pydantic import parse_obj_as
from pytest_lazyfixture import lazy_fixture
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

Expand All @@ -27,6 +28,7 @@
PandasWriter,
SeriesType,
get_pandas_batch_formats,
has_index,
pd_type_from_string,
python_type_from_pd_string_repr,
python_type_from_pd_type,
Expand Down Expand Up @@ -510,6 +512,20 @@ def test_import_data_csv(tmpdir, write_csv, file_ext, type_, data):
_check_data(meta, target_path)


def test_import_data_stata(tmpdir, data):
path = str(tmpdir / "mydata.stata")
data.to_stata(path, write_index=False)
meta = import_object(
path, target=path, type_="pandas[stata]", copy_data=True
)
pandas_assert(
data.astype(
"int32"
), # TODO: int32 converts to int64 for some reason for stata
meta.get_value(),
)


@long
def test_import_data_csv_remote(s3_tmp_path, s3_storage_fs, write_csv):
project_path = s3_tmp_path("test_csv_import")
Expand Down Expand Up @@ -619,6 +635,28 @@ def f(x):
assert set(get_object_requirements(sig).modules) == {"pandas"}


@pytest.mark.parametrize(
"df",
[
lazy_fixture("data"),
lazy_fixture("data2"),
],
)
def test_does_not_have_index(df):
assert not has_index(df)


@pytest.mark.parametrize(
"df",
[
PD_DATA_FRAME_INDEX,
PD_DATA_FRAME_MULTIINDEX,
],
)
def test_has_index(df):
assert has_index(df)


# Copyright 2019 Zyfra
# Copyright 2021 Iterative
#
Expand Down

0 comments on commit 92782f5

Please # to comment.