Skip to content

Commit

Permalink
Create bufconn package for a local, buffered net.Conn and dialer/list…
Browse files Browse the repository at this point in the history
…ener
  • Loading branch information
dfawley committed Jul 18, 2017
1 parent ce03e9c commit 9a71c79
Show file tree
Hide file tree
Showing 2 changed files with 344 additions and 0 deletions.
227 changes: 227 additions & 0 deletions test/bufconn/bufconn.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

// Package bufconn provides a net.Conn implemented by a buffer and related
// dialing and listening functionality.
package bufconn

import (
"fmt"
"io"
"net"
"sync"
"time"
)

// Listener implements a net.Listener that creates local, buffered net.Conns
// via its Accept and Dial method.
type Listener struct {
mu sync.Mutex
sz int
ch chan net.Conn
closed bool
}

var errClosed = fmt.Errorf("Closed")

// Listen returns a Listener that can only be contacted by its own Dialers and
// creates buffered connections between the two.
func Listen(sz int) *Listener {
return &Listener{sz: sz, ch: make(chan net.Conn)}
}

// Accept blocks until Dial is called, then returns a net.Conn for the server
// half of the connection.
func (l *Listener) Accept() (net.Conn, error) {
c := <-l.ch
if c == nil {
return nil, errClosed
}
return c, nil
}

// Close stops the listener.
func (l *Listener) Close() error {
l.mu.Lock()
defer l.mu.Unlock()
if l.closed {
return nil
}
l.closed = true
close(l.ch)
return nil
}

// Addr reports the address of the listener.
func (l *Listener) Addr() net.Addr { return addr{} }

// Dial creates an in-memory full-duplex network connection, unblocks Accept by
// providing it the server half of the connection, and returns the client half
// of the connection.
func (l *Listener) Dial() (net.Conn, error) {
l.mu.Lock()
defer l.mu.Unlock()
if l.closed {
return nil, errClosed
}
p1, p2 := newPipe(l.sz), newPipe(l.sz)
l.ch <- &conn{p1, p2}
return &conn{p2, p1}, nil
}

type pipe struct {
mu sync.Mutex

// buf contains the data in the pipe. It is a ring buffer of fixed capacity,
// with r and w pointing to the offset to read and write, respsectively.
//
// Data is read between [r, w) and written to [w, r), wrapping around the end
// of the slice if necessary.
//
// The buffer is empty if r == len(buf), otherwise if r == w, it is full.
//
// w and r are always in the range [0, cap(buf)) and [0, len(buf)].
buf []byte
w, r int

wwait sync.Cond
rwait sync.Cond
closed bool
}

func newPipe(sz int) *pipe {
p := &pipe{buf: make([]byte, 0, sz)}
p.wwait.L = &p.mu
p.rwait.L = &p.mu
return p
}

func (p *pipe) empty() bool {
return p.r == len(p.buf)
}

func (p *pipe) full() bool {
return p.r < len(p.buf) && p.r == p.w
}

func (p *pipe) Read(b []byte) (n int, err error) {
p.mu.Lock()
defer p.mu.Unlock()
// Block until p has data.
for {
if p.closed {
return 0, io.ErrClosedPipe
}
if !p.empty() {
break
}
p.rwait.Wait()
}
wasFull := p.full()

n = copy(b, p.buf[p.r:len(p.buf)])
p.r += n
if p.r == cap(p.buf) {
p.r = 0
p.buf = p.buf[:p.w]
}

// Signal a blocked writer, if any
if wasFull {
p.wwait.Signal()
}

return n, nil
}

func (p *pipe) Write(b []byte) (n int, err error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return 0, io.ErrClosedPipe
}
for len(b) > 0 {
// Block until p is not full.
for {
if p.closed {
return 0, io.ErrClosedPipe
}
if !p.full() {
break
}
p.wwait.Wait()
}
wasEmpty := p.empty()

end := cap(p.buf)
if p.w < p.r {
end = p.r
}
x := copy(p.buf[p.w:end], b)
b = b[x:]
n += x
p.w += x
if p.w > len(p.buf) {
p.buf = p.buf[:p.w]
}
if p.w == cap(p.buf) {
p.w = 0
}

// Signal a blocked reader, if any.
if wasEmpty {
p.rwait.Signal()
}
}
return n, nil
}

func (p *pipe) Close() error {
p.mu.Lock()
defer p.mu.Unlock()
p.closed = true
// Signal all blocked readers and writers to return an error.
p.rwait.Broadcast()
p.wwait.Broadcast()
return nil
}

type conn struct {
io.ReadCloser
io.WriteCloser
}

func (c *conn) Close() error {
err1 := c.ReadCloser.Close()
err2 := c.WriteCloser.Close()
if err1 != nil {
return err1
}
return err2
}

func (*conn) LocalAddr() net.Addr { return addr{} }
func (*conn) RemoteAddr() net.Addr { return addr{} }
func (c *conn) SetDeadline(t time.Time) error { return fmt.Errorf("unsupported") }
func (c *conn) SetReadDeadline(t time.Time) error { return fmt.Errorf("unsupported") }
func (c *conn) SetWriteDeadline(t time.Time) error { return fmt.Errorf("unsupported") }

type addr struct{}

func (addr) Network() string { return "bufconn" }
func (addr) String() string { return "bufconn" }
117 changes: 117 additions & 0 deletions test/bufconn/bufconn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
*
* Copyright 2017 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package bufconn

import (
"fmt"
"io"
"net"
"reflect"
"testing"
"time"
)

func testRW(r io.Reader, w io.Writer) error {
for i := 0; i < 20; i++ {
d := make([]byte, i)
for j := 0; j < i; j++ {
d[j] = byte(i - j)
}
var rn int
var rerr error
b := make([]byte, i)
done := make(chan struct{})
go func() {
for rn < len(b) && rerr == nil {
var x int
x, rerr = r.Read(b[rn:])
rn += x
}
close(done)
}()
wn, werr := w.Write(d)
if wn != i || werr != nil {
return fmt.Errorf("%v: w.Write(%v) = %v, %v; want %v, nil", i, d, wn, werr, i)
}
select {
case <-done:
case <-time.After(500 * time.Millisecond):
return fmt.Errorf("%v: r.Read never returned", i)
}
if rn != i || rerr != nil {
return fmt.Errorf("%v: r.Read = %v, %v; want %v, nil", i, rn, rerr, i)
}
if !reflect.DeepEqual(b, d) {
return fmt.Errorf("%v: r.Read read %v; want %v", i, b, d)
}
}
return nil
}

func TestPipe(t *testing.T) {
p := newPipe(10)
if err := testRW(p, p); err != nil {
t.Fatalf(err.Error())
}
}

func TestPipeClose(t *testing.T) {
p := newPipe(10)
p.Close()
if _, err := p.Write(nil); err != io.ErrClosedPipe {
t.Fatalf("p.Write = _, %v; want _, %v", err, io.ErrClosedPipe)
}
if _, err := p.Read(nil); err != io.ErrClosedPipe {
t.Fatalf("p.Read = _, %v; want _, %v", err, io.ErrClosedPipe)
}
}

func TestConn(t *testing.T) {
p1, p2 := newPipe(10), newPipe(10)
c1, c2 := &conn{p1, p2}, &conn{p2, p1}

if err := testRW(c1, c2); err != nil {
t.Fatalf(err.Error())
}
if err := testRW(c2, c1); err != nil {
t.Fatalf(err.Error())
}
}

func TestListener(t *testing.T) {
l := Listen(7)
var s net.Conn
var serr error
done := make(chan struct{})
go func() {
s, serr = l.Accept()
close(done)
}()
c, cerr := l.Dial()
<-done
if cerr != nil || serr != nil {
t.Fatalf("cerr = %v, serr = %v; want nil, nil", cerr, serr)
}
if err := testRW(c, s); err != nil {
t.Fatalf(err.Error())
}
if err := testRW(s, c); err != nil {
t.Fatalf(err.Error())
}
}

0 comments on commit 9a71c79

Please sign in to comment.