Skip to content

Commit

Permalink
Merge pull request #8 from getmomentum/knowledege_graph
Browse files Browse the repository at this point in the history
Async functions, Added support for knowledge graph + minor refactors
  • Loading branch information
vineetshar authored Jul 1, 2024
2 parents 59762bd + f9b409c commit 4d765c2
Show file tree
Hide file tree
Showing 29 changed files with 618 additions and 480 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ venv/
.venv
.momentum
/server/.hypothesis
.hypothesis
.hypothesis
db
13 changes: 10 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ langchain
langchain-community
langchain-openai
langchain-experimental
crewai==0.19.0
crewai
PyGithub==2.3.0
psycopg2-binary==2.9.6
firebase-admin==6.5.0
neo4j==5.2
google-cloud-secret-manager
posthog==1.4.0
posthog==3.5.0
starlette==0.35.0
loguru
python-dotenv
Expand All @@ -26,4 +26,11 @@ PyJWT
setuptools
portkey_ai
gunicorn
sentry-sdk[fastapi]
sentry-sdk[fastapi]
pydantic==2.7.4
pydantic_core==2.18.4
psycopg
embedchain==0.1.103
kombu==5.4.0rc1
embedchain[weaviate]
celery[redis]
5 changes: 2 additions & 3 deletions server/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dotenv import load_dotenv
from sqlalchemy import engine_from_config, pool

from server.schema.base import Base
from server.schemas.base import Base

CURRENT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(CURRENT_DIR)
Expand All @@ -26,8 +26,7 @@
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Base.metadata
# escaped_url = os.environ["POSTGRES_SERVER"].replace("%", "%%")
escaped_url = "postgresql://postgres:mysecretpassword@localhost:5432/momentum"
escaped_url = os.environ["POSTGRES_SERVER"].replace("%", "%%")
config.set_main_option("sqlalchemy.url", escaped_url)
# other values from the config, defined by the needs of env.py,
# can be acquired:
Expand Down
1 change: 0 additions & 1 deletion server/alembic/versions/6c007877a09d_jmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from server.schema.projects import ProjectStatusEnum

# revision identifiers, used by Alembic.
revision: str = '6c007877a09d'
Expand Down
1 change: 0 additions & 1 deletion server/change_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from tree_sitter_languages import get_language, get_parser
import os
import argparse

# Load Python grammar for Tree-sitter
parser = get_parser("python")
Expand Down
8 changes: 4 additions & 4 deletions server/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def add_codebase_map_path(self, directory):
return f"{directory}/{db_path}"


def dependencies_from_function(self, project_details, function_identifier: str,
async def dependencies_from_function(self, project_details, function_identifier: str,
function_to_test: str,
flow: list,
print_text: bool = True, # optionally prints text; helpful for understanding the function & debugging
Expand Down Expand Up @@ -73,17 +73,17 @@ def dependencies_from_function(self, project_details, function_identifier: str,
print_messages(detect_messages)


explanation = llm_call(self.detect_client, detect_messages)
explanation = await llm_call(self.detect_client, detect_messages)
return [x.strip() for x in explanation.content.split(",") if x.strip() != ""]


def get_dependencies(self, project_details, function_identifier):
async def get_dependencies(self, project_details, function_identifier):
flow = get_flow(function_identifier, project_details[2])
flow_trimmed = [x.split(':')[1] for x in flow if x != function_identifier]
output = []
for function in flow:
node = get_node(function, project_details)
code = GithubService.fetch_method_from_repo(node)
output += ( self.dependencies_from_function(project_details, function, code, flow_trimmed))
output += ( await self.dependencies_from_function(project_details, function, code, flow_trimmed))
return output+flow_trimmed

43 changes: 8 additions & 35 deletions server/endpoint_detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import os
import re
Expand Down Expand Up @@ -32,39 +33,8 @@ def __init__(
self.router_prefix_file_mapping = router_prefix_file_mapping
self.file_index = file_index

# SQLite database setup

def setup_database(self):
conn = psycopg2.connect(os.getenv("POSTGRES_SERVER"))
cursor = conn.cursor()
try:
# Create table if it doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS endpoints (
path TEXT,
identifier TEXT,
test_plan TEXT,
preferences TEXT,
project_id integer NOT NULL,
PRIMARY KEY (project_id, identifier),
CONSTRAINT fk_project FOREIGN KEY (project_id)
REFERENCES projects (id)
ON DELETE CASCADE
)
""")

# Commit the transaction
conn.commit()
except psycopg2.Error as e:
print(f"PostgreSQL error: {e}")
finally:
# Close the connection
if conn:
cursor.close()
conn.close()



def extract_path(self, decorator):
# Find the position of the first opening parenthesis and the following comma
start = decorator.find("(") + 1
Expand Down Expand Up @@ -472,8 +442,8 @@ def extract_function_metadata(self, node):

return function_name, parameters, start, end, text

def analyse_endpoints(self, project_id):
self.setup_database()
async def analyse_endpoints(self, project_id, user_id):

conn = psycopg2.connect(os.getenv("POSTGRES_SERVER"))
cursor = conn.cursor()
detected_endpoints = []
Expand Down Expand Up @@ -512,6 +482,9 @@ def analyse_endpoints(self, project_id):
)

conn.close()
from server.knowledge_graph.flow import understand_flows

asyncio.create_task(understand_flows(project_id, self.directory, user_id))

def get_qualified_endpoint_name(self, path, prefix):
if prefix == None:
Expand Down Expand Up @@ -686,7 +659,7 @@ def get_test_plan_preferences(self, identifier, project_id):
else:
test_plan = None # No test plan found for the given identifier

if row[1]:
if row and row[1]:
preferences = json.loads(
row[1]
) # Deserialize the test plan back into a Python dictionary
Expand Down
2 changes: 1 addition & 1 deletion server/handler/user_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from server.db.session import SessionManager
from server.models.user import CreateUser
from server.schema.users import User
from server.schemas.users import User


class UserHandler:
Expand Down
Loading

0 comments on commit 4d765c2

Please # to comment.