diff --git a/src/plot_output.py b/src/plot_output.py index 51242cc8..9dd754e0 100644 --- a/src/plot_output.py +++ b/src/plot_output.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import matplotlib.ticker as ticker import numpy as np +import pprint class PlotOutput: @@ -198,42 +199,58 @@ def plot_transcript_usage(self): def make_pie_charts(self): """ Create pie charts for transcript alignment classifications and read assignment consistency. + Handles both combined and separate sample data structures. """ + print("self.reads_and_class structure:") + pprint.pprint(self.reads_and_class) + titles = ["Transcript Alignment Classifications", "Read Assignment Consistency"] - for i, (title, reads_dict) in enumerate(zip(titles, self.reads_and_class)): - labels = reads_dict.keys() - sizes = reads_dict.values() - total = sum(sizes) - - # Generate a file-friendly title - file_title = title.lower().replace(" ", "_") - - plt.figure() - wedges, texts, autotexts = plt.pie( - sizes, - labels=labels, - autopct="%1.1f%%", - startangle=140, - textprops=dict(color="w"), - ) - plt.setp(autotexts, size=10, weight="bold") - plt.setp(texts, size=9) + for title, data in zip(titles, self.reads_and_class): + if isinstance(data, dict): + if any(isinstance(v, dict) for v in data.values()): + # Separate 'Mutants' and 'WildType' case + for sample_name, sample_data in data.items(): + self._create_pie_chart(f"{title} - {sample_name}", sample_data) + else: + # Combined data case + self._create_pie_chart(title, data) + else: + print(f"Skipping unexpected data type for {title}: {type(data)}") - plt.axis( - "equal" - ) # Equal aspect ratio ensures that pie is drawn as a circle. - plt.title(f"{title}\nTotal: {total}") + def _create_pie_chart(self, title, data): + """ + Helper method to create a single pie chart. + """ + labels = list(data.keys()) + sizes = list(data.values()) + total = sum(sizes) - plt.legend( - wedges, - labels, - title="Categories", - loc="center left", - bbox_to_anchor=(1, 0, 0.5, 1), - ) - plot_path = os.path.join( - self.visualization_dir, f"{file_title}_pie_chart.png" - ) - plt.savefig(plot_path, bbox_inches="tight") - plt.close() + # Generate a file-friendly title + file_title = title.lower().replace(" ", "_").replace("-", "_") + + plt.figure(figsize=(12, 8)) + wedges, texts, autotexts = plt.pie( + sizes, + labels=labels, + autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total):d})", + startangle=140, + textprops=dict(color="w"), + ) + plt.setp(autotexts, size=8, weight="bold") + plt.setp(texts, size=7) + + plt.axis("equal") # Equal aspect ratio ensures that pie is drawn as a circle. + plt.title(f"{title}\nTotal: {total}") + + plt.legend( + wedges, + labels, + title="Categories", + loc="center left", + bbox_to_anchor=(1, 0, 0.5, 1), + fontsize=8, + ) + plot_path = os.path.join(self.visualization_dir, f"{file_title}_pie_chart.png") + plt.savefig(plot_path, bbox_inches="tight", dpi=300) + plt.close() diff --git a/src/post_process.py b/src/post_process.py index f801d6a8..94a1102f 100644 --- a/src/post_process.py +++ b/src/post_process.py @@ -8,6 +8,7 @@ from argparse import Namespace import tempfile import gffutils +import yaml class OutputConfig: @@ -20,6 +21,8 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): self.read_assignments = None self.input_gtf = gtf # Initialize with the provided gtf flag self.genedb_filename = None + self.yaml_input = True + self.yaml_input_path = None self.gtf_flag_needed = False # Initialize flag to check if "--gtf" is needed. self.conditions = False self.gene_grouped_counts = None @@ -37,7 +40,7 @@ def __init__(self, output_directory, use_counts=False, ref_only=None, gtf=None): self.use_counts = use_counts self.ref_only = ref_only - self._load_params_file() # Load the params file instead of parsing the log + self._load_params_file() self._find_files() self._conditional_unzip() @@ -68,13 +71,23 @@ def _process_params(self, params): self.input_gtf = self.input_gtf or params.get("genedb") self.genedb_filename = params.get("genedb_filename") - processing_sample = params.get("prefix") - if processing_sample: - self.output_directory = os.path.join( - self.output_directory, processing_sample - ) + if params.get("yaml"): + # YAML input case + self.yaml_input = True + self.yaml_input_path = params.get("yaml") + # Keep the output_directory as is, don't modify it else: - raise ValueError("Processing sample directory not found in params.") + # Non-YAML input case + self.yaml_input = False + processing_sample = params.get("prefix") + if processing_sample: + self.output_directory = os.path.join( + self.output_directory, processing_sample + ) + else: + raise ValueError( + "Processing sample directory not found in params for non-YAML input." + ) def _conditional_unzip(self): """Check if unzip is needed and perform it conditionally based on the model use.""" @@ -106,9 +119,17 @@ def _unzip_file(self, file_path): def _find_files(self): """Locate the necessary files in the directory and determine the need for the "--gtf" flag.""" + if self.yaml_input: + self.conditions = True + self.ref_only = True + self._find_files_from_yaml() + return # Exit the method after processing YAML input + if not os.path.exists(self.output_directory): print(f"Directory not found: {self.output_directory}") # Debugging output - raise FileNotFoundError("Specified sample subdirectory does not exist.") + raise FileNotFoundError( + f"Specified sample subdirectory does not exist: {self.output_directory}" + ) for file_name in os.listdir(self.output_directory): if file_name.endswith(".extended_annotation.gtf"): @@ -174,6 +195,80 @@ def _find_files(self): if self.ref_only is None: self.ref_only = not self.extended_annotation + def _find_files_from_yaml(self): + """Locate the necessary files in the directory, set specific grouped count and TPM files, and process read assignments.""" + if not os.path.exists(self.yaml_input_path): + print(f"YAML file not found: {self.yaml_input_path}") + raise FileNotFoundError( + f"Specified YAML file does not exist: {self.yaml_input_path}" + ) + + # Set the four specific attributes + self.gene_grouped_counts = os.path.join( + self.output_directory, "combined_gene_counts.tsv" + ) + self.transcript_grouped_counts = os.path.join( + self.output_directory, "combined_transcript_counts.tsv" + ) + self.transcript_grouped_tpm = os.path.join( + self.output_directory, "combined_transcript_tpm.tsv" + ) + self.gene_grouped_tpm = os.path.join( + self.output_directory, "combined_gene_tpm.tsv" + ) + + # Check if the files exist + for attr in [ + "gene_grouped_counts", + "transcript_grouped_counts", + "transcript_grouped_tpm", + "gene_grouped_tpm", + ]: + file_path = getattr(self, attr) + if not os.path.exists(file_path): + print(f"Warning: {attr} file not found at {file_path}") + setattr(self, attr, None) + + # Initialize read_assignments list + self.read_assignments = [] + + # Read and process the YAML file + with open(self.yaml_input_path, "r") as yaml_file: + yaml_data = yaml.safe_load(yaml_file) + + # Check if yaml_data is a list + if isinstance(yaml_data, list): + samples = yaml_data + else: + # If it's not a list, assume it's a dictionary with a 'samples' key + samples = yaml_data.get("samples", []) + + for sample in samples: + name = sample.get("name") + if name: + sample_dir = os.path.join(self.output_directory, name) + + # Check for .read_assignments.tsv.gz + gz_file = os.path.join(sample_dir, f"{name}.read_assignments.tsv.gz") + if os.path.exists(gz_file): + unzipped_file = self._unzip_file(gz_file) + if unzipped_file: + self.read_assignments.append((name, unzipped_file)) + else: + print(f"Warning: Failed to unzip {gz_file}") + else: + # Check for .read_assignments.tsv + non_gz_file = os.path.join( + sample_dir, f"{name}.read_assignments.tsv" + ) + if os.path.exists(non_gz_file): + self.read_assignments.append((name, non_gz_file)) + else: + print(f"Warning: No read assignments file found for {name}") + + if not self.read_assignments: + print("Warning: No read assignment files found for any samples") + class DictionaryBuilder: """Class to build dictionaries from the output files of the pipeline.""" @@ -189,25 +284,42 @@ def build_gene_transcript_exon_dictionaries(self): return self.parse_input_gtf() def build_read_assignment_and_classification_dictionaries(self): - """Indexes classifications and assignment types from the read_assignments.tsv.""" + """Indexes classifications and assignment types from read_assignments.tsv file(s).""" + if not self.config.read_assignments: + raise FileNotFoundError("No read assignments file(s) found.") + + if isinstance(self.config.read_assignments, list): + # YAML input case (multiple files) + classification_counts_dict = {} + assignment_type_counts_dict = {} + for sample_name, read_assignment_file in self.config.read_assignments: + classification_counts, assignment_type_counts = ( + self._process_read_assignment_file(read_assignment_file) + ) + classification_counts_dict[sample_name] = classification_counts + assignment_type_counts_dict[sample_name] = assignment_type_counts + return classification_counts_dict, assignment_type_counts_dict + else: + # Non-YAML input case (single file) + return self._process_read_assignment_file(self.config.read_assignments) + + def _process_read_assignment_file(self, file_path): classification_counts = {} assignment_type_counts = {} - if not self.config.read_assignments: - raise FileNotFoundError("Read assignments file is missing.") - with open(self.config.read_assignments, "r") as file: - next(file) - next(file) - next(file) + with open(file_path, "r") as file: + # Skip header lines + for _ in range(3): + next(file, None) + for line in file: - parts = line.split("\t") + parts = line.strip().split("\t") if len(parts) < 6: continue + additional_info = parts[-1] classification = ( - additional_info.split("Classification=")[-1] - .replace(";", "") - .strip() + additional_info.split("Classification=")[-1].split(";")[0].strip() ) assignment_type = parts[5] @@ -224,7 +336,7 @@ def parse_input_gtf(self): """Parses the GTF file using gffutils to build a detailed dictionary of genes, transcripts, and exons.""" gene_dict = {} if not self.config.genedb_filename: - # convert GFT to DB if we use previous IsoQuant runs + # convert GTF to DB if we use previous IsoQuant runs # remove this functionality later tmp_file = tempfile.NamedTemporaryFile(suffix=".db") self.config.genedb_filename = tmp_file.name @@ -239,50 +351,46 @@ def parse_input_gtf(self): disable_infer_genes=True, disable_infer_transcripts=True, ) - # raise FileNotFoundError("IsoQuant annotation DB file is missing.") try: - # Create a temporary database - with gffutils.FeatureDB(self.config.genedb_filename) as db: - for gene in db.features_of_type("gene"): - gene_id = gene.id - gene_dict[gene_id] = { - "chromosome": gene.seqid, - "start": gene.start, - "end": gene.end, - "strand": gene.strand, - "name": gene.attributes.get("gene_name", [""])[0], - "biotype": gene.attributes.get("gene_biotype", [""])[0], - "transcripts": {}, + # Create a database without using a context manager + db = gffutils.FeatureDB(self.config.genedb_filename) + + for gene in db.features_of_type("gene"): + gene_id = gene.id + gene_dict[gene_id] = { + "chromosome": gene.seqid, + "start": gene.start, + "end": gene.end, + "strand": gene.strand, + "name": gene.attributes.get("gene_name", [""])[0], + "biotype": gene.attributes.get("gene_biotype", [""])[0], + "transcripts": {}, + } + + for transcript in db.children(gene, featuretype="transcript"): + transcript_id = transcript.id + gene_dict[gene_id]["transcripts"][transcript_id] = { + "start": transcript.start, + "end": transcript.end, + "name": transcript.attributes.get("transcript_name", [""])[0], + "biotype": transcript.attributes.get( + "transcript_biotype", [""] + )[0], + "exons": [], + "tags": transcript.attributes.get("tag", [""])[0].split(","), } - for transcript in db.children(gene, featuretype="transcript"): - transcript_id = transcript.id - gene_dict[gene_id]["transcripts"][transcript_id] = { - "start": transcript.start, - "end": transcript.end, - "name": transcript.attributes.get("transcript_name", [""])[ - 0 - ], - "biotype": transcript.attributes.get( - "transcript_biotype", [""] - )[0], - "exons": [], - "tags": transcript.attributes.get("tag", [""])[0].split( - "," - ), + for exon in db.children(transcript, featuretype="exon"): + exon_info = { + "exon_id": exon.id, + "start": exon.start, + "end": exon.end, + "number": exon.attributes.get("exon_number", [""])[0], } - - for exon in db.children(transcript, featuretype="exon"): - exon_info = { - "exon_id": exon.id, - "start": exon.start, - "end": exon.end, - "number": exon.attributes.get("exon_number", [""])[0], - } - gene_dict[gene_id]["transcripts"][transcript_id][ - "exons" - ].append(exon_info) + gene_dict[gene_id]["transcripts"][transcript_id][ + "exons" + ].append(exon_info) except Exception as e: raise Exception(f"Error parsing GTF file: {str(e)}")