Skip to content

Commit

Permalink
Merge pull request #470 from uselagoon/nats-update
Browse files Browse the repository at this point in the history
feat: avoid deprecated NATS API
  • Loading branch information
smlx authored Sep 27, 2024
2 parents 11395a4 + bb058b4 commit f5c8c3f
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 27 deletions.
16 changes: 10 additions & 6 deletions cmd/ssh-portal/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM)
defer stop()
// get nats server connection
nconn, err := nats.Connect(cmd.NATSServer,
nc, err := nats.Connect(cmd.NATSServer,
nats.Name("ssh-portal"),
// exit on connection close
nats.ClosedHandler(func(_ *nats.Conn) {
Expand All @@ -52,10 +52,6 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
if err != nil {
return fmt.Errorf("couldn't connect to NATS server: %v", err)
}
nc, err := nats.NewEncodedConn(nconn, "json")
if err != nil {
return fmt.Errorf("couldn't get encoded conn: %v", err)
}
defer nc.Close()
// start listening on TCP port
l, err := net.Listen("tcp", fmt.Sprintf(":%d", cmd.SSHServerPort))
Expand Down Expand Up @@ -83,7 +79,15 @@ func (cmd *ServeCmd) Run(log *slog.Logger) error {
eg.Go(func() error {
// start serving SSH connection requests
return sshserver.Serve(
ctx, log, nc, l, c, hostkeys, cmd.LogAccessEnabled, cmd.Banner)
ctx,
log,
nc,
l,
c,
hostkeys,
cmd.LogAccessEnabled,
cmd.Banner,
)
})
return eg.Wait()
}
15 changes: 7 additions & 8 deletions internal/sshportalapi/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func ServeNATS(
wg := sync.WaitGroup{}
wg.Add(1)
// connect to NATS server
nconn, err := nats.Connect(natsURL,
nc, err := nats.Connect(natsURL,
nats.Name("ssh-portal-api"),
// synchronise exiting ServeNATS()
nats.ClosedHandler(func(_ *nats.Conn) {
Expand All @@ -67,14 +67,13 @@ func ServeNATS(
if err != nil {
return fmt.Errorf("couldn't connect to NATS server: %v", err)
}
nc, err := nats.NewEncodedConn(nconn, "json")
if err != nil {
return fmt.Errorf("couldn't get encoded conn: %v", err)
}
defer nc.Close()
// set up request/response callback for sshportal
_, err = nc.QueueSubscribe(bus.SubjectSSHAccessQuery, queue,
sshportal(ctx, log, nc, p, l, k))
// configure callback
_, err = nc.QueueSubscribe(
bus.SubjectSSHAccessQuery,
queue,
sshportal(ctx, log, nc, p, l, k),
)
if err != nil {
return fmt.Errorf("couldn't subscribe to queue: %v", err)
}
Expand Down
28 changes: 21 additions & 7 deletions internal/sshportalapi/sshportal.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sshportalapi

import (
"context"
"encoding/json"
"errors"
"log/slog"
"time"
Expand All @@ -23,20 +24,30 @@ var (
})
)

var (
falseResponse = []byte(`false`)
trueResponse = []byte(`true`)
)

func sshportal(
ctx context.Context,
log *slog.Logger,
c *nats.EncodedConn,
c *nats.Conn,
p *rbac.Permission,
l LagoonDBService,
k KeycloakService,
) nats.Handler {
return func(_, replySubject string, query *bus.SSHAccessQuery) {
) nats.MsgHandler {
return func(msg *nats.Msg) {
var realmRoles, userGroups []string
// set up tracing and update metrics
ctx, span := otel.Tracer(pkgName).Start(ctx, bus.SubjectSSHAccessQuery)
defer span.End()
requestsCounter.Inc()
var query bus.SSHAccessQuery
if err := json.Unmarshal(msg.Data, &query); err != nil {
log.Warn("couldn't unmarshal query", slog.Any("query", msg.Data))
return
}
log := log.With(slog.Any("query", query))
// sanity check the query
if query.SSHFingerprint == "" || query.NamespaceName == "" {
Expand All @@ -48,7 +59,7 @@ func sshportal(
if err != nil {
if errors.Is(err, lagoondb.ErrNoResult) {
log.Warn("unknown namespace name", slog.Any("error", err))
if err = c.Publish(replySubject, false); err != nil {
if err = c.Publish(msg.Reply, falseResponse); err != nil {
log.Error("couldn't publish reply", slog.Any("error", err))
}
return
Expand All @@ -65,7 +76,7 @@ func sshportal(
log.Warn("ID mismatch in environment identification",
slog.Any("env", env),
slog.Any("error", err))
if err = c.Publish(replySubject, false); err != nil {
if err = c.Publish(msg.Reply, falseResponse); err != nil {
log.Error("couldn't publish reply", slog.Any("error", err))
}
return
Expand All @@ -75,7 +86,7 @@ func sshportal(
if err != nil {
if errors.Is(err, lagoondb.ErrNoResult) {
log.Debug("unknown SSH Fingerprint", slog.Any("error", err))
if err = c.Publish(replySubject, false); err != nil {
if err = c.Publish(msg.Reply, falseResponse); err != nil {
log.Error("couldn't publish reply", slog.Any("error", err))
}
return
Expand Down Expand Up @@ -115,10 +126,13 @@ func sshportal(
ok := p.UserCanSSHToEnvironment(
ctx, env, realmRoles, userGroups, groupNameProjectIDsMap)
var logMsg string
var response []byte
if ok {
logMsg = "SSH access authorized"
response = trueResponse
} else {
logMsg = "SSH access not authorized"
response = falseResponse
}
log.Info(logMsg,
slog.Int("environmentID", env.ID),
Expand All @@ -127,7 +141,7 @@ func sshportal(
slog.String("projectName", env.ProjectName),
slog.String("userUUID", user.UUID.String()),
)
if err = c.Publish(replySubject, ok); err != nil {
if err = c.Publish(msg.Reply, response); err != nil {
log.Error("couldn't publish reply",
slog.String("userUUID", user.UUID.String()),
slog.Any("error", err))
Expand Down
27 changes: 27 additions & 0 deletions internal/sshportalapi/sshportal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package sshportalapi

import (
"encoding/json"
"testing"
)

func TestResponseMarshal(t *testing.T) {
var testCases = map[string]struct {
input []byte
expect bool
}{
"true": {input: trueResponse, expect: true},
"false": {input: falseResponse, expect: false},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
var value bool
if err := json.Unmarshal(tc.input, &value); err != nil {
tt.Fatalf("error unmarshaling data %v to bool", tc.input)
}
if value != tc.expect {
tt.Fatalf("expected %v, got %v", tc.expect, value)
}
})
}
}
22 changes: 17 additions & 5 deletions internal/sshserver/authhandler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package sshserver

import (
"encoding/json"
"log/slog"
"time"

Expand Down Expand Up @@ -40,8 +41,11 @@ var (

// pubKeyAuth returns a ssh.PublicKeyHandler which queries the remote
// ssh-portal-api for Lagoon SSH authorization.
func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn,
c *k8s.Client) ssh.PublicKeyHandler {
func pubKeyAuth(
log *slog.Logger,
nc *nats.Conn,
c *k8s.Client,
) ssh.PublicKeyHandler {
return func(ctx ssh.Context, key ssh.PublicKey) bool {
authAttemptsTotal.Inc()
log := log.With(slog.String("sessionID", ctx.SessionID()))
Expand All @@ -60,21 +64,29 @@ func pubKeyAuth(log *slog.Logger, nc *nats.EncodedConn,
}
// construct ssh access query
fingerprint := gossh.FingerprintSHA256(pubKey)
q := bus.SSHAccessQuery{
queryData, err := json.Marshal(bus.SSHAccessQuery{
SSHFingerprint: fingerprint,
NamespaceName: ctx.User(),
ProjectID: pid,
EnvironmentID: eid,
SessionID: ctx.SessionID(),
})
if err != nil {
log.Warn("couldn't marshal NATS request", slog.Any("error", err))
return false
}
// send query
var ok bool
err = nc.Request(bus.SubjectSSHAccessQuery, q, &ok, natsTimeout)
msg, err := nc.Request(bus.SubjectSSHAccessQuery, queryData, natsTimeout)
if err != nil {
log.Warn("couldn't make NATS request", slog.Any("error", err))
return false
}
// handle response
var ok bool
if err := json.Unmarshal(msg.Data, &ok); err != nil {
log.Warn("couldn't unmarshal response", slog.Any("response", msg.Data))
return false
}
if !ok {
log.Debug("SSH access not authorized",
slog.String("fingerprint", fingerprint),
Expand Down
2 changes: 1 addition & 1 deletion internal/sshserver/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func disableSHA1Kex(_ ssh.Context) *gossh.ServerConfig {
func Serve(
ctx context.Context,
log *slog.Logger,
nc *nats.EncodedConn,
nc *nats.Conn,
l net.Listener,
c *k8s.Client,
hostKeys [][]byte,
Expand Down

0 comments on commit f5c8c3f

Please sign in to comment.