Skip to content

Commit

Permalink
Solve #230 All-NA hgvs columns in input problem. Remove NA columns fr…
Browse files Browse the repository at this point in the history
…om downloading file. Add some related tests. Failed tests due a bug.
  • Loading branch information
EstelleDa committed Feb 10, 2025
1 parent af5606f commit 9733b91
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 2 deletions.
31 changes: 31 additions & 0 deletions src/mavedb/lib/score_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from mavedb.lib.mave.utils import is_csv_null
from mavedb.lib.validation.constants.general import null_values_list
from mavedb.lib.validation.utilities import is_null as validate_is_null
from mavedb.models.contributor import Contributor
from mavedb.models.controlled_keyword import ControlledKeyword
from mavedb.models.doi_identifier import DoiIdentifier
Expand Down Expand Up @@ -311,6 +312,7 @@ def get_score_set_counts_as_csv(
score_set: ScoreSet,
start: Optional[int] = None,
limit: Optional[int] = None,
download: Optional[bool] = None,
) -> str:
assert type(score_set.dataset_columns) is dict
count_columns = [str(x) for x in list(score_set.dataset_columns.get("count_columns", []))]
Expand All @@ -329,6 +331,9 @@ def get_score_set_counts_as_csv(
variants = db.scalars(variants_query).all()

rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column)
if download:
rows_data, columns = process_downloadable_data(rows_data, columns)

stream = io.StringIO()
writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL)
writer.writeheader()
Expand All @@ -341,6 +346,7 @@ def get_score_set_scores_as_csv(
score_set: ScoreSet,
start: Optional[int] = None,
limit: Optional[int] = None,
download: Optional[bool] = None,
) -> str:
assert type(score_set.dataset_columns) is dict
score_columns = [str(x) for x in list(score_set.dataset_columns.get("score_columns", []))]
Expand All @@ -359,13 +365,38 @@ def get_score_set_scores_as_csv(
variants = db.scalars(variants_query).all()

rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column)
if download:
rows_data, columns = process_downloadable_data(rows_data, columns)

stream = io.StringIO()
writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL)
writer.writeheader()
writer.writerows(rows_data)
return stream.getvalue()


def process_downloadable_data(
rows_data: Iterable[dict[str, Any]],
columns: list[str]
) -> tuple[list[str], list[dict[str, Any]]]:
"""Process rows_data for downloadable CSV by removing empty columns."""
# Convert map to list.
rows_data = list(rows_data)
columns_to_check = ["hgvs_nt", "hgvs_splice", "hgvs_pro"]
columns_to_remove = []

# Check if all values in a column are None or "NA"
for col in columns_to_check:
if all(validate_is_null(row[col]) for row in rows_data):
columns_to_remove.append(col)
for row in rows_data:
row.pop(col, None) # Remove column from each row

# Remove these columns from the header list
columns = [col for col in columns if col not in columns_to_remove]
return rows_data, columns


null_values_re = re.compile(r"\s+|none|nan|na|undefined|n/a|null|nil", flags=re.IGNORECASE)


Expand Down
6 changes: 4 additions & 2 deletions src/mavedb/routers/score_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def get_score_set_scores_csv(
urn: str,
start: int = Query(default=None, description="Start index for pagination"),
limit: int = Query(default=None, description="Number of variants to return"),
download: Optional[bool] = None,
db: Session = Depends(deps.get_db),
user_data: Optional[UserData] = Depends(get_current_user),
) -> Any:
Expand Down Expand Up @@ -214,7 +215,7 @@ def get_score_set_scores_csv(

assert_permission(user_data, score_set, Action.READ)

csv_str = get_score_set_scores_as_csv(db, score_set, start, limit)
csv_str = get_score_set_scores_as_csv(db, score_set, start, limit, download)
return StreamingResponse(iter([csv_str]), media_type="text/csv")


Expand All @@ -234,6 +235,7 @@ async def get_score_set_counts_csv(
urn: str,
start: int = Query(default=None, description="Start index for pagination"),
limit: int = Query(default=None, description="Number of variants to return"),
download: Optional[bool] = None,
db: Session = Depends(deps.get_db),
user_data: Optional[UserData] = Depends(get_current_user),
) -> Any:
Expand Down Expand Up @@ -268,7 +270,7 @@ async def get_score_set_counts_csv(

assert_permission(user_data, score_set, Action.READ)

csv_str = get_score_set_counts_as_csv(db, score_set, start, limit)
csv_str = get_score_set_counts_as_csv(db, score_set, start, limit, download)
return StreamingResponse(iter([csv_str]), media_type="text/csv")


Expand Down
47 changes: 47 additions & 0 deletions tests/routers/test_score_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,3 +1749,50 @@ def test_score_set_not_found_for_non_existent_score_set_when_adding_score_calibr

assert response.status_code == 404
assert "score_calibrations" not in response_data


########################################################################################################################
# Score set download files
########################################################################################################################

# Test file doesn't have hgvs_splice so its values are all NA.
def test_download_scores_file(session, data_provider, client, setup_router_db, data_files):
experiment = create_experiment(client)
score_set = create_seq_score_set_with_variants(
client, session, data_provider, experiment["urn"], data_files / "scores.csv"
)

publish_score_set_response = client.post(f"/api/v1/score-sets/{score_set['urn']}/publish")
assert publish_score_set_response.status_code == 200
publish_score_set = publish_score_set_response.json()
print(publish_score_set)

download_scores_csv_response = client.get(f"/api/v1/score-sets/{publish_score_set['urn']}/scores?download=true")
assert download_scores_csv_response.status_code == 200
download_scores_csv = download_scores_csv_response.text
csv_header = download_scores_csv.split("\n")[0]
columns = csv_header.split(",")
assert "hgvs_nt" in columns
assert "hgvs_pro" in columns
assert "hgvs_splice" not in columns


def test_download_counts_file(session, data_provider, client, setup_router_db, data_files):
experiment = create_experiment(client)
score_set = create_seq_score_set_with_variants(
client, session, data_provider, experiment["urn"],
scores_csv_path=data_files / "scores.csv",
counts_csv_path = data_files / "counts.csv"
)
publish_score_set_response = client.post(f"/api/v1/score-sets/{score_set['urn']}/publish")
assert publish_score_set_response.status_code == 200
publish_score_set = publish_score_set_response.json()

download_counts_csv_response = client.get(f"/api/v1/score-sets/{publish_score_set['urn']}/counts?download=true")
assert download_counts_csv_response.status_code == 200
download_counts_csv = download_counts_csv_response.text
csv_header = download_counts_csv.split("\n")[0]
columns = csv_header.split(",")
assert "hgvs_nt" in columns
assert "hgvs_pro" in columns
assert "hgvs_splice" not in columns

0 comments on commit 9733b91

Please # to comment.