From aeea975bd3bd2bad892b4b79d217fe61926aa32a Mon Sep 17 00:00:00 2001 From: Vesa Karvonen Date: Sun, 4 Feb 2024 16:27:28 +0200 Subject: [PATCH] Fix to avoid write after closes (fixes #16) --- src/Domain_local_timeout.ml | 114 ++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 49 deletions(-) diff --git a/src/Domain_local_timeout.ml b/src/Domain_local_timeout.ml index c4ac26e..ba137ad 100644 --- a/src/Domain_local_timeout.ml +++ b/src/Domain_local_timeout.ml @@ -27,100 +27,117 @@ end module Q = Psq.Make (Int) (Entry) let system_on_current_domain (module Thread : Thread) (module Unix : Unix) = - let error = ref None in - let check () = match !error with None -> () | Some exn -> raise exn in - let running = ref true in - let needs_wakeup = ref true in - let reading, writing = Unix.pipe () in - let[@poll error] [@inline never] wakeup_needed_atomically () = - !needs_wakeup && !error == None + let open struct + type state = { + mutable needs_wakeup : bool; + mutable counter : int; + mutable running : bool; + mutable error : exn option; + reading : Unix.file_descr; + writing : Unix.file_descr; + timeouts : Q.t Atomic.t; + } + end in + let s = + let reading, writing = Unix.pipe () in + { + needs_wakeup = true; + counter = 0; + running = true; + error = None; + reading; + writing; + timeouts = Atomic.make Q.empty; + } + in + let check s = match s.error with None -> () | Some exn -> raise exn in + let[@poll error] [@inline never] wakeup_needed_atomically s = + s.needs_wakeup && begin - needs_wakeup := false; + s.needs_wakeup <- false; true end in - let wakeup () = - if wakeup_needed_atomically () then begin - let n = Unix.write writing (Bytes.create 1) 0 1 in + let wakeup s = + if wakeup_needed_atomically s then begin + let n = Unix.write s.writing (Bytes.create 1) 0 1 in assert (n = 1) end in - let counter = ref 0 in - let[@poll error] [@inline never] next_id_atomically () = - let id = !counter + 1 in - counter := id; + let[@poll error] [@inline never] next_id_atomically s = + let id = s.counter + 1 in + s.counter <- id; id in - let timeouts = Atomic.make Q.empty in - let[@poll error] [@inline never] running_atomically () = - !running - && begin - needs_wakeup := true; - true - end + let[@poll error] [@inline never] running_atomically s = + let running = s.running in + s.needs_wakeup <- running; + running in - let rec timeout_thread next = - if running_atomically () then begin + let rec timeout_thread s next = + if running_atomically s then begin begin - match Unix.select [ reading ] [] [] next with + match Unix.select [ s.reading ] [] [] next with | [ reading ], _, _ -> let n = Unix.read reading (Bytes.create 1) 0 1 in assert (n = 1) | _, _, _ -> () end; - let rec loop () = - let ts_old = Atomic.get timeouts in + let rec loop s = + let ts_old = Atomic.get s.timeouts in match Q.pop ts_old with | None -> -1.0 | Some ((_, t), ts) -> let elapsed = Mtime_clock.elapsed () in if Mtime.Span.compare t.time elapsed <= 0 then begin - if Atomic.compare_and_set timeouts ts_old ts then t.action (); - loop () + if Atomic.compare_and_set s.timeouts ts_old ts then t.action (); + loop s end else Mtime.Span.to_float_ns (Mtime.Span.abs_diff t.time elapsed) *. (1. /. 1_000_000_000.) in - timeout_thread (loop ()) + timeout_thread s (loop s) end in - let timeout_thread () = + let timeout_thread s = begin - match timeout_thread (-1.0) with + match timeout_thread s (-1.0) with | () -> () - | exception exn -> error := Some exn + | exception exn -> + s.needs_wakeup <- false; + s.error <- Some exn end; - Unix.close reading; - Unix.close writing + Unix.close s.reading; + Unix.close s.writing in - let tid = Thread.create timeout_thread () in + let tid = Thread.create timeout_thread s in let stop () = - running := false; - wakeup (); + s.running <- false; + wakeup s; Thread.join tid; - check () + check s in let set_timeoutf seconds action = match Mtime.Span.of_float_ns (seconds *. 1_000_000_000.) with | None -> invalid_arg "timeout should be between 0 to pow(2, 53) nanoseconds" | Some span -> - check (); + check s; let time = Mtime.Span.add (Mtime_clock.elapsed ()) span in let e' = Entry.{ time; action } in - let id = next_id_atomically () in + let id = next_id_atomically s in let rec insert_loop () = - let ts = Atomic.get timeouts in + let ts = Atomic.get s.timeouts in let ts' = Q.add id e' ts in - if not (Atomic.compare_and_set timeouts ts ts') then insert_loop () + if not (Atomic.compare_and_set s.timeouts ts ts') then insert_loop () else match Q.min ts' with Some (id', _) -> id = id' | None -> false in - if insert_loop () then wakeup (); + if insert_loop () then wakeup s; let rec cancel () = - let ts = Atomic.get timeouts in + let ts = Atomic.get s.timeouts in let ts' = Q.remove id ts in - if not (Atomic.compare_and_set timeouts ts ts') then cancel () + if not (Atomic.compare_and_set s.timeouts ts ts') then cancel () in cancel in @@ -144,9 +161,8 @@ let try_system = ref unimplemented let default seconds action = !try_system seconds action let key = Domain.DLS.new_key @@ fun () -> Per_domain { set_timeoutf = default } -let[@poll error] [@inline never] update_set_timeoutf_atomically state - set_timeoutf = - match state with +let[@poll error] [@inline never] update_set_timeoutf_atomically s set_timeoutf = + match s with | Per_domain r -> let current = r.set_timeoutf in if current == default then begin