diff --git a/src/cmd/dev.go b/src/cmd/dev.go index b10f2dde51..80a602afdb 100644 --- a/src/cmd/dev.go +++ b/src/cmd/dev.go @@ -142,7 +142,7 @@ var devSha256SumCmd = &cobra.Command{ Aliases: []string{"s"}, Short: lang.CmdDevSha256sumShort, Args: cobra.ExactArgs(1), - RunE: func(_ *cobra.Command, args []string) error { + RunE: func(cmd *cobra.Command, args []string) error { hashErr := errors.New("unable to compute the SHA256SUM hash") fileName := args[0] @@ -169,7 +169,7 @@ var devSha256SumCmd = &cobra.Command{ } downloadPath := filepath.Join(tmp, fileBase) - err = utils.DownloadToFile(fileName, downloadPath, "") + err = utils.DownloadToFile(cmd.Context(), fileName, downloadPath, "") if err != nil { return errors.Join(hashErr, err) } diff --git a/src/internal/packager/helm/repo.go b/src/internal/packager/helm/repo.go index 148a4176a4..2924321a38 100644 --- a/src/internal/packager/helm/repo.go +++ b/src/internal/packager/helm/repo.go @@ -53,14 +53,14 @@ func (h *Helm) PackageChart(ctx context.Context, cosignKeyPath string) error { return fmt.Errorf("unable to pull the chart %q from git: %w", h.chart.Name, err) } } else { - err = h.DownloadPublishedChart(cosignKeyPath) + err = h.DownloadPublishedChart(ctx, cosignKeyPath) if err != nil { return fmt.Errorf("unable to download the published chart %q: %w", h.chart.Name, err) } } } else { - err := h.PackageChartFromLocalFiles(cosignKeyPath) + err := h.PackageChartFromLocalFiles(ctx, cosignKeyPath) if err != nil { return fmt.Errorf("unable to package the %q chart: %w", h.chart.Name, err) } @@ -69,7 +69,7 @@ func (h *Helm) PackageChart(ctx context.Context, cosignKeyPath string) error { } // PackageChartFromLocalFiles creates a chart archive from a path to a chart on the host os. -func (h *Helm) PackageChartFromLocalFiles(cosignKeyPath string) error { +func (h *Helm) PackageChartFromLocalFiles(ctx context.Context, cosignKeyPath string) error { spinner := message.NewProgressSpinner("Processing helm chart %s:%s from %s", h.chart.Name, h.chart.Version, h.chart.LocalPath) defer spinner.Stop() @@ -103,7 +103,7 @@ func (h *Helm) PackageChartFromLocalFiles(cosignKeyPath string) error { } // Finalize the chart - err = h.finalizeChartPackage(saved, cosignKeyPath) + err = h.finalizeChartPackage(ctx, saved, cosignKeyPath) if err != nil { return err } @@ -127,11 +127,11 @@ func (h *Helm) PackageChartFromGit(ctx context.Context, cosignKeyPath string) er // Set the directory for the chart and package it h.chart.LocalPath = filepath.Join(gitPath, h.chart.GitPath) - return h.PackageChartFromLocalFiles(cosignKeyPath) + return h.PackageChartFromLocalFiles(ctx, cosignKeyPath) } // DownloadPublishedChart loads a specific chart version from a remote repo. -func (h *Helm) DownloadPublishedChart(cosignKeyPath string) error { +func (h *Helm) DownloadPublishedChart(ctx context.Context, cosignKeyPath string) error { spinner := message.NewProgressSpinner("Processing helm chart %s:%s from repo %s", h.chart.Name, h.chart.Version, h.chart.URL) defer spinner.Stop() @@ -222,7 +222,7 @@ func (h *Helm) DownloadPublishedChart(cosignKeyPath string) error { } // Finalize the chart - err = h.finalizeChartPackage(saved, cosignKeyPath) + err = h.finalizeChartPackage(ctx, saved, cosignKeyPath) if err != nil { return err } @@ -246,7 +246,7 @@ func DownloadChartFromGitToTemp(ctx context.Context, url string, spinner *messag return gitCfg.GitPath, nil } -func (h *Helm) finalizeChartPackage(saved, cosignKeyPath string) error { +func (h *Helm) finalizeChartPackage(ctx context.Context, saved, cosignKeyPath string) error { // Ensure the name is consistent for deployments destinationTarball := StandardName(h.chartPath, h.chart) + ".tgz" err := os.Rename(saved, destinationTarball) @@ -254,19 +254,19 @@ func (h *Helm) finalizeChartPackage(saved, cosignKeyPath string) error { return fmt.Errorf("unable to save the final chart tarball: %w", err) } - err = h.packageValues(cosignKeyPath) + err = h.packageValues(ctx, cosignKeyPath) if err != nil { return fmt.Errorf("unable to process the values for the package: %w", err) } return nil } -func (h *Helm) packageValues(cosignKeyPath string) error { +func (h *Helm) packageValues(ctx context.Context, cosignKeyPath string) error { for valuesIdx, path := range h.chart.ValuesFiles { dst := StandardValuesName(h.valuesPath, h.chart, valuesIdx) if helpers.IsURL(path) { - if err := utils.DownloadToFile(path, dst, cosignKeyPath); err != nil { + if err := utils.DownloadToFile(ctx, path, dst, cosignKeyPath); err != nil { return fmt.Errorf(lang.ErrDownloading, path, err.Error()) } } else { diff --git a/src/pkg/packager/creator/normal.go b/src/pkg/packager/creator/normal.go index 3b34b7e846..de6a8cbf57 100644 --- a/src/pkg/packager/creator/normal.go +++ b/src/pkg/packager/creator/normal.go @@ -387,7 +387,7 @@ func (pc *PackageCreator) addComponent(ctx context.Context, component types.Zarf compressedFile := filepath.Join(componentPaths.Temp, compressedFileName) // If the file is an archive, download it to the componentPath.Temp - if err := utils.DownloadToFile(file.Source, compressedFile, component.DeprecatedCosignKeyPath); err != nil { + if err := utils.DownloadToFile(ctx, file.Source, compressedFile, component.DeprecatedCosignKeyPath); err != nil { return fmt.Errorf(lang.ErrDownloading, file.Source, err.Error()) } @@ -396,7 +396,7 @@ func (pc *PackageCreator) addComponent(ctx context.Context, component types.Zarf return fmt.Errorf(lang.ErrFileExtract, file.ExtractPath, compressedFileName, err.Error()) } } else { - if err := utils.DownloadToFile(file.Source, dst, component.DeprecatedCosignKeyPath); err != nil { + if err := utils.DownloadToFile(ctx, file.Source, dst, component.DeprecatedCosignKeyPath); err != nil { return fmt.Errorf(lang.ErrDownloading, file.Source, err.Error()) } } @@ -447,7 +447,7 @@ func (pc *PackageCreator) addComponent(ctx context.Context, component types.Zarf dst := filepath.Join(componentPaths.Base, rel) if helpers.IsURL(data.Source) { - if err := utils.DownloadToFile(data.Source, dst, component.DeprecatedCosignKeyPath); err != nil { + if err := utils.DownloadToFile(ctx, data.Source, dst, component.DeprecatedCosignKeyPath); err != nil { return fmt.Errorf(lang.ErrDownloading, data.Source, err.Error()) } } else { @@ -480,7 +480,7 @@ func (pc *PackageCreator) addComponent(ctx context.Context, component types.Zarf // Copy manifests without any processing. spinner.Updatef("Copying manifest %s", path) if helpers.IsURL(path) { - if err := utils.DownloadToFile(path, dst, component.DeprecatedCosignKeyPath); err != nil { + if err := utils.DownloadToFile(ctx, path, dst, component.DeprecatedCosignKeyPath); err != nil { return fmt.Errorf(lang.ErrDownloading, path, err.Error()) } } else { diff --git a/src/pkg/packager/prepare.go b/src/pkg/packager/prepare.go index 340ac1c85a..4950e78748 100644 --- a/src/pkg/packager/prepare.go +++ b/src/pkg/packager/prepare.go @@ -233,7 +233,7 @@ func (p *Packager) findImages(ctx context.Context) (imgMap map[string][]string, if helpers.IsURL(f) { mname := fmt.Sprintf("manifest-%s-%d.yaml", manifest.Name, idx) destination := filepath.Join(componentPaths.Manifests, mname) - if err := utils.DownloadToFile(f, destination, component.DeprecatedCosignKeyPath); err != nil { + if err := utils.DownloadToFile(ctx, f, destination, component.DeprecatedCosignKeyPath); err != nil { return nil, fmt.Errorf(lang.ErrDownloading, f, err.Error()) } f = destination diff --git a/src/pkg/packager/sources/url.go b/src/pkg/packager/sources/url.go index 02fc785d81..d3d79af237 100644 --- a/src/pkg/packager/sources/url.go +++ b/src/pkg/packager/sources/url.go @@ -30,7 +30,7 @@ type URLSource struct { } // Collect downloads a package from the source URL. -func (s *URLSource) Collect(_ context.Context, dir string) (string, error) { +func (s *URLSource) Collect(ctx context.Context, dir string) (string, error) { if !config.CommonOptions.Insecure && s.Shasum == "" && !strings.HasPrefix(s.PackageSource, helpers.SGETURLPrefix) { return "", fmt.Errorf("remote package provided without a shasum, use --insecure to ignore, or provide one w/ --shasum") } @@ -43,7 +43,7 @@ func (s *URLSource) Collect(_ context.Context, dir string) (string, error) { dstTarball := filepath.Join(dir, "zarf-package-url-unknown") - if err := utils.DownloadToFile(packageURL, dstTarball, s.SGetKeyPath); err != nil { + if err := utils.DownloadToFile(ctx, packageURL, dstTarball, s.SGetKeyPath); err != nil { return "", err } diff --git a/src/pkg/utils/network.go b/src/pkg/utils/network.go index e17c086161..4bd6945f36 100644 --- a/src/pkg/utils/network.go +++ b/src/pkg/utils/network.go @@ -39,7 +39,7 @@ func parseChecksum(src string) (string, string, error) { } // DownloadToFile downloads a given URL to the target filepath (including the cosign key if necessary). -func DownloadToFile(src string, dst string, cosignKeyPath string) (err error) { +func DownloadToFile(ctx context.Context, src string, dst string, cosignKeyPath string) (err error) { message.Debugf("Downloading %s to %s", src, dst) // check if the parsed URL has a checksum // if so, remove it and use the checksum to validate the file @@ -66,7 +66,7 @@ func DownloadToFile(src string, dst string, cosignKeyPath string) (err error) { } // If the source url starts with the sget protocol use that, otherwise do a typical GET call if parsed.Scheme == helpers.SGETURLScheme { - err = Sget(context.TODO(), src, cosignKeyPath, file) + err = Sget(ctx, src, cosignKeyPath, file) if err != nil { return fmt.Errorf("unable to download file with sget: %s: %w", src, err) } diff --git a/src/pkg/utils/network_test.go b/src/pkg/utils/network_test.go index d2357f907d..02e308f170 100644 --- a/src/pkg/utils/network_test.go +++ b/src/pkg/utils/network_test.go @@ -13,6 +13,8 @@ import ( "strings" "testing" + "github.com/zarf-dev/zarf/src/test/testutil" + "github.com/stretchr/testify/require" "github.com/defenseunicorns/pkg/helpers/v2" @@ -136,7 +138,7 @@ func TestDownloadToFile(t *testing.T) { } fmt.Println(src) dst := filepath.Join(t.TempDir(), tt.fileName) - err := DownloadToFile(src, dst, "") + err := DownloadToFile(testutil.TestContext(t), src, dst, "") if tt.expectedErr != "" { require.ErrorContains(t, err, tt.expectedErr) return diff --git a/src/test/e2e/50_oci_publish_deploy_test.go b/src/test/e2e/50_oci_publish_deploy_test.go index 81e5e39f54..f8df882fcf 100644 --- a/src/test/e2e/50_oci_publish_deploy_test.go +++ b/src/test/e2e/50_oci_publish_deploy_test.go @@ -5,7 +5,6 @@ package test import ( - "context" "fmt" "path/filepath" "strings" @@ -16,6 +15,7 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/zarf-dev/zarf/src/pkg/zoci" + "github.com/zarf-dev/zarf/src/test/testutil" "oras.land/oras-go/v2/registry" "oras.land/oras-go/v2/registry/remote" ) @@ -139,7 +139,7 @@ func (suite *PublishDeploySuiteTestSuite) Test_3_Copy() { suite.NoError(err) reg.PlainHTTP = true attempt := 0 - ctx := context.TODO() + ctx := testutil.TestContext(t) for attempt <= 5 { err = reg.Ping(ctx) if err == nil { diff --git a/src/test/external/ext_in_cluster_test.go b/src/test/external/ext_in_cluster_test.go index 97a52fda0a..cde98f31cf 100644 --- a/src/test/external/ext_in_cluster_test.go +++ b/src/test/external/ext_in_cluster_test.go @@ -19,6 +19,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/zarf-dev/zarf/src/pkg/cluster" "github.com/zarf-dev/zarf/src/pkg/utils/exec" + "github.com/zarf-dev/zarf/src/test/testutil" "k8s.io/apimachinery/pkg/runtime/schema" "sigs.k8s.io/cli-utils/pkg/object" ) @@ -103,7 +104,7 @@ func (suite *ExtInClusterTestSuite) Test_0_Mirror() { c, err := cluster.NewCluster() suite.NoError(err) - ctx := context.TODO() + ctx := testutil.TestContext(suite.T()) // Check that the registry contains the images we want tunnelReg, err := c.NewTunnel("external-registry", "svc", "external-registry-docker-registry", "", 0, 5000) diff --git a/src/test/external/ext_out_cluster_test.go b/src/test/external/ext_out_cluster_test.go index dfcc9da580..88820cebea 100644 --- a/src/test/external/ext_out_cluster_test.go +++ b/src/test/external/ext_out_cluster_test.go @@ -20,6 +20,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/zarf-dev/zarf/src/pkg/utils" "github.com/zarf-dev/zarf/src/pkg/utils/exec" + "github.com/zarf-dev/zarf/src/test/testutil" "helm.sh/helm/v3/pkg/repo" ) @@ -207,7 +208,7 @@ func (suite *ExtOutClusterTestSuite) createHelmChartInGitea(baseURL string, user podinfoTarballPath := filepath.Join(tempDir, fmt.Sprintf("podinfo-%s.tgz", podInfoVersion)) suite.NoError(err, "Unable to package chart") - err = utils.DownloadToFile(fmt.Sprintf("https://stefanprodan.github.io/podinfo/podinfo-%s.tgz", podInfoVersion), podinfoTarballPath, "") + err = utils.DownloadToFile(testutil.TestContext(suite.T()), fmt.Sprintf("https://stefanprodan.github.io/podinfo/podinfo-%s.tgz", podInfoVersion), podinfoTarballPath, "") suite.NoError(err) url := fmt.Sprintf("%s/api/packages/%s/helm/api/charts", baseURL, username) diff --git a/src/test/testutil/testutil.go b/src/test/testutil/testutil.go new file mode 100644 index 0000000000..862bca937d --- /dev/null +++ b/src/test/testutil/testutil.go @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: 2021-Present The Zarf Authors + +// Package testutil provides global testing helper functions +package testutil + +import ( + "context" + "testing" +) + +// TestContext takes a testing.T and returns a context that is +// attached to the test by t.Cleanup() +func TestContext(t *testing.T) context.Context { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + return ctx +}