diff --git a/bind/describe.go b/bind/describe.go new file mode 100644 index 00000000..71bf1751 --- /dev/null +++ b/bind/describe.go @@ -0,0 +1,62 @@ +// Copyright 2023 Sauce Labs Inc. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package bind + +import ( + "encoding/json" + "fmt" + "sort" + "strings" + + "github.com/spf13/pflag" +) + +type DescribeFormat int + +const ( + Plain DescribeFormat = iota + JSON +) + +func DescribeFlags(fs *pflag.FlagSet, showHidden bool, format DescribeFormat) (string, error) { + args := make(map[string]any, fs.NFlag()) + keys := make([]string, 0, fs.NFlag()) + + fs.VisitAll(func(flag *pflag.Flag) { + if flag.Name == "help" { + return + } + + if flag.Hidden && !showHidden { + return + } + + if flag.Value.Type() == "bool" { + args[flag.Name] = flag.Value + } else { + args[flag.Name] = strings.Trim(flag.Value.String(), "[]") + } + + keys = append(keys, flag.Name) + }) + + sort.Strings(keys) + + switch format { + case Plain: + var b strings.Builder + for _, name := range keys { + b.WriteString(fmt.Sprintf("%s=%s\n", name, args[name])) + } + return b.String(), nil + case JSON: + encoded, err := json.Marshal(args) + return string(encoded), err + default: + return "", fmt.Errorf("unknown format requested") + } +} diff --git a/bind/describe_test.go b/bind/describe_test.go new file mode 100644 index 00000000..50f7e846 --- /dev/null +++ b/bind/describe_test.go @@ -0,0 +1,156 @@ +// Copyright 2023 Sauce Labs Inc. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +package bind + +import ( + "testing" + + "github.com/spf13/pflag" +) + +func TestDescribeFlagsAsPlain(t *testing.T) { + tests := map[string]struct { + input map[string]interface{} + expected string + isErr bool + isHidden bool + showHidden bool + }{ + "keys are sorted": { + input: map[string]interface{}{"foo": false, "bar": true}, + expected: "bar=true\nfoo=false\n", + isErr: false, + }, + "bool is correctly formatted": { + input: map[string]interface{}{"key": false}, + expected: "key=false\n", + isErr: false, + }, + "string is correctly formatted": { + input: map[string]interface{}{"key": "val"}, + expected: "key=val\n", + isErr: false, + }, + "help is not shown": { + input: map[string]interface{}{"key": false, "help": true}, + expected: "key=false\n", + isErr: false, + }, + "hidden is shown": { + input: map[string]interface{}{"key": false}, + expected: "key=false\n", + isErr: false, + isHidden: true, + showHidden: true, + }, + "hidden is not shown": { + input: map[string]interface{}{"key": false}, + expected: ``, + isErr: false, + isHidden: true, + showHidden: false, + }, + } + + for name, tc := range tests { + fs := pflag.NewFlagSet("flags", pflag.ContinueOnError) + + for k, v := range tc.input { + switch val := v.(type) { + case bool: + fs.Bool(k, val, "") + case string: + fs.String(k, val, "") + } + + if tc.isHidden { + err := fs.MarkHidden(k) + if err != nil { + t.Errorf("%s: test setup failed: %s", name, err) + } + } + } + result, err := DescribeFlags(fs, tc.showHidden, Plain) + + if (err != nil) != tc.isErr { + t.Errorf("%s: expected error: %v, got %s", name, tc.isErr, err) + } + + if result != tc.expected { + t.Errorf("%s: expected %s, got %s", name, tc.expected, result) + } + } +} + +func TestDescribeFlagsAsJSON(t *testing.T) { + tests := map[string]struct { + input map[string]interface{} + expected string + isErr bool + isHidden bool + showHidden bool + }{ + "bool is not quoted": { + input: map[string]interface{}{"key": false}, + expected: `{"key":false}`, + isErr: false, + }, + "help is not shown": { + input: map[string]interface{}{"key": false, "help": true}, + expected: `{"key":false}`, + isErr: false, + }, + "hidden is shown": { + input: map[string]interface{}{"key": false}, + expected: `{"key":false}`, + isErr: false, + isHidden: true, + showHidden: true, + }, + "hidden is not shown": { + input: map[string]interface{}{"key": false}, + expected: `{}`, + isErr: false, + isHidden: true, + showHidden: false, + }, + "string is quoted": { + input: map[string]interface{}{"key": "val"}, + expected: `{"key":"val"}`, + isErr: false, + }, + } + + for name, tc := range tests { + fs := pflag.NewFlagSet("flags", pflag.ContinueOnError) + + for k, v := range tc.input { + switch val := v.(type) { + case bool: + fs.Bool(k, val, "") + case string: + fs.String(k, val, "") + } + + if tc.isHidden { + err := fs.MarkHidden(k) + if err != nil { + t.Errorf("%s: test setup failed: %s", name, err) + } + } + } + result, err := DescribeFlags(fs, tc.showHidden, JSON) + + if (err != nil) != tc.isErr { + t.Errorf("%s: expected error: %v, got %s", name, tc.isErr, err) + } + + if result != tc.expected { + t.Errorf("%s: expected %s, got %s", name, tc.expected, result) + } + } +} diff --git a/bind/flag.go b/bind/flag.go index acdf6764..4891b295 100644 --- a/bind/flag.go +++ b/bind/flag.go @@ -7,7 +7,6 @@ package bind import ( - "fmt" "net/netip" "net/url" "os" @@ -328,14 +327,3 @@ func MarkFlagFilename(cmd *cobra.Command, names ...string) { } } } - -func DescribeFlags(fs *pflag.FlagSet) string { - var b strings.Builder - fs.VisitAll(func(flag *pflag.Flag) { - if flag.Hidden || flag.Name == "help" { - return - } - b.WriteString(fmt.Sprintf("%s=%s\n", flag.Name, strings.Trim(flag.Value.String(), "[]"))) - }) - return b.String() -} diff --git a/cmd/forwarder/httpbin/httpbin.go b/cmd/forwarder/httpbin/httpbin.go index 870b6c89..7af2284a 100644 --- a/cmd/forwarder/httpbin/httpbin.go +++ b/cmd/forwarder/httpbin/httpbin.go @@ -24,7 +24,10 @@ type command struct { } func (c *command) RunE(cmd *cobra.Command, _ []string) error { - config := bind.DescribeFlags(cmd.Flags()) + config, err := bind.DescribeFlags(cmd.Flags(), false, bind.Plain) + if err != nil { + return err + } if f := c.logConfig.File; f != nil { defer f.Close() diff --git a/cmd/forwarder/pac/server/server.go b/cmd/forwarder/pac/server/server.go index fac99887..79de9f1f 100644 --- a/cmd/forwarder/pac/server/server.go +++ b/cmd/forwarder/pac/server/server.go @@ -31,7 +31,10 @@ type command struct { } func (c *command) RunE(cmd *cobra.Command, _ []string) error { - config := bind.DescribeFlags(cmd.Flags()) + config, err := bind.DescribeFlags(cmd.Flags(), false, bind.Plain) + if err != nil { + return err + } if f := c.logConfig.File; f != nil { defer f.Close() diff --git a/cmd/forwarder/run/run.go b/cmd/forwarder/run/run.go index 573d8dc7..29bed000 100644 --- a/cmd/forwarder/run/run.go +++ b/cmd/forwarder/run/run.go @@ -43,7 +43,10 @@ type command struct { } func (c *command) RunE(cmd *cobra.Command, _ []string) error { - config := bind.DescribeFlags(cmd.Flags()) + config, err := bind.DescribeFlags(cmd.Flags(), false, bind.Plain) + if err != nil { + return err + } if f := c.logConfig.File; f != nil { defer f.Close()