Skip to content

Commit

Permalink
🐱 Add typed OpenAI.Chat completion
Browse files Browse the repository at this point in the history
The new `ChatClient.CompleteChatAsync<T>` provides automatic structured output from OpenAI  (see https://openai.com/index/introducing-structured-outputs-in-the-api/) by leveraging the new (in .NET9) API for emitting JSON Schema from a .NET type.
  • Loading branch information
kzu committed Oct 9, 2024
1 parent fc83fcb commit 43b6b7b
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 1 deletion.
138 changes: 138 additions & 0 deletions OpenAI/Chat/ChatClientTypedExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// <auto-generated />
#region License
// MIT License
//
// Copyright (c) Daniel Cazzulino
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
#endregion

#nullable enable

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading.Tasks;
using System.Threading;
using System.Text.Json.Schema;

namespace OpenAI.Chat;

/// <summary>
/// Provides strong-typed extension methods for <see cref="ChatClient"/>.
/// </summary>
/// <remarks>
/// Requires .NET 8+
/// </remarks>
/// <package id="OpenAI" version="2.0.0" />
/// <package id="System.Text.Json" version="9.0.0-rc.*" />
static partial class ChatClientTypedExtensions
{
static ConcurrentDictionary<Type, BinaryData> jsonSchemas = new();
static JsonSerializerOptions jsonOptions = new(JsonSerializerDefaults.Web)
{
TypeInfoResolver = new DefaultJsonTypeInfoResolver(),
PropertyNameCaseInsensitive = true,
PropertyNamingPolicy = JsonNamingPolicy.CamelCase,
NumberHandling = System.Text.Json.Serialization.JsonNumberHandling.Strict,
};

public static async Task<T?> CompleteChatAsync<T>(this ChatClient client, IEnumerable<ChatMessage> messages, ChatCompletionOptions? options = null, CancellationToken cancellationToken = default(CancellationToken))
{
options ??= new ChatCompletionOptions();
var elementType = typeof(T);

if (elementType.IsArray)
{
elementType = elementType.GetElementType()!;
}
else if (elementType.IsGenericType && (
elementType.GetGenericTypeDefinition() == typeof(IEnumerable<>) ||
elementType.GetGenericTypeDefinition() == typeof(ICollection<>) ||
elementType.GetGenericTypeDefinition() == typeof(List<>) ||
elementType.GetGenericTypeDefinition() == typeof(IList<>) ||
elementType.GetGenericTypeDefinition() == typeof(IReadOnlyCollection<>)))
{
elementType = elementType.GetGenericArguments()[0];
}

var typeName = elementType.Name;

if (elementType == typeof(T))
{
var schema = jsonSchemas.GetOrAdd(typeof(T), _ => GetJsonSchema<T>());
options.ResponseFormat = ChatResponseFormat.CreateJsonSchemaFormat(typeName, schema);

var response = await client.CompleteChatAsync(messages, options, cancellationToken);
var json = response.Value.Content.FirstOrDefault(x => x.Kind == ChatMessageContentPartKind.Text)?.Text;

if (string.IsNullOrEmpty(json))
return default;

return JsonSerializer.Deserialize<T>(json, jsonOptions);
}
else
{
typeName = $"{typeName}s";

var schema = jsonSchemas.GetOrAdd(typeof(Values<T>), _ => GetJsonSchema<Values<T>>());
options.ResponseFormat = ChatResponseFormat.CreateJsonSchemaFormat(typeName, schema);

var response = await client.CompleteChatAsync(messages, options, cancellationToken);
var json = response.Value.Content.FirstOrDefault(x => x.Kind == ChatMessageContentPartKind.Text)?.Text;

if (string.IsNullOrEmpty(json) ||
JsonSerializer.Deserialize<Values<T>>(json, jsonOptions) is not { } data)
return default;

return data.Data;
}
}


static BinaryData GetJsonSchema<T>()
{
var node = JsonSchemaExporter.GetJsonSchemaAsNode(jsonOptions, typeof(T), new JsonSchemaExporterOptions
{
TreatNullObliviousAsNonNullable = true,
TransformSchemaNode = (context, node) =>
{
var description = context.PropertyInfo?.AttributeProvider?.GetCustomAttributes(typeof(DescriptionAttribute), false)
.OfType<DescriptionAttribute>()
.FirstOrDefault()?.Description;
if (description != null)
node["description"] = description;
return node;
},
});

return BinaryData.FromString(node.ToJsonString());
}

public class Values<T>
{
public required T Data { get; set; }
}
}
7 changes: 6 additions & 1 deletion catbag.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
<PropertyGroup>
<RootNamespace>System</RootNamespace>
<AssemblyName>Devlooped</AssemblyName>
<TargetFrameworks>net6.0</TargetFrameworks>
<TargetFrameworks>net8.0</TargetFrameworks>
<IsPackable>false</IsPackable>
<LangVersion>latest</LangVersion>
<NoWarn>$(NoWarn);CS7011</NoWarn>
<ImplicitUsings>disable</ImplicitUsings>
<UserSecretsId>02191a9e-c7d2-482c-bb42-0f5b198b37e9</UserSecretsId>
</PropertyGroup>

<ItemGroup>
Expand All @@ -16,9 +17,11 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Devlooped.Web" Version="1.2.3" />
<PackageReference Include="Microsoft.Azure.Cosmos.Table" Version="1.0.8" />
<PackageReference Include="Microsoft.Azure.Functions.Worker" Version="1.20.1" />
<PackageReference Include="Microsoft.CSharp" Version="4.7.0" />
<PackageReference Include="Microsoft.Extensions.Configuration.UserSecrets" Version="6.0.1" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="6.0.1" />
<PackageReference Include="Microsoft.Extensions.Http" Version="6.0.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.2" />
Expand All @@ -27,6 +30,8 @@
<PackageReference Include="Moq" Version="4.18.4" />
<PackageReference Include="PolySharp" Version="1.13.2" PrivateAssets="all" />
<PackageReference Include="System.Reactive" Version="6.0.0" />
<PackageReference Include="System.Text.Json" Version="9.0.0-rc.*" />
<PackageReference Include="Azure.AI.OpenAI" Version="2.0.0" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' != 'netstandard2.0'">
Expand Down
84 changes: 84 additions & 0 deletions tests/Attributes.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
using System;
using System.Collections.Generic;
using Microsoft.Extensions.Configuration;
using Xunit;

#pragma warning disable CA1050 // Declare types in namespaces
public class SecretsFactAttribute : FactAttribute
{
public SecretsFactAttribute(params string[] secrets)
{
var configuration = new ConfigurationBuilder()
.AddEnvironmentVariables()
.AddUserSecrets<SecretsFactAttribute>()
.Build();

var missing = new HashSet<string>();

foreach (var secret in secrets)
{
if (string.IsNullOrEmpty(configuration[secret]))
missing.Add(secret);
}

if (missing.Count > 0)
Skip = "Missing user secrets: " + string.Join(',', missing);
}
}

public class LocalFactAttribute : SecretsFactAttribute
{
public LocalFactAttribute(params string[] secrets) : base(secrets)
{
if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI")))
Skip = "Non-CI test";
}
}

public class CIFactAttribute : FactAttribute
{
public CIFactAttribute()
{
if (string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI")))
Skip = "CI-only test";
}
}

public class SecretsTheoryAttribute : TheoryAttribute
{
public SecretsTheoryAttribute(params string[] secrets)
{
var configuration = new ConfigurationBuilder()
.AddUserSecrets<SecretsTheoryAttribute>()
.Build();

var missing = new HashSet<string>();

foreach (var secret in secrets)
{
if (string.IsNullOrEmpty(configuration[secret]))
missing.Add(secret);
}

if (missing.Count > 0)
Skip = "Missing user secrets: " + string.Join(',', missing);
}
}

public class LocalTheoryAttribute : TheoryAttribute
{
public LocalTheoryAttribute()
{
if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI")))
Skip = "Non-CI test";
}
}

public class CITheoryAttribute : TheoryAttribute
{
public CITheoryAttribute()
{
if (string.IsNullOrEmpty(Environment.GetEnvironmentVariable("CI")))
Skip = "CI-only test";
}
}
54 changes: 54 additions & 0 deletions tests/OpenAI/Chat/ChatClientTypedExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
using System;
using System.Net.Http;
using System.Reflection;
using System.Text.Json;
using System.Threading.Tasks;
using System.Xml.Linq;
using Devlooped.Web;
using Microsoft.Extensions.Configuration;
using Xunit;
using Xunit.Abstractions;

namespace OpenAI.Chat;

public class ChatClientTypedExtensionsTests(ITestOutputHelper output)
{
[SecretsFact("OpenAI:Key")]
public async Task CanScrapPage()
{
var configuration = new ConfigurationBuilder()
.AddEnvironmentVariables()
.AddUserSecrets(Assembly.GetExecutingAssembly())
.Build();

var client = new OpenAIClient(configuration["OpenAI:Key"]!);
var chat = client.GetChatClient("gpt-4o");

using var http = new HttpClient();
http.DefaultRequestHeaders.AcceptLanguage.Add(new("en-US"));

var html = await http.GetStringAsync("https://www.imdb.com/chart/moviemeter/?ref_=nv_mv_mpm&genres=thriller&user_rating=6%2C&sort=user_rating%2Cdesc&num_votes=50000%2C");
// By parsing, we eliminate lots of noise and only keep the relevant parts (i.e. skip headers, scripts and styles).
var doc = HtmlDocument.Parse(html);

var response = await chat.CompleteChatAsync<Movie[]>([
new SystemChatMessage(
"""
You are an HTML page scraper.
You use exclusively the data in the following HTML page to parse and return a list of movies.
You perform smart type conversion and parsing as needed to fit the result schema in JSON format.
"""),
new UserChatMessage(doc.CssSelectElement("html")!.ToString(SaveOptions.DisableFormatting)),
]);

Assert.Equal(16, response.Length);
Assert.Contains(response, x => x.Title == "The Dark Knight");
Assert.Contains(response, x => x.Title == "Joker");

output.WriteLine(JsonSerializer.Serialize(response, new JsonSerializerOptions(JsonSerializerOptions.Web) { WriteIndented = true }));
}

public record Movie(string Title, int Year, TimeSpan Duration, string AgeRating, StarsRating Stars, string Url);

public record StarsRating(double Stars, long Votes);
}

0 comments on commit 43b6b7b

Please sign in to comment.