diff --git a/Core Modules/WalletConnectSharp.Common/Utils/EventUtils.cs b/Core Modules/WalletConnectSharp.Common/Utils/EventUtils.cs new file mode 100644 index 0000000..d2a426d --- /dev/null +++ b/Core Modules/WalletConnectSharp.Common/Utils/EventUtils.cs @@ -0,0 +1,54 @@ +namespace WalletConnectSharp.Common.Utils; + +public class EventUtils +{ + /// + /// Invoke the given event handler once and then unsubscribe it. + /// Use with abstract events. Otherwise, use the extension method as it is more efficient. + /// + /// + /// + /// EventUtils.ListenOnce((_, _) => WCLogger.Log("Resubscribed")), + /// h => this.Subscriber.Resubscribed += h, + /// h => this.Subscriber.Resubscribed -= h + /// ); + /// + /// + public static Action ListenOnce( + EventHandler handler, + Action subscribe, + Action unsubscribe) + { + EventHandler internalHandler = null; + internalHandler = (sender, args) => + { + unsubscribe(internalHandler); + handler(sender, args); + }; + + subscribe(internalHandler); + + return () => unsubscribe(internalHandler); + } + + /// + /// Invoke the given event handler once and then unsubscribe it. + /// Use with abstract events. Otherwise, use the extension method as it is more efficient. + /// + public static Action ListenOnce( + EventHandler handler, + Action> subscribe, + Action> unsubscribe) + { + EventHandler internalHandler = null; + internalHandler = (sender, args) => + { + unsubscribe(internalHandler); + handler(sender, args); + }; + + subscribe(internalHandler); + + return () => unsubscribe(internalHandler); + } +} diff --git a/Core Modules/WalletConnectSharp.Common/Utils/Extensions.cs b/Core Modules/WalletConnectSharp.Common/Utils/Extensions.cs index 79e30d2..8e504cf 100644 --- a/Core Modules/WalletConnectSharp.Common/Utils/Extensions.cs +++ b/Core Modules/WalletConnectSharp.Common/Utils/Extensions.cs @@ -13,7 +13,7 @@ public static class Extensions /// The object to check /// Returns true if the object has a numeric type public static bool IsNumericType(this object o) - { + { switch (Type.GetTypeCode(o.GetType())) { case TypeCode.Byte: @@ -29,7 +29,7 @@ public static bool IsNumericType(this object o) return false; } } - + /// /// Add a query parameter to the given source string /// @@ -58,7 +58,8 @@ public static string AddQueryParam( + "=" + HttpUtility.UrlEncode(value); } - public static async Task WithTimeout(this Task task, int timeout = 1000, string message = "Timeout of %t exceeded") + public static async Task WithTimeout(this Task task, int timeout = 1000, + string message = "Timeout of %t exceeded") { var resultT = await Task.WhenAny(task, Task.Delay(timeout)); if (resultT != task) @@ -66,10 +67,11 @@ public static async Task WithTimeout(this Task task, int timeout = 1000 throw new TimeoutException(message.Replace("%t", timeout.ToString())); } - return ((Task) resultT).Result; + return ((Task)resultT).Result; } - - public static async Task WithTimeout(this Task task, int timeout = 1000, string message = "Timeout of %t exceeded") + + public static async Task WithTimeout(this Task task, int timeout = 1000, + string message = "Timeout of %t exceeded") { var resultT = await Task.WhenAny(task, Task.Delay(timeout)); if (resultT != task) @@ -77,8 +79,9 @@ public static async Task WithTimeout(this Task task, int timeout = 1000, string throw new TimeoutException(message.Replace("%t", timeout.ToString())); } } - - public static async Task WithTimeout(this Task task, TimeSpan timeout, string message = "Timeout of %t exceeded") + + public static async Task WithTimeout(this Task task, TimeSpan timeout, + string message = "Timeout of %t exceeded") { var resultT = await Task.WhenAny(task, Task.Delay(timeout)); if (resultT != task) @@ -86,10 +89,11 @@ public static async Task WithTimeout(this Task task, TimeSpan timeout, throw new TimeoutException(message.Replace("%t", timeout.ToString())); } - return ((Task) resultT).Result; + return ((Task)resultT).Result; } - - public static async Task WithTimeout(this Task task, TimeSpan timeout, string message = "Timeout of %t exceeded") + + public static async Task WithTimeout(this Task task, TimeSpan timeout, + string message = "Timeout of %t exceeded") { var resultT = await Task.WhenAny(task, Task.Delay(timeout)); if (resultT != task) @@ -97,38 +101,62 @@ public static async Task WithTimeout(this Task task, TimeSpan timeout, string me throw new TimeoutException(message.Replace("%t", timeout.ToString())); } } - + public static bool SetEquals(this IEnumerable first, IEnumerable second, IEqualityComparer comparer) { return new HashSet(second, comparer ?? EqualityComparer.Default) .SetEquals(first); } - - public static Action ListenOnce(this object eventSource, string eventName, EventHandler handler) { - var eventInfo = eventSource.GetType().GetEvent(eventName); + + public static Action ListenOnce(this EventHandler eventHandler, EventHandler handler) + { EventHandler internalHandler = null; - internalHandler = (src, args) => { - eventInfo.RemoveEventHandler(eventSource, internalHandler); - handler(src, args); - }; - void RemoveListener() + internalHandler = (sender, args) => { - eventInfo.RemoveEventHandler(eventSource, internalHandler); - } - eventInfo.AddEventHandler(eventSource, internalHandler); + eventHandler -= internalHandler; + handler(sender, args); + }; + + eventHandler += internalHandler; - return RemoveListener; + return () => + { + try + { + eventHandler -= internalHandler; + } + catch (Exception e) + { + // ignored + } + }; } - public static void ListenOnce(this object eventSource, string eventName, EventHandler handler) { - var eventInfo = eventSource.GetType().GetEvent(eventName); + public static Action ListenOnce( + this EventHandler eventHandler, + EventHandler handler) + { EventHandler internalHandler = null; - internalHandler = (src, args) => { - eventInfo.RemoveEventHandler(eventSource, internalHandler); - handler(src, args); + internalHandler = (sender, args) => + { + eventHandler -= internalHandler; + handler(sender, args); + }; + + eventHandler += internalHandler; + + return () => + { + try + { + eventHandler -= internalHandler; + } + catch (Exception e) + { + // ignored + } }; - eventInfo.AddEventHandler(eventSource, internalHandler); } } } diff --git a/Core Modules/WalletConnectSharp.Network.Websocket/WebsocketConnection.cs b/Core Modules/WalletConnectSharp.Network.Websocket/WebsocketConnection.cs index 8ff0223..ff54b84 100644 --- a/Core Modules/WalletConnectSharp.Network.Websocket/WebsocketConnection.cs +++ b/Core Modules/WalletConnectSharp.Network.Websocket/WebsocketConnection.cs @@ -139,14 +139,14 @@ private async Task Register(string url) TaskCompletionSource registeringTask = new TaskCompletionSource(TaskCreationOptions.None); - this.ListenOnce(nameof(RegisterErrored), (sender, args) => + RegisterErrored.ListenOnce((sender, args) => { registeringTask.SetException(args); }); - this.ListenOnce(nameof(Opened), (sender, args) => + Opened.ListenOnce((sender, args) => { - registeringTask.SetResult(args); + registeringTask.SetResult((WebsocketClient)args); }); await registeringTask.Task; diff --git a/Tests/WalletConnectSharp.Auth.Tests/AuthClientTest.cs b/Tests/WalletConnectSharp.Auth.Tests/AuthClientTest.cs index 504ff77..cbbc1ce 100644 --- a/Tests/WalletConnectSharp.Auth.Tests/AuthClientTest.cs +++ b/Tests/WalletConnectSharp.Auth.Tests/AuthClientTest.cs @@ -24,7 +24,7 @@ public class AuthClientTests : IClassFixture, IAsyncLifetim ChainId = "eip155:1", Nonce = CryptoUtils.GenerateNonce() }; - + private readonly CryptoWalletFixture _cryptoWalletFixture; private IAuthClient PeerA; @@ -64,12 +64,12 @@ public async void TestInit() { Assert.NotNull(PeerA); Assert.NotNull(PeerB); - + Assert.NotNull(PeerA.Core); Assert.NotNull(PeerA.Core.Expirer); Assert.NotNull(PeerA.Core.History); Assert.NotNull(PeerA.Core.Pairing); - + Assert.NotNull(PeerB.Core); Assert.NotNull(PeerB.Core.Expirer); Assert.NotNull(PeerB.Core.History); @@ -93,17 +93,18 @@ public async void TestPairs() await PeerB.Core.Pairing.Pair(uri); await authRequested.Task; - - Assert.Equal(PeerA.Core.Pairing.Pairings.Select(p => p.Key), PeerB.Core.Pairing.Pairings.Select(p => p.Key)); + + Assert.Equal(PeerA.Core.Pairing.Pairings.Select(p => p.Key), + PeerB.Core.Pairing.Pairings.Select(p => p.Key)); Assert.Equal(ogPairSize + 1, PeerA.Core.Pairing.Pairings.Length); var peerAHistory = await PeerA.Core.History.JsonRpcHistoryOfType(); var peerBHistory = await PeerB.Core.History.JsonRpcHistoryOfType(); - + Assert.Equal(peerAHistory.Size, peerBHistory.Size); - + Assert.True(PeerB.Core.Pairing.Pairings[0].Active); - + // Cleanup event listeners PeerB.AuthRequested -= OnPeerBOnAuthRequested; } @@ -114,20 +115,25 @@ public async void TestKnownPairings() var ogSizeA = PeerA.Core.Pairing.Pairings.Length; var history = await PeerA.AuthHistory(); var ogHistorySizeA = history.Keys.Length; - + var ogSizeB = PeerB.Core.Pairing.Pairings.Length; var historyB = await PeerB.AuthHistory(); var ogHistorySizeB = historyB.Keys.Length; - + List responses = new List(); TaskCompletionSource responseTask = new TaskCompletionSource(); async void OnPeerBOnAuthRequested(object sender, AuthRequest request) { var message = PeerB.FormatMessage(request.Parameters.CacaoPayload, this.Iss); - var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign.SendRequestAsync(Encoding.UTF8.GetBytes(message)); + var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign + .SendRequestAsync(Encoding.UTF8.GetBytes(message)); - await PeerB.Respond(new Cacao() { Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) }, this.Iss); + await PeerB.Respond( + new Cacao() + { + Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) + }, this.Iss); Assert.Equal(Validation.Unknown, request.VerifyContext?.Validation); } @@ -162,7 +168,7 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args) await PeerB.Core.Pairing.Pair(requestData.Uri); await responseTask.Task; - + // Reset responseTask = new TaskCompletionSource(); @@ -172,15 +178,15 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args) var requestData2 = await PeerA.Request(DefaultRequestParams, knownPairing.Topic); await responseTask.Task; - + Assert.Null(requestData2.Uri); - + Assert.Equal(ogSizeA + 1, PeerA.Core.Pairing.Pairings.Length); Assert.Equal(ogHistorySizeA + 2, history.Keys.Length); Assert.Equal(ogSizeB + 1, PeerB.Core.Pairing.Pairings.Length); Assert.Equal(ogHistorySizeB + 2, historyB.Keys.Length); Assert.Equal(responses[0].Topic, responses[1].Topic); - + // Cleanup event listeners PeerB.AuthRequested -= OnPeerBOnAuthRequested; @@ -192,7 +198,7 @@ void OnPeerAOnAuthError(object sender, AuthErrorResponse args) public async void HandlesAuthRequests() { var ogSize = PeerB.Requests.Length; - + TaskCompletionSource receivedAuthRequest = new TaskCompletionSource(); void OnPeerBOnAuthRequested(object o, AuthRequest authRequest) => receivedAuthRequest.SetResult(true); @@ -206,11 +212,11 @@ public async void HandlesAuthRequests() await receivedAuthRequest.Task; Assert.Equal(ogSize + 1, PeerB.Requests.Length); - + // Cleanup event listeners PeerB.AuthRequested -= OnPeerBOnAuthRequested; } - + [Fact, Trait("Category", "unit")] public async void TestErrorResponses() { @@ -220,8 +226,14 @@ public async void TestErrorResponses() async void OnPeerBOnAuthRequested(object sender, AuthRequest request) { - await PeerB.Respond(new ErrorResponse() { Error = new Network.Models.Error() { Code = 14001, Message = "Can not login" }, Id = request.Id }, this.Iss); + await PeerB.Respond( + new ErrorResponse() + { + Error = new Network.Models.Error() { Code = 14001, Message = "Can not login" }, + Id = request.Id + }, this.Iss); } + void OnPeerAOnAuthResponded(object sender, AuthResponse response) { errorResponse.SetResult(false); @@ -241,7 +253,7 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse response) await PeerB.Core.Pairing.Pair(requestData.Uri); await errorResponse.Task; - + Assert.False(PeerA.Core.Pairing.Pairings[^1].Active); Assert.True(errorResponse.Task.Result); @@ -254,22 +266,28 @@ void OnPeerAOnAuthResponded(object sender, AuthResponse response) public async void HandlesSuccessfulResponse() { var ogPSize = PeerA.Core.Pairing.Pairings.Length; - + TaskCompletionSource successfulResponse = new TaskCompletionSource(); async void OnPeerBOnAuthRequested(object sender, AuthRequest request) { var message = PeerB.FormatMessage(request.Parameters.CacaoPayload, this.Iss); - var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign.SendRequestAsync(Encoding.UTF8.GetBytes(message)); + var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign + .SendRequestAsync(Encoding.UTF8.GetBytes(message)); - await PeerB.Respond(new ResultResponse() { Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) }, this.Iss); + await PeerB.Respond( + new ResultResponse() + { + Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) + }, this.Iss); Assert.Equal(Validation.Unknown, request.VerifyContext?.Validation); } PeerB.AuthRequested += OnPeerBOnAuthRequested; - void OnPeerAOnAuthResponded(object sender, AuthResponse response) => successfulResponse.SetResult(response.Response.Result?.Signature != null); + void OnPeerAOnAuthResponded(object sender, AuthResponse response) => + successfulResponse.SetResult(response.Response.Result?.Signature != null); PeerA.AuthResponded += OnPeerAOnAuthResponded; @@ -278,17 +296,17 @@ async void OnPeerBOnAuthRequested(object sender, AuthRequest request) PeerA.AuthError += OnPeerAOnAuthError; var requestData = await PeerA.Request(DefaultRequestParams); - + Assert.Equal(ogPSize + 1, PeerA.Core.Pairing.Pairings.Length); Assert.False(PeerA.Core.Pairing.Pairings[^1].Active); await PeerB.Core.Pairing.Pair(requestData.Uri); await successfulResponse.Task; - + Assert.True(PeerA.Core.Pairing.Pairings[^1].Active); Assert.True(successfulResponse.Task.Result); - + PeerB.AuthRequested -= OnPeerBOnAuthRequested; PeerA.AuthResponded -= OnPeerAOnAuthResponded; PeerA.AuthError -= OnPeerAOnAuthError; @@ -302,12 +320,16 @@ public async void TestCustomRequestExpiry() TaskCompletionSource resolve1 = new TaskCompletionSource(); - PeerA.Core.Relayer.Publisher.ListenOnce(nameof(PeerA.Core.Relayer.Publisher.OnPublishedMessage), (sender, args) => - { - Assert.Equal(expiry, args.Options.TTL); - resolve1.SetResult(true); - }); - + EventUtils.ListenOnce( + (sender, args) => + { + Assert.Equal(expiry, args.Options.TTL); + resolve1.SetResult(true); + }, + h => PeerA.Core.Relayer.Publisher.OnPublishedMessage += h, + h => PeerA.Core.Relayer.Publisher.OnPublishedMessage -= h + ); + await Task.WhenAll(resolve1.Task, Task.Run(async () => { var response = await PeerA.Request(new RequestParams(DefaultRequestParams) { Expiry = expiry }); @@ -319,9 +341,14 @@ await Task.WhenAll(resolve1.Task, Task.Run(async () => async void OnPeerBOnAuthRequested(object sender, AuthRequest request) { var message = PeerB.FormatMessage(request.Parameters.CacaoPayload, this.Iss); - var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign.SendRequestAsync(Encoding.UTF8.GetBytes(message)); - - await PeerB.Respond(new ResultResponse() { Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) }, this.Iss); + var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign + .SendRequestAsync(Encoding.UTF8.GetBytes(message)); + + await PeerB.Respond( + new ResultResponse() + { + Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) + }, this.Iss); resolve3.SetResult(true); } @@ -351,7 +378,7 @@ public async void TestGetPendingPairings() await receivedAuthRequest.Task; var requests = PeerB.PendingRequests; - + Assert.Equal(ogCount + 1, requests.Count); Assert.Contains(requests, r => r.Value.CacaoPayload.Aud == aud); @@ -363,7 +390,7 @@ public async void TestGetPairings() { var peerAOgSize = PeerA.Core.Pairing.Pairings.Length; var peerBOgSize = PeerB.Core.Pairing.Pairings.Length; - + TaskCompletionSource receivedAuthRequest = new TaskCompletionSource(); void OnPeerBOnAuthRequested(object sender, AuthRequest request) => receivedAuthRequest.SetResult(true); @@ -378,14 +405,14 @@ public async void TestGetPairings() var clientPairings = PeerA.Core.Pairing.Pairings; var peerPairings = PeerB.Core.Pairing.Pairings; - + Assert.Equal(peerAOgSize + 1, clientPairings.Length); Assert.Equal(peerBOgSize + 1, peerPairings.Length); Assert.Equal(clientPairings[^1].Topic, peerPairings[^1].Topic); PeerB.AuthRequested -= OnPeerBOnAuthRequested; } - + [Fact, Trait("Category", "unit")] public async void TestPing() { @@ -397,15 +424,15 @@ public async void TestPing() PeerB.AuthRequested += OnPeerBOnAuthRequested; - PeerB.Core.Pairing.ListenOnce(nameof(PeerB.Core.Pairing.PairingPinged), (sender, @event) => + EventUtils.ListenOnce((sender, @event) => { receivedPeerPing.SetResult(true); - }); - - PeerA.Core.Pairing.ListenOnce(nameof(PeerA.Core.Pairing.PairingPinged), (sender, args) => + }, h => PeerB.Core.Pairing.PairingPinged += h, h => PeerB.Core.Pairing.PairingPinged -= h); + + EventUtils.ListenOnce((sender, @event) => { receivedClientPing.SetResult(true); - }); + }, h => PeerA.Core.Pairing.PairingPinged += h, h => PeerA.Core.Pairing.PairingPinged -= h); var requestData = await PeerA.Request(DefaultRequestParams); @@ -418,7 +445,7 @@ public async void TestPing() await PeerB.Core.Pairing.Ping(pairing.Topic); await Task.WhenAll(receivedClientPing.Task, receivedPeerPing.Task); - + Assert.True(receivedClientPing.Task.Result); Assert.True(receivedPeerPing.Task.Result); @@ -438,10 +465,10 @@ public async void TestDisconnectedPairing() PeerB.AuthRequested += OnPeerBOnAuthRequested; - PeerB.Core.Pairing.ListenOnce(nameof(PeerB.Core.Pairing.PairingDeleted), (sender, args) => + EventUtils.ListenOnce((sender, args) => { peerDeletedPairing.SetResult(true); - }); + }, h => PeerB.Core.Pairing.PairingDeleted += h, h => PeerB.Core.Pairing.PairingDeleted -= h); var requestData = await PeerA.Request(DefaultRequestParams); @@ -451,7 +478,7 @@ public async void TestDisconnectedPairing() var clientPairings = PeerA.Core.Pairing.Pairings; var peerPairings = PeerB.Core.Pairing.Pairings; - + Assert.Equal(peerAOgSize + 1, PeerA.Core.Pairing.Pairings.Length); Assert.Equal(peerBOgSize + 1, PeerB.Core.Pairing.Pairings.Length); Assert.Equal(clientPairings[^1].Topic, peerPairings[^1].Topic); @@ -464,7 +491,7 @@ public async void TestDisconnectedPairing() PeerB.AuthRequested -= OnPeerBOnAuthRequested; } - + [Fact, Trait("Category", "unit")] public async void TestReceivesMetadata() { @@ -477,9 +504,14 @@ async void OnPeerBOnAuthRequested(object sender, AuthRequest request) { receivedMetadataName = request.Parameters.Requester?.Metadata?.Name; var message = PeerB.FormatMessage(request.Parameters.CacaoPayload, this.Iss); - var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign.SendRequestAsync(Encoding.UTF8.GetBytes(message)); + var signature = await CryptoWallet.GetAccount(WalletAddress).AccountSigningService.PersonalSign + .SendRequestAsync(Encoding.UTF8.GetBytes(message)); - await PeerB.Respond(new ResultResponse() { Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) }, this.Iss); + await PeerB.Respond( + new ResultResponse() + { + Id = request.Id, Signature = new Cacao.CacaoSignature.EIP191CacaoSignature(signature) + }, this.Iss); hasResponded.SetResult(true); Assert.Equal(Validation.Unknown, request.VerifyContext.Validation); @@ -488,14 +520,14 @@ async void OnPeerBOnAuthRequested(object sender, AuthRequest request) PeerB.AuthRequested += OnPeerBOnAuthRequested; var requestData = await PeerA.Request(DefaultRequestParams); - + Assert.Equal(ogPairingSize + 1, PeerA.Core.Pairing.Pairings.Length); Assert.False(PeerA.Core.Pairing.Pairings[^1].Active); await PeerB.Core.Pairing.Pair(requestData.Uri); await hasResponded.Task; - + Assert.True(PeerA.Core.Pairing.Pairings[^1].Active); Assert.True(hasResponded.Task.Result); Assert.Equal(PeerA.Metadata.Name, receivedMetadataName); diff --git a/WalletConnectSharp.Core/Controllers/Relayer.cs b/WalletConnectSharp.Core/Controllers/Relayer.cs index 62e6aa7..cc09af9 100644 --- a/WalletConnectSharp.Core/Controllers/Relayer.cs +++ b/WalletConnectSharp.Core/Controllers/Relayer.cs @@ -316,11 +316,16 @@ public async Task Subscribe(string topic, SubscribeOptions opts = null) } TaskCompletionSource task1 = new TaskCompletionSource(); - this.Subscriber.ListenOnce(nameof(this.Subscriber.Created), (sender, subscription) => - { - if (subscription.Topic == topic) - task1.TrySetResult(""); - }); + + EventUtils.ListenOnce( + (sender, subscription) => + { + if (subscription.Topic == topic) + task1.TrySetResult(""); + }, + h => this.Subscriber.Created += h, + h => this.Subscriber.Created -= h + ); return (await Task.WhenAll( task1.Task, @@ -376,10 +381,9 @@ public async Task TransportOpen(string relayUrl = null) } else { - this.Subscriber.ListenOnce(nameof(this.Subscriber.Resubscribed), (sender, args) => - { - task1.TrySetResult(true); - }); + EventUtils.ListenOnce((_, _) => task1.TrySetResult(true), + h => this.Subscriber.Resubscribed += h, + h => this.Subscriber.Resubscribed -= h); } TaskCompletionSource task2 = new TaskCompletionSource(); @@ -391,7 +395,7 @@ void RejectTransportOpen(object sender, EventArgs @event) async void Task2() { - var cleanupEvent = this.ListenOnce(nameof(OnTransportClosed), RejectTransportOpen); + var cleanupEvent = OnTransportClosed.ListenOnce(RejectTransportOpen); try { var connectionTask = this.Provider.Connect(); @@ -432,10 +436,10 @@ public async Task RestartTransport(string relayUrl = null) if (this.Connected) { TaskCompletionSource task1 = new TaskCompletionSource(); - this.Provider.ListenOnce(nameof(this.Provider.Disconnected), (sender, args) => - { - task1.TrySetResult(true); - }); + + EventUtils.ListenOnce((_, _) => task1.TrySetResult(true), + h => this.Provider.Disconnected += h, + h => this.Provider.Disconnected -= h); await Task.WhenAll(task1.Task, this.TransportClose()); } diff --git a/WalletConnectSharp.Core/Models/MessageHandler/TypedEventHandler.cs b/WalletConnectSharp.Core/Models/MessageHandler/TypedEventHandler.cs index 2f0cb34..157cb24 100644 --- a/WalletConnectSharp.Core/Models/MessageHandler/TypedEventHandler.cs +++ b/WalletConnectSharp.Core/Models/MessageHandler/TypedEventHandler.cs @@ -17,11 +17,11 @@ namespace WalletConnectSharp.Sign.Models /// The response typ to filter for public class TypedEventHandler { - protected static Dictionary> _instances = new Dictionary>(); + protected static readonly Dictionary> Instances = new(); + protected readonly ICore Ref; - protected Func, bool> requestPredicate; - protected Func, bool> responsePredicate; - protected ICore _ref; + protected Func, bool> RequestPredicate; + protected Func, bool> ResponsePredicate; /// /// Get a singleton instance of this class for the given context. The context @@ -35,24 +35,26 @@ public class TypedEventHandler public static TypedEventHandler GetInstance(ICore engine) { var context = engine.Context; - - if (_instances.ContainsKey(context)) return _instances[context]; - var _instance = new TypedEventHandler(engine); - - _instances.Add(context, _instance); + if (Instances.TryGetValue(context, out var instance)) + return instance; - return _instance; + var newInstance = new TypedEventHandler(engine); + + Instances.Add(context, newInstance); + + return newInstance; } - + /// /// The callback function delegate that handles requests of the type TRequestArgs, TResponseArgs. These /// functions are async and return a Task. /// /// The type of the request this function is for /// The type of the response this function is for - public delegate Task RequestMethod(RequestEventArgs e); - + public delegate Task + RequestMethod(RequestEventArgs e); + /// /// The callback function delegate that handles responses of the type TResponseArgs. These /// functions are async and return a Task. @@ -64,7 +66,7 @@ public static TypedEventHandler GetInstance(ICore engine) private event ResponseMethod _onResponse; private object _eventLock = new object(); private int _activeCount; - + /// /// The event handler that triggers when a new request of type /// T, TR is received. This event handler is only triggered @@ -143,9 +145,9 @@ public event ResponseMethod OnResponse protected TypedEventHandler(ICore engine) { - _ref = engine; + Ref = engine; } - + /// /// Filter request events based on the given predicate. This will return a new instance of this /// that will only fire the event handler @@ -157,10 +159,10 @@ protected TypedEventHandler(ICore engine) public virtual TypedEventHandler FilterRequests(Func, bool> predicate) { var finalPredicate = predicate; - if (this.requestPredicate != null) - finalPredicate = (rea) => this.requestPredicate(rea) && predicate(rea); + if (this.RequestPredicate != null) + finalPredicate = (rea) => this.RequestPredicate(rea) && predicate(rea); - return BuildNew(_ref, finalPredicate, responsePredicate); + return BuildNew(Ref, finalPredicate, ResponsePredicate); } /// @@ -174,25 +176,25 @@ public virtual TypedEventHandler FilterRequests(Func FilterResponses(Func, bool> predicate) { var finalPredicate = predicate; - if (this.responsePredicate != null) - finalPredicate = (rea) => this.responsePredicate(rea) && predicate(rea); + if (this.ResponsePredicate != null) + finalPredicate = (rea) => this.ResponsePredicate(rea) && predicate(rea); - return BuildNew(_ref, requestPredicate, finalPredicate); + return BuildNew(Ref, RequestPredicate, finalPredicate); } - protected virtual TypedEventHandler BuildNew(ICore _ref, Func, bool> requestPredicate, + protected virtual TypedEventHandler BuildNew(ICore _ref, + Func, bool> requestPredicate, Func, bool> responsePredicate) { return new TypedEventHandler(_ref) { - requestPredicate = requestPredicate, - responsePredicate = responsePredicate + RequestPredicate = requestPredicate, ResponsePredicate = responsePredicate }; } protected virtual void Setup() { - _ref.MessageHandler.HandleMessageType(RequestCallback, ResponseCallback); + Ref.MessageHandler.HandleMessageType(RequestCallback, ResponseCallback); } protected virtual void Teardown() @@ -203,18 +205,18 @@ protected virtual void Teardown() protected virtual Task ResponseCallback(string arg1, JsonRpcResponse arg2) { var rea = new ResponseEventArgs(arg2, arg1); - return responsePredicate != null && !responsePredicate(rea) ? Task.CompletedTask : + return ResponsePredicate != null && !ResponsePredicate(rea) ? Task.CompletedTask : _onResponse != null ? _onResponse(rea) : Task.CompletedTask; } protected virtual async Task RequestCallback(string arg1, JsonRpcRequest arg2) { VerifiedContext verifyContext = new VerifiedContext() { Validation = Validation.Unknown }; - + // Find pairing to get metadata - if (_ref.Pairing.Store.Keys.Contains(arg1)) + if (Ref.Pairing.Store.Keys.Contains(arg1)) { - var pairing = _ref.Pairing.Store.Get(arg1); + var pairing = Ref.Pairing.Store.Get(arg1); var hash = HashUtils.HashMessage(JsonConvert.SerializeObject(arg2)); verifyContext = await VerifyContext(hash, pairing.PeerMetadata); @@ -222,29 +224,31 @@ protected virtual async Task RequestCallback(string arg1, JsonRpcRequest arg2 var rea = new RequestEventArgs(arg1, arg2, verifyContext); - if (requestPredicate != null && !requestPredicate(rea)) return; + if (RequestPredicate != null && !RequestPredicate(rea)) return; if (_onRequest == null) return; await _onRequest(rea); - if (rea.Response != null || rea.Error != null) + if (rea.Error != null) { - await _ref.MessageHandler.SendResult(arg2.Id, arg1, rea.Response); + await Ref.MessageHandler.SendError(arg2.Id, arg1, rea.Error); + } + else if (rea.Response != null) + { + await Ref.MessageHandler.SendResult(arg2.Id, arg1, rea.Response); } } - + async Task VerifyContext(string hash, Metadata metadata) { var context = new VerifiedContext() { - VerifyUrl = metadata.VerifyUrl ?? "", - Validation = Validation.Unknown, - Origin = metadata.Url ?? "" + VerifyUrl = metadata.VerifyUrl ?? "", Validation = Validation.Unknown, Origin = metadata.Url ?? "" }; try { - var origin = await _ref.Verify.Resolve(hash); + var origin = await Ref.Verify.Resolve(hash); if (!string.IsNullOrWhiteSpace(origin)) { context.Origin = origin; diff --git a/WalletConnectSharp.Sign/Engine.cs b/WalletConnectSharp.Sign/Engine.cs index 7bdda1b..8006a5d 100644 --- a/WalletConnectSharp.Sign/Engine.cs +++ b/WalletConnectSharp.Sign/Engine.cs @@ -477,18 +477,22 @@ public async Task Pair(string uri) TaskCompletionSource sessionProposeTask = new TaskCompletionSource(); - Client.ListenOnce(nameof(Client.SessionProposed), (sender, args) => - { - var proposal = args.Proposal; - if (topic != proposal.PairingTopic) - return; + EventUtils.ListenOnce( + (sender, args) => + { + var proposal = args.Proposal; + if (topic != proposal.PairingTopic) + return; - if (args.VerifiedContext.Validation == Validation.Invalid) - sessionProposeTask.SetException(new Exception( - $"Could not validate, invalid validation status {args.VerifiedContext.Validation} for origin {args.VerifiedContext.Origin}")); - else - sessionProposeTask.SetResult(proposal); - }); + if (args.VerifiedContext.Validation == Validation.Invalid) + sessionProposeTask.SetException(new Exception( + $"Could not validate, invalid validation status {args.VerifiedContext.Validation} for origin {args.VerifiedContext.Origin}")); + else + sessionProposeTask.SetResult(proposal); + }, + h => Client.SessionProposed += h, + h => Client.SessionProposed -= h + ); return await sessionProposeTask.Task; } diff --git a/WalletConnectSharp.Sign/Models/SessionRequestEventHandler.cs b/WalletConnectSharp.Sign/Models/SessionRequestEventHandler.cs index 7fa28fa..5fd7de1 100644 --- a/WalletConnectSharp.Sign/Models/SessionRequestEventHandler.cs +++ b/WalletConnectSharp.Sign/Models/SessionRequestEventHandler.cs @@ -13,8 +13,8 @@ namespace WalletConnectSharp.Sign.Models /// The type of the response for the session request public class SessionRequestEventHandler : TypedEventHandler { - private IEnginePrivate _enginePrivate; - + private readonly IEnginePrivate _enginePrivate; + /// /// Get a singleton instance of this class for the given context. The context /// string of the given will be used to determine the singleton instance to @@ -24,37 +24,38 @@ public class SessionRequestEventHandler : TypedEventHandler /// The engine this singleton instance is for, and where the context string will /// be read from /// The singleton instance to use for request/response event handlers - public static new TypedEventHandler GetInstance(ICore engine, IEnginePrivate _enginePrivate) + public static TypedEventHandler GetInstance(ICore engine, IEnginePrivate enginePrivate) { var context = engine.Context; - - if (_instances.ContainsKey(context)) return _instances[context]; - var _instance = new SessionRequestEventHandler(engine, _enginePrivate); - - _instances.Add(context, _instance); + if (Instances.TryGetValue(context, out var instance)) + return instance; + + var newInstance = new SessionRequestEventHandler(engine, enginePrivate); - return _instance; + Instances.Add(context, newInstance); + + return newInstance; } - + protected SessionRequestEventHandler(ICore engine, IEnginePrivate enginePrivate) : base(engine) { this._enginePrivate = enginePrivate; } - protected override TypedEventHandler BuildNew(ICore _ref, Func, bool> requestPredicate, Func, bool> responsePredicate) + protected override TypedEventHandler BuildNew(ICore @ref, + Func, bool> requestPredicate, Func, bool> responsePredicate) { - return new SessionRequestEventHandler(_ref, _enginePrivate) + return new SessionRequestEventHandler(@ref, _enginePrivate) { - requestPredicate = requestPredicate, - responsePredicate = responsePredicate + RequestPredicate = requestPredicate, ResponsePredicate = responsePredicate }; } protected override void Setup() { - var wrappedRef = TypedEventHandler, TR>.GetInstance(_ref); - + var wrappedRef = TypedEventHandler, TR>.GetInstance(Ref); + wrappedRef.OnRequest += WrappedRefOnOnRequest; wrappedRef.OnResponse += WrappedRefOnOnResponse; } @@ -70,17 +71,18 @@ private async Task WrappedRefOnOnRequest(RequestEventArgs, TR> var method = RpcMethodAttribute.MethodForType(); var sessionRequest = e.Request.Params.Request; - + if (sessionRequest.Method != method) return; //Set inner request id to match outer request id sessionRequest.Id = e.Request.Id; - + //Add to pending requests //We can't do a simple cast, so we need to copy all the data await _enginePrivate.SetPendingSessionRequest(new PendingRequestStruct() { - Id = e.Request.Id, Parameters = new SessionRequest() + Id = e.Request.Id, + Parameters = new SessionRequest() { ChainId = e.Request.Params.ChainId, Request = new JsonRpcRequest() @@ -89,7 +91,8 @@ await _enginePrivate.SetPendingSessionRequest(new PendingRequestStruct() Method = sessionRequest.Method, Params = sessionRequest.Params } - }, Topic = e.Topic + }, + Topic = e.Topic }); await base.RequestCallback(e.Topic, sessionRequest);