diff --git a/tests/robustness/failpoints.go b/tests/robustness/failpoints.go index 9fb3d5a2bbad..397a9fa275a6 100644 --- a/tests/robustness/failpoints.go +++ b/tests/robustness/failpoints.go @@ -23,6 +23,7 @@ import ( "time" "go.uber.org/zap" + healthpb "google.golang.org/grpc/health/grpc_health_v1" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/tests/v3/framework/e2e" @@ -85,30 +86,74 @@ var ( }} ) -func triggerFailpoints(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, config FailpointConfig) { +func triggerFailpoints(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.EtcdProcessCluster, config FailpointConfig) error { var err error successes := 0 failures := 0 for _, proc := range clus.Procs { if !config.failpoint.Available(proc) { t.Errorf("Failpoint %q not available on %s", config.failpoint.Name(), proc.Config().Name) - return + return nil } } for successes < config.count && failures < config.retries { time.Sleep(config.waitBetweenTriggers) - lg.Info("Triggering failpoint\n", zap.String("failpoint", config.failpoint.Name())) + + lg.Info("Verifying cluster health before failpoint", zap.String("failpoint", config.failpoint.Name())) + if err = verifyClusterHealth(ctx, t, clus); err != nil { + t.Errorf("failed to verify cluster health before failpoint injection, err: %v", err) + return err + } + + lg.Info("Triggering failpoint", zap.String("failpoint", config.failpoint.Name())) err = config.failpoint.Trigger(ctx, t, lg, clus) if err != nil { lg.Info("Failed to trigger failpoint", zap.String("failpoint", config.failpoint.Name()), zap.Error(err)) failures++ continue } + + lg.Info("Verifying cluster health after failpoint", zap.String("failpoint", config.failpoint.Name())) + if err = verifyClusterHealth(ctx, t, clus); err != nil { + t.Errorf("failed to verify cluster health after failpoint injection, err: %v", err) + return err + } + successes++ } if successes < config.count || failures >= config.retries { t.Errorf("failed to trigger failpoints enough times, err: %v", err) } + + return nil +} + +func verifyClusterHealth(ctx context.Context, t *testing.T, clus *e2e.EtcdProcessCluster) error { + for i := 0; i < len(clus.Procs); i++ { + clusterClient, err := clientv3.New(clientv3.Config{ + Endpoints: clus.Procs[i].EndpointsGRPC(), + Logger: zap.NewNop(), + DialKeepAliveTime: 1 * time.Millisecond, + DialKeepAliveTimeout: 5 * time.Millisecond, + }) + if err != nil { + return fmt.Errorf("Error creating client for cluster %s: %v", clus.Procs[i].Config().Name, err) + } + defer clusterClient.Close() + + cli := healthpb.NewHealthClient(clusterClient.ActiveConnection()) + resp, err := cli.Check(ctx, &healthpb.HealthCheckRequest{}) + if err != nil { + return fmt.Errorf("Error checking member %s health: %v", clus.Procs[i].Config().Name, err) + } + if resp.Status != healthpb.HealthCheckResponse_SERVING { + return fmt.Errorf("Member %s health status expected %s, got %s", + clus.Procs[i].Config().Name, + healthpb.HealthCheckResponse_SERVING, + resp.Status) + } + } + return nil } type FailpointConfig struct { diff --git a/tests/robustness/linearizability_test.go b/tests/robustness/linearizability_test.go index 5619dc0b6007..958a29762526 100644 --- a/tests/robustness/linearizability_test.go +++ b/tests/robustness/linearizability_test.go @@ -197,8 +197,12 @@ func runScenario(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.Et // Run multiple test components (traffic, failpoints, etc) in parallel and use canceling context to propagate stop signal. g := errgroup.Group{} trafficCtx, trafficCancel := context.WithCancel(ctx) + watchCtx, watchCancel := context.WithCancel(ctx) g.Go(func() error { - triggerFailpoints(ctx, t, lg, clus, failpoint) + err := triggerFailpoints(ctx, t, lg, clus, failpoint) + if err != nil { + watchCancel() + } time.Sleep(time.Second) trafficCancel() return nil @@ -211,7 +215,7 @@ func runScenario(ctx context.Context, t *testing.T, lg *zap.Logger, clus *e2e.Et return nil }) g.Go(func() error { - responses = collectClusterWatchEvents(ctx, t, clus, maxRevisionChan) + responses = collectClusterWatchEvents(watchCtx, t, clus, maxRevisionChan) return nil }) g.Wait()