Skip to content

Mutate and Validate RayCluster on SecurityContext #574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions pkg/controllers/raycluster_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package controllers

import (
"context"
"reflect"
"strconv"

rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
Expand Down Expand Up @@ -123,6 +124,16 @@ func (w *rayClusterWebhook) Default(ctx context.Context, obj runtime.Object) err
}
}

// Set the security context for the head container and worker containers
for i := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext = securityContext()
}
for i := range rayCluster.Spec.WorkerGroupSpecs {
for j := range rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers {
rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext = securityContext()
}
}

return nil
}

Expand All @@ -133,6 +144,7 @@ func (w *rayClusterWebhook) ValidateCreate(ctx context.Context, obj runtime.Obje
var allErrors field.ErrorList

allErrors = append(allErrors, validateIngress(rayCluster)...)
allErrors = append(allErrors, validateSecurityContext(rayCluster)...)

if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...)
Expand All @@ -155,6 +167,7 @@ func (w *rayClusterWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj r
}

allErrors = append(allErrors, validateIngress(rayCluster)...)
allErrors = append(allErrors, validateSecurityContext(rayCluster)...)

if ptr.Deref(w.Config.RayDashboardOAuthEnabled, true) {
allErrors = append(allErrors, validateOAuthProxyContainer(rayCluster)...)
Expand Down Expand Up @@ -202,6 +215,32 @@ func validateOAuthProxyVolume(rayCluster *rayv1.RayCluster) field.ErrorList {
return allErrors
}

func validateSecurityContext(rayCluster *rayv1.RayCluster) field.ErrorList {
var allErrors field.ErrorList

for i := range rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
if !reflect.DeepEqual(rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext, securityContext()) {
allErrors = append(allErrors, field.Invalid(
field.NewPath("spec", "headGroupSpec", "template", "spec", "containers", strconv.Itoa(i), "securityContext"),
rayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext,
"SecurityContext is immutable"))
}
}

for i := range rayCluster.Spec.WorkerGroupSpecs {
for j := range rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers {
if !reflect.DeepEqual(rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext, securityContext()) {
allErrors = append(allErrors, field.Invalid(
field.NewPath("spec", "workerGroupSpecs", strconv.Itoa(i), "template", "spec", "containers", strconv.Itoa(j), "securityContext"),
rayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext,
"SecurityContext is immutable"))
}
}
}

return allErrors
}

func validateIngress(rayCluster *rayv1.RayCluster) field.ErrorList {
var allErrors field.ErrorList

Expand Down Expand Up @@ -268,6 +307,18 @@ func oauthProxyContainer(rayCluster *rayv1.RayCluster) corev1.Container {
}
}

func securityContext() *corev1.SecurityContext {
return &corev1.SecurityContext{
AllowPrivilegeEscalation: ptr.To(false),
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{"ALL"},
},
SeccompProfile: &corev1.SeccompProfile{
Type: "RuntimeDefault",
},
}
}

func oauthProxyTLSSecretVolume(rayCluster *rayv1.RayCluster) corev1.Volume {
return corev1.Volume{
Name: oauthProxyVolumeName,
Expand Down
70 changes: 70 additions & 0 deletions pkg/controllers/raycluster_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/utils/ptr"

"github.com/project-codeflare/codeflare-operator/pkg/config"
)
Expand Down Expand Up @@ -226,6 +227,22 @@ func TestRayClusterWebhookDefault(t *testing.T) {
}
})

t.Run("Expected required SecurityContext for each head group container", func(t *testing.T) {
for _, container := range validRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
test.Expect(container.SecurityContext).To(Equal(securityContext()),
"Expected the required SecurityContext to be present in each head group container")
}
})

t.Run("Expected required SecurityContext for each worker group container", func(t *testing.T) {
for _, workerGroup := range validRayCluster.Spec.WorkerGroupSpecs {
for _, container := range workerGroup.Template.Spec.Containers {
test.Expect(container.SecurityContext).To(Equal(securityContext()),
"Expected the required SecurityContext to be present in each worker group container")
}
}
})

}

func TestValidateCreate(t *testing.T) {
Expand Down Expand Up @@ -277,6 +294,15 @@ func TestValidateCreate(t *testing.T) {
ReadOnly: true,
},
},
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: ptr.To(false),
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{"ALL"},
},
SeccompProfile: &corev1.SeccompProfile{
Type: "RuntimeDefault",
},
},
},
},
Volumes: []corev1.Volume{
Expand Down Expand Up @@ -346,6 +372,14 @@ func TestValidateCreate(t *testing.T) {
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group service account name")
})

t.Run("Negative: Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext", func(t *testing.T) {
for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true)
}
_, err = rcWebhook.ValidateCreate(test.Ctx(), runtime.Object(invalidRayCluster))
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateCreate function due to manipulated head group container SecurityContext")
})

}

func TestValidateUpdate(t *testing.T) {
Expand Down Expand Up @@ -409,6 +443,15 @@ func TestValidateUpdate(t *testing.T) {
ReadOnly: true,
},
},
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: ptr.To(false),
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{"ALL"},
},
SeccompProfile: &corev1.SeccompProfile{
Type: "RuntimeDefault",
},
},
},
},
InitContainers: []corev1.Container{
Expand Down Expand Up @@ -485,6 +528,15 @@ func TestValidateUpdate(t *testing.T) {
{Name: "RAY_TLS_SERVER_KEY", Value: "/home/ray/workspace/tls/server.key"},
{Name: "RAY_TLS_CA_CERT", Value: "/home/ray/workspace/tls/ca.crt"},
},
SecurityContext: &corev1.SecurityContext{
AllowPrivilegeEscalation: ptr.To(false),
Capabilities: &corev1.Capabilities{
Drop: []corev1.Capability{"ALL"},
},
SeccompProfile: &corev1.SeccompProfile{
Type: "RuntimeDefault",
},
},
},
},
InitContainers: []corev1.Container{
Expand Down Expand Up @@ -644,4 +696,22 @@ func TestValidateUpdate(t *testing.T) {
_, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster))
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated env vars in the worker group")
})

t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container", func(t *testing.T) {
for i := range invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers {
invalidRayCluster.Spec.HeadGroupSpec.Template.Spec.Containers[i].SecurityContext.AllowPrivilegeEscalation = ptr.To(true)
}
_, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster))
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the head group container")
})

t.Run("Negative: Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container", func(t *testing.T) {
for i := range invalidRayCluster.Spec.WorkerGroupSpecs {
for j := range invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers {
invalidRayCluster.Spec.WorkerGroupSpecs[i].Template.Spec.Containers[j].SecurityContext.AllowPrivilegeEscalation = ptr.To(true)
}
}
_, err := rcWebhook.ValidateUpdate(test.Ctx(), runtime.Object(validRayCluster), runtime.Object(invalidRayCluster))
test.Expect(err).Should(HaveOccurred(), "Expected errors on call to ValidateUpdate function due to manipulated SecurityContext in the worker group container")
})
}