diff --git a/internal/goenv/goenv.go b/internal/goenv/goenv.go index 2f207aa0..095e874e 100644 --- a/internal/goenv/goenv.go +++ b/internal/goenv/goenv.go @@ -3,46 +3,24 @@ package goenv import ( "errors" "os/exec" - "runtime" - "strconv" "strings" ) -func Read() (map[string]string, error) { - out, err := exec.Command("go", "env").CombinedOutput() +func Read(varNames []string) (map[string]string, error) { + out, err := exec.Command("go", append([]string{"env"}, varNames...)...).CombinedOutput() if err != nil { return nil, err } - return parseGoEnv(out, runtime.GOOS) + return parseGoEnv(varNames, out) } -func parseGoEnv(data []byte, goos string) (map[string]string, error) { +func parseGoEnv(varNames []string, data []byte) (map[string]string, error) { vars := make(map[string]string) lines := strings.Split(strings.ReplaceAll(string(data), "\r\n", "\n"), "\n") - - if goos == "windows" { - // Line format is: `set $name=$value` - for _, l := range lines { - l = strings.TrimPrefix(l, "set ") - parts := strings.Split(l, "=") - if len(parts) != 2 { - continue - } - vars[parts[0]] = parts[1] - } - } else { - // Line format is: `$name="$value"` - for _, l := range lines { - parts := strings.Split(strings.TrimSpace(l), "=") - if len(parts) != 2 { - continue - } - val, err := strconv.Unquote(parts[1]) - if err != nil { - continue - } - vars[parts[0]] = val + for i, varName := range varNames { + if i < len(lines) && len(lines[i]) > 0 { + vars[varName] = lines[i] } } diff --git a/internal/goenv/goenv_test.go b/internal/goenv/goenv_test.go index 957e7059..259b1a35 100644 --- a/internal/goenv/goenv_test.go +++ b/internal/goenv/goenv_test.go @@ -7,69 +7,83 @@ import ( func TestParse(t *testing.T) { tests := []struct { - goos string lines []string goroot string gopath string + err bool }{ + // handle windows line-endings { - goos: "windows", lines: []string{ - "set GOROOT=C:\\Program Files\\Go\r\n", - "set GOPATH=C:\\Users\\me\\go\r\n", + "C:\\Program Files\\Go\r\n", + "C:\\Users\\me\\go\r\n", }, goroot: "C:\\Program Files\\Go", gopath: "C:\\Users\\me\\go", }, - // Don't do trim on Windows. + // preserve trailing spaces on windows { - goos: "windows", lines: []string{ - "set GOROOT=C:\\Program Files\\Go \r\n", - "set GOPATH=C:\\Users\\me\\go \r\n", + "C:\\Program Files\\Go \r\n", + "C:\\Users\\me\\go \r\n", }, goroot: "C:\\Program Files\\Go ", gopath: "C:\\Users\\me\\go ", }, + // handle linux line-endings { - goos: "linux", lines: []string{ - "GOROOT=\"/usr/local/go\"\n", - "GOPATH=\"/home/me/go\"\n", + "/usr/local/go\n", + "/home/me/go\n", }, goroot: "/usr/local/go", gopath: "/home/me/go", }, - // Trim lines on Linux. + // preserve trailing spaces on linux { - goos: "linux", lines: []string{ - " GOROOT=\"/usr/local/go\" \n", - "GOPATH=\"/home/me/go\" \n", + "/usr/local/go \n", + "/home/me/go \n", }, - goroot: "/usr/local/go", + goroot: "/usr/local/go ", + gopath: "/home/me/go ", + }, + + // handle empty value + { + lines: []string{ + "\n", + "/home/me/go\n", + }, + goroot: "", gopath: "/home/me/go", }, - // Quotes preserve the whitespace. + // handle short output { - goos: "linux", lines: []string{ - " GOROOT=\"/usr/local/go \" \n", - "GOPATH=\"/home/me/go \" \n", + "/usr/local/go", }, - goroot: "/usr/local/go ", - gopath: "/home/me/go ", + goroot: "/usr/local/go", + gopath: "", + }, + + // handle empty output + { + lines: []string{}, + goroot: "", + gopath: "", + err: true, }, } for i, test := range tests { data := []byte(strings.Join(test.lines, "")) - vars, err := parseGoEnv(data, test.goos) - if err != nil { + vars, err := parseGoEnv([]string{"GOROOT", "GOPATH"}, data) + if err != nil != test.err { t.Fatalf("test %d failed: %v", i, err) } if vars["GOROOT"] != test.goroot { diff --git a/ruleguard/engine.go b/ruleguard/engine.go index 88feef92..cadc0ddc 100644 --- a/ruleguard/engine.go +++ b/ruleguard/engine.go @@ -248,7 +248,7 @@ func inferBuildContext() *build.Context { // Inherit most fields from the build.Default. ctx := build.Default - env, err := goenv.Read() + env, err := goenv.Read([]string{"GOROOT", "GOPATH", "GOARCH", "GOOS", "CGO_ENABLED"}) if err != nil { return &ctx }