From 207843fdee83e65d41d54aa5e52bd065b6ccf019 Mon Sep 17 00:00:00 2001 From: Gianluca Arbezzano Date: Wed, 3 Feb 2021 11:26:43 +0100 Subject: [PATCH 1/2] Migrate tink-server to cobra and viber This PR refactors how tink-server starts. It uses Cobra and Viper same as tink-cli and tink-worker. Right now it tries to keep compatibility with the old way of doing things. Signed-off-by: Gianluca Arbezzano --- cmd/tink-server/main.go | 300 +++++++++++++++++++++++++-------- grpc-server/events.go | 3 +- grpc-server/grpc_server.go | 32 ++-- grpc-server/hardware.go | 42 ++--- grpc-server/template.go | 34 ++-- grpc-server/template_test.go | 4 +- grpc-server/tinkerbell.go | 9 +- grpc-server/tinkerbell_test.go | 33 ++-- grpc-server/workflow.go | 38 ++--- grpc-server/workflow_test.go | 4 +- http-server/http_server.go | 41 +++-- 11 files changed, 350 insertions(+), 190 deletions(-) diff --git a/cmd/tink-server/main.go b/cmd/tink-server/main.go index 5ebc324ae..ed291b71b 100644 --- a/cmd/tink-server/main.go +++ b/cmd/tink-server/main.go @@ -6,9 +6,14 @@ import ( "fmt" "os" "os/signal" + "strconv" + "strings" "syscall" "github.com/packethost/pkg/log" + "github.com/spf13/cobra" + "github.com/spf13/pflag" + "github.com/spf13/viper" "github.com/tinkerbell/tink/client/listener" "github.com/tinkerbell/tink/db" rpcServer "github.com/tinkerbell/tink/grpc-server" @@ -18,93 +23,248 @@ import ( var ( // version is set at build time version = "devel" - - logger log.Logger ) -func main() { - log, err := log.Init("github.com/tinkerbell/tink") - if err != nil { - panic(err) +// DaemonConfig represents all the values you can configure as part of the tink-server. +// You can change the configuration via environment variable, or file, or command flags. +type DaemonConfig struct { + Facility string + PGDatabase string + PGUSer string + PGPassword string + PGSSLMode string + OnlyMigration bool + GRPCAuthority string + TLSCert string + CertDir string + HTTPAuthority string + HTTPBasicAuthUsername string + HTTPBasicAuthPassword string +} + +func (c *DaemonConfig) AddFlags(fs *pflag.FlagSet) { + fs.StringVar(&c.Facility, "facility", "deprecated", "This is temporary. It will be removed") + fs.StringVar(&c.PGDatabase, "postgres-database", "tinkerbell", "The Postgres database name") + fs.StringVar(&c.PGUSer, "postgres-user", "tinkerbell", "The Postgres database username") + fs.StringVar(&c.PGPassword, "postgres-password", "tinkerbell", "The Postgres database password") + fs.StringVar(&c.PGSSLMode, "postgres-sslmode", "disable", "Enable or disable SSL mode in postgres") + fs.BoolVar(&c.OnlyMigration, "only-migration", false, "When enabled it ") + fs.StringVar(&c.GRPCAuthority, "grpc-authority", ":42113", "The address used to expose the gRPC server") + fs.StringVar(&c.TLSCert, "tls-cert", "", "") + fs.StringVar(&c.CertDir, "cert-dir", "", "") + fs.StringVar(&c.HTTPAuthority, "http-authority", ":42114", "The address used to expose the HTTP server") +} + +func (c *DaemonConfig) PopulateFromLegacyEnvVar() { + if f := os.Getenv("FACILITY"); f != "" { + c.Facility = f } - logger = log - defer logger.Close() - log.Info("starting version " + version) - - ctx, closer := context.WithCancel(context.Background()) - errCh := make(chan error, 2) - facility := os.Getenv("FACILITY") - - // TODO(gianarb): I moved this up because we need to be sure that both - // connection, the one used for the resources and the one used for - // listening to events and notification are coming in the same way. - // BUT we should be using the right flags - connInfo := fmt.Sprintf("dbname=%s user=%s password=%s sslmode=%s", - os.Getenv("PGDATABASE"), - os.Getenv("PGUSER"), - os.Getenv("PGPASSWORD"), - os.Getenv("PGSSLMODE"), - ) - - dbCon, err := sql.Open("postgres", connInfo) - if err != nil { - logger.Error(err) - panic(err) + if pgdb := os.Getenv("PGDATABASE"); pgdb != "" { + c.PGDatabase = pgdb } - tinkDB := db.Connect(dbCon, logger) - - _, onlyMigration := os.LookupEnv("ONLY_MIGRATION") - if onlyMigration { - logger.Info("Applying migrations. This process will end when migrations will take place.") - numAppliedMigrations, err := tinkDB.Migrate() - if err != nil { - log.Fatal(err) - panic(err) + if pguser := os.Getenv("PGUSER"); pguser != "" { + c.PGUSer = pguser + } + if pgpass := os.Getenv("PGPASSWORD"); pgpass != "" { + c.PGPassword = pgpass + } + if pgssl := os.Getenv("PGSSLMODE"); pgssl != "" { + c.PGSSLMode = pgssl + } + if onlyMigration, isSet := os.LookupEnv("ONLY_MIGRATION"); isSet { + if b, err := strconv.ParseBool(onlyMigration); err != nil { + c.OnlyMigration = b } - log.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully") - os.Exit(0) } + if tlsCert := os.Getenv("TINKERBELL_TLS_CERT"); tlsCert != "" { + c.TLSCert = tlsCert + } + if certDir := os.Getenv("TINKERBELL_CERTS_DIR"); certDir != "" { + c.CertDir = certDir + } + if grpcAuthority := os.Getenv("TINKERBELL_GRPC_AUTHORITY"); grpcAuthority != "" { + c.GRPCAuthority = grpcAuthority + } + if httpAuthority := os.Getenv("TINKERBELL_HTTP_AUTHORITY"); httpAuthority != "" { + c.HTTPAuthority = httpAuthority + } + if basicAuthUser := os.Getenv("TINK_AUTH_USERNAME"); basicAuthUser != "" { + c.HTTPBasicAuthUsername = basicAuthUser + } + if basicAuthPass := os.Getenv("TINK_AUTH_PASSWORD"); basicAuthPass != "" { + c.HTTPBasicAuthPassword = basicAuthPass + } +} - err = listener.Init(connInfo) +func main() { + logger, err := log.Init("github.com/tinkerbell/tink") if err != nil { - log.Fatal(err) panic(err) } + defer logger.Close() - go tinkDB.PurgeEvents(errCh) + config := &DaemonConfig{} - numAvailableMigrations, err := tinkDB.CheckRequiredMigrations() - if err != nil { - log.Fatal(err) - panic(err) - } - if numAvailableMigrations != 0 { - log.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.") + cmd := NewRootCommand(config, logger) + if err := cmd.ExecuteContext(context.Background()); err != nil { + os.Exit(1) } - cert, modT := rpcServer.SetupGRPC(ctx, logger, facility, tinkDB, errCh) - httpServer.SetupHTTP(ctx, logger, cert, modT, errCh) +} + +func NewRootCommand(config *DaemonConfig, logger log.Logger) *cobra.Command { + cmd := &cobra.Command{ + Use: "tink-server", + PreRunE: func(cmd *cobra.Command, args []string) error { + viper, err := createViper(logger) + if err != nil { + return err + } + return applyViper(viper, cmd) + }, + RunE: func(cmd *cobra.Command, args []string) error { + // I am not sure if it is right for this to be here, + // but as last step I want to keep compatibility with + // what we have for a little bit and I thinik that's + // the most aggressive way we have to guarantee that + // the old way works as before. + config.PopulateFromLegacyEnvVar() - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM) - select { - case err = <-errCh: - logger.Error(err) - panic(err) - case sig := <-sigs: - logger.With("signal", sig.String()).Info("signal received, stopping servers") + logger.Info("starting version " + version) + + sigs := make(chan os.Signal, 1) + signal.Notify(sigs, syscall.SIGINT, syscall.SIGQUIT, syscall.SIGTERM) + ctx, closer := context.WithCancel(cmd.Context()) + defer closer() + // TODO(gianarb): I think we can do better in terms of + // graceful shutdown and error management but I want to + // figure this out in another PR + errCh := make(chan error, 2) + + // TODO(gianarb): I moved this up because we need to be sure that both + // connection, the one used for the resources and the one used for + // listening to events and notification are coming in the same way. + // BUT we should be using the right flags + connInfo := fmt.Sprintf("dbname=%s user=%s password=%s sslmode=%s", + config.PGDatabase, + config.PGUSer, + config.PGPassword, + config.PGSSLMode, + ) + + dbCon, err := sql.Open("postgres", connInfo) + if err != nil { + return err + } + tinkDB := db.Connect(dbCon, logger) + + if config.OnlyMigration { + logger.Info("Applying migrations. This process will end when migrations will take place.") + numAppliedMigrations, err := tinkDB.Migrate() + if err != nil { + return err + } + logger.With("num_applied_migrations", numAppliedMigrations).Info("Migrations applied successfully") + return nil + } + + err = listener.Init(connInfo) + if err != nil { + return err + } + + go tinkDB.PurgeEvents(errCh) + + numAvailableMigrations, err := tinkDB.CheckRequiredMigrations() + if err != nil { + return err + } + if numAvailableMigrations != 0 { + logger.Info("Your database schema is not up to date. Please apply migrations running tink-server with env var ONLY_MIGRATION set.") + } + + cert, modT := rpcServer.SetupGRPC(ctx, logger, &rpcServer.ConfigGRPCServer{ + Facility: config.Facility, + TLSCert: config.TLSCert, + GRPCAuthority: config.GRPCAuthority, + DB: tinkDB, + }, errCh) + + httpServer.SetupHTTP(ctx, logger, &httpServer.HTTPServerConfig{ + CertPEM: cert, + ModTime: modT, + GRPCAuthority: config.GRPCAuthority, + HTTPAuthority: config.HTTPAuthority, + HTTPBasicAuthUsername: config.HTTPBasicAuthUsername, + HTTPBasicAuthPassword: config.HTTPBasicAuthPassword, + }, errCh) + + <-ctx.Done() + select { + case err = <-errCh: + logger.Error(err) + case sig := <-sigs: + logger.With("signal", sig.String()).Info("signal received, stopping servers") + } + + // wait for grpc server to shutdown + err = <-errCh + if err != nil { + return err + } + err = <-errCh + if err != nil { + return err + } + return nil + }, } - closer() + config.AddFlags(cmd.Flags()) + return cmd +} - // wait for grpc server to shutdown - err = <-errCh - if err != nil { - log.Fatal(err) - panic(err) +func createViper(logger log.Logger) (*viper.Viper, error) { + v := viper.New() + v.AutomaticEnv() + v.SetConfigName("tink-server") + v.AddConfigPath("/etc/tinkerbell") + v.AddConfigPath(".") + v.SetEnvKeyReplacer(strings.NewReplacer("-", "_")) + + // If a config file is found, read it in. + if err := v.ReadInConfig(); err != nil { + if _, ok := err.(viper.ConfigFileNotFoundError); !ok { + logger.With("configFile", v.ConfigFileUsed()).Error(err, "could not load config file") + return nil, err + } + logger.Info("no config file found") + } else { + logger.With("configFile", v.ConfigFileUsed()).Info("loaded config file") } - err = <-errCh - if err != nil { - log.Fatal(err) - panic(err) + + return v, nil +} + +func applyViper(v *viper.Viper, cmd *cobra.Command) error { + errors := []error{} + + cmd.Flags().VisitAll(func(f *pflag.Flag) { + if !f.Changed && v.IsSet(f.Name) { + val := v.Get(f.Name) + if err := cmd.Flags().Set(f.Name, fmt.Sprintf("%v", val)); err != nil { + errors = append(errors, err) + return + } + } + }) + + if len(errors) > 0 { + errs := []string{} + for _, err := range errors { + errs = append(errs, err.Error()) + } + return fmt.Errorf(strings.Join(errs, ", ")) } + + return nil } diff --git a/grpc-server/events.go b/grpc-server/events.go index 857b73622..87175f0c6 100644 --- a/grpc-server/events.go +++ b/grpc-server/events.go @@ -17,14 +17,13 @@ func (s *server) Watch(req *events.WatchRequest, stream events.EventsService_Wat return stream.Send(event) }) if err != nil && err != io.EOF { - logger.Error(err) return err } return listener.Listen(req, func(e *events.Event) error { err := stream.Send(e) if err != nil { - logger.With("eventTypes", req.EventTypes, "resourceTypes", req.ResourceTypes).Info("events stream closed") + s.logger.With("eventTypes", req.EventTypes, "resourceTypes", req.ResourceTypes).Info("events stream closed") return listener.RemoveHandlers(req) } return nil diff --git a/grpc-server/grpc_server.go b/grpc-server/grpc_server.go index 4335416f1..721aec377 100644 --- a/grpc-server/grpc_server.go +++ b/grpc-server/grpc_server.go @@ -24,11 +24,6 @@ import ( "google.golang.org/grpc/reflection" ) -var ( - logger log.Logger - grpcListenAddr = os.Getenv("TINKERBELL_GRPC_AUTHORITY") -) - // Server is the gRPC server for tinkerbell type server struct { cert []byte @@ -42,25 +37,34 @@ type server struct { watchLock sync.RWMutex watch map[string]chan string + + logger log.Logger +} + +type ConfigGRPCServer struct { + Facility string + TLSCert string + GRPCAuthority string + DB *db.TinkDB } // SetupGRPC setup and return a gRPC server -func SetupGRPC(ctx context.Context, log log.Logger, facility string, db *db.TinkDB, errCh chan<- error) ([]byte, time.Time) { +func SetupGRPC(ctx context.Context, logger log.Logger, config *ConfigGRPCServer, errCh chan<- error) ([]byte, time.Time) { params := []grpc.ServerOption{ grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor), grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor), } - logger = log - metrics.SetupMetrics(facility, logger) + metrics.SetupMetrics(config.Facility, logger) server := &server{ - db: db, + db: config.DB, dbReady: true, + logger: logger, } - if cert := os.Getenv("TINKERBELL_TLS_CERT"); cert != "" { + if cert := config.TLSCert; cert != "" { server.cert = []byte(cert) server.modT = time.Now() } else { - tlsCert, certPEM, modT := getCerts(facility, logger) + tlsCert, certPEM, modT := getCerts(config.Facility, logger) params = append(params, grpc.Creds(credentials.NewServerTLSFromCert(&tlsCert))) server.cert = certPEM server.modT = modT @@ -77,11 +81,7 @@ func SetupGRPC(ctx context.Context, log log.Logger, facility string, db *db.Tink grpc_prometheus.Register(s) go func() { - logger.Info("serving grpc") - if grpcListenAddr == "" { - grpcListenAddr = ":42113" - } - lis, err := net.Listen("tcp", grpcListenAddr) + lis, err := net.Listen("tcp", config.GRPCAuthority) if err != nil { err = errors.Wrap(err, "failed to listen") logger.Error(err) diff --git a/grpc-server/hardware.go b/grpc-server/hardware.go index fc3404180..ee6428c62 100644 --- a/grpc-server/hardware.go +++ b/grpc-server/hardware.go @@ -22,7 +22,7 @@ const ( ) func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware.Empty, error) { - logger.Info("push") + s.logger.Info("push") labels := prometheus.Labels{"method": "Push", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -33,7 +33,7 @@ func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware. hw := in.GetData() if hw == nil { err := errors.New("expected data not to be nil") - logger.Error(err) + s.logger.Error(err) return &hardware.Empty{}, err } @@ -43,7 +43,7 @@ func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware. metrics.CacheTotals.With(labels).Inc() metrics.CacheErrors.With(labels).Inc() err := errors.New("id must be set to a UUID, got id: " + hw.Id) - logger.Error(err) + s.logger.Error(err) return &hardware.Empty{}, err } @@ -59,7 +59,7 @@ func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware. const msg = "inserting into DB" data, err := json.Marshal(hw) if err != nil { - logger.Error(err) + s.logger.Error(err) } labels["op"] = "insert" @@ -68,18 +68,18 @@ func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware. timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) err = s.db.InsertIntoDB(ctx, string(data)) - logger.Info("done " + msg) + s.logger.Info("done " + msg) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } l.Error(err) } - logger.With("id", hw.Id).Info("data pushed") + s.logger.With("id", hw.Id).Info("data pushed") s.watchLock.RLock() if ch := s.watch[hw.Id]; ch != nil { @@ -90,7 +90,7 @@ func (s *server) Push(ctx context.Context, in *hardware.PushRequest) (*hardware. } } s.watchLock.RUnlock() - logger.With("id", hw.Id).Info("skipping blocked watcher") + s.logger.With("id", hw.Id).Info("skipping blocked watcher") return &hardware.Empty{}, err } @@ -182,7 +182,7 @@ func (s *server) All(_ *hardware.Empty, stream hardware.HardwareService_AllServe } func (s *server) DeprecatedWatch(in *hardware.GetRequest, stream hardware.HardwareService_DeprecatedWatchServer) error { - l := logger.With("id", in.Id) + l := s.logger.With("id", in.Id) ch := make(chan string, 1) s.watchLock.Lock() @@ -252,7 +252,7 @@ func (s *server) ModTime() time.Time { } func (s *server) Delete(ctx context.Context, in *hardware.DeleteRequest) (*hardware.Empty, error) { - logger.Info("delete") + s.logger.Info("delete") labels := prometheus.Labels{"method": "Delete", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -264,11 +264,11 @@ func (s *server) Delete(ctx context.Context, in *hardware.DeleteRequest) (*hardw metrics.CacheTotals.With(labels).Inc() metrics.CacheErrors.With(labels).Inc() err := errors.New("id must be set to a UUID") - logger.Error(err) + s.logger.Error(err) return &hardware.Empty{}, err } - logger.With("id", in.Id).Info("data deleted") + s.logger.With("id", in.Id).Info("data deleted") labels["op"] = "delete" const msg = "deleting into DB" @@ -277,16 +277,16 @@ func (s *server) Delete(ctx context.Context, in *hardware.DeleteRequest) (*hardw timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) err := s.db.DeleteFromDB(ctx, in.Id) - logger.Info("done " + msg) + s.logger.Info("done " + msg) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + logger := s.logger if pqErr := db.Error(err); pqErr != nil { - l = l.With("detail", pqErr.Detail, "where", pqErr.Where) + logger = s.logger.With("detail", pqErr.Detail, "where", pqErr.Where) } - l.Error(err) + logger.Error(err) } s.watchLock.RLock() @@ -295,7 +295,7 @@ func (s *server) Delete(ctx context.Context, in *hardware.DeleteRequest) (*hardw case ch <- in.Id: default: metrics.WatchMissTotal.Inc() - logger.With("id", in.Id).Info("skipping blocked watcher") + s.logger.With("id", in.Id).Info("skipping blocked watcher") } } s.watchLock.RUnlock() @@ -308,11 +308,11 @@ func (s *server) validateHardwareData(ctx context.Context, hw *hardware.Hardware mac := iface.GetDhcp().GetMac() if data, _ := s.db.GetByMAC(ctx, mac); data != "" { - logger.With("MAC", mac).Info(duplicateMAC) + s.logger.With("MAC", mac).Info(duplicateMAC) newhw := hardware.Hardware{} if err := json.Unmarshal([]byte(data), &newhw); err != nil { - logger.Error(err, "Failed to unmarshal hardware data") + s.logger.Error(err, "Failed to unmarshal hardware data") return err } diff --git a/grpc-server/template.go b/grpc-server/template.go index 90ab1eb87..44b054f09 100644 --- a/grpc-server/template.go +++ b/grpc-server/template.go @@ -15,7 +15,7 @@ import ( // CreateTemplate implements template.CreateTemplate func (s *server) CreateTemplate(ctx context.Context, in *template.WorkflowTemplate) (*template.CreateResponse, error) { - logger.Info("createtemplate") + s.logger.Info("createtemplate") labels := prometheus.Labels{"method": "CreateTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -28,24 +28,24 @@ func (s *server) CreateTemplate(ctx context.Context, in *template.WorkflowTempla timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) err := s.db.CreateTemplate(ctx, in.Name, in.Data, id) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } l.Error(err) return &template.CreateResponse{}, err } - logger.Info("done " + msg) + s.logger.Info("done " + msg) return &template.CreateResponse{Id: id.String()}, err } // GetTemplate implements template.GetTemplate func (s *server) GetTemplate(ctx context.Context, in *template.GetRequest) (*template.WorkflowTemplate, error) { - logger.Info("gettemplate") + s.logger.Info("gettemplate") labels := prometheus.Labels{"method": "GetTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -57,16 +57,16 @@ func (s *server) GetTemplate(ctx context.Context, in *template.GetRequest) (*tem timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) fields := map[string]string{ "id": in.GetId(), "name": in.GetName(), } id, n, d, err := s.db.GetTemplate(ctx, fields, false) - logger.Info("done " + msg) + s.logger.Info("done " + msg) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } @@ -77,7 +77,7 @@ func (s *server) GetTemplate(ctx context.Context, in *template.GetRequest) (*tem // DeleteTemplate implements template.DeleteTemplate func (s *server) DeleteTemplate(ctx context.Context, in *template.GetRequest) (*template.Empty, error) { - logger.Info("deletetemplate") + s.logger.Info("deletetemplate") labels := prometheus.Labels{"method": "DeleteTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -89,12 +89,12 @@ func (s *server) DeleteTemplate(ctx context.Context, in *template.GetRequest) (* timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) err := s.db.DeleteTemplate(ctx, in.GetId()) - logger.Info("done " + msg) + s.logger.Info("done " + msg) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } @@ -105,7 +105,7 @@ func (s *server) DeleteTemplate(ctx context.Context, in *template.GetRequest) (* // ListTemplates implements template.ListTemplates func (s *server) ListTemplates(in *template.ListRequest, stream template.TemplateService_ListTemplatesServer) error { - logger.Info("listtemplates") + s.logger.Info("listtemplates") labels := prometheus.Labels{"method": "ListTemplates", "op": "list"} metrics.CacheTotals.With(labels).Inc() metrics.CacheInFlight.With(labels).Inc() @@ -141,7 +141,7 @@ func (s *server) ListTemplates(in *template.ListRequest, stream template.Templat // UpdateTemplate implements template.UpdateTemplate func (s *server) UpdateTemplate(ctx context.Context, in *template.WorkflowTemplate) (*template.Empty, error) { - logger.Info("updatetemplate") + s.logger.Info("updatetemplate") labels := prometheus.Labels{"method": "UpdateTemplate", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -153,12 +153,12 @@ func (s *server) UpdateTemplate(ctx context.Context, in *template.WorkflowTempla timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) err := s.db.UpdateTemplate(ctx, in.Name, in.Data, uuid.MustParse(in.Id)) - logger.Info("done " + msg) + s.logger.Info("done " + msg) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } diff --git a/grpc-server/template_test.go b/grpc-server/template_test.go index 5bafe942b..56acad8b7 100644 --- a/grpc-server/template_test.go +++ b/grpc-server/template_test.go @@ -140,7 +140,7 @@ func TestCreateTemplate(t *testing.T) { tc := testCases[name] t.Run(name, func(t *testing.T) { t.Parallel() - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.CreateTemplate(ctx, &pb.WorkflowTemplate{Name: tc.args.name, Data: tc.args.template}) if tc.want.expectedError { assert.Error(t, err) @@ -311,7 +311,7 @@ func TestGetTemplate(t *testing.T) { tc := testCases[name] t.Run(name, func(t *testing.T) { t.Parallel() - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.GetTemplate(ctx, tc.args.getRequest) if tc.err { assert.Error(t, err) diff --git a/grpc-server/tinkerbell.go b/grpc-server/tinkerbell.go index 838d62698..956144c2b 100644 --- a/grpc-server/tinkerbell.go +++ b/grpc-server/tinkerbell.go @@ -6,6 +6,7 @@ import ( "strconv" "time" + "github.com/packethost/pkg/log" "github.com/tinkerbell/tink/db" pb "github.com/tinkerbell/tink/protos/workflow" "google.golang.org/grpc/codes" @@ -38,7 +39,7 @@ func (s *server) GetWorkflowContexts(req *pb.WorkflowContextRequest, stream pb.W if err != nil { return status.Errorf(codes.Aborted, err.Error()) } - if isApplicableToSend(context.Background(), wfContext, req.WorkerId, s.db) { + if isApplicableToSend(context.Background(), s.logger, wfContext, req.WorkerId, s.db) { if err := stream.Send(wfContext); err != nil { return err } @@ -92,7 +93,7 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct return nil, status.Errorf(codes.InvalidArgument, errInvalidActionName) } - l := logger.With("actionName", req.GetActionName(), "workflowID", req.GetWorkflowId()) + l := s.logger.With("actionName", req.GetActionName(), "workflowID", req.GetWorkflowId()) l.Info(fmt.Sprintf(msgReceivedStatus, req.GetActionStatus())) wfContext, err := s.db.GetWorkflowContexts(context, wfID) @@ -134,7 +135,7 @@ func (s *server) ReportActionStatus(context context.Context, req *pb.WorkflowAct return &pb.Empty{}, status.Error(codes.Aborted, err.Error()) } - l = logger.With( + l = s.logger.With( "workflowID", wfContext.GetWorkflowId(), "currentWorker", wfContext.GetCurrentWorker(), "currentTask", wfContext.GetCurrentTask(), @@ -216,7 +217,7 @@ func getWorkflowActions(context context.Context, db db.Database, wfID string) (* // isApplicableToSend checks if a particular workflow context is applicable or if it is needed to // be sent to a worker based on the state of the current action and the targeted workerID -func isApplicableToSend(context context.Context, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool { +func isApplicableToSend(context context.Context, logger log.Logger, wfContext *pb.WorkflowContext, workerID string, db db.Database) bool { if wfContext.GetCurrentActionState() == pb.State_STATE_FAILED || wfContext.GetCurrentActionState() == pb.State_STATE_TIMEOUT { return false diff --git a/grpc-server/tinkerbell_test.go b/grpc-server/tinkerbell_test.go index 14dbeeca9..81701b830 100644 --- a/grpc-server/tinkerbell_test.go +++ b/grpc-server/tinkerbell_test.go @@ -28,17 +28,17 @@ const ( var wfData = []byte("{'os': 'ubuntu', 'base_url': 'http://192.168.1.1/'}") -func testServer(db db.Database) *server { +func testServer(t *testing.T, db db.Database) *server { + l, _ := log.Init("github.com/tinkerbell/tink") return &server{ - db: db, + logger: l, + db: db, } } func TestMain(m *testing.M) { - l, _ := log.Init("github.com/tinkerbell/tink") - logger = l.Package("grpcserver") - metrics.SetupMetrics("onprem", logger) + metrics.SetupMetrics("onprem", l.Package("grpcserver")) os.Exit(m.Run()) } @@ -122,7 +122,7 @@ func TestGetWorkflowContextList(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.GetWorkflowContextList(ctx, &pb.WorkflowContextRequest{WorkerId: tc.args.workerID}) if err != nil { assert.Error(t, err) @@ -205,7 +205,7 @@ func TestGetWorkflowActions(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.GetWorkflowActions(ctx, &pb.WorkflowActionsRequest{WorkflowId: tc.args.workflowID}) if err != nil { assert.True(t, tc.want.expectedError) @@ -546,7 +546,7 @@ func TestReportActionStatus(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.ReportActionStatus(ctx, &pb.WorkflowActionStatus{ WorkflowId: tc.args.workflowID, @@ -626,7 +626,7 @@ func TestUpdateWorkflowData(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.UpdateWorkflowData( ctx, &pb.UpdateWorkflowDataRequest{ WorkflowId: tc.args.workflowID, @@ -715,7 +715,7 @@ func TestGetWorkflowData(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() for name, tc := range testCases { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) t.Run(name, func(t *testing.T) { res, err := s.GetWorkflowData(ctx, &pb.GetWorkflowDataRequest{WorkflowId: tc.args.workflowID}) if err != nil { @@ -798,7 +798,7 @@ func TestGetWorkflowsForWorker(t *testing.T) { for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := getWorkflowsForWorker(s.db, tc.args.workerID) if err != nil { assert.True(t, tc.want.expectedError) @@ -886,7 +886,7 @@ func TestGetWorkflowMetadata(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.GetWorkflowMetadata(ctx, &pb.GetWorkflowDataRequest{WorkflowId: tc.args.workflowID}) if err != nil { assert.True(t, tc.want.expectedError) @@ -957,7 +957,7 @@ func TestGetWorkflowDataVersion(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.GetWorkflowDataVersion(ctx, &pb.GetWorkflowDataRequest{WorkflowId: workflowID}) assert.Equal(t, tc.want.version, res.Version) if err != nil { @@ -1166,13 +1166,14 @@ func TestIsApplicableToSend(t *testing.T) { }, } + logger, _ := log.Init("test") ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) wfContext, _ := s.db.GetWorkflowContexts(ctx, workflowID) - res := isApplicableToSend(ctx, wfContext, workerID, s.db) + res := isApplicableToSend(ctx, logger, wfContext, workerID, s.db) assert.Equal(t, tc.want.isApplicable, res) }) } @@ -1262,7 +1263,7 @@ func TestIsLastAction(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) wfContext, _ := s.db.GetWorkflowContexts(ctx, workflowID) actions, _ := s.db.GetWorkflowActions(ctx, workflowID) res := isLastAction(wfContext, actions) diff --git a/grpc-server/workflow.go b/grpc-server/workflow.go index f6ff955b5..cd3c54299 100644 --- a/grpc-server/workflow.go +++ b/grpc-server/workflow.go @@ -25,7 +25,7 @@ const errFailedToGetTemplate = "failed to get template with ID %s" // CreateWorkflow implements workflow.CreateWorkflow func (s *server) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) (*workflow.CreateResponse, error) { - logger.Info("createworkflow") + s.logger.Info("createworkflow") labels := prometheus.Labels{"method": "CreateWorkflow", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -41,7 +41,7 @@ func (s *server) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) fields := map[string]string{ "id": in.GetTemplate(), } @@ -53,7 +53,7 @@ func (s *server) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) if err != nil { metrics.CacheErrors.With(labels).Inc() - logger.Error(err) + s.logger.Error(err) return &workflow.CreateResponse{}, err } @@ -66,7 +66,7 @@ func (s *server) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) err = s.db.CreateWorkflow(ctx, wf, data, id) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } @@ -74,14 +74,14 @@ func (s *server) CreateWorkflow(ctx context.Context, in *workflow.CreateRequest) return &workflow.CreateResponse{}, err } - l := logger.With("workflowID", id.String()) + l := s.logger.With("workflowID", id.String()) l.Info("done " + msg) return &workflow.CreateResponse{Id: id.String()}, err } // GetWorkflow implements workflow.GetWorkflow func (s *server) GetWorkflow(ctx context.Context, in *workflow.GetRequest) (*workflow.Workflow, error) { - logger.Info("getworkflow") + s.logger.Info("getworkflow") labels := prometheus.Labels{"method": "GetWorkflow", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -93,11 +93,11 @@ func (s *server) GetWorkflow(ctx context.Context, in *workflow.GetRequest) (*wor timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) w, err := s.db.GetWorkflow(ctx, in.Id) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } @@ -122,21 +122,21 @@ func (s *server) GetWorkflow(ctx context.Context, in *workflow.GetRequest) (*wor State: state[w.State], Data: data, } - l := logger.With("workflowID", w.ID) + l := s.logger.With("workflowID", w.ID) l.Info("done " + msg) return wf, err } // DeleteWorkflow implements workflow.DeleteWorkflow func (s *server) DeleteWorkflow(ctx context.Context, in *workflow.GetRequest) (*workflow.Empty, error) { - logger.Info("deleteworkflow") + s.logger.Info("deleteworkflow") labels := prometheus.Labels{"method": "DeleteWorkflow", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() const msg = "deleting a workflow" labels["op"] = "delete" - l := logger.With("workflowID", in.GetId()) + l := s.logger.With("workflowID", in.GetId()) metrics.CacheTotals.With(labels).Inc() timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) @@ -146,7 +146,7 @@ func (s *server) DeleteWorkflow(ctx context.Context, in *workflow.GetRequest) (* err := s.db.DeleteWorkflow(ctx, in.Id, workflow.State_value[workflow.State_STATE_RUNNING.String()]) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } @@ -158,7 +158,7 @@ func (s *server) DeleteWorkflow(ctx context.Context, in *workflow.GetRequest) (* // ListWorkflows implements workflow.ListWorkflows func (s *server) ListWorkflows(_ *workflow.Empty, stream workflow.WorkflowService_ListWorkflowsServer) error { - logger.Info("listworkflows") + s.logger.Info("listworkflows") labels := prometheus.Labels{"method": "ListWorkflows", "op": "list"} metrics.CacheTotals.With(labels).Inc() metrics.CacheInFlight.With(labels).Inc() @@ -195,7 +195,7 @@ func (s *server) ListWorkflows(_ *workflow.Empty, stream workflow.WorkflowServic } func (s *server) GetWorkflowContext(ctx context.Context, in *workflow.GetRequest) (*workflow.WorkflowContext, error) { - logger.Info("GetworkflowContext") + s.logger.Info("GetworkflowContext") labels := prometheus.Labels{"method": "GetWorkflowContext", "op": ""} metrics.CacheInFlight.With(labels).Inc() defer metrics.CacheInFlight.With(labels).Dec() @@ -207,11 +207,11 @@ func (s *server) GetWorkflowContext(ctx context.Context, in *workflow.GetRequest timer := prometheus.NewTimer(metrics.CacheDuration.With(labels)) defer timer.ObserveDuration() - logger.Info(msg) + s.logger.Info(msg) w, err := s.db.GetWorkflowContexts(ctx, in.Id) if err != nil { metrics.CacheErrors.With(labels).Inc() - l := logger + l := s.logger if pqErr := db.Error(err); pqErr != nil { l = l.With("detail", pqErr.Detail, "where", pqErr.Where) } @@ -226,7 +226,7 @@ func (s *server) GetWorkflowContext(ctx context.Context, in *workflow.GetRequest CurrentActionState: workflow.State(w.CurrentActionState), TotalNumberOfActions: w.TotalNumberOfActions, } - l := logger.With( + l := s.logger.With( "workflowID", wf.GetWorkflowId(), "currentWorker", wf.GetCurrentWorker(), "currentTask", wf.GetCurrentTask(), @@ -241,7 +241,7 @@ func (s *server) GetWorkflowContext(ctx context.Context, in *workflow.GetRequest // ShowWorflowevents implements workflow.ShowWorflowEvents func (s *server) ShowWorkflowEvents(req *workflow.GetRequest, stream workflow.WorkflowService_ShowWorkflowEventsServer) error { - logger.Info("List workflows Events") + s.logger.Info("List workflows Events") labels := prometheus.Labels{"method": "ShowWorkflowEvents", "op": "list"} metrics.CacheTotals.With(labels).Inc() metrics.CacheInFlight.With(labels).Inc() @@ -274,7 +274,7 @@ func (s *server) ShowWorkflowEvents(req *workflow.GetRequest, stream workflow.Wo metrics.CacheErrors.With(labels).Inc() return err } - logger.Info("done listing workflows events") + s.logger.Info("done listing workflows events") metrics.CacheHits.With(labels).Inc() return nil } diff --git a/grpc-server/workflow_test.go b/grpc-server/workflow_test.go index 3e47d1402..7969f398d 100644 --- a/grpc-server/workflow_test.go +++ b/grpc-server/workflow_test.go @@ -95,7 +95,7 @@ func TestCreateWorkflow(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.CreateWorkflow(ctx, &workflow.CreateRequest{ Hardware: tc.args.wfHardware, Template: tc.args.wfTemplate, @@ -153,7 +153,7 @@ func TestGetWorkflow(t *testing.T) { defer cancel() for name, tc := range testCases { t.Run(name, func(t *testing.T) { - s := testServer(tc.args.db) + s := testServer(t, tc.args.db) res, err := s.GetWorkflow(ctx, &workflow.GetRequest{ Id: workflowID, }) diff --git a/http-server/http_server.go b/http-server/http_server.go index 092b98706..260559994 100644 --- a/http-server/http_server.go +++ b/http-server/http_server.go @@ -8,7 +8,6 @@ import ( "encoding/json" "net" "net/http" - "os" "runtime" "time" @@ -21,22 +20,27 @@ import ( ) var ( - gitRev = "unknown" - gitRevJSON []byte - grpcEndpoint = os.Getenv("TINKERBELL_GRPC_AUTHORITY") - httpListenAddr = os.Getenv("TINKERBELL_HTTP_AUTHORITY") - authUsername = os.Getenv("TINK_AUTH_USERNAME") - authPassword = os.Getenv("TINK_AUTH_PASSWORD") - startTime = time.Now() - logger log.Logger + gitRev = "unknown" + gitRevJSON []byte + startTime = time.Now() + logger log.Logger ) +type HTTPServerConfig struct { + CertPEM []byte + ModTime time.Time + GRPCAuthority string + HTTPAuthority string + HTTPBasicAuthUsername string + HTTPBasicAuthPassword string +} + // SetupHTTP setup and return an HTTP server -func SetupHTTP(ctx context.Context, lg log.Logger, certPEM []byte, modTime time.Time, errCh chan<- error) { +func SetupHTTP(ctx context.Context, lg log.Logger, config *HTTPServerConfig, errCh chan<- error) { logger = lg cp := x509.NewCertPool() - ok := cp.AppendCertsFromPEM(certPEM) + ok := cp.AppendCertsFromPEM(config.CertPEM) if !ok { logger.Error(errors.New("parse cert")) } @@ -47,9 +51,7 @@ func SetupHTTP(ctx context.Context, lg log.Logger, certPEM []byte, modTime time. dialOpts := []grpc.DialOption{grpc.WithTransportCredentials(creds)} - if grpcEndpoint == "" { - grpcEndpoint = "localhost:42113" - } + grpcEndpoint := config.GRPCAuthority host, _, err := net.SplitHostPort(grpcEndpoint) if err != nil { logger.Error(err) @@ -71,19 +73,16 @@ func SetupHTTP(ctx context.Context, lg log.Logger, certPEM []byte, modTime time. } http.HandleFunc("/cert", func(w http.ResponseWriter, r *http.Request) { - http.ServeContent(w, r, "server.pem", modTime, bytes.NewReader(certPEM)) + http.ServeContent(w, r, "server.pem", config.ModTime, bytes.NewReader(config.CertPEM)) }) http.Handle("/metrics", promhttp.Handler()) setupGitRevJSON() http.HandleFunc("/version", versionHandler) http.HandleFunc("/healthz", healthCheckHandler) - http.Handle("/", BasicAuth(mux)) + http.Handle("/", BasicAuth(config.HTTPBasicAuthUsername, config.HTTPBasicAuthPassword, mux)) - if httpListenAddr == "" { - httpListenAddr = ":42114" - } srv := &http.Server{ - Addr: httpListenAddr, + Addr: config.HTTPAuthority, } go func() { logger.Info("serving http") @@ -145,7 +144,7 @@ func setupGitRevJSON() { // BasicAuth adds authentication to the routes handled by handler // skips authentication if both authUsername and authPassword aren't set -func BasicAuth(handler http.Handler) http.Handler { +func BasicAuth(authUsername, authPassword string, handler http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if authUsername != "" || authPassword != "" { user, pass, ok := r.BasicAuth() From 550c63a4b5ee8b9785ed7e28663e5beb0f882f23 Mon Sep 17 00:00:00 2001 From: Gianluca Arbezzano Date: Tue, 9 Feb 2021 17:52:46 +0100 Subject: [PATCH 2/2] Fix doc for only-migration flag Signed-off-by: Gianluca Arbezzano --- cmd/tink-server/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/tink-server/main.go b/cmd/tink-server/main.go index ed291b71b..d37e24453 100644 --- a/cmd/tink-server/main.go +++ b/cmd/tink-server/main.go @@ -48,7 +48,7 @@ func (c *DaemonConfig) AddFlags(fs *pflag.FlagSet) { fs.StringVar(&c.PGUSer, "postgres-user", "tinkerbell", "The Postgres database username") fs.StringVar(&c.PGPassword, "postgres-password", "tinkerbell", "The Postgres database password") fs.StringVar(&c.PGSSLMode, "postgres-sslmode", "disable", "Enable or disable SSL mode in postgres") - fs.BoolVar(&c.OnlyMigration, "only-migration", false, "When enabled it ") + fs.BoolVar(&c.OnlyMigration, "only-migration", false, "When enabled the server applies the migration to postgres database and it exits") fs.StringVar(&c.GRPCAuthority, "grpc-authority", ":42113", "The address used to expose the gRPC server") fs.StringVar(&c.TLSCert, "tls-cert", "", "") fs.StringVar(&c.CertDir, "cert-dir", "", "")