From 2eaf999a0e74bfa9c86eae861787c664094fbb2a Mon Sep 17 00:00:00 2001 From: Brian McGee Date: Thu, 2 May 2024 10:56:32 +0100 Subject: [PATCH] feat: refactor some config init logic into config package Signed-off-by: Brian McGee --- cli/format.go | 42 +++++++++--------------------------------- config/config.go | 39 ++++++++++++++++++++++++++++++++++++--- config/config_test.go | 2 +- format/formatter.go | 14 +++++++------- 4 files changed, 53 insertions(+), 44 deletions(-) diff --git a/cli/format.go b/cli/format.go index 617861eb..ae322349 100644 --- a/cli/format.go +++ b/cli/format.go @@ -9,8 +9,6 @@ import ( "os/signal" "path/filepath" "runtime" - "slices" - "sort" "strings" "syscall" @@ -41,10 +39,10 @@ var ( ) func (f *Format) Run() (err error) { - stats.Init() - + // create a prefixed logger l := log.WithPrefix("format") + // ensure cache is closed on return defer func() { if err := cache.Close(); err != nil { l.Errorf("failed to close cache: %v", err) @@ -52,46 +50,21 @@ func (f *Format) Run() (err error) { }() // read config - cfg, err := config.ReadFile(Cli.ConfigFile) + cfg, err := config.ReadFile(Cli.ConfigFile, Cli.Formatters) if err != nil { return fmt.Errorf("%w: failed to read config file", err) } + // compile global exclude globs if globalExcludes, err = format.CompileGlobs(cfg.Global.Excludes); err != nil { return fmt.Errorf("%w: failed to compile global globs", err) } + // initialise pipelines pipelines = make(map[string]*format.Pipeline) formatters = make(map[string]*format.Formatter) - // filter formatters - if len(Cli.Formatters) > 0 { - // first check the cli formatter list is valid - for _, name := range Cli.Formatters { - _, ok := cfg.Formatters[name] - if !ok { - return fmt.Errorf("formatter not found in config: %v", name) - } - } - // next we remove any formatter configs that were not specified - for name := range cfg.Formatters { - if !slices.Contains(Cli.Formatters, name) { - delete(cfg.Formatters, name) - } - } - } - - // sort the formatter names so that, as we construct pipelines, we add formatters in a determinstic fashion. This - // ensures a deterministic order even when all priority values are the same e.g. 0 - - names := make([]string, 0, len(cfg.Formatters)) - for name := range cfg.Formatters { - names = append(names, name) - } - sort.Strings(names) - - // init formatters - for _, name := range names { + for _, name := range cfg.Names { formatterCfg := cfg.Formatters[name] formatter, err := format.NewFormatter(name, Cli.TreeRoot, formatterCfg, globalExcludes) if errors.Is(err, format.ErrCommandNotFound) && Cli.AllowMissingFormatter { @@ -134,6 +107,9 @@ func (f *Format) Run() (err error) { cancel() }() + // initialise stats collection + stats.Init() + // create some groups for concurrent processing and control flow eg, ctx := errgroup.WithContext(ctx) diff --git a/config/config.go b/config/config.go index e116b0ff..2fb5f547 100644 --- a/config/config.go +++ b/config/config.go @@ -1,6 +1,11 @@ package config -import "github.com/BurntSushi/toml" +import ( + "fmt" + "sort" + + "github.com/BurntSushi/toml" +) // Config is used to represent the list of configured Formatters. type Config struct { @@ -8,11 +13,39 @@ type Config struct { // Excludes is an optional list of glob patterns used to exclude certain files from all formatters. Excludes []string } + Names []string `toml:"-"` Formatters map[string]*Formatter `toml:"formatter"` } // ReadFile reads from path and unmarshals toml into a Config instance. -func ReadFile(path string) (cfg *Config, err error) { - _, err = toml.DecodeFile(path, &cfg) +func ReadFile(path string, names []string) (cfg *Config, err error) { + if _, err = toml.DecodeFile(path, &cfg); err != nil { + return nil, fmt.Errorf("failed to decode config file: %w", err) + } + + // filter formatters based on provided names + if len(names) > 0 { + filtered := make(map[string]*Formatter) + + // check if the provided names exist in the config + for _, name := range names { + formatterCfg, ok := cfg.Formatters[name] + if !ok { + return nil, fmt.Errorf("formatter %v not found in config", name) + } + filtered[name] = formatterCfg + } + + // updated formatters + cfg.Formatters = filtered + } + + // sort the formatter names so that, as we construct pipelines, we add formatters in a determinstic fashion. This + // ensures a deterministic order even when all priority values are the same e.g. 0 + for name := range cfg.Formatters { + cfg.Names = append(cfg.Names, name) + } + sort.Strings(cfg.Names) + return } diff --git a/config/config_test.go b/config/config_test.go index cfe67e00..6ad74937 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -9,7 +9,7 @@ import ( func TestReadConfigFile(t *testing.T) { as := require.New(t) - cfg, err := ReadFile("../test/examples/treefmt.toml") + cfg, err := ReadFile("../test/examples/treefmt.toml", nil) as.NoError(err, "failed to read config file") as.NotNil(cfg) diff --git a/format/formatter.go b/format/formatter.go index 1b66a57e..a0e16e33 100644 --- a/format/formatter.go +++ b/format/formatter.go @@ -111,7 +111,7 @@ func (f *Formatter) Wants(file *walk.File) bool { func NewFormatter( name string, treeRoot string, - config *config.Formatter, + cfg *config.Formatter, globalExcludes []glob.Glob, ) (*Formatter, error) { var err error @@ -120,11 +120,11 @@ func NewFormatter( // capture config and the formatter's name f.name = name - f.config = config + f.config = cfg f.workingDir = treeRoot // test if the formatter is available - executable, err := exec.LookPath(config.Command) + executable, err := exec.LookPath(cfg.Command) if errors.Is(err, exec.ErrNotFound) { return nil, ErrCommandNotFound } else if err != nil { @@ -133,18 +133,18 @@ func NewFormatter( f.executable = executable // initialise internal state - if config.Pipeline == "" { + if cfg.Pipeline == "" { f.log = log.WithPrefix(fmt.Sprintf("format | %s", name)) } else { - f.log = log.WithPrefix(fmt.Sprintf("format | %s[%s]", config.Pipeline, name)) + f.log = log.WithPrefix(fmt.Sprintf("format | %s[%s]", cfg.Pipeline, name)) } - f.includes, err = CompileGlobs(config.Includes) + f.includes, err = CompileGlobs(cfg.Includes) if err != nil { return nil, fmt.Errorf("%w: formatter '%v' includes", err, f.name) } - f.excludes, err = CompileGlobs(config.Excludes) + f.excludes, err = CompileGlobs(cfg.Excludes) if err != nil { return nil, fmt.Errorf("%w: formatter '%v' excludes", err, f.name) }