diff --git a/cmd/src/admin.go b/cmd/src/admin.go new file mode 100644 index 0000000000..ac8af95205 --- /dev/null +++ b/cmd/src/admin.go @@ -0,0 +1,38 @@ +package main + +import ( + "flag" + "fmt" +) + +var adminCommands commander + +func init() { + usage := `'src admin' is a tool that manages an initial admin user on a new Sourcegraph instance. + +Usage: + + src admin create [command options] + +The commands are: + + create create an initial admin user + +Use "src admin [command] -h" for more information about a command. +` + + flagSet := flag.NewFlagSet("admin", flag.ExitOnError) + handler := func(args []string) error { + adminCommands.run(flagSet, "srv admin", usage, args) + return nil + } + + commands = append(commands, &command{ + flagSet: flagSet, + aliases: []string{"admin"}, + handler: handler, + usageFunc: func() { + fmt.Println(usage) + }, + }) +} diff --git a/cmd/src/admin_create.go b/cmd/src/admin_create.go new file mode 100644 index 0000000000..5b7bdf9110 --- /dev/null +++ b/cmd/src/admin_create.go @@ -0,0 +1,100 @@ +package main + +import ( + "flag" + "fmt" + "os" + + "github.com/sourcegraph/src-cli/internal/users" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +func init() { + usage := ` +Examples: + + Create an initial admin user on a new Sourcegraph deployment: + + $ src admin create -url https://your-sourcegraph-url -username admin -email admin@yourcompany.com -with-token + + Create an initial admin user on a new Sourcegraph deployment using '-password' flag. + WARNING: for security purposes we strongly recommend using the SRC_ADMIN_PASS environment variable when possible. + + $ src admin create -url https://your-sourcegraph-url -username admin -email admin@yourcompany.com -password p@55w0rd -with-token + +Environmental variables + + SRC_ADMIN_PASS The new admin user's password +` + + flagSet := flag.NewFlagSet("create", flag.ExitOnError) + usageFunc := func() { + fmt.Fprintf(flag.CommandLine.Output(), "Usage of 'src users %s':\n", flagSet.Name()) + flagSet.PrintDefaults() + fmt.Println(usage) + } + + var ( + urlFlag = flagSet.String("url", "", "The base URL for the Sourcegraph instance.") + usernameFlag = flagSet.String("username", "", "The new admin user's username.") + emailFlag = flagSet.String("email", "", "The new admin user's email address.") + passwordFlag = flagSet.String("password", "", "The new admin user's password.") + tokenFlag = flagSet.Bool("with-token", false, "Optionally create and output an admin access token.") + ) + + handler := func(args []string) error { + if err := flagSet.Parse(args); err != nil { + return err + } + + ok, _, err := users.NeedsSiteInit(*urlFlag) + if err != nil { + return err + } + if !ok { + return errors.New("failed to create admin, site already initialized") + } + + envAdminPass := os.Getenv("SRC_ADMIN_PASS") + + var client *users.Client + + switch { + case envAdminPass != "" && *passwordFlag == "": + client, err = users.SiteAdminInit(*urlFlag, *emailFlag, *usernameFlag, envAdminPass) + if err != nil { + return err + } + case envAdminPass == "" && *passwordFlag != "": + client, err = users.SiteAdminInit(*urlFlag, *emailFlag, *usernameFlag, *passwordFlag) + if err != nil { + return err + } + case envAdminPass != "" && *passwordFlag != "": + return errors.New("failed to read admin password: environment variable and -password flag both set") + case envAdminPass == "" && *passwordFlag == "": + return errors.New("failed to read admin password from 'SRC_ADMIN_PASS' environment variable or -password flag") + } + + if *tokenFlag { + token, err := client.CreateAccessToken("", []string{"user:all", "site-admin:sudo"}, "src-cli") + if err != nil { + return err + } + + _, err = fmt.Fprintf(flag.CommandLine.Output(), "%s\n", token) + if err != nil { + return err + } + } + + return nil + } + + adminCommands = append(adminCommands, &command{ + flagSet: flagSet, + handler: handler, + usageFunc: usageFunc, + }) +} diff --git a/internal/lazyregexp/lazyregexp.go b/internal/lazyregexp/lazyregexp.go new file mode 100644 index 0000000000..dc55d78f1a --- /dev/null +++ b/internal/lazyregexp/lazyregexp.go @@ -0,0 +1,128 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package lazyregexp is a thin wrapper over regexp, allowing the use of global +// regexp variables without forcing them to be compiled at init. +package lazyregexp + +import ( + "os" + "strings" + "sync" + + "github.com/grafana/regexp" +) + +// Regexp is a wrapper around regexp.Regexp, where the underlying regexp will be +// compiled the first time it is needed. +type Regexp struct { + str string + posix bool + once sync.Once + rx *regexp.Regexp +} + +func (r *Regexp) Re() *regexp.Regexp { + r.once.Do(r.build) + return r.rx +} + +func (r *Regexp) build() { + if r.posix { + r.rx = regexp.MustCompilePOSIX(r.str) + } else { + r.rx = regexp.MustCompile(r.str) + } + r.str = "" +} + +func (r *Regexp) FindSubmatch(s []byte) [][]byte { + return r.Re().FindSubmatch(s) +} + +func (r *Regexp) FindStringSubmatch(s string) []string { + return r.Re().FindStringSubmatch(s) +} + +func (r *Regexp) FindStringSubmatchIndex(s string) []int { + return r.Re().FindStringSubmatchIndex(s) +} + +func (r *Regexp) ReplaceAllString(src, repl string) string { + return r.Re().ReplaceAllString(src, repl) +} + +func (r *Regexp) FindString(s string) string { + return r.Re().FindString(s) +} + +func (r *Regexp) FindAllString(s string, n int) []string { + return r.Re().FindAllString(s, n) +} + +func (r *Regexp) MatchString(s string) bool { + return r.Re().MatchString(s) +} + +func (r *Regexp) SubexpNames() []string { + return r.Re().SubexpNames() +} + +func (r *Regexp) FindAllStringSubmatch(s string, n int) [][]string { + return r.Re().FindAllStringSubmatch(s, n) +} + +func (r *Regexp) Split(s string, n int) []string { + return r.Re().Split(s, n) +} + +func (r *Regexp) ReplaceAllLiteralString(src, repl string) string { + return r.Re().ReplaceAllLiteralString(src, repl) +} + +func (r *Regexp) FindAllIndex(b []byte, n int) [][]int { + return r.Re().FindAllIndex(b, n) +} + +func (r *Regexp) Match(b []byte) bool { + return r.Re().Match(b) +} + +func (r *Regexp) ReplaceAllStringFunc(src string, repl func(string) string) string { + return r.Re().ReplaceAllStringFunc(src, repl) +} + +func (r *Regexp) ReplaceAll(src, repl []byte) []byte { + return r.Re().ReplaceAll(src, repl) +} + +func (r *Regexp) SubexpIndex(s string) int { + return r.Re().SubexpIndex(s) +} + +var inTest = len(os.Args) > 0 && strings.HasSuffix(strings.TrimSuffix(os.Args[0], ".exe"), ".test") + +// New creates a new lazy regexp, delaying the compiling work until it is first +// needed. If the code is being run as part of tests, the regexp compiling will +// happen immediately. +func New(str string) *Regexp { + lr := &Regexp{str: str} + if inTest { + // In tests, always compile the regexps early. + lr.Re() + } + return lr +} + +// NewPOSIX creates a new lazy regexp, delaying the compiling work until it is +// first needed. If the code is being run as part of tests, the regexp +// compiling will happen immediately. +func NewPOSIX(str string) *Regexp { + lr := &Regexp{str: str, posix: true} + if inTest { + // In tests, always compile the regexps early. + lr.Re() + } + return lr +} diff --git a/internal/users/admin.go b/internal/users/admin.go new file mode 100644 index 0000000000..dad8584989 --- /dev/null +++ b/internal/users/admin.go @@ -0,0 +1,330 @@ +package users + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strings" + + jsoniter "github.com/json-iterator/go" + "github.com/sourcegraph/src-cli/internal/lazyregexp" + + "github.com/sourcegraph/sourcegraph/lib/errors" +) + +// NeedsSiteInit returns true if the instance hasn't done "Site admin init" step. +func NeedsSiteInit(baseURL string) (bool, string, error) { + resp, err := http.Get(baseURL + "/sign-in") + if err != nil { + return false, "", errors.Wrap(err, "sign-in page") + } + defer func() { _ = resp.Body.Close() }() + + p, err := io.ReadAll(resp.Body) + if err != nil { + return false, "", errors.Wrap(err, "read body") + } + return strings.Contains(string(p), `"needsSiteInit":true`), string(p), nil +} + +// SiteAdminInit initializes the instance with given admin account. +// It returns an authenticated client as the admin for doing testing. +func SiteAdminInit(baseURL, email, username, password string) (*Client, error) { + return authenticate(baseURL, "/-/site-init", map[string]string{ + "email": email, + "username": username, + "password": password, + }) +} + +// authenticate initializes an authenticated client with given request body. +func authenticate(baseURL, path string, body any) (*Client, error) { + c, err := NewClient(baseURL, nil, nil) + if err != nil { + return nil, errors.Wrap(err, "new client") + } + + err = c.authenticate(path, body) + if err != nil { + return nil, errors.Wrap(err, "authenticate") + } + + return c, nil +} + +// Client is an authenticated client for a Sourcegraph user for doing e2e testing. +// The user may or may not be a site admin depends on how the client is instantiated. +// It works by simulating how the browser would send HTTP requests to the server. +type Client struct { + baseURL string + csrfToken string + csrfCookie *http.Cookie + sessionCookie *http.Cookie + + userID string + requestLogger logFunc + responseLogger logFunc +} + +type logFunc func(payload []byte) + +func noopLog(payload []byte) {} + +// NewClient instantiates a new client by performing a GET request then obtains the +// CSRF token and cookie from its response, if there is one (old versions of Sourcegraph only). +// If request- or responseLogger are provided, the request and response bodies, respectively, +// will be written to them for any GraphQL requests only. +func NewClient(baseURL string, requestLogger, responseLogger logFunc) (*Client, error) { + if requestLogger == nil { + requestLogger = noopLog + } + if responseLogger == nil { + responseLogger = noopLog + } + + resp, err := http.Get(baseURL) + if err != nil { + return nil, errors.Wrap(err, "get URL") + } + defer func() { _ = resp.Body.Close() }() + + p, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errors.Wrap(err, "read GET body") + } + + csrfToken := extractCSRFToken(string(p)) + var csrfCookie *http.Cookie + for _, cookie := range resp.Cookies() { + if cookie.Name == "sg_csrf_token" { + csrfCookie = cookie + break + } + } + + return &Client{ + baseURL: baseURL, + csrfToken: csrfToken, + csrfCookie: csrfCookie, + requestLogger: requestLogger, + responseLogger: responseLogger, + }, nil +} + +// authenticate is used to send an HTTP POST request to a URL that is able to authenticate +// a user with given body (marshalled to JSON), e.g. site admin init, sign in. Once the +// client is authenticated, the session cookie will be stored as a proof of authentication. +func (c *Client) authenticate(path string, body any) error { + p, err := jsoniter.Marshal(body) + if err != nil { + return errors.Wrap(err, "marshal body") + } + + req, err := http.NewRequest("POST", c.baseURL+path, bytes.NewReader(p)) + if err != nil { + return errors.Wrap(err, "new request") + } + req.Header.Set("Content-Type", "application/json") + if c.csrfToken != "" { + req.Header.Set("X-Csrf-Token", c.csrfToken) + } + if c.csrfCookie != nil { + req.AddCookie(c.csrfCookie) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Wrap(err, "do request") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + p, err := io.ReadAll(resp.Body) + if err != nil { + return errors.Wrap(err, "read response body") + } + return errors.New(string(p)) + } + + var sessionCookie *http.Cookie + for _, cookie := range resp.Cookies() { + if cookie.Name == "sgs" { + sessionCookie = cookie + break + } + } + if sessionCookie == nil { + return errors.Wrap(err, `"sgs" cookie not found`) + } + c.sessionCookie = sessionCookie + + userID, err := c.CurrentUserID("") + if err != nil { + return errors.Wrap(err, "get current user") + } + c.userID = userID + return nil +} + +// CurrentUserID returns the current authenticated user's GraphQL node ID. +// An optional token can be passed to impersonate other users. +func (c *Client) CurrentUserID(token string) (string, error) { + const query = ` + query { + currentUser { + id + } + } +` + var resp struct { + Data struct { + CurrentUser struct { + ID string `json:"id"` + } `json:"currentUser"` + } `json:"data"` + } + err := c.GraphQL(token, query, nil, &resp) + if err != nil { + return "", errors.Wrap(err, "request GraphQL") + } + + return resp.Data.CurrentUser.ID, nil +} + +// CreateAccessToken creates an access token for the current user. +// An optional token can be passed to impersonate other users. +func (c *Client) CreateAccessToken(token string, scopes []string, note string) (string, error) { + userID, err := c.CurrentUserID("") + if err != nil { + return "", err + } + + const mutation = ` + mutation createAccessToken($user: ID!, $scopes: [String!]!, $note: String!) { + createAccessToken(user: $user, scopes: $scopes, note: $note) { + token + } + } +` + + var resp struct { + Data struct { + CreateAccessToken struct { + Token string `json:"token"` + } `json:"createAccessToken"` + } + } + + err = c.GraphQL(token, mutation, map[string]any{ + "user": userID, + "scopes": scopes, + "note": note, + }, &resp) + + if err != nil { + return "", err + } + + return resp.Data.CreateAccessToken.Token, nil +} + +var graphqlQueryNameRe = lazyregexp.New(`(query|mutation) +(\w)+`) + +// GraphQL makes a GraphQL request to the server on behalf of the user authenticated by the client. +// An optional token can be passed to impersonate other users. A nil target will skip unmarshalling +// the returned JSON response. +func (c *Client) GraphQL(token, query string, variables map[string]any, target any) error { + body, err := jsoniter.Marshal(map[string]any{ + "query": query, + "variables": variables, + }) + if err != nil { + return err + } + + var name string + if matches := graphqlQueryNameRe.FindStringSubmatch(query); len(matches) >= 2 { + name = matches[2] + } + + req, err := http.NewRequest("POST", fmt.Sprintf("%s/.api/graphql?%s", c.baseURL, name), bytes.NewReader(body)) + if err != nil { + return err + } + if token != "" { + req.Header.Set("Authorization", fmt.Sprintf("token %s", token)) + } else { + // NOTE: This header is required to authenticate our session with a session cookie, see: + // https://docs.sourcegraph.com/dev/security/csrf_security_model#authentication-in-api-endpoints + req.Header.Set("X-Requested-With", "Sourcegraph") + req.AddCookie(c.sessionCookie) + + // Older versions of Sourcegraph require a CSRF cookie. + if c.csrfCookie != nil { + req.AddCookie(c.csrfCookie) + } + } + + c.requestLogger(body) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer func() { _ = resp.Body.Close() }() + + body, err = io.ReadAll(resp.Body) + if err != nil { + return errors.Wrap(err, "read response body") + } + + c.responseLogger(body) + + // Check if the response format should be JSON + if strings.Contains(resp.Header.Get("Content-Type"), "application/json") { + // Try and see unmarshalling to errors + var errResp struct { + Errors []struct { + Message string `json:"message"` + } `json:"errors"` + } + err = jsoniter.Unmarshal(body, &errResp) + if err != nil { + return errors.Wrap(err, "unmarshal response body to errors") + } + if len(errResp.Errors) > 0 { + var errs error + for _, err := range errResp.Errors { + errs = errors.Append(errs, errors.New(err.Message)) + } + return errs + } + } + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("%d: %s", resp.StatusCode, string(body)) + } + + if target == nil { + return nil + } + + return jsoniter.Unmarshal(body, &target) +} + +// extractCSRFToken extracts CSRF token from HTML response body. +func extractCSRFToken(body string) string { + anchor := `X-Csrf-Token":"` + i := strings.Index(body, anchor) + if i == -1 { + return "" + } + + j := strings.Index(body[i+len(anchor):], `","`) + if j == -1 { + return "" + } + + return body[i+len(anchor) : i+len(anchor)+j] +}