Skip to content

Commit

Permalink
Support WinHttp
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesNK committed Feb 17, 2021
1 parent 96750a3 commit 31eaaf8
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 48 deletions.
20 changes: 13 additions & 7 deletions src/Grpc.Net.Client/Internal/GrpcCall.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,7 @@ public CancellationToken CancellationToken
IClientStreamWriter<TRequest>? IGrpcCall<TRequest, TResponse>.ClientStreamWriter => ClientStreamWriter;
IAsyncStreamReader<TResponse>? IGrpcCall<TRequest, TResponse>.ClientStreamReader => ClientStreamReader;

public void StartUnary(TRequest request) => StartUnaryCore(new PushUnaryContent<TRequest, TResponse>(stream =>
{
return WriteMessageAsync(stream, request, Options);
}));
public void StartUnary(TRequest request) => StartUnaryCore(CreatePushUnaryContent(request));

public void StartClientStreaming()
{
Expand All @@ -119,10 +116,19 @@ public void StartClientStreaming()
StartClientStreamingCore(clientStreamWriter, content);
}

public void StartServerStreaming(TRequest request) => StartServerStreamingCore(new PushUnaryContent<TRequest, TResponse>(stream =>
public void StartServerStreaming(TRequest request) => StartServerStreamingCore(CreatePushUnaryContent(request));

private HttpContent CreatePushUnaryContent(TRequest request)
{
return WriteMessageAsync(stream, request, Options);
}));
return !Channel.IsWinHttp
? new PushUnaryContent<TRequest, TResponse>(request, WriteAsync)
: new WinHttpUnaryContent<TRequest, TResponse>(request, WriteAsync, this);

ValueTask WriteAsync(TRequest request, Stream stream)
{
return WriteMessageAsync(stream, request, Options);
}
}

public void StartDuplexStreaming()
{
Expand Down
8 changes: 5 additions & 3 deletions src/Grpc.Net.Client/Internal/Http/PushUnaryContent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,20 @@ internal class PushUnaryContent<TRequest, TResponse> : HttpContent
where TRequest : class
where TResponse : class
{
private readonly Func<Stream, ValueTask> _startCallback;
private readonly TRequest _request;
private readonly Func<TRequest, Stream, ValueTask> _startCallback;

public PushUnaryContent(Func<Stream, ValueTask> startCallback)
public PushUnaryContent(TRequest request, Func<TRequest, Stream, ValueTask> startCallback)
{
_request = request;
_startCallback = startCallback;
Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue;
}

protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context)
{
#pragma warning disable CA2012 // Use ValueTasks correctly
var writeMessageTask = _startCallback(stream);
var writeMessageTask = _startCallback(_request, stream);
#pragma warning restore CA2012 // Use ValueTasks correctly
if (writeMessageTask.IsCompletedSuccessfully())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,64 +37,66 @@ namespace Grpc.Net.Client.Internal.Http
/// The payload is then written directly to the request using specialized context
/// and serializer method.
/// </summary>
internal class LengthUnaryContent<TRequest, TResponse> : HttpContent
internal class WinHttpUnaryContent<TRequest, TResponse> : HttpContent
where TRequest : class
where TResponse : class
{
private readonly TRequest _content;
private readonly TRequest _request;
private readonly Func<TRequest, Stream, ValueTask> _startCallback;
private readonly GrpcCall<TRequest, TResponse> _call;
private byte[]? _payload;

public LengthUnaryContent(TRequest content, GrpcCall<TRequest, TResponse> call, MediaTypeHeaderValue mediaType)
public WinHttpUnaryContent(TRequest request, Func<TRequest, Stream, ValueTask> startCallback, GrpcCall<TRequest, TResponse> call)
{
_content = content;
_request = request;
_startCallback = startCallback;
_call = call;
Headers.ContentType = mediaType;
Headers.ContentType = GrpcProtocolConstants.GrpcContentTypeHeaderValue;
}

// Serialize message. Need to know size to prefix the length in the header.
private byte[] SerializePayload()
protected override Task SerializeToStreamAsync(Stream stream, TransportContext? context)
{
var serializationContext = _call.SerializationContext;
serializationContext.CallOptions = _call.Options;
serializationContext.Initialize();

try
{
_call.Method.RequestMarshaller.ContextualSerializer(_content, serializationContext);

return serializationContext.GetWrittenPayload().ToArray();
}
finally
#pragma warning disable CA2012 // Use ValueTasks correctly
var writeMessageTask = _startCallback(_request, stream);
#pragma warning restore CA2012 // Use ValueTasks correctly
if (writeMessageTask.IsCompletedSuccessfully())
{
serializationContext.Reset();
GrpcEventSource.Log.MessageSent();
return Task.CompletedTask;
}

return WriteMessageCore(writeMessageTask);
}

protected override async Task SerializeToStreamAsync(Stream stream, TransportContext? context)
private static async Task WriteMessageCore(ValueTask writeMessageTask)
{
if (_payload == null)
{
_payload = SerializePayload();
}

await _call.WriteMessageAsync(
stream,
_payload,
_call.Options.CancellationToken).ConfigureAwait(false);

await writeMessageTask.ConfigureAwait(false);
GrpcEventSource.Log.MessageSent();
}

protected override bool TryComputeLength(out long length)
{
if (_payload == null)
// This will serialize the request message again.
// Consider caching serialized content if it is a problem.
length = GetPayloadLength();
return true;
}

private int GetPayloadLength()
{
var serializationContext = _call.SerializationContext;
serializationContext.CallOptions = _call.Options;
serializationContext.Initialize();

try
{
_payload = SerializePayload();
}
_call.Method.RequestMarshaller.ContextualSerializer(_request, serializationContext);

length = _payload.Length;
return true;
return serializationContext.GetWrittenPayload().Length;
}
finally
{
serializationContext.Reset();
}
}
}
}
12 changes: 10 additions & 2 deletions src/Grpc.Net.Client/Internal/Retry/RetryCallBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using Grpc.Core;
Expand Down Expand Up @@ -168,9 +169,16 @@ public void StartDuplexStreaming()
});
}

private PushUnaryContent<TRequest, TResponse> CreatePushUnaryContent(TRequest request, GrpcCall<TRequest, TResponse> call)
private HttpContent CreatePushUnaryContent(TRequest request, GrpcCall<TRequest, TResponse> call)
{
return new PushUnaryContent<TRequest, TResponse>(stream => WriteNewMessage(call, stream, call.Options, request));
return !Channel.IsWinHttp
? new PushUnaryContent<TRequest, TResponse>(request, WriteAsync)
: new WinHttpUnaryContent<TRequest, TResponse>(request, WriteAsync, call);

ValueTask WriteAsync(TRequest request, Stream stream)
{
return WriteNewMessage(call, stream, call.Options, request);
}
}

private PushStreamContent<TRequest, TResponse> CreatePushStreamContent(GrpcCall<TRequest, TResponse> call, HttpContentClientStreamWriter<TRequest, TResponse> clientStreamWriter)
Expand Down
34 changes: 34 additions & 0 deletions test/Grpc.Net.Client.Tests/AsyncUnaryCallTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated()
Assert.AreEqual(new MediaTypeHeaderValue("application/grpc"), httpRequestMessage.Content?.Headers?.ContentType);
Assert.AreEqual(GrpcProtocolConstants.TEHeaderValue, httpRequestMessage.Headers.TE.Single().Value);
Assert.AreEqual("identity,gzip", httpRequestMessage.Headers.GetValues(GrpcProtocolConstants.MessageAcceptEncodingHeader).Single());
Assert.AreEqual(null, httpRequestMessage!.Content!.Headers!.ContentLength);

var userAgent = httpRequestMessage.Headers.UserAgent.Single()!;
Assert.AreEqual("grpc-dotnet", userAgent.Product?.Name);
Expand All @@ -83,6 +84,39 @@ public async Task AsyncUnaryCall_Success_HttpRequestMessagePopulated()
Assert.IsTrue(userAgent.Product!.Version!.Length <= 10);
}

[Test]
public async Task AsyncUnaryCall_HasWinHttpHandler_ContentLengthOnHttpRequestMessagePopulated()
{
// Arrange
HttpRequestMessage? httpRequestMessage = null;

var handler = TestHttpMessageHandler.Create(async request =>
{
httpRequestMessage = request;
HelloReply reply = new HelloReply
{
Message = "Hello world"
};
var streamContent = await ClientTestHelpers.CreateResponseContent(reply).DefaultTimeout();
return ResponseUtils.CreateResponse(HttpStatusCode.OK, streamContent);
});
// Just need to have a type called WinHttpHandler to activate new behavior.
var winHttpHandler = new WinHttpHandler(handler);
var invoker = HttpClientCallInvokerFactory.Create(winHttpHandler, "https://localhost");

// Act
var rs = await invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions(), new HelloRequest { Name = "Hello world" });

// Assert
Assert.AreEqual("Hello world", rs.Message);

Assert.IsNotNull(httpRequestMessage);
Assert.AreEqual(18, httpRequestMessage!.Content!.Headers!.ContentLength);
}

[Test]
public async Task AsyncUnaryCall_Success_RequestContentSent()
{
Expand Down
28 changes: 28 additions & 0 deletions test/Grpc.Net.Client.Tests/Infrastructure/WinHttpHandler.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#endregion

// Namespace and class name needs to resolve to System.Net.Http.WinHttpHandler.
namespace System.Net.Http
{
public class WinHttpHandler : DelegatingHandler
{
public WinHttpHandler(HttpMessageHandler innerHandler) : base(innerHandler)
{
}
}
}

0 comments on commit 31eaaf8

Please sign in to comment.