Skip to content

Commit

Permalink
feat: added blob mounting support for oras Copy functions
Browse files Browse the repository at this point in the history
Adds MountFrom and OnMounted to CopyGraphOptions.
Allows for trying to mount from multiple repositories.

Signed-off-by: Kyle M. Tarplee <kmtarplee@ieee.org>
  • Loading branch information
ktarplee committed Jan 3, 2024
1 parent 48f0943 commit 2d6b4c5
Show file tree
Hide file tree
Showing 4 changed files with 410 additions and 12 deletions.
89 changes: 88 additions & 1 deletion copy.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ type CopyGraphOptions struct {
// OnCopySkipped will be called when the sub-DAG rooted by the current node
// is skipped.
OnCopySkipped func(ctx context.Context, desc ocispec.Descriptor) error
// MountFrom returns the candidate repositories that desc may be mounted from.
// The OCI references will be tried in turn. If mounting fails on all of them, then it falls back to a copy.
MountFrom func(ctx context.Context, desc ocispec.Descriptor) ([]string, error)
// OnMounted will be invoked when desc is mounted.
OnMounted func(ctx context.Context, desc ocispec.Descriptor) error
// FindSuccessors finds the successors of the current node.
// fetcher provides cached access to the source storage, and is suitable
// for fetching non-leaf nodes like manifests. Since anything fetched from
Expand Down Expand Up @@ -259,12 +264,94 @@ func copyGraph(ctx context.Context, src content.ReadOnlyStorage, dst content.Sto
if exists {
return copyNode(ctx, proxy.Cache, dst, desc, opts)
}
return copyNode(ctx, src, dst, desc, opts)
return mountOrCopyNode(ctx, src, dst, desc, opts)
}

return syncutil.Go(ctx, limiter, fn, root)
}

// mountOrCopyNode enabled cross repository blob mounting.
// sourceReference is the repository to use for mounting (the mount point).
// mounter is the destination for the mount (a well-known implementation of this is *registry.Repository representing the target).
// OnMounted is called (if provided) when the blob is mounted.
// The original PreCopy hook is called only on copy, and therefore not when the blob is mounted.
func mountOrCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.Storage, desc ocispec.Descriptor, opts CopyGraphOptions) error {
mounter, ok := dst.(registry.Mounter)
if !ok {
// mounting is not supported by the destination
return copyNode(ctx, src, dst, desc, opts)
}

// Only care to mount blobs
if descriptor.IsManifest(desc) {
return copyNode(ctx, src, dst, desc, opts)
}

if opts.MountFrom == nil {
return copyNode(ctx, src, dst, desc, opts)
}

sourceRepositories, err := opts.MountFrom(ctx, desc)
if err != nil {
// Technically this error is not fatal, we can still attempt to copy the node
// But for consistency with the other callbacks we bail out.
return err
}

if len(sourceRepositories) == 0 {
return copyNode(ctx, src, dst, desc, opts)
}

skipContent := errors.New("skip content")
for i, sourceRepository := range sourceRepositories {
// try mounting this source repository
var mountFailed bool
getContent := func() (io.ReadCloser, error) {
// the invocation of getContent indicates that mounting has failed
mountFailed = true

if len(sourceRepositories)-1 == i {
// this is the last iteration so we need to actually get the content and do the copy

// call the original PreCopy function if it exists
if opts.PreCopy != nil {
if err := opts.PreCopy(ctx, desc); err != nil {
return nil, err
}
}
return src.Fetch(ctx, desc)
}

// We want to return an error that we will test for from mounter.Mount()
return nil, skipContent
}

// Mount or copy
if err := mounter.Mount(ctx, desc, sourceRepository, getContent); err != nil && !errors.Is(err, skipContent) {
return err
}

if !mountFailed {
// mounted, success
if opts.OnMounted != nil {
if err := opts.OnMounted(ctx, desc); err != nil {
return err
}
}
return nil
}
}

// we copied it
if opts.PostCopy != nil {
if err := opts.PostCopy(ctx, desc); err != nil {
return err
}
}

return nil
}

// doCopyNode copies a single content from the source CAS to the destination CAS.
func doCopyNode(ctx context.Context, src content.ReadOnlyStorage, dst content.Storage, desc ocispec.Descriptor) error {
rc, err := src.Fetch(ctx, desc)
Expand Down
252 changes: 251 additions & 1 deletion copy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1471,11 +1471,251 @@ func TestCopyGraph_WithOptions(t *testing.T) {
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
})

t.Run("MountFrom_Mounted", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numMount atomic.Int64
dst.mount = func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
numMount.Add(1)
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
}
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error {
numOnMounted.Add(1)
return nil
}
opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
numMountFrom.Add(1)
return []string{"source"}, nil
}
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMount.Load(), int64(4); got != expected {
t.Errorf("count(Mount()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(4); got != expected {
t.Errorf("count(OnMounted()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(3); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(3); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})

t.Run("MountFrom_Copied", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numMount atomic.Int64
dst.mount = func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
numMount.Add(1)
if expected := "source"; fromRepo != expected {
t.Fatalf("fromRepo = %v, want %v", fromRepo, expected)
}

rc, err := getContent()
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
}
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error {
numOnMounted.Add(1)
return nil
}
opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
numMountFrom.Add(1)
return []string{"source"}, nil
}
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMount.Load(), int64(4); got != expected {
t.Errorf("count(Mount()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(0); got != expected {
t.Errorf("count(OnMounted()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(7); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(7); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})

t.Run("MountFrom_Mounted_Second_Try", func(t *testing.T) {
root = descs[6]
dst := &countingStorage{storage: cas.NewMemory()}
var numMount atomic.Int64
dst.mount = func(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
numMount.Add(1)
switch fromRepo {
case "source":
rc, err := src.Fetch(ctx, desc)
if err != nil {
t.Fatalf("Failed to fetch content: %v", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
case "missing/the/data":
// simulate a registry mount will fail, so it will request the content to start the copy.
rc, err := getContent()
if err != nil {
return fmt.Errorf("getContent failed: %w", err)
}
defer rc.Close()
err = dst.storage.Push(ctx, desc, rc) // bypass the counters
if err != nil {
t.Fatalf("Failed to push content: %v", err)
}
return nil
default:
t.Fatalf("fromRepo = %v, want either %v or %v", fromRepo, "missing/the/data", "source")
return errors.ErrUnsupported

Check failure on line 1660 in copy_test.go

View workflow job for this annotation

GitHub Actions / build (1.20)

undefined: errors.ErrUnsupported

Check failure on line 1660 in copy_test.go

View workflow job for this annotation

GitHub Actions / Analyze (1.20)

undefined: errors.ErrUnsupported
}
}
opts = oras.CopyGraphOptions{}
var numPreCopy, numPostCopy, numOnMounted, numMountFrom atomic.Int64
opts.PreCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPreCopy.Add(1)
return nil
}
opts.PostCopy = func(ctx context.Context, desc ocispec.Descriptor) error {
numPostCopy.Add(1)
return nil
}
opts.OnMounted = func(ctx context.Context, d ocispec.Descriptor) error {
numOnMounted.Add(1)
return nil
}
opts.MountFrom = func(ctx context.Context, desc ocispec.Descriptor) ([]string, error) {
numMountFrom.Add(1)
return []string{"missing/the/data", "source"}, nil
}
if err := oras.CopyGraph(ctx, src, dst, root, opts); err != nil {
t.Fatalf("CopyGraph() error = %v, wantErr %v", err, errdef.ErrSizeExceedsLimit)
}

if got, expected := dst.numExists.Load(), int64(7); got != expected {
t.Errorf("count(Exists()) = %d, want %d", got, expected)
}
if got, expected := dst.numFetch.Load(), int64(0); got != expected {
t.Errorf("count(Fetch()) = %d, want %d", got, expected)
}
// 7 (exists) - 1 (skipped) = 6 pushes expected
if got, expected := dst.numPush.Load(), int64(3); got != expected {
// If we get >=7 then ErrSkipDesc did not short circuit the push like it is supposed to do.
t.Errorf("count(Push()) = %d, want %d", got, expected)
}
if got, expected := numMount.Load(), int64(4*2); got != expected {
t.Errorf("count(Mount()) = %d, want %d", got, expected)
}
if got, expected := numOnMounted.Load(), int64(4); got != expected {
t.Errorf("count(OnMounted()) = %d, want %d", got, expected)
}
if got, expected := numMountFrom.Load(), int64(4); got != expected {
t.Errorf("count(MountFrom()) = %d, want %d", got, expected)
}
if got, expected := numPreCopy.Load(), int64(3); got != expected {
t.Errorf("count(PreCopy()) = %d, want %d", got, expected)
}
if got, expected := numPostCopy.Load(), int64(3); got != expected {
t.Errorf("count(PostCopy()) = %d, want %d", got, expected)
}
})
}

// countingStorage counts the calls to its content.Storage methods
type countingStorage struct {
storage content.Storage
storage content.Storage
mount mountFunc

numExists, numFetch, numPush atomic.Int64
}

Expand All @@ -1494,6 +1734,16 @@ func (cs *countingStorage) Push(ctx context.Context, target ocispec.Descriptor,
return cs.storage.Push(ctx, target, r)
}

type mountFunc func(context.Context, ocispec.Descriptor, string, func() (io.ReadCloser, error)) error

func (cs *countingStorage) Mount(ctx context.Context,
desc ocispec.Descriptor,
fromRepo string,
getContent func() (io.ReadCloser, error),
) error {
return cs.mount(ctx, desc, fromRepo, getContent)
}

func TestCopyGraph_WithConcurrencyLimit(t *testing.T) {
src := cas.NewMemory()
// generate test content
Expand Down
Loading

0 comments on commit 2d6b4c5

Please sign in to comment.