diff --git a/worker/message_port.go b/worker/message_port.go index 8fb0e56..5767b77 100644 --- a/worker/message_port.go +++ b/worker/message_port.go @@ -5,6 +5,7 @@ package worker import ( "context" "fmt" + "sync" "github.com/hack-pad/safejs" ) @@ -47,14 +48,25 @@ func (p *messagePort) Listen(ctx context.Context) (_ <-chan MessageEvent, err er }() events := make(chan MessageEvent) - messageHandler, err := nonBlocking(func(args []safejs.Value) { - events <- parseMessageEvent(args[0]) + var wg sync.WaitGroup + messageHandler, err := nonBlocking(&wg, func(args []safejs.Value) { + select { + case <-ctx.Done(): + return + default: + events <- parseMessageEvent(args[0]) + } }) if err != nil { return nil, err } - errorHandler, err := nonBlocking(func(args []safejs.Value) { - events <- parseMessageEvent(args[0]) + errorHandler, err := nonBlocking(&wg, func(args []safejs.Value) { + select { + case <-ctx.Done(): + return + default: + events <- parseMessageEvent(args[0]) + } }) if err != nil { return nil, err @@ -70,6 +82,7 @@ func (p *messagePort) Listen(ctx context.Context) (_ <-chan MessageEvent, err er if err == nil { errorHandler.Release() } + wg.Wait() close(events) }() _, err = p.jsMessagePort.Call("addEventListener", "message", messageHandler) @@ -90,9 +103,13 @@ func (p *messagePort) Listen(ctx context.Context) (_ <-chan MessageEvent, err er return events, nil } -func nonBlocking(fn func(args []safejs.Value)) (safejs.Func, error) { +func nonBlocking(wg *sync.WaitGroup, fn func(args []safejs.Value)) (safejs.Func, error) { return safejs.FuncOf(func(_ safejs.Value, args []safejs.Value) any { - go fn(args) + wg.Add(1) + go func() { + fn(args) + wg.Done() + }() return nil }) } diff --git a/worker/worker_test.go b/worker/worker_test.go index 52eb094..943cb82 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -301,3 +301,41 @@ self.addEventListener("message", event => { } }) } + +func TestWorkerStopListen(t *testing.T) { + t.Parallel() + const pingPongScript = ` +"use strict"; + +self.addEventListener("message", event => { + self.postMessage("foo"); + self.postMessage("bar"); +}); +` + worker, err := NewFromScript(pingPongScript, Options{}) + if err != nil { + t.Fatal(err) + } + cleanUpWorker(t, worker) + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + _, err = worker.Listen(ctx) + if err != nil { + t.Fatal(err) + } + + msg, err := safejs.ValueOf("start") + if err != nil { + t.Fatal(err) + } + + err = worker.PostMessage(msg, nil) + if err != nil { + t.Fatal(err) + } + + time.Sleep(100 * time.Millisecond) + + cancel() +}