Skip to content

Commit

Permalink
Merge pull request #243 from potpie-ai/cs_optimization
Browse files Browse the repository at this point in the history
avoiding creation of multiple agent nodes to improve latency
  • Loading branch information
dhirenmathur authored Feb 5, 2025
2 parents ca1a73b + 89fc352 commit 0360f58
Showing 1 changed file with 58 additions and 49 deletions.
107 changes: 58 additions & 49 deletions app/modules/conversations/conversation/conversation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,18 @@ class SimplifiedAgentSupervisor:
def __init__(self, db, provider_service):
self.db = db
self.provider_service = provider_service
self.agents = {}
self.agent = None
self.current_agent_id = None
self.classifier = None
self.agents_service = AgentsService(db)
self.agent_factory = AgentFactory(db, provider_service)
self.available_agents = []

async def initialize(self, user_id: str):
available_agents = await self.agents_service.list_available_agents(
self.available_agents = await self.agents_service.list_available_agents(
current_user={"user_id": user_id}, list_system_agents=True
)
self.available_agents = available_agents
self.agents = {
agent.id: self.agent_factory.get_agent(agent.id, user_id)
for agent in available_agents
}
self.llm = self.provider_service.get_small_llm(user_id)

self.classifier_prompt = """
Given the user query and the current agent ID, select the most appropriate agent by comparing the query’s requirements with each agent’s specialties.
Expand Down Expand Up @@ -135,7 +130,7 @@ async def initialize(self, user_id: str):
"""

self.agent_descriptions = "\n".join(
[f"- {agent.id}: {agent.description}" for agent in available_agents]
[f"- {agent.id}: {agent.description}" for agent in self.available_agents]
)

class State(TypedDict):
Expand All @@ -151,69 +146,83 @@ async def classifier_node(self, state: State) -> Command:
"""Classifies the query and routes to appropriate agent"""
if not state.get("query"):
return Command(update={"response": "No query provided"}, goto=END)
agent_list = {agent.id: agent.status for agent in self.available_agents}

# Do not route for custom agents
if (
state["agent_id"] in agent_list
and agent_list[state["agent_id"]] != "SYSTEM"
):
return Command(
update={"agent_id": state["agent_id"]}, goto=state["agent_id"]
)

# Classification using LLM with enhanced prompt
agent_list = {agent.id: agent.status for agent in self.available_agents}

# First check - if this is a custom agent (non-SYSTEM), route directly
if state["agent_id"] in agent_list and agent_list[state["agent_id"]] != "SYSTEM":
# Initialize the agent if needed
if not self.agent or self.current_agent_id != state["agent_id"]:
try:
self.agent = self.agent_factory.get_agent(state["agent_id"], state["user_id"])
self.current_agent_id = state["agent_id"]
except Exception as e:
logger.error(f"Failed to create agent {state['agent_id']}: {e}")
return Command(update={"response": "Failed to initialize agent"}, goto=END)
return Command(update={"agent_id": state["agent_id"]}, goto="agent_node")

# For system agents, perform classification
prompt = self.classifier_prompt.format(
query=state["query"],
agent_id=state["agent_id"],
agent_descriptions=self.agent_descriptions,
)

response = await self.llm.ainvoke(prompt)
response = response.content.strip("`")
try:
agent_id, confidence = response.split("|")
confidence = float(confidence)
selected_agent_id = agent_id if confidence >= 0.5 and agent_id in agent_list else state["agent_id"]
except (ValueError, TypeError):
return Command(
update={"response": "Error in classification format"}, goto=END
)
if confidence < 0.5 or agent_id not in self.agents:
logger.info(
f"Streaming AI response for conversation {state['conversation_id']} for user {state['user_id']} using agent {agent_id}"
)
return Command(
update={"agent_id": state["agent_id"]}, goto=state["agent_id"]
)
logger.error("Classification format error, falling back to current agent")
selected_agent_id = state["agent_id"]

# Initialize the selected system agent
if not self.agent or self.current_agent_id != selected_agent_id:
try:
self.agent = self.agent_factory.get_agent(selected_agent_id, state["user_id"])
self.current_agent_id = selected_agent_id
except Exception as e:
logger.error(f"Failed to create agent {selected_agent_id}: {e}")
return Command(update={"response": "Failed to initialize agent"}, goto=END)

logger.info(
f"Streaming AI response for conversation {state['conversation_id']} for user {state['user_id']} using agent {agent_id}"
f"Streaming AI response for conversation {state['conversation_id']} "
f"for user {state['user_id']} using agent {selected_agent_id}"
)
return Command(update={"agent_id": agent_id}, goto=agent_id)
return Command(update={"agent_id": selected_agent_id}, goto="agent_node")

async def agent_node(self, state: State, writer: StreamWriter):
"""Creates a node function for a specific agent"""
agent = self.agents[state["agent_id"]]
async for chunk in agent.run(
query=state["query"],
project_id=state["project_id"],
conversation_id=state["conversation_id"],
user_id=state["user_id"],
node_ids=state["node_ids"],
):
if isinstance(chunk, str):
writer(chunk)
"""Single agent node that uses the created agent"""
if not self.agent:
logger.error("Agent not initialized before agent_node execution")
return Command(update={"response": "Agent not initialized"}, goto=END)

try:
async for chunk in self.agent.run(
query=state["query"],
project_id=state["project_id"],
conversation_id=state["conversation_id"],
user_id=state["user_id"],
node_ids=state["node_ids"],
):
if isinstance(chunk, str):
writer(chunk)
except Exception as e:
logger.error(f"Error in agent execution: {e}")
writer("An error occurred while processing your request")

def build_graph(self) -> StateGraph:
"""Builds the graph with classifier and agent nodes"""
"""Builds simplified graph with classifier and single agent node"""
builder = StateGraph(self.State)

# Add classifier as entry point
builder.add_node("classifier", self.classifier_node)

# Add agent nodes
for agent_id in self.agents:
builder.add_node(agent_id, self.agent_node)
builder.add_edge(agent_id, END)
# Add single agent node
builder.add_node("agent_node", self.agent_node)
builder.add_edge("agent_node", END)

builder.set_entry_point("classifier")
return builder.compile()
Expand Down

0 comments on commit 0360f58

Please # to comment.