diff --git a/app/SharedWebComponents/Pages/Chat.razor.cs b/app/SharedWebComponents/Pages/Chat.razor.cs index 8211087a..f9899dd0 100644 --- a/app/SharedWebComponents/Pages/Chat.razor.cs +++ b/app/SharedWebComponents/Pages/Chat.razor.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft. All rights reserved. namespace SharedWebComponents.Pages; +using System.Text; public sealed partial class Chat { @@ -42,17 +43,39 @@ private async Task OnAskClickedAsync() try { var history = _questionAndAnswerMap - .Where(x => x.Value?.Choices is { Length: > 0}) - .SelectMany(x => new ChatMessage[] { new ChatMessage("user", x.Key.Question), new ChatMessage("assistant", x.Value!.Choices[0].Message.Content) }) + .Where(x => x.Value?.Choices is { Length: > 0 }) + .SelectMany(x => new ChatMessage[] { + new ChatMessage("user", x.Key.Question), + new ChatMessage("assistant", x.Value!.Choices[0].Message.Content) + }) .ToList(); history.Add(new ChatMessage("user", _userQuestion)); var request = new ChatRequest([.. history], Settings.Overrides); - var result = await ApiClient.ChatConversationAsync(request); - _questionAndAnswerMap[_currentQuestion] = result.Response; - if (result.IsSuccessful) + try + { + var responseStream = await ApiClient.PostStreamingRequestAsync(request, "api/chat/stream"); + + await foreach (var response in responseStream) + { + _questionAndAnswerMap[_currentQuestion] = new ChatAppResponseOrError( + response.Choices, + null); + + StateHasChanged(); + await Task.Delay(1); + } + } + catch (Exception ex) + { + _questionAndAnswerMap[_currentQuestion] = new ChatAppResponseOrError( + Array.Empty(), + ex.Message); + } + + if (_questionAndAnswerMap[_currentQuestion]?.Error is null) { _userQuestion = ""; _currentQuestion = default; diff --git a/app/SharedWebComponents/Services/ApiClient.cs b/app/SharedWebComponents/Services/ApiClient.cs index 9b15f7fd..3d3c1198 100644 --- a/app/SharedWebComponents/Services/ApiClient.cs +++ b/app/SharedWebComponents/Services/ApiClient.cs @@ -90,47 +90,32 @@ public async IAsyncEnumerable GetDocumentsAsync( } } - public Task> ChatConversationAsync(ChatRequest request) => PostRequestAsync(request, "api/chat"); - - private async Task> PostRequestAsync( + public async Task> PostStreamingRequestAsync( TRequest request, string apiRoute) where TRequest : ApproachRequest { - var result = new AnswerResult( - IsSuccessful: false, - Response: null, - Approach: request.Approach, - Request: request); - var json = JsonSerializer.Serialize( request, SerializerOptions.Default); - using var body = new StringContent( + using var content = new StringContent( json, Encoding.UTF8, "application/json"); - var response = await httpClient.PostAsync(apiRoute, body); + // Use both HttpCompletionOption and CancellationToken + var response = await httpClient.PostAsync( + apiRoute, + content, + CancellationToken.None); if (response.IsSuccessStatusCode) { - var answer = await response.Content.ReadFromJsonAsync(); - return result with - { - IsSuccessful = answer is not null, - Response = answer, - }; + var stream = await response.Content.ReadAsStreamAsync(); + var nullableResponses = JsonSerializer.DeserializeAsyncEnumerable( + stream, + new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + + return nullableResponses.Where(r => r != null)!; } - else - { - var errorTitle = $"HTTP {(int)response.StatusCode} : {response.ReasonPhrase ?? "☹️ Unknown error..."}"; - var answer = new ChatAppResponseOrError( - Array.Empty(), - errorTitle); - return result with - { - IsSuccessful = false, - Response = answer - }; - } + throw new HttpRequestException($"HTTP {(int)response.StatusCode} : {response.ReasonPhrase ?? "Unknown error"}"); } } diff --git a/app/backend/Extensions/WebApplicationExtensions.cs b/app/backend/Extensions/WebApplicationExtensions.cs index 64464f52..093cf80a 100644 --- a/app/backend/Extensions/WebApplicationExtensions.cs +++ b/app/backend/Extensions/WebApplicationExtensions.cs @@ -12,7 +12,7 @@ internal static WebApplication MapApi(this WebApplication app) api.MapPost("openai/chat", OnPostChatPromptAsync); // Long-form chat w/ contextual history endpoint - api.MapPost("chat", OnPostChatAsync); + api.MapPost("chat/stream", OnPostChatStreamingAsync); // Upload a document api.MapPost("documents", OnPostDocumentAsync); @@ -70,20 +70,23 @@ You will always reply with a Markdown formatted response. } } - private static async Task OnPostChatAsync( + private static async IAsyncEnumerable OnPostChatStreamingAsync( ChatRequest request, ReadRetrieveReadChatService chatService, - CancellationToken cancellationToken) + [EnumeratorCancellation] CancellationToken cancellationToken) { - if (request is { History.Length: > 0 }) + if (request is not { History.Length: > 0 }) { - var response = await chatService.ReplyAsync( - request.History, request.Overrides, cancellationToken); - - return TypedResults.Ok(response); + yield break; } - return Results.BadRequest(); + await foreach (var response in chatService.ReplyStreamingAsync( + request.History, + request.Overrides, + cancellationToken)) + { + yield return response; + } } private static async Task OnPostDocumentAsync( diff --git a/app/backend/Services/ReadRetrieveReadChatService.cs b/app/backend/Services/ReadRetrieveReadChatService.cs index 7c72dcd9..f512ef2e 100644 --- a/app/backend/Services/ReadRetrieveReadChatService.cs +++ b/app/backend/Services/ReadRetrieveReadChatService.cs @@ -4,6 +4,7 @@ using Microsoft.SemanticKernel.ChatCompletion; using Microsoft.SemanticKernel.Connectors.OpenAI; using Microsoft.SemanticKernel.Embeddings; +using System.Text; namespace MinimalApi.Services; #pragma warning disable SKEXP0011 // Mark members as static @@ -56,10 +57,10 @@ public ReadRetrieveReadChatService( _tokenCredential = tokenCredential; } - public async Task ReplyAsync( + public async IAsyncEnumerable ReplyStreamingAsync( ChatMessage[] history, RequestOverrides? overrides, - CancellationToken cancellationToken = default) + [EnumeratorCancellation] CancellationToken cancellationToken = default) { var top = overrides?.Top ?? 3; var useSemanticCaptions = overrides?.SemanticCaptions ?? false; @@ -71,9 +72,8 @@ public async Task ReplyAsync( float[]? embeddings = null; var question = history.LastOrDefault(m => m.IsUser)?.Content is { } userQuestion ? userQuestion - : throw new InvalidOperationException("Use question is null"); + : throw new InvalidOperationException("User question is null"); - string[]? followUpQuestionList = null; if (overrides?.RetrievalMode != RetrievalMode.Text && embedding is not null) { embeddings = (await embedding.GenerateEmbeddingAsync(question, cancellationToken: cancellationToken)).ToArray(); @@ -92,11 +92,20 @@ standard plan AND dental AND employee benefit. "); getQueryChat.AddUserMessage(question); - var result = await chat.GetChatMessageContentAsync( + var queryBuilder = new StringBuilder(); + + await foreach (var content in chat.GetStreamingChatMessageContentsAsync( getQueryChat, - cancellationToken: cancellationToken); + kernel: _kernel, + cancellationToken: cancellationToken)) + { + if (content.Content is { Length: > 0 }) + { + queryBuilder.Append(content.Content); + } + } - query = result.Content ?? throw new InvalidOperationException("Failed to get search query"); + query = queryBuilder.ToString() ?? throw new InvalidOperationException("Failed to get search query"); } // step 2 @@ -110,7 +119,7 @@ standard plan AND dental AND employee benefit. } else { - documentContents = string.Join("\r", documentContentList.Select(x =>$"{x.Title}:{x.Content}")); + documentContents = string.Join("\r", documentContentList.Select(x => $"{x.Title}:{x.Content}")); } // step 2.5 @@ -140,7 +149,7 @@ standard plan AND dental AND employee benefit. } } - + if (images != null) { var prompt = @$"## Source ## @@ -185,63 +194,95 @@ You answer needs to be a json object with the following format. StopSequences = [], }; + var streamingResponse = new StringBuilder(); + var documentContext = new ResponseContext( + DataPointsContent: documentContentList.Select(x => new SupportingContentRecord(x.Title, x.Content)).ToArray(), + DataPointsImages: images?.Select(x => new SupportingImageRecord(x.Title, x.Url)).ToArray(), + FollowupQuestions: Array.Empty(), // Will be populated after full response + Thoughts: Array.Empty()); // Will be populated after full response + // get answer - var answer = await chat.GetChatMessageContentAsync( - answerChat, - promptExecutingSetting, - cancellationToken: cancellationToken); - var answerJson = answer.Content ?? throw new InvalidOperationException("Failed to get search query"); - var answerObject = JsonSerializer.Deserialize(answerJson); - var ans = answerObject.GetProperty("answer").GetString() ?? throw new InvalidOperationException("Failed to get answer"); - var thoughts = answerObject.GetProperty("thoughts").GetString() ?? throw new InvalidOperationException("Failed to get thoughts"); + await foreach (var content in chat.GetStreamingChatMessageContentsAsync( + answerChat, + executionSettings: promptExecutingSetting, + kernel: _kernel, + cancellationToken: cancellationToken)) + { + if (content.Content is { Length: > 0 }) + { + streamingResponse.Append(content.Content); + var responseMessage = new ResponseMessage("assistant", streamingResponse.ToString()); + var choice = new ResponseChoice( + Index: 0, + Message: responseMessage, + Context: documentContext, + CitationBaseUrl: _configuration.ToCitationBaseUrl()); + + + yield return new ChatAppResponse(new[] { choice }); + } + } + + // After streaming completes, parse the final answer + var answerJson = streamingResponse.ToString(); + var finalAnswerObject = JsonSerializer.Deserialize(answerJson); + var ans = finalAnswerObject.GetProperty("answer").GetString() ?? throw new InvalidOperationException("Failed to get answer"); + var finalThoughts = finalAnswerObject.GetProperty("thoughts").GetString() ?? throw new InvalidOperationException("Failed to get thoughts"); + + // Create response context that will be used throughout + var responseContext = new ResponseContext( + DataPointsContent: documentContentList.Select(x => new SupportingContentRecord(x.Title, x.Content)).ToArray(), + DataPointsImages: images?.Select(x => new SupportingImageRecord(x.Title, x.Url)).ToArray(), + FollowupQuestions: Array.Empty(), + Thoughts: new[] { new Thoughts("Thoughts", finalThoughts) }); // step 4 // add follow up questions if requested if (overrides?.SuggestFollowupQuestions is true) { var followUpQuestionChat = new ChatHistory(@"You are a helpful AI assistant"); - followUpQuestionChat.AddUserMessage($@"Generate three follow-up question based on the answer you just generated. + followUpQuestionChat.AddUserMessage($@"Generate three follow-up questions based on the answer you just generated. # Answer {ans} # Format of the response -Return the follow-up question as a json string list. Don't put your answer between ```json and ```, return the json string directly. -e.g. -[ - ""What is the deductible?"", - ""What is the co-pay?"", - ""What is the out-of-pocket maximum?"" -]"); +Generate three questions, one per line. Do not include any JSON formatting or other text. +For example: +What is the deductible? +What is the co-pay? +What is the out-of-pocket maximum?"); - var followUpQuestions = await chat.GetChatMessageContentAsync( + var followUpQuestions = new List(); + var followUpBuilder = new StringBuilder(); + await foreach (var content in chat.GetStreamingChatMessageContentsAsync( followUpQuestionChat, - promptExecutingSetting, - cancellationToken: cancellationToken); - - var followUpQuestionsJson = followUpQuestions.Content ?? throw new InvalidOperationException("Failed to get search query"); - var followUpQuestionsObject = JsonSerializer.Deserialize(followUpQuestionsJson); - var followUpQuestionsList = followUpQuestionsObject.EnumerateArray().Select(x => x.GetString()!).ToList(); - foreach (var followUpQuestion in followUpQuestionsList) + executionSettings: promptExecutingSetting, + kernel: _kernel, + cancellationToken: cancellationToken)) { - ans += $" <<{followUpQuestion}>> "; - } + if (content.Content is { Length: > 0 }) + { + followUpBuilder.Append(content.Content); + var questions = followUpBuilder.ToString().Split('\n', StringSplitOptions.RemoveEmptyEntries); + + var answerWithQuestions = ans; + foreach (var followUpQuestion in questions) + { + answerWithQuestions += $" <<{followUpQuestion.Trim()}>> "; + } - followUpQuestionList = followUpQuestionsList.ToArray(); - } + var responseMessage = new ResponseMessage("assistant", answerWithQuestions); + var updatedContext = responseContext with { FollowupQuestions = questions }; - var responseMessage = new ResponseMessage("assistant", ans); - var responseContext = new ResponseContext( - DataPointsContent: documentContentList.Select(x => new SupportingContentRecord(x.Title, x.Content)).ToArray(), - DataPointsImages: images?.Select(x => new SupportingImageRecord(x.Title, x.Url)).ToArray(), - FollowupQuestions: followUpQuestionList ?? Array.Empty(), - Thoughts: new[] { new Thoughts("Thoughts", thoughts) }); + var choice = new ResponseChoice( + Index: 0, + Message: responseMessage, + Context: updatedContext, + CitationBaseUrl: _configuration.ToCitationBaseUrl()); - var choice = new ResponseChoice( - Index: 0, - Message: responseMessage, - Context: responseContext, - CitationBaseUrl: _configuration.ToCitationBaseUrl()); - - return new ChatAppResponse(new[] { choice }); + yield return new ChatAppResponse(new[] { choice }); + } + } + } } }