diff --git a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs index 70fddc68718..e1e4542d5d0 100644 --- a/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs +++ b/src/Libraries/Microsoft.Extensions.AI/ChatCompletion/FunctionInvokingChatClient.cs @@ -216,7 +216,7 @@ public override async Task CompleteAsync(IList chat // doesn't realize this and is wasting their budget requesting extra choices we'd never use. if (response.Choices.Count > 1) { - throw new InvalidOperationException($"Automatic function call invocation only accepts a single choice, but {response.Choices.Count} choices were received."); + ThrowForMultipleChoices(); } // Extract any function call contents on the first choice. If there are none, we're done. @@ -301,22 +301,47 @@ public override async IAsyncEnumerable CompleteSt _ = Throw.IfNull(chatMessages); HashSet? messagesToRemove = null; + List functionCallContents = []; + int? choice; try { for (int iteration = 0; ; iteration++) { - List? functionCallContents = null; - await foreach (var chunk in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) + choice = null; + functionCallContents.Clear(); + await foreach (var update in base.CompleteStreamingAsync(chatMessages, options, cancellationToken).ConfigureAwait(false)) { // We're going to emit all StreamingChatMessage items upstream, even ones that represent - // function calls, because a given StreamingChatMessage can contain other content too. - yield return chunk; + // function calls, because a given StreamingChatMessage can contain other content, too. + // And if we yield the function calls, and the consumer adds all the content into a message + // that's then added into history, they'll end up with function call contents that aren't + // directly paired with function result contents, which may cause issues for some models + // when the history is later sent again. + + // Find all the FCCs. We need to track these separately in order to be able to process them later. + int preFccCount = functionCallContents.Count; + functionCallContents.AddRange(update.Contents.OfType()); + + // If there were any, remove them from the update. We do this before yielding the update so + // that we're not modifying an instance already provided back to the caller. + int addedFccs = functionCallContents.Count - preFccCount; + if (addedFccs > preFccCount) + { + update.Contents = addedFccs == update.Contents.Count ? + [] : update.Contents.Where(c => c is not FunctionCallContent).ToList(); + } - foreach (var item in chunk.Contents.OfType()) + // Only one choice is allowed with automatic function calling. + if (choice is null) + { + choice = update.ChoiceIndex; + } + else if (choice != update.ChoiceIndex) { - functionCallContents ??= []; - functionCallContents.Add(item); + ThrowForMultipleChoices(); } + + yield return update; } // If there are no tools to call, or for any other reason we should stop, return the response. @@ -373,6 +398,16 @@ public override async IAsyncEnumerable CompleteSt } } + /// Throws an exception when multiple choices are received. + private static void ThrowForMultipleChoices() + { + // If there's more than one choice, we don't know which one to add to chat history, or which + // of their function calls to process. This should not happen except if the developer has + // explicitly requested multiple choices. We fail aggressively to avoid cases where a developer + // doesn't realize this and is wasting their budget requesting extra choices we'd never use. + throw new InvalidOperationException("Automatic function call invocation only accepts a single choice, but multiple choices were received."); + } + /// /// Removes all of the messages in from /// and all of the content in from the messages in . diff --git a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs index d9df2fc89e3..da983243acb 100644 --- a/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs +++ b/test/Libraries/Microsoft.Extensions.AI.Tests/ChatCompletion/FunctionInvokingChatClientTests.cs @@ -12,6 +12,8 @@ using OpenTelemetry.Trace; using Xunit; +#pragma warning disable SA1118 // Parameter should not span multiple lines + namespace Microsoft.Extensions.AI; public class FunctionInvokingChatClientTests @@ -41,14 +43,16 @@ public async Task SupportsSingleFunctionCallPerRequestAsync() { var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create(() => "Result 1", "Func1"), AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), AIFunctionFactory.Create((int i) => { }, "VoidReturn"), ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), @@ -57,7 +61,11 @@ await InvokeAndAssertAsync(options, [ new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), new ChatMessage(ChatRole.Assistant, "world"), - ]); + ]; + + await InvokeAndAssertAsync(options, plan); + + await InvokeAndAssertStreamingAsync(options, plan); } [Theory] @@ -67,31 +75,46 @@ public async Task SupportsMultipleFunctionCallsPerRequestAsync(bool concurrentIn { var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create((int i) => "Result 1", "Func1"), AIFunctionFactory.Create((int i) => $"Result 2: {i}", "Func2"), ] }; - await InvokeAndAssertAsync(options, [ + + List plan = + [ new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId1", "Func1"), new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 34 } }), new FunctionCallContent("callId3", "Func2", arguments: new Dictionary { { "i", 56 } }), ]), - new ChatMessage(ChatRole.Tool, [ + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId1", "Func1", result: "Result 1"), new FunctionResultContent("callId2", "Func2", result: "Result 2: 34"), new FunctionResultContent("callId3", "Func2", result: "Result 2: 56"), ]), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId4", "Func2", arguments: new Dictionary { { "i", 78 } }), - new FunctionCallContent("callId5", "Func1")]), - new ChatMessage(ChatRole.Tool, [ + new FunctionCallContent("callId5", "Func1") + ]), + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId4", "Func2", result: "Result 2: 78"), - new FunctionResultContent("callId5", "Func1", result: "Result 1")]), + new FunctionResultContent("callId5", "Func1", result: "Result 1") + ]), new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation })); + ]; + + Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = concurrentInvocation }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } [Fact] @@ -101,7 +124,8 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create((string arg) => { barrier.SignalAndWait(); @@ -110,18 +134,27 @@ public async Task ParallelFunctionCallsMayBeInvokedConcurrentlyAsync() ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), ]), - new ChatMessage(ChatRole.Tool, [ + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId1", "Func", result: "hellohello"), new FunctionResultContent("callId2", "Func", result: "worldworld"), ]), new ChatMessage(ChatRole.Assistant, "done"), - ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true })); + ]; + + Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { ConcurrentInvocation = true }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } [Fact] @@ -131,7 +164,8 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() var options = new ChatOptions { - Tools = [ + Tools = + [ AIFunctionFactory.Create(async (string arg) => { Interlocked.Increment(ref activeCount); @@ -143,18 +177,25 @@ public async Task ConcurrentInvocationOfParallelCallsDisabledByDefaultAsync() ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [ + new ChatMessage(ChatRole.Assistant, + [ new FunctionCallContent("callId1", "Func", arguments: new Dictionary { { "arg", "hello" } }), new FunctionCallContent("callId2", "Func", arguments: new Dictionary { { "arg", "world" } }), ]), - new ChatMessage(ChatRole.Tool, [ + new ChatMessage(ChatRole.Tool, + [ new FunctionResultContent("callId1", "Func", result: "hellohello"), new FunctionResultContent("callId2", "Func", result: "worldworld"), ]), new ChatMessage(ChatRole.Assistant, "done"), - ]); + ]; + + await InvokeAndAssertAsync(options, plan); + + await InvokeAndAssertStreamingAsync(options, plan); } [Theory] @@ -172,36 +213,40 @@ public async Task RemovesFunctionCallingMessagesWhenRequestedAsync(bool keepFunc ] }; -#pragma warning disable SA1118 // Parameter should not span multiple lines - var finalChat = await InvokeAndAssertAsync( - options, - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "world"), - ], - expected: keepFunctionCallingMessages ? - null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, "world") - ], - configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); -#pragma warning restore SA1118 - - IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + List? expected = keepFunctionCallingMessages ? null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, "world") + ]; + + Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); + + Validate(await InvokeAndAssertAsync(options, plan, expected, configure)); + Validate(await InvokeAndAssertStreamingAsync(options, plan, expected, configure)); + + void Validate(List finalChat) { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } } } @@ -220,37 +265,56 @@ public async Task RemovesFunctionCallingContentWhenRequestedAsync(bool keepFunct ] }; -#pragma warning disable SA1118 // Parameter should not span multiple lines - var finalChat = await InvokeAndAssertAsync(options, - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } }), new TextContent("more")]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), - new ChatMessage(ChatRole.Assistant, "world"), - ], - expected: keepFunctionCallingMessages ? - null : - [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), - new ChatMessage(ChatRole.Assistant, "more"), - new ChatMessage(ChatRole.Assistant, "world"), - ], - configurePipeline: b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages })); -#pragma warning restore SA1118 - - IEnumerable content = finalChat.SelectMany(m => m.Contents); - if (keepFunctionCallingMessages) - { - Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); - } - else + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new FunctionCallContent("callId1", "Func1"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } }), new TextContent("more")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; + + Func configure = b => b.Use(client => new FunctionInvokingChatClient(client) { KeepFunctionCallingMessages = keepFunctionCallingMessages }); + +#pragma warning disable SA1005, S125 + Validate(await InvokeAndAssertAsync(options, plan, keepFunctionCallingMessages ? null : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new TextContent("extra"), new TextContent("stuff")]), + new ChatMessage(ChatRole.Assistant, "more"), + new ChatMessage(ChatRole.Assistant, "world"), + ], configure)); + + Validate(await InvokeAndAssertStreamingAsync(options, plan, keepFunctionCallingMessages ? + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId2", "Func2", arguments: new Dictionary { { "i", 42 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId2", "Func2", result: "Result 2: 42")]), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId3", "VoidReturn", arguments: new Dictionary { { "i", 43 } })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId3", "VoidReturn", result: "Success: Function completed.")]), + new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), + ] : + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, "extrastuffmoreworld"), + ], configure)); + + void Validate(List finalChat) { - Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + IEnumerable content = finalChat.SelectMany(m => m.Contents); + if (keepFunctionCallingMessages) + { + Assert.Contains(content, c => c is FunctionCallContent or FunctionResultContent); + } + else + { + Assert.All(content, c => Assert.False(c is FunctionCallContent or FunctionResultContent)); + } } } @@ -267,12 +331,19 @@ public async Task ExceptionDetailsOnlyReportedWhenRequestedAsync(bool detailedEr ] }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1")]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: detailedErrors ? "Error: Function failed. Exception: Oh no!" : "Error: Function failed.")]), new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors })); + ]; + + Func configure = b => b.Use(s => new FunctionInvokingChatClient(s) { DetailedErrors = detailedErrors }); + + await InvokeAndAssertAsync(options, plan, configurePipeline: configure); + + await InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure); } [Fact] @@ -281,28 +352,36 @@ public async Task RejectsMultipleChoicesAsync() var func1 = AIFunctionFactory.Create(() => "Some result 1", "Func1"); var func2 = AIFunctionFactory.Create(() => "Some result 2", "Func2"); + var expected = new ChatCompletion( + [ + new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Metadata.Name)]), + new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Metadata.Name)]), + ]); + using var innerClient = new TestChatClient { CompleteAsyncCallback = async (chatContents, options, cancellationToken) => { await Task.Yield(); - - return new ChatCompletion( - [ - new(ChatRole.Assistant, [new FunctionCallContent("callId1", func1.Metadata.Name)]), - new(ChatRole.Assistant, [new FunctionCallContent("callId2", func2.Metadata.Name)]), - ]); - } + return expected; + }, + CompleteStreamingAsyncCallback = (chatContents, options, cancellationToken) => + YieldAsync(expected.ToStreamingChatCompletionUpdates()), }; IChatClient service = innerClient.AsBuilder().UseFunctionInvocation().Build(); List chat = [new ChatMessage(ChatRole.User, "hello")]; - var ex = await Assert.ThrowsAsync( - () => service.CompleteAsync(chat, new ChatOptions { Tools = [func1, func2] })); + ChatOptions options = new() { Tools = [func1, func2] }; - Assert.Contains("only accepts a single choice", ex.Message); - Assert.Single(chat); // It didn't add anything to the chat history + Validate(await Assert.ThrowsAsync(() => service.CompleteAsync(chat, options))); + Validate(await Assert.ThrowsAsync(() => service.CompleteStreamingAsync(chat, options).ToChatCompletionAsync())); + + void Validate(Exception ex) + { + Assert.Contains("only accepts a single choice", ex.Message); + Assert.Single(chat); // It didn't add anything to the chat history + } } [Theory] @@ -311,39 +390,51 @@ public async Task RejectsMultipleChoicesAsync() [InlineData(LogLevel.Information)] public async Task FunctionInvocationsLogged(LogLevel level) { - using CapturingLoggerProvider clp = new(); - - ServiceCollection c = new(); - c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); - var services = c.BuildServiceProvider(); + List plan = + [ + new ChatMessage(ChatRole.User, "hello"), + new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), + new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), + new ChatMessage(ChatRole.Assistant, "world"), + ]; var options = new ChatOptions { Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] }; - await InvokeAndAssertAsync(options, [ - new ChatMessage(ChatRole.User, "hello"), - new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), - new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), - new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(c => new FunctionInvokingChatClient(c, services.GetRequiredService>()))); + Func configure = b => + b.Use((c, services) => new FunctionInvokingChatClient(c, services.GetRequiredService>())); - if (level is LogLevel.Trace) - { - Assert.Collection(clp.Logger.Entries, - entry => Assert.True(entry.Message.Contains("Invoking Func1({") && entry.Message.Contains("\"arg1\": \"value1\"")), - entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && entry.Message.Contains("Result: \"Result 1\""))); - } - else if (level is LogLevel.Debug) - { - Assert.Collection(clp.Logger.Entries, - entry => Assert.True(entry.Message.Contains("Invoking Func1") && !entry.Message.Contains("arg1")), - entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && !entry.Message.Contains("Result"))); - } - else + await InvokeAsync(services => InvokeAndAssertAsync(options, plan, configurePipeline: configure, services: services)); + + await InvokeAsync(services => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure, services: services)); + + async Task InvokeAsync(Func work) { - Assert.Empty(clp.Logger.Entries); + using CapturingLoggerProvider clp = new(); + + ServiceCollection c = new(); + c.AddLogging(b => b.AddProvider(clp).SetMinimumLevel(level)); + + await work(c.BuildServiceProvider()); + + if (level is LogLevel.Trace) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("Invoking Func1({") && entry.Message.Contains("\"arg1\": \"value1\"")), + entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && entry.Message.Contains("Result: \"Result 1\""))); + } + else if (level is LogLevel.Debug) + { + Assert.Collection(clp.Logger.Entries, + entry => Assert.True(entry.Message.Contains("Invoking Func1") && !entry.Message.Contains("arg1")), + entry => Assert.True(entry.Message.Contains("Func1 invocation completed. Duration:") && !entry.Message.Contains("Result"))); + } + else + { + Assert.Empty(clp.Logger.Entries); + } } } @@ -353,38 +444,51 @@ await InvokeAndAssertAsync(options, [ public async Task FunctionInvocationTrackedWithActivity(bool enableTelemetry) { string sourceName = Guid.NewGuid().ToString(); - var activities = new List(); - using TracerProvider? tracerProvider = enableTelemetry ? - OpenTelemetry.Sdk.CreateTracerProviderBuilder() - .AddSource(sourceName) - .AddInMemoryExporter(activities) - .Build() : - null; - - var options = new ChatOptions - { - Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] - }; - await InvokeAndAssertAsync(options, [ + List plan = + [ new ChatMessage(ChatRole.User, "hello"), new ChatMessage(ChatRole.Assistant, [new FunctionCallContent("callId1", "Func1", new Dictionary { ["arg1"] = "value1" })]), new ChatMessage(ChatRole.Tool, [new FunctionResultContent("callId1", "Func1", result: "Result 1")]), new ChatMessage(ChatRole.Assistant, "world"), - ], configurePipeline: b => b.Use(c => - new FunctionInvokingChatClient( - new OpenTelemetryChatClient(c, sourceName: sourceName)))); + ]; - if (enableTelemetry) + ChatOptions options = new() { - Assert.Collection(activities, - activity => Assert.Equal("chat", activity.DisplayName), - activity => Assert.Equal("Func1", activity.DisplayName), - activity => Assert.Equal("chat", activity.DisplayName)); - } - else + Tools = [AIFunctionFactory.Create(() => "Result 1", "Func1")] + }; + + Func configure = b => b.Use(c => + new FunctionInvokingChatClient( + new OpenTelemetryChatClient(c, sourceName: sourceName))); + + await InvokeAsync(() => InvokeAndAssertAsync(options, plan, configurePipeline: configure)); + + await InvokeAsync(() => InvokeAndAssertStreamingAsync(options, plan, configurePipeline: configure)); + + async Task InvokeAsync(Func work) { - Assert.Empty(activities); + var activities = new List(); + using TracerProvider? tracerProvider = enableTelemetry ? + OpenTelemetry.Sdk.CreateTracerProviderBuilder() + .AddSource(sourceName) + .AddInMemoryExporter(activities) + .Build() : + null; + + await work(); + + if (enableTelemetry) + { + Assert.Collection(activities, + activity => Assert.Equal("chat", activity.DisplayName), + activity => Assert.Equal("Func1", activity.DisplayName), + activity => Assert.Equal("chat", activity.DisplayName)); + } + else + { + Assert.Empty(activities); + } } } @@ -392,7 +496,8 @@ private static async Task> InvokeAndAssertAsync( ChatOptions options, List plan, List? expected = null, - Func? configurePipeline = null) + Func? configurePipeline = null, + IServiceProvider? services = null) { Assert.NotEmpty(plan); @@ -400,7 +505,6 @@ private static async Task> InvokeAndAssertAsync( using CancellationTokenSource cts = new(); List chat = [plan[0]]; - int i = 0; using var innerClient = new TestChatClient { @@ -411,11 +515,11 @@ private static async Task> InvokeAndAssertAsync( await Task.Yield(); - return new ChatCompletion([plan[contents.Count]]); + return new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])); } }; - IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(); + IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); var result = await service.CompleteAsync(chat, options, cts.Token); chat.Add(result.Message); @@ -423,7 +527,7 @@ private static async Task> InvokeAndAssertAsync( expected ??= plan; Assert.NotNull(result); Assert.Equal(expected.Count, chat.Count); - for (; i < expected.Count; i++) + for (int i = 0; i < expected.Count; i++) { var expectedMessage = expected[i]; var chatMessage = chat[i]; @@ -456,4 +560,80 @@ private static async Task> InvokeAndAssertAsync( return chat; } + + private static async Task> InvokeAndAssertStreamingAsync( + ChatOptions options, + List plan, + List? expected = null, + Func? configurePipeline = null, + IServiceProvider? services = null) + { + Assert.NotEmpty(plan); + + configurePipeline ??= static b => b.UseFunctionInvocation(); + + using CancellationTokenSource cts = new(); + List chat = [plan[0]]; + + using var innerClient = new TestChatClient + { + CompleteStreamingAsyncCallback = (contents, actualOptions, actualCancellationToken) => + { + Assert.Same(chat, contents); + Assert.Equal(cts.Token, actualCancellationToken); + + return YieldAsync(new ChatCompletion(new ChatMessage(ChatRole.Assistant, [.. plan[contents.Count].Contents])).ToStreamingChatCompletionUpdates()); + } + }; + + IChatClient service = configurePipeline(innerClient.AsBuilder()).Build(services); + + var result = await service.CompleteStreamingAsync(chat, options, cts.Token).ToChatCompletionAsync(); + chat.Add(result.Message); + + expected ??= plan; + Assert.NotNull(result); + Assert.Equal(expected.Count, chat.Count); + for (int i = 0; i < expected.Count; i++) + { + var expectedMessage = expected[i]; + var chatMessage = chat[i]; + + Assert.Equal(expectedMessage.Role, chatMessage.Role); + Assert.Equal(expectedMessage.Text, chatMessage.Text); + Assert.Equal(expectedMessage.GetType(), chatMessage.GetType()); + + Assert.Equal(expectedMessage.Contents.Count, chatMessage.Contents.Count); + for (int j = 0; j < expectedMessage.Contents.Count; j++) + { + var expectedItem = expectedMessage.Contents[j]; + var chatItem = chatMessage.Contents[j]; + + Assert.Equal(expectedItem.GetType(), chatItem.GetType()); + Assert.Equal(expectedItem.ToString(), chatItem.ToString()); + if (expectedItem is FunctionCallContent expectedFunctionCall) + { + var chatFunctionCall = (FunctionCallContent)chatItem; + Assert.Equal(expectedFunctionCall.Name, chatFunctionCall.Name); + AssertExtensions.EqualFunctionCallParameters(expectedFunctionCall.Arguments, chatFunctionCall.Arguments); + } + else if (expectedItem is FunctionResultContent expectedFunctionResult) + { + var chatFunctionResult = (FunctionResultContent)chatItem; + AssertExtensions.EqualFunctionCallResults(expectedFunctionResult.Result, chatFunctionResult.Result); + } + } + } + + return chat; + } + + private static async IAsyncEnumerable YieldAsync(params T[] items) + { + await Task.Yield(); + foreach (var item in items) + { + yield return item; + } + } }