From 1a2aca35078ffefe4f2970cec4e661bbe26f1d1c Mon Sep 17 00:00:00 2001 From: Kathryn Baldauf Date: Sun, 30 Apr 2023 23:08:54 -0700 Subject: [PATCH] Guest agent support for partitions on SCSI devices * Update `ControllerLunToName` to `GetDevicePath` and take in partition as an additional param * Wait for partition subdirectory to appear for the devices * Update device encryption and verity device names with partition index * Update device encryption and verity device tests * Add new unit tests for `GetDevicePath` Signed-off-by: Kathryn Baldauf --- internal/guest/runtime/hcsv2/uvm.go | 8 +- internal/guest/storage/scsi/scsi.go | 80 +++++-- internal/guest/storage/scsi/scsi_test.go | 267 ++++++++++++++++++++--- 3 files changed, 306 insertions(+), 49 deletions(-) diff --git a/internal/guest/runtime/hcsv2/uvm.go b/internal/guest/runtime/hcsv2/uvm.go index 54e00f571f..f1c5540a54 100644 --- a/internal/guest/runtime/hcsv2/uvm.go +++ b/internal/guest/runtime/hcsv2/uvm.go @@ -555,8 +555,7 @@ func (h *Host) modifyHostSettings(ctx context.Context, containerID string, req * if !mvd.ReadOnly { localCtx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - var source string - source, err = scsi.ControllerLunToName(localCtx, mvd.Controller, mvd.Lun) + source, err := scsi.GetDevicePath(localCtx, mvd.Controller, mvd.Lun, mvd.Partition) if err != nil { return err } @@ -982,7 +981,7 @@ func modifyMappedVirtualDisk( } } - return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.MountPath, + return scsi.Mount(mountCtx, mvd.Controller, mvd.Lun, mvd.Partition, mvd.MountPath, mvd.ReadOnly, mvd.Encrypted, mvd.Options, mvd.VerityInfo) } return nil @@ -994,7 +993,8 @@ func modifyMappedVirtualDisk( } } - if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.MountPath, mvd.Encrypted, mvd.VerityInfo); err != nil { + if err := scsi.Unmount(ctx, mvd.Controller, mvd.Lun, mvd.Partition, + mvd.MountPath, mvd.Encrypted, mvd.VerityInfo); err != nil { return err } } diff --git a/internal/guest/storage/scsi/scsi.go b/internal/guest/storage/scsi/scsi.go index 32dffc425c..ffcfb24521 100644 --- a/internal/guest/storage/scsi/scsi.go +++ b/internal/guest/storage/scsi/scsi.go @@ -6,6 +6,7 @@ package scsi import ( "context" "fmt" + "io/fs" "os" "path" "path/filepath" @@ -32,8 +33,12 @@ var ( osRemoveAll = os.RemoveAll unixMount = unix.Mount - // controllerLunToName is stubbed to make testing `Mount` easier. - controllerLunToName = ControllerLunToName + // mock functions for testing getDevicePath + osReadDir = os.ReadDir + osStat = os.Stat + + // getDevicePath is stubbed to make testing `Mount` easier. + getDevicePath = GetDevicePath // createVerityTarget is stubbed for unit testing `Mount`. createVerityTarget = dm.CreateVerityTarget // removeDevice is stubbed for unit testing `Mount`. @@ -49,8 +54,8 @@ var ( const ( scsiDevicesPath = "/sys/bus/scsi/devices" vmbusDevicesPath = "/sys/bus/vmbus/devices" - verityDeviceFmt = "dm-verity-scsi-contr%d-lun%d-%s" - cryptDeviceFmt = "dm-crypt-scsi-contr%d-lun%d" + verityDeviceFmt = "dm-verity-scsi-contr%d-lun%d-p%d-%s" + cryptDeviceFmt = "dm-crypt-scsi-contr%d-lun%d-p%d" ) // ActualControllerNumber retrieves the actual controller number assigned to a SCSI controller @@ -98,6 +103,7 @@ func Mount( ctx context.Context, controller, lun uint8, + partition uint64, target string, readonly bool, encrypted bool, @@ -109,9 +115,11 @@ func Mount( span.AddAttributes( trace.Int64Attribute("controller", int64(controller)), - trace.Int64Attribute("lun", int64(lun))) + trace.Int64Attribute("lun", int64(lun)), + trace.Int64Attribute("partition", int64(partition)), + ) - source, err := controllerLunToName(spnCtx, controller, lun) + source, err := getDevicePath(spnCtx, controller, lun, partition) if err != nil { return err } @@ -123,7 +131,7 @@ func Mount( } if verityInfo != nil { - dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, deviceHash) + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, partition, deviceHash) if source, err = createVerityTarget(spnCtx, source, dmVerityName, verityInfo); err != nil { return err } @@ -156,7 +164,7 @@ func Mount( mountType := "ext4" if encrypted { - cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun) + cryptDeviceName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition) encryptedSource, err := encryptDevice(spnCtx, source, cryptDeviceName) if err != nil { // todo (maksiman): add better retry logic, similar to how SCSI device mounts are @@ -173,7 +181,7 @@ func Mount( for { if err := unixMount(source, target, mountType, flags, data); err != nil { - // The `source` found by controllerLunToName can take some time + // The `source` found by GetDevicePath can take some time // before its actually available under `/dev/sd*`. Retry while we // wait for `source` to show up. if errors.Is(err, unix.ENOENT) || errors.Is(err, unix.ENXIO) { @@ -210,6 +218,7 @@ func Unmount( ctx context.Context, controller, lun uint8, + partition uint64, target string, encrypted bool, verityInfo *guestresource.DeviceVerityInfo, @@ -221,6 +230,7 @@ func Unmount( span.AddAttributes( trace.Int64Attribute("controller", int64(controller)), trace.Int64Attribute("lun", int64(lun)), + trace.Int64Attribute("partition", int64(partition)), trace.StringAttribute("target", target)) // unmount target @@ -229,7 +239,7 @@ func Unmount( } if verityInfo != nil { - dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, verityInfo.RootDigest) + dmVerityName := fmt.Sprintf(verityDeviceFmt, controller, lun, partition, verityInfo.RootDigest) if err := removeDevice(dmVerityName); err != nil { // Ignore failures, since the path has been unmounted at this point. log.G(ctx).WithError(err).Debugf("failed to remove dm verity target: %s", dmVerityName) @@ -237,7 +247,7 @@ func Unmount( } if encrypted { - dmCryptName := fmt.Sprintf(cryptDeviceFmt, controller, lun) + dmCryptName := fmt.Sprintf(cryptDeviceFmt, controller, lun, partition) if err := cleanupCryptDevice(dmCryptName); err != nil { return fmt.Errorf("failed to cleanup dm-crypt target %s: %w", dmCryptName, err) } @@ -246,16 +256,18 @@ func Unmount( return nil } -// ControllerLunToName finds the `/dev/sd*` path to the SCSI device on -// `controller` index `lun`. -func ControllerLunToName(ctx context.Context, controller, lun uint8) (_ string, err error) { - ctx, span := oc.StartSpan(ctx, "scsi::ControllerLunToName") +// GetDevicePath finds the `/dev/sd*` path to the SCSI device on `controller` +// index `lun` with partition index `partition`. +func GetDevicePath(ctx context.Context, controller, lun uint8, partition uint64) (_ string, err error) { + ctx, span := oc.StartSpan(ctx, "scsi::GetDevicePath") defer span.End() defer func() { oc.SetSpanStatus(span, err) }() span.AddAttributes( trace.Int64Attribute("controller", int64(controller)), - trace.Int64Attribute("lun", int64(lun))) + trace.Int64Attribute("lun", int64(lun)), + trace.Int64Attribute("partition", int64(partition)), + ) scsiID := fmt.Sprintf("%d:0:0:%d", controller, lun) // Devices matching the given SCSI code should each have a subdirectory @@ -263,8 +275,8 @@ func ControllerLunToName(ctx context.Context, controller, lun uint8) (_ string, blockPath := filepath.Join(scsiDevicesPath, scsiID, "block") var deviceNames []os.DirEntry for { - deviceNames, err = os.ReadDir(blockPath) - if err != nil && !os.IsNotExist(err) { + deviceNames, err = osReadDir(blockPath) + if err != nil && !errors.Is(err, fs.ErrNotExist) { return "", err } if len(deviceNames) == 0 { @@ -282,8 +294,38 @@ func ControllerLunToName(ctx context.Context, controller, lun uint8) (_ string, if len(deviceNames) > 1 { return "", errors.Errorf("more than one block device could match SCSI ID \"%s\"", scsiID) } + deviceName := deviceNames[0].Name() + + // devices that have partitions have a subdirectory under + // /sys/bus/scsi/devices//block/ for each partition. + // Partitions use 1-based indexing, so if `partition` is 0, then we should + // return the device name without a partition index. + if partition != 0 { + partitionName := fmt.Sprintf("%s%d", deviceName, partition) + partitionPath := filepath.Join(blockPath, deviceName, partitionName) + + // Wait for the device partition to show up + for { + fi, err := osStat(partitionPath) + if err != nil && !errors.Is(err, fs.ErrNotExist) { + return "", err + } else if fi == nil { + // if the fileinfo is nil that means we didn't find the device, keep + // trying until the context is done or the device path shows up + select { + case <-ctx.Done(): + return "", ctx.Err() + default: + time.Sleep(time.Millisecond * 10) + continue + } + } + break + } + deviceName = partitionName + } - devicePath := filepath.Join("/dev", deviceNames[0].Name()) + devicePath := filepath.Join("/dev", deviceName) log.G(ctx).WithField("devicePath", devicePath).Debug("found device path") return devicePath, nil } diff --git a/internal/guest/storage/scsi/scsi_test.go b/internal/guest/storage/scsi/scsi_test.go index eb83002381..b5a176fcdb 100644 --- a/internal/guest/storage/scsi/scsi_test.go +++ b/internal/guest/storage/scsi/scsi_test.go @@ -8,17 +8,21 @@ import ( "errors" "fmt" "os" + "path/filepath" "testing" + "time" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" "golang.org/x/sys/unix" ) func clearTestDependencies() { + osReadDir = nil + osStat = nil osMkdirAll = nil osRemoveAll = nil unixMount = nil - controllerLunToName = nil + getDevicePath = nil createVerityTarget = nil encryptDevice = nil cleanupCryptDevice = nil @@ -33,7 +37,7 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { return expectedErr } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } @@ -41,6 +45,7 @@ func Test_Mount_Mkdir_Fails_Error(t *testing.T) { context.Background(), 0, 0, + 0, "", false, false, @@ -65,7 +70,7 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { } return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -77,6 +82,7 @@ func Test_Mount_Mkdir_ExpectedPath(t *testing.T) { context.Background(), 0, 0, + 0, target, false, false, @@ -101,7 +107,7 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { } return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -113,6 +119,7 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { context.Background(), 0, 0, + 0, target, false, false, @@ -123,7 +130,7 @@ func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) { } } -func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { +func Test_Mount_GetDevicePath_Valid_Controller(t *testing.T) { clearTestDependencies() // NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to @@ -133,7 +140,7 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { return nil } expectedController := uint8(2) - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { if expectedController != controller { t.Errorf("expected controller: %v, got: %v", expectedController, controller) return "", errors.New("unexpected controller") @@ -149,6 +156,7 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { context.Background(), expectedController, 0, + 0, "/fake/path", false, false, @@ -159,7 +167,7 @@ func Test_Mount_ControllerLunToName_Valid_Controller(t *testing.T) { } } -func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { +func Test_Mount_GetDevicePath_Valid_Lun(t *testing.T) { clearTestDependencies() // NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to @@ -169,7 +177,7 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { return nil } expectedLun := uint8(2) - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { if expectedLun != lun { t.Errorf("expected lun: %v, got: %v", expectedLun, lun) return "", errors.New("unexpected lun") @@ -185,6 +193,44 @@ func Test_Mount_ControllerLunToName_Valid_Lun(t *testing.T) { context.Background(), 0, expectedLun, + 0, + "/fake/path", + false, + false, + nil, + nil, + ); err != nil { + t.Fatalf("expected nil error got: %v", err) + } +} + +func Test_Mount_GetDevicePath_Valid_Partition(t *testing.T) { + clearTestDependencies() + + // NOTE: Do NOT set osRemoveAll because the mount succeeds. Expect it not to + // be called. + + osMkdirAll = func(path string, perm os.FileMode) error { + return nil + } + expectedPartition := uint64(3) + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { + if expectedPartition != partition { + t.Errorf("expected partition: %v, got: %v", expectedPartition, partition) + return "", errors.New("unexpected lun") + } + return "", nil + } + unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { + // Fake the mount success + return nil + } + + if err := Mount( + context.Background(), + 0, + 0, + expectedPartition, "/fake/path", false, false, @@ -201,7 +247,7 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } target := "/fake/path" @@ -224,6 +270,7 @@ func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) { context.Background(), 0, 0, + 0, target, false, false, @@ -247,7 +294,7 @@ func Test_Mount_Valid_Source(t *testing.T) { return nil } expectedSource := "/dev/sdz" - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return expectedSource, nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -257,7 +304,7 @@ func Test_Mount_Valid_Source(t *testing.T) { } return nil } - err := Mount(context.Background(), 0, 0, "/fake/path", false, false, nil, nil) + err := Mount(context.Background(), 0, 0, 0, "/fake/path", false, false, nil, nil) if err != nil { t.Fatalf("expected nil err, got: %v", err) } @@ -272,7 +319,7 @@ func Test_Mount_Valid_Target(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } expectedTarget := "/fake/path" @@ -288,6 +335,7 @@ func Test_Mount_Valid_Target(t *testing.T) { context.Background(), 0, 0, + 0, expectedTarget, false, false, @@ -307,7 +355,7 @@ func Test_Mount_Valid_FSType(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -323,6 +371,7 @@ func Test_Mount_Valid_FSType(t *testing.T) { context.Background(), 0, 0, + 0, "/fake/path", false, false, @@ -342,7 +391,7 @@ func Test_Mount_Valid_Flags(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -358,6 +407,7 @@ func Test_Mount_Valid_Flags(t *testing.T) { context.Background(), 0, 0, + 0, "/fake/path", false, false, @@ -377,7 +427,7 @@ func Test_Mount_Readonly_Valid_Flags(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -393,6 +443,7 @@ func Test_Mount_Readonly_Valid_Flags(t *testing.T) { context.Background(), 0, 0, + 0, "/fake/path", true, false, @@ -412,7 +463,7 @@ func Test_Mount_Valid_Data(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -427,6 +478,7 @@ func Test_Mount_Valid_Data(t *testing.T) { context.Background(), 0, 0, + 0, "/fake/path", false, false, @@ -446,7 +498,7 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { osMkdirAll = func(path string, perm os.FileMode) error { return nil } - controllerLunToName = func(ctx context.Context, controller, lun uint8) (string, error) { + getDevicePath = func(ctx context.Context, controller, lun uint8, partition uint64) (string, error) { return "", nil } unixMount = func(source string, target string, fstype string, flags uintptr, data string) error { @@ -462,6 +514,7 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { context.Background(), 0, 0, + 0, "/fake/path", true, false, @@ -477,13 +530,13 @@ func Test_Mount_Readonly_Valid_Data(t *testing.T) { func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) { clearTestDependencies() - expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, 0, "hash") expectedSource := "/dev/sdb" expectedMapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName) expectedTarget := "/foo" createVerityTargetCalled := false - controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + getDevicePath = func(_ context.Context, _, _ uint8, _ uint64) (string, error) { return expectedSource, nil } @@ -519,6 +572,7 @@ func Test_CreateVerityTarget_And_Mount_Called_With_Correct_Parameters(t *testing context.Background(), 0, 0, + 0, expectedTarget, true, false, @@ -536,10 +590,10 @@ func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { clearTestDependencies() expectedError := errors.New("osMkdirAll error") - expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, "hash") + expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, 0, 0, "hash") removeDeviceCalled := false - controllerLunToName = func(_ context.Context, _, _ uint8) (string, error) { + getDevicePath = func(_ context.Context, _, _ uint8, _ uint64) (string, error) { return "/dev/sdb", nil } @@ -567,6 +621,7 @@ func Test_osMkdirAllFails_And_RemoveDevice_Called(t *testing.T) { context.Background(), 0, 0, + 0, "/foo", true, false, @@ -586,7 +641,7 @@ func Test_Mount_EncryptDevice_Called(t *testing.T) { osMkdirAll = func(string, os.FileMode) error { return nil } - controllerLunToName = func(context.Context, uint8, uint8) (string, error) { + getDevicePath = func(context.Context, uint8, uint8, uint64) (string, error) { return "", nil } unixMount = func(string, string, string, uintptr, string) error { @@ -594,7 +649,7 @@ func Test_Mount_EncryptDevice_Called(t *testing.T) { } encryptDeviceCalled := false encryptDevice = func(_ context.Context, source string, devName string) (string, error) { - expectedCryptTarget := fmt.Sprintf(cryptDeviceFmt, 0, 0) + expectedCryptTarget := fmt.Sprintf(cryptDeviceFmt, 0, 0, 0) if devName != expectedCryptTarget { t.Fatalf("expected crypt device %q got %q", expectedCryptTarget, devName) } @@ -605,6 +660,7 @@ func Test_Mount_EncryptDevice_Called(t *testing.T) { context.Background(), 0, 0, + 0, "/fake/path", false, true, @@ -624,7 +680,7 @@ func Test_Mount_RemoveAllCalled_When_EncryptDevice_Fails(t *testing.T) { osMkdirAll = func(string, os.FileMode) error { return nil } - controllerLunToName = func(context.Context, uint8, uint8) (string, error) { + getDevicePath = func(context.Context, uint8, uint8, uint64) (string, error) { return "", nil } unixMount = func(string, string, string, uintptr, string) error { @@ -644,6 +700,7 @@ func Test_Mount_RemoveAllCalled_When_EncryptDevice_Fails(t *testing.T) { context.Background(), 0, 0, + 0, "/fake/path", false, true, @@ -669,7 +726,7 @@ func Test_Unmount_CleanupCryptDevice_Called(t *testing.T) { } cleanupCryptDeviceCalled := false cleanupCryptDevice = func(devName string) error { - expectedDevName := fmt.Sprintf(cryptDeviceFmt, 0, 0) + expectedDevName := fmt.Sprintf(cryptDeviceFmt, 0, 0, 0) if devName != expectedDevName { t.Fatalf("expected crypt target %q, got %q", expectedDevName, devName) } @@ -677,10 +734,168 @@ func Test_Unmount_CleanupCryptDevice_Called(t *testing.T) { return nil } - if err := Unmount(context.Background(), 0, 0, "/fake/path", true, nil); err != nil { + if err := Unmount(context.Background(), 0, 0, 0, "/fake/path", true, nil); err != nil { t.Fatalf("unexpected error: %s", err) } if !cleanupCryptDeviceCalled { t.Fatal("cleanupCryptDevice not called") } } + +// fakeFileInfo is a mock os.FileInfo that can be used to return +// in mock os calls +type fakeFileInfo struct { + name string +} + +func (f *fakeFileInfo) Name() string { + return f.name +} + +func (f *fakeFileInfo) Size() int64 { + // fake size + return 100 +} + +func (f *fakeFileInfo) Mode() os.FileMode { + // fake mode + return os.ModeDir +} + +func (f *fakeFileInfo) ModTime() time.Time { + // fake time + return time.Now() +} + +func (f *fakeFileInfo) IsDir() bool { + // fake isDir + return false +} + +func (f *fakeFileInfo) Sys() interface{} { + return nil +} + +// fakeDirEntry is a mock os.DirEntry that can be used to return in +// the response from the mocked os.ReadDir call. +type fakeDirEntry struct { + name string +} + +func (d *fakeDirEntry) Name() string { + return d.name +} + +func (d *fakeDirEntry) IsDir() bool { + return true +} + +func (d *fakeDirEntry) Type() os.FileMode { + return os.ModeDir +} + +func (d *fakeDirEntry) Info() (os.FileInfo, error) { + return &fakeFileInfo{name: d.name}, nil +} + +func Test_GetDevicePath_Device_With_Partition(t *testing.T) { + clearTestDependencies() + + deviceName := "sdd" + partition := uint64(1) + deviceWithPartitionName := deviceName + fmt.Sprintf("%d", partition) + expectedDevicePath := filepath.Join("/dev", deviceWithPartitionName) + + osReadDir = func(_ string) ([]os.DirEntry, error) { + entry := &fakeDirEntry{name: deviceName} + return []os.DirEntry{entry}, nil + } + + osStat = func(_ string) (os.FileInfo, error) { + return &fakeFileInfo{ + name: deviceWithPartitionName, + }, nil + } + + getDevicePath = GetDevicePath + + actualPath, err := getDevicePath(context.Background(), 0, 0, partition) + if err != nil { + t.Fatalf("expected to get no error, instead got %v", err) + } + if actualPath != expectedDevicePath { + t.Fatalf("expected to get %v, instead got %v", expectedDevicePath, actualPath) + } +} + +func Test_GetDevicePath_Device_With_Partition_Error(t *testing.T) { + clearTestDependencies() + + deviceName := "sdd" + partition := uint64(1) + + osReadDir = func(_ string) ([]os.DirEntry, error) { + entry := &fakeDirEntry{name: deviceName} + return []os.DirEntry{entry}, nil + } + + osStat = func(_ string) (os.FileInfo, error) { + return nil, nil + } + + getDevicePath = GetDevicePath + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + actualPath, err := getDevicePath(ctx, 0, 0, partition) + if err == nil { + t.Fatalf("expected to get an error, instead got %v", actualPath) + } +} + +func Test_GetDevicePath_Device_No_Partition(t *testing.T) { + clearTestDependencies() + + deviceName := "sdd" + expectedDevicePath := filepath.Join("/dev", deviceName) + + osReadDir = func(_ string) ([]os.DirEntry, error) { + entry := &fakeDirEntry{name: deviceName} + return []os.DirEntry{entry}, nil + } + + osStat = func(name string) (os.FileInfo, error) { + return nil, fmt.Errorf("should not make this call: %v", name) + } + + getDevicePath = GetDevicePath + + actualPath, err := getDevicePath(context.Background(), 0, 0, 0) + if err != nil { + t.Fatalf("expected to get no error, instead got %v", err) + } + if actualPath != expectedDevicePath { + t.Fatalf("expected to get %v, instead got %v", expectedDevicePath, actualPath) + } +} + +func Test_GetDevicePath_Device_No_Partition_Error(t *testing.T) { + clearTestDependencies() + + osReadDir = func(_ string) ([]os.DirEntry, error) { + return nil, nil + } + + osStat = func(name string) (os.FileInfo, error) { + return nil, fmt.Errorf("should not make this call: %v", name) + } + + getDevicePath = GetDevicePath + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + actualPath, err := getDevicePath(ctx, 0, 0, 0) + if err == nil { + t.Fatalf("expected to get an error, instead got %v", actualPath) + } +}