From 00fbabe3b47ce4802ca94ced9af8d9fc4695474a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20R=C3=B8nne=20Petersen?= Date: Wed, 17 Jan 2024 08:00:37 +0100 Subject: [PATCH] Implement terminal I/O cancellation on Windows. Closes #63. --- src/core/Native/TerminalInterop.cs | 11 +++- src/core/Terminals/NativeTerminalReader.cs | 21 +++---- src/core/Terminals/NativeTerminalWriter.cs | 21 +++---- src/core/Terminals/NativeVirtualTerminal.cs | 7 ++- src/core/Terminals/UnixVirtualTerminal.cs | 18 +++--- src/core/Terminals/WindowsVirtualTerminal.cs | 12 +++- src/native/driver-unix.c | 13 ++-- src/native/driver-windows.c | 62 ++++++++++---------- src/native/driver-windows.h | 2 +- src/native/driver.h | 17 +++--- src/samples/cancellation/Program.cs | 35 ++++++----- src/samples/raw/Program.cs | 3 +- 12 files changed, 117 insertions(+), 105 deletions(-) diff --git a/src/core/Native/TerminalInterop.cs b/src/core/Native/TerminalInterop.cs index 3baf86d..00ebe03 100644 --- a/src/core/Native/TerminalInterop.cs +++ b/src/core/Native/TerminalInterop.cs @@ -11,6 +11,7 @@ public enum TerminalException { None, ArgumentOutOfRange, + OperationCanceled, PlatformNotSupported, TerminalNotAttached, TerminalConfiguration, @@ -28,11 +29,11 @@ public struct TerminalResult public readonly void ThrowIfError() { - // For when ArgumentOutOfRangeException is not expected. + // For when ArgumentOutOfRangeException and/or OperationCanceledException are not expected. ThrowIfError(value: (object?)null); } - public readonly void ThrowIfError(in T value, [CallerArgumentExpression(nameof(value))] string? name = null) + public readonly void ThrowIfError(T value, [CallerArgumentExpression(nameof(value))] string? name = null) { _ = value; @@ -43,6 +44,8 @@ public readonly void ThrowIfError(in T value, [CallerArgumentExpression(nameo { case TerminalException.ArgumentOutOfRange: throw new ArgumentOutOfRangeException(name); + case TerminalException.OperationCanceled: + throw new OperationCanceledException(Unsafe.As(ref value)); case TerminalException.PlatformNotSupported: throw new PlatformNotSupportedException(); case TerminalException.TerminalNotAttached: @@ -152,4 +155,8 @@ public static partial TerminalResult SetMode( [LibraryImport(Library, EntryPoint = "cathode_poll")] [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] public static partial void Poll([MarshalAs(UnmanagedType.U1)] bool write, int* fds, bool* results, int count); + + [LibraryImport(Library, EntryPoint = "cathode_cancel")] + [UnmanagedCallConv(CallConvs = [typeof(CallConvCdecl)])] + public static partial void Cancel(TerminalDescriptor* descriptor); } diff --git a/src/core/Terminals/NativeTerminalReader.cs b/src/core/Terminals/NativeTerminalReader.cs index bf87e0d..dd5125b 100644 --- a/src/core/Terminals/NativeTerminalReader.cs +++ b/src/core/Terminals/NativeTerminalReader.cs @@ -22,18 +22,12 @@ internal sealed unsafe class NativeTerminalReader : TerminalReader private readonly SemaphoreSlim _semaphore; - private readonly Action? _cancellationHook; - public NativeTerminalReader( - NativeVirtualTerminal terminal, - TerminalInterop.TerminalDescriptor* descriptor, - SemaphoreSlim semaphore, - Action? cancellationHook) + NativeVirtualTerminal terminal, TerminalInterop.TerminalDescriptor* descriptor, SemaphoreSlim semaphore) { Terminal = terminal; Descriptor = descriptor; _semaphore = semaphore; - _cancellationHook = cancellationHook; Stream = new SynchronizedStream(new TerminalInputStream(this)); TextReader = new SynchronizedTextReader( @@ -58,14 +52,15 @@ private int ReadPartialNative(scoped Span buffer, CancellationToken cancel using (_semaphore.Enter(cancellationToken)) { - _cancellationHook?.Invoke((nuint)Descriptor, cancellationToken); - - int progress; + using (Terminal.ArrangeCancellation(Descriptor, write: false, cancellationToken)) + { + int progress; - fixed (byte* p = buffer) - TerminalInterop.Read(Descriptor, p, buffer.Length, &progress).ThrowIfError(); + fixed (byte* p = buffer) + TerminalInterop.Read(Descriptor, p, buffer.Length, &progress).ThrowIfError(cancellationToken); - return progress; + return progress; + } } } } diff --git a/src/core/Terminals/NativeTerminalWriter.cs b/src/core/Terminals/NativeTerminalWriter.cs index 335fca5..2d92177 100644 --- a/src/core/Terminals/NativeTerminalWriter.cs +++ b/src/core/Terminals/NativeTerminalWriter.cs @@ -21,18 +21,12 @@ internal sealed unsafe class NativeTerminalWriter : TerminalWriter private readonly SemaphoreSlim _semaphore; - private readonly Action? _cancellationHook; - public NativeTerminalWriter( - NativeVirtualTerminal terminal, - TerminalInterop.TerminalDescriptor* descriptor, - SemaphoreSlim semaphore, - Action? cancellationHook) + NativeVirtualTerminal terminal, TerminalInterop.TerminalDescriptor* descriptor, SemaphoreSlim semaphore) { Terminal = terminal; Descriptor = descriptor; _semaphore = semaphore; - _cancellationHook = cancellationHook; Stream = new SynchronizedStream(new TerminalOutputStream(this)); TextWriter = new SynchronizedTextWriter( @@ -55,14 +49,15 @@ private int WritePartialNative(scoped ReadOnlySpan buffer, CancellationTok using (_semaphore.Enter(cancellationToken)) { - _cancellationHook?.Invoke((nuint)Descriptor, cancellationToken); - - int progress; + using (Terminal.ArrangeCancellation(Descriptor, write: true, cancellationToken)) + { + int progress; - fixed (byte* p = buffer) - TerminalInterop.Write(Descriptor, p, buffer.Length, &progress).ThrowIfError(); + fixed (byte* p = buffer) + TerminalInterop.Write(Descriptor, p, buffer.Length, &progress).ThrowIfError(cancellationToken); - return progress; + return progress; + } } } } diff --git a/src/core/Terminals/NativeVirtualTerminal.cs b/src/core/Terminals/NativeVirtualTerminal.cs index b17c6d1..f5bc21e 100644 --- a/src/core/Terminals/NativeVirtualTerminal.cs +++ b/src/core/Terminals/NativeVirtualTerminal.cs @@ -32,12 +32,12 @@ private protected unsafe NativeVirtualTerminal() NativeTerminalReader CreateReader(TerminalInterop.TerminalDescriptor* descriptor, SemaphoreSlim semaphore) { - return new(this, descriptor, semaphore, CreateCancellationHook(write: false)); + return new(this, descriptor, semaphore); } NativeTerminalWriter CreateWriter(TerminalInterop.TerminalDescriptor* descriptor, SemaphoreSlim semaphore) { - return new(this, descriptor, semaphore, CreateCancellationHook(write: true)); + return new(this, descriptor, semaphore); } StandardIn = CreateReader(stdIn, inLock); @@ -47,7 +47,8 @@ NativeTerminalWriter CreateWriter(TerminalInterop.TerminalDescriptor* descriptor TerminalOut = CreateWriter(ttyOut, outLock); } - protected abstract Action? CreateCancellationHook(bool write); + internal abstract unsafe IDisposable? ArrangeCancellation( + TerminalInterop.TerminalDescriptor* descriptor, bool write, CancellationToken cancellationToken); private protected override sealed unsafe Size? QuerySize() { diff --git a/src/core/Terminals/UnixVirtualTerminal.cs b/src/core/Terminals/UnixVirtualTerminal.cs index de2b295..24efd1d 100644 --- a/src/core/Terminals/UnixVirtualTerminal.cs +++ b/src/core/Terminals/UnixVirtualTerminal.cs @@ -1,3 +1,5 @@ +using Vezel.Cathode.Native; + namespace Vezel.Cathode.Terminals; internal sealed class UnixVirtualTerminal : NativeVirtualTerminal @@ -6,6 +8,10 @@ internal sealed class UnixVirtualTerminal : NativeVirtualTerminal public static UnixVirtualTerminal Instance { get; } = new(); + private readonly UnixCancellationPipe _readPipe = new(write: false); + + private readonly UnixCancellationPipe _writePipe = new(write: true); + private readonly PosixSignalRegistration _sigWinch; private readonly PosixSignalRegistration _sigCont; @@ -54,14 +60,12 @@ void HandleSignal(PosixSignalContext context) _sigChld = PosixSignalRegistration.Create(PosixSignal.SIGCHLD, HandleSignal); } - protected override unsafe Action CreateCancellationHook(bool write) + internal override unsafe IDisposable? ArrangeCancellation( + TerminalInterop.TerminalDescriptor* descriptor, bool write, CancellationToken cancellationToken) { - var pipe = new UnixCancellationPipe(write); + if (cancellationToken.CanBeCanceled) + (write ? _writePipe : _readPipe).PollWithCancellation(*(int*)descriptor, cancellationToken); - return (descriptor, cancellationToken) => - { - if (cancellationToken.CanBeCanceled) - pipe.PollWithCancellation(*(int*)descriptor, cancellationToken); - }; + return null; } } diff --git a/src/core/Terminals/WindowsVirtualTerminal.cs b/src/core/Terminals/WindowsVirtualTerminal.cs index a06b37e..df709cd 100644 --- a/src/core/Terminals/WindowsVirtualTerminal.cs +++ b/src/core/Terminals/WindowsVirtualTerminal.cs @@ -1,3 +1,5 @@ +using Vezel.Cathode.Native; + namespace Vezel.Cathode.Terminals; internal sealed class WindowsVirtualTerminal : NativeVirtualTerminal @@ -20,8 +22,14 @@ private WindowsVirtualTerminal() { } - protected override Action? CreateCancellationHook(bool write) + internal override unsafe IDisposable? ArrangeCancellation( + TerminalInterop.TerminalDescriptor* descriptor, bool write, CancellationToken cancellationToken) { - return null; + return cancellationToken.CanBeCanceled + ? cancellationToken.UnsafeRegister( + static descriptor => + TerminalInterop.Cancel((TerminalInterop.TerminalDescriptor*)Unsafe.Unbox(descriptor!)), + (nuint)descriptor) + : null; } } diff --git a/src/native/driver-unix.c b/src/native/driver-unix.c index 9b29e16..0451f77 100644 --- a/src/native/driver-unix.c +++ b/src/native/driver-unix.c @@ -321,7 +321,7 @@ TerminalResult cathode_generate_signal(TerminalSignal signal) } TerminalResult cathode_read( - const TerminalDescriptor *nonnull descriptor, uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress) + TerminalDescriptor *nonnull descriptor, uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress) { assert(descriptor); assert(buffer); @@ -372,10 +372,7 @@ TerminalResult cathode_read( } TerminalResult cathode_write( - const TerminalDescriptor *nonnull descriptor, - const uint8_t *nullable buffer, - int32_t length, - int32_t *nonnull progress) + TerminalDescriptor *nonnull descriptor, const uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress) { assert(descriptor); assert(buffer); @@ -385,9 +382,9 @@ TerminalResult cathode_write( { ssize_t ret; - // Note that this call may get us suspended by way of a SIGTTOU signal if we are a background process, the handle - // refers to a terminal, and the TOSTOP bit is set (we disable TOSTOP but there are ways that it could get set - // anyway). + // Note that this call may get us suspended by way of a SIGTTOU signal if we are a background process, the + // handle refers to a terminal, and the TOSTOP bit is set (we disable TOSTOP but there are ways that it could + // get set anyway). while ((ret = write(descriptor->fd, buffer, (size_t)length)) == -1 && errno == EINTR) { // Retry in case we get interrupted by a signal. diff --git a/src/native/driver-windows.c b/src/native/driver-windows.c index d5f3272..96a2532 100644 --- a/src/native/driver-windows.c +++ b/src/native/driver-windows.c @@ -9,7 +9,8 @@ struct TerminalDescriptor HANDLE handle; }; -typedef struct { +typedef struct +{ TerminalDescriptor descriptor; DWORD original_mode; UINT original_code_page; @@ -38,7 +39,6 @@ static HANDLE open_console_handle(const wchar_t *nonnull name) SECURITY_ATTRIBUTES attrs = { .nLength = sizeof(SECURITY_ATTRIBUTES), - .lpSecurityDescriptor = nullptr, .bInheritHandle = true, }; @@ -310,14 +310,10 @@ TerminalResult cathode_generate_signal(TerminalSignal signal) }; } -TerminalResult cathode_read( - const TerminalDescriptor *nonnull descriptor, uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress) +static TerminalResult create_io_result(BOOL result, const int32_t *nonnull progress) { - assert(descriptor); - assert(buffer); assert(progress); - BOOL result = ReadFile(descriptor->handle, buffer, (DWORD)length, (LPDWORD)progress, nullptr); DWORD error = GetLastError(); // See driver-unix.c for the error handling rationale. @@ -326,39 +322,43 @@ TerminalResult cathode_read( { .exception = TerminalException_None, } - : (TerminalResult) - { - .exception = TerminalException_Terminal, - .message = u"Could not read from input handle.", - .error = (int32_t)error, - }; + : error == ERROR_OPERATION_ABORTED + ? (TerminalResult) + { + .exception = TerminalException_OperationCanceled, + } + : (TerminalResult) + { + .exception = TerminalException_Terminal, + .message = u"Could not read from input handle.", + .error = (int32_t)error, + }; +} + +TerminalResult cathode_read( + TerminalDescriptor *nonnull descriptor, uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress) +{ + assert(descriptor); + assert(buffer); + assert(progress); + + return create_io_result(ReadFile(descriptor->handle, buffer, (DWORD)length, (LPDWORD)progress, nullptr), progress); } TerminalResult cathode_write( - const TerminalDescriptor *nonnull descriptor, - const uint8_t *nullable buffer, - int32_t length, - int32_t *nonnull progress) + TerminalDescriptor *nonnull descriptor, const uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress) { assert(descriptor); assert(buffer); assert(progress); - BOOL result = WriteFile(descriptor->handle, buffer, (DWORD)length, (LPDWORD)progress, nullptr); - DWORD error = GetLastError(); + return create_io_result(WriteFile(descriptor->handle, buffer, (DWORD)length, (LPDWORD)progress, nullptr), progress); +} - // See driver-unix.c for the error handling rationale. - return result || *progress || error == ERROR_HANDLE_EOF || error == ERROR_BROKEN_PIPE || error == ERROR_NO_DATA - ? (TerminalResult) - { - .exception = TerminalException_None, - } - : (TerminalResult) - { - .exception = TerminalException_Terminal, - .message = u"Could not write to output handle.", - .error = (int32_t)error, - }; +void cathode_cancel(TerminalDescriptor *nonnull descriptor) +{ + // This is a best-effort situation; nothing we can do if this fails. + CancelIoEx(descriptor->handle, nullptr); } #endif diff --git a/src/native/driver-windows.h b/src/native/driver-windows.h index b84f6ca..cbb078f 100644 --- a/src/native/driver-windows.h +++ b/src/native/driver-windows.h @@ -2,4 +2,4 @@ #include "driver.h" -// Currently no OS-specific APIs. +CATHODE_API void cathode_cancel(TerminalDescriptor *nonnull descriptor); diff --git a/src/native/driver.h b/src/native/driver.h index b9db50f..8585f9f 100644 --- a/src/native/driver.h +++ b/src/native/driver.h @@ -2,23 +2,27 @@ typedef struct TerminalDescriptor TerminalDescriptor; -typedef enum { +typedef enum +{ TerminalException_None, TerminalException_ArgumentOutOfRange, + TerminalException_OperationCanceled, TerminalException_PlatformNotSupported, TerminalException_TerminalNotAttached, TerminalException_TerminalConfiguration, TerminalException_Terminal, } TerminalException; -typedef struct { +typedef struct +{ TerminalException exception; const uint16_t *nullable message; // TODO: This should be char16_t. int32_t error; } TerminalResult; // Keep in sync with src/core/TerminalSignal.cs (public API). -typedef enum { +typedef enum +{ TerminalSignal_Close, TerminalSignal_Interrupt, TerminalSignal_Quit, @@ -48,10 +52,7 @@ CATHODE_API TerminalResult cathode_set_mode(bool raw, bool flush); CATHODE_API TerminalResult cathode_generate_signal(TerminalSignal signal); CATHODE_API TerminalResult cathode_read( - const TerminalDescriptor *nonnull handle, uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress); + TerminalDescriptor *nonnull descriptor, uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress); CATHODE_API TerminalResult cathode_write( - const TerminalDescriptor *nonnull handle, - const uint8_t *nullable buffer, - int32_t length, - int32_t *nonnull progress); + TerminalDescriptor *nonnull descriptor, const uint8_t *nullable buffer, int32_t length, int32_t *nonnull progress); diff --git a/src/samples/cancellation/Program.cs b/src/samples/cancellation/Program.cs index aab0bf6..d1dd57f 100644 --- a/src/samples/cancellation/Program.cs +++ b/src/samples/cancellation/Program.cs @@ -1,34 +1,39 @@ await OutAsync("Reading cooked input: "); await OutLineAsync(await ReadLineAsync()); -await OutLineAsync("Entering raw mode and reading input. Canceling after 5 seconds."); +await OutLineAsync("Entering raw mode and reading input. Then canceling after 5 seconds."); await OutLineAsync(); -using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); - var array = new byte[1]; EnableRawMode(); try { - while (true) + for (var i = 0; i < 2; i++) { - try + await Task.Delay(TimeSpan.FromSeconds(5)); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)); + + await OutAsync($"Round {i}...\r\n"); + + while (true) { - if (await ReadAsync(array, cts.Token) == 0) + try + { + if (await ReadAsync(array, cts.Token) == 0) + break; + } + catch (OperationCanceledException) + { + await OutAsync("Canceled.\r\n"); + break; - } - catch (OperationCanceledException) - { - await OutAsync("Canceled."); - await OutAsync("\r\n"); + } - break; + await OutAsync($"0x{array[0]:x2}\r\n"); } - - await OutAsync($"0x{array[0]:x2}"); - await OutAsync("\r\n"); } } finally diff --git a/src/samples/raw/Program.cs b/src/samples/raw/Program.cs index 011e16b..42c43e2 100644 --- a/src/samples/raw/Program.cs +++ b/src/samples/raw/Program.cs @@ -14,8 +14,7 @@ if (await ReadAsync(array) == 0) break; - await OutAsync($"0x{array[0]:x2}"); - await OutAsync("\r\n"); + await OutAsync($"0x{array[0]:x2}\r\n"); } } finally