Skip to content

Commit

Permalink
Improved Assistant streaming methods
Browse files Browse the repository at this point in the history
  • Loading branch information
kayhantolga committed Jun 15, 2024
1 parent b25961b commit 64e376b
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 28 deletions.
137 changes: 127 additions & 10 deletions OpenAI.Playground/TestHelpers/AssistantHelpers/RunTestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,19 +165,58 @@ public static async Task CreateRunAsStreamTest(IOpenAIService openAI)
var result = openAI.Beta.Runs.RunCreateAsStream(CreatedThreadId, new()
{
AssistantId = assistantResult.Id
});
},justDataMode:false);

await foreach (var run in result)
{
if (run.Successful)
{
if (string.IsNullOrEmpty(run.Status))
Console.WriteLine($"Event:{run.StreamEvent}");
if (run is RunResponse runResponse)
{
Console.Write(".");
if (string.IsNullOrEmpty(runResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
}
}

else if (run is RunStepResponse runStepResponse)
{
if (string.IsNullOrEmpty(runStepResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
}
}

else if (run is MessageResponse messageResponse)
{
if (string.IsNullOrEmpty(messageResponse.Id))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
}
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
if (run.StreamEvent!=null)
{
Console.WriteLine(run.StreamEvent);
}
else
{
Console.Write(".");
}
}
}
else
Expand Down Expand Up @@ -450,13 +489,52 @@ public static async Task SubmitToolOutputsAsStreamToRunTest(IOpenAIService openA
{
if (run.Successful)
{
if (string.IsNullOrEmpty(run.Status))
Console.WriteLine($"Event:{run.StreamEvent}");
if (run is RunResponse runResponse)
{
Console.Write(".");
if (string.IsNullOrEmpty(runResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
}
}

else if (run is RunStepResponse runStepResponse)
{
if (string.IsNullOrEmpty(runStepResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
}
}

else if (run is MessageResponse messageResponse)
{
if (string.IsNullOrEmpty(messageResponse.Id))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
}
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
if (run.StreamEvent != null)
{
Console.WriteLine(run.StreamEvent);
}
else
{
Console.Write(".");
}
}
}
else
Expand Down Expand Up @@ -642,13 +720,52 @@ public static async Task CreateThreadAndRunAsStream(IOpenAIService sdk)
{
if (run.Successful)
{
if (string.IsNullOrEmpty(run.Status))
Console.WriteLine($"Event:{run.StreamEvent}");
if (run is RunResponse runResponse)
{
Console.Write(".");
if (string.IsNullOrEmpty(runResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {runResponse.Id}, Status: {runResponse.Status}");
}
}

else if (run is RunStepResponse runStepResponse)
{
if (string.IsNullOrEmpty(runStepResponse.Status))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Run Step Id: {runStepResponse.Id}, Status: {runStepResponse.Status}");
}
}

else if (run is MessageResponse messageResponse)
{
if (string.IsNullOrEmpty(messageResponse.Id))
{
Console.Write(".");
}
else
{
ConsoleExtensions.WriteLine($"Message Id: {messageResponse.Id}, Message: {messageResponse.Content?.FirstOrDefault()?.Text?.Value}");
}
}
else
{
ConsoleExtensions.WriteLine($"Run Id: {run.Id}, Status: {run.Status}");
if (run.StreamEvent != null)
{
Console.WriteLine(run.StreamEvent);
}
else
{
Console.Write(".");
}
}
}
else
Expand Down
22 changes: 22 additions & 0 deletions OpenAI.SDK/Extensions/JsonToObjectRouterExtension.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
using System.Text.Json;
using OpenAI.ObjectModels.ResponseModels;
using OpenAI.ObjectModels.SharedModels;

namespace OpenAI.Extensions;

public static class JsonToObjectRouterExtension
{
public static Type Route(string json)
{
var apiResponse = JsonSerializer.Deserialize<ObjectBaseResponse>(json);

return apiResponse?.ObjectTypeName switch
{
"thread.run.step" => typeof(RunStepResponse),
"thread.run" => typeof(RunResponse),
"thread.message" => typeof(MessageResponse),
"thread.message.delta" => typeof(MessageResponse),
_ => typeof(BaseResponse)
};
}
}
39 changes: 36 additions & 3 deletions OpenAI.SDK/Extensions/StreamHandleExtension.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Runtime.CompilerServices;
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Text.Json;
using OpenAI.ObjectModels;
using OpenAI.ObjectModels.RequestModels;
Expand All @@ -8,6 +9,10 @@ namespace OpenAI.Extensions;

public static class StreamHandleExtension
{
public static async IAsyncEnumerable<BaseResponse> AsStream(this HttpResponseMessage response, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await foreach (var baseResponse in AsStream<BaseResponse>(response, justDataMode, cancellationToken)) yield return baseResponse;
}
public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpResponseMessage response, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default)
where TResponse : BaseResponse, new()
{
Expand All @@ -20,13 +25,15 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes

await using var stream = await response.Content.ReadAsStreamAsync(cancellationToken);
using var reader = new StreamReader(stream);

string? tempStreamEvent = null;
bool isEventDelta;
// Continuously read the stream until the end of it
while (true)
{
cancellationToken.ThrowIfCancellationRequested();

var line = await reader.ReadLineAsync();
// Console.WriteLine("---" + line);
// Break the loop if we have reached the end of the stream
if (line == null)
{
Expand All @@ -39,11 +46,28 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
continue;
}

if (line.StartsWith("event: "))
{
line = line.RemoveIfStartWith("event: ");
tempStreamEvent = line;
isEventDelta = true;
}
else
{
isEventDelta = false;
}

if (justDataMode && !line.StartsWith("data: "))
{
continue;
}

if (!justDataMode && isEventDelta )
{
yield return new(){ObjectTypeName = "base.stream.event",StreamEvent = tempStreamEvent};
continue;
}

line = line.RemoveIfStartWith("data: ");

// Exit the loop if the stream is done
Expand All @@ -56,7 +80,14 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
try
{
// When the response is good, each line is a serializable CompletionCreateRequest
block = JsonSerializer.Deserialize<TResponse>(line);
if (typeof(TResponse) == typeof(BaseResponse))
{
block =JsonSerializer.Deserialize(line, JsonToObjectRouterExtension.Route(line), new JsonSerializerOptions()) as TResponse;
}
else
{
block = JsonSerializer.Deserialize<TResponse>(line);
}
}
catch (Exception)
{
Expand All @@ -78,6 +109,8 @@ public static async IAsyncEnumerable<TResponse> AsStream<TResponse>(this HttpRes
{
block.HttpStatusCode = httpStatusCode;
block.HeaderValues = headerValues;
block.StreamEvent = tempStreamEvent;
tempStreamEvent = null;
yield return block;
}
}
Expand Down
13 changes: 8 additions & 5 deletions OpenAI.SDK/Interfaces/IRunService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System.Runtime.CompilerServices;
using OpenAI.ObjectModels.RequestModels;
using OpenAI.ObjectModels.ResponseModels;
using OpenAI.ObjectModels.SharedModels;

namespace OpenAI.Interfaces;
Expand All @@ -24,8 +25,8 @@ public interface IRunService
/// <param name="modelId"></param>
/// <param name="justDataMode"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
IAsyncEnumerable<RunResponse> RunCreateAsStream(string threadId, RunCreateRequest request, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
IAsyncEnumerable<BaseResponse> RunCreateAsStream(string threadId, RunCreateRequest request, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Retrieves a run.
Expand Down Expand Up @@ -71,9 +72,10 @@ public interface IRunService
/// <param name="threadId"></param>
/// <param name="runId"></param>
/// <param name="request"></param>
/// <param name="justDataMode"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
IAsyncEnumerable<RunResponse> RunSubmitToolOutputsAsStream(string threadId, string runId, SubmitToolOutputsToRunRequest request, [EnumeratorCancellation] CancellationToken cancellationToken = default);
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
IAsyncEnumerable<BaseResponse> RunSubmitToolOutputsAsStream(string threadId, string runId, SubmitToolOutputsToRunRequest request, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Modifies a run.
Expand All @@ -93,7 +95,8 @@ public interface IRunService
/// <summary>
/// Create a thread and run it in one request as Stream.
/// </summary>
IAsyncEnumerable<RunResponse> CreateThreadAndRunAsStream(CreateThreadAndRunRequest createThreadAndRunRequest, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);
/// <returns><see cref="BaseResponse"/> also returns <see cref="RunResponse"/>,<see cref="RunStepResponse"/>, <see cref="MessageResponse"/> </returns>
IAsyncEnumerable<BaseResponse> CreateThreadAndRunAsStream(CreateThreadAndRunRequest createThreadAndRunRequest, string? modelId = null, bool justDataMode = true, [EnumeratorCancellation] CancellationToken cancellationToken = default);

/// <summary>
/// Returns a list of runs belonging to a thread.
Expand Down
Loading

0 comments on commit 64e376b

Please sign in to comment.