diff --git a/.github/workflows/build_and_push.yml b/.github/workflows/build_and_push.yml index a2e6251b..327b6f2a 100644 --- a/.github/workflows/build_and_push.yml +++ b/.github/workflows/build_and_push.yml @@ -21,6 +21,6 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} # add --with-dev to below commands to build & push the dev image - name: Build docker image - run: ./docker/build + run: ./docker/build --with-dev # TODO: remove --with-dev before merge - name: Push docker image - run: ./docker/push + run: ./docker/push --with-dev # TODO: remove --with-dev before merge diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 351d5c89..dd49f71b 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -89,8 +89,8 @@ jobs: sudo apt-get install -y '^libxcb.*-dev' libx11-xcb-dev libglu1-mesa-dev libxrender-dev libxi-dev libxkbcommon-dev libxkbcommon-x11-dev pip install joblib==1.1.0 conda install -c bioconda pp-sketchlib=2.0.0 - pip3 install git+https://github.com/bacpop/PopPUNK@v2.7.2#egg=PopPUNK - + pip3 install git+https://github.com/bacpop/PopPUNK@network_relabelling#egg=PopPUNK +# TODO revert above branch before merge - name: Install Poetry uses: snok/install-poetry@v1 with: diff --git a/beebop/app.py b/beebop/app.py index 022f49fd..90990ba3 100644 --- a/beebop/app.py +++ b/beebop/app.py @@ -410,11 +410,11 @@ def get_network_graphs(p_hash) -> json: """ fs = PoppunkFileStore(storage_location) try: - cluster_result = get_clusters_internal(p_hash, storage_location) + cluster_result = get_cluster_assignments(p_hash, storage_location) graphmls = {} for cluster_info in cluster_result.values(): cluster = cluster_info["cluster"] - path = fs.network_output_component( + path = fs.pruned_network_output_component( p_hash, get_cluster_num(cluster) ) with open(path, "r") as graphml_file: @@ -494,9 +494,9 @@ def get_results(result_type) -> json: storage_location) -def get_clusters_internal(p_hash: str, storage_location: str) -> dict: +def get_cluster_assignments(p_hash: str, storage_location: str) -> dict: """ - [returns cluster assignment results ] + [returns cluster assignment results] :param p_hash: [project hash] :param storage_location: [storage location] @@ -516,7 +516,7 @@ def get_clusters_json(p_hash: str, storage_location: str) -> json: :param storage_location: [storage location] :return json: [response object with cluster results stored in 'data'] """ - cluster_result = get_clusters_internal(p_hash, storage_location) + cluster_result = get_cluster_assignments(p_hash, storage_location) cluster_dict = {value['hash']: value for value in cluster_result.values()} failed_samples = get_failed_samples_internal(p_hash, storage_location) @@ -620,7 +620,7 @@ def get_project(p_hash) -> json: if "error" in status: return jsonify(error=response_failure(status)), 500 else: - clusters_result = get_clusters_internal(p_hash, storage_location) + clusters_result = get_cluster_assignments(p_hash, storage_location) failed_samples = get_failed_samples_internal(p_hash, storage_location) fs = PoppunkFileStore(storage_location) diff --git a/beebop/assignClusters.py b/beebop/assignClusters.py index 3c0a6bd4..7c0f3ec6 100644 --- a/beebop/assignClusters.py +++ b/beebop/assignClusters.py @@ -208,7 +208,7 @@ def handle_external_clusters( sketches_dict, not_found_query_names, output_full_tmp, - not_found_query_clusters + not_found_query_clusters, ) ) queries_names.extend(not_found_query_names_new) @@ -269,13 +269,11 @@ def handle_not_found_queries( assign_query_clusters( config, config.full_db_fs, not_found_query_names, output_full_tmp ) - query_names, query_clusters, _, _, _, _, _ = ( - summarise_clusters( - output_full_tmp, - config.species, - config.full_db_fs.db, - not_found_query_names, - ) + query_names, query_clusters, _, _, _, _, _ = summarise_clusters( + output_full_tmp, + config.species, + config.full_db_fs.db, + not_found_query_names, ) handle_files_manipulation( @@ -347,15 +345,17 @@ def update_external_clusters( not_found_prev_querying = config.fs.external_previous_query_clustering_tmp( config.p_hash ) - external_clusters_not_found, _ = get_external_clusters_from_file( + + update_external_clusters_csv( + previous_query_clustering, not_found_prev_querying, not_found_query_names, - config.external_clusters_prefix, ) - update_external_clusters_csv( - previous_query_clustering, + + external_clusters_not_found, _ = get_external_clusters_from_file( + not_found_prev_querying, not_found_query_names, - external_clusters_not_found, + config.external_clusters_prefix, ) external_clusters.update(external_clusters_not_found) @@ -399,9 +399,7 @@ def copy_include_files(output_full_tmp: str, outdir: str) -> None: merge_txt_files(dest_file, source_file) os.remove(source_file) else: - os.rename( - source_file, dest_file - ) + os.rename(source_file, dest_file) def filter_queries( diff --git a/beebop/filestore.py b/beebop/filestore.py index 960404db..15dd9845 100644 --- a/beebop/filestore.py +++ b/beebop/filestore.py @@ -230,6 +230,24 @@ def network_output_component(self, p_hash, component_number) -> str: ) ) + def pruned_network_output_component(self, p_hash, component_number) -> str: + """ + [Generates the path to the pruned network component file + for the given project hash and component number.] + + :param p_hash: [project hash] + :param component_number: [component number, + which is the same as cluster number] + :return str: [path to pruned network component file] + """ + return str( + PurePath( + self.output(p_hash), + "network", + f"pruned_network_component_{component_number}.graphml", + ) + ) + def tmp(self, p_hash) -> str: """ :param p_hash: [project hash] diff --git a/beebop/poppunkWrapper.py b/beebop/poppunkWrapper.py index a6959b17..d344071a 100644 --- a/beebop/poppunkWrapper.py +++ b/beebop/poppunkWrapper.py @@ -108,7 +108,7 @@ def create_microreact(self, cluster: str, internal_cluster: str) -> None: ), previous_mst=None, previous_distances=None, - network_file=self.fs.network_file(self.p_hash), + network_file=None, gpu_graph=self.args.visualise.gpu_graph, info_csv=self.args.visualise.info_csv, rapidnj=shutil.which("rapidnj"), @@ -120,6 +120,7 @@ def create_microreact(self, cluster: str, internal_cluster: str) -> None: recalculate_distances=self.args.visualise.recalculate_distances, use_partial_query_graph=self.fs.partial_query_graph(self.p_hash), tmp=self.fs.tmp(self.p_hash), + extend_query_graph=True ) def create_network(self) -> None: @@ -152,7 +153,7 @@ def create_network(self) -> None: ), previous_mst=None, previous_distances=None, - network_file=self.fs.network_file(self.p_hash), + network_file=None, gpu_graph=self.args.visualise.gpu_graph, info_csv=self.args.visualise.info_csv, rapidnj=shutil.which("rapidnj"), @@ -164,4 +165,5 @@ def create_network(self) -> None: recalculate_distances=self.args.visualise.recalculate_distances, use_partial_query_graph=self.fs.partial_query_graph(self.p_hash), tmp=self.fs.tmp(self.p_hash), + extend_query_graph=True ) diff --git a/beebop/utils.py b/beebop/utils.py index 2c2a3c2f..052c2179 100644 --- a/beebop/utils.py +++ b/beebop/utils.py @@ -7,6 +7,7 @@ import glob import pandas as pd from beebop.filestore import PoppunkFileStore +from networkx import read_graphml, write_graphml, Graph ET.register_namespace("", "http://graphml.graphdrawing.org/xmlns") ET.register_namespace("xsi", "http://www.w3.org/2001/XMLSchema-instance") @@ -82,9 +83,82 @@ def replace_filehashes(folder: str, filename_dict: dict) -> None: print(line) -def add_query_ref_status( - fs: PoppunkFileStore, p_hash: str, filename_dict: dict -) -> None: +def create_subgraphs(network_folder: str, filename_dict: dict) -> None: + """ + [Create subgraphs for the network visualisation. These are what + will be sent back to the user to see. + The subgraphs are created + by selecting a maximum number nodes, prioritizing query nodes and adding + neighbor nodes until the maximum number of nodes is reached. The query + nodes are highlighted in the network graph by adding a ref or query status + to the .graphml files.] + + :param network_folder: [path to the network folder] + :param filename_dict: [dict that maps filehashes(keys) toclear + corresponding filenames (values) of all query samples. We only need + the filenames here.] + """ + query_names = list(filename_dict.values()) + + for path in get_component_filenames(network_folder): + SubGraph = build_subgraph(path, query_names) + + add_query_ref_to_graph(SubGraph, query_names) + + write_graphml( + SubGraph, + path.replace("network_component", "pruned_network_component"), + ) + + +def get_component_filenames(network_folder: str) -> list[str]: + """ + [Get all network component filenames in the network folder.] + + :param network_folder: [path to the network folder] + :return list: [list of all network component filenames] + """ + return glob.glob(network_folder + "/network_component_*.graphml") + + +def build_subgraph(path: str, query_names: list) -> Graph: + """ + [Build a subgraph from a network graph, prioritizing query nodes and + adding neighbor nodes until the maximum number of nodes is reached.] + + :param path: [path to the network graph] + :param query_names: [list of query sample names] + :return nx.Graph: [subgraph] + """ + MAX_NODES = 30 # arbitrary number based on performance + Graph = read_graphml(path) + + # get query nodes + query_nodes = { + node for (node, id) in Graph.nodes(data="id") if id in query_names + } + + # get neighbor nodes of query nodes + neighbor_nodes = set() + for node in query_nodes: + neighbor_nodes.update(Graph.neighbors(node)) + + # remove query nodes from neighbor nodes + neighbor_nodes = neighbor_nodes - query_nodes + + # create final set of nodes, prioritizing query nodes + sub_graph_nodes = set() + sub_graph_nodes.update(query_nodes) + + # add neighbor nodes until we reach the maximum number of nodes + remaining_capacity = MAX_NODES - len(sub_graph_nodes) + if remaining_capacity > 0: + sub_graph_nodes.update(list(neighbor_nodes)[:remaining_capacity]) + + return Graph.subgraph(sub_graph_nodes) + + +def add_query_ref_to_graph(graph: Graph, query_names: list) -> None: """ [The standard poppunk visualisation output for the cytoscape network graph (.graphml file) does not include information on whether a sample has been @@ -94,31 +168,14 @@ def add_query_ref_status( This is done by adding a new element to the nodes, with the key "ref_query" and the value being coded as either 'query' or 'ref'.] - :param fs: [filestore to locate output files] - :param p_hash: [project hash to find right project folder] - :param filename_dict: [dict that maps filehashes(keys) toclear - corresponding filenames (values) of all query samples. We only need - the filenames here.] + :param graph: [networkx graph object] + :param query_names: [list of query sample names] """ - # list of query filenames - query_names = list(filename_dict.values()) - # list of all component graph filenames - file_list = glob.glob( - fs.output_network(p_hash) + "/network_component_*.graphml" - ) - for path in file_list: - xml_tree = ET.parse(path) - graph = xml_tree.getroot() - nodes = graph.findall(".//{http://graphml.graphdrawing.org/xmlns}node") - for node in nodes: - name = node.find("./").text - child = ET.Element("data") - child.set("key", "ref_query") - child.text = "query" if name in query_names else "ref" - node.append(child) - ET.indent(xml_tree, space=" ", level=0) - with open(path, "wb") as f: - xml_tree.write(f, encoding="utf-8") + for node, id in graph.nodes(data="id"): + if id in query_names: + graph.nodes[node]["ref_query"] = "query" + else: + graph.nodes[node]["ref_query"] = "ref" def get_lowest_cluster(clusters_str: str) -> int: @@ -134,6 +191,38 @@ def get_lowest_cluster(clusters_str: str) -> int: return min(clusters) +def replace_merged_component_filenames(network_folder: str) -> None: + """ + [Replace the filenames of merged network components with the lowest + cluster number. These lowest numbers correspond to the external + cluster we use/display] + + :param network_folder: [path to the network folder] + """ + for file_path in get_component_filenames(network_folder): + if ";" in file_path: + filename = os.path.basename(file_path) + cluster_nums_str = re.search( + r"network_component_([^.]+)\.graphml", filename + ).group( + 1 + ) # extracts component string "1;2;3" + cluster_num = get_lowest_cluster(cluster_nums_str) + + new_path = os.path.join( + network_folder, f"network_component_{cluster_num}.graphml" + ) + + # Handle potential file conflict + if not os.path.exists(new_path) or new_path == file_path: + os.rename(file_path, new_path) + else: + print( + "Warning: " + f"{new_path} already exists, " + f"skipping rename of {file_path}" + ) + + def get_external_clusters_from_file( previous_query_clustering_file: str, hashes_list: list, @@ -153,10 +242,9 @@ def get_external_clusters_from_file( :return tuple: [dictionary of sample hash to external cluster name, list of sample hashes that were not found] """ - df, samples_mask = get_df_sample_mask( + filtered_df = get_df_filtered_by_samples( previous_query_clustering_file, hashes_list ) - filtered_df = df[samples_mask] # Split into found and not found based on NaN values found_mask = filtered_df["Cluster"].notna() @@ -174,31 +262,69 @@ def get_external_clusters_from_file( return hash_to_cluster_mapping.to_dict(), not_found_hashes +def get_external_cluster_nums( + previous_query_clustering_file: str, hashes_list: list +) -> dict[str, str]: + """ + [Get external cluster numbers for samples in the external clusters file.] + + :param previous_query_clustering_file: [Path to CSV file + containing sample data] + :param hashes_list: [List of sample hashes to find samples for] + :return dict: [Dictionary mapping sample names to external cluster names] + """ + filtered_df = get_df_filtered_by_samples( + previous_query_clustering_file, hashes_list + ) + + sample_cluster_num_mapping = filtered_df["Cluster"].astype(str) + sample_cluster_num_mapping.index = filtered_df["sample"] + + return sample_cluster_num_mapping.to_dict() + + +def get_df_filtered_by_samples(previous_query_clustering_file: str, + hashes_list: list) -> pd.DataFrame: + """ + [Filter a DataFrame by sample names.] + + :param previous_query_clustering_file: [Path to CSV file + containing sample data] + :param hashes_list: [List of sample hashes to find samples for] + :return pd.DataFrame: [DataFrame containing sample data] + """ + df, samples_mask = get_df_sample_mask( + previous_query_clustering_file, hashes_list + ) + return df[samples_mask] + + def update_external_clusters_csv( - previous_query_clustering_file: str, + dest_query_clustering_file: str, + source_query_clustering_file: str, q_names: list, - external_clusters_to_update: dict, ) -> None: """ [Update the external clusters CSV file with the clusters of the samples that were not found in the external clusters file.] - :param previous_query_clustering_file: [Path to CSV file - containing sample data] - :param q_names: [List of sample names - that were not - found in the external clusters file] - :param external_clusters_to_update: [Dictionary mapping - sample names to external cluster names] + :param dest_query_clustering_file: [Path to CSV file + containing sample data to copy into] + :param source_query_clustering_file: [Path to CSV file + containing sample data to copy from] + :param q_names: [List of sample names to match] """ df, samples_mask = get_df_sample_mask( - previous_query_clustering_file, q_names + dest_query_clustering_file, q_names ) + sample_cluster_num_mapping = get_external_cluster_nums( + source_query_clustering_file, q_names + ) + df.loc[samples_mask, "Cluster"] = [ - get_cluster_num(external_clusters_to_update[sample_id]) - for sample_id in q_names + sample_cluster_num_mapping[sample_id] for sample_id in q_names ] - df.to_csv(previous_query_clustering_file, index=False) + df.to_csv(dest_query_clustering_file, index=False) def get_df_sample_mask( diff --git a/beebop/visualise.py b/beebop/visualise.py index b2601650..9e4c175e 100644 --- a/beebop/visualise.py +++ b/beebop/visualise.py @@ -2,7 +2,11 @@ from rq.job import Dependency from redis import Redis from beebop.poppunkWrapper import PoppunkWrapper -from beebop.utils import replace_filehashes, add_query_ref_status +from beebop.utils import ( + replace_filehashes, + create_subgraphs, + replace_merged_component_filenames, +) from beebop.utils import get_cluster_num from beebop.filestore import PoppunkFileStore, DatabaseFileStore import pickle @@ -182,9 +186,7 @@ def network( # get results from previous job current_job = get_current_job(Redis()) assign_result = current_job.dependency.result - network_internal( - p_hash, fs, db_fs, args, name_mapping, species - ) + network_internal(p_hash, fs, db_fs, args, name_mapping, species) return assign_result @@ -209,5 +211,7 @@ def network_internal( wrapper = PoppunkWrapper(fs, db_fs, args, p_hash, species) wrapper.create_network() - replace_filehashes(fs.output_network(p_hash), name_mapping) - add_query_ref_status(fs, p_hash, name_mapping) + network_folder = fs.output_network(p_hash) + replace_merged_component_filenames(network_folder) + replace_filehashes(network_folder, name_mapping) + create_subgraphs(network_folder, name_mapping) diff --git a/docker/common b/docker/common index a985b07e..9c28df14 100644 --- a/docker/common +++ b/docker/common @@ -21,4 +21,5 @@ TAG_LATEST="${REGISTRY}/${PACKAGE_ORG}/${PACKAGE_NAME}:latest" # development image TAG_DEV_SHA="${REGISTRY}/${TAG_SHA}-${PACKAGE_DEV}" TAG_DEV_BRANCH="${REGISTRY}/${TAG_BRANCH}-${PACKAGE_DEV}" -POPPUNK_VERSION=v2.6.7 # can be version, branch or commit \ No newline at end of file +# TODO: revert +POPPUNK_VERSION=network_relabelling # can be version, branch or commit \ No newline at end of file diff --git a/tests/setup.py b/tests/setup.py index 58ace63a..0e4361a6 100644 --- a/tests/setup.py +++ b/tests/setup.py @@ -39,8 +39,13 @@ def generate_json_pneumo(): "9c00583e2f24fed5e3c6baa87a4bfa4c": "name2.fa" } -db_fs = DatabaseFileStore('./storage/dbs/GPS_v9_ref', - "GPS_v9_external_clusters.csv") +ref_db_fs = DatabaseFileStore( + "./storage/dbs/GPS_v9_ref", "GPS_v9_external_clusters.csv" +) +full_db_fs = DatabaseFileStore( + "./storage/dbs/GPS_v9", "GPS_v9_external_clusters.csv" +) + args = get_args() species = "Streptococcus pneumoniae" species_db_name = "GPS_v9_ref" @@ -56,8 +61,8 @@ def do_assign_clusters(p_hash: str): hashes_list, p_hash, fs, - db_fs, - db_fs, + ref_db_fs, + ref_db_fs, args, species) @@ -66,7 +71,7 @@ def do_network_internal(p_hash: str): do_assign_clusters(p_hash) visualise.network_internal(p_hash, fs, - db_fs, + ref_db_fs, args, name_mapping, species) diff --git a/tests/test_integration.py b/tests/test_integration.py index 2363b67d..71337de6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -148,9 +148,9 @@ def test_results_zip(client): type = 'network' response = client.post("/results/zip", json={ 'projectHash': p_hash, - 'cluster': 'GPSC1', + 'cluster': 'GPSC38', 'type': type}) - assert 'network_component_1.graphml'.encode('utf-8') in response.data + assert 'network_component_38.graphml'.encode('utf-8') in response.data assert 'network_cytoscape.csv'.encode('utf-8') in response.data diff --git a/tests/test_unit.py b/tests/test_unit.py index 6b03b58f..4ae0b13b 100644 --- a/tests/test_unit.py +++ b/tests/test_unit.py @@ -33,7 +33,7 @@ import beebop.schemas from beebop.filestore import PoppunkFileStore, FileStore, DatabaseFileStore - +import networkx as nx fs = setup.fs args = setup.args @@ -154,7 +154,7 @@ def __init__(self, result): visualise.microreact( p_hash, fs, - setup.db_fs, + setup.ref_db_fs, args, setup.name_mapping, setup.species, @@ -162,7 +162,7 @@ def __init__(self, result): {}, ) - time.sleep(30) # wait for jobs to finish + time.sleep(60) # wait for jobs to finish assert os.path.exists( fs.output_microreact(p_hash, 16) + "/microreact_16_core_NJ.nwk" @@ -267,7 +267,7 @@ def __init__(self, result): setup.do_assign_clusters(p_hash) visualise.network( - p_hash, fs, setup.db_fs, args, setup.name_mapping, setup.species + p_hash, fs, setup.ref_db_fs, args, setup.name_mapping, setup.species ) for cluster in external_to_poppunk_clusters.keys(): @@ -624,15 +624,16 @@ def test_send_zip_internal(client): filename2 = "microreact_24_perplexity20.0_accessory_mandrake.dot" assert filename1.encode("utf-8") in response.data assert filename2.encode("utf-8") in response.data + project_hash = "test_network_zip" - cluster = "GPSC1" + cluster = "GPSC38" type = "network" response = app.send_zip_internal( project_hash, type, cluster, storage_location ) response.direct_passthrough = False assert "network_cytoscape.csv".encode("utf-8") in response.data - assert "network_component_1.graphml".encode("utf-8") in response.data + assert "network_component_38.graphml".encode("utf-8") in response.data def test_hex_to_decimal(): @@ -708,52 +709,34 @@ def test_add_files(): assert "7622_5_91.fa".encode("utf-8") not in contents2 -def test_replace_filehashes(): - p_hash = "results_modifications" - folder = fs.output_network(p_hash) - filename_dict = { - "filehash1": "filename1", - "filehash2": "filename2", - "filehash3": "filename3", +def test_replace_filehashes(tmp_path): + + folder = tmp_path / "replace_filehashes" + folder.mkdir() + + # Create test files with hash content + test_data = { + "file1": "filehash1", + "file2": "filehash2", + "file3": "filehash3", } - utils.replace_filehashes(folder, filename_dict) - with open(fs.network_output_component(p_hash, 5), "r") as comp5: - comp5_text = comp5.read() - assert "filename1" in comp5_text - assert "filename3" in comp5_text - assert "filehash1" not in comp5_text - assert "filehash3" not in comp5_text - with open(fs.network_output_component(p_hash, 7), "r") as comp7: - comp7_text = comp7.read() - assert "filename2" in comp7_text - assert "filehash2" not in comp7_text - - -def test_add_query_ref_status(): - p_hash = "results_modifications" + for filename, content in test_data.items(): + (folder / filename).write_text(content) + filename_dict = { "filehash1": "filename1", "filehash2": "filename2", "filehash3": "filename3", } - utils.add_query_ref_status(fs, p_hash, filename_dict) - path = fs.network_output_component(p_hash, 5) - print(path) - xml = ET.parse(path) - graph = xml.getroot() - - def get_node_status(node_no): - node = graph.find( - f".//{{http://graphml.graphdrawing.org/xmlns}}" - f"node[@id='n{node_no}']" - ) - return node.find( - "./{http://graphml.graphdrawing.org/xmlns}data[@key='ref_query']" - ).text - assert get_node_status(21) == "query" - assert get_node_status(22) == "query" - assert get_node_status(20) == "ref" + utils.replace_filehashes(str(folder), filename_dict) + + # Verify results + for filename, original_hash in test_data.items(): + expected_name = filename_dict[original_hash] + content = (folder / filename).read_text() + assert expected_name in content + assert original_hash not in content @patch("beebop.poppunkWrapper.assign_query_hdf5") @@ -970,20 +953,30 @@ def test_get_df_sample_mask(sample_clustering_csv): assert sum(mask) == 2 -def test_update_external_clusters_csv(sample_clustering_csv): +@patch("beebop.utils.get_external_cluster_nums") +def test_update_external_clusters_csv( + mock_get_external_cluster_nums, sample_clustering_csv +): not_found_samples = ["sample1", "sample3"] - new_clusters = {"sample1": "GPSC69", "sample3": "GPSC420"} - + sample_cluster_num_mapping = {"sample1": "11", "sample3": "69;191"} + source_query_clustering = "tmp_query_clustering.csv" + mock_get_external_cluster_nums.return_value = sample_cluster_num_mapping utils.update_external_clusters_csv( - sample_clustering_csv, not_found_samples, new_clusters + sample_clustering_csv, + source_query_clustering, + not_found_samples, ) df = pd.read_csv(sample_clustering_csv) - assert df.loc[df["sample"] == "sample1", "Cluster"].values[0] == "69" + + mock_get_external_cluster_nums.assert_called_once_with( + source_query_clustering, not_found_samples + ) + assert df.loc[df["sample"] == "sample1", "Cluster"].values[0] == "11" assert ( df.loc[df["sample"] == "sample2", "Cluster"].values[0] == "309;20;101" ) # Unchanged - assert df.loc[df["sample"] == "sample3", "Cluster"].values[0] == "420" + assert df.loc[df["sample"] == "sample3", "Cluster"].values[0] == "69;191" assert ( df.loc[df["sample"] == "sample4", "Cluster"].values[0] == "40" ) # Unchanged @@ -1229,7 +1222,7 @@ def test_update_external_clusters( config.external_clusters_prefix, ) mock_update_external_clusters.assert_called_once_with( - previous_query_clustering, not_found, new_external_clusters + previous_query_clustering, "tmp_previous_query_clustering", not_found ) assert external_clusters == { @@ -1412,3 +1405,144 @@ def test_save_external_to_poppunk_clusters(tmp_path): "GPSC69": "1", "GPSC420": "2", } + + +def test_get_component_filenames(tmp_path): + network_folder = tmp_path / "network" + network_folder.mkdir() + + # Create matching files + expected_files = [ + network_folder / "network_component_1;88.graphml", + network_folder / "network_component_2.graphml", + ] + for f in expected_files: + f.touch() + + # Create non-matching files + (network_folder / "other_file.txt").touch() + (network_folder / "network_other.graphml").touch() + + result = utils.get_component_filenames(str(network_folder)) + + assert len(result) == 2 + assert sorted(result) == sorted([str(f) for f in expected_files]) + + +def test_get_df_filtered_by_samples(sample_clustering_csv): + """Test getting mask for existing samples""" + samples = ["sample1", "sample3"] + + filtered_df = utils.get_df_filtered_by_samples( + sample_clustering_csv, samples + ) + + # Check DataFrame + assert isinstance(filtered_df, pd.DataFrame) + assert len(filtered_df) == 2 + assert list(filtered_df["sample"]) == ["sample1", "sample3"] + + +@patch("beebop.utils.build_subgraph") +@patch("beebop.utils.write_graphml") +@patch("beebop.utils.add_query_ref_to_graph") +@patch("beebop.utils.get_component_filenames") +def test_create_subgraphs( + mock_get_component_filenames, + mock_add_query_ref_to_graph, + mock_write_graphml, + mock_build_subgraph, +): + mock_get_component_filenames.return_value = [ + "network_component_1.graphml", + ] + mock_subgraph = Mock() + mock_build_subgraph.return_value = mock_subgraph + filename_dict = { + "filehash1": "filename1", + "filehash2": "filename2", + } + query_names = list(filename_dict.values()) + + utils.create_subgraphs("network_folder", filename_dict) + + mock_build_subgraph.assert_called_once_with( + "network_component_1.graphml", query_names + ) + mock_add_query_ref_to_graph.assert_called_once_with( + mock_subgraph, query_names + ) + mock_write_graphml.assert_called_once_with( + mock_subgraph, "pruned_network_component_1.graphml" + ) + + +@patch("beebop.utils.read_graphml") +def test_build_subgraph(mock_read_graphml): + graph = nx.complete_graph(50) # 50 nodes fully conected + query_names = ["sample1", "sample2", "sample3"] + graph.nodes[45]["id"] = "sample2" + mock_read_graphml.return_value = graph + + subgraph = utils.build_subgraph("network_component_1.graphml", query_names) + + assert len(subgraph.nodes) == 30 # max number + assert subgraph.has_node(45) is True + + +def test_add_query_ref_to_graph(): + graph = nx.complete_graph(10) # 10 nodes fully conected + query_names = ["sample1", "sample2", "sample3"] + graph.nodes[0]["id"] = "sample2" + + utils.add_query_ref_to_graph(graph, query_names) + + assert graph.nodes[0]["ref_query"] == "query" + for i in range(1, 10): + assert graph.nodes[i]["ref_query"] == "ref" + + +def create_test_files(network_folder, filenames): + """Helper to create test files in the network folder""" + for filename in filenames: + filepath = os.path.join(network_folder, filename) + with open(filepath, "w") as f: + f.write("test content") + + +def test_replace_merged_component_filenames(tmp_path): + network_dir = tmp_path / "network" + network_dir.mkdir() + network_folder = str(network_dir) + create_test_files( + network_folder, + [ + "network_component_10.graphml", + "network_component_15;5;25.graphml", + "network_component_6;4;2.graphml", + "network_component_2.graphml", + ], + ) + + utils.replace_merged_component_filenames(network_folder) + + assert os.path.exists( + os.path.join(network_folder, "network_component_10.graphml") + ) + assert os.path.exists( + os.path.join(network_folder, "network_component_5.graphml") + ) + assert os.path.exists( + os.path.join(network_folder, "network_component_2.graphml") + ) + + +def test_get_external_cluster_nums(sample_clustering_csv): + samples = ["sample1", "sample2"] + + result = utils.get_external_cluster_nums(sample_clustering_csv, samples) + + assert result == { + "sample1": "10", + "sample2": "309;20;101", + }