From 84188dc2df08a570d9002a51ad744c6f7a828229 Mon Sep 17 00:00:00 2001 From: hadleyking Date: Wed, 2 Oct 2024 09:49:14 -0400 Subject: [PATCH] Docerizing Changes to be committed: modified: .gitignore new file: Dockerfile new file: docker-compose.yml deleted: lib/censu-scope.py new file: lib/censuscope.py new file: lib/process_blast_file_chunk.py --- .gitignore | 4 + Dockerfile | 30 +++ docker-compose.yml | 20 ++ lib/censu-scope.py | 194 ----------------- lib/censuscope.py | 364 ++++++++++++++++++++++++++++++++ lib/process_blast_file_chunk.py | 157 ++++++++++++++ 6 files changed, 575 insertions(+), 194 deletions(-) create mode 100644 Dockerfile create mode 100644 docker-compose.yml delete mode 100644 lib/censu-scope.py create mode 100644 lib/censuscope.py create mode 100644 lib/process_blast_file_chunk.py diff --git a/.gitignore b/.gitignore index 6b6300b..a1b1b6d 100644 --- a/.gitignore +++ b/.gitignore @@ -161,3 +161,7 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +.vscode +/temp_dirs +/test_data \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..1fed4e0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,30 @@ +# Dockerfile + +FROM python:3.10-slim +RUN apt-get update && apt-get install -y \ + wget \ + gcc \ + make \ + zlib1g-dev \ + build-essential \ + ca-certificates \ + && rm -rf /var/lib/apt/lists/* + +# Install NCBI BLAST+ +RUN apt-get update && apt-get install -y ncbi-blast+ \ + && rm -rf /var/lib/apt/lists/* + +# Install Seqtk +RUN wget https://github.com/lh3/seqtk/archive/refs/tags/v1.3.tar.gz && \ + tar -xzvf v1.3.tar.gz && \ + cd seqtk-1.3 && \ + make && \ + mv seqtk /usr/local/bin/ && \ + cd .. && \ + rm -rf seqtk-1.3 v1.3.tar.gz + +WORKDIR /app + +COPY . . + +CMD ["python", "lib/censuscope.py"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..c57c0be --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,20 @@ +# docker-compose.yml + +services: + censuscope: + build: . + image: censuscope_image + container_name: censuscope_container + command: > + python3 lib/censuscope.py + -i 10 + -s 5 + -t 5 + -q /app/input_files/GW-04_S4_L001_R1_001.fastq + -d /app/input_files/blast_db/gfkb/HumanGutDB-v2.6.fasta + + volumes: + - /tmp/test:/app/input_files + - /tmp/test/outputs:/app/temp_dirs + stdin_open: true + tty: true diff --git a/lib/censu-scope.py b/lib/censu-scope.py deleted file mode 100644 index b252c32..0000000 --- a/lib/censu-scope.py +++ /dev/null @@ -1,194 +0,0 @@ -#!/usr/bin/env python3 -""" -""" - -import csv -import json -import argparse -import sys -import os -import time -import random -import shutil -import subprocess -from urllib.parse import urlparse -from Bio import SeqIO -import tempfile - -__version__ = "0.1" -__status__ = "BETA" - -def usr_args(): - """User supplied arguments for functions - """ - - parser = argparse.ArgumentParser() - - parser = argparse.ArgumentParser( - prog='censuscope', - usage='%(prog)s [options]') - - # version - parser.add_argument( - '-v', '--version', - action='version', - version='%(prog)s ' + __version__) - - parser.add_argument( - "-i", '--iterations', - required=True, - help="The number of sample iterations to perform" - ) - - parser.add_argument( - "-s", '--sample-size', - required=True, - help="The number of reads to sample for each iteration" - ) - - parser.add_argument( - "-t", '--tax-depth', - required=True, - help="The taxonomy depth to report in the final results" - ) - - parser.add_argument( - "-q", '--query', - required=True, - help="Input file name" - ) - - parser.add_argument( - "-d", '--database', - required=True, - help="BLAST database name" - ) - - if len(sys.argv) <= 1: - sys.argv.append('--help') - - options = parser.parse_args() - return options - - -def sample_randomizer(iteration_count: int, query:str, sample_size: int): - """ - # Initialization - - # Open the file and read first char - # Process - # Get the number of lines sequences in the query file - # Generate a list of random indexes - # Random sampling or file copying - """ - - try: - with open(query, 'r') as query_file: - first_char = query_file.read(1) - print(first_char) - except Exception as error: - print(f"Could not get handle: {error}") - # break - - if first_char == ">": - records = list(SeqIO.parse(query, "fasta")) - total_reads = len(records) - seq_format = "FASTA" - print(f"{total_reads} {seq_format} records") - elif first_char == "@": - records = list(SeqIO.parse(query, "fastq")) - total_reads = len(records) - seq_format = "FASTQ" - print(f"{total_reads} {seq_format} records") - - if total_reads > sample_size: - print("subset file") - for it in range(1, iteration_count + 1): - print(f"{it}- out of {iteration_count}") - sample_list = random.sample(range(0,total_reads), sample_size) - with open(f"home/random_samples/random_sample.{it}.fasta", "w") as random_file: - for sample in sample_list: - SeqIO.write(records[sample], random_file, "fasta") - - else: - print("whole file") - with open(f"home/random_samples/random_sample.1.fasta", "w") as random_file: - for sample in records: - SeqIO.write(sample, random_file, "fasta") - - - -def blastn(query, database): - """""" - - filenames = next(os.walk("home/random_samples"), (None, None, []))[2] - for random_sample in filenames: - identifier = random_sample.split(".")[1] - output = f"home/blastn/result_{identifier}.txt" - - query = ( - f"blastn -db {database} " - f"-query home/random_samples/{random_sample} -out {output} -outfmt 10 " - f"-num_threads 10 -evalue 1e-6 -max_target_seqs 1 -perc_identity 80" - ) - - subprocess.run(query, shell=True) - -def refine(): - """ - # Read the contents of the file into a list of lines - # Extract the first element (before the comma) of each line and store it in 'read' list - # Get the unique elements from the 'read' list - # Build the total string by appending lines that have a unique identifier - # Write the total string to the refine file - """ - - blast_results = next(os.walk("home/blastn"), (None, None, []))[2] - - for result in blast_results: - - total = "" - identifier = result.split("_")[1].split(".")[0] - refine_name = f"home/blastn/refined.{identifier}.txt" - print(identifier, result) - - with open(f"home/blastn/{result}", "r") as blast_file: - data = blast_file.readlines() - - read_ids = [line.split(",")[0] for line in data] - unique_ids = list(dict.fromkeys(read_ids)) - - import pdb; pdb.set_trace() - for datum in range(len(data)): - if datum < len(unique_ids): - total += data[datum] - with open(refine_name, "w") as refined_file: - refined_file.write(total) - - -def main(): - """ - Main function - """ - options = usr_args() - iter_counter = 0 - time_start = time.time() - iteration_count = int(options.iterations) - - # sample_randomizer( - # iteration_count=iteration_count, - # query = options.query, - # sample_size=int(options.sample_size) - # ) - - # blastn( - # query=options.query, - # database=options.database - # ) - - refine() - - - -if __name__ == "__main__": - main() diff --git a/lib/censuscope.py b/lib/censuscope.py new file mode 100644 index 0000000..8f798fa --- /dev/null +++ b/lib/censuscope.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +""" +""" + +import argparse +import concurrent.futures +import csv +import json +import os +import logging +import random +import shutil +import subprocess +import sys +import time +from collections import defaultdict +from datetime import datetime +from threading import Thread, Lock +from urllib.parse import urlparse + +write_lock = Lock() + +__version__ = "0.1" +__status__ = "BETA" + + +class GlobalState: + def __init__(self): + self.base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + self.start_time = datetime.fromtimestamp(time.time()).strftime('%Y_%m_%d_%H.%M.%S%z') + self.temp_path = f"{self.base_dir}/temp_dirs/{self.start_time}" + self.temp_dirs = { + "random_samples": f"{self.temp_path}/random_samples", + "results": f"{self.temp_path}/results", + "blastn": f"{self.temp_path}/blastn", + "inputs": f"{self.temp_path}/inputs", + } + self.iteration_count = 0 + + os.makedirs(self.temp_path, exist_ok=True) + for dir_path in self.temp_dirs.values(): + os.makedirs(dir_path, exist_ok=True) + + logging.info( + "#________________________________________________________________________________#" + ) + logging.info( + f"Global State values set: {self.start_time}, {self.base_dir}, {self.temp_path}" + ) +global_state = GlobalState() + +log_file_path = os.path.join(global_state.temp_path, "CensuScope_app.log") +logging.basicConfig( + level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_file_path), + # logging.StreamHandler() + ] +) + + +def usr_args(): + """User supplied arguments for functions + """ + + parser = argparse.ArgumentParser() + + parser = argparse.ArgumentParser( + prog='censuscope', + usage='%(prog)s [options]') + + # version + parser.add_argument( + '-v', '--version', + action='version', + version='%(prog)s ' + __version__) + + parser.add_argument( + "-i", '--iterations', + required=True, + help="The number of sample iterations to perform" + ) + + parser.add_argument( + "-s", '--sample-size', + required=True, + help="The number of reads to sample for each iteration" + ) + + parser.add_argument( + "-t", '--tax-depth', + required=True, + help="The taxonomy depth to report in the final results" + ) + + parser.add_argument( + "-q", '--query_path', + required=True, + help="Input file name" + ) + + parser.add_argument( + "-d", '--database', + required=True, + help="BLAST database name" + ) + + if len(sys.argv) <= 1: + sys.argv.append('--help') + + options = parser.parse_args() + return options + + + +def configure_env(): + """ + """ + + for key, value in global_state.temp_dirs.items(): + os.makedirs(value, exist_ok=True) + + +def count_sequences(query_path: str) -> int: + """ + Use grep to count the number of sequences in a file. + """ + + try: + result = subprocess.run(['grep', '-c', '>', query_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + return int(result.stdout.strip()) + except subprocess.CalledProcessError as e: + raise ValueError(f"Error counting sequences: {e}") + + +def sample_randomizer(iteration_count: int, query_path: str, sample_size: int): + """ + Sample random sequences from a large FASTA or FASTQ file using awk to extract sequences. + """ + + random_samples_path = global_state.temp_dirs["random_samples"] + + # Step 1: Determine how many reads we have + total_reads = count_sequences(query_path) # Assuming count_sequences uses grep to count headers + logging.info(f"{total_reads} FASTA records") + + if total_reads > sample_size: + logging.info("Subset file") + + for it in range(1, iteration_count + 1): + logging.info(f"{it}- out of {iteration_count}") + + # Step 2: Generate random sample indices (these are read indices, not line indices) + sample_indices = random.sample(range(total_reads), sample_size) + sample_indices.sort() + logging.info(f"Sample indices (sorted): {sample_indices}") + + # Step 3: Create an awk script that will extract sequences based on sample_indices + # We store the indices in an associative array + awk_script = f"""BEGIN {{ split("{','.join(map(str, sample_indices))}", samples, ","); for (i in samples) sample_map[samples[i]] = 1 }} + /^>/ {{ n++; if (n in sample_map) {{ print; while (getline && !/^>/) print }} }}""" + + # Use awk to extract the sequences in one go + awk_command = f"awk '{awk_script}' {query_path} > {random_samples_path}/random_sample.{it}.fasta" + + try: + subprocess.run(awk_command, shell=True, check=True) + except subprocess.CalledProcessError as e: + logging.exception(f"Error during awk execution: {e}") + + else: + logging.info("Whole file") + subprocess.run(f"cp {query_path} home/random_samples/random_sample.1.fasta", shell=True) + + +def blastn(query_path, database): + """""" + random_samples_path = global_state.temp_dirs["random_samples"] + blastn_path = global_state.temp_dirs["blastn"] + filenames = next(os.walk(random_samples_path), (None, None, []))[2] + for random_sample in filenames: + identifier = random_sample.split(".")[1] + output = f"{blastn_path}/result_{identifier}.txt" + + blast_command = ( + f"blastn -db {database} " + f"-query {random_samples_path}/{random_sample} -out {output} -outfmt 10 " + f"-num_threads 10 -evalue 1e-6 -max_target_seqs 1 -perc_identity 80" + ) + + subprocess.run(blast_command, shell=True) + + +def process_blast_file(result, refine_name, iteration, overall_hits): + """ + Process a single blast result file: + - Filter out unique GB accessions. + - Track hit counts for each accession. + - Write refined data and taxonomy hit counts. + """ + unique_accessions = set() + filtered_data = [] + tax_data = defaultdict(int) # Use defaultdict for simpler counting + + with open(result, "r") as blast_file: + reader = csv.reader(blast_file) + + # Process each row in the blast file + for row in reader: + read_id = row[0] + # Parse the GB accession from the row + accession = row[1].split("|")[3] if len(row[1].split("|")) > 3 else row[1] + + # Ensure we only process unique accessions + if accession not in unique_accessions: + unique_accessions.add(accession) + filtered_data.append(row) + tax_data[accession] += 1 + + # Update the overall hit counts (protected by a lock for multithreading) + with write_lock: + for accession, hit_count in tax_data.items(): + overall_hits[accession].append(hit_count) + + # Write the filtered data to the refined file + with write_lock: + with open(refine_name, "w", newline='') as refined_file: + writer = csv.writer(refined_file) + writer.writerows(filtered_data) + + # Write taxonomy hit counts for this iteration + with write_lock: + with open(iteration, "w", newline='') as tax_file: + writer = csv.writer(tax_file) + for accession, count in tax_data.items(): + writer.writerow([accession, count]) + + +def refine_blast_files(): + """ + Iterate over all blast result files, process each one. + Use multithreading to process multiple files concurrently. + Track hit counts across all files and calculate averages. + """ + + blastn_path = global_state.temp_dirs["blastn"] + results_path = global_state.temp_dirs["results"] + blast_results = next(os.walk(blastn_path), (None, None, []))[2] + overall_hits = defaultdict(list) # A dictionary to track hits for each accession across files + + # Use ThreadPoolExecutor for parallel processing + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + for result in blast_results: + if "result" not in result: + continue + + identifier = result.split("_")[1].split(".")[0] + iteration = f"{results_path}/iteration{identifier}.csv" + refine_name = f"{results_path}/refined.{identifier}.csv" + + # Pass the file names and paths to the thread pool + futures.append(executor.submit(process_blast_file, f"{blastn_path}/{result}", refine_name, iteration, overall_hits)) + + # Wait for all threads to complete + for future in concurrent.futures.as_completed(futures): + future.result() # Raise any exceptions that occurred during processing + + # After all files are processed, calculate averages and write the final table + write_final_table(overall_hits) + + +def write_final_table(overall_hits): + """ + Calculate the average hit count for each GB accession and write the final output. + """ + results_path = global_state.temp_dirs["results"] + total_hits = sum([sum(hits) for hits in overall_hits.values()]) # Total hits across all accessions + + final_table = [] + + for accession, hits in overall_hits.items(): + hit_sum = sum(hits) + iterations_present = len(hits) + average_hits = round(hit_sum / total_hits, 4) if total_hits > 0 else 0 + final_table.append([accession, hit_sum, iterations_present, average_hits]) + + # Write the final table to a file + with open(f"{results_path}/final_table.csv", "w", newline='') as final_file: + writer = csv.writer(final_file) + writer.writerow(["Accession", "Total Hits", "Iterations Present", "Average Hits"]) + writer.writerows(final_table) + + +def fastq_to_fasta(query_path): + """ + Use grep to get the first character of the first line in the file. + Determine the file format (FASTA or FASTQ) + If format is FASTQ then convert to a FASTA and stort in a temp file. + This uses the seqtk (https://github.com/lh3/seqtk) program. + """ + + output_fasta = global_state.temp_dirs["inputs"]+"/temp.fasta" + + head_command = f"head -n 1 {query_path} | cut -c1" + try: + head_char = subprocess.run( + head_command, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ).stdout.strip() + except subprocess.CalledProcessError as e: + logging.exception(f"Error counting sequences: {e}") + return 0 + + + if head_char == ">": + return query_path + + elif head_char == "@": + try: + subprocess.run(f"seqtk seq -a {query_path} > {output_fasta}", shell=True, check=True) + logging.info(f"Conversion complete: {output_fasta}") + return output_fasta + except subprocess.CalledProcessError as e: + raise ValueError(f"Error during conversion: {e}") + + else: + logging.critical(f"Unknown file format {head_char}") + raise ValueError(f"Unsupported sequence header format: {head_char}") + + +def main(): + """ + Main function + """ + + options = usr_args() + configure_env() + + options.query_path = fastq_to_fasta( + query_path=options.query_path + ) + + sample_randomizer( + iteration_count=int(options.iterations), + query_path = options.query_path, + sample_size=int(options.sample_size) + ) + + blastn( + query_path=options.query_path, + database=options.database + ) + + refine_blast_files() + + + +if __name__ == "__main__": + main() diff --git a/lib/process_blast_file_chunk.py b/lib/process_blast_file_chunk.py new file mode 100644 index 0000000..56baec0 --- /dev/null +++ b/lib/process_blast_file_chunk.py @@ -0,0 +1,157 @@ +import csv +from threading import Thread, Lock + +"""Key Modifications for Large File Processing: +Row-by-Row Processing: + +Each chunk of the file is processed row by row without loading the entire file into memory. +This ensures efficient memory usage, particularly for large files. +Threaded Processing: + +The file is split into chunks, and each thread processes a portion of the file. +Each thread processes its chunk and checks for unique IDs both locally (for efficiency) and globally (to ensure uniqueness across threads). +Locking for Thread Safety: + +A Lock is used to ensure that only one thread writes to the output file at a time, preventing race conditions. +The global unique_ids_global set is protected by the lock as well, ensuring that only one thread updates it at a time. +Efficient Memory and I/O Usage: + +The file is streamed, meaning only portions of it are processed at a time, making it more efficient for large files. + +How This Works: +File Splitting: + +The file size is determined, and the file is split into num_threads chunks, each processed by a separate thread. +Processing Chunks: + +Each thread processes its chunk of the file, ensuring that only unique IDs are processed. +For efficiency, each thread maintains its own local set (unique_ids_local) and checks the global set (unique_ids_global) to avoid duplicates. +Writing to File: + +The write_lock ensures that only one thread writes to the output file at a time, avoiding conflicts and ensuring the file is written sequentially. +Handling Large Files: + +By streaming the file row-by-row and processing in chunks, memory usage is minimized, which is critical when handling very large files. +Further Considerations: +Chunk Splitting: The chunk-splitting method uses file byte positions (seek(start)), but this assumes that no row is split between chunks. To handle splitting gracefully, you might want to adjust the start position of each chunk to the beginning of the next complete row. + +Exception Handling: You can add error handling within each thread to ensure that any issues (such as file reading/writing errors) are properly managed. + +This approach should be suitable for processing large files efficiently in a multithreaded environment. + +---------------------- +Setting Up the Logger: + +The basicConfig() method configures logging to both a file (process.log) and the console (via StreamHandler()). +level=logging.DEBUG ensures that all log levels (from DEBUG to CRITICAL) are captured. +The log format includes timestamps, the severity level of the log, and the message. +Using the Logger in Functions: + +logger.info(): Logs general information, such as when a thread starts or finishes processing. +logger.error(): Logs any exceptions that occur, with the exc_info=True option to include a traceback in the log file. +logger.critical(): Logs critical issues that may halt the program, like failing to process the file. +Thread-Specific Logging: + +Each thread logs its progress, making it easier to debug any issues specific to a certain part of the file being processed. +Why Use Logging Over print() for Debugging? +Thread Safety: Logging is thread-safe, meaning multiple threads can write to the log without conflicts. +Severity Levels: You can filter logs based on severity (DEBUG, INFO, ERROR, etc.), making it easier to focus on specific kinds of messages during debugging. +Persistent Logs: Logs are saved to a file for later analysis, unlike print() statements that only appear in the console. +Configurable: You can easily change the logging behavior (e.g., log to a different file, change verbosity) without modifying the function code. +Structured and Timestamped: Logs provide more structure, making it easier to track what happened when. +Additional Enhancements: +Dynamic Log Levels: Allow the user to set the log level dynamically (e.g., via command-line arguments), so you can increase or decrease verbosity without changing the code. +Logging Per Module: You can set up separate loggers for different modules of your program, allowing for more granular control over logging. +Conclusion: +Incorporating the logging module into your Python functions is a robust and efficient way to track progress and debug issues, especially in a multithreaded or complex program. It gives you flexibility and control over how much information you capture and where that information is stored. +""" + +import logging +import os +from threading import Thread, Lock + +# Set up logging +logging.basicConfig(level=logging.DEBUG, + format='%(asctime)s [%(levelname)s] %(message)s', + handlers=[logging.FileHandler("process.log"), logging.StreamHandler()]) + +logger = logging.getLogger(__name__) + +write_lock = Lock() + +def process_blast_file_chunk(result, refine_name, start, end, unique_ids_global): + try: + # Log the start of chunk processing + logger.info(f"Thread {start}-{end} started processing chunk.") + + with open(result, "r") as blast_file: + reader = csv.reader(blast_file) + + # Move to the start of the chunk + blast_file.seek(start) + + filtered_data = [] + unique_ids_local = set() + + for row in reader: + # Stop processing if we've reached the end of the chunk + if blast_file.tell() > end: + break + + read_id = row[0] + + if read_id not in unique_ids_local and read_id not in unique_ids_global: + unique_ids_local.add(read_id) + with write_lock: + unique_ids_global.add(read_id) + filtered_data.append(row) + + # Log when writing the output + logger.info(f"Thread {start}-{end} is writing filtered data to the file.") + + with write_lock: + with open(refine_name, "a", newline='') as refined_file: + writer = csv.writer(refined_file) + writer.writerows(filtered_data) + + # Log the completion of the chunk processing + logger.info(f"Thread {start}-{end} completed processing chunk.") + + except Exception as e: + # Log any errors encountered + logger.error(f"Error processing chunk {start}-{end}: {str(e)}", exc_info=True) + +def refine_blast_file(result, refine_name, num_threads=4): + try: + logger.info(f"Starting to process the file {result} with {num_threads} threads.") + + # Find file size + file_size = os.path.getsize(result) + + # Global set to track all unique ids across threads + unique_ids_global = set() + + # Split the file into chunks for each thread + chunk_size = file_size // num_threads + threads = [] + + # Clear the output file before starting + open(refine_name, "w").close() + + # Create threads to process each chunk + for i in range(num_threads): + start = i * chunk_size + end = (i + 1) * chunk_size if i != num_threads - 1 else file_size + + thread = Thread(target=process_blast_file_chunk, args=(result, refine_name, start, end, unique_ids_global)) + threads.append(thread) + thread.start() + + # Wait for all threads to finish + for thread in threads: + thread.join() + + logger.info(f"Processing completed for file {result}. Output saved to {refine_name}.") + + except Exception as e: + logger.critical(f"Failed to process the file {result}: {str(e)}", exc_info=True)