Skip to content

Commit 9603c3b

Browse files
committed
Improve benchmarks and single frame write path
1 parent 2389eb1 commit 9603c3b

File tree

4 files changed

+62
-43
lines changed

4 files changed

+62
-43
lines changed

bench_test.go

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212
"nhooyr.io/websocket"
1313
)
1414

15-
func benchConn(b *testing.B, stream bool) {
15+
func benchConn(b *testing.B, echo, stream bool) {
1616
name := "buffered"
1717
if stream {
1818
name = "stream"
@@ -25,12 +25,11 @@ func benchConn(b *testing.B, stream bool) {
2525
b.Logf("server handshake failed: %+v", err)
2626
return
2727
}
28-
if stream {
29-
streamEchoLoop(r.Context(), c)
28+
if echo {
29+
echoLoop(r.Context(), c)
3030
} else {
31-
bufferedEchoLoop(r.Context(), c)
31+
discardLoop(r.Context(), c)
3232
}
33-
3433
}))
3534
defer closeFn()
3635

@@ -50,6 +49,7 @@ func benchConn(b *testing.B, stream bool) {
5049
buf := make([]byte, len(msg))
5150
b.Run(strconv.Itoa(n), func(b *testing.B) {
5251
b.SetBytes(int64(len(msg)))
52+
b.ReportAllocs()
5353
for i := 0; i < b.N; i++ {
5454
if stream {
5555
w, err := c.Writer(ctx, websocket.MessageText)
@@ -72,14 +72,17 @@ func benchConn(b *testing.B, stream bool) {
7272
b.Fatal(err)
7373
}
7474
}
75-
_, r, err := c.Reader(ctx)
76-
if err != nil {
77-
b.Fatal(err, b.N)
78-
}
7975

80-
_, err = io.ReadFull(r, buf)
81-
if err != nil {
82-
b.Fatal(err)
76+
if echo {
77+
_, r, err := c.Reader(ctx)
78+
if err != nil {
79+
b.Fatal(err)
80+
}
81+
82+
_, err = io.ReadFull(r, buf)
83+
if err != nil {
84+
b.Fatal(err)
85+
}
8386
}
8487
}
8588
})
@@ -99,6 +102,11 @@ func benchConn(b *testing.B, stream bool) {
99102
}
100103

101104
func BenchmarkConn(b *testing.B) {
102-
benchConn(b, false)
103-
benchConn(b, true)
105+
b.Run("write", func(b *testing.B) {
106+
benchConn(b, false, false)
107+
benchConn(b, false, true)
108+
})
109+
b.Run("echo", func(b *testing.B) {
110+
benchConn(b, true, true)
111+
})
104112
}

export_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@ import (
88
// method for when the entire message is in memory and does not need to be streamed
99
// to the peer via Writer.
1010
//
11-
// Both paths are zero allocation but Writer always has
12-
// to write an additional fin frame when Close is called on the writer which
13-
// can result in worse performance if the full message exceeds the buffer size
14-
// which is 4096 right now as then two syscalls will be necessary to complete the message.
15-
// TODO this is no good as we cannot write data frame msg in between other ones
11+
// This prevents the allocation of the Writer.
12+
// Furthermore Writer always has to write an additional fin frame when Close is
13+
// called on the writer which can result in worse performance if the full message
14+
// exceeds the buffer size which is 4096 right now as then an extra syscall
15+
// will be necessary to complete the message.
1616
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
17-
return c.writeControl(ctx, opcode(typ), p)
17+
return c.writeSingleFrame(ctx, opcode(typ), p)
1818
}

websocket.go

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
"golang.org/x/xerrors"
1414
)
1515

16-
type control struct {
16+
type frame struct {
1717
opcode opcode
1818
payload []byte
1919
}
@@ -42,7 +42,8 @@ type Conn struct {
4242
// ping on writeDone.
4343
// writeDone will be closed if the data message write errors.
4444
write chan MessageType
45-
control chan control
45+
control chan frame
46+
fastWrite chan frame
4647
writeBytes chan []byte
4748
writeDone chan struct{}
4849
writeFlush chan struct{}
@@ -86,7 +87,8 @@ func (c *Conn) init() {
8687
c.closed = make(chan struct{})
8788

8889
c.write = make(chan MessageType)
89-
c.control = make(chan control)
90+
c.control = make(chan frame)
91+
c.fastWrite = make(chan frame)
9092
c.writeBytes = make(chan []byte)
9193
c.writeDone = make(chan struct{})
9294
c.writeFlush = make(chan struct{})
@@ -103,6 +105,8 @@ func (c *Conn) init() {
103105
go c.readLoop()
104106
}
105107

108+
// We never mask inside here because our mask key is always 0,0,0,0.
109+
// See comment on secWebSocketKey.
106110
func (c *Conn) writeFrame(h header, p []byte) {
107111
b2 := marshalHeader(h)
108112
_, err := c.bw.Write(b2)
@@ -126,14 +130,14 @@ func (c *Conn) writeFrame(h header, p []byte) {
126130
}
127131
}
128132

129-
func (c *Conn) writeLoopControl(control control) {
133+
func (c *Conn) writeLoopFastWrite(frame frame) {
130134
h := header{
131135
fin: true,
132-
opcode: control.opcode,
133-
payloadLength: int64(len(control.payload)),
136+
opcode: frame.opcode,
137+
payloadLength: int64(len(frame.payload)),
134138
masked: c.client,
135139
}
136-
c.writeFrame(h, control.payload)
140+
c.writeFrame(h, frame.payload)
137141
select {
138142
case <-c.closed:
139143
case c.writeDone <- struct{}{}:
@@ -150,7 +154,11 @@ messageLoop:
150154
case <-c.closed:
151155
return
152156
case control := <-c.control:
153-
c.writeLoopControl(control)
157+
c.writeLoopFastWrite(control)
158+
continue
159+
case frame := <-c.fastWrite:
160+
c.writeLoopFastWrite(frame)
161+
continue
154162
case dataType = <-c.write:
155163
}
156164

@@ -160,7 +168,7 @@ messageLoop:
160168
case <-c.closed:
161169
return
162170
case control := <-c.control:
163-
c.writeLoopControl(control)
171+
c.writeLoopFastWrite(control)
164172
case b := <-c.writeBytes:
165173
h := header{
166174
fin: false,
@@ -341,7 +349,7 @@ func (c *Conn) writePong(p []byte) error {
341349
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
342350
defer cancel()
343351

344-
err := c.writeControl(ctx, opPong, p)
352+
err := c.writeSingleFrame(ctx, opPong, p)
345353
return err
346354
}
347355

@@ -384,7 +392,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
384392
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
385393
defer cancel()
386394

387-
err := c.writeControl(ctx, opClose, p)
395+
err := c.writeSingleFrame(ctx, opClose, p)
388396

389397
c.close(cerr)
390398

@@ -399,11 +407,15 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
399407
return nil
400408
}
401409

402-
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
410+
func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) error {
411+
ch := c.fastWrite
412+
if opcode.controlOp() {
413+
ch = c.control
414+
}
403415
select {
404416
case <-c.closed:
405417
return c.closeErr
406-
case c.control <- control{
418+
case ch <- frame{
407419
opcode: opcode,
408420
payload: p,
409421
}:

websocket_test.go

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ func TestAutobahnServer(t *testing.T) {
448448
t.Logf("server handshake failed: %+v", err)
449449
return
450450
}
451-
streamEchoLoop(r.Context(), c)
451+
echoLoop(r.Context(), c)
452452
}))
453453
defer s.Close()
454454

@@ -495,7 +495,7 @@ func TestAutobahnServer(t *testing.T) {
495495
checkWSTestIndex(t, "./wstest_reports/server/index.json")
496496
}
497497

498-
func streamEchoLoop(ctx context.Context, c *websocket.Conn) {
498+
func echoLoop(ctx context.Context, c *websocket.Conn) {
499499
defer c.Close(websocket.StatusInternalError, "")
500500

501501
ctx, cancel := context.WithTimeout(ctx, time.Minute)
@@ -534,25 +534,24 @@ func streamEchoLoop(ctx context.Context, c *websocket.Conn) {
534534
}
535535
}
536536

537-
func bufferedEchoLoop(ctx context.Context, c *websocket.Conn) {
537+
func discardLoop(ctx context.Context, c *websocket.Conn) {
538538
defer c.Close(websocket.StatusInternalError, "")
539539

540540
ctx, cancel := context.WithTimeout(ctx, time.Minute)
541541
defer cancel()
542542

543-
b := make([]byte, 131072+2)
543+
b := make([]byte, 32768)
544544
echo := func() error {
545-
typ, r, err := c.Reader(ctx)
545+
_, r, err := c.Reader(ctx)
546546
if err != nil {
547547
return err
548548
}
549549

550-
n, err := io.ReadFull(r, b)
551-
if err != io.ErrUnexpectedEOF {
550+
_, err = io.CopyBuffer(ioutil.Discard, r, b)
551+
if err != nil {
552552
return err
553553
}
554-
555-
return c.Write(ctx, typ, b[:n])
554+
return nil
556555
}
557556

558557
for {
@@ -647,7 +646,7 @@ func TestAutobahnClient(t *testing.T) {
647646
if err != nil {
648647
t.Fatalf("failed to dial: %v", err)
649648
}
650-
streamEchoLoop(ctx, c)
649+
echoLoop(ctx, c)
651650
}()
652651
}
653652

0 commit comments

Comments
 (0)