From 00a5f18ebd17235760034b997b24dc58c4ead09b Mon Sep 17 00:00:00 2001 From: Jason Lyu Date: Tue, 24 Oct 2023 10:55:33 +0800 Subject: [PATCH] Feature(proxy): support gost relay protocol (#310) --- engine/parse.go | 18 ++++ go.mod | 1 + go.sum | 2 + proxy/proto/proto.go | 3 + proxy/relay.go | 252 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 276 insertions(+) create mode 100644 proxy/relay.go diff --git a/engine/parse.go b/engine/parse.go index 48408d8f..234dc406 100644 --- a/engine/parse.go +++ b/engine/parse.go @@ -7,6 +7,8 @@ import ( "net/url" "strings" + "github.com/gorilla/schema" + "github.com/xjasonlyu/tun2socks/v2/core/device" "github.com/xjasonlyu/tun2socks/v2/core/device/fdbased" "github.com/xjasonlyu/tun2socks/v2/core/device/tun" @@ -92,6 +94,8 @@ func parseProxy(s string) (proxy.Proxy, error) { return parseSocks5(u) case proto.Shadowsocks.String(): return parseShadowsocks(u) + case proto.Relay.String(): + return parseRelay(u) default: return nil, fmt.Errorf("unsupported protocol: %s", protocol) } @@ -158,6 +162,20 @@ func parseShadowsocks(u *url.URL) (proxy.Proxy, error) { return proxy.NewShadowsocks(address, method, password, obfsMode, obfsHost) } +func parseRelay(u *url.URL) (proxy.Proxy, error) { + address, username := u.Host, u.User.Username() + password, _ := u.User.Password() + + opts := struct { + NoDelay bool + }{} + if err := schema.NewDecoder().Decode(&opts, u.Query()); err != nil { + return nil, err + } + + return proxy.NewRelay(address, username, password, opts.NoDelay) +} + func parseMulticastGroups(s string) (multicastGroups []net.IP, _ error) { ipStrings := strings.Split(s, ",") for _, ipString := range ipStrings { diff --git a/go.mod b/go.mod index 841b11f4..c64869ea 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-chi/chi/v5 v5.0.10 github.com/go-chi/cors v1.2.1 github.com/go-chi/render v1.0.3 + github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 github.com/google/uuid v1.3.1 github.com/gorilla/schema v1.2.0 github.com/gorilla/websocket v1.5.0 diff --git a/go.sum b/go.sum index b98ef121..48597dc2 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/go-chi/cors v1.2.1 h1:xEC8UT3Rlp2QuWNEr4Fs/c2EAGVKBwy/1vHx3bppil4= github.com/go-chi/cors v1.2.1/go.mod h1:sSbTewc+6wYHBBCW7ytsFSn836hqM7JxpglAy2Vzc58= github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= +github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7 h1:qAG1OyjvdA5h221CfFSS3J359V3d2E7dJWyP29QoDSI= +github.com/go-gost/relay v0.4.1-0.20230916134211-828f314ddfe7/go.mod h1:lcX+23LCQ3khIeASBo+tJ/WbwXFO32/N5YN6ucuYTG8= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= diff --git a/proxy/proto/proto.go b/proxy/proto/proto.go index 5378a5f9..ef77d053 100644 --- a/proxy/proto/proto.go +++ b/proxy/proto/proto.go @@ -9,6 +9,7 @@ const ( Socks4 Socks5 Shadowsocks + Relay ) type Proto uint8 @@ -27,6 +28,8 @@ func (proto Proto) String() string { return "socks5" case Shadowsocks: return "ss" + case Relay: + return "relay" default: return fmt.Sprintf("proto(%d)", proto) } diff --git a/proxy/relay.go b/proxy/relay.go new file mode 100644 index 00000000..b9858d16 --- /dev/null +++ b/proxy/relay.go @@ -0,0 +1,252 @@ +package proxy + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "math" + "net" + "sync" + + "github.com/go-gost/relay" + + "github.com/xjasonlyu/tun2socks/v2/common/pool" + "github.com/xjasonlyu/tun2socks/v2/dialer" + M "github.com/xjasonlyu/tun2socks/v2/metadata" + "github.com/xjasonlyu/tun2socks/v2/proxy/proto" +) + +var _ Proxy = (*Relay)(nil) + +type Relay struct { + *Base + + user string + pass string + + noDelay bool +} + +func NewRelay(addr, user, pass string, noDelay bool) (*Relay, error) { + return &Relay{ + Base: &Base{ + addr: addr, + proto: proto.Relay, + }, + user: user, + pass: pass, + noDelay: noDelay, + }, nil +} + +func (rl *Relay) DialContext(ctx context.Context, metadata *M.Metadata) (c net.Conn, err error) { + return rl.dialContext(ctx, metadata) +} + +func (rl *Relay) DialUDP(metadata *M.Metadata) (net.PacketConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) + defer cancel() + + return rl.dialContext(ctx, metadata) +} + +func (rl *Relay) dialContext(ctx context.Context, metadata *M.Metadata) (rc *relayConn, err error) { + var c net.Conn + + c, err = dialer.DialContext(ctx, "tcp", rl.Addr()) + if err != nil { + return nil, fmt.Errorf("connect to %s: %w", rl.Addr(), err) + } + setKeepAlive(c) + + defer func(c net.Conn) { + safeConnClose(c, err) + }(c) + + req := relay.Request{ + Version: relay.Version1, + Cmd: relay.CmdConnect, + } + + if metadata.Network == M.UDP { + req.Cmd |= relay.FUDP + req.Features = append(req.Features, &relay.NetworkFeature{ + Network: relay.NetworkUDP, + }) + } + + if rl.user != "" { + req.Features = append(req.Features, &relay.UserAuthFeature{ + Username: rl.user, + Password: rl.pass, + }) + } + + req.Features = append(req.Features, serializeRelayAddr(metadata)) + + if rl.noDelay { + if _, err = req.WriteTo(c); err != nil { + return + } + if err = readRelayResponse(c); err != nil { + return + } + } + + switch metadata.Network { + case M.TCP: + rc = newRelayConn(c, metadata.Addr(), rl.noDelay, false) + if !rl.noDelay { + if _, err = req.WriteTo(rc.wbuf); err != nil { + return + } + } + case M.UDP: + rc = newRelayConn(c, metadata.Addr(), rl.noDelay, true) + if !rl.noDelay { + if _, err = req.WriteTo(rc.wbuf); err != nil { + return + } + } + default: + err = fmt.Errorf("network %s is unsupported", metadata.Network) + return + } + + return +} + +type relayConn struct { + net.Conn + udp bool + addr net.Addr + once sync.Once + wbuf *bytes.Buffer +} + +func newRelayConn(c net.Conn, addr net.Addr, noDelay, udp bool) *relayConn { + rc := &relayConn{ + Conn: c, + addr: addr, + udp: udp, + } + if !noDelay { + rc.wbuf = &bytes.Buffer{} + } + return rc +} + +func (rc *relayConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := rc.Read(b) + return n, rc.addr, err +} + +func (rc *relayConn) Read(b []byte) (n int, err error) { + rc.once.Do(func() { + if rc.wbuf != nil { + err = readRelayResponse(rc.Conn) + } + }) + if err != nil { + return + } + + if !rc.udp { + return rc.Conn.Read(b) + } + + var bb [2]byte + _, err = io.ReadFull(rc.Conn, bb[:]) + if err != nil { + return + } + + dLen := int(binary.BigEndian.Uint16(bb[:])) + if len(b) >= dLen { + return io.ReadFull(rc.Conn, b[:dLen]) + } + + buf := pool.Get(dLen) + defer pool.Put(buf) + _, err = io.ReadFull(rc.Conn, buf) + n = copy(b, buf) + + return +} + +func (rc *relayConn) WriteTo(b []byte, _ net.Addr) (int, error) { + return rc.Write(b) +} + +func (rc *relayConn) Write(b []byte) (int, error) { + if rc.udp { + return rc.udpWrite(b) + } + return rc.tcpWrite(b) +} + +func (rc *relayConn) tcpWrite(b []byte) (n int, err error) { + if rc.wbuf != nil && rc.wbuf.Len() > 0 { + n = len(b) + rc.wbuf.Write(b) + _, err = rc.Conn.Write(rc.wbuf.Bytes()) + rc.wbuf.Reset() + return + } + return rc.Conn.Write(b) +} + +func (rc *relayConn) udpWrite(b []byte) (n int, err error) { + if len(b) > math.MaxUint16 { + err = errors.New("write: data maximum exceeded") + return + } + + n = len(b) + if rc.wbuf != nil && rc.wbuf.Len() > 0 { + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + rc.wbuf.Write(bb[:]) + rc.wbuf.Write(b) + _, err = rc.wbuf.WriteTo(rc.Conn) + return + } + + var bb [2]byte + binary.BigEndian.PutUint16(bb[:], uint16(len(b))) + _, err = rc.Conn.Write(bb[:]) + if err != nil { + return + } + return rc.Conn.Write(b) +} + +func readRelayResponse(r io.Reader) error { + resp := relay.Response{} + if _, err := resp.ReadFrom(r); err != nil { + return err + } + if resp.Version != relay.Version1 { + return relay.ErrBadVersion + } + if resp.Status != relay.StatusOK { + return fmt.Errorf("status %d", resp.Status) + } + return nil +} + +func serializeRelayAddr(m *M.Metadata) *relay.AddrFeature { + af := &relay.AddrFeature{ + Host: m.DstIP.String(), + Port: m.DstPort, + } + if m.DstIP.To4() != nil { + af.AType = relay.AddrIPv4 + } else { + af.AType = relay.AddrIPv6 + } + return af +}