From 6c72c3623559138569c2b6991fbfae124e8fbd4e Mon Sep 17 00:00:00 2001 From: Scott Leggett Date: Tue, 15 Mar 2022 22:56:34 +0800 Subject: [PATCH 1/3] feat: use nats.EncodedConn instead of manual JSON marshalling --- cmd/ssh-portal/serve.go | 7 ++++++- internal/sshserver/authhandler.go | 19 ++++++------------- internal/sshserver/serve.go | 2 +- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/cmd/ssh-portal/serve.go b/cmd/ssh-portal/serve.go index bc64ce4f..3caf076f 100644 --- a/cmd/ssh-portal/serve.go +++ b/cmd/ssh-portal/serve.go @@ -35,7 +35,7 @@ func (cmd *ServeCmd) Run(log *zap.Logger) error { ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM) defer stop() // get nats server connection - nc, err := nats.Connect(cmd.NATSServer, + nconn, err := nats.Connect(cmd.NATSServer, // exit on connection close nats.ClosedHandler(func(_ *nats.Conn) { log.Error("nats connection closed") @@ -50,6 +50,11 @@ func (cmd *ServeCmd) Run(log *zap.Logger) error { if err != nil { return fmt.Errorf("couldn't connect to NATS server: %v", err) } + nc, err := nats.NewEncodedConn(nconn, "json") + if err != nil { + return fmt.Errorf("couldn't get encoded conn: %v", err) + } + defer nc.Close() // start listening on TCP port l, err := net.Listen("tcp", fmt.Sprintf(":%d", cmd.SSHServerPort)) if err != nil { diff --git a/internal/sshserver/authhandler.go b/internal/sshserver/authhandler.go index 365b8690..78ec459c 100644 --- a/internal/sshserver/authhandler.go +++ b/internal/sshserver/authhandler.go @@ -1,8 +1,6 @@ package sshserver import ( - "bytes" - "encoding/json" "time" "github.com/gliderlabs/ssh" @@ -32,7 +30,7 @@ var ( // pubKeyAuth returns a ssh.PublicKeyHandler which accepts any key, and simply // adds the given key to the connection context. -func pubKeyAuth(log *zap.Logger, nc *nats.Conn, +func pubKeyAuth(log *zap.Logger, nc *nats.EncodedConn, c *k8s.Client) ssh.PublicKeyHandler { return func(ctx ssh.Context, key ssh.PublicKey) bool { authAttemptsTotal.Inc() @@ -52,23 +50,18 @@ func pubKeyAuth(log *zap.Logger, nc *nats.Conn, zap.String("namespace", ctx.User()), zap.Error(err)) return false } - // construct and marshal ssh access query + // construct ssh access query fingerprint := gossh.FingerprintSHA256(pubKey) - data, err := json.Marshal(&sshportalapi.SSHAccessQuery{ + q := sshportalapi.SSHAccessQuery{ SSHFingerprint: fingerprint, NamespaceName: ctx.User(), ProjectID: pid, EnvironmentID: eid, SessionID: ctx.SessionID(), - }) - if err != nil { - log.Warn("couldn't marshal SSHAccessQuery", - zap.String("session-id", ctx.SessionID()), - zap.Error(err)) - return false } // send query - response, err := nc.Request(sshportalapi.SubjectSSHAccessQuery, data, + var response bool + err = nc.Request(sshportalapi.SubjectSSHAccessQuery, q, &response, natsTimeout) if err != nil { log.Warn("couldn't make NATS request", @@ -77,7 +70,7 @@ func pubKeyAuth(log *zap.Logger, nc *nats.Conn, return false } // handle response - if bytes.Equal(response.Data, []byte("true")) { + if response { authSuccessTotal.Inc() log.Debug("authentication successful", zap.String("session-id", ctx.SessionID()), diff --git a/internal/sshserver/serve.go b/internal/sshserver/serve.go index 7db11e78..88d6c1e4 100644 --- a/internal/sshserver/serve.go +++ b/internal/sshserver/serve.go @@ -14,7 +14,7 @@ import ( ) // Serve contains the main ssh session logic -func Serve(ctx context.Context, log *zap.Logger, nc *nats.Conn, +func Serve(ctx context.Context, log *zap.Logger, nc *nats.EncodedConn, l net.Listener, c *k8s.Client, hostKeys [][]byte) error { srv := ssh.Server{ Handler: sessionHandler(log, c), From 999ab04c77c03ecab874297d2603477bf6ab7488 Mon Sep 17 00:00:00 2001 From: Scott Leggett Date: Tue, 15 Mar 2022 22:57:09 +0800 Subject: [PATCH 2/3] chore: minor formatting changes for consistency --- internal/sshportalapi/server.go | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/internal/sshportalapi/server.go b/internal/sshportalapi/server.go index 04a6150e..609d3ca3 100644 --- a/internal/sshportalapi/server.go +++ b/internal/sshportalapi/server.go @@ -34,7 +34,7 @@ func ServeNATS(ctx context.Context, stop context.CancelFunc, log *zap.Logger, wg := sync.WaitGroup{} wg.Add(1) // connect to NATS server - nc, err := nats.Connect(natsURL, + nconn, err := nats.Connect(natsURL, // synchronise exiting ServeNATS() nats.ClosedHandler(func(_ *nats.Conn) { log.Error("nats connection closed") @@ -50,20 +50,21 @@ func ServeNATS(ctx context.Context, stop context.CancelFunc, log *zap.Logger, if err != nil { return fmt.Errorf("couldn't connect to NATS server: %v", err) } - c, err := nats.NewEncodedConn(nc, "json") + nc, err := nats.NewEncodedConn(nconn, "json") if err != nil { return fmt.Errorf("couldn't get encoded conn: %v", err) } - defer c.Close() + defer nc.Close() // set up request/response callback for sshportal - _, err = c.QueueSubscribe(SubjectSSHAccessQuery, queue, sshportal(ctx, log, c, l, k)) + _, err = nc.QueueSubscribe(SubjectSSHAccessQuery, queue, + sshportal(ctx, log, nc, l, k)) if err != nil { return fmt.Errorf("couldn't subscribe to queue: %v", err) } // wait for context cancellation <-ctx.Done() // drain and log errors - if err := c.Drain(); err != nil { + if err := nc.Drain(); err != nil { log.Warn("couldn't drain connection", zap.Error(err)) } // wait for connection to close From c2096c7f83006f14226dbbfbd808e42279bef902 Mon Sep 17 00:00:00 2001 From: Scott Leggett Date: Wed, 16 Mar 2022 10:45:28 +0800 Subject: [PATCH 3/3] feat: add name to NATS connections This facilitates debugging. --- cmd/ssh-portal/serve.go | 1 + internal/sshportalapi/server.go | 1 + 2 files changed, 2 insertions(+) diff --git a/cmd/ssh-portal/serve.go b/cmd/ssh-portal/serve.go index 3caf076f..ab6af791 100644 --- a/cmd/ssh-portal/serve.go +++ b/cmd/ssh-portal/serve.go @@ -36,6 +36,7 @@ func (cmd *ServeCmd) Run(log *zap.Logger) error { defer stop() // get nats server connection nconn, err := nats.Connect(cmd.NATSServer, + nats.Name("ssh-portal"), // exit on connection close nats.ClosedHandler(func(_ *nats.Conn) { log.Error("nats connection closed") diff --git a/internal/sshportalapi/server.go b/internal/sshportalapi/server.go index 609d3ca3..1f37fcff 100644 --- a/internal/sshportalapi/server.go +++ b/internal/sshportalapi/server.go @@ -35,6 +35,7 @@ func ServeNATS(ctx context.Context, stop context.CancelFunc, log *zap.Logger, wg.Add(1) // connect to NATS server nconn, err := nats.Connect(natsURL, + nats.Name("ssh-portal-api"), // synchronise exiting ServeNATS() nats.ClosedHandler(func(_ *nats.Conn) { log.Error("nats connection closed")