Skip to content

Commit

Permalink
Better error handling for Sage.
Browse files Browse the repository at this point in the history
  • Loading branch information
QuinnDamerell committed Feb 11, 2025
1 parent 511e54b commit 39785f4
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 48 deletions.
15 changes: 14 additions & 1 deletion homeway/homeway/Proto/SageStreamMessage.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,15 @@ def StatusCode(self):
return self._tab.Get(octoflatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0

# SageStreamMessage
def ErrorMessage(self) -> Optional[str]:
o = octoflatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None

def SageStreamMessageStart(builder: octoflatbuffers.Builder):
builder.StartObject(8)
builder.StartObject(9)

def Start(builder: octoflatbuffers.Builder):
SageStreamMessageStart(builder)
Expand Down Expand Up @@ -170,6 +177,12 @@ def SageStreamMessageAddStatusCode(builder: octoflatbuffers.Builder, statusCode:
def AddStatusCode(builder: octoflatbuffers.Builder, statusCode: int):
SageStreamMessageAddStatusCode(builder, statusCode)

def SageStreamMessageAddErrorMessage(builder: octoflatbuffers.Builder, errorMessage: int):
builder.PrependUOffsetTRelativeSlot(8, octoflatbuffers.number_types.UOffsetTFlags.py_type(errorMessage), 0)

def AddErrorMessage(builder: octoflatbuffers.Builder, errorMessage: int):
SageStreamMessageAddErrorMessage(builder, errorMessage)

def SageStreamMessageEnd(builder: octoflatbuffers.Builder) -> int:
return builder.EndObject()

Expand Down
78 changes: 56 additions & 22 deletions homeway/homeway_linuxhost/sage/fibermanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import struct
import logging
import threading
from typing import List, Dict
from typing import List, Dict, Optional
import octoflatbuffers

from homeway.sentry import Sentry
Expand Down Expand Up @@ -81,12 +81,16 @@ def createDataContextOffset(builder:octoflatbuffers.Builder) -> int:
class ResponseContext:
Text:str = None
StatusCode:int = None
ErrorMessage:str = None
response:ResponseContext = ResponseContext()
async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool):
async def onDataStreamReceived(statusCode:int, errorMessage:Optional[str], data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool):
# For listen, this should only be called once
if response.StatusCode is not None:
raise Exception("Sage Listen onDataStreamReceived called more than once.")

# Set the error string, to either a string or None
response.ErrorMessage = errorMessage

# Check for a failure, which can happen at anytime.
# If we have anything but a 200, stop processing now.
response.StatusCode = statusCode
Expand All @@ -107,8 +111,12 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD
# Do the operation, stream or wait for the response.
result = await self._SendAndReceive(SageOperationTypes.Listen, audio, createDataContextOffset, onDataStreamReceived, isTransmissionDone)

# If the status code is set at any time and not 200, we failed, regardless of the mode.
# Check this before the function result, so we get the status code.
# If the error string or status code is set at any time, we failed, regardless of the mode.
# Check these before anything else, to ensure they get handled.
if response.ErrorMessage is not None:
# Return the raw string, since it's intended to send to the user.
return ListenResult.Failure(response.ErrorMessage)

if response.StatusCode is not None and response.StatusCode != 200:
# In some special cases, we want to map the status code to a user message.
errorStr = self._MapErrorStatusCodeToUserStr(response.StatusCode)
Expand Down Expand Up @@ -148,8 +156,11 @@ def createDataContextOffset(builder:octoflatbuffers.Builder) -> int:
# to stream back.
class ResponseContext:
StatusCode:int = None
ErrorMessage:str = None
response:ResponseContext = ResponseContext()
async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool):
async def onDataStreamReceived(statusCode:int, errorMessage:Optional[str], data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool):
# Set the error string, to either a string or None
response.ErrorMessage = errorMessage

# Check for a failure, which can happen at anytime.
# If we have anything but a 200, stop processing now.
Expand All @@ -167,8 +178,10 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD
requestBytes = json.dumps(request).encode("utf-8")
result = await self._SendAndReceive(SageOperationTypes.Speak, requestBytes, createDataContextOffset, onDataStreamReceived)

# If the status code is set at any time and not 200, we failed, regardless of the mode.
# Check this before the function result, so we get the status code.
# Check if either of the error systems are set.
if response.ErrorMessage is not None:
self.Logger.error(f"Sage Speak failed with error msg: {response.ErrorMessage}")
return False
if response.StatusCode is not None and response.StatusCode != 200:
self.Logger.error(f"Sage Speak failed with status code {response.StatusCode}")
return False
Expand All @@ -182,7 +195,7 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD

# Takes a chat json object and returns the assistant's response text.
# Returns None on failure.
async def Chat(self, requestJson:str, homeContext_CanBeNone:CompressionResult, states_CanBeNone:CompressionResult) -> bytearray:
async def Chat(self, requestJson:str, homeContext_CanBeNone:CompressionResult, states_CanBeNone:CompressionResult) -> str:

# Our data type is a json string.
def createDataContextOffset(builder:octoflatbuffers.Builder) -> int:
Expand All @@ -191,13 +204,17 @@ def createDataContextOffset(builder:octoflatbuffers.Builder) -> int:
# We expect the onDataStreamReceived handler to be called once, with the full response.
class ResponseContext:
Bytes = None
StatusCode = None
StatusCode:int = None
ErrorMessage:str = None
response:ResponseContext = ResponseContext()
async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool):
async def onDataStreamReceived(statusCode:int, errorMessage:Optional[str], data:bytearray, dataContext:SageDataContext, isFinalDataChunk:bool):
# For Chat, this should only be called once
if response.StatusCode is not None:
raise Exception("Sage Chat onDataStreamReceived called more than once.")

# Set the error string, to either a string or None
response.ErrorMessage = errorMessage

# Check for a failure, which can happen at anytime.
# If we have anything but a 200, stop processing now.
response.StatusCode = statusCode
Expand All @@ -220,13 +237,18 @@ async def onDataStreamReceived(statusCode:int, data:bytearray, dataContext:SageD
# Do the operation, wait for the result.
result = await self._SendAndReceive(SageOperationTypes.Chat, data, createDataContextOffset, onDataStreamReceived, True)

# If the status code is set at any time and not 200, we failed.
# Check this before the function result, so we get the status code.
# If the error string or status code is set at any time, we failed, regardless of the mode.
# Check these before anything else, to ensure they get handled.
# Note, the server might also return a valid response but override the response text with the error.
if response.ErrorMessage is not None:
# Return the raw string, since it's intended to send to the user.
return response.ErrorMessage

if response.StatusCode is not None and response.StatusCode != 200:
# In some special cases, we want to map the status code to a user message.
userError = self._MapErrorStatusCodeToUserStr(response.StatusCode)
if userError is not None:
return userError
errorStr = self._MapErrorStatusCodeToUserStr(response.StatusCode)
if errorStr is not None:
return errorStr
self.Logger.error(f"Sage Chat failed with status code {response.StatusCode}")
return None

Expand Down Expand Up @@ -316,7 +338,8 @@ async def _SendAndReceive(self,
# If we are doing the upload stream, check that the stream wasn't aborted from the server side before returning.
if context.StatusCode is not None or context.IsAborted:
# If the stream was aborted, fire the callback and return False
await onDataStreamReceivedCallbackAsync(context.StatusCode, bytearray(), None, True)
# ErrorMessage is either None or a string to send to the user.
await onDataStreamReceivedCallbackAsync(context.StatusCode, context.ErrorMessage, bytearray(), None, True)
return False

# Otherwise, keep the context around and return success.
Expand All @@ -333,6 +356,7 @@ async def _SendAndReceive(self,
data:List[bytearray] = None
dataContext:SageDataContext = None
statusCode:int = None
errorMessage:str = None
isDataDownloadComplete:bool = True
with self.StreamContextLock:
# Grab all of the data we currently have and process it.
Expand All @@ -341,21 +365,22 @@ async def _SendAndReceive(self,
isDataDownloadComplete = context.IsDataDownloadComplete
dataContext = context.DataContext
statusCode = context.StatusCode
errorMessage = context.ErrorMessage

# Clear the event under lock to ensure we don't miss a set.
context.Event.clear()

# Check for a stream abort, before we check if there's no data.
if context.IsAborted:
self.Logger.error(f"Sage message stream was aborted: {context.StatusCode}")
await onDataStreamReceivedCallbackAsync(context.StatusCode, bytearray(), None, True)
self.Logger.error(f"Sage message stream was aborted: {statusCode}")
await onDataStreamReceivedCallbackAsync(statusCode, errorMessage, bytearray(), None, True)
return False

# Regardless of the other vars, if we didn't get any data, the response wait timed out.
if len(data) == 0:
self.Logger.error("Sage message timed out while waiting for a response.")
context.StatusCode = 608
await onDataStreamReceivedCallbackAsync(context.StatusCode, bytearray(), None, True)
await onDataStreamReceivedCallbackAsync(statusCode, errorMessage, bytearray(), None, True)
return False

# Process the data
Expand All @@ -368,7 +393,7 @@ async def _SendAndReceive(self,
for d in data:
count += 1
isLastChunk = isDataDownloadComplete and count == len(data)
if await onDataStreamReceivedCallbackAsync(statusCode, d, dataContext, isLastChunk) is False:
if await onDataStreamReceivedCallbackAsync(statusCode, errorMessage, d, dataContext, isLastChunk) is False:
return False

# If we processed all the data and the stream is done, we're done.
Expand Down Expand Up @@ -530,6 +555,10 @@ def OnIncomingMessage(self, buf:bytearray):
self.Logger.debug(f"Sage got a message for stream [{streamId}] but couldn't find the context.")
return

# If there's an error string, always set or update it.
if msg.ErrorMessage() is not None:
context.ErrorMessage = msg.ErrorMessage().decode("utf-8")

# The abort message doesn't require the other checks.
if msg.IsAbortMsg():
# The abort message should always have a non-success status code.
Expand Down Expand Up @@ -575,7 +604,7 @@ def OnIncomingMessage(self, buf:bytearray):
raise Exception("Sage Fiber message has is missing the data context.")
context.DataContext = dataContext

# Append the data.
# Append the data.
context.Data.append(data)

# Set the event so the caller can consume the data.
Expand Down Expand Up @@ -636,8 +665,13 @@ def __init__(self, streamId:int, requestType:SageOperationTypes):
# Note this will be set multiple times if the response is being streamed.
self.Event = threading.Event()

# Set when the first response is received.
# Usually set to 200 for success otherwise it can be an error.
self.StatusCode:int = None

# Optional - If set, this is an error message that should be shown to the user.
self.ErrorMessage:str = None

# Set when the first response is received.
self.DataContext:SageDataContext = None

# When set, the data has been fully downloaded.
Expand Down
27 changes: 20 additions & 7 deletions homeway/homeway_linuxhost/sage/sagehandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ class SageHandler(AsyncEventHandler):
# This is more of a sanity check, the server will enforce it's own limits.
c_MaxStringLength = 500

# A quick static cache of the info event, since it's gets called in rapid succession sometimes.
c_InfoEventMaxCacheTimeSec = 15.0
# Cache the describe info because it doesn't change often and it's sometimes requested many times rapidly.
# Sometimes the describe event is called before an action for some reason and is blocking, so the cache helps there as well.
# Since this hardly ever changes, there's no reason to update it often. It will update when the addon restarts.
c_InfoEventMaxCacheTimeSec = 60 * 60 * 24 * 3
s_InfoEvent:Info = None
s_InfoEventTime:float = 0.0
s_InfoEventLock = threading.Lock()
Expand Down Expand Up @@ -238,19 +240,30 @@ async def _HandleDescribe(self) -> bool:
attempt += 1

# Attempt getting a valid response.
serviceInfo:Info = None
try:
self.Logger.debug(f"Sage - Starting Info Service Request - {url}")
# Attempt to get and parse the info object.
response = requests.get(url, timeout=10)
if response.status_code == 200:
info = self._BuildInfoEvent(response.json())
if info is None:
serviceInfo = self._BuildInfoEvent(response.json())
if serviceInfo is None:
raise Exception("Failed to build info event.")
await self._CacheAndWriteInfoEvent(info)
return True
self.Logger.warning(f"Sage - Failed to get models from service. Attempt: {attempt} - {response.status_code}")
else:
self.Logger.warning(f"Sage - Failed to get models from service. Attempt: {attempt} - {response.status_code}")
except Exception as e:
self.Logger.warning(f"Sage - Failed to get models from service. Attempt: {attempt} - {e}")

# If we got service info, try to write it to the client.
if serviceInfo is not None:
self.Logger.debug("Sage - Service request successful, sending to Wyoming protocol.")
try:
await self._CacheAndWriteInfoEvent(serviceInfo)
# Success
return True
except Exception as e:
self.Logger.warning(f"Sage - Failed to send info to wyoming protocol. Attempt: {attempt} - {e}")

# If we fail, try a few times. Throw when we hit the limit.
if attempt > 3:
raise Exception("Failed to get models from service after 3 attempts.")
Expand Down
Loading

0 comments on commit 39785f4

Please # to comment.