Skip to content

Commit

Permalink
Add required for enqueue_runs
Browse files Browse the repository at this point in the history
  • Loading branch information
kaloster committed Feb 5, 2025
1 parent 53012ae commit 4f9fdce
Show file tree
Hide file tree
Showing 4 changed files with 509 additions and 0 deletions.
212 changes: 212 additions & 0 deletions ingestion_tools/scripts/db_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import logging

import boto3
import click
from botocore import UNSIGNED
from botocore.config import Config

# from importers.db.annotation import AnnotationAuthorDBImporter, AnnotationDBImporter, StaleAnnotationDeletionDBImporter
from importers.db.base_importer import DBImportConfig
from importers.db.dataset import DatasetAuthorDBImporter, DatasetDBImporter, DatasetFundingDBImporter
from importers.db.deposition import DepositionAuthorDBImporter, DepositionDBImporter

# from importers.db.run import RunDBImporter, StaleRunDeletionDBImporter
# from importers.db.tiltseries import StaleTiltSeriesDeletionDBImporter, TiltSeriesDBImporter
# from importers.db.tomogram import StaleTomogramDeletionDBImporter, TomogramAuthorDBImporter, TomogramDBImporter
# from importers.db.voxel_spacing import StaleVoxelSpacingDeletionDBImporter, TomogramVoxelSpacingDBImporter
from s3fs import S3FileSystem

from common import db_models

logger = logging.getLogger("db_import")
logging.basicConfig(level=logging.INFO)


@click.group()
def cli():
pass


def db_import_options(func):
options = []
options.append(click.option("--import-alignments", is_flag=True, default=False))
options.append(click.option("--import-annotations", is_flag=True, default=False))
options.append(click.option("--import-annotation-authors", is_flag=True, default=False))
options.append(click.option("--import-dataset-authors", is_flag=True, default=False))
options.append(click.option("--import-dataset-funding", is_flag=True, default=False))
options.append(click.option("--import-depositions", is_flag=True, default=False))
options.append(click.option("--import-runs", is_flag=True, default=False))
options.append(click.option("--import-tiltseries", is_flag=True, default=False))
options.append(click.option("--import-tomograms", is_flag=True, default=False))
options.append(click.option("--import-tomogram-authors", is_flag=True, default=False))
options.append(click.option("--import-tomogram-voxel-spacing", is_flag=True, default=False))
options.append(click.option("--import-everything", is_flag=True, default=False))
options.append(click.option("--deposition-id", type=str, default=None, multiple=True))
options.append(
click.option(
"--anonymous",
is_flag=True,
required=True,
default=False,
type=bool,
help="Use anonymous access to S3",
),
)
for option in options:
func = option(func)
return func


@cli.command()
@click.argument("s3_bucket", required=True, type=str)
@click.argument("https_prefix", required=True, type=str)
@click.argument("postgres_url", required=True, type=str)
@click.option("--filter-dataset", type=str, default=None, multiple=True)
@click.option("--s3-prefix", required=True, default="", type=str)
@click.option("--endpoint-url", type=str, default=None)
@click.option(
"--debug/--no-debug",
is_flag=True,
required=True,
default=False,
type=bool,
help="Print DB Queries",
)
@db_import_options
def load(
s3_bucket: str,
https_prefix: str,
postgres_url: str,
s3_prefix: str,
anonymous: bool,
debug: bool,
filter_dataset: list[str],
import_alignments: bool, # noqa
import_annotations: bool,
import_annotation_authors: bool,
import_dataset_authors: bool,
import_dataset_funding: bool,
import_depositions: bool,
import_runs: bool,
import_tiltseries: bool,
import_tomograms: bool,
import_tomogram_authors: bool,
import_tomogram_voxel_spacing: bool,
import_everything: bool,
deposition_id: list[str],
endpoint_url: str,
):
db_models.db.init(postgres_url)
if debug:
peewee_logger = logging.getLogger("peewee")
peewee_logger.addHandler(logging.StreamHandler())
peewee_logger.setLevel(logging.DEBUG)
logger.setLevel(logging.DEBUG)

if import_everything:
import_annotations = True
import_annotation_authors = True
import_dataset_authors = True
import_dataset_funding = True
import_depositions = True
import_runs = True
import_tiltseries = True
import_tomograms = True
import_tomogram_authors = True
import_tomogram_voxel_spacing = True
else:
import_annotations = max(import_annotations, import_annotation_authors)
import_tomograms = max(import_tomograms, import_tomogram_authors)
import_tomogram_voxel_spacing = max(import_annotations, import_tomograms, import_tomogram_voxel_spacing)
import_runs = max(import_runs, import_tiltseries, import_tomogram_voxel_spacing)

s3_config = Config(signature_version=UNSIGNED) if anonymous else None
s3_client = boto3.client("s3", endpoint_url=endpoint_url, config=s3_config)
s3fs = S3FileSystem(client_kwargs={"endpoint_url": endpoint_url})
config = DBImportConfig(s3_client, s3fs, s3_bucket, https_prefix)

if import_depositions and deposition_id:
for dep_id in deposition_id:
for deposition_importer in DepositionDBImporter.get_items(config, dep_id):
deposition_obj = deposition_importer.import_to_db()
deposition_authors = DepositionAuthorDBImporter.get_item(deposition_obj, deposition_importer, config)
deposition_authors.import_to_db()

for dataset in DatasetDBImporter.get_items(config, s3_prefix):
if filter_dataset and dataset.dir_prefix not in filter_dataset:
logger.info("Skipping %s...", dataset.dir_prefix)
continue

dataset_obj = dataset.import_to_db()
dataset_id = dataset_obj.id

if import_dataset_authors:
dataset_authors = DatasetAuthorDBImporter.get_item(dataset_id, dataset, config)
dataset_authors.import_to_db()

if import_dataset_funding:
funding = DatasetFundingDBImporter.get_item(dataset_id, dataset, config)
funding.import_to_db()

if not import_runs:
continue

# run_cleaner = StaleRunDeletionDBImporter(dataset_id, config)
# for run in RunDBImporter.get_item(dataset_id, dataset, config):
# logger.info("Processing Run with prefix %s", run.dir_prefix)
# run_obj = run.import_to_db()
# run_id = run_obj.id
# run_cleaner.mark_as_active(run_obj)

# if import_tiltseries:
# tiltseries_cleaner = StaleTiltSeriesDeletionDBImporter(run_id, config)
# tiltseries = TiltSeriesDBImporter.get_item(run_id, run, config)
# if tiltseries:
# tiltseries_obj = tiltseries.import_to_db()
# tiltseries_cleaner.mark_as_active(tiltseries_obj)
# tiltseries_cleaner.remove_stale_objects()

# if not import_tomogram_voxel_spacing:
# continue

# voxel_spacing_cleaner = StaleVoxelSpacingDeletionDBImporter(run_id, config)
# for voxel_spacing in TomogramVoxelSpacingDBImporter.get_items(run_id, run, config):
# voxel_spacing_obj = voxel_spacing.import_to_db()

# if import_tomograms:
# tomogram_cleaner = StaleTomogramDeletionDBImporter(voxel_spacing_obj.id, config)
# for tomogram in TomogramDBImporter.get_item(
# voxel_spacing_obj.id, voxel_spacing, dataset_obj, config,
# ):
# tomogram_obj = tomogram.import_to_db()
# tomogram_cleaner.mark_as_active(tomogram_obj)

# if import_tomogram_authors:
# tomogram_authors = TomogramAuthorDBImporter.get_item(tomogram_obj.id, tomogram, config)
# tomogram_authors.import_to_db()
# tomogram_cleaner.remove_stale_objects()

# if import_annotations:
# annotation_cleaner = StaleAnnotationDeletionDBImporter(voxel_spacing_obj.id, config)
# for annotation in AnnotationDBImporter.get_item(voxel_spacing_obj.id, voxel_spacing, config):
# annotation_obj = annotation.import_to_db()
# annotation_cleaner.mark_as_active(annotation_obj)

# if import_annotation_authors:
# annotation_authors = AnnotationAuthorDBImporter.get_item(
# annotation_obj.id,
# annotation,
# config,
# )
# annotation_authors.import_to_db()
# annotation_cleaner.remove_stale_objects()

# voxel_spacing_cleaner.mark_as_active(voxel_spacing_obj)

# voxel_spacing_cleaner.remove_stale_objects()

# run_cleaner.remove_stale_objects()


if __name__ == "__main__":
cli()
4 changes: 4 additions & 0 deletions ingestion_tools/scripts/enqueue_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
from boto3 import Session
from botocore import UNSIGNED
from botocore.config import Config
from db_import import db_import_options
from importers.dataset import DatasetImporter
from importers.db.base_importer import DBImportConfig
from importers.db.dataset import DatasetDBImporter
from importers.db.deposition import DepositionDBImporter
from importers.deposition import DepositionImporter
from importers.run import RunImporter
from importers.utils import IMPORTERS
Expand Down Expand Up @@ -239,6 +242,7 @@ def to_args(**kwargs) -> list[str]:
default="db_import-v0.0.2.wdl",
help="Specify wdl key for custom workload",
)
@db_import_options
@enqueue_common_options
@click.pass_context
def db_import(
Expand Down
156 changes: 156 additions & 0 deletions ingestion_tools/scripts/importers/db/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Any, Iterable

import importers.db.deposition
from importers.db.base_importer import (
AuthorsStaleDeletionDBImporter,
BaseDBImporter,
DBImportConfig,
StaleDeletionDBImporter,
)

from common import db_models
from common.db_models import BaseModel


class DatasetDBImporter(BaseDBImporter):
def __init__(self, dir_prefix: str, config: DBImportConfig):
self.config = config
self.dir_prefix = dir_prefix
self.parent = None
self.metadata = config.load_key_json(self.get_metadata_file_path())

def get_metadata_file_path(self) -> str:
return self.join_path(self.dir_prefix, "dataset_metadata.json")

def get_data_map(self) -> dict[str, Any]:
return {**self.get_direct_mapped_fields(), **self.get_computed_fields()}

@classmethod
def get_id_fields(cls) -> list[str]:
return ["id"]

@classmethod
def get_db_model_class(cls) -> type[BaseModel]:
return db_models.Dataset

@classmethod
def get_direct_mapped_fields(cls) -> dict[str, Any]:
return {
"id": ["dataset_identifier"],
"title": ["dataset_title"],
"description": ["dataset_description"],
"deposition_date": ["dates", "deposition_date"],
"release_date": ["dates", "release_date"],
"last_modified_date": ["dates", "last_modified_date"],
"related_database_entries": ["cross_references", "related_database_entries"],
"related_database_links": ["cross_references", "related_database_links"],
"dataset_publications": ["cross_references", "publications"],
"dataset_citations": ["cross_references", "dataset_citations"],
"sample_type": ["sample_type"],
"organism_name": ["organism", "name"],
"organism_taxid": ["organism", "taxonomy_id"],
"tissue_name": ["tissue", "name"],
"tissue_id": ["tissue", "id"],
"cell_name": ["cell_type", "name"],
"cell_type_id": ["cell_type", "id"],
"cell_strain_name": ["cell_strain", "name"],
"cell_strain_id": ["cell_strain", "id"],
"cell_component_name": ["cell_component", "name"],
"cell_component_id": ["cell_component", "id"],
"sample_preparation": ["sample_preparation"],
"grid_preparation": ["grid_preparation"],
"other_setup": ["other_setup"],
}

def get_computed_fields(self) -> dict[str, Any]:
extra_data = {
"s3_prefix": self.get_s3_url(self.dir_prefix),
"https_prefix": self.get_https_url(self.dir_prefix),
"key_photo_url": None,
"key_photo_thumbnail_url": None,
}
if database_publications := self.metadata.get("cross_references", {}).get("database_publications"):
extra_data["dataset_publications"] = database_publications

key_photos = self.metadata.get("key_photos", {})
if snapshot_path := key_photos.get("snapshot"):
extra_data["key_photo_url"] = self.get_https_url(snapshot_path)
if thumbnail_path := key_photos.get("thumbnail"):
extra_data["key_photo_thumbnail_url"] = self.get_https_url(thumbnail_path)

deposition = importers.db.deposition.get_deposition(self.config, self.metadata.get("deposition_id"))
extra_data["deposition_id"] = deposition.id
return extra_data

@classmethod
def get_items(cls, config: DBImportConfig, prefix: str) -> Iterable["DatasetDBImporter"]:
return [
cls(dataset_id, config) for dataset_id in config.find_subdirs_with_files(prefix, "dataset_metadata.json")
]


class DatasetAuthorDBImporter(AuthorsStaleDeletionDBImporter):
def __init__(self, dataset_id: int, parent: DatasetDBImporter, config: DBImportConfig):
self.dataset_id = dataset_id
self.parent = parent
self.config = config
self.metadata = parent.metadata.get("authors", [])

def get_data_map(self) -> dict[str, Any]:
return {
"dataset_id": self.dataset_id,
"orcid": ["ORCID"],
"name": ["name"],
"primary_author_status": ["primary_author_status"],
"corresponding_author_status": ["corresponding_author_status"],
"email": ["email"],
"affiliation_name": ["affiliation_name"],
"affiliation_address": ["affiliation_address"],
"affiliation_identifier": ["affiliation_identifier"],
"author_list_order": ["author_list_order"],
}

@classmethod
def get_id_fields(cls) -> list[str]:
return ["dataset_id", "name"]

@classmethod
def get_db_model_class(cls) -> type[BaseModel]:
return db_models.DatasetAuthor

def get_filters(self) -> dict[str, Any]:
return {"dataset_id": self.dataset_id}

@classmethod
def get_item(cls, dataset_id: int, parent: DatasetDBImporter, config: DBImportConfig) -> "DatasetAuthorDBImporter":
return cls(dataset_id, parent, config)


class DatasetFundingDBImporter(StaleDeletionDBImporter):
def __init__(self, dataset_id: int, parent: DatasetDBImporter, config: DBImportConfig):
self.dataset_id = dataset_id
self.parent = parent
self.config = config
self.metadata = parent.metadata.get("funding", [])

def get_data_map(self) -> dict[str, Any]:
return {
"dataset_id": self.dataset_id,
"funding_agency_name": ["funding_agency_name"],
"grant_id": ["grant_id"],
}

@classmethod
def get_id_fields(cls) -> list[str]:
return ["dataset_id", "funding_agency_name", "grant_id"]

@classmethod
def get_db_model_class(cls) -> type[BaseModel]:
return db_models.DatasetFunding

def get_filters(self) -> dict[str, Any]:
return {"dataset_id": self.dataset_id}

@classmethod
def get_item(cls, dataset_id: int, parent: DatasetDBImporter, config: DBImportConfig) -> "DatasetFundingDBImporter":
return cls(dataset_id, parent, config)
Loading

0 comments on commit 4f9fdce

Please # to comment.