From c0498d0ffdfecaa3c8930765f30aa6b5f5faee9b Mon Sep 17 00:00:00 2001 From: antonylebechec Date: Thu, 8 Jun 2023 22:39:39 +0200 Subject: [PATCH] fix #44, #61 annotate parquet Database, #62 export --- howard/objects/database.py | 151 +++++- howard/objects/variants.py | 739 +++-------------------------- tests/test_commons.py | 2 +- tests/test_objects_database.py | 167 ++++++- tests/test_objects_variants.py | 160 ++++--- tests/test_tools_annotation.py | 14 +- tests/test_tools_calculation.py | 14 +- tests/test_tools_convert.py | 14 +- tests/test_tools_prioritization.py | 11 +- tests/test_tools_process.py | 41 +- tests/test_tools_query.py | 19 +- 11 files changed, 580 insertions(+), 752 deletions(-) diff --git a/howard/objects/database.py b/howard/objects/database.py index 257b79a..e2adea4 100644 --- a/howard/objects/database.py +++ b/howard/objects/database.py @@ -1110,7 +1110,123 @@ def is_compressed(self, database:str = None) -> bool: if not database: database = self.get_database() - return get_file_compressed(database) + if type(database) == duckdb.DuckDBPyConnection: + return False + else: + return get_file_compressed(database) + + + def get_header_infos_list(self, database:str = None) -> list: + """ + This function returns a list of header information for a given database or the current database + if none is specified. + + :param database: The `database` parameter is a string that represents the name of the database + from which the header information is to be retrieved. If no database name is provided, the + method will use the default database name obtained from the `get_database()` method + :type database: str + :return: A list of header information from a database, or an empty list if the database header + is not available. + """ + + if not database: + database = self.get_database() + + # Database header + database_header = self.get_header(database=database) + + # Init + database_header_infos_list = [] + + if database_header: + database_header_infos_list = list(database_header.infos) + + return database_header_infos_list + + + def find_column(self, database:str = None, table:str = None, column:str = "INFO", prefixes:list = ["INFO/"]) -> str: + """ + This function finds a specific column in a database table, with the option to search for a + column with a specific prefix or within the INFO column header. + + :param database: The name of the database to search for the column in. If not provided, it will + use the current database that the code is connected to + :type database: str + :param table: The name of the table in the database where the column is located + :type table: str + :param column: The default value for the "column" parameter is "INFO", but it can be changed to + search for a specific column name, defaults to INFO + :type column: str (optional) + :param prefixes: The prefixes parameter is a list of strings that are used to search for a + column with a specific prefix in the database. For example, if the prefixes list contains "DP/", + the function will search for a column named "DP/INFO" in addition to the default "INFO" column + :type prefixes: list + :return: a string that represents the name of the column found in the database, based on the + input parameters. If the column is found, it returns the column name. If the column is not + found, it returns None. + """ + + if not database: + database = self.get_database() + + # Database columns + database_columns = self.get_columns(database=database, table=table) + + # Init + column_found = None + + # Column exists + if column in database_columns: + column_found = column + + # Column with prefix + elif prefixes: + for prefix in prefixes: + if prefix + column in database_columns: + column_found = prefix + column + break + + # Column in INFO column (test if in header) + if not column_found and "INFO" in database_columns : + database_header_infos = self.get_header_infos_list(database=database) + if column in database_header_infos: + column_found = "INFO" + + return column_found + + + def map_columns(self, database:str = None, table:str = None, columns:list = [], prefixes:list = ["INFO/"]) -> dict: + """ + This function maps columns in a database table to their corresponding columns with specified + prefixes. + + :param database: The name of the database to search for columns in. If no database is specified, + the method will use the default database set in the connection + :type database: str + :param table: The name of the table in the database that you want to map the columns for + :type table: str + :param columns: A list of column names that you want to map to their corresponding column names + in the database + :type columns: list + :param prefixes: The `prefixes` parameter is a list of strings that are used to filter the + columns that are searched for. Only columns that start with one of the prefixes in the list will + be considered. In the code above, the default value for `prefixes` is `["INFO/"]`, which + :type prefixes: list + :return: a dictionary that maps the input columns to their corresponding columns found in the + specified database and table, with the specified prefixes. + """ + + if not database: + database = self.get_database() + + # Init + columns_mapping = {} + + for column in columns: + column_found = self.find_column(database=database, table=table, column=column, prefixes=prefixes) + columns_mapping[column] = column_found + + return columns_mapping def get_columns(self, database:str = None, table:str = None) -> list: @@ -1252,7 +1368,7 @@ def get_conn(self): return self.conn - def export(self, output_database:str, output_header:str = None, database:str = None) -> bool: + def export(self, output_database:str, output_header:str = None, database:str = None, table:str = "variants") -> bool: """ This function exports data from a database to a specified output format and compresses it if necessary. @@ -1335,6 +1451,31 @@ def export(self, output_database:str, output_header:str = None, database:str = N query_export_format = f"FORMAT CSV, DELIMITER '{delimiter}', HEADER" include_header = True + # duckDB + elif output_type in ["duckdb"]: + + # Export database as Parquet + database_export_parquet_file = f"""{output_database}.{random_tmp}.database_export.parquet""" + self.export(database=database, output_database=database_export_parquet_file) + + # Create database and connexion + output_database_conn = duckdb.connect(output_database) + + # Create table in database connexion with Parquet file + query_copy = f""" + CREATE TABLE {table} + AS {self.get_sql_database_link(database=database_export_parquet_file)} + """ + output_database_conn.execute(query_copy) + + # Close connexion + output_database_conn.close() + + # remove tmp + remove_if_exists([database_export_parquet_file]) + + return os.path.exists(output_database) + # else: # log.debug("Not available") @@ -1366,7 +1507,7 @@ def export(self, output_database:str, output_header:str = None, database:str = N query_output_database_tmp = output_database else: query_output_database_tmp = f"""{output_database}.{random_tmp}""" - + query_copy = f""" COPY ( SELECT {query_export_columns} @@ -1390,10 +1531,14 @@ def export(self, output_database:str, output_header:str = None, database:str = N concat_file(input_files=[query_output_header_tmp, query_output_database_tmp], output_file=query_output_database_header_tmp) # move file shutil.move(query_output_database_header_tmp, query_output_database_tmp) + # remove tmp + remove_if_exists([query_output_header_tmp]) # Compress if compressed: compress_file(input_file=query_output_database_tmp, output_file=output_database) + # remove tmp + remove_if_exists([query_output_database_tmp]) else: shutil.move(query_output_database_tmp, output_database) diff --git a/howard/objects/variants.py b/howard/objects/variants.py index 7063c5b..c79c4ff 100644 --- a/howard/objects/variants.py +++ b/howard/objects/variants.py @@ -441,16 +441,16 @@ def get_output_format(self, output_file: str = None) -> str: return output_format - def get_output_compressed(self, output_file: str = None) -> str: - """ - It returns the format of the input variable. - :return: The format is being returned. - """ - if not output_file: - output_file = self.get_output() - output_compressed = get_file_compressed(output_file) + # def get_output_compressed(self, output_file: str = None) -> str: + # """ + # It returns the format of the input variable. + # :return: The format is being returned. + # """ + # if not output_file: + # output_file = self.get_output() + # output_compressed = get_file_compressed(output_file) - return output_compressed + # return output_compressed def get_config(self) -> dict: @@ -553,10 +553,7 @@ def get_header(self, type: str = "vcf"): :param type: the type of header you want to get, defaults to vcf (optional) :return: The header of the vcf file. """ - # vcf_required = [ - # '##fileformat=VCFv4.2', - # '#CHROM POS ID REF ALT QUAL FILTER INFO' - # ] + if self.header_vcf: if type == "vcf": return self.header_vcf @@ -1057,7 +1054,7 @@ def execute_query(self, query: str): return None - def export_output(self, export_header: bool = True, output_file: str = None, query: str = None) -> None: + def export_output(self, output_file: str = None, output_header: str = None, export_header: bool = True, query: str = None) -> bool: """ This function exports data from a VCF file to a specified output file in various formats, including VCF, CSV, TSV, PSV, and Parquet. @@ -1082,161 +1079,55 @@ def export_output(self, export_header: bool = True, output_file: str = None, que if not output_file: output_file = self.get_output() - # Connexion format - connexion_format = self.get_connexion_format() - - # Export header - if export_header: - header_name = self.export_header(output_file=output_file) - else: - # Header - tmp_header = NamedTemporaryFile( - prefix=self.get_prefix(), dir=self.get_tmp_dir()) - tmp_header_name = tmp_header.name - f = open(tmp_header_name, 'w') - vcf.Writer(f, self.header_vcf) - f.close() - header_name = tmp_header_name - - if output_file: - - sql_columns = self.get_header_columns_as_sql() - table_variants = self.get_table_variants() - sql_query_hard = "" - sql_query_sort = "" - sql_query_limit = "" - - # output_format - output_format = self.get_output_format(output_file=output_file) - output_compressed = self.get_output_compressed( - output_file=output_file) - - # delimiter - delimiter = file_format_delimiters.get(output_format, "\t") - - # Threads - threads = self.get_threads() - - log.debug(f"Export file: {output_file}") - - # Extra columns - sql_extra_columns = "" - if self.get_param().get("export_extra_infos", None): - sql_extra_columns = ", " + self.get_extra_infos_sql() - - log.debug(f"Export extra columns: {sql_extra_columns}") - - sql_query_export_subquery = None - sql_query_export_to = None - sql_query_export_format = None - commands = [] - - sqlite_options = {} - - output_file_tmp = output_file + ".tmp" - - if output_format in ["duckdb", "db"]: + # Auto header name with extension + if export_header or output_header: + if not output_header: + output_header = f"{output_file}.hdr" + # Export header + self.export_header(output_file=output_file) - # Remove output if exists - remove_if_exists(output_file) - - # Export parquet - sql_query_export_subquery = f""" - SELECT {sql_columns} {sql_extra_columns} FROM {table_variants} WHERE 1 {sql_query_hard} {sql_query_sort} {sql_query_limit} - """ - sql_query_export = f"COPY ({sql_query_export_subquery}) TO '{output_file_tmp}' WITH (FORMAT PARQUET)" - self.conn.execute(sql_query_export) + # Database + database_source=self.get_connexion() - # Export in duckdb - conn = duckdb.connect(output_file) - conn.execute(f"CREATE TABLE IF NOT EXISTS variants AS SELECT * FROM read_parquet('{output_file_tmp}')") - conn.close() - - # Remove tmp parquet file - remove_if_exists(output_file_tmp) + # Connexion format + connexion_format = self.get_connexion_format() - elif output_format in ["parquet"]: + # Tmp files to remove + tmp_to_remove = [] - # Export parquet - sql_query_export_subquery = f""" - SELECT {sql_columns} {sql_extra_columns} FROM {table_variants} WHERE 1 {sql_query_hard} {sql_query_sort} {sql_query_limit} - """ - sql_query_export_to = output_file_tmp - sql_query_export_format = "FORMAT PARQUET" + if connexion_format in ["sqlite"] or query: - sqlite_options = { - "format": "parquet" - } + # Export in Parquet + random_tmp = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) + database_source = f"""{output_file}.{random_tmp}.database_export.parquet""" + tmp_to_remove.append(database_source) - elif output_format in ["tsv", "csv", "psv"]: + # Table Variants + table_variants = self.get_table_variants() - # Export TSV/CSV + # Create export query + if query: sql_query_export_subquery = f""" - SELECT {sql_columns} {sql_extra_columns} FROM {table_variants} WHERE 1 {sql_query_hard} {sql_query_sort} {sql_query_limit} + SELECT * FROM ({query}) """ - sql_query_export_to = output_file_tmp - sql_query_export_format = f"FORMAT CSV, DELIMITER '{delimiter}', HEADER" - - sqlite_options = { - "format": "csv", - "sep": delimiter, - "quotechar": "'" - } - - elif output_format in ["vcf"]: - - # Extract VCF - tmp_variants = NamedTemporaryFile(prefix=self.get_prefix( - ), dir=self.get_tmp_dir(), suffix=".vcf", delete=False) - tmp_variants_name = tmp_variants.name + elif connexion_format in ["sqlite"]: sql_query_export_subquery = f""" - SELECT {sql_columns} FROM {table_variants} WHERE 1 {sql_query_hard} {sql_query_sort} {sql_query_limit} + SELECT * FROM {table_variants} """ - sql_query_export_to = tmp_variants_name - sql_query_export_format = f"FORMAT CSV, DELIMITER '\t', HEADER, QUOTE ''" - sqlite_options = { - "format": "csv", - "sep": delimiter, - "quotechar": None - } - - # VCF - commands.append( - f"grep '^#CHROM' -v {header_name} > {output_file_tmp}; cat {tmp_variants_name} >> {output_file_tmp}") + # Write source file + fp.write(database_source, self.get_query_to_df(sql_query_export_subquery)) - # Query - if query: - sql_query_export_subquery = query + # Create database + database = Database(database=database_source, table="variants", header_file=output_header) - # Export from table variants - if sql_query_export_subquery and sql_query_export_to and sql_query_export_format: + # Export file + database.export(output_database=output_file, output_header=output_header) - # Export data - if connexion_format in ["duckdb"]: - sql_query_export = f"COPY ({sql_query_export_subquery}) TO '{sql_query_export_to}' WITH ({sql_query_export_format})" - self.conn.execute(sql_query_export) - elif connexion_format in ["sqlite"]: - if sqlite_options.get("format","csv") in ["csv"]: - if sqlite_options.get("quotechar",None): - self.get_query_to_df(sql_query_export_subquery).to_csv(sql_query_export_to, index=False, sep=sqlite_options.get("sep",";"), compression='infer', quotechar=sqlite_options.get("quotechar","'")) - else: - self.get_query_to_df(sql_query_export_subquery).to_csv(sql_query_export_to, index=False, sep=sqlite_options.get("sep",";"), compression='infer') - elif sqlite_options.get("format","csv") in ["parquet"]: - fp.write(sql_query_export_to, self.get_query_to_df(sql_query_export_subquery)) - - # Compression - if output_compressed: - bgzip_command = get_bgzip(threads=threads) - commands.append( - f""" {bgzip_command} {output_file_tmp} > {output_file} && rm {output_file_tmp} """) - else: - commands.append( - f""" mv {output_file_tmp} {output_file} """) + # Remove + remove_if_exists(tmp_to_remove) - # Commands - if commands: - run_parallel_commands(commands=commands, threads=1) + return (os.path.exists(output_file) or None) and (os.path.exists(output_file) or None) def get_extra_infos(self, table: str = None) -> list: @@ -2629,432 +2520,7 @@ def annotation_annovar(self, threads: int = None) -> None: log.debug(f"Annotation - cleaning command: {clean_command}") run_parallel_commands([clean_command], 1) - - # def annotation_parquet(self, threads: int = None) -> None: - # """ - # It takes a VCF file, and annotates it with a parquet file - - # :param threads: number of threads to use for the annotation - # :return: the value of the variable "result". - # """ - - # # DEBUG - # log.debug("Start annotation with parquet databases") - - # # Threads - # if not threads: - # threads = self.get_threads() - # log.debug("Threads: "+str(threads)) - - # # DEBUG - # delete_tmp = True - # if self.get_config().get("verbosity", "warning") in ["debug"]: - # delete_tmp = False - # log.debug("Delete tmp files/folders: "+str(delete_tmp)) - - # # Config - # databases_folders = self.get_config().get("folders", {}).get( - # "databases", {}).get("parquet", ["."]) - # log.debug("Databases annotations: " + str(databases_folders)) - - # # Param - # annotations = self.get_param().get("annotation", {}).get( - # "parquet", {}).get("annotations", None) - # log.debug("Annotations: " + str(annotations)) - - # # Assembly - # assembly = self.get_param().get("assembly", "hg19") - - # # Data - # table_variants = self.get_table_variants() - - # # Check if not empty - # log.debug("Check if not empty") - # sql_query_chromosomes_df = self.get_query_to_df( - # f"""SELECT count(*) as count FROM {table_variants} as table_variants LIMIT 1""") - # if not sql_query_chromosomes_df["count"][0]: - # log.info(f"VCF empty") - # return - - # # VCF header - # vcf_reader = self.get_header() - # log.debug("Initial header: " + str(vcf_reader.infos)) - - # # Nb Variants POS - # log.debug("NB Variants Start") - # nb_variants = self.conn.execute( - # f"SELECT count(*) AS count FROM variants").fetchdf()["count"][0] - # log.debug("NB Variants Stop") - - # # Existing annotations - # for vcf_annotation in self.get_header().infos: - - # vcf_annotation_line = self.get_header().infos.get(vcf_annotation) - # log.debug( - # f"Existing annotations in VCF: {vcf_annotation} [{vcf_annotation_line}]") - - # # explode infos - # self.explode_infos(prefix=self.get_param().get("explode_infos", None)) - - # # drop indexes - # log.debug(f"Drop indexes...") - # self.drop_indexes() - - # if annotations: - - # for annotation in annotations: - # annotation_fields = annotations[annotation] - - # if not annotation_fields: - # annotation_fields = {"INFO": None} - - # log.debug(f"Annotation '{annotation}'") - # log.debug( - # f"Annotation '{annotation}' - fields: {annotation_fields}") - - # # DEVEL - # annotation = "nci60.parquet" - # #databases_folders = [".", "/databases/annotations/current/hg19", "/tools/howard/devel/tests/data/annotations"] - # #databases_folders = [".", "/databases/annotations/current"] - # #databases_folders = [".", "/tools/howard/devel/tests/data/annotations"] - # print() - # print("DEVEL START") - # print(f"assembly: {assembly}") - # print(f"annotation: {annotation}") - # print(f"databases_folders: {databases_folders}") - # database = Database(database=annotation, databases_folders=databases_folders, assembly=assembly) - - # print(f"get database: " + database.get_database()) - # print(f"get header file: " + str(database.get_header_file())) - # parquet_file = database.get_database() - # parquet_hdr_file = database.get_header_file() - # #print("DEVEL END") - - - # # Find parquet file and header file - # # parquet_file = None - # # parquet_hdr_file = None - # # for databases_folder in databases_folders: - # # parquet_file = None - # # parquet_hdr_file = None - # # log.debug("Annotation file check: " + annotation + - # # " or " + str(databases_folder+"/"+annotation+".parquet")) - - # # # Parquet .parquet - # # if os.path.exists(annotation): - # # parquet_file = annotation - # # elif os.path.exists(databases_folder+"/"+annotation+".parquet"): - # # parquet_file = databases_folder+"/"+annotation+".parquet" - # # if not parquet_file: - # # continue - - # # # Header .hdr - # # if os.path.exists(parquet_file+".hdr"): - # # parquet_hdr_file = parquet_file+".hdr" - - # # # parquet and hdr found - # # if parquet_file and parquet_hdr_file: - # # break - - # if not parquet_file or not parquet_hdr_file: - # log.error("Annotation failed: file not found") - # raise ValueError("Annotation failed: file not found") - # else: - - # # Get parquet connexion - # parquet_sql_attach = database.get_sql_database_attach(output="query") - # if parquet_sql_attach: - # self.conn.execute(parquet_sql_attach) - # parquet_file_link = database.get_sql_database_link() - # # Log - # log.debug(f"Annotation '{annotation}' - file: " + - # str(parquet_file) + " and " + str(parquet_hdr_file)) - - # # Load header as VCF object - # # parquet_hdr_vcf = Variants(input=parquet_hdr_file) - # # print(parquet_hdr_vcf) - # # print(parquet_hdr_vcf.get_header().infos) - # parquet_hdr_vcf = database.get_header() - # # print(parquet_hdr_vcf) - # # print(parquet_hdr_vcf.infos) - # #parquet_hdr_vcf_header_infos = parquet_hdr_vcf.get_header().infos - # parquet_hdr_vcf_header_infos = parquet_hdr_vcf.infos - # # Log - # log.debug("Annotation database header: " + - # str(parquet_hdr_vcf_header_infos)) - - - # # get extra infos - # parquet_columns = self.get_extra_infos( - # table=parquet_file_link) - - # # For all fields in database - # annotation_fields_ALL = False - # if "ALL" in annotation_fields or "INFO" in annotation_fields: - # annotation_fields_ALL = True - # annotation_fields = { - # key: key for key in parquet_hdr_vcf_header_infos} - # log.debug( - # "Annotation database header - All annotations added: " + str(annotation_fields)) - - # # List of annotation fields to use - # sql_query_annotation_update_info_sets = [] - - # # Number of fields - # nb_annotation_field = 0 - - # # Annotation fields processed - # annotation_fields_processed = [] - - # for annotation_field in annotation_fields: - - # # annotation_field_column - # if annotation_field in parquet_columns: - # annotation_field_column = annotation_field - # elif "INFO/" + annotation_field in parquet_columns: - # annotation_field_column = "INFO/" + annotation_field - # else: - # annotation_field_column = "INFO" - - # # field new name, if parametered - # annotation_fields_new_name = annotation_fields.get( - # annotation_field, annotation_field) - # if not annotation_fields_new_name: - # annotation_fields_new_name = annotation_field - - # # check annotation field in data - # annotation_field_exists_on_variants = 0 - # if annotation_fields_new_name not in self.get_header().infos: - # sampling_annotation_field_exists_on_variants = 10000 - # sql_query_chromosomes = f""" - # SELECT 1 AS count - # FROM (SELECT * FROM {table_variants} as table_variants LIMIT {sampling_annotation_field_exists_on_variants}) - # WHERE ';' || INFO LIKE '%;{annotation_fields_new_name}=%' - # LIMIT 1 - # """ - # annotation_field_exists_on_variants = len( - # self.conn.execute(f"{sql_query_chromosomes}").df()["count"]) - # log.debug(f"Annotation field {annotation_fields_new_name} found in variants: " + str( - # annotation_field_exists_on_variants)) - - # # To annotate - # force_update_annotation = False - # if annotation_field in parquet_hdr_vcf_header_infos and (force_update_annotation or (annotation_fields_new_name not in self.get_header().infos and not annotation_field_exists_on_variants)): - - # # Add field to annotation to process list - # annotation_fields_processed.append( - # annotation_fields_new_name) - - # # Sep between fields in INFO - # nb_annotation_field += 1 - # if nb_annotation_field > 1: - # annotation_field_sep = ";" - # else: - # annotation_field_sep = "" - - # log.info( - # f"Annotation '{annotation}' - '{annotation_field}' -> 'INFO/{annotation_fields_new_name}'") - - # # Add INFO field to header - # parquet_hdr_vcf_header_infos_number = parquet_hdr_vcf_header_infos[ - # annotation_field].num or "." - # parquet_hdr_vcf_header_infos_type = parquet_hdr_vcf_header_infos[ - # annotation_field].type or "String" - # parquet_hdr_vcf_header_infos_description = parquet_hdr_vcf_header_infos[ - # annotation_field].desc or f"{annotation_field} description" - # parquet_hdr_vcf_header_infos_source = parquet_hdr_vcf_header_infos[ - # annotation_field].source or "unknown" - # parquet_hdr_vcf_header_infos_version = parquet_hdr_vcf_header_infos[ - # annotation_field].version or "unknown" - - # vcf_reader.infos[annotation_fields_new_name] = vcf.parser._Info( - # annotation_fields_new_name, - # parquet_hdr_vcf_header_infos_number, - # parquet_hdr_vcf_header_infos_type, - # parquet_hdr_vcf_header_infos_description, - # parquet_hdr_vcf_header_infos_source, - # parquet_hdr_vcf_header_infos_version, - # self.code_type_map[parquet_hdr_vcf_header_infos_type] - # ) - - # # Annotation/Update query fields - # # Found in INFO column - # if annotation_field_column == "INFO": - # sql_query_annotation_update_info_sets.append(f""" - # || CASE WHEN REGEXP_EXTRACT(';' || table_parquet.INFO, ';{annotation_field}=([^;]*)',1) NOT IN ('','.') - # THEN '{annotation_field_sep}' || '{annotation_fields_new_name}=' || REGEXP_EXTRACT(';' || table_parquet.INFO, ';{annotation_field}=([^;]*)',1) - # ELSE '' - # END - # """) - # # Found in a specific column - # else: - # sql_query_annotation_update_info_sets.append(f""" - # || CASE WHEN table_parquet."{annotation_field_column}" NOT IN ('','.') - # THEN '{annotation_field_sep}' || '{annotation_fields_new_name}=' || table_parquet."{annotation_field_column}" - # ELSE '' - # END - # """) - - # # Not to annotate - # else: - - # if force_update_annotation: - # annotation_message = "forced" - # else: - # annotation_message = "skipped" - - # if annotation_field not in parquet_hdr_vcf_header_infos: - # log.warning( - # f"Annotation '{annotation}' - '{annotation_field}' [{nb_annotation_field}] - not available in parquet file") - # if annotation_fields_new_name in self.get_header().infos: - # log.warning( - # f"Annotation '{annotation}' - '{annotation_fields_new_name}' [{nb_annotation_field}] - already exists in header ({annotation_message})") - # if annotation_field_exists_on_variants: - # log.warning( - # f"Annotation '{annotation}' - '{annotation_fields_new_name}' [{nb_annotation_field}] - already exists in variants ({annotation_message})") - - # # Check if ALL fields have to be annotated. Thus concat all INFO field - # allow_annotation_full_info = True - # if allow_annotation_full_info and nb_annotation_field == len(annotation_fields) and annotation_fields_ALL: - # sql_query_annotation_update_info_sets = [] - # sql_query_annotation_update_info_sets.append( - # f"|| table_parquet.INFO ") - - # if sql_query_annotation_update_info_sets: - - # # Annotate - # log.info(f"Annotation '{annotation}' - Annotation...") - - # # Join query annotation update info sets for SQL - # sql_query_annotation_update_info_sets_sql = " ".join( - # sql_query_annotation_update_info_sets) - - # # Check chromosomes list (and variant max position) - # sql_query_chromosomes_max_pos = f""" SELECT table_variants."#CHROM" as CHROM, MAX(table_variants."POS") as MAX_POS, MIN(table_variants."POS")-1 as MIN_POS FROM {table_variants} as table_variants GROUP BY table_variants."#CHROM" """ - # sql_query_chromosomes_max_pos_df = self.conn.execute( - # sql_query_chromosomes_max_pos).df() - - # # Create dictionnary with chromosomes (and max position) - # sql_query_chromosomes_max_pos_dictionary = sql_query_chromosomes_max_pos_df.groupby('CHROM').apply( - # lambda x: {'max_pos': x['MAX_POS'].max(), 'min_pos': x['MIN_POS'].min()}).to_dict() - - # # Affichage du dictionnaire - # log.debug("Chromosomes max pos found: " + - # str(sql_query_chromosomes_max_pos_dictionary)) - - # # nb_of_variant_annotated - # nb_of_query = 0 - # nb_of_variant_annotated = 0 - # query_dict = {} - - # for chrom in sql_query_chromosomes_max_pos_dictionary: - - # # nb_of_variant_annotated_by_chrom - # nb_of_variant_annotated_by_chrom = 0 - - # # Get position of the farthest variant (max position) in the chromosome - # sql_query_chromosomes_max_pos_dictionary_max_pos = sql_query_chromosomes_max_pos_dictionary.get( - # chrom, {}).get("max_pos") - # sql_query_chromosomes_max_pos_dictionary_min_pos = sql_query_chromosomes_max_pos_dictionary.get( - # chrom, {}).get("min_pos") - - # # Autodetect range of bases to split/chunk - # log.debug( - # f"Annotation '{annotation}' - Chromosome '{chrom}' - Start Autodetection Intervals...") - - # batch_annotation_databases_step = None - # batch_annotation_databases_ncuts = 1 - - # # Create intervals from 0 to max position variant, with the batch window previously defined - # sql_query_intervals = split_interval( - # sql_query_chromosomes_max_pos_dictionary_min_pos, sql_query_chromosomes_max_pos_dictionary_max_pos, step=batch_annotation_databases_step, ncuts=batch_annotation_databases_ncuts) - - # log.debug( - # f"Annotation '{annotation}' - Chromosome '{chrom}' - Stop Autodetection Intervals") - - # # Interval Start/Stop - # sql_query_interval_start = sql_query_intervals[0] - - # # For each interval - # for i in sql_query_intervals[1:]: - - # # Interval Start/Stop - # sql_query_interval_stop = i - - # log.debug( - # f"Annotation '{annotation}' - Chromosome '{chrom}' - Interval [{sql_query_interval_start}-{sql_query_interval_stop}] ...") - - # log.debug( - # f"Annotation '{annotation}' - Chromosome '{chrom}' - Interval [{sql_query_interval_start}-{sql_query_interval_stop}] - Start detecting regions...") - - # regions = [ - # (chrom, sql_query_interval_start, sql_query_interval_stop)] - - # log.debug( - # f"Annotation '{annotation}' - Chromosome '{chrom}' - Interval [{sql_query_interval_start}-{sql_query_interval_stop}] - Stop detecting regions") - - # # Fusion des régions chevauchantes - # if regions: - - # # Number of regions - # nb_regions = len(regions) - - # # create where caluse on regions - # clause_where_regions_variants = create_where_clause( - # regions, table="table_variants") - # clause_where_regions_parquet = create_where_clause( - # regions, table="table_parquet") - - # log.debug( - # f"Annotation '{annotation}' - Chromosome '{chrom}' - Interval [{sql_query_interval_start}-{sql_query_interval_stop}] - {nb_regions} regions...") - - # sql_query_annotation_chrom_interval_pos = f""" - # UPDATE {table_variants} as table_variants - # SET INFO = CASE WHEN table_variants.INFO NOT IN ('','.') THEN table_variants.INFO ELSE '' END || CASE WHEN table_variants.INFO NOT IN ('','.') AND ('' {sql_query_annotation_update_info_sets_sql}) NOT IN ('','.') THEN ';' ELSE '' END {sql_query_annotation_update_info_sets_sql} - # FROM {parquet_file_link} as table_parquet - # WHERE ( {clause_where_regions_parquet} ) - # AND table_parquet.\"#CHROM\" = table_variants.\"#CHROM\" - # AND table_parquet.\"POS\" = table_variants.\"POS\" - # AND table_parquet.\"ALT\" = table_variants.\"ALT\" - # AND table_parquet.\"REF\" = table_variants.\"REF\"; - # """ - # query_dict[f"{chrom}:{sql_query_interval_start}-{sql_query_interval_stop}"] = sql_query_annotation_chrom_interval_pos - - # log.debug( - # "Create SQL query: " + str(sql_query_annotation_chrom_interval_pos)) - - # # Interval Start/Stop - # sql_query_interval_start = sql_query_interval_stop - - # # nb_of_variant_annotated - # nb_of_variant_annotated += nb_of_variant_annotated_by_chrom - - # nb_of_query = len(query_dict) - # num_query = 0 - # for query_name in query_dict: - # query = query_dict[query_name] - # num_query += 1 - # log.info( - # f"Annotation '{annotation}' - Annotation - Query [{num_query}/{nb_of_query}] {query_name}...") - # result = self.conn.execute(query) - # nb_of_variant_annotated_by_query = result.df()[ - # "Count"][0] - # nb_of_variant_annotated += nb_of_variant_annotated_by_query - # log.info( - # f"Annotation '{annotation}' - Annotation - Query [{num_query}/{nb_of_query}] {query_name} - {nb_of_variant_annotated_by_query} variants annotated") - - # log.info( - # f"Annotation '{annotation}' - Annotation of {nb_of_variant_annotated} variants out of {nb_variants} (with {nb_of_query} queries)") - - # else: - - # log.info( - # f"Annotation '{annotation}' - No Annotations available") - - # log.debug("Final header: " + str(vcf_reader.infos)) - - - + # NEW def def annotation_parquet(self, threads: int = None) -> None: """ It takes a VCF file, and annotates it with a parquet file @@ -3078,15 +2544,18 @@ def annotation_parquet(self, threads: int = None) -> None: log.debug("Delete tmp files/folders: "+str(delete_tmp)) # Config - databases_folders = self.config.get("folders", {}).get( + databases_folders = self.get_config().get("folders", {}).get( "databases", {}).get("parquet", ["."]) log.debug("Databases annotations: " + str(databases_folders)) # Param - annotations = self.param.get("annotation", {}).get( + annotations = self.get_param().get("annotation", {}).get( "parquet", {}).get("annotations", None) log.debug("Annotations: " + str(annotations)) + # Assembly + assembly = self.get_param().get("assembly", "hg19") + # Data table_variants = self.get_table_variants() @@ -3134,73 +2603,39 @@ def annotation_parquet(self, threads: int = None) -> None: log.debug( f"Annotation '{annotation}' - fields: {annotation_fields}") - # Find parquet file and header file - parquet_file = None - parquet_hdr_file = None - for databases_folder in databases_folders: - parquet_file = None - parquet_hdr_file = None - log.debug("Annotation file check: " + annotation + - " or " + str(databases_folder+"/"+annotation+".parquet")) + # Create Database + database = Database(database=annotation, databases_folders=databases_folders, assembly=assembly) - # Parquet .parquet - if os.path.exists(annotation): - parquet_file = annotation - elif os.path.exists(databases_folder+"/"+annotation+".parquet"): - parquet_file = databases_folder+"/"+annotation+".parquet" - if not parquet_file: - continue - - # Header .hdr - if os.path.exists(parquet_file+".hdr"): - parquet_hdr_file = parquet_file+".hdr" - - # parquet and hdr found - if parquet_file and parquet_hdr_file: - break + # Find files + parquet_file = database.get_database() + parquet_hdr_file = database.get_header_file() + # Check if files exists if not parquet_file or not parquet_hdr_file: log.error("Annotation failed: file not found") raise ValueError("Annotation failed: file not found") else: - parquet_file_link = f"'{parquet_file}'" - - # Database format and type - parquet_file_name, parquet_file_extension = os.path.splitext( - parquet_file) - parquet_file_basename = os.path.basename(parquet_file) - parquet_file_format = parquet_file_extension.replace( - ".", "") - parquet_file_type = parquet_file_format - - if parquet_file_format in ["db", "duckdb", "sqlite"]: - parquet_file_as_duckdb_name = parquet_file_basename.replace( - ".", "_") - if parquet_file_format in ["sqlite"]: - parquet_file_format_attached_type = ", TYPE SQLITE" - else: - parquet_file_format_attached_type = "" - log.debug( - f"Annotation '{annotation}' - attach database : " + str(parquet_file)) - self.conn.execute( - f"ATTACH DATABASE '{parquet_file}' AS {parquet_file_as_duckdb_name} (READ_ONLY{parquet_file_format_attached_type})") - parquet_file_link = f"{parquet_file_as_duckdb_name}.variants" - elif parquet_file_format in ["parquet"]: - parquet_file_link = f"'{parquet_file}'" - + # Get parquet connexion + parquet_sql_attach = database.get_sql_database_attach(output="query") + if parquet_sql_attach: + self.conn.execute(parquet_sql_attach) + parquet_file_link = database.get_sql_database_link() + # Log log.debug(f"Annotation '{annotation}' - file: " + str(parquet_file) + " and " + str(parquet_hdr_file)) # Load header as VCF object - parquet_hdr_vcf = Variants(input=parquet_hdr_file) - parquet_hdr_vcf_header_infos = parquet_hdr_vcf.get_header().infos + parquet_hdr_vcf_header_infos = database.get_header().infos + # Log log.debug("Annotation database header: " + str(parquet_hdr_vcf_header_infos)) - # get extra infos - parquet_columns = self.get_extra_infos( - table=parquet_file_link) + # Get extra infos + parquet_columns = database.get_extra_columns() + # Log + log.debug("Annotation database Columns: " + + str(parquet_columns)) # For all fields in database annotation_fields_ALL = False @@ -3211,6 +2646,8 @@ def annotation_parquet(self, threads: int = None) -> None: log.debug( "Annotation database header - All annotations added: " + str(annotation_fields)) + # Init + # List of annotation fields to use sql_query_annotation_update_info_sets = [] @@ -3220,15 +2657,14 @@ def annotation_parquet(self, threads: int = None) -> None: # Annotation fields processed annotation_fields_processed = [] + # Columns mapping + map_columns = database.map_columns(columns=annotation_fields, prefixes=["INFO/"]) + + # Fetch Anotation fields for annotation_field in annotation_fields: # annotation_field_column - if annotation_field in parquet_columns: - annotation_field_column = annotation_field - elif "INFO/" + annotation_field in parquet_columns: - annotation_field_column = "INFO/" + annotation_field - else: - annotation_field_column = "INFO" + annotation_field_column = map_columns.get(annotation_field, "INFO") # field new name, if parametered annotation_fields_new_name = annotation_fields.get( @@ -3236,24 +2672,9 @@ def annotation_parquet(self, threads: int = None) -> None: if not annotation_fields_new_name: annotation_fields_new_name = annotation_field - # check annotation field in data - annotation_field_exists_on_variants = 0 - if annotation_fields_new_name not in self.get_header().infos: - sampling_annotation_field_exists_on_variants = 10000 - sql_query_chromosomes = f""" - SELECT 1 AS count - FROM (SELECT * FROM {table_variants} as table_variants LIMIT {sampling_annotation_field_exists_on_variants}) - WHERE ';' || INFO LIKE '%;{annotation_fields_new_name}=%' - LIMIT 1 - """ - annotation_field_exists_on_variants = len( - self.conn.execute(f"{sql_query_chromosomes}").df()["count"]) - log.debug(f"Annotation field {annotation_fields_new_name} found in variants: " + str( - annotation_field_exists_on_variants)) - # To annotate force_update_annotation = False - if annotation_field in parquet_hdr_vcf.get_header().infos and (force_update_annotation or (annotation_fields_new_name not in self.get_header().infos and not annotation_field_exists_on_variants)): + if annotation_field in parquet_hdr_vcf_header_infos and (force_update_annotation or (annotation_fields_new_name not in self.get_header().infos)): # Add field to annotation to process list annotation_fields_processed.append( @@ -3317,15 +2738,12 @@ def annotation_parquet(self, threads: int = None) -> None: else: annotation_message = "skipped" - if annotation_field not in parquet_hdr_vcf.get_header().infos: + if annotation_field not in parquet_hdr_vcf_header_infos: log.warning( f"Annotation '{annotation}' - '{annotation_field}' [{nb_annotation_field}] - not available in parquet file") if annotation_fields_new_name in self.get_header().infos: log.warning( f"Annotation '{annotation}' - '{annotation_fields_new_name}' [{nb_annotation_field}] - already exists in header ({annotation_message})") - if annotation_field_exists_on_variants: - log.warning( - f"Annotation '{annotation}' - '{annotation_fields_new_name}' [{nb_annotation_field}] - already exists in variants ({annotation_message})") # Check if ALL fields have to be annotated. Thus concat all INFO field allow_annotation_full_info = True @@ -3468,7 +2886,6 @@ def annotation_parquet(self, threads: int = None) -> None: log.debug("Final header: " + str(vcf_reader.infos)) - ### # Prioritization ### diff --git a/tests/test_commons.py b/tests/test_commons.py index b9ac44a..43d5dc9 100644 --- a/tests/test_commons.py +++ b/tests/test_commons.py @@ -38,7 +38,7 @@ def test_get_file_compressed(): assert get_file_compressed("testfile.gz") == True # Test pour un fichier compressé .bcf - assert get_file_compressed("testfile.bcf") == True + assert get_file_compressed("testfile.bcf") == False # Test pour un fichier non compressé .vcf assert get_file_compressed("testfile.vcf") == False diff --git a/tests/test_objects_database.py b/tests/test_objects_database.py index 2374020..f297d78 100644 --- a/tests/test_objects_database.py +++ b/tests/test_objects_database.py @@ -90,6 +90,8 @@ def test_database_as_conn(): # Check get_database_basename assert database.get_database_basename() == None + # Check is compressed + assert not database.is_compressed() # Check export output_database = "/tmp/output_database.vcf" @@ -100,7 +102,10 @@ def test_database_as_conn(): except: assert False - #assert False + # Check export duckdb + output_database = "/tmp/output_database.duckdb" + remove_if_exists([output_database]) + assert database.export(output_database=output_database) def test_empty_database(): @@ -246,6 +251,32 @@ def test_get_header_from_file(): assert list(database_header.infos) == [] +def test_get_header_infos_list(): + """ + This function tests the `get_header` method of the `Database` class in Python. + """ + + # Init files + + # Create object + database = Database() + + # Header None + database_header_infos_list = database.get_header_infos_list() + assert database_header_infos_list == [] + + # Header parquet + database_header_infos_list = database.get_header_infos_list(database=database_files.get("parquet")) + assert database_header_infos_list == ["nci60"] + + # Create conn with variants in a table by loading a Parquet with Variants object + # Header list is columns of the table + variants = Variants(input=database_files.get("parquet"), load=True) + database = Database(conn=variants.get_connexion()) + database_header_infos_list = database.get_header_infos_list(database=database.get_conn()) + assert database_header_infos_list == ["INFO/nci60"] + + def test_get_header(): """ This function tests the `get_header` method of the `Database` class in Python. @@ -711,6 +742,53 @@ def test_is_vcf(): assert not database.is_vcf(None) +def test_is_compressed(): + """ + This function test is a database is a vcf (contains all needed columns) + """ + + # Create object + database = Database(database_files.get("vcf")) + + # Check duckdb + assert not database.is_compressed() + + # Create object + database = Database() + + # Check parquet + assert not database.is_compressed(database_files.get("parquet")) + + # Check duckdb + assert not database.is_compressed(database_files.get("duckdb")) + + # Check vcf + assert not database.is_compressed(database_files.get("vcf")) + assert database.is_compressed(database_files.get("vcf_gz")) + + # Check tsv + assert not database.is_compressed(database_files.get("tsv")) + assert database.is_compressed(database_files.get("tsv_gz")) + assert not database.is_compressed(database_files.get("tsv_alternative_columns")) + assert not database.is_compressed(database_files.get("tsv_failed_columns")) + assert not database.is_compressed(database_files.get("tsv_lower_columns")) + + # Check csv + assert not database.is_compressed(database_files.get("csv")) + assert database.is_compressed(database_files.get("csv_gz")) + + # Check json + assert not database.is_compressed(database_files.get("json")) + assert database.is_compressed(database_files.get("json_gz")) + + # Check bed + assert not database.is_compressed(database_files.get("bed")) + assert database.is_compressed(database_files.get("bed_gz")) + + # Check None + assert not database.is_compressed(None) + + def test_get_database_tables(): """ This function list tables in a duckdb database @@ -819,7 +897,6 @@ def test_get_database_table(): assert database.get_database_table(None) is None - def test_get_sql_from(): """ This function get sql from section from a database @@ -872,7 +949,6 @@ def test_get_sql_from(): assert database.get_sql_from(None) == None - def test_get_sql_database_link(): """ This function get sql_database_link section from a database @@ -1088,6 +1164,85 @@ def test_get_extra_colums(): assert database.get_extra_columns(None) == [] +def test_find_column(): + """ + This is a test function for a Python program that checks the functionality of a method called + "find_column" in a database. + """ + + # Check duckdb + database = Database(database=database_files.get("parquet")) + + # find column + assert database.find_column() == "INFO" + + assert database.find_column(column="nci60") == "INFO/nci60" + assert database.find_column(column="nci60", prefixes=["INFO/"]) == "INFO/nci60" + assert database.find_column(column="nci60", prefixes=["PREFIX/"]) == "INFO" + + assert database.find_column(column="INFO/nci60") == "INFO/nci60" + assert database.find_column(column="INFO/nci60", prefixes=["INFO/"]) == "INFO/nci60" + assert database.find_column(column="INFO/nci60", prefixes=["PREFIX/"]) == "INFO/nci60" + + assert database.find_column(column="OTHER/nci60", prefixes=["PREFIX/"]) == None + + + assert database.find_column(column="column") == None + assert database.find_column(column="column", prefixes=["INFO/"]) == None + assert database.find_column(column="column", prefixes=["PREFIX/"]) == None + + # Empty Database + database = Database() + assert database.find_column(database=database_files.get("parquet"), column="nci60") == "INFO/nci60" + assert database.find_column(database=database_files.get("parquet"), column="INFO/nci60") == "INFO/nci60" + assert database.find_column(database=database_files.get("parquet"), column="nci60", prefixes=["INFO/"]) == "INFO/nci60" + assert database.find_column(database=database_files.get("parquet"), column="nci60", prefixes=["OTHER/"]) == "INFO" + assert database.find_column(database=database_files.get("parquet"), column="column") == None + + +def test_map_columns(): + """ + The function tests the `map_columns` method of a `Database` class in Python. + """ + + # Check duckdb + database = Database(database=database_files.get("parquet")) + + # find column + assert database.map_columns() == {} + + assert database.map_columns(columns=["nci60"]) == {"nci60": "INFO/nci60"} + assert database.map_columns(columns=["nci60"], prefixes=["INFO/"]) == {"nci60": "INFO/nci60"} + assert database.map_columns(columns=["nci60"], prefixes=["PREFIX/"]) == {"nci60": "INFO"} + + assert database.map_columns(columns=["nci60", "QUAL"]) == {"nci60": "INFO/nci60", "QUAL": "QUAL"} + assert database.map_columns(columns=["nci60", "QUAL"], prefixes=["INFO/"]) == {"nci60": "INFO/nci60", "QUAL": "QUAL"} + assert database.map_columns(columns=["nci60", "QUAL"], prefixes=["PREFIX/"]) == {"nci60": "INFO", "QUAL": "QUAL"} + + + assert database.map_columns(columns=["INFO/nci60"]) == {"INFO/nci60": "INFO/nci60"} + assert database.map_columns(columns=["INFO/nci60"], prefixes=["INFO/"]) == {"INFO/nci60": "INFO/nci60"} + assert database.map_columns(columns=["INFO/nci60"], prefixes=["PREFIX/"]) == {"INFO/nci60": "INFO/nci60"} + + assert database.map_columns(columns=["column"]) == {"column": None} + assert database.map_columns(columns=["column"], prefixes=["INFO/"]) == {"column": None} + assert database.map_columns(columns=["column"], prefixes=["PREFIX/"]) == {"column": None} + + assert database.map_columns(columns=["column", "QUAL"]) == {"column": None, "QUAL": "QUAL"} + assert database.map_columns(columns=["column", "QUAL"], prefixes=["INFO/"]) == {"column": None, "QUAL": "QUAL"} + assert database.map_columns(columns=["column", "QUAL"], prefixes=["PREFIX/"]) == {"column": None, "QUAL": "QUAL"} + + # Empty Database + database = Database() + assert database.map_columns(database=database_files.get("parquet"), columns=["nci60"]) == {"nci60": "INFO/nci60"} + assert database.map_columns(database=database_files.get("parquet"), columns=["INFO/nci60"]) == {"INFO/nci60": "INFO/nci60"} + assert database.map_columns(database=database_files.get("parquet"), columns=["nci60"], prefixes=["INFO/"]) == {"nci60": "INFO/nci60"} + assert database.map_columns(database=database_files.get("parquet"), columns=["nci60"], prefixes=["OTHER/"]) == {"nci60": "INFO"} + assert database.map_columns(database=database_files.get("parquet"), columns=["column"]) == {"column": None} + + assert database.map_columns(database=database_files.get("parquet"), columns=["nci60", "QUAL"]) == {"nci60": "INFO/nci60", "QUAL": "QUAL"} + + def test_get_header_from_columns(): """ The function tests the `get_header_from_columns` method of the `Database` class for different database types and @@ -1114,8 +1269,6 @@ def test_get_header_from_columns(): # VCF without external header assert list(database.get_header_from_columns(database=database_files.get("vcf_without_header")).infos) == [] - #assert False - def test_get_annotations(): """ @@ -1287,7 +1440,7 @@ def test_export(): # database input/format for database_input_index in ["bed", "parquet", "vcf", "vcf_gz", "tsv", "csv", "tbl", "tsv_alternative_columns", "tsv_variants", "json", "example_vcf"]: - for database_output_format in ["parquet", "vcf", "vcf.gz", "tsv", "csv", "tbl", "json", "bed"]: + for database_output_format in ["duckdb", "parquet", "vcf", "vcf.gz", "tsv", "csv", "tbl", "json", "bed"]: input_database = database_files.get(database_input_index) database = Database(database_files.get(database_input_index)) output_database=f"/tmp/output_database.{database_output_format}" @@ -1296,6 +1449,8 @@ def test_export(): if not (database.get_format(input_database) == "bed" and database.get_format(output_database) == "vcf"): try: assert database.export(output_database=output_database, output_header=output_header) + if database.get_sql_database_attach(database=output_database): + database.query(database=output_database, query=f"""{database.get_sql_database_attach(database=output_database)}""") assert database.query(database=output_database, query=f"""{database.get_sql_database_link(database=output_database)}""") except: assert False diff --git a/tests/test_objects_variants.py b/tests/test_objects_variants.py index c9251a6..a161c79 100644 --- a/tests/test_objects_variants.py +++ b/tests/test_objects_variants.py @@ -26,6 +26,8 @@ # Main tests folder tests_folder = os.path.dirname(__file__) +tests_data_folder = tests_folder + "/data" +tests_annotations_folder = tests_folder + "/data/annotations" # Tools folder tests_tools = "/tools" @@ -54,56 +56,89 @@ } } +# Annotation databases +database_files = { + "parquet" : tests_annotations_folder + "/nci60.parquet", + "parquet_without_header" : tests_annotations_folder + "/nci60.without_header.parquet", + "duckdb" : tests_annotations_folder + "/nci60.duckdb", + "duckdb_no_annotation_table" : tests_annotations_folder + "/nci60.no_annotation_table.duckdb", + "sqlite" : tests_annotations_folder + "/nci60.sqlite", + "vcf" : tests_annotations_folder + "/nci60.vcf", + "vcf_gz" : tests_annotations_folder + "/nci60.vcf.gz", + "vcf_without_header" : tests_annotations_folder + "/nci60.without_header.vcf", + "vcf_gz_without_header" : tests_annotations_folder + "/nci60.without_header.vcf.gz", + "tsv" : tests_annotations_folder + "/nci60.tsv", + "tsv_alternative_columns" : tests_annotations_folder + "/nci60.alternative_columns.tsv", + "tsv_failed_columns" : tests_annotations_folder + "/nci60.failed_columns.tsv", + "tsv_lower_columns" : tests_annotations_folder + "/nci60.lower_columns.tsv", + "tsv_without_header" : tests_annotations_folder + "/nci60.without_header.tsv", + "tsv_variants" : tests_annotations_folder + "/nci60.variants.tsv", + "tsv_gz" : tests_annotations_folder + "/nci60.tsv.gz", + "csv" : tests_annotations_folder + "/nci60.csv", + "csv_gz" : tests_annotations_folder + "/nci60.csv.gz", + "tbl" : tests_annotations_folder + "/nci60.tbl", + "tbl_gz" : tests_annotations_folder + "/nci60.tbl.gz", + "json" : tests_annotations_folder + "/nci60.json", + "json_gz" : tests_annotations_folder + "/nci60.json.gz", + "bed" : tests_annotations_folder + "/annotation_regions.bed", + "bed_gz" : tests_annotations_folder + "/annotation_regions.bed.gz", + "example_vcf" : tests_data_folder + "/example.vcf", +} -# def test_DEVEL_annotation_parquet(): -# """ -# Tests the `annotation()` method of the `Variants` class using a Parquet file as annotation source. - -# The function creates a `Variants` object with an input VCF file and an output VCF file, and a parameter dictionary specifying that the Parquet file should be used as the annotation source with the "INFO" field. The `annotation()` method is then called to annotate the variants, and the resulting VCF file is checked for correctness using PyVCF. - -# Returns: -# None -# """ -# # Init files -# input_vcf = tests_folder + "/data/example.vcf.gz" -# #annotation_parquet = tests_folder + "/data/annotations/nci60.parquet" -# annotation_parquet = "nci60.parquet" -# databses_parquet = tests_folder + "/data/annotations" -# output_vcf = "/tmp/output.vcf.gz" - -# # Construct config dict -# config = {"folders": {"databases": {"parquet": [databses_parquet]}}} - -# # Construct param dict -# param = {"annotation": {"parquet": {"annotations": {annotation_parquet: {"INFO": None}}}}} -# # Create object -# variants = Variants(conn=None, input=input_vcf, output=output_vcf, config=config, param=param, load=True) -# # Remove if output file exists -# remove_if_exists([output_vcf]) - -# # Annotation -# variants.annotation() +def test_export_query(): + """ + This is a test function for exporting data from a VCF file to a TSV file using SQL queries. + """ -# # query annotated variant -# result = variants.get_query_to_df("SELECT 1 AS count FROM variants WHERE \"#CHROM\" = 'chr7' AND POS = 55249063 AND REF = 'G' AND ALT = 'A' AND INFO = 'DP=125;nci60=0.66'") -# length = len(result) - -# assert length == 1 + # Init files + input_vcf = tests_folder + "/data/example.vcf.gz" + output_tsv = "/tmp/example.tsv" -# # Check if VCF is in correct format with pyVCF -# variants.export_output() -# try: -# vcf.Reader(filename=output_vcf) -# except: -# assert False + # remove if exists + remove_if_exists([output_tsv]) -# assert False + # Create object + variants = Variants(input=input_vcf, output=output_tsv, load=True) + # Check get_output + query = 'SELECT "#CHROM", POS, REF, ALT, INFO FROM variants' + variants.export_output(query=query) + assert os.path.exists(output_tsv) + # Check get_output without header + output_header = output_tsv + ".hdr" + remove_if_exists([output_tsv, output_tsv + ".hdr"]) + variants.export_output(output_header=output_header, query=query) + assert os.path.exists(output_tsv) and os.path.exists(output_header) + + +def test_export(): + """ + The function tests the export functionality of a database for various input and output formats. + """ + + # database input/format + for database_input_index in ["parquet", "vcf", "vcf_gz", "tsv", "csv", "tsv_alternative_columns", "example_vcf"]: + for database_output_format in ["parquet", "vcf", "vcf.gz", "tsv", "csv", "json", "bed"]: + input_database = database_files.get(database_input_index) + output_database=f"/tmp/output_database.{database_output_format}" + output_header=output_database+".hdr" + variants = Variants(input=input_database, output=output_database, load=True) + remove_if_exists([output_database,output_header]) + try: + assert variants.export_output(output_file=output_database, output_header=output_header) + if database_output_format == "vcf": + try: + vcf.Reader(filename=output_database) + except: + assert False + except: + assert False + def test_set_get_input(): """ @@ -556,25 +591,25 @@ def test_load_tsv(): assert nb_variant_in_database == expected_number_of_variants -def test_load_psv(): - """ - This function tests if a PSV file can be loaded into a Variants object and if the expected number of - variants is present in the database. - """ +# def test_load_psv(): +# """ +# This function tests if a PSV file can be loaded into a Variants object and if the expected number of +# variants is present in the database. +# """ - # Init files - input_vcf = tests_folder + "/data/example.psv" +# # Init files +# input_vcf = tests_folder + "/data/example.psv" - # Create object - variants = Variants(input=input_vcf, load=True) +# # Create object +# variants = Variants(input=input_vcf, load=True) - # Check data loaded - result = variants.get_query_to_df("SELECT count(*) AS count FROM variants") - nb_variant_in_database = result["count"][0] +# # Check data loaded +# result = variants.get_query_to_df("SELECT count(*) AS count FROM variants") +# nb_variant_in_database = result["count"][0] - expected_number_of_variants = 7 +# expected_number_of_variants = 7 - assert nb_variant_in_database == expected_number_of_variants +# assert nb_variant_in_database == expected_number_of_variants def test_load_duckdb(): @@ -983,21 +1018,24 @@ def test_export_output_duckdb(): # Init files input_vcf = tests_folder + "/data/example.vcf.gz" - output_vcf = "/tmp/example.duckdb" + output_duckdb = "/tmp/example.duckdb" # remove if exists - remove_if_exists([output_vcf]) + remove_if_exists([output_duckdb]) # Create object - variants = Variants(input=input_vcf, output=output_vcf, load=True) + variants = Variants(input=input_vcf, output=output_duckdb, load=True) # Check get_output variants.export_output() - assert os.path.exists(output_vcf) + assert os.path.exists(output_duckdb) + + # remove if exists + remove_if_exists([output_duckdb]) # Check get_output without header variants.export_output(export_header=False) - assert os.path.exists(output_vcf) and os.path.exists(output_vcf + ".hdr") + assert os.path.exists(output_duckdb) and os.path.exists(output_duckdb + ".hdr") def test_export_output_tsv(): @@ -1078,14 +1116,14 @@ def test_export_output_csv(): assert os.path.exists(output_vcf) and os.path.exists(output_vcf + ".hdr") -def test_export_output_psv(): +def test_export_output_tbl(): """ This function tests the export_output method of the Variants class in Python. """ # Init files input_vcf = tests_folder + "/data/example.vcf.gz" - output_vcf = "/tmp/example.psv" + output_vcf = "/tmp/example.tbl" # remove if exists remove_if_exists([output_vcf]) diff --git a/tests/test_tools_annotation.py b/tests/test_tools_annotation.py index 026d79e..c27f111 100644 --- a/tests/test_tools_annotation.py +++ b/tests/test_tools_annotation.py @@ -52,13 +52,23 @@ def test_annotation(): # Query annotation(args) + # Check output file exists + assert os.path.exists(output_vcf) + # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output_nb_lines = len(f.readlines()) + result_output_nb_lines = 0 + result_output_nb_variants = 0 + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 # Expected result - expected_result_nb_lines = 8 + expected_result_nb_lines = 61 + expected_result_nb_variants = 7 # Compare assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants diff --git a/tests/test_tools_calculation.py b/tests/test_tools_calculation.py index 9f8e049..6e8482f 100644 --- a/tests/test_tools_calculation.py +++ b/tests/test_tools_calculation.py @@ -56,13 +56,23 @@ def test_calculation(): # Query calculation(args) + # Check output file exists + assert os.path.exists(output_vcf) + # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output_nb_lines = len(f.readlines()) + result_output_nb_lines = 0 + result_output_nb_variants = 0 + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 # Expected result - expected_result_nb_lines = 8 + expected_result_nb_lines = 72 + expected_result_nb_variants = 7 # Compare assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants diff --git a/tests/test_tools_convert.py b/tests/test_tools_convert.py index 450e534..00e4a1e 100644 --- a/tests/test_tools_convert.py +++ b/tests/test_tools_convert.py @@ -56,12 +56,22 @@ def test_convert(): # Query convert(args) + # Check output file exists + assert os.path.exists(output_vcf) + # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output_nb_lines = len(f.readlines()) + result_output_nb_lines = 0 + result_output_nb_variants = 0 + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 # Expected result - expected_result_nb_lines = 8 + expected_result_nb_lines = 60 + expected_result_nb_variants = 7 # Compare assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants diff --git a/tests/test_tools_prioritization.py b/tests/test_tools_prioritization.py index f0f2db1..e101d76 100644 --- a/tests/test_tools_prioritization.py +++ b/tests/test_tools_prioritization.py @@ -58,11 +58,18 @@ def test_prioritization(): # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output_nb_lines = len(f.readlines()) + result_output_nb_lines = 0 + result_output_nb_variants = 0 + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 # Expected result - expected_result_nb_lines = 8 + expected_result_nb_lines = 66 + expected_result_nb_variants = 7 # Compare assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants diff --git a/tests/test_tools_process.py b/tests/test_tools_process.py index 2c71ac0..21c3aa2 100644 --- a/tests/test_tools_process.py +++ b/tests/test_tools_process.py @@ -62,13 +62,20 @@ def test_process(): # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output_nb_lines = len(f.readlines()) + result_output_nb_lines = 0 + result_output_nb_variants = 0 + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 # Expected result - expected_result_nb_lines = 8 + expected_result_nb_lines = 68 + expected_result_nb_variants = 7 # Compare assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants def test_process_with_param_file(): @@ -103,13 +110,20 @@ def test_process_with_param_file(): # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output_nb_lines = len(f.readlines()) + result_output_nb_lines = 0 + result_output_nb_variants = 0 + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 # Expected result - expected_result_nb_lines = 8 + expected_result_nb_lines = 78 + expected_result_nb_variants = 7 # Compare assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants def test_process_with_query(): @@ -122,7 +136,7 @@ def test_process_with_query(): annotations = tests_folder + "/data/annotations/nci60.parquet" calculations = "VARTYPE" prioritizations = tests_folder + "/data/prioritization_profiles.json" - input_query = "SELECT count(*) as count FROM variants WHERE INFO LIKE '%VARTYPE%' AND INFO LIKE '%PZScore%'" + input_query = "SELECT count(*) AS '#count' FROM variants WHERE INFO LIKE '%VARTYPE%' AND INFO LIKE '%PZScore%'" # prepare arguments for the query function args = argparse.Namespace( @@ -144,12 +158,23 @@ def test_process_with_query(): # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output = f.read() + result_output_nb_lines = 0 + result_output_nb_variants = 0 + result_lines = [] + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 + result_lines.append(line.strip()) # Expected result - expected_result = "count\n7\n" + expected_result_nb_lines = 62 + expected_result_nb_variants = 1 + expected_result_lines = ["7"] # Compare - assert result_output == expected_result + assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants + assert result_lines == expected_result_lines diff --git a/tests/test_tools_query.py b/tests/test_tools_query.py index efa4ec5..35762af 100644 --- a/tests/test_tools_query.py +++ b/tests/test_tools_query.py @@ -36,7 +36,7 @@ def test_query(): input_vcf = tests_folder + "/data/example.vcf.gz" output_vcf = "/tmp/output_file.tsv" config = {'threads': 4} - input_query = "SELECT count(*) AS count FROM variants" + input_query = "SELECT count(*) AS '#count' FROM variants" for explode_infos in [True, False]: @@ -57,10 +57,21 @@ def test_query(): # read the contents of the actual output file with open(output_vcf, 'r') as f: - result_output = f.read() + result_output_nb_lines = 0 + result_output_nb_variants = 0 + result_lines = [] + for line in f: + result_output_nb_lines += 1 + if not line.startswith("#"): + result_output_nb_variants += 1 + result_lines.append(line.strip()) # Expected result - expected_result = "count\n7\n" + expected_result_nb_lines = 54 + expected_result_nb_variants = 1 + expected_result_lines = ["7"] # Compare - assert result_output == expected_result + assert result_output_nb_lines == expected_result_nb_lines + assert result_output_nb_variants == expected_result_nb_variants + assert result_lines == expected_result_lines \ No newline at end of file