Skip to content

Commit

Permalink
Merge pull request #414 from uselagoon/tests
Browse files Browse the repository at this point in the history
 Add more tests to improve coverage
  • Loading branch information
smlx authored Mar 12, 2024
2 parents e33bcd6 + 357c9a2 commit e4c632a
Show file tree
Hide file tree
Showing 10 changed files with 308 additions and 23 deletions.
6 changes: 5 additions & 1 deletion internal/k8s/exec.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ import (
"k8s.io/client-go/tools/remotecommand"
)

const (
idleAnnotation = "idling.amazee.io/unidle-replicas"
)

// podContainer returns the first pod and first container inside that pod for
// the given namespace and deployment.
func (c *Client) podContainer(ctx context.Context, namespace,
Expand Down Expand Up @@ -68,7 +72,7 @@ func (c *Client) hasRunningPod(ctx context.Context,
// replicas to restore. If the label cannot be read or parsed, 1 is returned.
// The return value is clamped to the interval [1,16].
func unidleReplicas(deploy appsv1.Deployment) int {
rs, ok := deploy.Annotations["idling.amazee.io/unidle-replicas"]
rs, ok := deploy.Annotations[idleAnnotation]
if !ok {
return 1
}
Expand Down
37 changes: 37 additions & 0 deletions internal/k8s/exec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package k8s

import (
"testing"

"github.com/alecthomas/assert/v2"
appsv1 "k8s.io/api/apps/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

func TestUnidleReplicas(t *testing.T) {
var testCases = map[string]struct {
input string
expect int
}{
"simple": {input: "4", expect: 4},
"high edge": {input: "16", expect: 16},
"low edge": {input: "1", expect: 1},
"zero": {input: "0", expect: 1},
"too high": {input: "17", expect: 16},
"way too high": {input: "17000000", expect: 16},
"overflow too high": {input: "9223372036854775808", expect: 1},
"too low": {input: "-1", expect: 1},
"way too low": {input: "-17000000", expect: 1},
"overflow too low": {input: "-9223372036854775808", expect: 1},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
deploy := appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{
Annotations: map[string]string{idleAnnotation: tc.input},
},
}
assert.Equal(tt, tc.expect, unidleReplicas(deploy), name)
})
}
}
46 changes: 46 additions & 0 deletions internal/k8s/logs_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package k8s

import (
"context"
"io"
"strings"
"testing"
"time"

"github.com/alecthomas/assert/v2"
)

func TestLinewiseCopy(t *testing.T) {
var testCases = map[string]struct {
input string
expect []string
prefix string
}{
"logs": {
input: "foo\nbar\nbaz\n",
expect: []string{"test: foo", "test: bar", "test: baz"},
prefix: "test:",
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
out := make(chan string, 1)
in := io.NopCloser(strings.NewReader(tc.input))
go linewiseCopy(ctx, tc.prefix, out, in)
timer := time.NewTimer(500 * time.Millisecond)
var lines []string
loop:
for {
select {
case <-timer.C:
break loop
case line := <-out:
lines = append(lines, line)
}
}
assert.Equal(tt, tc.expect, lines, name)
})
}
}
41 changes: 41 additions & 0 deletions internal/k8s/namespacedetails_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package k8s

import (
"testing"

"github.com/alecthomas/assert/v2"
)

func TestIntFromLabel(t *testing.T) {
labels := map[string]string{
"foo": "1",
"bar": "hello",
"baz": "true",
"negative": "-1",
"max": "9223372036854775807",
"overflow": "9223372036854775808",
}
var testCases = map[string]struct {
target string
expect int
expectErr bool
}{
"foo": {target: "foo", expect: 1},
"bar": {target: "bar", expectErr: true},
"baz": {target: "baz", expectErr: true},
"negative": {target: "negative", expect: -1},
"max": {target: "max", expect: 9223372036854775807},
"overflow": {target: "overflow", expectErr: true},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
result, err := intFromLabel(labels, tc.target)
if tc.expectErr {
assert.Error(tt, err, name)
} else {
assert.NoError(tt, err, name)
assert.Equal(tt, tc.expect, result, name)
}
})
}
}
37 changes: 37 additions & 0 deletions internal/k8s/spin_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package k8s

import (
"context"
"strings"
"testing"
"time"

"github.com/alecthomas/assert/v2"
)

func TestSpinAfter(t *testing.T) {
wait := 500 * time.Millisecond
var testCases = map[string]struct {
connectTime time.Duration
expectSpinner bool
}{
"spinner": {connectTime: 600 * time.Millisecond, expectSpinner: true},
"no spinner": {connectTime: 400 * time.Millisecond, expectSpinner: false},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
var buf strings.Builder
// start the spinner with a given connect time
ctx, cancel := context.WithTimeout(context.Background(), tc.connectTime)
wg := spinAfter(ctx, &buf, wait)
wg.Wait()
cancel()
// check if the builder has spinner animations
if tc.expectSpinner {
assert.NotZero(tt, buf.Len(), name)
} else {
assert.Zero(tt, buf.Len(), name)
}
})
}
}
39 changes: 39 additions & 0 deletions internal/k8s/termsizequeue_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package k8s

import (
"context"
"testing"

"github.com/alecthomas/assert/v2"
"github.com/gliderlabs/ssh"
"k8s.io/client-go/tools/remotecommand"
)

func TestTermSizeQueue(t *testing.T) {
var testCases = map[string]struct {
input ssh.Window
expect remotecommand.TerminalSize
}{
"term size change": {
input: ssh.Window{
Width: 100,
Height: 200,
},
expect: remotecommand.TerminalSize{
Width: 100,
Height: 200,
},
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
in := make(chan ssh.Window, 1)
tsq := newTermSizeQueue(ctx, in)
in <- tc.input
output := tsq.Next()
assert.Equal(tt, tc.expect, *output, name)
})
}
}
27 changes: 27 additions & 0 deletions internal/k8s/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package k8s_test

import (
"testing"

"github.com/alecthomas/assert/v2"
"github.com/uselagoon/ssh-portal/internal/k8s"
)

func TestValidateLabelValues(t *testing.T) {
var testCases = map[string]struct {
input string
expectError bool
}{
"valid": {input: "foo", expectError: false},
"invalid": {input: "naïve", expectError: true},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
if tc.expectError {
assert.Error(tt, k8s.ValidateLabelValue(tc.input), name)
} else {
assert.NoError(tt, k8s.ValidateLabelValue(tc.input), name)
}
})
}
}
9 changes: 7 additions & 2 deletions internal/rbac/usercansshtoenvironment.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,13 @@ var defaultEnvTypeRoleCanSSH = map[lagoon.EnvironmentType][]lagoon.UserRole{
// UserCanSSHToEnvironment returns true if the given environment can be
// connected to via SSH by the user with the given realm roles and user groups,
// and false otherwise.
func (p *Permission) UserCanSSHToEnvironment(ctx context.Context, env *lagoondb.Environment,
realmRoles, userGroups []string, groupProjectIDs map[string][]int) bool {
func (p *Permission) UserCanSSHToEnvironment(
ctx context.Context,
env *lagoondb.Environment,
realmRoles,
userGroups []string,
groupProjectIDs map[string][]int,
) bool {
// set up tracing
_, span := otel.Tracer(pkgName).Start(ctx, "UserCanSSHToEnvironment")
defer span.End()
Expand Down
24 changes: 24 additions & 0 deletions internal/sshserver/serve_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package sshserver

import (
"slices"
"testing"

"github.com/alecthomas/assert/v2"
)

func TestDisableSHA1Kex(t *testing.T) {
var testCases = map[string]struct {
input string
expect bool
}{
"no sha1": {input: "diffie-hellman-group14-sha1", expect: false},
}
for name, tc := range testCases {
t.Run(name, func(tt *testing.T) {
conf := disableSHA1Kex(nil)
assert.Equal(tt, tc.expect,
slices.Contains(conf.Config.KeyExchanges, tc.input), name)
})
}
}
65 changes: 45 additions & 20 deletions internal/sshserver/sessionhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,42 @@ var (
})
)

// authCtxValues extracts the context values set by the authhandler.
func authCtxValues(ctx ssh.Context) (int, string, int, string, string, error) {
var ok bool
var eid, pid int
var ename, pname, fingerprint string
eid, ok = ctx.Value(environmentIDKey).(int)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract environment ID from session context")
}
ename, ok = ctx.Value(environmentNameKey).(string)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract environment name from session context")
}
pid, ok = ctx.Value(projectIDKey).(int)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract project ID from session context")
}
pname, ok = ctx.Value(projectNameKey).(string)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract project name from session context")
}
fingerprint, ok = ctx.Value(sshFingerprint).(string)
if !ok {
return eid, ename, pid, pname, fingerprint,
fmt.Errorf("couldn't extract SSH key fingerprint from session context")
}
return eid, ename, pid, pname, fingerprint, nil
}

// getSSHIntent analyses the SFTP flag and the raw command strings to determine
// if the command should be wrapped.
// if the command should be wrapped, and returns the given cmd wrapped
// appropriately.
func getSSHIntent(sftp bool, cmd []string) []string {
// if this is an sftp session we ignore any commands
if sftp {
Expand Down Expand Up @@ -104,25 +138,16 @@ func sessionHandler(log *slog.Logger, c K8SAPIService,
return
}
// extract info passed through the context by the authhandler
eid, ok := ctx.Value(environmentIDKey).(int)
if !ok {
log.Warn("couldn't extract environment ID from session context")
}
ename, ok := ctx.Value(environmentNameKey).(string)
if !ok {
log.Warn("couldn't extract environment name from session context")
}
pid, ok := ctx.Value(projectIDKey).(int)
if !ok {
log.Warn("couldn't extract project ID from session context")
}
pname, ok := ctx.Value(projectNameKey).(string)
if !ok {
log.Warn("couldn't extract project name from session context")
}
fingerprint, ok := ctx.Value(sshFingerprint).(string)
if !ok {
log.Warn("couldn't extract SSH key fingerprint from session context")
eid, ename, pid, pname, fingerprint, err := authCtxValues(ctx)
if err != nil {
log.Error("couldn't extract auth values from context",
slog.Any("error", err))
_, err = fmt.Fprintf(s.Stderr(), "error executing command. SID: %s\r\n",
ctx.SessionID())
if err != nil {
log.Debug("couldn't write to session stream", slog.Any("error", err))
}
return
}
if len(logs) != 0 {
if !logAccessEnabled {
Expand Down

0 comments on commit e4c632a

Please sign in to comment.