Skip to content

Commit

Permalink
Adding support for Text-To-Image Settings
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerBarreto committed Sep 13, 2024
1 parent f58e689 commit 36d4fb9
Show file tree
Hide file tree
Showing 16 changed files with 584 additions and 530 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@

using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Text.Json.Nodes;
using System.Threading.Tasks;
using Azure.AI.OpenAI;
using Azure.Core;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
using Microsoft.SemanticKernel.Connectors.OpenAI;
using Microsoft.SemanticKernel.Services;
using Moq;
using OpenAI.Images;

namespace SemanticKernel.Connectors.AzureOpenAI.UnitTests.Services;

Expand All @@ -30,7 +35,7 @@ public AzureOpenAITextToImageServiceTests()
{
ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{
Content = new StringContent(File.ReadAllText("./TestData/text-to-image-response.txt"))
Content = new StringContent(File.ReadAllText("./TestData/text-to-image-response.json"))
}
};
this._httpClient = new HttpClient(this._messageHandlerStub, false);
Expand Down Expand Up @@ -143,6 +148,191 @@ public void ItShouldThrowExceptionIfNoEndpointProvided(bool useTokeCredential, s
}
}

[Theory]
[InlineData(null, null)]
[InlineData("uri", "url")]
[InlineData("url", "url")]
[InlineData("GeneratedImage.Uri", "url")]
[InlineData("bytes", "b64_json")]
[InlineData("b64_json", "b64_json")]
[InlineData("GeneratedImage.Bytes", "b64_json")]
public async Task GetUriImageContentsResponseFormatRequestWorksCorrectlyAsync(string? uri, string? expectedResponseFormat)
{
// Arrange
object? responseFormatObject = uri switch
{
"GeneratedImage.Uri" => GeneratedImageFormat.Uri,
"GeneratedImage.Bytes" => GeneratedImageFormat.Bytes,
_ => uri
};

this._httpClient.BaseAddress = new Uri("https://api-host");
var sut = new AzureOpenAITextToImageService("deployment", endpoint: null!, credential: new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient);

// Act
var result = await sut.GetImageContentsAsync("my prompt", new OpenAITextToImageExecutionSettings { ResponseFormat = responseFormatObject });

// Assert
Assert.NotNull(result);
Assert.NotNull(this._messageHandlerStub.RequestContent);

var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent);
if (expectedResponseFormat is not null)
{
Assert.Contains($"\"response_format\":\"{expectedResponseFormat}\"", requestBody);
}
else
{
// Then no response format is provided, it should not be included in the request body
Assert.DoesNotContain("response_format", requestBody);
}
}

[Theory]
[InlineData(null, null)]
[InlineData("hd", "hd")]
[InlineData("high", "hd")]
[InlineData("standard", "standard")]
public async Task GetUriImageContentsImageQualityRequestWorksCorrectlyAsync(string? quality, string? expectedQuality)
{
// Arrange
this._httpClient.BaseAddress = new Uri("https://api-host");
var sut = new AzureOpenAITextToImageService("deployment", endpoint: null!, credential: new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient);

// Act
var result = await sut.GetImageContentsAsync("my prompt", new OpenAITextToImageExecutionSettings { Quality = quality });

// Assert
Assert.NotNull(result);
Assert.NotNull(this._messageHandlerStub.RequestContent);

var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent);
if (expectedQuality is not null)
{
Assert.Contains($"\"quality\":\"{expectedQuality}\"", requestBody);
}
else
{
// Then no quality is provided, it should not be included in the request body
Assert.DoesNotContain("quality", requestBody);
}
}

[Theory]
[InlineData(null, null)]
[InlineData("vivid", "vivid")]
[InlineData("natural", "natural")]
public async Task GetUriImageContentsImageStyleRequestWorksCorrectlyAsync(string? style, string? expectedStyle)
{
// Arrange
this._httpClient.BaseAddress = new Uri("https://api-host");
var sut = new AzureOpenAITextToImageService("deployment", endpoint: null!, credential: new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient);

// Act
var result = await sut.GetImageContentsAsync("my prompt", new OpenAITextToImageExecutionSettings { Style = style });

// Assert
Assert.NotNull(result);
Assert.NotNull(this._messageHandlerStub.RequestContent);

var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent);
if (expectedStyle is not null)
{
Assert.Contains($"\"style\":\"{expectedStyle}\"", requestBody);
}
else
{
// Then no style is provided, it should not be included in the request body
Assert.DoesNotContain("style", requestBody);
}
}

[Theory]
[InlineData(null, null, null)]
[InlineData(1, 2, "1x2")]
public async Task GetUriImageContentsImageSizeRequestWorksCorrectlyAsync(int? width, int? height, string? expectedSize)
{
// Arrange
this._httpClient.BaseAddress = new Uri("https://api-host");
var sut = new AzureOpenAITextToImageService("deployment", endpoint: null!, credential: new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient);

// Act
var result = await sut.GetImageContentsAsync("my prompt", new OpenAITextToImageExecutionSettings
{
Size = width.HasValue && height.HasValue
? (width.Value, height.Value)
: null
});

// Assert
Assert.NotNull(result);
Assert.NotNull(this._messageHandlerStub.RequestContent);

var requestBody = UTF8Encoding.UTF8.GetString(this._messageHandlerStub.RequestContent);
if (expectedSize is not null)
{
Assert.Contains($"\"size\":\"{expectedSize}\"", requestBody);
}
else
{
// Then no size is provided, it should not be included in the request body
Assert.DoesNotContain("size", requestBody);
}
}

[Fact]
public async Task GetByteImageContentsResponseWorksCorrectlyAsync()
{
// Arrange
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(System.Net.HttpStatusCode.OK)
{
Content = new StringContent(File.ReadAllText("./TestData/text-to-image-b64_json-format-response.json"))
};

this._httpClient.BaseAddress = new Uri("https://api-host");
var sut = new AzureOpenAITextToImageService("deployment", endpoint: null!, credential: new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient);

// Act
var result = await sut.GetImageContentsAsync("my prompt", new OpenAITextToImageExecutionSettings { ResponseFormat = "b64_json" });

// Assert
Assert.NotNull(result);
Assert.Single(result);
var imageContent = result[0];
Assert.NotNull(imageContent);
Assert.True(imageContent.CanRead);
Assert.Equal("image/png", imageContent.MimeType);
Assert.NotNull(imageContent.InnerContent);
Assert.IsType<GeneratedImage>(imageContent.InnerContent);

var breakingGlass = imageContent.InnerContent as GeneratedImage;
Assert.Equal("my prompt", breakingGlass!.RevisedPrompt);
}

[Fact]
public async Task GetUrlImageContentsResponseWorksCorrectlyAsync()
{
// Arrange
this._httpClient.BaseAddress = new Uri("https://api-host");
var sut = new AzureOpenAITextToImageService("deployment", endpoint: null!, credential: new Mock<TokenCredential>().Object, "dall-e-3", this._httpClient);

// Act
var result = await sut.GetImageContentsAsync("my prompt", new OpenAITextToImageExecutionSettings { ResponseFormat = "url" });

// Assert
Assert.NotNull(result);
Assert.Single(result);
var imageContent = result[0];
Assert.NotNull(imageContent);
Assert.False(imageContent.CanRead);
Assert.Equal(new Uri("https://image-url/"), imageContent.Uri);
Assert.NotNull(imageContent.InnerContent);
Assert.IsType<GeneratedImage>(imageContent.InnerContent);

var breakingGlass = imageContent.InnerContent as GeneratedImage;
Assert.Equal("my prompt", breakingGlass!.RevisedPrompt);
}

public void Dispose()
{
this._httpClient.Dispose();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"created": 1726234481,
"data": [
{
"b64_json": "iVBORw0KGgoAAA==",
"revised_prompt": "my prompt"
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"created": 1702575371,
"data": [
{
"revised_prompt": "my prompt",
"url": "https://image-url/"
}
]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,8 @@ public AzureOpenAITextToImageService(
{
Verify.NotNullOrWhiteSpace(apiKey);

var connectorEndpoint = !string.IsNullOrWhiteSpace(endpoint) ? endpoint! : httpClient?.BaseAddress?.AbsoluteUri;
if (connectorEndpoint is null)
{
throw new ArgumentException($"The {nameof(httpClient)}.{nameof(HttpClient.BaseAddress)} and {nameof(endpoint)} are both null or empty. Please ensure at least one is provided.");
}
var connectorEndpoint = (!string.IsNullOrWhiteSpace(endpoint) ? endpoint! : httpClient?.BaseAddress?.AbsoluteUri)
?? throw new ArgumentException($"The {nameof(httpClient)}.{nameof(HttpClient.BaseAddress)} and {nameof(endpoint)} are both null or empty. Please ensure at least one is provided.");

var options = AzureClientCore.GetAzureOpenAIClientOptions(
httpClient,
Expand Down Expand Up @@ -87,11 +84,8 @@ public AzureOpenAITextToImageService(
{
Verify.NotNull(credential);

var connectorEndpoint = !string.IsNullOrWhiteSpace(endpoint) ? endpoint! : httpClient?.BaseAddress?.AbsoluteUri;
if (connectorEndpoint is null)
{
throw new ArgumentException($"The {nameof(httpClient)}.{nameof(HttpClient.BaseAddress)} and {nameof(endpoint)} are both null or empty. Please ensure at least one is provided.");
}
var connectorEndpoint = (!string.IsNullOrWhiteSpace(endpoint) ? endpoint! : httpClient?.BaseAddress?.AbsoluteUri)
?? throw new ArgumentException($"The {nameof(httpClient)}.{nameof(HttpClient.BaseAddress)} and {nameof(endpoint)} are both null or empty. Please ensure at least one is provided.");

var options = AzureClientCore.GetAzureOpenAIClientOptions(
httpClient,
Expand Down Expand Up @@ -133,4 +127,8 @@ public AzureOpenAITextToImageService(
/// <inheritdoc/>
public Task<string> GenerateImageAsync(string description, int width, int height, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._client.GenerateImageAsync(this._client.DeploymentName, description, width, height, cancellationToken);

/// <inheritdoc/>
public Task<IReadOnlyList<ImageContent>> GetImageContentsAsync(TextContent input, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
=> this._client.GetImageContentsAsync(input, executionSettings, kernel, cancellationToken);
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@
<None Update="TestData\text-embeddings-response.txt">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="TestData\text-to-image-response.txt">
<None Update="TestData\text-to-image-b64_json-format-response.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
<None Update="TestData\text-to-image-response.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</None>
</ItemGroup>
Expand Down
Loading

0 comments on commit 36d4fb9

Please sign in to comment.