Skip to content
This repository has been archived by the owner on Feb 27, 2023. It is now read-only.

Commit

Permalink
Add default session image to config (#347)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaasen committed Jan 20, 2022
1 parent d67770d commit 66fe818
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 35 deletions.
11 changes: 8 additions & 3 deletions cmd/beaker/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func newConfigListCommand() *cobra.Command {
t := reflect.TypeOf(*beakerConfig)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
propertyKey := field.Tag.Get("yaml")
propertyKey := firstTag(field.Tag.Get("yaml"))
value := reflect.ValueOf(beakerConfig).Elem().FieldByName(field.Name).String()
if value == "" {
value = "(unset)"
Expand Down Expand Up @@ -68,7 +68,7 @@ func newConfigSetCommand() *cobra.Command {
found := false
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Tag.Get("yaml") == args[0] {
if firstTag(field.Tag.Get("yaml")) == args[0] {
found = true
// The following code assumes all values are strings and will not work with non-string values.
reflect.ValueOf(beakerCfg).Elem().FieldByName(field.Name).SetString(strings.TrimSpace(args[1]))
Expand Down Expand Up @@ -148,7 +148,7 @@ func newConfigUnsetCommand() *cobra.Command {
found := false
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Tag.Get("yaml") == args[0] {
if firstTag(field.Tag.Get("yaml")) == args[0] {
found = true
reflect.ValueOf(beakerCfg).Elem().FieldByName(field.Name).Set(reflect.Zero(field.Type))
}
Expand All @@ -163,3 +163,8 @@ func newConfigUnsetCommand() *cobra.Command {
},
}
}

// Return first fields from a YAML tag e.g. "name,omitempty" -> "name".
func firstTag(tag string) string {
return strings.Split(tag, ",")[0]
}
90 changes: 61 additions & 29 deletions cmd/beaker/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"
"strings"

"github.com/allenai/beaker/config"
"github.com/allenai/bytefmt"
"github.com/beaker/client/api"
"github.com/beaker/client/client"
Expand All @@ -16,6 +17,8 @@ import (
"github.com/spf13/cobra"
)

const defaultImage = "beaker://ai2/cuda11.2-ubuntu20.04"

func newSessionCommand() *cobra.Command {
cmd := &cobra.Command{
Use: "session <command>",
Expand Down Expand Up @@ -70,11 +73,13 @@ To pass flags, use "--" e.g. "create -- ls -l"`,
var node string
var workspace string
var saveImage bool
cmd.Flags().StringVar(
var noUpdateDefaultImage bool
cmd.Flags().StringVarP(
&image,
"image",
"beaker://ai2/cuda11.2-ubuntu20.04",
"Base image to run, may be a Beaker or Docker image")
"i",
defaultImage,
"Base image to run, may be a Beaker or Docker image. Uses 'default_image' from the Beaker configuration if set.")
cmd.Flags().BoolVar(&localHome, "local-home", false, "Mount the invoking user's home directory, ignoring Beaker configuration")
cmd.Flags().StringVarP(&name, "name", "n", "", "Assign a name to the session")
cmd.Flags().StringVar(&node, "node", "", "Node that the session will run on. Defaults to current node.")
Expand All @@ -85,6 +90,11 @@ To pass flags, use "--" e.g. "create -- ls -l"`,
"s",
false,
"Save the result image of the session. A new image will be created in the session's workspace.")
cmd.Flags().BoolVar(
&noUpdateDefaultImage,
"no-update-default-image",
false,
"Do not update the default image when using --save-image.")

var secretEnv map[string]string
var secretMount map[string]string
Expand Down Expand Up @@ -134,6 +144,10 @@ To pass flags, use "--" e.g. "create -- ls -l"`,
}
}

if image == defaultImage && beakerConfig.DefaultImage != "" {
fmt.Printf("Defaulting to image %s\n", color.BlueString(beakerConfig.DefaultImage))
image = beakerConfig.DefaultImage
}
imageSource, err := getImageSource(image)
if err != nil {
return err
Expand Down Expand Up @@ -244,36 +258,54 @@ Do not write sensitive information outside of the home directory.
}
shouldCancel = false

if saveImage {
var job *api.Job
started := func(ctx context.Context) (bool, error) {
var err error
job, err = beaker.Job(session.ID).Get(ctx)
if err != nil {
return false, err
}
return job.Status.Finalized != nil, nil
}
if err := await(ctx, "Waiting for image capture to complete", started, 0); err != nil {
return fmt.Errorf("waiting for image capture to complete: %w", err)
}
if job.Status.Failed != nil {
return fmt.Errorf("session failed: %s", job.Status.Message)
}
images, err := beaker.Job(job.ID).GetImages(ctx)
if !saveImage {
return nil
}
var job *api.Job
started := func(ctx context.Context) (bool, error) {
var err error
job, err = beaker.Job(session.ID).Get(ctx)
if err != nil {
return err
}
if len(images) == 0 {
return fmt.Errorf("job has no result images")
return false, err
}
return job.Status.Finalized != nil, nil
}
if err := await(ctx, "Waiting for image capture to complete", started, 0); err != nil {
return fmt.Errorf("waiting for image capture to complete: %w", err)
}
if job.Status.Failed != nil {
return fmt.Errorf("session failed: %s", job.Status.Message)
}
images, err := beaker.Job(job.ID).GetImages(ctx)
if err != nil {
return err
}
if len(images) == 0 {
return fmt.Errorf("job has no result images")
}
if !quiet {
fmt.Printf("Image saved to %s: %s/im/%s\n",
color.BlueString(images[0].ID),
beaker.Address(),
images[0].ID)
}
if noUpdateDefaultImage {
if !quiet {
fmt.Printf(`Image saved to %[1]s: %[2]s/im/%[3]s
Resume this session with: beaker session create --image beaker://%[3]s
`, color.BlueString(images[0].ID), beaker.Address(), images[0].ID)
fmt.Printf(`Default image not updated.
Resume this session with: beaker session create --image beaker://%s
`, images[0].ID)
}
return nil
}
beakerConfig.DefaultImage = "beaker://" + images[0].ID
if err := config.WriteConfig(beakerConfig, config.GetFilePath()); err != nil {
return fmt.Errorf("setting default image: %w", err)
}
if !quiet {
fmt.Printf(`Default image updated in your config file: %s
Resume this session with: beaker session create
`, config.GetFilePath())
}

return nil
}
return cmd
Expand Down Expand Up @@ -546,7 +578,7 @@ func awaitSessionStart(session api.Job) (*api.Job, error) {
return false, err
}
if job.Status.Finalized != nil {
return false, fmt.Errorf("session finalized: %s", session.Status.Message)
return false, fmt.Errorf("session finalized: %s", job.Status.Message)
}
return job.Status.Started != nil, nil
}
Expand Down
7 changes: 4 additions & 3 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ import (
// Config is a structured representation of a Beaker config file.
type Config struct {
// Client settings
BeakerAddress string `yaml:"agent_address"` // TODO: Find a better name than "agent_address"
UserToken string `yaml:"user_token"`
DefaultWorkspace string `yaml:"default_workspace"`
BeakerAddress string `yaml:"agent_address,omitempty"` // TODO: Find a better name than "agent_address"
UserToken string `yaml:"user_token,omitempty"`
DefaultWorkspace string `yaml:"default_workspace,omitempty"`
DefaultImage string `yaml:"default_image,omitempty"`
}

const (
Expand Down

0 comments on commit 66fe818

Please sign in to comment.