Skip to content

Commit

Permalink
GH-43907: [C#][FlightRPC] Add Grpc Call Options support on Flight Cli…
Browse files Browse the repository at this point in the history
…ent (#43910)

### Rationale for this change

This implementation add default grpc call options on the csharp implementation FlightClient

### What changes are included in this PR?

- FlightClient.cs with updated signature for all the methods accepting grpc call options
- FlightTest.cs update test to verify the raise of the right exception

### Are these changes tested?

Yes, tests are added in FlightTest.cs
I've tested locally with the C++ implementation.

### Are there any user-facing changes?

No is transparent for the user, following the already present documentation should be sufficient.

### References

* GitHub Issue: #43907

Authored-by: Marco Malagoli <mmalagoli@board.com>
Signed-off-by: Curt Hagenlocher <curt@hagenlocher.org>
  • Loading branch information
qmmk committed Sep 3, 2024
1 parent 170c599 commit b0786d4
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 16 deletions.
69 changes: 58 additions & 11 deletions csharp/src/Apache.Arrow.Flight/Client/FlightClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Internal;
using Apache.Arrow.Flight.Protocol;
Expand All @@ -34,35 +35,55 @@ public FlightClient(ChannelBase grpcChannel)

public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria = null, Metadata headers = null)
{
if(criteria == null)
return ListFlights(criteria, headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightInfo> ListFlights(FlightCriteria criteria, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
if (criteria == null)
{
criteria = FlightCriteria.Empty;
}
var response = _client.ListFlights(criteria.ToProtocol(), headers);

var response = _client.ListFlights(criteria.ToProtocol(), headers, deadline, cancellationToken);
var convertStream = new StreamReader<Protocol.FlightInfo, FlightInfo>(response.ResponseStream, inFlight => new FlightInfo(inFlight));

return new AsyncServerStreamingCall<FlightInfo>(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}

public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata headers = null)
{
var response = _client.ListActions(EmptyInstance, headers);
return ListActions(headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightActionType> ListActions(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var response = _client.ListActions(EmptyInstance, headers, deadline, cancellationToken);
var convertStream = new StreamReader<Protocol.ActionType, FlightActionType>(response.ResponseStream, actionType => new FlightActionType(actionType));

return new AsyncServerStreamingCall<FlightActionType>(convertStream, response.ResponseHeadersAsync, response.GetStatus, response.GetTrailers, response.Dispose);
}

public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers = null)
{
var stream = _client.DoGet(ticket.ToProtocol(), headers);
return GetStream(ticket, headers, null, CancellationToken.None);
}

public FlightRecordBatchStreamingCall GetStream(FlightTicket ticket, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var stream = _client.DoGet(ticket.ToProtocol(), headers, deadline, cancellationToken);
var responseStream = new FlightClientRecordBatchStreamReader(stream.ResponseStream);
return new FlightRecordBatchStreamingCall(responseStream, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose);
}

public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers);
return GetInfo(flightDescriptor, headers, null, CancellationToken.None);
}

public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var flightInfoResult = _client.GetFlightInfoAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken);

var flightInfo = flightInfoResult
.ResponseAsync
Expand All @@ -79,7 +100,12 @@ public AsyncUnaryCall<FlightInfo> GetInfo(FlightDescriptor flightDescriptor, Met

public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var channels = _client.DoPut(headers);
return StartPut(flightDescriptor, headers, null, CancellationToken.None);
}

public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channels = _client.DoPut(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channels.RequestStream, flightDescriptor);
var readStream = new StreamReader<Protocol.PutResult, FlightPutResult>(channels.ResponseStream, putResult => new FlightPutResult(putResult));
return new FlightRecordBatchDuplexStreamingCall(
Expand All @@ -93,7 +119,13 @@ public FlightRecordBatchDuplexStreamingCall StartPut(FlightDescriptor flightDesc

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers = null)
{
var channel = _client.Handshake(headers);
return Handshake(headers, null, CancellationToken.None);

}

public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse> Handshake(Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channel = _client.Handshake(headers, deadline, cancellationToken);
var readStream = new StreamReader<HandshakeResponse, FlightHandshakeResponse>(channel.ResponseStream, response => new FlightHandshakeResponse(response));
var writeStream = new FlightHandshakeStreamWriterAdapter(channel.RequestStream);
var call = new AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse>(
Expand All @@ -109,7 +141,12 @@ public AsyncDuplexStreamingCall<FlightHandshakeRequest, FlightHandshakeResponse>

public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var channel = _client.DoExchange(headers);
return DoExchange(flightDescriptor, headers, null, CancellationToken.None);
}

public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var channel = _client.DoExchange(headers, deadline, cancellationToken);
var requestStream = new FlightClientRecordBatchStreamWriter(channel.RequestStream, flightDescriptor);
var responseStream = new FlightClientRecordBatchStreamReader(channel.ResponseStream);
var call = new FlightRecordBatchExchangeCall(
Expand All @@ -125,14 +162,24 @@ public FlightRecordBatchExchangeCall DoExchange(FlightDescriptor flightDescripto

public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction action, Metadata headers = null)
{
var stream = _client.DoAction(action.ToProtocol(), headers);
return DoAction(action, headers, null, CancellationToken.None);
}

public AsyncServerStreamingCall<FlightResult> DoAction(FlightAction action, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var stream = _client.DoAction(action.ToProtocol(), headers, deadline, cancellationToken);
var streamReader = new StreamReader<Protocol.Result, FlightResult>(stream.ResponseStream, result => new FlightResult(result));
return new AsyncServerStreamingCall<FlightResult>(streamReader, stream.ResponseHeadersAsync, stream.GetStatus, stream.GetTrailers, stream.Dispose);
}

public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor flightDescriptor, Metadata headers = null)
{
var schemaResult = _client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers);
return GetSchema(flightDescriptor, headers, null, CancellationToken.None);
}

public AsyncUnaryCall<Schema> GetSchema(FlightDescriptor flightDescriptor, Metadata headers, System.DateTime? deadline, CancellationToken cancellationToken = default)
{
var schemaResult = _client.GetSchemaAsync(flightDescriptor.ToProtocol(), headers, deadline, cancellationToken);

var schema = schemaResult
.ResponseAsync
Expand Down
97 changes: 92 additions & 5 deletions csharp/test/Apache.Arrow.Flight.Tests/FlightTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Apache.Arrow.Flight.Client;
using Apache.Arrow.Flight.TestWeb;
using Apache.Arrow.Tests;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Core.Utils;
using Python.Runtime;
using Xunit;

namespace Apache.Arrow.Flight.Tests
Expand Down Expand Up @@ -70,7 +73,7 @@ private FlightInfo GivenStoreBatches(FlightDescriptor flightDescriptor, params R

var flightHolder = new FlightHolder(flightDescriptor, initialBatch.RecordBatch.Schema, _testWebFactory.GetAddress());

foreach(var batch in batches)
foreach (var batch in batches)
{
flightHolder.AddBatch(batch);
}
Expand Down Expand Up @@ -187,8 +190,8 @@ public async Task TestGetFlightMetadata()

var getStream = _flightClient.GetStream(endpoint.Ticket);

List<ByteString> actualMetadata = new List<ByteString>();
while(await getStream.ResponseStream.MoveNext(default))
List<ByteString> actualMetadata = new List<ByteString>();
while (await getStream.ResponseStream.MoveNext(default))
{
actualMetadata.AddRange(getStream.ResponseStream.ApplicationMetadata);
}
Expand Down Expand Up @@ -277,7 +280,7 @@ public async Task TestListFlights()

var actualFlights = await listFlightStream.ResponseStream.ToListAsync();

for(int i = 0; i < expectedFlightInfo.Count; i++)
for (int i = 0; i < expectedFlightInfo.Count; i++)
{
FlightInfoComparer.Compare(expectedFlightInfo[i], actualFlights[i]);
}
Expand Down Expand Up @@ -386,7 +389,7 @@ public async Task TestGetBatchesWithAsyncEnumerable()


List<RecordBatch> resultList = new List<RecordBatch>();
await foreach(var recordBatch in getStream.ResponseStream)
await foreach (var recordBatch in getStream.ResponseStream)
{
resultList.Add(recordBatch);
}
Expand Down Expand Up @@ -415,5 +418,89 @@ public async Task EnsureTheSerializedBatchContainsTheProperTotalRecordsAndTotalB
Assert.Equal(expectedBatch.Length, result.TotalRecords);
Assert.Equal(expectedTotalBytes, result.TotalBytes);
}

[Fact]
public async Task EnsureCallRaisesDeadlineExceeded()
{
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_deadline");
var deadline = DateTime.UtcNow;
var batch = CreateTestBatch(0, 100);

RpcException exception = null;

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode);

var asyncServerStreamingCallActions = _flightClient.ListActions(null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, asyncServerStreamingCallFlights.GetStatus().StatusCode);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetInfo(flightDescriptor, null, deadline));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var flightInfo = await _flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();
var getStream = _flightClient.GetStream(endpoint.Ticket, null, deadline);
Assert.Equal(StatusCode.DeadlineExceeded, getStream.GetStatus().StatusCode);

var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, deadline));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);

var handshakeStreamingCall = _flightClient.Handshake(null, deadline);
exception = await Assert.ThrowsAsync<RpcException>(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)));
Assert.Equal(StatusCode.DeadlineExceeded, exception.StatusCode);
}

[Fact]
public async Task EnsureCallRaisesRequestCancelled()
{
var cts = new CancellationTokenSource();
cts.CancelAfter(1);

var batch = CreateTestBatch(0, 100);
var metadata = new Metadata();
var flightDescriptor = FlightDescriptor.CreatePathDescriptor("raise_cancelled");
await Task.Delay(5);
RpcException exception = null;

var asyncServerStreamingCallFlights = _flightClient.ListFlights(null, null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode);

var asyncServerStreamingCallActions = _flightClient.ListActions(null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, asyncServerStreamingCallFlights.GetStatus().StatusCode);

GivenStoreBatches(flightDescriptor, new RecordBatchWithMetadata(batch));
exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetInfo(flightDescriptor, null, null, cts.Token));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var flightInfo = await _flightClient.GetInfo(flightDescriptor);
var endpoint = flightInfo.Endpoints.FirstOrDefault();
var getStream = _flightClient.GetStream(endpoint.Ticket, null, null, cts.Token);
Assert.Equal(StatusCode.Cancelled, getStream.GetStatus().StatusCode);

var duplexStreamingCall = _flightClient.DoExchange(flightDescriptor, null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await duplexStreamingCall.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var putStream = _flightClient.StartPut(flightDescriptor, null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await putStream.RequestStream.WriteAsync(batch));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

exception = await Assert.ThrowsAsync<RpcException>(async () => await _flightClient.GetSchema(flightDescriptor, null, null, cts.Token));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

var handshakeStreamingCall = _flightClient.Handshake(null, null, cts.Token);
exception = await Assert.ThrowsAsync<RpcException>(async () => await handshakeStreamingCall.RequestStream.WriteAsync(new FlightHandshakeRequest(ByteString.Empty)));
Assert.Equal(StatusCode.Cancelled, exception.StatusCode);

}
}
}

0 comments on commit b0786d4

Please sign in to comment.