Skip to content
This repository was archived by the owner on Jul 18, 2025. It is now read-only.

Dead session detection #57

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions spdy/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ package spdy

import (
"net"
"time"
"errors"
"fmt"
)

// TransportListener is a listener which accepts new
// connections angi rd spawns spdy transports.
type TransportListener struct {
listener net.Listener
auth Authenticator
timeout *time.Duration
}

// NewTransportListener creates a new listen transport using
Expand All @@ -32,11 +36,36 @@ func (l *TransportListener) Close() error {
// and creates a new stream. Connections which fail
// authentication will not be returned.
func (l *TransportListener) AcceptTransport() (*Transport, error) {
tranChan := make(chan interface{})
for {
conn, err := l.listener.Accept()
if err != nil {
return nil, err
// The timeout channel has a buffer of 1
// to allow the timeout goroutine to exit
// if nothing is listening anymore. This prevents
// it form hanging forever waiting on a receiver.
timeoutChan := make(chan bool, 1)
// Launch listener wait inside a goroutine passing
// transport channel
go l.waitForAccept(tranChan)
// If timeout provided launch timeout goroutine with
// duration to wait and timeout channel
if l.timeout != nil {
go waitForTimeout(*l.timeout, timeoutChan)
}

// Wait for new connection channel or timeout channel
var conn net.Conn
select {
case x := <-tranChan:
if x, ok := x.(net.Conn); ok {
conn = x
} else if x, ok := x.(error); ok {
return nil, x
}
case <-timeoutChan:
// We have timed out
return nil, errors.New(fmt.Sprintf("listener timed out (%s)", l.timeout.String()))
}

authErr := l.auth(conn)
if authErr != nil {
// TODO log
Expand All @@ -47,3 +76,25 @@ func (l *TransportListener) AcceptTransport() (*Transport, error) {
return newSession(conn, true)
}
}

func (l *TransportListener) waitForAccept(tranChan chan interface{}) {
conn, err := l.listener.Accept()
if err != nil {
tranChan<-err
}
tranChan<-conn
return
}

// Sets the timeout for this listener. AcceptTransport() will return if no
// connection is opened for t amount of time.
func (l *TransportListener) SetTimeout(t time.Duration) {
l.timeout = &t
}

// Function to wait for a timeout condition and then signal to a channel
func waitForTimeout(d time.Duration, timeoutChan chan bool) {
time.Sleep(d)
timeoutChan<-true
}

57 changes: 57 additions & 0 deletions spdy/listener_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package spdy

import (
"testing"
"time"
"net"
)

// Test that server detects client session is dead while waiting for receive channel
func TestListenerTimeout(t *testing.T) {
// Test default behavior without timeout
noWait(t)
// Test enabling timeout behavior
withWait(t)
}

func noWait(t *testing.T) {
// Start listener, ensure it doesn't throw error after 100 ms of no connecton
timeoutChan := make(chan bool)
go func() {
time.Sleep(time.Millisecond * 200)
close(timeoutChan)
}()
// Start listener, ensure it does throw error after 100 ms of no connection
listener, _ := net.Listen("tcp", "localhost:12945")
go func() {
transportListener, _ := NewTransportListener(listener, NoAuthenticator)
_, err := transportListener.AcceptTransport()
t.Fatal(err)
}()
<-timeoutChan
}

func withWait(t *testing.T) {
timeoutChan := make(chan bool)
go func() {
time.Sleep(time.Millisecond * 200)
timeoutChan<-false
}()
// Start listener, ensure it does throw error after 100 ms of no connection
listener, _ := net.Listen("tcp", "localhost:12946")
go func() {
transportListener, _ := NewTransportListener(listener, NoAuthenticator)
transportListener.SetTimeout(time.Millisecond * 100)
_, err := transportListener.AcceptTransport()
if err.Error() != "listener timed out (100ms)" {
t.Fatal(err.Error() + ", should have timed out at (100ms)")
}
timeoutChan<-true
}()
select {
case ok := <-timeoutChan:
if !ok {
t.Fatal("timeout expected and did not occur")
}
}
}
156 changes: 131 additions & 25 deletions spdy/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import (
"net/http"
"strconv"
"sync"
"time"
"fmt"

"github.com/dmcgowan/go/codec"
"github.com/docker/libchan"
Expand All @@ -18,6 +20,18 @@ type direction uint8
const (
outbound = direction(0x01)
inbound = direction(0x02)
// Defaults for heartbeat.
// Currently ping every 30 seconds and fail if no pings respond
//
// The frequency of pinging the client to
// detect a dead session. This is the default and
// can be overridden via property on the *Transport
defaultHeartbeatInterval = time.Second * 30
// The amount of times the heartbeat fails
// before the session is considered dead. This is
// the default and can be overridden via property
// on the *Transport.
defaultHeartbeatLimit = 3
)

var (
Expand All @@ -29,9 +43,15 @@ var (
// Transport is a transport session on top of a network
// connection using spdy.
type Transport struct {
HeartbeatInterval time.Duration
HeartbeatLimit int

conn *spdystream.Connection
handler codec.Handle

deadSessionChan chan struct{}
deadSessionFlag bool

receiverChan chan *channel
channelC *sync.Cond
channels map[uint64]*channel
Expand Down Expand Up @@ -76,15 +96,18 @@ func newSession(conn net.Conn, server bool) (*Transport, error) {
referenceCounter = 1
}
session := &Transport{
receiverChan: make(chan *channel),
channelC: sync.NewCond(new(sync.Mutex)),
channels: make(map[uint64]*channel),
referenceCounter: referenceCounter,
byteStreamC: sync.NewCond(new(sync.Mutex)),
byteStreams: make(map[uint64]*byteStream),
netConnC: sync.NewCond(new(sync.Mutex)),
netConns: make(map[byte]map[string]net.Conn),
networks: make(map[string]byte),
deadSessionChan: make(chan struct{}),
receiverChan: make(chan *channel),
channelC: sync.NewCond(new(sync.Mutex)),
channels: make(map[uint64]*channel),
referenceCounter: referenceCounter,
byteStreamC: sync.NewCond(new(sync.Mutex)),
byteStreams: make(map[uint64]*byteStream),
netConnC: sync.NewCond(new(sync.Mutex)),
netConns: make(map[byte]map[string]net.Conn),
networks: make(map[string]byte),
HeartbeatInterval: defaultHeartbeatInterval,
HeartbeatLimit: defaultHeartbeatLimit,
}

spdyConn, spdyErr := spdystream.NewConnection(conn, server)
Expand All @@ -96,9 +119,57 @@ func newSession(conn net.Conn, server bool) (*Transport, error) {
session.conn = spdyConn
session.handler = session.initializeHandler()

// Looping heartbeat monitor. Pings the client to
// determine if it has lost connection without sending
// a close.
go session.monitorHeartbeat()

return session, nil
}

// errDeadSession occurs when heartbeat is enabled and
// a ping returns an error trying to contact the client.
// This is useful for managing long runnning connections
// that may die form network failure. This is a method
// rather than a var to allow insertion of time elapsed
// dynamically.
func (s *Transport) errDeadSession() error {
return errors.New(fmt.Sprintf("session appears dead no response after %v", s.HeartbeatInterval*time.Duration(s.HeartbeatLimit)))
}

func (s *Transport) monitorHeartbeat() {
var hbFailures int = 0
for {
// Only loop after waiting for the heartbeatInterval
time.Sleep(s.HeartbeatInterval)
_, err := s.conn.Ping()
if err != nil {
// Increase heartbeat failure count
hbFailures++
// If we have hit out limit on failures we trigger marking
// the session as dead.
if hbFailures >= s.HeartbeatLimit {
// Set the deadSessionFlag to true. This is used to
// check for a dead session before starting a blocking
// op using a channel.
s.deadSessionFlag = true
// Uses the closing of a channel trick to
// broadcast to all waiting threads that
// the session is dead.
// Any thread that needs to wait on a blocking
// op that is dependant on the session being live
// should implement a select that includes this
// channel closing as a signal.
close(s.deadSessionChan)
return
}
} else {
// Reset heartbeat failure count
hbFailures = 0
}
}
}

func (s *Transport) newStreamHandler(stream *spdystream.Stream) {
referenceIDString := stream.Headers().Get("libchan-ref")
parentIDString := stream.Headers().Get("libchan-parent-ref")
Expand Down Expand Up @@ -192,7 +263,7 @@ func (s *Transport) dial(referenceID uint64) (*byteStream, error) {
func (s *Transport) nextReferenceID() uint64 {
s.referenceLock.Lock()
referenceID := s.referenceCounter
s.referenceCounter = referenceID + 2
s.referenceCounter = referenceID+2
s.referenceLock.Unlock()
return referenceID
}
Expand Down Expand Up @@ -306,12 +377,25 @@ func (s *Transport) NewSendChannel() (libchan.Sender, error) {
// WaitReceiveChannel waits for a new channel be created by a remote
// call to NewSendChannel.
func (s *Transport) WaitReceiveChannel() (libchan.Receiver, error) {
r, ok := <-s.receiverChan
if !ok {
return nil, io.EOF
}
for {
// Safety check to see if session is dead before starting select
if s.deadSessionFlag {
return nil, s.errDeadSession()
}
// We use a select to wait for either the receiver channel
// or a dead session channel.
select {
case <-s.deadSessionChan:
// Return nil and ErrDeadSession
return nil, s.errDeadSession()
case r, ok := <-s.receiverChan:
if !ok {
return nil, io.EOF
}
return r, nil
}

return r, nil
}
}

func (c *channel) createSubChannel(direction direction) (libchan.Sender, libchan.Receiver, error) {
Expand Down Expand Up @@ -387,19 +471,41 @@ func (c *channel) Receive(message interface{}) error {
if c.direction == outbound {
return ErrWrongDirection
}
buf, readErr := c.stream.ReadData()
if readErr != nil {
if readErr == io.EOF {
c.stream.Close()
// Use a goroutine and channel to ReadData from channel
buffChan := make(chan interface{})
go c.handleReadData(buffChan)
// Wait for channel response or signal that session is dead
select {
case <-c.session.deadSessionChan:
// Dead session
return c.session.errDeadSession()
case b := <-buffChan:
switch b.(type) {
case error:
if b.(error) == io.EOF {
c.stream.Close()
}
return b.(error)
case []byte:
decoder := codec.NewDecoderBytes(b.([]byte), c.session.handler)
decodeErr := decoder.Decode(message)
if decodeErr != nil {
return decodeErr
}
return nil
default:
panic("unknown type")
}
return readErr
}
decoder := codec.NewDecoderBytes(buf, c.session.handler)
decodeErr := decoder.Decode(message)
if decodeErr != nil {
return decodeErr
}

func (c *channel) handleReadData(buffChan chan interface{}) {
buf, err := c.stream.ReadData()
if err != nil {
buffChan<-err
} else {
buffChan<-buf
}
return nil
}

// Close closes the underlying stream, causing any subsequent
Expand Down
Loading