diff --git a/.gitignore b/.gitignore index f44fdad..3a132ed 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ venv/ .venv .momentum /server/.hypothesis -.hypothesis \ No newline at end of file +.hypothesis +db \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 79159b6..6da09d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,13 +8,13 @@ langchain langchain-community langchain-openai langchain-experimental -crewai==0.19.0 +crewai PyGithub==2.3.0 psycopg2-binary==2.9.6 firebase-admin==6.5.0 neo4j==5.2 google-cloud-secret-manager -posthog==1.4.0 +posthog==3.5.0 starlette==0.35.0 loguru python-dotenv @@ -26,4 +26,11 @@ PyJWT setuptools portkey_ai gunicorn -sentry-sdk[fastapi] \ No newline at end of file +sentry-sdk[fastapi] +pydantic==2.7.4 +pydantic_core==2.18.4 +psycopg +embedchain==0.1.103 +kombu==5.4.0rc1 +embedchain[weaviate] +celery[redis] \ No newline at end of file diff --git a/server/alembic/env.py b/server/alembic/env.py index 1725034..847c39e 100644 --- a/server/alembic/env.py +++ b/server/alembic/env.py @@ -6,7 +6,7 @@ from dotenv import load_dotenv from sqlalchemy import engine_from_config, pool -from server.schema.base import Base +from server.schemas.base import Base CURRENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) sys.path.append(CURRENT_DIR) @@ -26,8 +26,7 @@ # from myapp import mymodel # target_metadata = mymodel.Base.metadata target_metadata = Base.metadata -# escaped_url = os.environ["POSTGRES_SERVER"].replace("%", "%%") -escaped_url = "postgresql://postgres:mysecretpassword@localhost:5432/momentum" +escaped_url = os.environ["POSTGRES_SERVER"].replace("%", "%%") config.set_main_option("sqlalchemy.url", escaped_url) # other values from the config, defined by the needs of env.py, # can be acquired: diff --git a/server/alembic/versions/6c007877a09d_jmd.py b/server/alembic/versions/6c007877a09d_jmd.py index 66c7d1a..5fec25d 100644 --- a/server/alembic/versions/6c007877a09d_jmd.py +++ b/server/alembic/versions/6c007877a09d_jmd.py @@ -10,7 +10,6 @@ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql -from server.schema.projects import ProjectStatusEnum # revision identifiers, used by Alembic. revision: str = '6c007877a09d' diff --git a/server/change_detection.py b/server/change_detection.py index db8e189..6e516d8 100644 --- a/server/change_detection.py +++ b/server/change_detection.py @@ -4,7 +4,6 @@ from tree_sitter_languages import get_language, get_parser import os -import argparse # Load Python grammar for Tree-sitter parser = get_parser("python") diff --git a/server/dependencies.py b/server/dependencies.py index b47d9fb..e3a689f 100644 --- a/server/dependencies.py +++ b/server/dependencies.py @@ -25,7 +25,7 @@ def add_codebase_map_path(self, directory): return f"{directory}/{db_path}" - def dependencies_from_function(self, project_details, function_identifier: str, + async def dependencies_from_function(self, project_details, function_identifier: str, function_to_test: str, flow: list, print_text: bool = True, # optionally prints text; helpful for understanding the function & debugging @@ -73,17 +73,17 @@ def dependencies_from_function(self, project_details, function_identifier: str, print_messages(detect_messages) - explanation = llm_call(self.detect_client, detect_messages) + explanation = await llm_call(self.detect_client, detect_messages) return [x.strip() for x in explanation.content.split(",") if x.strip() != ""] - def get_dependencies(self, project_details, function_identifier): + async def get_dependencies(self, project_details, function_identifier): flow = get_flow(function_identifier, project_details[2]) flow_trimmed = [x.split(':')[1] for x in flow if x != function_identifier] output = [] for function in flow: node = get_node(function, project_details) code = GithubService.fetch_method_from_repo(node) - output += ( self.dependencies_from_function(project_details, function, code, flow_trimmed)) + output += ( await self.dependencies_from_function(project_details, function, code, flow_trimmed)) return output+flow_trimmed diff --git a/server/endpoint_detection.py b/server/endpoint_detection.py index f179272..8307c4c 100755 --- a/server/endpoint_detection.py +++ b/server/endpoint_detection.py @@ -1,3 +1,4 @@ +import asyncio import json import os import re @@ -32,39 +33,8 @@ def __init__( self.router_prefix_file_mapping = router_prefix_file_mapping self.file_index = file_index - # SQLite database setup - - def setup_database(self): - conn = psycopg2.connect(os.getenv("POSTGRES_SERVER")) - cursor = conn.cursor() - try: - # Create table if it doesn't exist - cursor.execute(""" - CREATE TABLE IF NOT EXISTS endpoints ( - path TEXT, - identifier TEXT, - test_plan TEXT, - preferences TEXT, - project_id integer NOT NULL, - PRIMARY KEY (project_id, identifier), - CONSTRAINT fk_project FOREIGN KEY (project_id) - REFERENCES projects (id) - ON DELETE CASCADE - - ) - """) - - # Commit the transaction - conn.commit() - except psycopg2.Error as e: - print(f"PostgreSQL error: {e}") - finally: - # Close the connection - if conn: - cursor.close() - conn.close() - + def extract_path(self, decorator): # Find the position of the first opening parenthesis and the following comma start = decorator.find("(") + 1 @@ -472,8 +442,8 @@ def extract_function_metadata(self, node): return function_name, parameters, start, end, text - def analyse_endpoints(self, project_id): - self.setup_database() + async def analyse_endpoints(self, project_id, user_id): + conn = psycopg2.connect(os.getenv("POSTGRES_SERVER")) cursor = conn.cursor() detected_endpoints = [] @@ -512,6 +482,9 @@ def analyse_endpoints(self, project_id): ) conn.close() + from server.knowledge_graph.flow import understand_flows + + asyncio.create_task(understand_flows(project_id, self.directory, user_id)) def get_qualified_endpoint_name(self, path, prefix): if prefix == None: @@ -686,7 +659,7 @@ def get_test_plan_preferences(self, identifier, project_id): else: test_plan = None # No test plan found for the given identifier - if row[1]: + if row and row[1]: preferences = json.loads( row[1] ) # Deserialize the test plan back into a Python dictionary diff --git a/server/handler/user_handler.py b/server/handler/user_handler.py index 24956a2..d0258c0 100644 --- a/server/handler/user_handler.py +++ b/server/handler/user_handler.py @@ -4,7 +4,7 @@ from server.db.session import SessionManager from server.models.user import CreateUser -from server.schema.users import User +from server.schemas.users import User class UserHandler: diff --git a/server/knowledge_graph/flow.py b/server/knowledge_graph/flow.py new file mode 100644 index 0000000..828f0c4 --- /dev/null +++ b/server/knowledge_graph/flow.py @@ -0,0 +1,216 @@ + +import json +from typing import List, Dict +import os +from server.utils.ai_helper import get_llm_client , llm_call,print_messages +from langchain.schema import SystemMessage, HumanMessage +import hashlib +import psycopg2 +from server.utils.github_helper import GithubService +from utils.graph_db_helper import Neo4jGraph +neo4j_graph = Neo4jGraph() + + +class FlowQuery: + def __init__(self, query: str): + self.query = query + +class FlowInference: + def __init__(self, project_id: str, directory: str, user_id: str): + self.project_id = project_id + self.directory = directory + self.user_id = user_id + self.explain_client = get_llm_client(user_id, "gpt-3.5-turbo-0125") + self.setup_database() + + def setup_database(self): + conn = psycopg2.connect(os.environ['POSTGRES_SERVER']) + + cursor = conn.cursor() + cursor.execute(''' + CREATE TABLE IF NOT EXISTS inference ( + key TEXT, + inference TEXT, + hash TEXT, + explanation TEXT, + project_id INTEGER + ) + ''') + conn.commit() + if conn: + conn.close() + + def insert_inference(self, key: str, inference: str, project_id: str, overall_explanation: str, hash: str): + conn = psycopg2.connect(os.environ['POSTGRES_SERVER']) + cursor = conn.cursor() + cursor.execute("INSERT INTO inference (key, inference, hash, explanation, project_id) VALUES (%s, %s, %s, %s, %s)", (key, inference, hash,overall_explanation, project_id)) + conn.commit() + conn.close() + + def _get_code_for_node(self, node): + return GithubService.fetch_method_from_repo(node) + + def get_flow(self, endpoint_id, project_id): + flow = () + nodes_pro = neo4j_graph.find_outbound_neighbors( + endpoint_id=endpoint_id, project_id=project_id, with_bodies=True + ) + for node in nodes_pro: + if "id" in node: + flow += (node["id"],) + elif "neighbor" in node: + flow += (node["neighbor"]["id"],) + return flow + + def get_code_flow_by_id(self, endpoint_id): + code = "" + nodes = self.get_flow(endpoint_id, self.project_id) + for node in nodes: + node = self.get_node(node) + code += ( + "\n" + + GithubService.fetch_method_from_repo(node) + + "\n code: \n" + + self._get_code_for_node(node) + ) + return code + + def get_node(self, function_identifier): + return neo4j_graph.get_node_by_id(function_identifier, self.project_id) + + def get_endpoints(self) : + conn = psycopg2.connect(os.environ['POSTGRES_SERVER']) + cursor = conn.cursor() + paths = [] + try: + cursor.execute("SELECT path, identifier FROM endpoints where project_id=%s", (self.project_id, )) + endpoints = cursor.fetchall() + + for endpoint in endpoints: + paths.append({"path": endpoint[0], "identifier": endpoint[1]}) + + except psycopg2.Error as e: + print("An error occurred: 9", e) + finally: + conn.close() + return paths + + + async def explanation_from_function(self, + function_to_test: str, # Python function to test, as a string +) -> str: + print(function_to_test) + """Returns a integration test for a given Python function, using a 3-step GPT prompt.""" + + # Step 1: Generate an explanation of the function + explain_system_message = SystemMessage( + content="You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists.", + ) + explain_user_message = HumanMessage( + content= f"""Please explain the following Python function. Review what each element of the function is doing precisely and what the author's intentions may have been. Organize your explanation as a markdown-formatted, bulleted list. + + ```python + {function_to_test} + ```""", + ) + + explain_messages = [explain_system_message, explain_user_message] + print_messages(explain_messages) + + + explanation = await llm_call(self.explain_client, explain_messages) + return explanation.content + + async def _get_explanation_for_function(self,function_identifier, node): + conn = psycopg2.connect(os.environ['POSTGRES_SERVER']) + cursor = conn.cursor() + if "code" in node: + code_hash = hashlib.sha256(node["code"].encode('utf-8')).hexdigest() + cursor.execute("SELECT explanation FROM explanation WHERE identifier=? AND hash=?", (function_identifier, code_hash)) + explanation_row = cursor.fetchone() + + if explanation_row: + explanation = explanation_row[0] + else: + explanation = await self.explanation_from_function(node["code"]) + cursor.execute("INSERT INTO explanation (identifier, hash, explanation) VALUES (?, ?, ?)", (function_identifier, code_hash, explanation)) + conn.commit() + + return explanation + + async def generate_overall_explanation(self, endpoint: Dict) -> str: + conn = psycopg2.connect(os.environ['POSTGRES_SERVER']) + cursor = conn.cursor() + code = self.get_code_flow_by_id(endpoint["identifier"]) + if code != '': + code_hash = hashlib.sha256(code.encode('utf-8')).hexdigest() + cursor.execute("SELECT inference FROM inference WHERE key=%s AND hash=%s", (endpoint["path"], code_hash)) + explanation_row = cursor.fetchone() + if explanation_row: + return explanation_row[0], code_hash + else: + result = await self.generate_explanation(code), code_hash + return result + cursor.close() + conn.close() + + return None, None + + + async def generate_explanation(self, code: str) -> str: + explain_system_message = SystemMessage( + content="You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists.", + ) + explain_user_message = HumanMessage( + content=f"""Please analyse the following Python functions. Review what each element of the function is doing precisely and what the author's intentions may have been. Return the overall intent of the API call. Organize your explanation as a markdown-formatted, bulleted list. +```python +{code} +``` +""") + explain_messages = [explain_system_message, explain_user_message] + explanation = await llm_call(self.explain_client, explain_messages) + return explanation.content + + async def get_intent_from_explanation(self, explanations: str) -> str: + explain_system_message = SystemMessage( + content="You are a world-class Python developer with an eagle eye for unintended bugs and edge cases. You carefully explain code with great detail and accuracy. You organize your explanations in markdown-formatted, bulleted lists.", + ) + explain_user_message = HumanMessage( + content=f"""You are provided the following explanation of a series of Python functions in the call stack of a given API. From this explanation, extract the intent of the API call. Return only the intent, nothing else. +``` +{explanations} +``` +""") + explain_messages = [explain_system_message, explain_user_message] + explanation = await llm_call(self.explain_client, explain_messages) + return explanation.content + + def get_inferencess(self) -> List[Dict]: + conn = psycopg2.connect(os.environ['POSTGRES_SERVER']) + cursor = conn.cursor() + cursor.execute("SELECT key FROM inference where project_id=%s", (self.project_id,)) + inferences = cursor.fetchall() + conn.close() + return [x[0] for x in inferences] + + async def infer_flows(self) -> Dict[str, str]: + endpoints = self.get_endpoints() + inferred_flows = self.get_inferencess() + flow_explanations = {} + + for endpoint in endpoints: + if endpoint["path"] not in inferred_flows: + overall_explanation, code_hash = await self.generate_overall_explanation(endpoint) + if overall_explanation is not None: + flow_explanations[endpoint["path"]] = (await self.get_intent_from_explanation(overall_explanation), overall_explanation, code_hash) + + return flow_explanations + +async def understand_flows(project_id, directory, user_id): + flow_inference = FlowInference(project_id, directory, user_id) + flow_explanations = await flow_inference.infer_flows() + for key, inference in flow_explanations.items(): + flow_inference.insert_inference(key, inference[0], project_id, inference[1], inference[2]) + from server.knowledge_graph.knowledge_graph import KnowledgeGraph + KnowledgeGraph(project_id) + diff --git a/server/knowledge_graph/knowledge_graph.py b/server/knowledge_graph/knowledge_graph.py new file mode 100644 index 0000000..3797007 --- /dev/null +++ b/server/knowledge_graph/knowledge_graph.py @@ -0,0 +1,25 @@ +from embedchain.loaders.postgres import PostgresLoader +from embedchain.pipeline import Pipeline as App +import os +import psycopg2 + +class KnowledgeGraph: + _instance = None + + def __new__(cls, project_id): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.postgres_loader = PostgresLoader({"url":os.environ['POSTGRES_SERVER']}) + cls._instance.app = App() + cls._instance.init_app(project_id) + return cls._instance + + def init_app(self, project_id): + self.app.add(f"SELECT key, explanation, inference FROM inference WHERE project_id={project_id};", data_type='postgres', loader=self.postgres_loader, metadata={"project_id": project_id}) + self.app.add(f"SELECT * FROM endpoints WHERE project_id={project_id};", data_type='postgres', loader=self.postgres_loader, metadata={"project_id": project_id}) + self.app.add(f"SELECT identifier, explanation FROM explanation WHERE project_id={project_id};", data_type='postgres', loader=self.postgres_loader, metadata={"project_id": project_id}) + self.app.add(f"SELECT * FROM pydantic WHERE project_id={project_id};", data_type='postgres', loader=self.postgres_loader, metadata={"project_id": project_id}) + + def query(self, query, project_id): + prefix = "Always INCLUDE ALL RELEVANT FILEPATH, FUNCTION NAME AND VARIABLE NAMES in your response. If you are asked about an API: ALWAYS include its HTTP verb and url path along with its identifier in the response: \n" + return self.app.query(prefix+query, metadata={"project_id": project_id}) \ No newline at end of file diff --git a/server/parse.py b/server/parse.py index 6487b71..13b0f43 100755 --- a/server/parse.py +++ b/server/parse.py @@ -104,14 +104,11 @@ def put_pydantic_class(filepath, classname, definition, project_id): conn.close() -def get_pydantic_class(classname): +def get_pydantic_class(classname, project_id): conn = psycopg2.connect(os.getenv("POSTGRES_SERVER")) cursor = conn.cursor() try: - cursor.execute( - "SELECT filepath, definition FROM pydantic WHERE classname =" - f" '{classname}'" - ) + cursor.execute("SELECT filepath, definition FROM pydantic WHERE project_id=%s AND classname = %s", (project_id, classname)) conn.commit() except psycopg2.Error as e: @@ -129,17 +126,14 @@ def get_pydantic_class(classname): return edited_definitions -def get_pydantic_classes(classnames, directory): +def get_pydantic_classes(classnames, project_id): try: conn = psycopg2.connect(os.getenv("POSTGRES_SERVER")) cursor = conn.cursor() definitions = [] - placeholders = ", ".join("%s" for classname in classnames) - query = ( - "SELECT filepath, classname, definition FROM pydantic WHERE" - f" classname IN ({placeholders})" - ) - cursor.execute(query, classnames) + placeholders = ', '.join('%s' for classname in classnames) + query = f"SELECT filepath, classname, definition FROM pydantic WHERE project_id=%s AND classname IN ({placeholders})" + cursor.execute(query, (project_id, *classnames)) definitions.extend(cursor.fetchall()) except psycopg2.Error as e: print("An error occurred: 8", e) @@ -183,16 +177,12 @@ def add_node_safe( project_id, ) except psycopg2.IntegrityError: - print( - f"Node with identifier {function_identifier} already exists." - " Skipping insert." - ) + print(f"Node with identifier {function_identifier} already exists. Skipping insert.") + return function_identifier -def add_class_node_safe( - directory, file_path, class_name, start, end, project_id -): - function_identifier = file_path.replace(directory, "") + ":" + class_name +def add_class_node_safe(directory, file_path, class_name, start, end, project_id): + function_identifier = file_path.replace(directory, '') + ":" + class_name try: neo4j_graph.upsert_node( function_identifier, @@ -205,10 +195,8 @@ def add_class_node_safe( project_id, ) except psycopg2.IntegrityError: - print( - f"Node with identifier {function_identifier} already exists." - " Skipping insert." - ) + print(f"Node with identifier {function_identifier} already exists. Skipping insert.") + return function_identifier def get_node_text(node, source_code): @@ -906,7 +894,7 @@ def extract_function_metadata(node, parameters=[], class_context=None): # todo: optimise for single run -def analyze_directory(directory, user_id, project_id): +async def analyze_directory(directory, user_id, project_id): _create_pydantic_table(directory) _create_explanation_table_if_not_exists(directory) user_defined_functions = {} @@ -1028,9 +1016,9 @@ def analyze_directory(directory, user_id, project_id): } ( - EndpointManager( + await EndpointManager( directory, router_metadata_file_mapping, file_index - ).analyse_endpoints(project_id) + ).analyse_endpoints(project_id, user_id) ) delete_folder(directory) @@ -1042,12 +1030,11 @@ def get_code_flow_by_id(endpoint_id, project_id): endpoint_id, project_id, with_bodies=True ) for node in nodes_pro: - if "code" in node: - if "file" in json.loads(node[2]): - code += ( + if "file" in json.loads(node[2]): + code += ( f"File: {json.loads(node[2])['file'].replace(dir, '')}\n" ) - code += json.loads(node[2])["code"] + "\n" + code += GithubService.fetch_method_from_repo(node[2]) + "\n" return code @@ -1110,6 +1097,8 @@ def get_code_for_function(function_identifier): def get_node(function_identifier, project_details): return neo4j_graph.get_node_by_id(function_identifier, project_details[2]) +def get_node_by_id(node_id, project_id): + return neo4j_graph.get_node_by_id(node_id, project_id) def get_values(repo_branch, project_manager, user_id): repo_name = repo_branch.repo_name.split("/")[-1] diff --git a/server/plan.py b/server/plan.py index f8aa2f3..ac71f33 100755 --- a/server/plan.py +++ b/server/plan.py @@ -45,7 +45,7 @@ def __init__(self, user_id): self.test_client = self.user_pref_openai_client # example of a function that uses a multi-step prompt to write integration tests - def explanation_from_function( + async def explanation_from_function( self, function_to_test: str, # Python function to test, as a string print_text: bool = True, # optionally prints text; helpful for understanding the function & debugging @@ -77,7 +77,7 @@ def explanation_from_function( if print_text: print_messages(explain_messages) - explanation = llm_call(self.explain_client, explain_messages) + explanation = await llm_call(self.explain_client, explain_messages) return explanation.content async def _plan( @@ -91,23 +91,39 @@ async def _plan( plan_user_message = HumanMessage( content=f"""A good integration test suite should aim to: - - Test the function's behavior for a wide range of possible inputs - - Test edge cases that the author may not have foreseen - - Take advantage of the features of `{test_package}` to make the tests easy to write and maintain - - Be easy to read and understand, with clean code and descriptive names - - Be deterministic, so that the tests always pass or fail in the same way - - Evaluate what scenarios are possible - - Reuse code by using fixtures and other testing utilities for common setup and mocks - - To help integration test the flow above, list diverse scenarios that the function should be able to handle (and under each scenario, include a few examples). - Include exactly 3 scenario statements of happpy paths and 3 scenarios of edge cases. - Format your output in JSON format as such, each scenario is only a string statement: - {{ - \"happy_path\": [\"happy_scenario0\", \"happy_scenario1\", happy_scenario2,\" happy_scenario3\", \"happy_scenario4\", \"happy_scenario5\"], - \"edge_case\": [\"edge_scenario1\",\" edge_scenario2\", \"edge_scenario3\"] - }} - - Ensure that your output is JSON parsable.""" + - Test the function's behavior for a wide range of possible inputs +- Test edge cases that the author may not have foreseen +- Take advantage of the features of `{test_package}` to make the tests easy to write and maintain +- Be easy to read and understand, with clean code and descriptive names +- Be deterministic, so that the tests always pass or fail in the same way +Happy Path Scenarios: +- Test cases that cover the expected normal operation of the function, where the inputs are valid and the function produces the expected output without any errors. +Edge Case Scenarios: +- Test cases that explore the boundaries of the function's input domain, such as: + - Boundary values for input parameters + - Unexpected or invalid input types + - Error conditions and exceptions + - Interactions with external dependencies (e.g., databases, APIs) +- These scenarios test the function's robustness and error handling capabilities. +To help integration test the flow above: +1. Analyze the provided code and explaination. +2. List diverse happy path and edge case scenarios that the function should handle. +3. Include exactly 3 scenario statements of happpy paths and 3 scenarios of edge cases. +4. Format your output in JSON format as such, each scenario is only a string statement: +{{ +\"happy_path\": [ + \"happy_scenario1\", + \"happy_scenario2\", + ... +], +\"edge_case\": [ + \"edge_case1\", + \"edge_case2\", + ... +] +}} +5. Ensure that your output is JSON parsable. +""" ) plan_messages = [ explain_assistant_message, @@ -117,7 +133,7 @@ async def _plan( print("Plan messages:") print_messages(plan_messages) - plan = llm_call(self.plan_client, plan_messages) + plan = await llm_call(self.plan_client, plan_messages) plan_assistant_message = AIMessage(content=plan.content) # Step 2b: If the plan is short, ask GPT to elaborate further @@ -142,7 +158,7 @@ async def _plan( print_messages([elaboration_user_message]) elaboration = "" if elaboration_needed: - elaboration = llm_call(self.plan_client, elaboration_messages) + elaboration = await llm_call(self.plan_client, elaboration_messages) return elaboration.content else: return plan @@ -205,7 +221,7 @@ async def run_tests(self, identifier, content): except Exception as e: print(e) - def _get_explanation_for_function( + async def _get_explanation_for_function( self, function_identifier, node, project_id ): conn = psycopg2.connect(os.getenv("POSTGRES_SERVER")) @@ -239,7 +255,7 @@ def _get_explanation_for_function( code = GithubService.fetch_method_from_repo( node ) - explanation = self.explanation_from_function(code) + explanation = await self.explanation_from_function(code) cursor.execute( "INSERT INTO explanation (identifier, hash, explanation," " project_id) VALUES (%s, %s, %s, %s)", @@ -283,12 +299,7 @@ async def generate_tests( ) -> str: execute_system_message = SystemMessage( content=( - "You are a world-class Python SDET who specialises in FastAPI," - " pytest, pytest-mocks with an eagle eye for unintended bugs" - " and edge cases. You write careful, accurate integration" - " tests using the aforementioned frameworks. When asked to" - " reply only with code, you write all of your code in a single" - " block." +"You are a world-class Python SDET who specialises in FastAPI, pytest, pytest-mocks with an eagle eye for unintended bugs and edge cases. You write careful, accurate integration tests using the aforementioned frameworks. When asked to reply only with code, you write all of your code in a single block." ), ) plan_message = AIMessage(content=plan) @@ -301,9 +312,13 @@ async def generate_tests( The complete path of the endpoint is {endpoint_path}. It is important to use this complete path in the test API call because the code might not contain prefixes. Consider the following points while writing the integration tests: - + * Analyze the provided function code and identify the key components, such as dependencies, database connections, and external API calls, that need to be mocked or set up for testing. * Review the provided test plan and understand the different test scenarios that need to be covered. Consider edge cases, error handling, and potential variations in input data. * Use the provided pydantic classes ({pydantic_classes}) to create the necessary pydantic objects for the test data and mock data setup. This ensures that the tests align with the expected data structures used in the function. + * Pay attention to the preferences provided ({preferences}). If a list of entities (functions, classes, databases, etc.) is specified to be mocked, strictly follow these preferences. If the preferences are empty, use your judgment to determine which components should be mocked, such as the database and any external API calls. + * Utilize FastAPI testing features like TestClient and dependency overrides to set up the test environment. Create fixtures to minimize code duplication and improve test maintainability. + * ALWAYS create a new FastAPI app in the test client and include the relevant routers in it for testing. Do not assume where the main FastAPI app is defined. + * When setting up mocks, use the pytest-mock library. Check if the output structure is defined in the code and use that to create the expected output response data for the test cases. If not defined, infer the expected output based on the test plan outcomes and the provided code under test. * Use your judgment to determine which components should be mocked, such as the database and any external API calls. Don't mock internal methods unless specified. * Utilize FastAPI TestClient and dependency overrides wherever possible to set up the tests. Create fixtures to minimize code duplication. * If there is authorisation involved, mock the authorisation middleware/dependency to always be authenticated. @@ -311,9 +326,11 @@ async def generate_tests( * ALWAYS create a new FastAPI app in the test client and IMPORT THE RELEVANT ROUTERS in it for testing. DO NO TRY to import the main FastAPI app. DO NOT WRITE any new routers in the test file. * Use pytest-mocks library only for mocking. For mocked response objects, use the output structure IF it is defined in the code ELSE infer the expected output structure based on the code and test plan. * When defining the target using pytest mocks, ensure that the target path is the path of the call and not the path of the definition. - For a func_a defined at src.utils.helper and imported in code as from src.utils.helper import func_a, the mock would look like : mocker.patch('src.pipeline.node_1.func_a', return_value="some_value") + * For a func_a defined at src.utils.helper and imported in code as from src.utils.helper import func_a, the mock would look like : mocker.patch('src.pipeline.node_1.func_a', return_value="some_value") * Write clear and concise test case names that reflect the scenario being tested. Use assertions to validate the expected behavior and handle potential exceptions gracefully. * Use appropriate setup and teardown methods to manage test resources efficiently. + * Consider the performance implications of the tests and optimize them where possible. Use appropriate setup and teardown methods to manage test resources efficiently. + * Provide meaningful error messages and logging statements to aid in debugging and troubleshooting failed tests. * Reply only with complete code, formatted as follows: ```python # imports @@ -333,7 +350,7 @@ async def generate_tests( if print_text: print_messages([execute_system_message, execute_user_message]) - execution = llm_call( + execution = await llm_call( self.test_client, execute_messages, print_text, temperature ) @@ -382,7 +399,7 @@ async def generate_test_plan_for_endpoint( + "\n code: \n" + self._get_code_for_node(node) + "\n explanation: \n" - + self._get_explanation_for_function( + + await self._get_explanation_for_function( function, node, project_details[2] ) ) diff --git a/server/router.py b/server/router.py index 99f5cd9..2a096f6 100644 --- a/server/router.py +++ b/server/router.py @@ -11,7 +11,6 @@ from server.utils.APIRouter import APIRouter from fastapi.requests import Request from github import Github -from github.GithubException import UnknownObjectException from server.models.repo_details import ( PreferenceDetails, ProjectStatusEnum, @@ -33,7 +32,7 @@ from server.blast_radius_detection import get_paths_from_identifiers from server.utils.github_helper import GithubService from server.utils.graph_db_helper import Neo4jGraph -from server.utils.parse_helper import setup_project_directory, reparse_cleanup +from server.utils.parse_helper import setup_project_directory, delete_folder, reparse_cleanup from server.dependencies import Dependencies from server.auth import check_auth from server.test_agent.crew import GenerateTest @@ -50,7 +49,7 @@ @api_router.post("/parse") -def parse_directory( +async def parse_directory( request: Request, repo_branch: RepoDetails, user=Depends(check_auth) ): dir_details = "" @@ -82,7 +81,7 @@ def parse_directory( dir_details, project_id = setup_project_directory( owner, repo_name, branch_name, app_auth, repo, user_id, project_id ) - analyze_directory(dir_details, user_id, project_id) + await analyze_directory(dir_details, user_id, project_id) new_project = True message = "The project has been parsed successfully" else: @@ -93,7 +92,7 @@ def parse_directory( dir_details, project_id = setup_project_directory(owner, repo_name, branch_name, app_auth, repo, user_id, project_id) - analyze_directory(dir_details, user_id, project_id) + await analyze_directory(dir_details, user_id, project_id) new_project = False message = "The project has been re-parsed successfully" else: @@ -236,7 +235,7 @@ def get_dependencies( @api_router.get("/endpoints/dependencies/more") -def get_more_dependencies_ai( +async def get_more_dependencies_ai( project_id: int, endpoint_id: str, user=Depends(check_auth) ): user_id = user["user_id"] @@ -244,7 +243,7 @@ def get_more_dependencies_ai( project_id, user_id ) if project_details is not None: - graph_structure = Dependencies(user["user_id"]).get_dependencies( + graph_structure = await Dependencies(user["user_id"]).get_dependencies( project_details, endpoint_id ) return graph_structure @@ -397,7 +396,8 @@ async def generate_test( endpoint_path, str(test_plan), user["user_id"], - project_details[1], + project_dir, + project_id ).write_tests(identifier, preferences, no_of_test_generated, project_details, user_id) else: raise HTTPException( diff --git a/server/routers/webhook.py b/server/routers/webhook.py index 5bcde2a..ee60fc7 100644 --- a/server/routers/webhook.py +++ b/server/routers/webhook.py @@ -74,7 +74,7 @@ async def parse_repos(payload, request: Request): repo_details, user_id ) - analyze_directory(dir_details, user_id, project_id) + await analyze_directory(dir_details, user_id, project_id) project_manager.update_project_status( project_id, ProjectStatusEnum.READY ) @@ -110,7 +110,7 @@ async def parse_repos(payload, request: Request): user_id, project_id ) - analyze_directory(dir_details, user_id, project_id) + await analyze_directory(dir_details, user_id, project_id) request.state.additional_data.append({ "repository_name": repo_name, "branch_name": branch_name, diff --git a/server/schema/__init__.py b/server/schemas/__init__.py similarity index 100% rename from server/schema/__init__.py rename to server/schemas/__init__.py diff --git a/server/schema/base.py b/server/schemas/base.py similarity index 100% rename from server/schema/base.py rename to server/schemas/base.py diff --git a/server/schema/endpoints.py b/server/schemas/endpoints.py similarity index 80% rename from server/schema/endpoints.py rename to server/schemas/endpoints.py index 4a09861..739529d 100644 --- a/server/schema/endpoints.py +++ b/server/schemas/endpoints.py @@ -1,25 +1,25 @@ +import logging from sqlalchemy import Column, Integer, Text -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.schema import ForeignKeyConstraint, PrimaryKeyConstraint - -from server.schema.base import Base +from server.schemas.base import Base class Endpoint(Base): __tablename__ = "endpoints" + path = Column(Text) identifier = Column(Text) test_plan = Column(Text) preferences = Column(Text) project_id = Column(Integer, nullable=False) - __table_args__ = ( PrimaryKeyConstraint("project_id", "identifier"), ForeignKeyConstraint( ["project_id"], ["projects.id"], ondelete="CASCADE" ), ) - - # Relationship to a Project model (assuming it exists) project = relationship("Project", back_populates="endpoints") + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/server/schema/explanation.py b/server/schemas/explanation.py similarity index 89% rename from server/schema/explanation.py rename to server/schemas/explanation.py index a1953e3..fdad0b3 100644 --- a/server/schema/explanation.py +++ b/server/schemas/explanation.py @@ -1,9 +1,8 @@ from sqlalchemy import Column, ForeignKey, Integer, String, Text, create_engine -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.schema import ForeignKeyConstraint, UniqueConstraint -from server.schema.base import Base +from server.schemas.base import Base class Explanation(Base): diff --git a/server/schema/projects.py b/server/schemas/projects.py similarity index 93% rename from server/schema/projects.py rename to server/schemas/projects.py index f4946c2..819a4f5 100644 --- a/server/schema/projects.py +++ b/server/schemas/projects.py @@ -1,11 +1,10 @@ from enum import Enum from sqlalchemy import TIMESTAMP, Boolean, CheckConstraint, Column, Integer, String, Text -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.schema import ForeignKeyConstraint from sqlalchemy.sql import func from server.models.repo_details import ProjectStatusEnum -from server.schema.base import Base +from server.schemas.base import Base diff --git a/server/schema/pydantic.py b/server/schemas/pydantic.py similarity index 94% rename from server/schema/pydantic.py rename to server/schemas/pydantic.py index 8e19fa1..b6501a1 100644 --- a/server/schema/pydantic.py +++ b/server/schemas/pydantic.py @@ -1,6 +1,6 @@ from sqlalchemy import Column, ForeignKey, ForeignKeyConstraint, Integer, PrimaryKeyConstraint, Text from sqlalchemy.orm import relationship -from server.schema.base import Base +from server.schemas.base import Base class Pydantic(Base): diff --git a/server/schema/user_subscription_detail.py b/server/schemas/user_subscription_detail.py similarity index 100% rename from server/schema/user_subscription_detail.py rename to server/schemas/user_subscription_detail.py diff --git a/server/schema/user_test_details.py b/server/schemas/user_test_details.py similarity index 100% rename from server/schema/user_test_details.py rename to server/schemas/user_test_details.py diff --git a/server/schema/users.py b/server/schemas/users.py similarity index 89% rename from server/schema/users.py rename to server/schemas/users.py index cc9e369..7d7675b 100644 --- a/server/schema/users.py +++ b/server/schemas/users.py @@ -1,10 +1,9 @@ from sqlalchemy import TIMESTAMP, Boolean, Column, String from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship from sqlalchemy.sql import func -from server.schema.base import Base +from server.schemas.base import Base class User(Base): diff --git a/server/test_agent/agents.py b/server/test_agent/agents.py index 8af0711..f227838 100644 --- a/server/test_agent/agents.py +++ b/server/test_agent/agents.py @@ -9,94 +9,31 @@ def __init__(self, openai_client, directory): self.openai_client = openai_client self.directory = directory - # def testing_agent(self, identifier, test_plan): - # return Agent( - # role="FastAPI Software Developer in Test", - # goal="Write meaningful integration tests for FastAPI endpoints using Pytest framework", - # backstory=f"""You're a software developer in test at a software company building with FastAPI. - # You're responsible for reading the test plan and code under test and writing the necessary integration tests for the FastAPI endpoints using PyTest. - # You're also responsible for integrating and utilising any provided necessary fixtures for the tests. - # Request code definitions wherever needed using the tools provided. - # Endpoint identifier : {identifier}, - # {test_plan} - # """, - # tools=[CodeTools().get_code], - # max_iter=10, - # max_rpm=10, - # verbose=True, - # allow_delegation=False, - # llm=openai_client - # ) - - # def fixtures_agent(self): - # return Agent( - # role="Pytest fixtures expert for writing good integration tests", - # goal="Write meaningful pytest fixtures for integration tests for FastAPI endpoints using Pytest framework", - # backstory=f"""You're a software developer specialist in test at a fast-growing startup working with FastAPI. - # you are responsible for writing the necessary fixtures for the tests. - # Request code definitions wherever needed using the tools provided. - # """, - # tools=[CodeTools().get_code], - # max_iter=10, - # max_rpm=10, - # verbose=True, - # allow_delegation=True, - # llm=openai_client - # ) - - # def mocking_agent(self,identifier, test_plan): - # return Agent( - # role="FastAPI Software Developer in Test", - # goal="Write meaningful integration tests for FastAPI endpoints using Pytest framework", - # backstory=f"""You're a software developer in test at a software startup working with FastAPI. - # You're responsible for reading the test plan and code under test and writing the necessary integration tests for the FastAPI endpoints using PyTest. - # You're also responsible for writing the necessary fixtures for the tests. - # Request code definitions wherever needed using the tools provided. - # Endpoint identifier : {identifier}, - # {test_plan} - # """, - # tools=[CodeTools().get_code], - # max_iter=10, - # max_rpm=10, - # verbose=True, - # allow_delegation=True, - # llm=openai_client - # ) - - # def get_input() -> str: - # print("Insert your text. Enter 'q' or press Ctrl-D (or Ctrl-Z on Windows) to end.") - # contents = [] - # while True: - # try: - # line = input() - # except EOFError: - # break - # if line == "q": - # break - # contents.append(line) - # return "\n".join(contents) - - # def data_setup_agent(self, identifier, test_plan): - # return Agent( - # role="Test data setup agent", - # goal="Plan the test data setup required for the test scenarios, and write the necessary pydantic objects for the test data setup", - # backstory=f"""You're a software developer in test at a fast-growing startup working with FastAPI. - # You're responsible for reading the test plan and code under test and creating the test data required for the the necessary integration tests for the FastAPI endpoints using PyTest. - # This includes creating the necessary pydantic objects for the test data setup. - # Request pydantic class definitions wherever needed using the get pydantic definition tool provided. - # Use the get code tool if you require the code under test. - # Use human input if you cannot generate meaningful test data , for example, if you need an s3 bucket name or a user id. - # Create fixtures for test data setup and cleanup where needed. - # Endpoint identifier : {identifier}, - # Test plan : {test_plan} - # """, - # tools=[CodeTools().get_code, CodeTools().get_pydantic_definition, HumanInputRun(input_func=self.get_input) ], - # max_iter=10, - # max_rpm=10, - # verbose=True, - # allow_delegation=True, - # llm=openai_client - # ) + def testing_agent(self, identifier, test_plan): + return Agent( + role="FastAPI Software Developer in Test", + goal="Write comprehensive, meaningful, and maintainable integration tests for FastAPI endpoints using the Pytest framework. Ensure high test coverage by thoroughly testing happy paths, edge cases, and error scenarios based on the provided test plan. Utilize Pytest features effectively to create modular, readable, and efficient tests.", + backstory=f"""You are an experienced software developer in test working at a fast-growing software company that builds applications using the FastAPI framework. +Your primary responsibility is to ensure the quality and reliability of the FastAPI endpoints by writing robust integration tests using the Pytest framework. +You can fetch the code under test using the get code tool in order to write correct mocks and function/ class imports. +Talk to knowledge graph using the ask knowledge graph tool to answer any setup questions. +ALWAYS use the pydantic definitions tool to get the accurate file paths of the pydantic classes you want to import. +To accomplish this, you need to carefully analyze the provided test plan, which outlines the expected behavior and requirements for each endpoint. You should also review the code under test to gain a deep understanding of its functionality and potential edge cases. +Based on the test plan and code analysis, your task is to write comprehensive integration tests that cover a wide range of scenarios, including happy paths, edge cases, and error handling. You should leverage the full capabilities of the Pytest framework to create modular, maintainable, and efficient tests. +In addition to writing the tests, you are also responsible for integrating and utilizing any necessary fixtures provided by the development team. These fixtures will help set up the test environment, manage test data, and handle dependencies. +Throughout the testing process, you should communicate with the knowledge graph, requesting code definitions and clarifications whenever needed. You have access to a set of tools that allow you to retrieve relevant code snippets and information. +Remember, your ultimate goal is to ensure that the FastAPI endpoints are thoroughly tested, reliable, and meet the specified requirements. Your integration tests should instill confidence in the codebase and catch any potential issues before they reach production. +Endpoint identifier: {identifier} +Test plan: {test_plan} +DO NOT INCLUDE ANYTHING OTHER THAN PYTHON TEST CODE IN THE FINAL OUTPUT +""", + tools=[CodeTools().get_code, CodeTools().ask_knowledge_graph, CodeTools().get_pydantic_definitions_tool], + max_iter=10, + max_rpm=10, + verbose=True, + allow_delegation=False, + llm=self.openai_client + ) def pydantic_definition_agent(self): return Agent( @@ -129,3 +66,31 @@ def pydantic_definition_agent(self): allow_delegation=True, llm=self.openai_client, ) + + def code_analysis_agent(self, identifier, project_id): + return Agent( + role="Analyze code to determine additional context needed for testing and query the knowledge graph to obtain that context.", + goal="Read the code under test, identify any additional context or dependencies required for effective testing, query the knowledge graph to retrieve the necessary information, and get pydantic definitions where needed.", + + backstory=f"""You are a software developer in test working on a FastAPI project. Your main responsibilities include: +* Analyzing the code under test for project id = {project_id} to identify any additional context or dependencies needed for comprehensive testing. +* Using the "get code" tool to retrieve the relevant code for {identifier} under analysis. +* Querying the knowledge graph with query and project id to obtain information on related code elements, such as how to set up test data or interact with the database. +* Using the "get pydantic definitions" tool to retrieve pydantic class definitions required for test data setup or understanding the code under test. +* Leveraging the retrieved context to enhance the quality and coverage of the generated tests. +To effectively analyze the code and retrieve necessary context, please follow these guidelines: +* Use the "get code" tool to fetch the code under test for the given identifier. +* Carefully review the code to understand its functionality and identify any dependencies or related code elements that may impact testing. +* Formulate specific queries to the knowledge graph to retrieve information on how to interact with those dependencies, such as inserting test data into the database. +* Use the "get pydantic definitions" tool to retrieve pydantic class definitions whenever they are needed for understanding the code or setting up test data. +* Analyze the query results and pydantic definitions to gain a deeper understanding of the code's context and how it fits into the overall system. +* Utilize the retrieved context to inform test case generation, ensuring that the tests cover all relevant scenarios and edge cases. +* If the query results or pydantic definitions are insufficient or unclear, refine your queries or seek additional information to fill in any gaps in understanding. +* Apply the insights gained from code analysis, knowledge graph queries, and pydantic definitions to enhance the quality, coverage, and maintainability of the generated tests.""", + tools=[CodeTools().ask_knowledge_graph, CodeTools().get_code, CodeTools().get_pydantic_definitions_tool], + max_iter=10, + max_rpm=10, + verbose=True, + allow_delegation=True, + llm=self.openai_client + ) \ No newline at end of file diff --git a/server/test_agent/crew.py b/server/test_agent/crew.py index 057cfd4..f98e50d 100755 --- a/server/test_agent/crew.py +++ b/server/test_agent/crew.py @@ -5,105 +5,78 @@ from crewai import Crew from crewai.process import Process - -from server.parse import get_code_flow_by_id -from server.plan import Plan +import re +import random +import string +from server.utils.ai_helper import get_llm_client +import asyncio from server.test_agent.agents import TestAgents from server.test_agent.tasks import TestTasks from server.utils.test_detail_handler import UserTestDetailsManager -from server.utils.ai_helper import get_llm_client class GenerateTest: + + def __init__(self, identifier: str, endpoint_path: str, test_plan: dict, user_id: str, directory: str, project_id: str): + self.directory = directory + self.user_id = user_id + self.openai_client = get_llm_client(user_id, "gpt-3.5-turbo-0125") + self.reasoning_client = get_llm_client(user_id, os.environ['OPENAI_MODEL_REASONING']) + self.test_plan = test_plan + self.identifier = identifier + self.endpoint_path = endpoint_path + self.pydantic_definition_task = TestTasks(self.reasoning_client,self.directory).get_pydantic_definition_task(identifier, project_id) + self.pydantic_definition_agent = TestAgents(self.openai_client, self.directory).pydantic_definition_agent() + self.code_analysis_agent = TestAgents(self.openai_client,self.directory).code_analysis_agent(identifier, project_id) + self.knowledge_graph_query_task = TestTasks(self.reasoning_client, self.directory).query_knowledge_graph(identifier, project_id) + self.test_task = TestTasks(self.reasoning_client, self.directory ).test_task(identifier , self.test_plan, self.endpoint_path,self.knowledge_graph_query_task, self.pydantic_definition_task, project_id) + self.testing_agent = TestAgents(self.openai_client, self.directory).testing_agent(identifier, self.test_plan) + self.test_crew = Crew(agents=[self.pydantic_definition_agent, self.code_analysis_agent, self.testing_agent], tasks=[self.pydantic_definition_task, self.knowledge_graph_query_task, self.test_task], process=Process.sequential, llm=self.openai_client) + self.user_detail_manager = UserTestDetailsManager() + + def extract_code_blocks(self, text): + if "```python" in text: + code_blocks = re.findall(r'```python(.*?)```', text, re.DOTALL) + else: + code_blocks = text + return code_blocks - def __init__( - self, - identifier: str, - endpoint_path: str, - test_plan: dict, - user_id: str, - directory: str, - ): - self.directory = directory - self.user_id = user_id - - self.openai_client = get_llm_client( - user_id, - "gpt-3.5-turbo-0125", - ) - self.reasoning_client = get_llm_client( - user_id, - os.environ["OPENAI_MODEL_REASONING"], - ) - self.test_plan = test_plan - self.identifier = identifier - self.endpoint_path = endpoint_path - self.pydantic_definition_task = TestTasks( - self.reasoning_client, self.directory - ).get_pydantic_definition_task(identifier) - self.pydantic_definition_agent = TestAgents( - self.reasoning_client, self.directory - ).pydantic_definition_agent() - self.pydantic_crew = Crew( - agents=[self.pydantic_definition_agent], - tasks=[self.pydantic_definition_task], - process=Process.sequential, - llm=self.openai_client, - ) - self.user_detail_manager = UserTestDetailsManager() - - - def extract_code_blocks(self, text): - if "```python" in text: - code_blocks = re.findall(r"```python(.*?)```", text, re.DOTALL) - else: - code_blocks = text - return code_blocks - - async def get_pydantic_definition(self, identifier: str): - self.pydantic_crew.kickoff() - return self.pydantic_definition_task.output.exported_output - - async def write_tests(self, identifier: str, preferences: dict, - no_of_test_generated: int, project_details: list, user_id: str): - print(identifier) - project_id = project_details[2] - repo_name = project_details[3] - branch_name = project_details[4] - func = get_code_flow_by_id(identifier, self.directory) - pydantic_classes = await self.get_pydantic_definition(identifier) - result = await Plan(self.user_id).generate_tests( - self.test_plan, - func, - pydantic_classes, - preferences, - self.endpoint_path, - ) - print(result) - self.user_detail_manager.send_user_test_details( - project_id=project_id, - user_id=user_id, - number_of_tests_generated=no_of_test_generated, - repo_name=repo_name, - branch_name=branch_name - ) - return result + # async def get_pydantic_definition(self, identifier: str): + # print(identifier) + # self.pydantic_crew.kickoff() + # return self.pydantic_definition_task.output.exported_output + async def write_tests(self, identifier: str, preferences: dict, + no_of_test_generated: int, project_details: list, user_id: str): + print(identifier) + project_id = project_details[2] + repo_name = project_details[3] + branch_name = project_details[4] + result = await asyncio.to_thread(self.test_crew.kickoff, None) + print(result) + self.user_detail_manager.send_user_test_details( + project_id=project_id, + user_id=user_id, + number_of_tests_generated=no_of_test_generated, + repo_name=repo_name, + branch_name=branch_name + ) -async def create_temp_test_file(identifier, result, directory): + return self.extract_code_blocks(result)[0] - temp_file_id = "".join( - random.choice(string.ascii_letters) for _ in range(8) - ) - if not os.path.exists(f"{directory}/tests"): + @staticmethod + async def create_temp_test_file( identifier, result, directory): + + temp_file_id = ''.join(random.choice(string.ascii_letters) for _ in range(8)) + if not os.path.exists(f"{directory}/tests"): os.mkdir(f"{directory}/tests") + + filename = f"{directory}/tests/test_{identifier.split(':')[-1]}_{temp_file_id}.py" - filename = ( - f"{directory}/tests/test_{identifier.split(':')[-1]}_{temp_file_id}.py" - ) + - with open(filename, "w") as file: + with open(filename, 'w') as file: # Write the string to the file - file.write(result) - return filename + file.write(result) + return filename diff --git a/server/test_agent/tasks.py b/server/test_agent/tasks.py index 6298e0f..8a7c533 100644 --- a/server/test_agent/tasks.py +++ b/server/test_agent/tasks.py @@ -52,97 +52,73 @@ def __init__(self, reasoning_client, directory): # contents.append(line) # return "\n".join(contents) - def get_pydantic_definition_task(self, identifier): + def get_pydantic_definition_task(self, identifier , project_id): + return Task( + description=f"""Endpoint identifier: {identifier} \n Project id: {project_id} \n + Codebase directory: {self.directory} \n + 1. **Identify Pydantic Class Requirements**: Begin by identifying ALL the specific Pydantic classes required for the test data setup and mock response setup using the code under test. + 2. Pydantic classes can include request and response models, data validation classes, any function definition parameters, and ANY other Pydantic structures used in the endpoint code. + 3. DO NOT MAKE UP PYDANTIC DEFINITIONS: Call the get_pydantic_definitions_tool with a python list of classnames structured as a list ["classA","ClassB"] to get the definitions for all the classes you need. + 4. If there are no pydantic objects REQUIRED for mock or test data setup, DO NOT create Pydantic definition for them., + 5. Add the filename of the class as a comment in the pydantic definition. This is important for the get code tool to work properly. + 6. ALWAYS PROVIDE COMPLETE DEFINITIONS DO NOT LEAVE ANYTHING FOR THE USER TO IMPLEMENT + 7. Use the exact provided project id for tools and do not create new project ids """, + expected_output="Properly formatted Pydantic class definition code", + agent=TestAgents(self.reasoning_client, self.directory).pydantic_definition_agent(), + tools=[CodeTools().get_pydantic_definitions_tool, CodeTools().get_code], + async_execution=True, + ) + + def query_knowledge_graph(self, identifier, project_id): return Task( - description=f"""Endpoint identifier: {identifier} \n - Codebase directory: {self.directory} \n - 1. **Identify Pydantic Class Requirements**: Begin by identifying ALL the specific Pydantic classes required for the test data setup and mock response setup using the code under test. - 2. Pydantic classes can include request and response models, data validation classes, any function definition parameters, and ANY other Pydantic structures used in the endpoint code. - 3. DO NOT MAKE UP PYDANTIC DEFINITIONS: Call the get_pydantic_definitions_tool with a python list of classnames structured as a list ["classA","ClassB"] to get the definitions for all the classes you need. - 4. If there are no pydantic objects REQUIRED for mock or test data setup, DO NOT create Pydantic definition for them., - 5. Add the filename of the class as a comment in the pydantic definition. This is important for the get code tool to work properly. - 6. ALWAYS PROVIDE COMPLETE DEFINITIONS DO NOT LEAVE ANYTHING FOR THE USER TO IMPLEMENT """, - expected_output=( - "Properly formatted Pydantic class definition code" - ), - agent=TestAgents( - self.reasoning_client, self.directory - ).pydantic_definition_agent(), - tools=[ - CodeTools().get_pydantic_definitions_tool, - CodeTools().get_code, - ], + description=f""" Endpoint identifier: {identifier} \n Project id: {project_id} \n + 1. **Formulate a Clear Query**: Begin by formulating a clear and specific natural language query to retrieve the desired information from the knowledge graph along with project id. Consider the various aspects of the codebase you want to explore, such as API endpoints, code explanations, relationships between code elements, or pydantic definitions. + 2. **Analyze Query Results**: Once you receive the query results, carefully analyze them to understand the relationships and dependencies between different code elements. Look for insights that can help you gain a deeper understanding of the codebase and its functionality. + 3. **Utilize Code Explanations**: Pay attention to the code explanations and inferred knowledge provided in the query results. These explanations can offer valuable insights into the purpose and functionality of specific code segments, helping you make informed decisions during the testing process. + 4. **Leverage Pydantic Definitions**: If your query involves pydantic models, make sure to examine the pydantic definitions using the get pydantic definitions tool. Understanding the data models and schemas used in the project is crucial for creating accurate and comprehensive tests. Incude the definitions is final output AS IS. + 5. **Refine Queries if Needed**: If the initial query results are unclear or insufficient, don't hesitate to refine your queries and seek additional information from the knowledge graph. Iterative querying can help you gather all the necessary details to thoroughly understand the codebase. + 6. **Apply Insights to Testing**: Finally, apply the insights gained from the knowledge graph to enhance the quality and effectiveness of your tests. Use the information to make informed decisions, identify potential edge cases, and ensure that your tests cover a wide range of scenarios relevant to the codebase. + 7. **Plan imports for test file**: The test file should always import functions and classes from the correct path and not a made up path. """, + expected_output="""1. Relevant insights and information from the code knowledge graph to aid in understanding the codebase and creating effective tests., + 2.INCLUDE EXACT PYDANTIC DEFINITIONS OF REQUESTED CLASSES WITH NO MODIFICATIONS IN OUTPUT. + 3.ALWAYS INCLUDE the accurate path for classes or functions that need to be imported in the test file.""", + + agent=TestAgents(self.reasoning_client, self.directory).code_analysis_agent(identifier, project_id), + tools=[CodeTools().ask_knowledge_graph, CodeTools.get_code, CodeTools.get_pydantic_definitions_tool], + async_execution=True, + ) + + def test_task( + self, identifier, test_plan, endpoint_path, code_analysis_task, pydantic_definition_task, project_id + ): + return Task( + description= f"""Test Plan {test_plan} for identifier {identifier} and project id {project_id} + Using Python and ONLY the pytest package along with pytest-mocks for mocking, write a suite of integration tests - one each for every scenario in the test plan above, personalise your tests for the flow defined by the function that can be fetched using the get_code tool. + The complete path of the endpoint is {endpoint_path}. It is important to use this complete path in the test API call because the code might not contain prefixes. + Consider the following points while writing the integration tests: + You can fetch the code under test using the get code tool with correct identifier and project id in order to write correct mocks and function/ class imports. + Talk to knowledge graph using the ask knowledge graph tool to answer any setup questions. + ALWAYS use the pydantic definitions tool to get the accurate file paths of the pydantic classes you want to import. + * Analyze the provided function code and identify the key components, such as dependencies, database connections, and external API calls, that need to be mocked or set up for testing. + * Review the provided test plan and understand the different test scenarios that need to be covered. Consider edge cases, error handling, and potential variations in input data. + * Use the provided context and pydantic classes from the output of the code analysis task to create the necessary pydantic objects for the test data and mock test data setup. This ensures that the tests align with the expected data structures used in the function. + * Pay attention to the preferences provided: ({None}). If a list of entities (functions, classes, databases, etc.) is specified to be mocked, strictly follow these preferences. If the preferences are empty, use your judgment to determine which components should be mocked, such as the database and any external API calls. + * Utilize FastAPI testing features like TestClient and dependency overrides to set up the test environment. Create fixtures to minimize code duplication and improve test maintainability. + * ALWAYS create a new FastAPI app in the test client and INCLUDE THE RELEVANT ROUTERS IN THE APP in it for testing. DO NOT ASSUME where the main FastAPI app is defined. DO NOT REWRITE ROUTERS in the test file. + * When setting up mocks, use the pytest-mock library. Check if the output structure is defined in the code and use that to create the expected output response data for the test cases. If not defined, infer the expected output based on the test plan outcomes and the provided code under test. + * When defining the target using pytest mocks, ensure that the target path is the path of the call and not the path of the definition. + * For a func_a defined at src.utils.helper and imported in code as from src.utils.helper import func_a, the mock would look like : mocker.patch('src.pipeline.node_1.func_a', return_value="some_value") + * Write clear and concise test case names that reflect the scenario being tested. Use assertions to validate the expected behavior and handle potential exceptions gracefully. + * Consider the performance implications of the tests and optimize them where possible. Use appropriate setup and teardown methods to manage test resources efficiently. + * Provide meaningful error messages and logging statements to aid in debugging and troubleshooting failed tests. + * Reply only with complete code, formatted as follows (DO NOT INCLUDE ANYTHING OTHER THAN CODE, NOT EVEN MARKDOWN, JUST PYTHON CODE): + # imports + # Any required fixtures can be defined here + #insert integration test code here + ``` + """, + expected_output="Properly formatted pytest test code for FastAPI", + agent=TestAgents(self.reasoning_client, self.directory).testing_agent(identifier, test_plan), + context=[code_analysis_task], + tools=[CodeTools().get_code, CodeTools().ask_knowledge_graph, CodeTools.get_pydantic_definitions_tool], ) - - -# def data_setup_task(self, identifier, test_plan, get_pydantic_definition_task): -# return Task( -# description=""" You are responsible for creating all the test data necessary for testing the different scenarios included in the test plan for the function. -# 1. This includes creating the necessary pydantic objects for the test data setup. -# 2. Request pydantic class definitions wherever needed using the get pydantic definition tool provided. DO NOT MAKE UP PYDANTIC DEFINITIONS. Call the tool multiple times to get the definitions for all the classes you need. -# 3. Use the get code tool if you require the code under test. -# 4. Use human input if you cannot generate meaningful test data , for example, if you need an s3 bucket name or a user id. -# 5. Create pytest fixtures for test data setup and cleanup where needed. -# 6. Use pytest-mocks to create mocks for the different scenarios included in the test plan. Ask the human for input if you need advice on whether to mock a certain class. -# """, -# # description='based on the test plan and the code for the endpoint to be tested, evaluate what kind of data setup is needed for each test scenario. Use the pydantic definition tool in order to understand the input structure of each endpoint and create relevant pydantic object initialized with relevant datat from it to be used in the next test creation step. ', -# expected_output="Input test data for each test scenario in the form of pydantic objects", -# agent=TestAgents().data_setup_agent(identifier, test_plan), -# context=[get_pydantic_definition_task], -# tools=[CodeTools().get_pydantic_definition, CodeTools().get_code], -# ) - -# def mocking_task(self, identifier, test_plan): -# return Task( -# description="""1. **Define the Mocking Scope**: Begin by identifying which external dependencies or services need to be mocked. This includes third-party APIs, databases, or internal services not under test. Clearly articulate the reasons for mocking these components, focusing on the need to isolate the system under test from external interactions. - -# 2. **Mocking Strategies**: Decide between using `pytest-mock` or `unittest.mock` based on the test framework and specific needs of the test suite. Outline the criteria for this choice, such as compatibility with the testing framework, ease of use, and available features. - -# 3. **Direct Mocking vs. Fixtures**: Evaluate the test scenarios to determine if direct mocking within individual tests or the use of fixtures for shared mocked instances is more appropriate. Explain the rationale behind this decision, considering factors like test isolation, reusability of mock objects, and the complexity of setup and teardown processes. - -# 4. **Scenario-Based Mocking**: For each external dependency being mocked, describe the different scenarios to be tested. These scenarios should cover normal operation, error conditions, timeouts, and unexpected responses. Detail the steps to configure the mock objects to simulate these conditions and the expected outcomes of the tests. - -# 5. **Implementing Mocks**: Provide a step-by-step guide for implementing mocks in the test suite. This includes creating mock objects, configuring return values or side effects, and integrating mocks into the tests. Emphasize best practices for ensuring that mocks accurately reflect the behavior of the real dependencies they replace. - -# 6. **Verifying Mock Interactions**: Outline methods for verifying that the system under test interacts with the mocks as expected. This includes checking that the correct methods are called with the expected arguments and that the system under test properly handles the mocked responses. - -# 7. **Cleanup and Teardown**: Finally, detail the process for cleaning up mock objects after tests to prevent side effects on subsequent tests. This might involve resetting mock objects, removing any temporary data, or restoring original states if necessary. -# """, -# expected_output="Mocks for each test scenario in the form of json objects objects", -# agent=TestAgents().mocking_agent(identifier, test_plan), -# tools=[CodeTools().get_code], -# ) - -# def test_task( -# self, identifier, test_plan, fixture_task, data_setup_task, mocking_task -# ): -# return Task( -# description=""" - -# Using Python and the pytest and pytest-mocks package, write a suite of integration tests. Following the test plan above, personalise your tests for the code flow defined for the identifier. You can fetch the code flow using the get code tool if needed. -# 1. Use to the output of the data setup task to utilise the fixtures where possible to avoid duplication of code and easy test data setup and cleanup. -# 2. Detail the rationale behind assertions, including status codes, data integrity, and application state, with examples of complex assertions for enhanced understanding. - -# 3. Ensure tests are designed for independence and parallel execution, discussing patterns for shared setup/teardown while maintaining isolation. Mention tools that support these practices. - -# 4. Address error handling and validation, with strategies for testing expected failures and asserting error messages. -# 5. Use FastAPI features for testing like the TestClient and create Pydantic objects for request body and response validation where required. Pydantic definitions of models can be looked up using the get pydantic definition tool. - -# 5. Respond only with a python code block, do not respond with any other text. Format your output as follows: -# Reply only with complete code, do not reply with any other text, formatted as follows: -# ```python -# # imports -# import pytest # used for our integration tests -# #insert other imports as needed -# #fixtures - -# # integration tests -# #insert integration test code here -# ``` -# 6. Always refer the API path from the code to ensure that you do not get 404 errors. -# """, -# expected_output="Properly formatted pytest test code for FastAPI", -# agent=TestAgents().testing_agent(identifier, test_plan), -# context=[data_setup_task], -# tools=[CodeTools().get_code], -# ) diff --git a/server/test_agent/tools.py b/server/test_agent/tools.py index 1441230..0707665 100644 --- a/server/test_agent/tools.py +++ b/server/test_agent/tools.py @@ -1,69 +1,91 @@ from typing import List from langchain.tools import tool - +from server.utils.github_helper import GithubService from server.parse import ( - get_code_flow_by_id, + get_flow, + get_node_by_id, get_pydantic_class, get_pydantic_classes, ) +from server.utils.graph_db_helper import Neo4jGraph +neo4j_graph = Neo4jGraph() class CodeTools: """ A class that provides code tools for generating code for endpoint identifiers. """ - + # Annotate the function with the tool decorator from LangChain @tool("Get accurate code context for given endpoint identifier") - def get_code(identifier, directory): - """ - Get the code for the specified endpoint identifier. - - Parameters: - - identifier: The identifier of the endpoint. - - directory: The directory of the codebase. + def get_code(identifier, project_id): + """ + Get the code for the specified endpoint identifier. + Parameters: + - identifier: The identifier of the endpoint. + - project_id: The exact project id of the project. + Returns: + - The code for the specified endpoint identifier. + """ + code = "" + nodes = get_flow(identifier, project_id) + for node in nodes: + node = get_node_by_id(node, project_id) + code += ( + "\n" + + node["id"] + + "\n code: \n" + + GithubService.fetch_method_from_repo(node) + ) + return code - Returns: - - The code for the specified endpoint identifier. - """ - return get_code_flow_by_id(identifier, directory) - @tool("Get pydantic class definition for a single class name") - def get_pydantic_definition(classname, directory): - """ + @tool("Get pydantic class definition for a single class name") + def get_pydantic_definition( classname, project_id): + """ Get the pydantic class definition for given class name + Parameters: + - classname: The name of a class. + - project_id: The id of the project. + Returns: + - The pydantic class definition for the specified class name. + """ + return get_pydantic_class(classname, project_id) - Parameters: - - classname: The name of a class. - - directory: The directory of the codebase. - - Returns: - - The code definition for the specified pydantic class. - """ - print("pyd inp: " + classname) - - return get_pydantic_class(classname, directory) - - @tool("Get the pydantic class definition for list of class names") - def get_pydantic_definitions_tool(classnames: List[str], directory): - """ - Get the pydantic class definition for list of class names + @tool("Get the pydantic class definition for list of class names") + def get_pydantic_definitions_tool( classnames: List[str], project_id): + """ + Get the pydantic class definition for list of class names + Parameters: + - classnames: The list of the names of pydantic classes. + - project_id: The id of the project. + Returns: + - The code definitions for the specified pydantic classes. + """ + definitions = "" + try: + definitions = get_pydantic_classes(classnames, project_id) + except Exception as e: + print(f"something went wrong during fetching definition for {classnames}", e) + return definitions + @tool("Query the code knowledge graph with specific directed questions using natural language and project id and return the query result") + def ask_knowledge_graph(query: str, project_id) -> str: + """ + Query the code knowledge graph using natural language questions. + DO NOT USE THIS TOOL TO QUERY CODE DIRECTLY. USE GET CODE TOOL TO FETCH CODE + The knowledge graph contains information from various database tables including: + - inference: key-value pairs of inferred knowledge about APIs and their constituting functions + - endpoints: API endpoint paths and identifiers + - explanation: code explanations for function identifiers + - pydantic: pydantic class definitions Parameters: - - classnames: The list of the names of pydantic classes. - - directory: The directory of the codebase. - + - query: A natural language question to ask the knowledge graph.\ + - project_id: The project id metadata for the project bein evaluated Returns: - - The code definitions for the specified pydantic classes. - """ - definitions = "" - try: - definitions = get_pydantic_classes(classnames, directory) - except Exception as e: - print( - "something went wrong during fetching definition for" - f" {classnames}", - e, - ) - return definitions + - The result of querying the knowledge graph with the provided question. + """ + from server.knowledge_graph.knowledge_graph import KnowledgeGraph + + return KnowledgeGraph(project_id).query(query, project_id) \ No newline at end of file diff --git a/server/utils/ai_helper.py b/server/utils/ai_helper.py index 3e8816d..4b16b7e 100644 --- a/server/utils/ai_helper.py +++ b/server/utils/ai_helper.py @@ -2,6 +2,7 @@ from langchain_openai.chat_models import ChatOpenAI from portkey_ai import createHeaders, PORTKEY_GATEWAY_URL +import asyncio color_prefix_by_role = { "system": "\033[0m", # gray @@ -11,9 +12,9 @@ "assistant": "\033[92m", # green } - def get_llm_client(user_id, model_name): - return create_client("openai", os.environ.get("OPENAI_API_KEY"), model_name, user_id) + provider_key = os.getenv("OPENAI_API_KEY") + return create_client("openai", provider_key, model_name, user_id) def create_client(provider, key, model_name, user_id): if provider == "openai": @@ -25,9 +26,8 @@ def create_client(provider, key, model_name, user_id): return ChatOpenAI(api_key=PROVIDER_API_KEY, model=model_name, base_url=PORTKEY_GATEWAY_URL, default_headers=portkey_headers) - -def llm_call(client, messages, print_text=True, temperature=0.4): - response = client(messages=messages, temperature=temperature) +async def llm_call(client, messages, print_text=True, temperature=0.4): + response = await asyncio.to_thread(client, messages=messages, temperature=temperature) if print_text: print_message_delta(response) return response diff --git a/server/utils/test_detail_handler.py b/server/utils/test_detail_handler.py index 9a1f737..7badc7e 100644 --- a/server/utils/test_detail_handler.py +++ b/server/utils/test_detail_handler.py @@ -1,7 +1,7 @@ import datetime -from server.schema.user_subscription_detail import UserSubscriptionDetail -from server.schema.user_test_details import UserTestDetail +from server.schemas.user_subscription_detail import UserSubscriptionDetail +from server.schemas.user_test_details import UserTestDetail class UserTestDetailsManager: @@ -9,30 +9,10 @@ def __init__(self): pass def send_user_test_details(self, project_id: str, user_id: str, number_of_tests_generated: int, repo_name: str, branch_name: str): - try: - user_test_detail = UserTestDetail( - project_id=project_id, - user_id=user_id, - number_of_tests_generated=number_of_tests_generated, - date_of_generation=datetime.datetime.utcnow(), - repo_name=repo_name, - branch_name=branch_name - ) - user_test_detail.save() - print("Data successfully added to MongoDB.") - except Exception as e: - print(f"An error occurred: {e}") + return def get_test_count_last_month(self, user_id: str) -> int: - try: - one_month_ago = datetime.datetime.utcnow() - datetime.timedelta(days=30) - docs = UserTestDetail.objects(user_id=user_id, date_of_generation__gte=one_month_ago) - total_tests = sum(doc.number_of_tests_generated for doc in docs) - return total_tests - except Exception as e: - print(f"An error occurred: {e}") - return 0 + return 0 def is_pro_plan(self, user_id): - return False - + return False \ No newline at end of file