diff --git a/net.go b/net.go index 54beb8ee..ac001dda 100644 --- a/net.go +++ b/net.go @@ -17,6 +17,7 @@ package openssl import ( "errors" "net" + "time" ) type listener struct { @@ -80,6 +81,13 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { return DialSession(network, addr, ctx, flags, nil) } +// DialTimeout works just like Dial, but with a timeout on the underlying net +// connection. +func DialTimeout(network, addr string, ctx *Ctx, flags DialFlags, + timeout time.Duration) (*Conn, error) { + return DialSessionTimeout(network, addr, ctx, flags, nil, timeout) +} + // DialSession will connect to network/address and then wrap the corresponding // underlying connection with an OpenSSL client connection using context ctx. // If flags includes InsecureSkipHostVerification, the server certificate's @@ -95,7 +103,28 @@ func Dial(network, addr string, ctx *Ctx, flags DialFlags) (*Conn, error) { // can be retrieved from the GetSession method on the Conn. func DialSession(network, addr string, ctx *Ctx, flags DialFlags, session []byte) (*Conn, error) { + return dialSession(network, addr, ctx, flags, session, net.Dial) +} + +// DialSessionTimeout works just like DialSessionTimeout, but with a timeout +// on the underlying net connection. +func DialSessionTimeout(network, addr string, ctx *Ctx, flags DialFlags, + session []byte, timeout time.Duration) (*Conn, error) { + return dialSession( + network, + addr, + ctx, + flags, + session, + func(network, addr string) (net.Conn, error) { + return net.DialTimeout(network, addr, timeout) + }, + ) +} +func dialSession(network, addr string, ctx *Ctx, flags DialFlags, + session []byte, + makeNetConn func(network, addr string) (net.Conn, error)) (*Conn, error) { host, _, err := net.SplitHostPort(addr) if err != nil { return nil, err @@ -108,7 +137,7 @@ func DialSession(network, addr string, ctx *Ctx, flags DialFlags, } // TODO: use operating system default certificate chain? } - c, err := net.Dial(network, addr) + c, err := makeNetConn(network, addr) if err != nil { return nil, err }