diff --git a/src/mavedb/lib/score_sets.py b/src/mavedb/lib/score_sets.py index 775e067d..0bd68e76 100644 --- a/src/mavedb/lib/score_sets.py +++ b/src/mavedb/lib/score_sets.py @@ -7,7 +7,7 @@ import numpy as np import pandas as pd from pandas.testing import assert_index_equal -from sqlalchemy import Integer, cast, func, or_, select +from sqlalchemy import Integer, cast, func, null, or_, select from sqlalchemy.orm import Session, aliased, contains_eager, joinedload, selectinload from mavedb.lib.exceptions import ValidationError @@ -21,6 +21,7 @@ ) from mavedb.lib.mave.utils import is_csv_null from mavedb.lib.validation.constants.general import null_values_list +from mavedb.models.clinvar_variant import ClinvarVariant from mavedb.models.contributor import Contributor from mavedb.models.controlled_keyword import ControlledKeyword from mavedb.models.doi_identifier import DoiIdentifier @@ -37,6 +38,7 @@ from mavedb.models.score_set_publication_identifier import ( ScoreSetPublicationIdentifierAssociation, ) +from mavedb.models.mapped_variant import MappedVariant from mavedb.models.target_accession import TargetAccession from mavedb.models.target_gene import TargetGene from mavedb.models.target_sequence import TargetSequence @@ -314,11 +316,14 @@ def get_score_set_counts_as_csv( ) -> str: assert type(score_set.dataset_columns) is dict count_columns = [str(x) for x in list(score_set.dataset_columns.get("count_columns", []))] - columns = ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"] + count_columns + # HACK + columns = ( + ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"] + count_columns + ["mavedb_clinsig", "mavedb_reviewstat"] + ) type_column = "count_data" variants_query = ( - select(Variant) + select(Variant, null(), null()) .where(Variant.score_set_id == score_set.id) .order_by(cast(func.split_part(Variant.urn, "#", 2), Integer)) ) @@ -326,9 +331,10 @@ def get_score_set_counts_as_csv( variants_query = variants_query.offset(start) if limit: variants_query = variants_query.limit(limit) - variants = db.scalars(variants_query).all() + variants = db.execute(variants_query).all() - rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column) + # HACK: Hideous hack for expediency... + rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column) # type: ignore stream = io.StringIO() writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL) writer.writeheader() @@ -344,11 +350,21 @@ def get_score_set_scores_as_csv( ) -> str: assert type(score_set.dataset_columns) is dict score_columns = [str(x) for x in list(score_set.dataset_columns.get("score_columns", []))] - columns = ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"] + score_columns + # HACK + columns = ( + ["accession", "hgvs_nt", "hgvs_splice", "hgvs_pro"] + score_columns + ["mavedb_clinsig", "mavedb_reviewstat"] + ) type_column = "score_data" + # HACK: This is a poorly tested and very temporary solution to surface clinical significance and + # clinical review status within the CSV export in a way our front end can handle and display. + current_mapped_variants_subquery = db.query(MappedVariant).filter(MappedVariant.current.is_(True)).subquery() variants_query = ( - select(Variant) + select(Variant, ClinvarVariant.clinical_significance, ClinvarVariant.clinical_review_status) + .join( + current_mapped_variants_subquery, Variant.id == current_mapped_variants_subquery.c.variant_id, isouter=True + ) + .join(ClinvarVariant, current_mapped_variants_subquery.c.clinvar_variant_id == ClinvarVariant.id, isouter=True) .where(Variant.score_set_id == score_set.id) .order_by(cast(func.split_part(Variant.urn, "#", 2), Integer)) ) @@ -356,9 +372,9 @@ def get_score_set_scores_as_csv( variants_query = variants_query.offset(start) if limit: variants_query = variants_query.limit(limit) - variants = db.scalars(variants_query).all() + variants = db.execute(variants_query).all() - rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column) + rows_data = variants_to_csv_rows(variants, columns=columns, dtype=type_column) # type: ignore stream = io.StringIO() writer = csv.DictWriter(stream, fieldnames=columns, quoting=csv.QUOTE_MINIMAL) writer.writeheader() @@ -375,7 +391,9 @@ def is_null(value): return null_values_re.fullmatch(value) or not value -def variant_to_csv_row(variant: Variant, columns: list[str], dtype: str, na_rep="NA") -> dict[str, Any]: +def variant_to_csv_row( + variant: tuple[Variant, str, str], columns: list[str], dtype: str, na_rep="NA" +) -> dict[str, Any]: """ Format a variant into a containing the keys specified in `columns`. @@ -397,24 +415,29 @@ def variant_to_csv_row(variant: Variant, columns: list[str], dtype: str, na_rep= row = {} for column_key in columns: if column_key == "hgvs_nt": - value = str(variant.hgvs_nt) + value = str(variant[0].hgvs_nt) elif column_key == "hgvs_pro": - value = str(variant.hgvs_pro) + value = str(variant[0].hgvs_pro) elif column_key == "hgvs_splice": - value = str(variant.hgvs_splice) + value = str(variant[0].hgvs_splice) elif column_key == "accession": - value = str(variant.urn) + value = str(variant[0].urn) else: - parent = variant.data.get(dtype) if variant.data else None + parent = variant[0].data.get(dtype) if variant[0].data else None value = str(parent.get(column_key)) if parent else na_rep if is_null(value): value = na_rep row[column_key] = value + + # HACK: Overwrite any potential values of ClinVar fields present in the data + # object with db results from the tuple directly. + row["mavedb_clinsig"] = variant[1] + row["mavedb_reviewstat"] = variant[2] return row def variants_to_csv_rows( - variants: Sequence[Variant], columns: list[str], dtype: str, na_rep="NA" + variants: Sequence[tuple[Variant, str, str]], columns: list[str], dtype: str, na_rep="NA" ) -> Iterable[dict[str, Any]]: """ Format each variant into a dictionary row containing the keys specified in `columns`.