|
1 | 1 | package websocket
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "encoding/binary" |
4 | 5 | "io"
|
| 6 | + |
| 7 | + "golang.org/x/xerrors" |
5 | 8 | )
|
6 | 9 |
|
| 10 | +// First byte contains fin, rsv1, rsv2, rsv3. |
| 11 | +// Second byte contains mask flag and payload len. |
| 12 | +// Next 8 bytes are the maximum extended payload length. |
| 13 | +// Last 4 bytes are the mask key. |
| 14 | +// https://tools.ietf.org/html/rfc6455#section-5.2 |
| 15 | +const maxHeaderSize = 1 + 1 + 8 + 4 |
| 16 | + |
7 | 17 | // header represents a WebSocket frame header.
|
8 | 18 | // See https://tools.ietf.org/html/rfc6455#section-5.2
|
9 | 19 | // The fields are exported for easy printing for debugging.
|
10 | 20 | type header struct {
|
11 |
| - Fin bool |
12 |
| - Rsv1 bool |
13 |
| - Rsv2 bool |
14 |
| - Rsv3 bool |
15 |
| - Opcode opcode |
| 21 | + fin bool |
| 22 | + rsv1 bool |
| 23 | + rsv2 bool |
| 24 | + rsv3 bool |
| 25 | + opcode opcode |
16 | 26 |
|
17 |
| - PayloadLength int64 |
| 27 | + payloadLength int64 |
18 | 28 |
|
19 |
| - Masked bool |
20 |
| - MaskKey [4]byte |
| 29 | + masked bool |
| 30 | + maskKey [4]byte |
21 | 31 | }
|
22 | 32 |
|
23 |
| -// Bytes returns the bytes of the header. |
24 |
| -func (h header) Bytes() []byte { |
25 |
| - panic("TODO") |
| 33 | +// TODO bitwise helpers |
| 34 | + |
| 35 | +// bytes returns the bytes of the header. |
| 36 | +// See https://tools.ietf.org/html/rfc6455#section-5.2 |
| 37 | +func marshalHeader(h header) ([]byte, error) { |
| 38 | + b := make([]byte, 2, maxHeaderSize) |
| 39 | + |
| 40 | + if h.fin { |
| 41 | + b[0] |= 1 << 7 |
| 42 | + } |
| 43 | + if h.rsv1 { |
| 44 | + b[0] |= 1 << 6 |
| 45 | + } |
| 46 | + if h.rsv2 { |
| 47 | + b[0] |= 1 << 5 |
| 48 | + } |
| 49 | + if h.rsv3 { |
| 50 | + b[0] |= 1 << 4 |
| 51 | + } |
| 52 | + |
| 53 | + b[0] |= byte(h.opcode) |
| 54 | + |
| 55 | + switch { |
| 56 | + case h.payloadLength < 0: |
| 57 | + return nil, xerrors.Errorf("invalid header: negative length: %v", h.payloadLength) |
| 58 | + case h.payloadLength <= 125: |
| 59 | + b[1] = byte(h.payloadLength) |
| 60 | + case h.payloadLength <= 1<<16: |
| 61 | + b[1] = 126 |
| 62 | + b = b[:len(b)+2] |
| 63 | + binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) |
| 64 | + default: |
| 65 | + b[1] = 127 |
| 66 | + b = b[:len(b)+8] |
| 67 | + binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength)) |
| 68 | + } |
| 69 | + |
| 70 | + if h.masked { |
| 71 | + b[1] |= 1 << 7 |
| 72 | + b = b[:len(b)+4] |
| 73 | + copy(b[len(b)-4:], h.maskKey[:]) |
| 74 | + } |
| 75 | + |
| 76 | + return b, nil |
26 | 77 | }
|
27 | 78 |
|
28 |
| -// ReadHeader reads a header from the reader. |
29 |
| -func ReadHeader(r io.Reader) []byte { |
30 |
| - panic("TODO") |
| 79 | +// readHeader reads a header from the reader. |
| 80 | +// See https://tools.ietf.org/html/rfc6455#section-5.2 |
| 81 | +func readHeader(r io.Reader) (header, error) { |
| 82 | + // We read the first two bytes directly so that we know |
| 83 | + // exactly how long the header is. |
| 84 | + b := make([]byte, 2, maxHeaderSize-2) |
| 85 | + _, err := io.ReadFull(r, b) |
| 86 | + if err != nil { |
| 87 | + return header{}, err |
| 88 | + } |
| 89 | + |
| 90 | + var h header |
| 91 | + h.fin = b[0]&(1<<7) != 0 |
| 92 | + h.rsv1 = b[0]&(1<<6) != 0 |
| 93 | + h.rsv2 = b[0]&(1<<5) != 0 |
| 94 | + h.rsv3 = b[0]&(1<<4) != 0 |
| 95 | + |
| 96 | + h.opcode = opcode(b[0] & 0xf) |
| 97 | + |
| 98 | + var extra int |
| 99 | + |
| 100 | + h.masked = b[1]&(1<<7) != 0 |
| 101 | + if h.masked { |
| 102 | + extra += 4 |
| 103 | + } |
| 104 | + |
| 105 | + payloadLength := b[1] &^ (1 << 7) |
| 106 | + switch { |
| 107 | + case payloadLength < 126: |
| 108 | + h.payloadLength = int64(payloadLength) |
| 109 | + case payloadLength == 126: |
| 110 | + h.payloadLength = 126 |
| 111 | + extra += 2 |
| 112 | + case payloadLength == 127: |
| 113 | + h.payloadLength = 127 |
| 114 | + extra += 8 |
| 115 | + } |
| 116 | + |
| 117 | + if extra == 0 { |
| 118 | + return h, nil |
| 119 | + } |
| 120 | + |
| 121 | + b = b[:extra] |
| 122 | + _, err = io.ReadFull(r, b) |
| 123 | + if err != nil { |
| 124 | + return header{}, err |
| 125 | + } |
| 126 | + |
| 127 | + switch { |
| 128 | + case payloadLength == 126: |
| 129 | + h.payloadLength = int64(binary.BigEndian.Uint16(b)) |
| 130 | + b = b[2:] |
| 131 | + case payloadLength == 127: |
| 132 | + h.payloadLength = int64(binary.BigEndian.Uint64(b)) |
| 133 | + b = b[8:] |
| 134 | + } |
| 135 | + |
| 136 | + if h.masked { |
| 137 | + copy(h.maskKey[:], b) |
| 138 | + } |
| 139 | + |
| 140 | + return h, nil |
31 | 141 | }
|
0 commit comments