Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

.Net Add support for quality and style for OpenAI image generation #8064

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ public void ConstructorsAddRequiredMetadata()
}

[Theory]
[InlineData(256, 256, "dall-e-2")]
[InlineData(512, 512, "dall-e-2")]
[InlineData(1024, 1024, "dall-e-2")]
[InlineData(1024, 1024, "dall-e-3")]
[InlineData(1024, 1792, "dall-e-3")]
[InlineData(1792, 1024, "dall-e-3")]
[InlineData(123, 321, "custom-model-1")]
[InlineData(179, 124, "custom-model-2")]
public async Task GenerateImageWorksCorrectlyAsync(int width, int height, string modelId)
[InlineData(256, 256, "dall-e-2", "HIGH", "VIVID")]
[InlineData(512, 512, "dall-e-2", "STANDARD", "NATURAL")]
[InlineData(1024, 1024, "dall-e-2", "HIGH", "NATURAL")]
[InlineData(1024, 1024, "dall-e-3", "STANDARD", "VIVID")]
[InlineData(1024, 1792, "dall-e-3", "HIGH", "VIVID")]
[InlineData(1792, 1024, "dall-e-3", "STANDARD", "NATURAL")]
[InlineData(123, 321, "custom-model-1", "HIGH", "VIVID")]
[InlineData(179, 124, "custom-model-2", "STANDARD", "NATURAL")]
public async Task GenerateImageWorksCorrectlyAsync(int width, int height, string modelId, string quality, string style)
{
// Arrange
var sut = new AzureOpenAITextToImageService("deployment", "https://api-host", "api-key", modelId, this._httpClient, loggerFactory: this._mockLoggerFactory.Object);

// Act
var result = await sut.GenerateImageAsync("description", width, height);
var result = await sut.GenerateImageAsync("description", width, height, quality, style);

// Assert
Assert.Equal("https://image-url/", result);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,6 @@ public AzureOpenAITextToImageService(
}

/// <inheritdoc/>
public Task<string> GenerateImageAsync(string description, int width, int height, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._client.GenerateImageAsync(this._client.DeploymentName, description, width, height, cancellationToken);
public Task<string> GenerateImageAsync(string description, int width, int height, string quality = "HIGH", string style = "VIVID", Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._client.GenerateImageAsync(this._client.DeploymentName, description, width, height, quality, style, cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,22 @@ public void ConstructorWorksCorrectly()
}

[Theory]
[InlineData(256, 256, "dall-e-2")]
[InlineData(512, 512, "dall-e-2")]
[InlineData(1024, 1024, "dall-e-2")]
[InlineData(1024, 1024, "dall-e-3")]
[InlineData(1024, 1792, "dall-e-3")]
[InlineData(1792, 1024, "dall-e-3")]
[InlineData(123, 321, "custom-model-1")]
[InlineData(179, 124, "custom-model-2")]
public async Task GenerateImageWorksCorrectlyAsync(int width, int height, string modelId)
[InlineData(256, 256, "dall-e-2", "HIGH", "VIVID")]
[InlineData(512, 512, "dall-e-2", "STANDARD", "NATURAL")]
[InlineData(1024, 1024, "dall-e-2", "HIGH", "NATURAL")]
[InlineData(1024, 1024, "dall-e-3", "STANDARD", "VIVID")]
[InlineData(1024, 1792, "dall-e-3", "HIGH", "VIVID")]
[InlineData(1792, 1024, "dall-e-3", "STANDARD", "NATURAL")]
[InlineData(123, 321, "custom-model-1", "HIGH", "VIVID")]
[InlineData(179, 124, "custom-model-2", "STANDARD", "NATURAL")]
public async Task GenerateImageWorksCorrectlyAsync(int width, int height, string modelId, string quality, string style)
{
// Arrange
var sut = new OpenAITextToImageService("api-key", modelId: modelId, httpClient: this._httpClient);
Assert.Equal(modelId, sut.Attributes["ModelId"]);

// Act
var result = await sut.GenerateImageAsync("description", width, height);
var result = await sut.GenerateImageAsync("description", width, height, quality, style);

// Assert
Assert.Equal("https://image-url/", result);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.ClientModel;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -19,13 +20,17 @@ internal partial class ClientCore
/// <param name="prompt">Prompt to generate the image</param>
/// <param name="width">Width of the image</param>
/// <param name="height">Height of the image</param>
/// <param name="quality">The quality of the generated image</param>
/// <param name="style">The style of the generated image</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Url of the generated image</returns>
internal async Task<string> GenerateImageAsync(
string? targetModel,
string prompt,
int width,
int height,
string quality,
string style,
CancellationToken cancellationToken)
{
Verify.NotNullOrWhiteSpace(prompt);
Expand All @@ -35,7 +40,9 @@ internal async Task<string> GenerateImageAsync(
var imageOptions = new ImageGenerationOptions()
{
Size = size,
ResponseFormat = GeneratedImageFormat.Uri
ResponseFormat = GeneratedImageFormat.Uri,
Quality = GetGeneratedImageQuality(quality),
Style = GetGeneratedImageStyle(style)
};

// The model is not required by the OpenAI API and defaults to the DALL-E 2 server-side - https://platform.openai.com/docs/api-reference/images/create#images-create-model.
Expand All @@ -47,4 +54,20 @@ internal async Task<string> GenerateImageAsync(

return generatedImage.ImageUri?.ToString() ?? throw new KernelException("The generated image is not in url format");
}

private static GeneratedImageQuality GetGeneratedImageQuality(string? quality)
=> quality?.ToUpperInvariant() switch
{
"HIGH" => GeneratedImageQuality.High,
"STANDARD" => GeneratedImageQuality.Standard,
_ => throw new NotSupportedException($"The image quality '{quality}' is not supported."),
};

private static GeneratedImageStyle GetGeneratedImageStyle(string? style)
=> style?.ToUpperInvariant() switch
{
"VIVID" => GeneratedImageStyle.Vivid,
"NATURAL" => GeneratedImageStyle.Natural,
_ => throw new NotSupportedException($"The image style '{style}' is not supported."),
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@ public OpenAITextToImageService(
}

/// <inheritdoc/>
public Task<string> GenerateImageAsync(string description, int width, int height, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._client.GenerateImageAsync(this._client.ModelId, description, width, height, cancellationToken);
public Task<string> GenerateImageAsync(string description, int width, int height, string quality = "HIGH", string style = "VIVID", Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._client.GenerateImageAsync(this._client.ModelId, description, width, height, quality, style, cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public interface ITextToImageService : IAIService
/// <param name="description">Image generation prompt</param>
/// <param name="width">Image width in pixels</param>
/// <param name="height">Image height in pixels</param>
/// <param name="quality">The quality of the generated image</param>
/// <param name="style">The style of the generated image</param>
/// <param name="kernel">The <see cref="Kernel"/> containing services, plugins, and other state for use throughout the operation.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Generated image in base64 format or image URL</returns>
Expand All @@ -27,6 +29,8 @@ public Task<string> GenerateImageAsync(
string description,
int width,
int height,
string quality = "HIGH",
string style = "VIVID",
Kernel? kernel = null,
CancellationToken cancellationToken = default);
}
Loading