diff --git a/Dockerfile b/Dockerfile index 44e739229a..6d4e40d255 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,7 @@ FROM golang:1.12.7-stretch as builder WORKDIR /go/src/github.com/kubernetes-sigs/aws-ebs-csi-driver ADD . . -RUN make +RUN make FROM amazonlinux:2 RUN yum install ca-certificates e2fsprogs xfsprogs util-linux -y diff --git a/Makefile b/Makefile index 78c4e62e82..85c76aea25 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ verify: .PHONY: test test: - go test -v -race ./pkg/... + go test -v -race ./cmd/... ./pkg/... .PHONY: test-sanity test-sanity: diff --git a/aws-ebs-csi-driver/templates/manifest.yaml b/aws-ebs-csi-driver/templates/manifest.yaml index 70726122c0..f77ae1ac5d 100644 --- a/aws-ebs-csi-driver/templates/manifest.yaml +++ b/aws-ebs-csi-driver/templates/manifest.yaml @@ -222,10 +222,11 @@ spec: - name: ebs-plugin image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}" args: + - controller - --endpoint=$(CSI_ENDPOINT) + {{ include "aws-ebs-csi-driver.extra-volume-tags" . }} - --logtostderr - --v=5 - {{ include "aws-ebs-csi-driver.extra-volume-tags" . }} env: - name: CSI_ENDPOINT value: unix:///var/lib/csi/sockets/pluginproxy/csi.sock @@ -241,6 +242,10 @@ spec: name: aws-secret key: access_key optional: true + {{- if .Values.region }} + - name: AWS_REGION + value: {{ .Values.region }} + {{- end }} volumeMounts: - name: socket-dir mountPath: /var/lib/csi/sockets/pluginproxy/ @@ -353,6 +358,7 @@ spec: privileged: true image: "{{ .Values.image.repository }}:{{ .Values.image.tag }}" args: + - node - --endpoint=$(CSI_ENDPOINT) - --logtostderr - --v=5 diff --git a/aws-ebs-csi-driver/values.yaml b/aws-ebs-csi-driver/values.yaml index 608eb3f2fb..3f06b6c768 100644 --- a/aws-ebs-csi-driver/values.yaml +++ b/aws-ebs-csi-driver/values.yaml @@ -68,3 +68,9 @@ affinity: {} # key1: value1 # key2: value2 extraVolumeTags: {} + +# AWS region to use. If not specified then the region will be looked up via the AWS EC2 metadata +# service. +# --- +# region: us-east-1 +region: "" diff --git a/cmd/main.go b/cmd/main.go index e2e096a465..6bf98fba9c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -18,40 +18,20 @@ package main import ( "flag" - "fmt" - "os" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" - cliflag "k8s.io/component-base/cli/flag" + "k8s.io/klog" ) func main() { - var ( - version bool - endpoint string - extraVolumeTags map[string]string - ) - - flag.BoolVar(&version, "version", false, "Print the version and exit.") - flag.StringVar(&endpoint, "endpoint", driver.DefaultCSIEndpoint, "CSI Endpoint") - flag.Var(cliflag.NewMapStringString(&extraVolumeTags), "extra-volume-tags", "Extra volume tags to attach to each dynamically provisioned volume. It is a comma separated list of key value pairs like '=,='") - - klog.InitFlags(nil) - flag.Parse() - - if version { - info, err := driver.GetVersionJSON() - if err != nil { - klog.Fatalln(err) - } - fmt.Println(info) - os.Exit(0) - } + fs := flag.NewFlagSet("aws-ebs-csi-driver", flag.ExitOnError) + options := GetOptions(fs) drv, err := driver.NewDriver( - driver.WithEndpoint(endpoint), - driver.WithExtraVolumeTags(extraVolumeTags), + driver.WithEndpoint(options.ServerOptions.Endpoint), + driver.WithExtraVolumeTags(options.ControllerOptions.ExtraVolumeTags), + driver.WithMode(options.DriverMode), ) if err != nil { klog.Fatalln(err) diff --git a/cmd/options.go b/cmd/options.go new file mode 100644 index 0000000000..5431fea6fe --- /dev/null +++ b/cmd/options.go @@ -0,0 +1,110 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "flag" + "fmt" + "os" + "strings" + + "github.com/kubernetes-sigs/aws-ebs-csi-driver/cmd/options" + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" + + "k8s.io/klog" +) + +// Options is the combined set of options for all operating modes. +type Options struct { + DriverMode driver.Mode + + *options.ServerOptions + *options.ControllerOptions + *options.NodeOptions +} + +// used for testing +var osExit = os.Exit + +// GetOptions parses the command line options and returns a struct that contains +// the parsed options. +func GetOptions(fs *flag.FlagSet) *Options { + var ( + version = fs.Bool("version", false, "Print the version and exit.") + + args = os.Args[1:] + mode = driver.AllMode + + serverOptions = options.ServerOptions{} + controllerOptions = options.ControllerOptions{} + nodeOptions = options.NodeOptions{} + ) + + serverOptions.AddFlags(fs) + klog.InitFlags(fs) + + if len(os.Args) > 1 { + cmd := os.Args[1] + + switch { + case cmd == string(driver.ControllerMode): + controllerOptions.AddFlags(fs) + args = os.Args[2:] + mode = driver.ControllerMode + + case cmd == string(driver.NodeMode): + nodeOptions.AddFlags(fs) + args = os.Args[2:] + mode = driver.NodeMode + + case cmd == string(driver.AllMode): + controllerOptions.AddFlags(fs) + nodeOptions.AddFlags(fs) + args = os.Args[2:] + + case strings.HasPrefix(cmd, "-"): + controllerOptions.AddFlags(fs) + nodeOptions.AddFlags(fs) + args = os.Args[1:] + + default: + fmt.Printf("unknown command: %s: expected %q, %q or %q", cmd, driver.ControllerMode, driver.NodeMode, driver.AllMode) + os.Exit(1) + } + } + + if err := fs.Parse(args); err != nil { + panic(err) + } + + if *version { + info, err := driver.GetVersionJSON() + if err != nil { + klog.Fatalln(err) + } + fmt.Println(info) + osExit(0) + } + + return &Options{ + DriverMode: mode, + + ServerOptions: &serverOptions, + ControllerOptions: &controllerOptions, + NodeOptions: &nodeOptions, + } +} diff --git a/cmd/options/controller_options.go b/cmd/options/controller_options.go new file mode 100644 index 0000000000..388c387441 --- /dev/null +++ b/cmd/options/controller_options.go @@ -0,0 +1,34 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package options + +import ( + "flag" + + cliflag "k8s.io/component-base/cli/flag" +) + +// ControllerOptions contains options and configuration settings for the controller service. +type ControllerOptions struct { + // ExtraVolumeTags is a map of tags that will be attached to each dynamically provisioned + // volume. + ExtraVolumeTags map[string]string +} + +func (s *ControllerOptions) AddFlags(fs *flag.FlagSet) { + fs.Var(cliflag.NewMapStringString(&s.ExtraVolumeTags), "extra-volume-tags", "Extra volume tags to attach to each dynamically provisioned volume. It is a comma separated list of key value pairs like '=,='") +} diff --git a/cmd/options/controller_options_test.go b/cmd/options/controller_options_test.go new file mode 100644 index 0000000000..5460b63695 --- /dev/null +++ b/cmd/options/controller_options_test.go @@ -0,0 +1,56 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package options + +import ( + "flag" + "testing" +) + +func TestControllerOptions(t *testing.T) { + testCases := []struct { + name string + flag string + found bool + }{ + { + name: "lookup desired flag", + flag: "extra-volume-tags", + found: true, + }, + { + name: "fail for non-desired flag", + flag: "some-other-flag", + found: false, + }, + } + + for _, tc := range testCases { + flagSet := flag.NewFlagSet("test-flagset", flag.ContinueOnError) + controllerOptions := &ControllerOptions{} + + t.Run(tc.name, func(t *testing.T) { + controllerOptions.AddFlags(flagSet) + + flag := flagSet.Lookup(tc.flag) + found := flag != nil + if found != tc.found { + t.Fatalf("result not equal\ngot:\n%v\nexpected:\n%v", found, tc.found) + } + }) + } +} diff --git a/cmd/options/node_options.go b/cmd/options/node_options.go new file mode 100644 index 0000000000..0f696e0386 --- /dev/null +++ b/cmd/options/node_options.go @@ -0,0 +1,26 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package options + +import ( + "flag" +) + +// NodeOptions contains options and configuration settings for the node service. +type NodeOptions struct{} + +func (s *NodeOptions) AddFlags(fs *flag.FlagSet) {} diff --git a/cmd/options/node_options_test.go b/cmd/options/node_options_test.go new file mode 100644 index 0000000000..483d7cac00 --- /dev/null +++ b/cmd/options/node_options_test.go @@ -0,0 +1,51 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package options + +import ( + "flag" + "testing" +) + +func TestNodeOptions(t *testing.T) { + testCases := []struct { + name string + flag string + found bool + }{ + { + name: "fail for non-desired flag", + flag: "some-flag", + found: false, + }, + } + + for _, tc := range testCases { + flagSet := flag.NewFlagSet("test-flagset", flag.ContinueOnError) + nodeOptions := &NodeOptions{} + + t.Run(tc.name, func(t *testing.T) { + nodeOptions.AddFlags(flagSet) + + flag := flagSet.Lookup(tc.flag) + found := flag != nil + if found != tc.found { + t.Fatalf("result not equal\ngot:\n%v\nexpected:\n%v", found, tc.found) + } + }) + } +} diff --git a/cmd/options/server_options.go b/cmd/options/server_options.go new file mode 100644 index 0000000000..0edaca6de0 --- /dev/null +++ b/cmd/options/server_options.go @@ -0,0 +1,33 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package options + +import ( + "flag" + + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" +) + +// ServerOptions contains options and configuration settings for the driver server. +type ServerOptions struct { + // Endpoint is the endpoint that the driver server should listen on. + Endpoint string +} + +func (s *ServerOptions) AddFlags(fs *flag.FlagSet) { + fs.StringVar(&s.Endpoint, "endpoint", driver.DefaultCSIEndpoint, "Endpoint for the CSI driver server") +} diff --git a/cmd/options/server_options_test.go b/cmd/options/server_options_test.go new file mode 100644 index 0000000000..726573f338 --- /dev/null +++ b/cmd/options/server_options_test.go @@ -0,0 +1,56 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package options + +import ( + "flag" + "testing" +) + +func TestServerOptions(t *testing.T) { + testCases := []struct { + name string + flag string + found bool + }{ + { + name: "lookup desired flag", + flag: "endpoint", + found: true, + }, + { + name: "fail for non-desired flag", + flag: "some-other-flag", + found: false, + }, + } + + for _, tc := range testCases { + flagSet := flag.NewFlagSet("test-flagset", flag.ContinueOnError) + serverOptions := &ServerOptions{} + + t.Run(tc.name, func(t *testing.T) { + serverOptions.AddFlags(flagSet) + + flag := flagSet.Lookup(tc.flag) + found := flag != nil + if found != tc.found { + t.Fatalf("result not equal\ngot:\n%v\nexpected:\n%v", found, tc.found) + } + }) + } +} diff --git a/cmd/options_test.go b/cmd/options_test.go new file mode 100644 index 0000000000..cd58e19def --- /dev/null +++ b/cmd/options_test.go @@ -0,0 +1,164 @@ +/* +Copyright 2020 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package main + +import ( + "flag" + "os" + "reflect" + "testing" + + "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver" +) + +func TestGetOptions(t *testing.T) { + testFunc := func( + t *testing.T, + additionalArgs []string, + withServerOptions bool, + withControllerOptions bool, + withNodeOptions bool, + ) *Options { + flagSet := flag.NewFlagSet("test-flagset", flag.ContinueOnError) + + endpointFlagName := "endpoint" + endpoint := "foo" + + extraVolumeTagsFlagName := "extra-volume-tags" + extraVolumeTagKey := "bar" + extraVolumeTagValue := "baz" + extraVolumeTags := map[string]string{ + extraVolumeTagKey: extraVolumeTagValue, + } + + args := append([]string{ + "aws-ebs-csi-driver", + }, additionalArgs...) + + if withServerOptions { + args = append(args, "-"+endpointFlagName+"="+endpoint) + } + if withControllerOptions { + args = append(args, "-"+extraVolumeTagsFlagName+"="+extraVolumeTagKey+"="+extraVolumeTagValue) + } + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = args + + options := GetOptions(flagSet) + + if withServerOptions { + endpointFlag := flagSet.Lookup(endpointFlagName) + if endpointFlag == nil { + t.Fatalf("expected %q flag to be added but it is not", endpointFlagName) + } + if options.ServerOptions.Endpoint != endpoint { + t.Fatalf("expected endpoint to be %q but it is %q", endpoint, options.ServerOptions.Endpoint) + } + } + + if withControllerOptions { + extraVolumeTagsFlag := flagSet.Lookup(extraVolumeTagsFlagName) + if extraVolumeTagsFlag == nil { + t.Fatalf("expected %q flag to be added but it is not", extraVolumeTagsFlagName) + } + if !reflect.DeepEqual(options.ControllerOptions.ExtraVolumeTags, extraVolumeTags) { + t.Fatalf("expected extra volume tags to be %q but it is %q", extraVolumeTags, options.ControllerOptions.ExtraVolumeTags) + } + } + + return options + } + + testCases := []struct { + name string + testFunc func(t *testing.T) + }{ + { + name: "no controller mode given - expect all mode", + testFunc: func(t *testing.T) { + options := testFunc(t, nil, true, true, true) + + if options.DriverMode != driver.AllMode { + t.Fatalf("expected driver mode to be %q but it is %q", driver.AllMode, options.DriverMode) + } + }, + }, + { + name: "all mode given - expect all mode", + testFunc: func(t *testing.T) { + options := testFunc(t, []string{"all"}, true, true, true) + + if options.DriverMode != driver.AllMode { + t.Fatalf("expected driver mode to be %q but it is %q", driver.AllMode, options.DriverMode) + } + }, + }, + { + name: "controller mode given - expect controller mode", + testFunc: func(t *testing.T) { + options := testFunc(t, []string{"controller"}, true, true, false) + + if options.DriverMode != driver.ControllerMode { + t.Fatalf("expected driver mode to be %q but it is %q", driver.ControllerMode, options.DriverMode) + } + }, + }, + { + name: "node mode given - expect node mode", + testFunc: func(t *testing.T) { + options := testFunc(t, []string{"node"}, true, false, true) + + if options.DriverMode != driver.NodeMode { + t.Fatalf("expected driver mode to be %q but it is %q", driver.NodeMode, options.DriverMode) + } + }, + }, + { + name: "version flag specified", + testFunc: func(t *testing.T) { + oldOSExit := osExit + defer func() { osExit = oldOSExit }() + + var exitCode int + testExit := func(code int) { + exitCode = code + } + osExit = testExit + + oldArgs := os.Args + defer func() { os.Args = oldArgs }() + os.Args = []string{ + "aws-ebs-csi-driver", + "-version", + } + + flagSet := flag.NewFlagSet("test-flagset", flag.ContinueOnError) + _ = GetOptions(flagSet) + + if exitCode != 0 { + t.Fatalf("expected exit code 0 but got %d", exitCode) + } + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, tc.testFunc) + } +} diff --git a/deploy/kubernetes/base/controller.yaml b/deploy/kubernetes/base/controller.yaml index 291eabf75b..f2583f4051 100644 --- a/deploy/kubernetes/base/controller.yaml +++ b/deploy/kubernetes/base/controller.yaml @@ -26,6 +26,7 @@ spec: - name: ebs-plugin image: amazon/aws-ebs-csi-driver:latest args : + # - {all,controller,node} # specify the driver mode - --endpoint=$(CSI_ENDPOINT) - --logtostderr - --v=5 @@ -44,6 +45,9 @@ spec: name: aws-secret key: access_key optional: true + # overwrite the AWS region instead of looking it up dynamically via the AWS EC2 metadata svc + # - name: AWS_REGION + # value: us-east-1 volumeMounts: - name: socket-dir mountPath: /var/lib/csi/sockets/pluginproxy/ diff --git a/docs/design.md b/docs/design.md index dbb915a24e..1554a7cdee 100644 --- a/docs/design.md +++ b/docs/design.md @@ -216,3 +216,22 @@ Blindly return: rpc: - STAGE\_UNSTAGE\_VOLUME ``` + +## Driver modes + +Traditionally, you run the CSI controllers together with the EBS driver in the same Kubernetes cluster. +Though, in some scenarios you might want to run the CSI controllers (csi-provisioner, csi-attacher, etc.) together with the EBS controller service of this driver separately from the Kubernetes cluster it serves (while the EBS driver with an activated node service still runs inside the cluster). +This may not necessarily have to be in the same AWS region. +Also, the controllers may not necessarily have to run on an AWS EC2 instance. +To support these cases, the AWS EBS CSI driver plugin supports three modes: + +- `all`: This is the standard/default mode that is used for the mentioned traditional scenario. It assumes that the CSI controllers run together with the EBS driver in the same AWS cluster. It starts both the controller and the node service of the driver.\ +Example 1: `/bin/aws-ebs-csi-driver --extra-volume-tags=foo=bar`\ +Example 2: `/bin/aws-ebs-csi-driver all --extra-volume-tags=foo=bar` + +- `controller`: This will only start the controller service of the CSI driver. It enables use-cases as mentioned above, e.g., running the CSI controllers outside of the Kubernetes cluster they serve. Still, this mode assumes that it runs in the same AWS region on an AWS EC2 instance. If this is not true you may overwrite the region by specifying the `AWS_REGION` environment variable (if not specified the controller will try to use the AWS EC2 metadata service to look it up dynamically).\ +Example 1: `/bin/aws-ebs-csi-driver controller --extra-volume-tags=foo=bar`\ +Example 2: `AWS_REGION=us-west-1 /bin/aws-ebs-csi-driver controller --extra-volume-tags=foo=bar`\ + +- `node`: This will only start the node service of the CSI driver.\ +Example: `/bin/aws-ebs-csi-driver node --endpoint=unix://...` diff --git a/go.sum b/go.sum index ffa9f1fee0..88e97fc109 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,7 @@ github.com/codedellemc/goscaleio v0.0.0-20170830184815-20e2ce2cf885/go.mod h1:JI github.com/codegangsta/negroni v1.0.0/go.mod h1:v0y3T5G7Y1UlFfyxFn/QLRU4a2EuNau2iZY63YTKWo0= github.com/container-storage-interface/spec v1.1.0 h1:qPsTqtR1VUPvMPeK0UnCZMtXaKGyyLPG8gj/wG6VqMs= github.com/container-storage-interface/spec v1.1.0/go.mod h1:6URME8mwIBbpVyZV93Ce5St17xBiQJQY67NDsuohiy4= +github.com/container-storage-interface/spec v1.2.0 h1:bD9KIVgaVKKkQ/UbVUY9kCaH/CJbhNxe0eeB4JeJV2s= github.com/containerd/console v0.0.0-20170925154832-84eeaae905fa/go.mod h1:Tj/on1eG8kiEhd0+fhSDzsPAFESxzBBvdyEgyryXffw= github.com/containerd/containerd v1.0.2/go.mod h1:bC6axHOhabU15QhwfG7w5PipXdVtMXFTttgp+kVtyUA= github.com/containerd/typeurl v0.0.0-20190228175220-2a93cfde8c20/go.mod h1:Cm3kwCdlkCfMSHURc+r6fwoGH6/F1hH3S4sg0rLFWPc= diff --git a/pkg/driver/controller.go b/pkg/driver/controller.go index 3de13be2c7..49b300bb52 100644 --- a/pkg/driver/controller.go +++ b/pkg/driver/controller.go @@ -18,6 +18,7 @@ package driver import ( "context" + "os" "strconv" "strings" @@ -56,15 +57,28 @@ type controllerService struct { driverOptions *DriverOptions } +var ( + // NewMetadataFunc is a variable for the cloud.NewMetadata function that can + // be overwritten in unit tests. + NewMetadataFunc = cloud.NewMetadata + // NewCloudFunc is a variable for the cloud.NewCloud function that can + // be overwritten in unit tests. + NewCloudFunc = cloud.NewCloud +) + // newControllerService creates a new controller service // it panics if failed to create the service func newControllerService(driverOptions *DriverOptions) controllerService { - metadata, err := cloud.NewMetadata() - if err != nil { - panic(err) + region := os.Getenv("AWS_REGION") + if region == "" { + metadata, err := NewMetadataFunc() + if err != nil { + panic(err) + } + region = metadata.GetRegion() } - region := metadata.GetRegion() - cloud, err := cloud.NewCloud(region) + + cloud, err := NewCloudFunc(region) if err != nil { panic(err) } diff --git a/pkg/driver/controller_test.go b/pkg/driver/controller_test.go index b9a038f984..30f68255bd 100644 --- a/pkg/driver/controller_test.go +++ b/pkg/driver/controller_test.go @@ -18,8 +18,10 @@ package driver import ( "context" + "errors" "fmt" "math/rand" + "os" "reflect" "testing" "time" @@ -38,6 +40,107 @@ const ( expInstanceID = "i-123456789abcdef01" ) +func TestNewControllerService(t *testing.T) { + + var ( + cloudObj cloud.Cloud + testErr = errors.New("test error") + testRegion = "test-region" + + getNewCloudFunc = func(expectedRegion string) func(region string) (cloud.Cloud, error) { + return func(region string) (cloud.Cloud, error) { + if region != expectedRegion { + t.Fatalf("expected region %q but got %q", expectedRegion, region) + } + return cloudObj, nil + } + } + ) + + testCases := []struct { + name string + region string + newCloudFunc func(string) (cloud.Cloud, error) + newMetadataFuncErrors bool + expectPanic bool + }{ + { + name: "AWS_REGION variable set, newCloud does not error", + region: "foo", + newCloudFunc: getNewCloudFunc("foo"), + }, + { + name: "AWS_REGION variable set, newCloud errors", + region: "foo", + newCloudFunc: func(region string) (cloud.Cloud, error) { + return nil, testErr + }, + expectPanic: true, + }, + { + name: "AWS_REGION variable not set, newMetadata does not error", + newCloudFunc: getNewCloudFunc(testRegion), + }, + { + name: "AWS_REGION variable not set, newMetadata errors", + newCloudFunc: getNewCloudFunc(testRegion), + newMetadataFuncErrors: true, + expectPanic: true, + }, + } + + driverOptions := &DriverOptions{ + endpoint: "test", + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + oldNewCloudFunc := NewCloudFunc + defer func() { NewCloudFunc = oldNewCloudFunc }() + NewCloudFunc = tc.newCloudFunc + + if tc.region == "" { + mockCtl := gomock.NewController(t) + defer mockCtl.Finish() + mockMetadataService := mocks.NewMockMetadataService(mockCtl) + + oldNewMetadataFunc := NewMetadataFunc + defer func() { NewMetadataFunc = oldNewMetadataFunc }() + NewMetadataFunc = func() (cloud.MetadataService, error) { + if tc.newMetadataFuncErrors { + return nil, testErr + } + return mockMetadataService, nil + } + + if !tc.newMetadataFuncErrors { + mockMetadataService.EXPECT().GetRegion().Return(testRegion) + } + } else { + os.Setenv("AWS_REGION", tc.region) + defer os.Unsetenv("AWS_REGION") + } + + if tc.expectPanic { + defer func() { + if r := recover(); r == nil { + t.Errorf("The code did not panic") + } + }() + } + + controllerService := newControllerService(driverOptions) + + if controllerService.cloud != cloudObj { + t.Fatalf("expected cloud attribute to be equal to instantiated cloud object") + } + if !reflect.DeepEqual(controllerService.driverOptions, driverOptions) { + t.Fatalf("expected driverOptions attribute to be equal to input") + } + }) + } +} + func TestCreateVolume(t *testing.T) { stdVolCap := []*csi.VolumeCapability{ { diff --git a/pkg/driver/driver.go b/pkg/driver/driver.go index 0367f3069b..e54df9b994 100644 --- a/pkg/driver/driver.go +++ b/pkg/driver/driver.go @@ -27,6 +27,18 @@ import ( "k8s.io/klog" ) +// Mode is the operating mode of the CSI driver. +type Mode string + +const ( + // ControllerMode is the mode that only starts the controller service. + ControllerMode Mode = "controller" + // NodeMode is the mode that only starts the node service. + NodeMode Mode = "node" + // AllMode is the mode that only starts both the controller and the node service. + AllMode Mode = "all" +) + const ( DriverName = "ebs.csi.aws.com" TopologyKey = "topology." + DriverName + "/zone" @@ -43,6 +55,7 @@ type Driver struct { type DriverOptions struct { endpoint string extraVolumeTags map[string]string + mode Mode } func NewDriver(options ...func(*DriverOptions)) (*Driver, error) { @@ -50,6 +63,7 @@ func NewDriver(options ...func(*DriverOptions)) (*Driver, error) { driverOptions := DriverOptions{ endpoint: DefaultCSIEndpoint, + mode: AllMode, } for _, option := range options { option(&driverOptions) @@ -60,9 +74,19 @@ func NewDriver(options ...func(*DriverOptions)) (*Driver, error) { } driver := Driver{ - controllerService: newControllerService(&driverOptions), - nodeService: newNodeService(), - options: &driverOptions, + options: &driverOptions, + } + + switch driverOptions.mode { + case ControllerMode: + driver.controllerService = newControllerService(&driverOptions) + case NodeMode: + driver.nodeService = newNodeService() + case AllMode: + driver.controllerService = newControllerService(&driverOptions) + driver.nodeService = newNodeService() + default: + return nil, fmt.Errorf("unknown mode: %s", driverOptions.mode) } return &driver, nil @@ -92,8 +116,18 @@ func (d *Driver) Run() error { d.srv = grpc.NewServer(opts...) csi.RegisterIdentityServer(d.srv, d) - csi.RegisterControllerServer(d.srv, d) - csi.RegisterNodeServer(d.srv, d) + + switch d.options.mode { + case ControllerMode: + csi.RegisterControllerServer(d.srv, d) + case NodeMode: + csi.RegisterNodeServer(d.srv, d) + case AllMode: + csi.RegisterControllerServer(d.srv, d) + csi.RegisterNodeServer(d.srv, d) + default: + return fmt.Errorf("unknown mode: %s", d.options.mode) + } klog.Infof("Listening for connections on address: %#v", listener.Addr()) return d.srv.Serve(listener) @@ -115,3 +149,9 @@ func WithExtraVolumeTags(extraVolumeTags map[string]string) func(*DriverOptions) o.extraVolumeTags = extraVolumeTags } } + +func WithMode(mode Mode) func(*DriverOptions) { + return func(o *DriverOptions) { + o.mode = mode + } +} diff --git a/pkg/driver/fakes.go b/pkg/driver/fakes.go index 872aaead51..b0e3007e9c 100644 --- a/pkg/driver/fakes.go +++ b/pkg/driver/fakes.go @@ -26,6 +26,7 @@ import ( func NewFakeDriver(endpoint string, fakeCloud cloud.Cloud, fakeMounter *mount.FakeMounter) *Driver { driverOptions := &DriverOptions{ endpoint: endpoint, + mode: AllMode, } return &Driver{ options: driverOptions, diff --git a/pkg/driver/validation.go b/pkg/driver/validation.go index 8fee156497..1a69e1734e 100644 --- a/pkg/driver/validation.go +++ b/pkg/driver/validation.go @@ -28,6 +28,10 @@ func ValidateDriverOptions(options *DriverOptions) error { return fmt.Errorf("Invalid extra volume tags: %v", err) } + if err := validateMode(options.mode); err != nil { + return fmt.Errorf("Invalid mode: %v", err) + } + return nil } @@ -56,3 +60,11 @@ func validateExtraVolumeTags(tags map[string]string) error { return nil } + +func validateMode(mode Mode) error { + if mode != AllMode && mode != ControllerMode && mode != NodeMode { + return fmt.Errorf("Mode is not supported (actual: %s, supported: %v)", mode, []Mode{AllMode, ControllerMode, NodeMode}) + } + + return nil +} diff --git a/pkg/driver/validation_test.go b/pkg/driver/validation_test.go index 137f011af9..64427a00fc 100644 --- a/pkg/driver/validation_test.go +++ b/pkg/driver/validation_test.go @@ -108,3 +108,81 @@ func TestValidateExtraVolumeTags(t *testing.T) { }) } } + +func TestValidateMode(t *testing.T) { + testCases := []struct { + name string + mode Mode + expErr error + }{ + { + name: "valid mode: all", + mode: AllMode, + expErr: nil, + }, + { + name: "valid mode: controller", + mode: ControllerMode, + expErr: nil, + }, + { + name: "valid mode: node", + mode: NodeMode, + expErr: nil, + }, + { + name: "invalid mode: unknown", + mode: Mode("unknown"), + expErr: fmt.Errorf("Mode is not supported (actual: unknown, supported: %v)", []Mode{AllMode, ControllerMode, NodeMode}), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := validateMode(tc.mode) + if !reflect.DeepEqual(err, tc.expErr) { + t.Fatalf("error not equal\ngot:\n%s\nexpected:\n%s", err, tc.expErr) + } + }) + } +} + +func TestValidateDriverOptions(t *testing.T) { + testCases := []struct { + name string + mode Mode + extraVolumeTags map[string]string + expErr error + }{ + { + name: "success", + mode: AllMode, + expErr: nil, + }, + { + name: "fail because validateMode fails", + mode: Mode("unknown"), + expErr: fmt.Errorf("Invalid mode: Mode is not supported (actual: unknown, supported: %v)", []Mode{AllMode, ControllerMode, NodeMode}), + }, + { + name: "fail because validateExtraVolumeTags fails", + mode: AllMode, + extraVolumeTags: map[string]string{ + randomString(cloud.MaxTagKeyLength + 1): "extra-tag-value", + }, + expErr: fmt.Errorf("Invalid extra volume tags: Volume tag key too long (actual: %d, limit: %d)", cloud.MaxTagKeyLength+1, cloud.MaxTagKeyLength), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := ValidateDriverOptions(&DriverOptions{ + extraVolumeTags: tc.extraVolumeTags, + mode: tc.mode, + }) + if !reflect.DeepEqual(err, tc.expErr) { + t.Fatalf("error not equal\ngot:\n%s\nexpected:\n%s", err, tc.expErr) + } + }) + } +}