From 609a8eb129d90420e0f6e0e542010e90ad9fe016 Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Wed, 18 Dec 2024 12:24:47 +0530 Subject: [PATCH 1/6] Q&A streaminh --- .../agents/agents/callback_handler.py | 111 ++++++++++++++++ .../intelligence/agents/agents/rag_agent.py | 38 +++++- .../agents/chat_agents/qna_chat_agent.py | 123 +++++++++++------- 3 files changed, 221 insertions(+), 51 deletions(-) create mode 100644 app/modules/intelligence/agents/agents/callback_handler.py diff --git a/app/modules/intelligence/agents/agents/callback_handler.py b/app/modules/intelligence/agents/agents/callback_handler.py new file mode 100644 index 00000000..94669aa7 --- /dev/null +++ b/app/modules/intelligence/agents/agents/callback_handler.py @@ -0,0 +1,111 @@ +import os +from datetime import datetime +import json +from typing import Any, Dict, List, Optional, Tuple, Union +from crewai.agents.parser import AgentAction + +class FileCallbackHandler: + def __init__(self, filename: str = "agent_execution_log.md"): + """Initialize the file callback handler. + + Args: + filename (str): The markdown file to write the logs to + """ + self.filename = filename + # Create or clear the file initially + with open(self.filename, 'w', encoding='utf-8') as f: + f.write(f"# Agent Execution Log\nStarted at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") + + def __call__(self, step_output: Union[str, List[Tuple[Dict[str, Any], str]], AgentAction]) -> None: + """Callback function to handle agent execution steps. + + Args: + step_output: Output from the agent's execution step. Can be: + - string + - list of (action, observation) tuples + - AgentAction from CrewAI + """ + with open(self.filename, 'a', encoding='utf-8') as f: + f.write(f"\n## Step - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write("---\n") + + # Handle AgentAction output + if isinstance(step_output, AgentAction): + # Write thought section + if hasattr(step_output, 'thought') and step_output.thought: + f.write("### Thought\n") + f.write(f"{step_output.thought}\n\n") + + # Write tool section + if hasattr(step_output, 'tool'): + f.write("### Action\n") + f.write(f"**Tool:** {step_output.tool}\n") + + # if hasattr(step_output, 'tool_input'): + # try: + # # Try to parse and pretty print JSON input + # tool_input = json.loads(step_output.tool_input) + # formatted_input = json.dumps(tool_input, indent=2) + # f.write(f"**Input:**\n```json\n{formatted_input}\n```\n") + # except (json.JSONDecodeError, TypeError): + # # Fallback to raw string if not JSON + # f.write(f"**Input:**\n```\n{step_output.tool_input}\n```\n") + + # # Write result section + # if hasattr(step_output, 'result'): + # f.write("\n### Result\n") + # try: + # # Try to parse and pretty print JSON result + # result = json.loads(step_output.result) + # formatted_result = json.dumps(result, indent=2) + # f.write(f"```json\n{formatted_result}\n```\n") + # except (json.JSONDecodeError, TypeError): + # # Fallback to raw string if not JSON + # f.write(f"```\n{step_output.result}\n```\n") + + f.write("\n") + return + + # Handle single string output + if isinstance(step_output, str): + f.write(step_output + "\n") + return + + for step in step_output: + if not isinstance(step, tuple): + f.write(str(step) + "\n") + continue + + action, observation = step + + # Handle action section + f.write("### Action\n") + if isinstance(action, dict): + if "tool" in action: + f.write(f"**Tool:** {action['tool']}\n") + if "tool_input" in action: + f.write(f"**Input:**\n```\n{action['tool_input']}\n```\n") + if "log" in action: + f.write(f"**Log:** {action['log']}\n") + if "Action" in action: + f.write(f"**Action Type:** {action['Action']}\n") + else: + f.write(f"{str(action)}\n") + + # Handle observation section + f.write("\n### Observation\n") + if isinstance(observation, str): + # Handle special formatting for search-like results + lines = observation.split('\n') + for line in lines: + if line.startswith(('Title:', 'Link:', 'Snippet:')): + key, value = line.split(':', 1) + f.write(f"**{key.strip()}:**{value}\n") + elif line.startswith('-'): + f.write(line + "\n") + else: + f.write(line + "\n") + else: + f.write(str(observation) + "\n") + + f.write("\n") \ No newline at end of file diff --git a/app/modules/intelligence/agents/agents/rag_agent.py b/app/modules/intelligence/agents/agents/rag_agent.py index b33a7e1d..09f8cbae 100644 --- a/app/modules/intelligence/agents/agents/rag_agent.py +++ b/app/modules/intelligence/agents/agents/rag_agent.py @@ -1,12 +1,16 @@ +import asyncio import os -from typing import Any, Dict, List +from typing import Any, AsyncGenerator, Dict, List +import aiofiles +from contextlib import redirect_stdout import agentops from crewai import Agent, Crew, Process, Task from pydantic import BaseModel, Field from app.modules.code_provider.code_provider_service import CodeProviderService from app.modules.conversations.message.message_schema import NodeContext +from app.modules.intelligence.agents.agents.callback_handler import FileCallbackHandler from app.modules.intelligence.provider.provider_service import ( AgentType, ProviderService, @@ -71,6 +75,7 @@ def __init__(self, sql_db, llm, mini_llm, user_id): self.llm = llm self.mini_llm = mini_llm self.user_id = user_id + #self.callback_handler = FileCallbackHandler("rag_agent_execution.md") async def create_agents(self): query_agent = Agent( @@ -101,6 +106,7 @@ async def create_agents(self): verbose=True, llm=self.llm, max_iter=self.max_iter, + #step_callback=self.callback_handler, ) return query_agent @@ -267,7 +273,7 @@ async def kickoff_rag_agent( llm, mini_llm, user_id: str, -) -> str: +) -> AsyncGenerator[str, None]: provider_service = ProviderService(sql_db, user_id) crew_ai_llm = provider_service.get_large_llm(agent_type=AgentType.CREWAI) crew_ai_mini_llm = provider_service.get_small_llm(agent_type=AgentType.CREWAI) @@ -275,7 +281,31 @@ async def kickoff_rag_agent( file_structure = await CodeProviderService(sql_db).get_project_structure_async( project_id ) - result = await rag_agent.run( + + + read_fd, write_fd = os.pipe() + + async def kickoff(): + with os.fdopen(write_fd, "w", buffering=1) as write_file: + with redirect_stdout(write_file): + await rag_agent.run( query, project_id, chat_history, node_ids, file_structure ) - return result + + + asyncio.create_task(kickoff()) + + # Yield CrewAgent logs as they are written to the pipe + final_answer_streaming = False + async with aiofiles.open(read_fd, mode='r') as read_file: + async for line in read_file: + if not line: + break + else: + if final_answer_streaming: + if line.endswith('\\x1b[00m\\n'): + yield line[:-6] + else: + yield line + if "## Final Answer:" in line: + final_answer_streaming = True \ No newline at end of file diff --git a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py index c29a1c37..d73740b5 100644 --- a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py @@ -2,8 +2,14 @@ import logging import time from functools import lru_cache -from typing import AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator, Dict, List, Optional +from typing import Annotated +from langgraph.types import StreamWriter +from typing_extensions import TypedDict + +from langgraph.graph import StateGraph, START, END +from langgraph.graph.message import add_messages from langchain.schema import HumanMessage, SystemMessage from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ( @@ -81,7 +87,57 @@ async def _classify_query(self, query: str, history: List[HumanMessage]): return response.classification - async def run( + + + class State(TypedDict): + query: str + project_id: str + user_id: str + conversation_id: str + node_ids: List[NodeContext] + + + + async def _stream_rag_agent(self, state: State, writer: StreamWriter): + async for chunk in self.execute( + state["query"], + state["project_id"], + state["user_id"], + state["conversation_id"], + state["node_ids"], + ): + + writer(chunk) + + def _create_graph(self): + graph_builder = StateGraph(QNAChatAgent.State) + + graph_builder.add_node( + "rag_agent", + self._stream_rag_agent, + ) + graph_builder.add_edge(START, "rag_agent") + graph_builder.add_edge("rag_agent", END) + return graph_builder.compile() + + async def run(self, + query: str, + project_id: str, + user_id: str, + conversation_id: str, + node_ids: List[NodeContext],): + state = { + "query": query, + "project_id": project_id, + "user_id": user_id, + "conversation_id": conversation_id, + "node_ids": node_ids, + } + graph = self._create_graph() + async for chunk in graph.astream(state,stream_mode="custom"): + yield chunk + + async def execute( self, query: str, project_id: str, @@ -117,8 +173,7 @@ async def run( tool_results = [] citations = [] if classification == ClassificationResult.AGENT_REQUIRED: - rag_start_time = time.time() # Start timer for RAG agent - rag_result = await kickoff_rag_agent( + async for chunk in kickoff_rag_agent( query, project_id, [ @@ -131,54 +186,28 @@ async def run( self.llm, self.mini_llm, user_id, - ) - rag_duration = time.time() - rag_start_time # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of RAG agent: {rag_duration:.2f}s" - ) + ): + content = str(chunk) - if rag_result.pydantic: - citations = rag_result.pydantic.citations - response = rag_result.pydantic.response - result = [node for node in response] - else: - citations = [] - result = rag_result.raw - tool_results = [SystemMessage(content=result)] - # Timing for adding message chunk - add_chunk_start_time = ( - time.time() - ) # Start timer for adding message chunk - self.history_manager.add_message_chunk( - conversation_id, - tool_results[0].content, - MessageType.AI_GENERATED, - citations=citations, - ) - add_chunk_duration = ( - time.time() - add_chunk_start_time - ) # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of adding message chunk: {add_chunk_duration:.2f}s" - ) + self.history_manager.add_message_chunk( + conversation_id, + content, + MessageType.AI_GENERATED, + citations=citations, + ) + + yield json.dumps( + { + "citations": citations, + "message": content, + } + ) + - # Timing for flushing message buffer - flush_buffer_start_time = ( - time.time() - ) # Start timer for flushing message buffer self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED ) - flush_buffer_duration = ( - time.time() - flush_buffer_start_time - ) # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of flushing message buffer: {flush_buffer_duration:.2f}s" - ) - yield json.dumps({"citations": citations, "message": result}) + if classification != ClassificationResult.AGENT_REQUIRED: inputs = { From 1e41250e5f601d29ff0b398aaf440e2f288b4adf Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Wed, 18 Dec 2024 16:06:13 +0530 Subject: [PATCH 2/6] streaming for all agents --- .../agents/agents/callback_handler.py | 60 +++++----- .../agents/agents/code_gen_agent.py | 52 +++++--- .../agents/agents/debug_rag_agent.py | 59 ++++++--- .../agents/agents/low_level_design_agent.py | 50 +++++--- .../intelligence/agents/agents/rag_agent.py | 23 ++-- .../agents/chat_agents/code_gen_chat_agent.py | 103 +++++++++------- .../chat_agents/debugging_chat_agent.py | 113 +++++++++++------- .../agents/chat_agents/lld_chat_agent.py | 113 ++++++++++-------- .../agents/chat_agents/qna_chat_agent.py | 28 ++--- .../graph_construction/code_graph_service.py | 1 - .../graph_construction/parsing_repomap.py | 2 +- 11 files changed, 356 insertions(+), 248 deletions(-) diff --git a/app/modules/intelligence/agents/agents/callback_handler.py b/app/modules/intelligence/agents/agents/callback_handler.py index 94669aa7..bfd6223a 100644 --- a/app/modules/intelligence/agents/agents/callback_handler.py +++ b/app/modules/intelligence/agents/agents/callback_handler.py @@ -1,46 +1,50 @@ -import os from datetime import datetime -import json -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Tuple, Union + from crewai.agents.parser import AgentAction + class FileCallbackHandler: def __init__(self, filename: str = "agent_execution_log.md"): """Initialize the file callback handler. - + Args: filename (str): The markdown file to write the logs to """ self.filename = filename # Create or clear the file initially - with open(self.filename, 'w', encoding='utf-8') as f: - f.write(f"# Agent Execution Log\nStarted at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") - - def __call__(self, step_output: Union[str, List[Tuple[Dict[str, Any], str]], AgentAction]) -> None: + with open(self.filename, "w", encoding="utf-8") as f: + f.write( + f"# Agent Execution Log\nStarted at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n" + ) + + def __call__( + self, step_output: Union[str, List[Tuple[Dict[str, Any], str]], AgentAction] + ) -> None: """Callback function to handle agent execution steps. - + Args: step_output: Output from the agent's execution step. Can be: - string - list of (action, observation) tuples - AgentAction from CrewAI """ - with open(self.filename, 'a', encoding='utf-8') as f: + with open(self.filename, "a", encoding="utf-8") as f: f.write(f"\n## Step - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") f.write("---\n") - + # Handle AgentAction output if isinstance(step_output, AgentAction): # Write thought section - if hasattr(step_output, 'thought') and step_output.thought: + if hasattr(step_output, "thought") and step_output.thought: f.write("### Thought\n") f.write(f"{step_output.thought}\n\n") - + # Write tool section - if hasattr(step_output, 'tool'): + if hasattr(step_output, "tool"): f.write("### Action\n") f.write(f"**Tool:** {step_output.tool}\n") - + # if hasattr(step_output, 'tool_input'): # try: # # Try to parse and pretty print JSON input @@ -50,7 +54,7 @@ def __call__(self, step_output: Union[str, List[Tuple[Dict[str, Any], str]], Age # except (json.JSONDecodeError, TypeError): # # Fallback to raw string if not JSON # f.write(f"**Input:**\n```\n{step_output.tool_input}\n```\n") - + # # Write result section # if hasattr(step_output, 'result'): # f.write("\n### Result\n") @@ -62,22 +66,22 @@ def __call__(self, step_output: Union[str, List[Tuple[Dict[str, Any], str]], Age # except (json.JSONDecodeError, TypeError): # # Fallback to raw string if not JSON # f.write(f"```\n{step_output.result}\n```\n") - + f.write("\n") return - + # Handle single string output if isinstance(step_output, str): f.write(step_output + "\n") return - + for step in step_output: if not isinstance(step, tuple): f.write(str(step) + "\n") continue - + action, observation = step - + # Handle action section f.write("### Action\n") if isinstance(action, dict): @@ -91,21 +95,21 @@ def __call__(self, step_output: Union[str, List[Tuple[Dict[str, Any], str]], Age f.write(f"**Action Type:** {action['Action']}\n") else: f.write(f"{str(action)}\n") - + # Handle observation section f.write("\n### Observation\n") if isinstance(observation, str): # Handle special formatting for search-like results - lines = observation.split('\n') + lines = observation.split("\n") for line in lines: - if line.startswith(('Title:', 'Link:', 'Snippet:')): - key, value = line.split(':', 1) + if line.startswith(("Title:", "Link:", "Snippet:")): + key, value = line.split(":", 1) f.write(f"**{key.strip()}:**{value}\n") - elif line.startswith('-'): + elif line.startswith("-"): f.write(line + "\n") else: f.write(line + "\n") else: f.write(str(observation) + "\n") - - f.write("\n") \ No newline at end of file + + f.write("\n") diff --git a/app/modules/intelligence/agents/agents/code_gen_agent.py b/app/modules/intelligence/agents/agents/code_gen_agent.py index c543c38f..59e58440 100644 --- a/app/modules/intelligence/agents/agents/code_gen_agent.py +++ b/app/modules/intelligence/agents/agents/code_gen_agent.py @@ -1,7 +1,9 @@ +import asyncio import os -from typing import Any, Dict, List +from contextlib import redirect_stdout +from typing import Any, AsyncGenerator, Dict, List -import agentops +import aiofiles from crewai import Agent, Crew, Process, Task from app.modules.conversations.message.message_schema import NodeContext @@ -261,7 +263,7 @@ async def run( project_id: str, history: str, node_ids: List[NodeContext], - ) -> str: + ) -> AsyncGenerator[str, None]: code_results = [] if len(node_ids) > 0: code_results = await GetCodeFromMultipleNodeIdsTool( @@ -278,16 +280,34 @@ async def run( code_generator, ) - crew = Crew( - agents=[code_generator], - tasks=[generation_task], - process=Process.sequential, - verbose=False, - ) - agentops.init(os.getenv("AGENTOPS_API_KEY")) - result = await crew.kickoff_async() - agentops.end_session("Success") - return result + read_fd, write_fd = os.pipe() + + async def kickoff(): + with os.fdopen(write_fd, "w", buffering=1) as write_file: + with redirect_stdout(write_file): + crew = Crew( + agents=[code_generator], + tasks=[generation_task], + process=Process.sequential, + verbose=True, + ) + await crew.kickoff_async() + + asyncio.create_task(kickoff()) + + # Stream the output + final_answer_streaming = False + async with aiofiles.open(read_fd, mode="r") as read_file: + async for line in read_file: + if not line: + break + if final_answer_streaming: + if line.endswith("\\x1b[00m\\n"): + yield line[:-6] + else: + yield line + if "## Final Answer:" in line: + final_answer_streaming = True async def kickoff_code_generation_crew( @@ -299,10 +319,10 @@ async def kickoff_code_generation_crew( llm, mini_llm, user_id: str, -) -> str: +) -> AsyncGenerator[str, None]: provider_service = ProviderService(sql_db, user_id) crew_ai_mini_llm = provider_service.get_small_llm(agent_type=AgentType.CREWAI) crew_ai_llm = provider_service.get_large_llm(agent_type=AgentType.CREWAI) code_gen_agent = CodeGenerationAgent(sql_db, crew_ai_llm, crew_ai_mini_llm, user_id) - result = await code_gen_agent.run(query, project_id, history, node_ids) - return result + async for chunk in code_gen_agent.run(query, project_id, history, node_ids): + yield chunk diff --git a/app/modules/intelligence/agents/agents/debug_rag_agent.py b/app/modules/intelligence/agents/agents/debug_rag_agent.py index 030924dc..e5cb6d0d 100644 --- a/app/modules/intelligence/agents/agents/debug_rag_agent.py +++ b/app/modules/intelligence/agents/agents/debug_rag_agent.py @@ -1,7 +1,10 @@ +import asyncio import os -from typing import Any, Dict, List +from contextlib import redirect_stdout +from typing import Any, AsyncGenerator, Dict, List import agentops +import aiofiles from crewai import Agent, Crew, Process, Task from pydantic import BaseModel, Field @@ -227,7 +230,7 @@ async def run( chat_history: List, node_ids: List[NodeContext], file_structure: str, - ) -> str: + ) -> AsyncGenerator[str, None]: agentops.init( os.getenv("AGENTOPS_API_KEY"), default_tags=["openai-gpt-notebook"] ) @@ -236,27 +239,47 @@ async def run( code_results = await GetCodeFromMultipleNodeIdsTool( self.sql_db, self.user_id ).run_multiple(project_id, [node.node_id for node in node_ids]) - query_agent = await self.create_agents() - query_task = await self.create_tasks( + debug_agent = await self.create_agents() + debug_task = await self.create_tasks( query, project_id, chat_history, node_ids, file_structure, code_results, - query_agent, + debug_agent, ) - crew = Crew( - agents=[query_agent], - tasks=[query_task], - process=Process.sequential, - verbose=False, - ) + read_fd, write_fd = os.pipe() + + async def kickoff(): + with os.fdopen(write_fd, "w", buffering=1) as write_file: + with redirect_stdout(write_file): + crew = Crew( + agents=[debug_agent], + tasks=[debug_task], + process=Process.sequential, + verbose=True, + ) + await crew.kickoff_async() - result = await crew.kickoff_async() agentops.end_session("Success") - return result + + asyncio.create_task(kickoff()) + + # Stream the output + final_answer_streaming = False + async with aiofiles.open(read_fd, mode="r") as read_file: + async for line in read_file: + if not line: + break + if final_answer_streaming: + if line.endswith("\\x1b[00m\\n"): + yield line[:-6] + else: + yield line + if "## Final Answer:" in line: + final_answer_streaming = True async def kickoff_debug_rag_agent( @@ -268,15 +291,15 @@ async def kickoff_debug_rag_agent( llm, mini_llm, user_id: str, -) -> str: +) -> AsyncGenerator[str, None]: provider_service = ProviderService(sql_db, user_id) - crew_ai_mini_llm = provider_service.get_small_llm(agent_type=AgentType.CREWAI) crew_ai_llm = provider_service.get_large_llm(agent_type=AgentType.CREWAI) + crew_ai_mini_llm = provider_service.get_small_llm(agent_type=AgentType.CREWAI) debug_agent = DebugRAGAgent(sql_db, crew_ai_llm, crew_ai_mini_llm, user_id) file_structure = await CodeProviderService(sql_db).get_project_structure_async( project_id ) - result = await debug_agent.run( + async for chunk in debug_agent.run( query, project_id, chat_history, node_ids, file_structure - ) - return result + ): + yield chunk diff --git a/app/modules/intelligence/agents/agents/low_level_design_agent.py b/app/modules/intelligence/agents/agents/low_level_design_agent.py index 002cffab..ec6cfe15 100644 --- a/app/modules/intelligence/agents/agents/low_level_design_agent.py +++ b/app/modules/intelligence/agents/agents/low_level_design_agent.py @@ -1,6 +1,9 @@ +import asyncio import os -from typing import Dict, List +from contextlib import redirect_stdout +from typing import AsyncGenerator, Dict, List +import aiofiles from crewai import Agent, Crew, Process, Task from pydantic import BaseModel, Field @@ -165,21 +168,40 @@ async def create_tasks( async def run( self, functional_requirements: str, project_id: str - ) -> LowLevelDesignPlan: + ) -> AsyncGenerator[str, None]: codebase_analyst, design_planner = await self.create_agents() tasks = await self.create_tasks( functional_requirements, project_id, codebase_analyst, design_planner ) - crew = Crew( - agents=[codebase_analyst, design_planner], - tasks=tasks, - process=Process.sequential, - verbose=True, - ) - - result = await crew.kickoff_async() - return result + read_fd, write_fd = os.pipe() + + async def kickoff(): + with os.fdopen(write_fd, "w", buffering=1) as write_file: + with redirect_stdout(write_file): + crew = Crew( + agents=[codebase_analyst, design_planner], + tasks=tasks, + process=Process.sequential, + verbose=True, + ) + await crew.kickoff_async() + + asyncio.create_task(kickoff()) + + # Stream the output + final_answer_streaming = False + async with aiofiles.open(read_fd, mode="r") as read_file: + async for line in read_file: + if not line: + break + if final_answer_streaming: + if line.endswith("\\x1b[00m\\n"): + yield line[:-6] + else: + yield line + if "## Final Answer:" in line: + final_answer_streaming = True async def create_low_level_design_agent( @@ -188,9 +210,9 @@ async def create_low_level_design_agent( sql_db, llm, user_id: str, -) -> LowLevelDesignPlan: +) -> AsyncGenerator[str, None]: provider_service = ProviderService(sql_db, user_id) crew_ai_llm = provider_service.get_large_llm(agent_type=AgentType.CREWAI) design_agent = LowLevelDesignAgent(sql_db, crew_ai_llm, user_id) - result = await design_agent.run(functional_requirements, project_id) - return result + async for chunk in design_agent.run(functional_requirements, project_id): + yield chunk diff --git a/app/modules/intelligence/agents/agents/rag_agent.py b/app/modules/intelligence/agents/agents/rag_agent.py index 09f8cbae..7919fb78 100644 --- a/app/modules/intelligence/agents/agents/rag_agent.py +++ b/app/modules/intelligence/agents/agents/rag_agent.py @@ -1,16 +1,15 @@ import asyncio import os +from contextlib import redirect_stdout from typing import Any, AsyncGenerator, Dict, List -import aiofiles -from contextlib import redirect_stdout import agentops +import aiofiles from crewai import Agent, Crew, Process, Task from pydantic import BaseModel, Field from app.modules.code_provider.code_provider_service import CodeProviderService from app.modules.conversations.message.message_schema import NodeContext -from app.modules.intelligence.agents.agents.callback_handler import FileCallbackHandler from app.modules.intelligence.provider.provider_service import ( AgentType, ProviderService, @@ -75,7 +74,7 @@ def __init__(self, sql_db, llm, mini_llm, user_id): self.llm = llm self.mini_llm = mini_llm self.user_id = user_id - #self.callback_handler = FileCallbackHandler("rag_agent_execution.md") + # self.callback_handler = FileCallbackHandler("rag_agent_execution.md") async def create_agents(self): query_agent = Agent( @@ -106,7 +105,7 @@ async def create_agents(self): verbose=True, llm=self.llm, max_iter=self.max_iter, - #step_callback=self.callback_handler, + # step_callback=self.callback_handler, ) return query_agent @@ -282,30 +281,28 @@ async def kickoff_rag_agent( project_id ) - read_fd, write_fd = os.pipe() async def kickoff(): with os.fdopen(write_fd, "w", buffering=1) as write_file: with redirect_stdout(write_file): - await rag_agent.run( - query, project_id, chat_history, node_ids, file_structure - ) - + await rag_agent.run( + query, project_id, chat_history, node_ids, file_structure + ) asyncio.create_task(kickoff()) # Yield CrewAgent logs as they are written to the pipe final_answer_streaming = False - async with aiofiles.open(read_fd, mode='r') as read_file: + async with aiofiles.open(read_fd, mode="r") as read_file: async for line in read_file: if not line: break else: if final_answer_streaming: - if line.endswith('\\x1b[00m\\n'): + if line.endswith("\\x1b[00m\\n"): yield line[:-6] else: yield line if "## Final Answer:" in line: - final_answer_streaming = True \ No newline at end of file + final_answer_streaming = True diff --git a/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py b/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py index a5cdb2c3..68527ba4 100644 --- a/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py @@ -1,7 +1,10 @@ import json import logging import time -from typing import AsyncGenerator, List +from typing import AsyncGenerator, Dict, List, Annotated +from langgraph.types import StreamWriter +from typing_extensions import TypedDict +from langgraph.graph import StateGraph, START, END from langchain.schema import HumanMessage, SystemMessage from sqlalchemy.orm import Session @@ -28,6 +31,30 @@ def __init__(self, mini_llm, llm, db: Session): self.chain = None self.db = db + class State(TypedDict): + query: str + project_id: str + user_id: str + conversation_id: str + node_ids: List[NodeContext] + + async def _stream_code_gen_agent(self, state: State, writer: StreamWriter): + async for chunk in self.execute( + state["query"], + state["project_id"], + state["user_id"], + state["conversation_id"], + state["node_ids"], + ): + writer(chunk) + + def _create_graph(self): + graph_builder = StateGraph(CodeGenerationChatAgent.State) + graph_builder.add_node("code_gen_agent", self._stream_code_gen_agent) + graph_builder.add_edge(START, "code_gen_agent") + graph_builder.add_edge("code_gen_agent", END) + return graph_builder.compile() + async def run( self, query: str, @@ -35,8 +62,26 @@ async def run( user_id: str, conversation_id: str, node_ids: List[NodeContext], + ): + state = { + "query": query, + "project_id": project_id, + "user_id": user_id, + "conversation_id": conversation_id, + "node_ids": node_ids, + } + graph = self._create_graph() + async for chunk in graph.astream(state, stream_mode="custom"): + yield chunk + + async def execute( + self, + query: str, + project_id: str, + user_id: str, + conversation_id: str, + node_ids: List[NodeContext], ) -> AsyncGenerator[str, None]: - start_time = time.time() try: history = self.history_manager.get_session_history(user_id, conversation_id) validated_history = [ @@ -48,12 +93,8 @@ async def run( for msg in history ] - tool_results = [] citations = [] - code_gen_start_time = time.time() - - # Call multi-agent code generation instead of RAG - code_gen_result = await kickoff_code_generation_crew( + async for chunk in kickoff_code_generation_crew( query, project_id, validated_history[-5:], @@ -62,45 +103,23 @@ async def run( self.llm, self.mini_llm, user_id, - ) - - code_gen_duration = time.time() - code_gen_start_time - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of Code Generation: {code_gen_duration:.2f}s" - ) - - result = code_gen_result.raw - - tool_results = [SystemMessage(content=result)] - - add_chunk_start_time = time.time() - self.history_manager.add_message_chunk( - conversation_id, - tool_results[0].content, - MessageType.AI_GENERATED, - citations=citations, - ) - add_chunk_duration = time.time() - add_chunk_start_time - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of adding message chunk: {add_chunk_duration:.2f}s" - ) + ): + content = str(chunk) + self.history_manager.add_message_chunk( + conversation_id, + content, + MessageType.AI_GENERATED, + citations=citations, + ) + yield json.dumps({ + "citations": citations, + "message": content, + }) - # Timing for flushing message buffer - flush_buffer_start_time = time.time() self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED ) - flush_buffer_duration = time.time() - flush_buffer_start_time - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of flushing message buffer: {flush_buffer_duration:.2f}s" - ) - - yield json.dumps({"citations": citations, "message": result}) except Exception as e: logger.error(f"Error in code generation: {str(e)}") - error_message = f"An error occurred during code generation: {str(e)}" - yield json.dumps({"error": error_message}) + yield json.dumps({"error": f"An error occurred during code generation: {str(e)}"}) diff --git a/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py b/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py index 276dd31c..b6a98368 100644 --- a/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py @@ -2,7 +2,7 @@ import logging import time from functools import lru_cache -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, List, Annotated, TypedDict from langchain.schema import HumanMessage, SystemMessage from langchain_core.output_parsers import PydanticOutputParser @@ -14,6 +14,8 @@ ) from langchain_core.runnables import RunnableSequence from sqlalchemy.orm import Session +from langgraph.types import StreamWriter +from langgraph.graph import StateGraph, START, END from app.modules.conversations.message.message_model import MessageType from app.modules.conversations.message.message_schema import NodeContext @@ -83,6 +85,34 @@ async def _classify_query(self, query: str, history: List[HumanMessage]): return response.classification + class State(TypedDict): + query: str + project_id: str + user_id: str + conversation_id: str + node_ids: List[NodeContext] + logs: str + stacktrace: str + + async def _stream_rag_agent(self, state: State, writer: StreamWriter): + async for chunk in self.execute( + state["query"], + state["project_id"], + state["user_id"], + state["conversation_id"], + state["node_ids"], + state["logs"], + state["stacktrace"], + ): + writer(chunk) + + def _create_graph(self): + graph_builder = StateGraph(DebuggingChatAgent.State) + graph_builder.add_node("rag_agent", self._stream_rag_agent) + graph_builder.add_edge(START, "rag_agent") + graph_builder.add_edge("rag_agent", END) + return graph_builder.compile() + async def run( self, query: str, @@ -92,6 +122,29 @@ async def run( node_ids: List[NodeContext], logs: str = "", stacktrace: str = "", + ): + state = { + "query": query, + "project_id": project_id, + "user_id": user_id, + "conversation_id": conversation_id, + "node_ids": node_ids, + "logs": logs, + "stacktrace": stacktrace, + } + graph = self._create_graph() + async for chunk in graph.astream(state, stream_mode="custom"): + yield chunk + + async def execute( + self, + query: str, + project_id: str, + user_id: str, + conversation_id: str, + node_ids: List[NodeContext], + logs: str = "", + stacktrace: str = "", ) -> AsyncGenerator[str, None]: start_time = time.time() # Start the timer @@ -114,61 +167,31 @@ async def run( tool_results = [] citations = [] if classification == ClassificationResult.AGENT_REQUIRED: - rag_result = await kickoff_debug_rag_agent( + async for chunk in kickoff_debug_rag_agent( query, project_id, - [ - msg.content - for msg in validated_history - if isinstance(msg, HumanMessage) - ], + [msg.content for msg in validated_history if isinstance(msg, HumanMessage)], node_ids, self.db, self.llm, self.mini_llm, user_id, - ) - if rag_result.pydantic: - response = rag_result.pydantic.response - citations = rag_result.pydantic.citations - result = [node.model_dump() for node in response] - else: - result = rag_result.raw - citations = [] - - tool_results = [SystemMessage(content=result)] - add_chunk_start_time = ( - time.time() - ) # Start timer for adding message chunk - self.history_manager.add_message_chunk( - conversation_id, - tool_results[0].content, - MessageType.AI_GENERATED, - citations=citations, - ) - add_chunk_duration = ( - time.time() - add_chunk_start_time - ) # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of adding message chunk: {add_chunk_duration:.2f}s" - ) + ): + content = str(chunk) + self.history_manager.add_message_chunk( + conversation_id, + content, + MessageType.AI_GENERATED, + citations=citations, + ) + yield json.dumps({ + "citations": citations, + "message": content, + }) - # Timing for flushing message buffer - flush_buffer_start_time = ( - time.time() - ) # Start timer for flushing message buffer self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED ) - flush_buffer_duration = ( - time.time() - flush_buffer_start_time - ) # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of flushing message buffer: {flush_buffer_duration:.2f}s" - ) - yield json.dumps({"citations": citations, "message": result}) if classification != ClassificationResult.AGENT_REQUIRED: full_query = f"Query: {query}\nProject ID: {project_id}\nLogs: {logs}\nStacktrace: {stacktrace}" diff --git a/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py b/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py index 8971619e..6e15cec8 100644 --- a/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py @@ -2,7 +2,10 @@ import logging import time from functools import lru_cache -from typing import AsyncGenerator, Dict, List +from typing import AsyncGenerator, Dict, List, Annotated +from langgraph.types import StreamWriter +from typing_extensions import TypedDict +from langgraph.graph import StateGraph, START, END from langchain.schema import HumanMessage, SystemMessage from langchain_core.output_parsers import PydanticOutputParser @@ -81,6 +84,30 @@ async def _classify_query(self, query: str, history: List[HumanMessage]): return response.classification + class State(TypedDict): + query: str + project_id: str + user_id: str + conversation_id: str + node_ids: List[NodeContext] + + async def _stream_rag_agent(self, state: State, writer: StreamWriter): + async for chunk in self.execute( + state["query"], + state["project_id"], + state["user_id"], + state["conversation_id"], + state["node_ids"], + ): + writer(chunk) + + def _create_graph(self): + graph_builder = StateGraph(LLDChatAgent.State) + graph_builder.add_node("rag_agent", self._stream_rag_agent) + graph_builder.add_edge(START, "rag_agent") + graph_builder.add_edge("rag_agent", END) + return graph_builder.compile() + async def run( self, query: str, @@ -88,6 +115,25 @@ async def run( user_id: str, conversation_id: str, node_ids: List[NodeContext], + ): + state = { + "query": query, + "project_id": project_id, + "user_id": user_id, + "conversation_id": conversation_id, + "node_ids": node_ids, + } + graph = self._create_graph() + async for chunk in graph.astream(state, stream_mode="custom"): + yield chunk + + async def execute( + self, + query: str, + project_id: str, + user_id: str, + conversation_id: str, + node_ids: List[NodeContext], ) -> AsyncGenerator[str, None]: start_time = time.time() # Start the timer try: @@ -117,68 +163,31 @@ async def run( tool_results = [] citations = [] if classification == ClassificationResult.AGENT_REQUIRED: - rag_start_time = time.time() # Start timer for RAG agent - rag_result = await kickoff_rag_agent( + async for chunk in kickoff_rag_agent( query, project_id, - [ - msg.content - for msg in validated_history - if isinstance(msg, HumanMessage) - ], + [msg.content for msg in validated_history if isinstance(msg, HumanMessage)], node_ids, self.db, self.llm, self.mini_llm, user_id, - ) - rag_duration = time.time() - rag_start_time # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of RAG agent: {rag_duration:.2f}s" - ) - - if rag_result.pydantic: - citations = rag_result.pydantic.citations - response = rag_result.pydantic.response - result = [node for node in response] - else: - citations = [] - result = rag_result.raw - tool_results = [SystemMessage(content=result)] - # Timing for adding message chunk - add_chunk_start_time = ( - time.time() - ) # Start timer for adding message chunk - self.history_manager.add_message_chunk( - conversation_id, - tool_results[0].content, - MessageType.AI_GENERATED, - citations=citations, - ) - add_chunk_duration = ( - time.time() - add_chunk_start_time - ) # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of adding message chunk: {add_chunk_duration:.2f}s" - ) + ): + content = str(chunk) + self.history_manager.add_message_chunk( + conversation_id, + content, + MessageType.AI_GENERATED, + citations=citations, + ) + yield json.dumps({ + "citations": citations, + "message": content, + }) - # Timing for flushing message buffer - flush_buffer_start_time = ( - time.time() - ) # Start timer for flushing message buffer self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED ) - flush_buffer_duration = ( - time.time() - flush_buffer_start_time - ) # Calculate duration - logger.info( - f"Time elapsed since entering run: {time.time() - start_time:.2f}s, " - f"Duration of flushing message buffer: {flush_buffer_duration:.2f}s" - ) - yield json.dumps({"citations": citations, "message": result}) if classification != ClassificationResult.AGENT_REQUIRED: inputs = { diff --git a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py index d73740b5..9384dac1 100644 --- a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py @@ -2,15 +2,9 @@ import logging import time from functools import lru_cache -from typing import Any, AsyncGenerator, Dict, List, Optional -from typing import Annotated -from langgraph.types import StreamWriter - -from typing_extensions import TypedDict +from typing import AsyncGenerator, Dict, List -from langgraph.graph import StateGraph, START, END -from langgraph.graph.message import add_messages -from langchain.schema import HumanMessage, SystemMessage +from langchain.schema import HumanMessage from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ( ChatPromptTemplate, @@ -19,7 +13,10 @@ SystemMessagePromptTemplate, ) from langchain_core.runnables import RunnableSequence +from langgraph.graph import END, START, StateGraph +from langgraph.types import StreamWriter from sqlalchemy.orm import Session +from typing_extensions import TypedDict from app.modules.conversations.message.message_model import MessageType from app.modules.conversations.message.message_schema import NodeContext @@ -87,8 +84,6 @@ async def _classify_query(self, query: str, history: List[HumanMessage]): return response.classification - - class State(TypedDict): query: str project_id: str @@ -96,8 +91,6 @@ class State(TypedDict): conversation_id: str node_ids: List[NodeContext] - - async def _stream_rag_agent(self, state: State, writer: StreamWriter): async for chunk in self.execute( state["query"], @@ -106,7 +99,6 @@ async def _stream_rag_agent(self, state: State, writer: StreamWriter): state["conversation_id"], state["node_ids"], ): - writer(chunk) def _create_graph(self): @@ -120,12 +112,14 @@ def _create_graph(self): graph_builder.add_edge("rag_agent", END) return graph_builder.compile() - async def run(self, + async def run( + self, query: str, project_id: str, user_id: str, conversation_id: str, - node_ids: List[NodeContext],): + node_ids: List[NodeContext], +): state = { "query": query, "project_id": project_id, @@ -134,7 +128,7 @@ async def run(self, "node_ids": node_ids, } graph = self._create_graph() - async for chunk in graph.astream(state,stream_mode="custom"): + async for chunk in graph.astream(state, stream_mode="custom"): yield chunk async def execute( @@ -202,13 +196,11 @@ async def execute( "message": content, } ) - self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED ) - if classification != ClassificationResult.AGENT_REQUIRED: inputs = { "history": validated_history[-10:], diff --git a/app/modules/parsing/graph_construction/code_graph_service.py b/app/modules/parsing/graph_construction/code_graph_service.py index d0af621e..84d9b1ee 100644 --- a/app/modules/parsing/graph_construction/code_graph_service.py +++ b/app/modules/parsing/graph_construction/code_graph_service.py @@ -44,7 +44,6 @@ def create_and_store_graph(self, repo_dir, project_id, user_id): nx_graph = self.repo_map.create_graph(repo_dir) with self.driver.session() as session: - start_time = time.time() node_count = nx_graph.number_of_nodes() logging.info(f"Creating {node_count} nodes") diff --git a/app/modules/parsing/graph_construction/parsing_repomap.py b/app/modules/parsing/graph_construction/parsing_repomap.py index 511a0bc0..7261066e 100644 --- a/app/modules/parsing/graph_construction/parsing_repomap.py +++ b/app/modules/parsing/graph_construction/parsing_repomap.py @@ -25,7 +25,7 @@ class RepoMap: # Parsing logic adapted from aider (https://github.com/paul-gauthier/aider) - # Modified and customized for potpie's parsing needs with detailed tags, relationship tracking etc + # Modified and customized for potpie's parsing needs with detailed tags, relationship tracking etc def __init__( self, From 3c05de4e98a666dc098ea36b1339f529ca4320a0 Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Wed, 18 Dec 2024 20:11:49 +0530 Subject: [PATCH 3/6] supervisour routing --- .../conversation/conversation_service.py | 166 ++++++++++++++++-- .../intelligence/agents/agent_classifier.py | 80 +++++++++ .../intelligence/agents/agent_factory.py | 55 ++++++ .../agents/chat_agents/code_gen_chat_agent.py | 25 +-- .../chat_agents/debugging_chat_agent.py | 24 ++- .../agents/chat_agents/lld_chat_agent.py | 26 +-- .../agents/chat_agents/qna_chat_agent.py | 7 +- .../chat_agents/unit_test_chat_agent.py | 2 +- 8 files changed, 342 insertions(+), 43 deletions(-) create mode 100644 app/modules/intelligence/agents/agent_classifier.py create mode 100644 app/modules/intelligence/agents/agent_factory.py diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index ee16eea2..bca6b0c3 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -1,12 +1,17 @@ import asyncio +import json import logging from datetime import datetime, timezone -from typing import AsyncGenerator, List +from typing import AsyncGenerator, Dict, Any, List, Optional, TypedDict +from langgraph.types import StreamWriter +from fastapi import HTTPException +from langgraph.graph import END, StateGraph +from langgraph.types import Command from langchain.prompts import ChatPromptTemplate +from sqlalchemy import func from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm import Session -from sqlalchemy.sql import func from uuid6 import uuid7 from app.modules.code_provider.code_provider_service import CodeProviderService @@ -30,7 +35,10 @@ MessageResponse, NodeContext, ) + from app.modules.intelligence.agents.agent_injector_service import AgentInjectorService +from app.modules.intelligence.agents.agents_service import AgentsService +from app.modules.intelligence.agents.agent_factory import AgentFactory from app.modules.intelligence.agents.custom_agents.custom_agents_service import ( CustomAgentsService, ) @@ -47,24 +55,161 @@ class ConversationServiceError(Exception): - """Base exception class for ConversationService errors.""" + pass class ConversationNotFoundError(ConversationServiceError): - """Raised when a conversation is not found.""" + pass class MessageNotFoundError(ConversationServiceError): - """Raised when a message is not found.""" + pass class AccessTypeNotFoundError(ConversationServiceError): - """Raised when an access type is not found.""" + pass class AccessTypeReadError(ConversationServiceError): - """Raised when an access type is read-only.""" + pass + + +from langgraph.graph import END, StateGraph +from langgraph.types import Command +from typing import AsyncGenerator, Dict, Any + +class SimplifiedAgentSupervisor: + def __init__(self, db, provider_service): + self.db = db + self.provider_service = provider_service + self.agents = {} + self.classifier = None + self.agents_service = AgentsService(db) + self.agent_factory = AgentFactory(db, provider_service) + + async def initialize(self, user_id: str): + # Get available agents using AgentsService + available_agents = await self.agents_service.list_available_agents( + current_user={"user_id": user_id}, + list_system_agents=True + ) + + # Create agent instances dictionary + self.agents = { + agent.id: self.agent_factory.get_agent(agent.id, user_id) + for agent in available_agents + } + + self.llm = self.provider_service.get_small_llm(user_id) + + # Enhanced classifier prompt with agent descriptions + self.classifier_prompt = """ + Given the user query, determine which agent should handle it based on their specialties: + + Query: {query} + + Available agents and their specialties: + {agent_descriptions} + + Return ONLY the agent id and confidence score in format: agent_id|confidence + Example: debugging_agent|0.85 + """ + + # Format agent descriptions for the prompt + self.agent_descriptions = "\n".join([ + f"- {agent.id}: {agent.description}" + for agent in available_agents + ]) + class State(TypedDict): + query: str + project_id: str + conversation_id: str + response: Optional[str] + agent_id: Optional[str] + user_id: str + node_ids: List[NodeContext] + + async def classifier_node(self, state: State) -> Command: + """Classifies the query and routes to appropriate agent""" + if not state.get("query"): + return Command(update={"response": "No query provided"}, goto=END) + + # Classification using LLM with enhanced prompt + prompt = self.classifier_prompt.format( + query=state["query"], + agent_descriptions=self.agent_descriptions + ) + response = await self.llm.ainvoke(prompt) + + # Parse response + try: + agent_id, confidence = response.content.split("|") + confidence = float(confidence) + except (ValueError, TypeError): + return Command( + update={"response": "Error in classification format"}, + goto=END + ) + if confidence < 0.5 or agent_id not in self.agents: + return Command( + update={"agent_id":state["agent_id"]}, + goto=state["agent_id"] + ) + + return Command( + update={"agent_id": agent_id}, + goto=agent_id + ) + + async def agent_node(self, state: State, writer: StreamWriter): + """Creates a node function for a specific agent""" + agent = self.agents[state["agent_id"]] + async for chunk in agent.run( + query=state["query"], + project_id=state["project_id"], + conversation_id=state["conversation_id"], + user_id=state["user_id"], + node_ids=state["node_ids"] + ): + if isinstance(chunk, str): + writer(chunk) + + + + + def build_graph(self) -> StateGraph: + """Builds the graph with classifier and agent nodes""" + builder = StateGraph(self.State) + + # Add classifier as entry point + builder.add_node("classifier", self.classifier_node) + #builder.add_edge("classifier", END) + + # # Add agent nodes + #node_func = await self.agent_node(self.State, StreamWriter) + for agent_id in self.agents: + builder.add_node(agent_id, self.agent_node) + builder.add_edge(agent_id, END) + + builder.set_entry_point("classifier") + return builder.compile() + + async def process_query(self, query: str, project_id: str, conversation_id: str, user_id: str, node_ids: List[NodeContext], agent_id: str) -> AsyncGenerator[Dict[str, Any], None]: + """Main method to process queries""" + state = { + "query": query, + "project_id": project_id, + "conversation_id": conversation_id, + "response": None, + "user_id": user_id, + "node_ids": node_ids, + "agent_id": agent_id + } + + graph = self.build_graph() + async for chunk in graph.astream(state, stream_mode="custom"): + yield chunk class ConversationService: def __init__( @@ -450,7 +595,8 @@ async def _generate_and_stream_ai_response( agent_id = conversation.agent_ids[0] project_id = conversation.project_ids[0] if conversation.project_ids else None - + supervisor = SimplifiedAgentSupervisor(self.sql_db, self.provider_service) + await supervisor.initialize(user_id) try: agent = self.agent_injector_service.get_agent(agent_id) @@ -466,8 +612,8 @@ async def _generate_and_stream_ai_response( yield response else: # For other agents that support streaming - async for chunk in agent.run( - query, project_id, user_id, conversation.id, node_ids + async for chunk in supervisor.process_query( + query, project_id, conversation.id, user_id, node_ids, agent_id ): yield chunk diff --git a/app/modules/intelligence/agents/agent_classifier.py b/app/modules/intelligence/agents/agent_classifier.py new file mode 100644 index 00000000..59669846 --- /dev/null +++ b/app/modules/intelligence/agents/agent_classifier.py @@ -0,0 +1,80 @@ +from typing import List +from langchain.schema import BaseMessage +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import PydanticOutputParser +from pydantic import BaseModel, Field + +from app.modules.conversations.message.message_schema import MessageResponse +from app.modules.intelligence.prompts.classification_prompts import ClassificationResult + +class AgentClassification(BaseModel): + agent_id: str = Field(..., description="ID of the agent that should handle the query") + confidence: float = Field(..., description="Confidence score between 0 and 1") + reasoning: str = Field(..., description="Reasoning behind the agent selection") + +class AgentClassifier: + def __init__(self, llm, available_agents): + self.llm = llm + self.available_agents = available_agents + self.parser = PydanticOutputParser(pydantic_object=AgentClassification) + + def create_prompt(self) -> str: + + return """You are an expert agent router. + User's query: {query} + Conversation history: {history} + + Based on the user's query and conversation history, + select the most appropriate agent from the following options: + + Available Agents: + {agents_desc} + + Analyze the query and select the agent that best matches the user's needs. + Consider: + 1. The specific task or question type + 2. Required expertise + 3. Context from conversation history + 4. Any explicit agent requests + + {format_instructions} + """ + + async def classify(self, messages: List[MessageResponse]) -> AgentClassification: + """Classify the conversation and determine which agent should handle it""" + + if not messages: + return AgentClassification( + agent_id=self.available_agents[0].id, # Default to first agent + confidence=0.0, + reasoning="No messages to classify" + ) + + # Format agent descriptions + agents_desc = "\n".join([ + f"{i+1}. {agent.id}: {agent.description}" + for i, agent in enumerate(self.available_agents) + ]) + + # Get the last message and up to 10 messages of history + last_message = messages[-1].content if messages else "" + history = [msg.content for msg in messages[-10:]] if len(messages) > 1 else [] + + inputs = { + "query": last_message, + "history": history, + "agents_desc": agents_desc, + "format_instructions": self.parser.get_format_instructions() + } + + # Rest of the classification logic... + + prompt = ChatPromptTemplate.from_template(self.create_prompt()) + + chain = prompt | self.llm | self.parser + + result = await chain.ainvoke( + inputs + ) + + return result \ No newline at end of file diff --git a/app/modules/intelligence/agents/agent_factory.py b/app/modules/intelligence/agents/agent_factory.py new file mode 100644 index 00000000..903c787c --- /dev/null +++ b/app/modules/intelligence/agents/agent_factory.py @@ -0,0 +1,55 @@ +from typing import Dict, Any +from sqlalchemy.orm import Session + +from app.modules.intelligence.provider.provider_service import AgentType, ProviderService +from app.modules.intelligence.agents.chat_agents.code_changes_chat_agent import CodeChangesChatAgent +from app.modules.intelligence.agents.chat_agents.debugging_chat_agent import DebuggingChatAgent +from app.modules.intelligence.agents.chat_agents.qna_chat_agent import QNAChatAgent +from app.modules.intelligence.agents.chat_agents.unit_test_chat_agent import UnitTestAgent +from app.modules.intelligence.agents.chat_agents.integration_test_chat_agent import IntegrationTestChatAgent +from app.modules.intelligence.agents.chat_agents.lld_chat_agent import LLDChatAgent +from app.modules.intelligence.agents.chat_agents.code_gen_chat_agent import CodeGenerationChatAgent +from app.modules.intelligence.agents.custom_agents.custom_agent import CustomAgent + +class AgentFactory: + def __init__(self, db: Session, provider_service: ProviderService): + self.db = db + self.provider_service = provider_service + self._agent_cache: Dict[str, Any] = {} + + def get_agent(self, agent_id: str, user_id: str) -> Any: + """Get or create an agent instance""" + cache_key = f"{agent_id}_{user_id}" + + if cache_key in self._agent_cache: + return self._agent_cache[cache_key] + + mini_llm = self.provider_service.get_small_llm(agent_type=AgentType.LANGCHAIN) + reasoning_llm = self.provider_service.get_large_llm(agent_type=AgentType.LANGCHAIN) + + agent = self._create_agent(agent_id, mini_llm, reasoning_llm, user_id) + self._agent_cache[cache_key] = agent + return agent + + def _create_agent(self, agent_id: str, mini_llm, reasoning_llm, user_id: str) -> Any: + """Create a new agent instance""" + agent_map = { + "debugging_agent": lambda: DebuggingChatAgent(mini_llm, reasoning_llm, self.db), + "codebase_qna_agent": lambda: QNAChatAgent(mini_llm, reasoning_llm, self.db), + "unit_test_agent": lambda: UnitTestAgent(mini_llm, reasoning_llm, self.db), + "integration_test_agent": lambda: IntegrationTestChatAgent(mini_llm, reasoning_llm, self.db), + "code_changes_agent": lambda: CodeChangesChatAgent(mini_llm, reasoning_llm, self.db), + "LLD_agent": lambda: LLDChatAgent(mini_llm, reasoning_llm, self.db), + "code_generation_agent": lambda: CodeGenerationChatAgent(mini_llm, reasoning_llm, self.db), + } + + if agent_id in agent_map: + return agent_map[agent_id]() + + # If not a system agent, create custom agent + return CustomAgent( + llm=reasoning_llm, + db=self.db, + agent_id=agent_id, + user_id=user_id + ) \ No newline at end of file diff --git a/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py b/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py index 68527ba4..c676b590 100644 --- a/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/code_gen_chat_agent.py @@ -1,13 +1,12 @@ import json import logging -import time -from typing import AsyncGenerator, Dict, List, Annotated -from langgraph.types import StreamWriter -from typing_extensions import TypedDict -from langgraph.graph import StateGraph, START, END +from typing import AsyncGenerator, List -from langchain.schema import HumanMessage, SystemMessage +from langchain.schema import HumanMessage +from langgraph.graph import END, START, StateGraph +from langgraph.types import StreamWriter from sqlalchemy.orm import Session +from typing_extensions import TypedDict from app.modules.conversations.message.message_model import MessageType from app.modules.conversations.message.message_schema import NodeContext @@ -111,10 +110,12 @@ async def execute( MessageType.AI_GENERATED, citations=citations, ) - yield json.dumps({ - "citations": citations, - "message": content, - }) + yield json.dumps( + { + "citations": citations, + "message": content, + } + ) self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED @@ -122,4 +123,6 @@ async def execute( except Exception as e: logger.error(f"Error in code generation: {str(e)}") - yield json.dumps({"error": f"An error occurred during code generation: {str(e)}"}) + yield json.dumps( + {"error": f"An error occurred during code generation: {str(e)}"} + ) diff --git a/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py b/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py index b6a98368..4c08535e 100644 --- a/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/debugging_chat_agent.py @@ -2,9 +2,9 @@ import logging import time from functools import lru_cache -from typing import AsyncGenerator, Dict, List, Annotated, TypedDict +from typing import AsyncGenerator, Dict, List, TypedDict -from langchain.schema import HumanMessage, SystemMessage +from langchain.schema import HumanMessage from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ( ChatPromptTemplate, @@ -13,9 +13,9 @@ SystemMessagePromptTemplate, ) from langchain_core.runnables import RunnableSequence -from sqlalchemy.orm import Session +from langgraph.graph import END, START, StateGraph from langgraph.types import StreamWriter -from langgraph.graph import StateGraph, START, END +from sqlalchemy.orm import Session from app.modules.conversations.message.message_model import MessageType from app.modules.conversations.message.message_schema import NodeContext @@ -170,7 +170,11 @@ async def execute( async for chunk in kickoff_debug_rag_agent( query, project_id, - [msg.content for msg in validated_history if isinstance(msg, HumanMessage)], + [ + msg.content + for msg in validated_history + if isinstance(msg, HumanMessage) + ], node_ids, self.db, self.llm, @@ -184,10 +188,12 @@ async def execute( MessageType.AI_GENERATED, citations=citations, ) - yield json.dumps({ - "citations": citations, - "message": content, - }) + yield json.dumps( + { + "citations": citations, + "message": content, + } + ) self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED diff --git a/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py b/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py index 6e15cec8..5d341f3e 100644 --- a/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/lld_chat_agent.py @@ -2,12 +2,9 @@ import logging import time from functools import lru_cache -from typing import AsyncGenerator, Dict, List, Annotated -from langgraph.types import StreamWriter -from typing_extensions import TypedDict -from langgraph.graph import StateGraph, START, END +from typing import AsyncGenerator, Dict, List -from langchain.schema import HumanMessage, SystemMessage +from langchain.schema import HumanMessage from langchain_core.output_parsers import PydanticOutputParser from langchain_core.prompts import ( ChatPromptTemplate, @@ -16,7 +13,10 @@ SystemMessagePromptTemplate, ) from langchain_core.runnables import RunnableSequence +from langgraph.graph import END, START, StateGraph +from langgraph.types import StreamWriter from sqlalchemy.orm import Session +from typing_extensions import TypedDict from app.modules.conversations.message.message_model import MessageType from app.modules.conversations.message.message_schema import NodeContext @@ -166,7 +166,11 @@ async def execute( async for chunk in kickoff_rag_agent( query, project_id, - [msg.content for msg in validated_history if isinstance(msg, HumanMessage)], + [ + msg.content + for msg in validated_history + if isinstance(msg, HumanMessage) + ], node_ids, self.db, self.llm, @@ -180,10 +184,12 @@ async def execute( MessageType.AI_GENERATED, citations=citations, ) - yield json.dumps({ - "citations": citations, - "message": content, - }) + yield json.dumps( + { + "citations": citations, + "message": content, + } + ) self.history_manager.flush_message_buffer( conversation_id, MessageType.AI_GENERATED diff --git a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py index 9384dac1..ef72ac0d 100644 --- a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py @@ -110,6 +110,7 @@ def _create_graph(self): ) graph_builder.add_edge(START, "rag_agent") graph_builder.add_edge("rag_agent", END) + graph_builder.set_entry_point("rag_agent") return graph_builder.compile() async def run( @@ -119,7 +120,7 @@ async def run( user_id: str, conversation_id: str, node_ids: List[NodeContext], -): + ): state = { "query": query, "project_id": project_id, @@ -129,7 +130,9 @@ async def run( } graph = self._create_graph() async for chunk in graph.astream(state, stream_mode="custom"): - yield chunk + if isinstance(chunk, str): + yield chunk + async def execute( self, diff --git a/app/modules/intelligence/agents/chat_agents/unit_test_chat_agent.py b/app/modules/intelligence/agents/chat_agents/unit_test_chat_agent.py index f5769b21..f3129fe1 100644 --- a/app/modules/intelligence/agents/chat_agents/unit_test_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/unit_test_chat_agent.py @@ -96,7 +96,7 @@ async def run( try: if not self.chain: self.chain = await self._create_chain() - + citations = [] if not node_ids: content = "It looks like there is no context selected. Please type @ followed by file or function name to interact with the unit test agent" self.history_manager.add_message_chunk( From ade7591a1870ec0201cb93e1b4e094afa43e2edc Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Fri, 20 Dec 2024 17:51:27 +0530 Subject: [PATCH 4/6] update classifier prompt --- .../conversation/conversation_service.py | 48 +++++++++-- .../intelligence/agents/agent_classifier.py | 80 ------------------- .../intelligence/agents/agents_service.py | 2 +- 3 files changed, 41 insertions(+), 89 deletions(-) delete mode 100644 app/modules/intelligence/agents/agent_classifier.py diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index bca6b0c3..350a068b 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -104,15 +104,46 @@ async def initialize(self, user_id: str): # Enhanced classifier prompt with agent descriptions self.classifier_prompt = """ - Given the user query, determine which agent should handle it based on their specialties: - + Given the user query and the current agent ID, select the most appropriate agent by comparing the query’s requirements with each agent’s specialties. + Query: {query} - + Current Agent ID: {agent_id} + Available agents and their specialties: {agent_descriptions} - - Return ONLY the agent id and confidence score in format: agent_id|confidence - Example: debugging_agent|0.85 + + Follow the instructions below to determine the best matching agent and provide a confidence score: + + Analysis Instructions (DO NOT include these instructions in the final answer): + 1. **Semantic Analysis:** + - Identify the key topics, technical terms, and the user’s intent from the query. + - Compare these elements to each agent’s detailed specialty description. + - Focus on specific skills, tools, frameworks, and domain expertise mentioned. + + 2. **Contextual Weighting:** + - If the query strongly aligns with the current agent’s known capabilities, add +0.15 confidence for direct core expertise and +0.1 for related domain knowledge. + - If the query introduces new topics outside the current agent’s domain, do not apply the current agent bias. Instead, evaluate all agents equally based on their described expertise. + + 3. **Multi-Agent Evaluation:** + - Consider all agents’ described specialties thoroughly, not just the current agent. + - For overlapping capabilities, favor the agent with more specialized expertise or more relevant tools/methodologies. + - If no agent clearly surpasses a 0.5 confidence threshold, select the agent with the highest confidence score, even if it is below 0.5. + + 4. **Confidence Scoring Guidelines:** + - 0.9-1.0: Ideal match with the agent’s core, primary expertise. + - 0.7-0.9: Strong match with the agent’s known capabilities. + - 0.5-0.7: Partial or related match, not a direct specialty. + - Below 0.5: Weak match; consider if another agent is more suitable, but still choose the best available option. + + Final Output Requirements: + - Return ONLY the chosen agent_id and the confidence score in the format: + `agent_id|confidence` + + Examples: + - Direct expertise match: `debugging_agent|0.95` + - Related capability (current agent): `current_agent_id|0.75` + - Need different expertise: `ml_training_agent|0.85` + - Overlapping domains, choose more specialized: `choose_higher_expertise_agent|0.80` """ # Format agent descriptions for the prompt @@ -136,8 +167,9 @@ async def classifier_node(self, state: State) -> Command: # Classification using LLM with enhanced prompt prompt = self.classifier_prompt.format( - query=state["query"], - agent_descriptions=self.agent_descriptions + query=state["query"], + agent_id=state["agent_id"], + agent_descriptions=self.agent_descriptions, ) response = await self.llm.ainvoke(prompt) diff --git a/app/modules/intelligence/agents/agent_classifier.py b/app/modules/intelligence/agents/agent_classifier.py deleted file mode 100644 index 59669846..00000000 --- a/app/modules/intelligence/agents/agent_classifier.py +++ /dev/null @@ -1,80 +0,0 @@ -from typing import List -from langchain.schema import BaseMessage -from langchain_core.prompts import ChatPromptTemplate -from langchain_core.output_parsers import PydanticOutputParser -from pydantic import BaseModel, Field - -from app.modules.conversations.message.message_schema import MessageResponse -from app.modules.intelligence.prompts.classification_prompts import ClassificationResult - -class AgentClassification(BaseModel): - agent_id: str = Field(..., description="ID of the agent that should handle the query") - confidence: float = Field(..., description="Confidence score between 0 and 1") - reasoning: str = Field(..., description="Reasoning behind the agent selection") - -class AgentClassifier: - def __init__(self, llm, available_agents): - self.llm = llm - self.available_agents = available_agents - self.parser = PydanticOutputParser(pydantic_object=AgentClassification) - - def create_prompt(self) -> str: - - return """You are an expert agent router. - User's query: {query} - Conversation history: {history} - - Based on the user's query and conversation history, - select the most appropriate agent from the following options: - - Available Agents: - {agents_desc} - - Analyze the query and select the agent that best matches the user's needs. - Consider: - 1. The specific task or question type - 2. Required expertise - 3. Context from conversation history - 4. Any explicit agent requests - - {format_instructions} - """ - - async def classify(self, messages: List[MessageResponse]) -> AgentClassification: - """Classify the conversation and determine which agent should handle it""" - - if not messages: - return AgentClassification( - agent_id=self.available_agents[0].id, # Default to first agent - confidence=0.0, - reasoning="No messages to classify" - ) - - # Format agent descriptions - agents_desc = "\n".join([ - f"{i+1}. {agent.id}: {agent.description}" - for i, agent in enumerate(self.available_agents) - ]) - - # Get the last message and up to 10 messages of history - last_message = messages[-1].content if messages else "" - history = [msg.content for msg in messages[-10:]] if len(messages) > 1 else [] - - inputs = { - "query": last_message, - "history": history, - "agents_desc": agents_desc, - "format_instructions": self.parser.get_format_instructions() - } - - # Rest of the classification logic... - - prompt = ChatPromptTemplate.from_template(self.create_prompt()) - - chain = prompt | self.llm | self.parser - - result = await chain.ainvoke( - inputs - ) - - return result \ No newline at end of file diff --git a/app/modules/intelligence/agents/agents_service.py b/app/modules/intelligence/agents/agents_service.py index 7619ad3b..ee1704e3 100644 --- a/app/modules/intelligence/agents/agents_service.py +++ b/app/modules/intelligence/agents/agents_service.py @@ -52,7 +52,7 @@ async def list_available_agents( AgentInfo( id="code_changes_agent", name="Code Changes Agent", - description="An agent specialized in generating detailed analysis of code changes in your current branch compared to default branch. Works best with Py, JS, TS", + description="An agent specialized in generating blast radius of the code changes in your current branch compared to default branch. Use this for functional review of your code changes. Works best with Py, JS, TS", status="SYSTEM", ), AgentInfo( From b280503d198c5d413801f6c47e784294859c0642 Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Fri, 20 Dec 2024 17:52:24 +0530 Subject: [PATCH 5/6] pre-commit --- .../conversation/conversation_service.py | 70 +++++++--------- .../intelligence/agents/agent_factory.py | 80 ++++++++++++------- .../agents/chat_agents/qna_chat_agent.py | 1 - 3 files changed, 84 insertions(+), 67 deletions(-) diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 350a068b..3cb9d302 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -1,14 +1,11 @@ import asyncio -import json import logging from datetime import datetime, timezone -from typing import AsyncGenerator, Dict, Any, List, Optional, TypedDict -from langgraph.types import StreamWriter +from typing import Any, AsyncGenerator, Dict, List, Optional, TypedDict -from fastapi import HTTPException -from langgraph.graph import END, StateGraph -from langgraph.types import Command from langchain.prompts import ChatPromptTemplate +from langgraph.graph import END, StateGraph +from langgraph.types import Command, StreamWriter from sqlalchemy import func from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.orm import Session @@ -35,10 +32,9 @@ MessageResponse, NodeContext, ) - +from app.modules.intelligence.agents.agent_factory import AgentFactory from app.modules.intelligence.agents.agent_injector_service import AgentInjectorService from app.modules.intelligence.agents.agents_service import AgentsService -from app.modules.intelligence.agents.agent_factory import AgentFactory from app.modules.intelligence.agents.custom_agents.custom_agents_service import ( CustomAgentsService, ) @@ -74,10 +70,6 @@ class AccessTypeReadError(ConversationServiceError): pass -from langgraph.graph import END, StateGraph -from langgraph.types import Command -from typing import AsyncGenerator, Dict, Any - class SimplifiedAgentSupervisor: def __init__(self, db, provider_service): self.db = db @@ -90,10 +82,9 @@ def __init__(self, db, provider_service): async def initialize(self, user_id: str): # Get available agents using AgentsService available_agents = await self.agents_service.list_available_agents( - current_user={"user_id": user_id}, - list_system_agents=True + current_user={"user_id": user_id}, list_system_agents=True ) - + # Create agent instances dictionary self.agents = { agent.id: self.agent_factory.get_agent(agent.id, user_id) @@ -147,10 +138,10 @@ async def initialize(self, user_id: str): """ # Format agent descriptions for the prompt - self.agent_descriptions = "\n".join([ - f"- {agent.id}: {agent.description}" - for agent in available_agents - ]) + self.agent_descriptions = "\n".join( + [f"- {agent.id}: {agent.description}" for agent in available_agents] + ) + class State(TypedDict): query: str project_id: str @@ -167,32 +158,27 @@ async def classifier_node(self, state: State) -> Command: # Classification using LLM with enhanced prompt prompt = self.classifier_prompt.format( - query=state["query"], + query=state["query"], agent_id=state["agent_id"], agent_descriptions=self.agent_descriptions, ) response = await self.llm.ainvoke(prompt) - + # Parse response try: agent_id, confidence = response.content.split("|") confidence = float(confidence) except (ValueError, TypeError): return Command( - update={"response": "Error in classification format"}, - goto=END + update={"response": "Error in classification format"}, goto=END ) if confidence < 0.5 or agent_id not in self.agents: return Command( - update={"agent_id":state["agent_id"]}, - goto=state["agent_id"] + update={"agent_id": state["agent_id"]}, goto=state["agent_id"] ) - return Command( - update={"agent_id": agent_id}, - goto=agent_id - ) + return Command(update={"agent_id": agent_id}, goto=agent_id) async def agent_node(self, state: State, writer: StreamWriter): """Creates a node function for a specific agent""" @@ -202,24 +188,21 @@ async def agent_node(self, state: State, writer: StreamWriter): project_id=state["project_id"], conversation_id=state["conversation_id"], user_id=state["user_id"], - node_ids=state["node_ids"] + node_ids=state["node_ids"], ): if isinstance(chunk, str): writer(chunk) - - - def build_graph(self) -> StateGraph: """Builds the graph with classifier and agent nodes""" builder = StateGraph(self.State) - + # Add classifier as entry point builder.add_node("classifier", self.classifier_node) - #builder.add_edge("classifier", END) + # builder.add_edge("classifier", END) # # Add agent nodes - #node_func = await self.agent_node(self.State, StreamWriter) + # node_func = await self.agent_node(self.State, StreamWriter) for agent_id in self.agents: builder.add_node(agent_id, self.agent_node) builder.add_edge(agent_id, END) @@ -227,22 +210,31 @@ def build_graph(self) -> StateGraph: builder.set_entry_point("classifier") return builder.compile() - async def process_query(self, query: str, project_id: str, conversation_id: str, user_id: str, node_ids: List[NodeContext], agent_id: str) -> AsyncGenerator[Dict[str, Any], None]: + async def process_query( + self, + query: str, + project_id: str, + conversation_id: str, + user_id: str, + node_ids: List[NodeContext], + agent_id: str, + ) -> AsyncGenerator[Dict[str, Any], None]: """Main method to process queries""" state = { "query": query, - "project_id": project_id, + "project_id": project_id, "conversation_id": conversation_id, "response": None, "user_id": user_id, "node_ids": node_ids, - "agent_id": agent_id + "agent_id": agent_id, } graph = self.build_graph() async for chunk in graph.astream(state, stream_mode="custom"): yield chunk + class ConversationService: def __init__( self, diff --git a/app/modules/intelligence/agents/agent_factory.py b/app/modules/intelligence/agents/agent_factory.py index 903c787c..a6f9a606 100644 --- a/app/modules/intelligence/agents/agent_factory.py +++ b/app/modules/intelligence/agents/agent_factory.py @@ -1,55 +1,81 @@ -from typing import Dict, Any +from typing import Any, Dict + from sqlalchemy.orm import Session -from app.modules.intelligence.provider.provider_service import AgentType, ProviderService -from app.modules.intelligence.agents.chat_agents.code_changes_chat_agent import CodeChangesChatAgent -from app.modules.intelligence.agents.chat_agents.debugging_chat_agent import DebuggingChatAgent -from app.modules.intelligence.agents.chat_agents.qna_chat_agent import QNAChatAgent -from app.modules.intelligence.agents.chat_agents.unit_test_chat_agent import UnitTestAgent -from app.modules.intelligence.agents.chat_agents.integration_test_chat_agent import IntegrationTestChatAgent +from app.modules.intelligence.agents.chat_agents.code_changes_chat_agent import ( + CodeChangesChatAgent, +) +from app.modules.intelligence.agents.chat_agents.code_gen_chat_agent import ( + CodeGenerationChatAgent, +) +from app.modules.intelligence.agents.chat_agents.debugging_chat_agent import ( + DebuggingChatAgent, +) +from app.modules.intelligence.agents.chat_agents.integration_test_chat_agent import ( + IntegrationTestChatAgent, +) from app.modules.intelligence.agents.chat_agents.lld_chat_agent import LLDChatAgent -from app.modules.intelligence.agents.chat_agents.code_gen_chat_agent import CodeGenerationChatAgent +from app.modules.intelligence.agents.chat_agents.qna_chat_agent import QNAChatAgent +from app.modules.intelligence.agents.chat_agents.unit_test_chat_agent import ( + UnitTestAgent, +) from app.modules.intelligence.agents.custom_agents.custom_agent import CustomAgent +from app.modules.intelligence.provider.provider_service import ( + AgentType, + ProviderService, +) + class AgentFactory: def __init__(self, db: Session, provider_service: ProviderService): self.db = db self.provider_service = provider_service self._agent_cache: Dict[str, Any] = {} - + def get_agent(self, agent_id: str, user_id: str) -> Any: """Get or create an agent instance""" cache_key = f"{agent_id}_{user_id}" - + if cache_key in self._agent_cache: return self._agent_cache[cache_key] - + mini_llm = self.provider_service.get_small_llm(agent_type=AgentType.LANGCHAIN) - reasoning_llm = self.provider_service.get_large_llm(agent_type=AgentType.LANGCHAIN) - + reasoning_llm = self.provider_service.get_large_llm( + agent_type=AgentType.LANGCHAIN + ) + agent = self._create_agent(agent_id, mini_llm, reasoning_llm, user_id) self._agent_cache[cache_key] = agent return agent - - def _create_agent(self, agent_id: str, mini_llm, reasoning_llm, user_id: str) -> Any: + + def _create_agent( + self, agent_id: str, mini_llm, reasoning_llm, user_id: str + ) -> Any: """Create a new agent instance""" agent_map = { - "debugging_agent": lambda: DebuggingChatAgent(mini_llm, reasoning_llm, self.db), - "codebase_qna_agent": lambda: QNAChatAgent(mini_llm, reasoning_llm, self.db), + "debugging_agent": lambda: DebuggingChatAgent( + mini_llm, reasoning_llm, self.db + ), + "codebase_qna_agent": lambda: QNAChatAgent( + mini_llm, reasoning_llm, self.db + ), "unit_test_agent": lambda: UnitTestAgent(mini_llm, reasoning_llm, self.db), - "integration_test_agent": lambda: IntegrationTestChatAgent(mini_llm, reasoning_llm, self.db), - "code_changes_agent": lambda: CodeChangesChatAgent(mini_llm, reasoning_llm, self.db), + "integration_test_agent": lambda: IntegrationTestChatAgent( + mini_llm, reasoning_llm, self.db + ), + "code_changes_agent": lambda: CodeChangesChatAgent( + mini_llm, reasoning_llm, self.db + ), "LLD_agent": lambda: LLDChatAgent(mini_llm, reasoning_llm, self.db), - "code_generation_agent": lambda: CodeGenerationChatAgent(mini_llm, reasoning_llm, self.db), + "code_generation_agent": lambda: CodeGenerationChatAgent( + mini_llm, reasoning_llm, self.db + ), } - + if agent_id in agent_map: return agent_map[agent_id]() - + # If not a system agent, create custom agent return CustomAgent( - llm=reasoning_llm, - db=self.db, - agent_id=agent_id, - user_id=user_id - ) \ No newline at end of file + llm=reasoning_llm, db=self.db, agent_id=agent_id, user_id=user_id + ) diff --git a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py index ef72ac0d..cacfdf9b 100644 --- a/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py +++ b/app/modules/intelligence/agents/chat_agents/qna_chat_agent.py @@ -132,7 +132,6 @@ async def run( async for chunk in graph.astream(state, stream_mode="custom"): if isinstance(chunk, str): yield chunk - async def execute( self, From 5f428d6aab373c6caba926d27d3d112fd3babe7a Mon Sep 17 00:00:00 2001 From: dhirenmathur Date: Fri, 20 Dec 2024 17:54:14 +0530 Subject: [PATCH 6/6] pre-commit --- .../conversations/conversation/conversation_service.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/app/modules/conversations/conversation/conversation_service.py b/app/modules/conversations/conversation/conversation_service.py index 3cb9d302..c6578706 100644 --- a/app/modules/conversations/conversation/conversation_service.py +++ b/app/modules/conversations/conversation/conversation_service.py @@ -80,12 +80,10 @@ def __init__(self, db, provider_service): self.agent_factory = AgentFactory(db, provider_service) async def initialize(self, user_id: str): - # Get available agents using AgentsService available_agents = await self.agents_service.list_available_agents( current_user={"user_id": user_id}, list_system_agents=True ) - # Create agent instances dictionary self.agents = { agent.id: self.agent_factory.get_agent(agent.id, user_id) for agent in available_agents @@ -93,7 +91,6 @@ async def initialize(self, user_id: str): self.llm = self.provider_service.get_small_llm(user_id) - # Enhanced classifier prompt with agent descriptions self.classifier_prompt = """ Given the user query and the current agent ID, select the most appropriate agent by comparing the query’s requirements with each agent’s specialties. @@ -137,7 +134,6 @@ async def initialize(self, user_id: str): - Overlapping domains, choose more specialized: `choose_higher_expertise_agent|0.80` """ - # Format agent descriptions for the prompt self.agent_descriptions = "\n".join( [f"- {agent.id}: {agent.description}" for agent in available_agents] ) @@ -164,7 +160,6 @@ async def classifier_node(self, state: State) -> Command: ) response = await self.llm.ainvoke(prompt) - # Parse response try: agent_id, confidence = response.content.split("|") confidence = float(confidence) @@ -199,10 +194,8 @@ def build_graph(self) -> StateGraph: # Add classifier as entry point builder.add_node("classifier", self.classifier_node) - # builder.add_edge("classifier", END) - # # Add agent nodes - # node_func = await self.agent_node(self.State, StreamWriter) + # Add agent nodes for agent_id in self.agents: builder.add_node(agent_id, self.agent_node) builder.add_edge(agent_id, END)