-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The new `ChatClient.CompleteChatAsync<T>` provides automatic structured output from OpenAI (see https://openai.com/index/introducing-structured-outputs-in-the-api/) by leveraging the new (in .NET9) API for emitting JSON Schema from a .NET type.
- Loading branch information
Showing
4 changed files
with
282 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
// <auto-generated /> | ||
#region License | ||
// MIT License | ||
// | ||
// Copyright (c) Daniel Cazzulino | ||
// | ||
// Permission is hereby granted, free of charge, to any person obtaining a copy | ||
// of this software and associated documentation files (the "Software"), to deal | ||
// in the Software without restriction, including without limitation the rights | ||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
// copies of the Software, and to permit persons to whom the Software is | ||
// furnished to do so, subject to the following conditions: | ||
// | ||
// The above copyright notice and this permission notice shall be included in all | ||
// copies or substantial portions of the Software. | ||
// | ||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
// SOFTWARE. | ||
#endregion | ||
|
||
#nullable enable | ||
|
||
using System; | ||
using System.Collections.Concurrent; | ||
using System.Collections.Generic; | ||
using System.ComponentModel; | ||
using System.Linq; | ||
using System.Text.Json; | ||
using System.Text.Json.Serialization.Metadata; | ||
using System.Threading.Tasks; | ||
using System.Threading; | ||
using System.Text.Json.Schema; | ||
|
||
namespace OpenAI.Chat; | ||
|
||
/// <summary> | ||
/// Provides strong-typed extension methods for <see cref="ChatClient"/>. | ||
/// </summary> | ||
/// <remarks> | ||
/// Requires .NET 8+ | ||
/// </remarks> | ||
/// <package id="OpenAI" version="2.0.0" /> | ||
/// <package id="System.Text.Json" version="9.0.0-rc.*" /> | ||
static partial class ChatClientTypedExtensions | ||
{ | ||
static ConcurrentDictionary<Type, BinaryData> jsonSchemas = new(); | ||
static JsonSerializerOptions jsonOptions = new(JsonSerializerDefaults.Web) | ||
{ | ||
TypeInfoResolver = new DefaultJsonTypeInfoResolver(), | ||
PropertyNameCaseInsensitive = true, | ||
PropertyNamingPolicy = JsonNamingPolicy.CamelCase, | ||
NumberHandling = System.Text.Json.Serialization.JsonNumberHandling.Strict, | ||
}; | ||
|
||
public static async Task<T?> CompleteChatAsync<T>(this ChatClient client, IEnumerable<ChatMessage> messages, ChatCompletionOptions? options = null, CancellationToken cancellationToken = default(CancellationToken)) | ||
{ | ||
options ??= new ChatCompletionOptions(); | ||
var elementType = typeof(T); | ||
|
||
if (elementType.IsArray) | ||
{ | ||
elementType = elementType.GetElementType()!; | ||
} | ||
else if (elementType.IsGenericType && ( | ||
elementType.GetGenericTypeDefinition() == typeof(IEnumerable<>) || | ||
elementType.GetGenericTypeDefinition() == typeof(ICollection<>) || | ||
elementType.GetGenericTypeDefinition() == typeof(List<>) || | ||
elementType.GetGenericTypeDefinition() == typeof(IList<>) || | ||
elementType.GetGenericTypeDefinition() == typeof(IReadOnlyCollection<>))) | ||
{ | ||
elementType = elementType.GetGenericArguments()[0]; | ||
} | ||
|
||
var typeName = elementType.Name; | ||
|
||
if (elementType == typeof(T)) | ||
{ | ||
var schema = jsonSchemas.GetOrAdd(typeof(T), _ => GetJsonSchema<T>()); | ||
options.ResponseFormat = ChatResponseFormat.CreateJsonSchemaFormat(typeName, schema); | ||
|
||
var response = await client.CompleteChatAsync(messages, options, cancellationToken); | ||
var json = response.Value.Content.FirstOrDefault(x => x.Kind == ChatMessageContentPartKind.Text)?.Text; | ||
|
||
if (string.IsNullOrEmpty(json)) | ||
return default; | ||
|
||
return JsonSerializer.Deserialize<T>(json, jsonOptions); | ||
} | ||
else | ||
{ | ||
typeName = $"{typeName}s"; | ||
|
||
var schema = jsonSchemas.GetOrAdd(typeof(Values<T>), _ => GetJsonSchema<Values<T>>()); | ||
options.ResponseFormat = ChatResponseFormat.CreateJsonSchemaFormat(typeName, schema); | ||
|
||
var response = await client.CompleteChatAsync(messages, options, cancellationToken); | ||
var json = response.Value.Content.FirstOrDefault(x => x.Kind == ChatMessageContentPartKind.Text)?.Text; | ||
|
||
if (string.IsNullOrEmpty(json) || | ||
JsonSerializer.Deserialize<Values<T>>(json, jsonOptions) is not { } data) | ||
return default; | ||
|
||
return data.Data; | ||
} | ||
} | ||
|
||
|
||
static BinaryData GetJsonSchema<T>() | ||
{ | ||
var node = JsonSchemaExporter.GetJsonSchemaAsNode(jsonOptions, typeof(T), new JsonSchemaExporterOptions | ||
{ | ||
TreatNullObliviousAsNonNullable = true, | ||
TransformSchemaNode = (context, node) => | ||
{ | ||
var description = context.PropertyInfo?.AttributeProvider?.GetCustomAttributes(typeof(DescriptionAttribute), false) | ||
.OfType<DescriptionAttribute>() | ||
.FirstOrDefault()?.Description; | ||
if (description != null) | ||
node["description"] = description; | ||
return node; | ||
}, | ||
}); | ||
|
||
return BinaryData.FromString(node.ToJsonString()); | ||
} | ||
|
||
public class Values<T> | ||
{ | ||
public required T Data { get; set; } | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using Microsoft.Extensions.Configuration; | ||
using Xunit; | ||
|
||
#pragma warning disable CA1050 // Declare types in namespaces | ||
public class SecretsFactAttribute : FactAttribute | ||
{ | ||
public SecretsFactAttribute(params string[] secrets) | ||
{ | ||
var configuration = new ConfigurationBuilder() | ||
.AddEnvironmentVariables() | ||
.AddUserSecrets<SecretsFactAttribute>() | ||
.Build(); | ||
|
||
var missing = new HashSet<string>(); | ||
|
||
foreach (var secret in secrets) | ||
{ | ||
if (string.IsNullOrEmpty(configuration[secret])) | ||
missing.Add(secret); | ||
} | ||
|
||
if (missing.Count > 0) | ||
Skip = "Missing user secrets: " + string.Join(',', missing); | ||
} | ||
} | ||
|
||
public class LocalFactAttribute : SecretsFactAttribute | ||
{ | ||
public LocalFactAttribute(params string[] secrets) : base(secrets) | ||
{ | ||
if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI"))) | ||
Skip = "Non-CI test"; | ||
} | ||
} | ||
|
||
public class CIFactAttribute : FactAttribute | ||
{ | ||
public CIFactAttribute() | ||
{ | ||
if (string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI"))) | ||
Skip = "CI-only test"; | ||
} | ||
} | ||
|
||
public class SecretsTheoryAttribute : TheoryAttribute | ||
{ | ||
public SecretsTheoryAttribute(params string[] secrets) | ||
{ | ||
var configuration = new ConfigurationBuilder() | ||
.AddUserSecrets<SecretsTheoryAttribute>() | ||
.Build(); | ||
|
||
var missing = new HashSet<string>(); | ||
|
||
foreach (var secret in secrets) | ||
{ | ||
if (string.IsNullOrEmpty(configuration[secret])) | ||
missing.Add(secret); | ||
} | ||
|
||
if (missing.Count > 0) | ||
Skip = "Missing user secrets: " + string.Join(',', missing); | ||
} | ||
} | ||
|
||
public class LocalTheoryAttribute : TheoryAttribute | ||
{ | ||
public LocalTheoryAttribute() | ||
{ | ||
if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI"))) | ||
Skip = "Non-CI test"; | ||
} | ||
} | ||
|
||
public class CITheoryAttribute : TheoryAttribute | ||
{ | ||
public CITheoryAttribute() | ||
{ | ||
if (string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI"))) | ||
Skip = "CI-only test"; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
using System; | ||
using System.Net.Http; | ||
using System.Reflection; | ||
using System.Text.Json; | ||
using System.Threading.Tasks; | ||
using System.Xml.Linq; | ||
using Devlooped.Web; | ||
using Microsoft.Extensions.Configuration; | ||
using Xunit; | ||
using Xunit.Abstractions; | ||
|
||
namespace OpenAI.Chat; | ||
|
||
public class ChatClientTypedExtensionsTests(ITestOutputHelper output) | ||
{ | ||
[SecretsFact("OpenAI:Key")] | ||
public async Task CanScrapPage() | ||
{ | ||
var configuration = new ConfigurationBuilder() | ||
.AddEnvironmentVariables() | ||
.AddUserSecrets(Assembly.GetExecutingAssembly()) | ||
.Build(); | ||
|
||
var client = new OpenAIClient(configuration["OpenAI:Key"]!); | ||
var chat = client.GetChatClient("gpt-4o"); | ||
|
||
using var http = new HttpClient(); | ||
http.DefaultRequestHeaders.AcceptLanguage.Add(new("en-US")); | ||
|
||
var html = await http.GetStringAsync("https://www.imdb.com/chart/moviemeter/?ref_=nv_mv_mpm&genres=thriller&user_rating=6%2C&sort=user_rating%2Cdesc&num_votes=50000%2C"); | ||
// By parsing, we eliminate lots of noise and only keep the relevant parts (i.e. skip headers, scripts and styles). | ||
var doc = HtmlDocument.Parse(html); | ||
|
||
var response = await chat.CompleteChatAsync<Movie[]>([ | ||
new SystemChatMessage( | ||
""" | ||
You are an HTML page scraper. | ||
You use exclusively the data in the following HTML page to parse and return a list of movies. | ||
You perform smart type conversion and parsing as needed to fit the result schema in JSON format. | ||
"""), | ||
new UserChatMessage(doc.CssSelectElement("html")!.ToString(SaveOptions.DisableFormatting)), | ||
]); | ||
|
||
Assert.Equal(16, response.Length); | ||
Assert.Contains(response, x => x.Title == "The Dark Knight"); | ||
Assert.Contains(response, x => x.Title == "Joker"); | ||
|
||
output.WriteLine(JsonSerializer.Serialize(response, new JsonSerializerOptions(JsonSerializerOptions.Web) { WriteIndented = true })); | ||
} | ||
|
||
public record Movie(string Title, int Year, TimeSpan Duration, string AgeRating, StarsRating Stars, string Url); | ||
|
||
public record StarsRating(double Stars, long Votes); | ||
} |