Skip to content

Commit 7389c4f

Browse files
committed
Refactor nvml CDI spec generation for consistency
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 0b6863e commit 7389c4f

File tree

5 files changed

+177
-163
lines changed

5 files changed

+177
-163
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
312312
return nil, fmt.Errorf("failed to create CDI library: %v", err)
313313
}
314314

315-
deviceSpecs, err := cdilib.GetAllDeviceSpecs()
315+
deviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
316316
if err != nil {
317317
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
318318
}

pkg/nvcdi/api.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ type Interface interface {
3333
GetAllDeviceSpecs() ([]specs.Device, error)
3434
}
3535

36+
// A deviceSpecGenerator is used to generate the specs for a set of devices.
37+
type deviceSpecGenerator interface {
38+
GetDeviceSpecs() ([]specs.Device, error)
39+
}
40+
3641
// A HookName represents one of the predefined NVIDIA CDI hooks.
3742
type HookName = discover.HookName
3843

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,41 +19,66 @@ package nvcdi
1919
import (
2020
"fmt"
2121

22-
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2322
"tags.cncf.io/container-device-interface/pkg/cdi"
2423
"tags.cncf.io/container-device-interface/specs-go"
2524

25+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
26+
"github.com/NVIDIA/go-nvml/pkg/nvml"
27+
2628
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2729
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2830
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
2931
)
3032

31-
// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
32-
func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
33-
edits, err := l.GetGPUDeviceEdits(d)
33+
type fullGPUDeviceSpecGenerator struct {
34+
*nvmllib
35+
id string
36+
index int
37+
device device.Device
38+
}
39+
40+
var _ deviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)
41+
42+
func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (deviceSpecGenerator, error) {
43+
device, err := l.devicelib.NewDevice(nvmlDevice)
3444
if err != nil {
35-
return nil, fmt.Errorf("failed to get edits for device: %v", err)
45+
return nil, err
3646
}
3747

38-
var deviceSpecs []specs.Device
39-
names, err := l.deviceNamers.GetDeviceNames(i, convert{d})
48+
e := &fullGPUDeviceSpecGenerator{
49+
nvmllib: l,
50+
id: id,
51+
device: device,
52+
}
53+
return e, nil
54+
}
55+
56+
func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
57+
deviceEdits, err := l.getDeviceEdits()
58+
if err != nil {
59+
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", "TODO", err)
60+
}
61+
62+
names, err := l.getNames()
4063
if err != nil {
41-
return nil, fmt.Errorf("failed to get device name: %v", err)
64+
return nil, fmt.Errorf("failed to get edits for device: %w", err)
4265
}
66+
67+
var deviceSpecs []specs.Device
4368
for _, name := range names {
44-
spec := specs.Device{
69+
deviceSpec := specs.Device{
4570
Name: name,
46-
ContainerEdits: *edits.ContainerEdits,
71+
ContainerEdits: *deviceEdits.ContainerEdits,
4772
}
48-
deviceSpecs = append(deviceSpecs, spec)
73+
deviceSpecs = append(deviceSpecs, deviceSpec)
4974
}
5075

5176
return deviceSpecs, nil
5277
}
5378

5479
// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
55-
func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
56-
device, err := l.newFullGPUDiscoverer(d)
80+
func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, error) {
81+
device, err := l.newFullGPUDiscoverer(l.device)
5782
if err != nil {
5883
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
5984
}
@@ -66,8 +91,12 @@ func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error
6691
return editsForDevice, nil
6792
}
6893

94+
func (l *fullGPUDeviceSpecGenerator) getNames() ([]string, error) {
95+
return l.deviceNamers.GetDeviceNames(l.index, convert{l.device})
96+
}
97+
6998
// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.
70-
func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
99+
func (l *fullGPUDeviceSpecGenerator) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
71100
deviceNodes, err := dgpu.NewForDevice(d,
72101
dgpu.WithDevRoot(l.devRoot),
73102
dgpu.WithLogger(l.logger),

pkg/nvcdi/lib-nvml.go

Lines changed: 77 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,29 @@ import (
2828

2929
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
3030
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
31-
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
3231
)
3332

3433
type nvmllib nvcdilib
3534

36-
var _ Interface = (*nvmllib)(nil)
35+
var _ wrapped = (*nvmllib)(nil)
3736

38-
// GetSpec should not be called for nvmllib
39-
func (l *nvmllib) GetSpec(...string) (spec.Interface, error) {
40-
return nil, fmt.Errorf("unexpected call to nvmllib.GetSpec()")
41-
}
37+
// GetCommonEdits generates a CDI specification that can be used for ANY devices
38+
func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
39+
common, err := l.newCommonNVMLDiscoverer()
40+
if err != nil {
41+
return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err)
42+
}
4243

43-
// GetAllDeviceSpecs returns the device specs for all available devices.
44-
func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
45-
var deviceSpecs []specs.Device
44+
return edits.FromDiscoverer(common)
45+
}
4646

47+
// GetDeviceSpecsByID returns the CDI device specs for the devices represented
48+
// by the requested identifiers. Here an identifier is one of the following:
49+
// * an index of a GPU or MIG device
50+
// * a UUID of a GPU or MIG device
51+
func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
4752
if r := l.nvmllib.Init(); r != nvml.SUCCESS {
48-
return nil, fmt.Errorf("failed to initialize NVML: %v", r)
53+
return nil, fmt.Errorf("failed to initialize NVML: %w", r)
4954
}
5055
defer func() {
5156
if r := l.nvmllib.Shutdown(); r != nvml.SUCCESS {
@@ -66,93 +71,83 @@ func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
6671
}()
6772
}
6873

69-
gpuDeviceSpecs, err := l.getGPUDeviceSpecs()
74+
generators, err := l.getDeviceSpecGeneratorsForIDs(ids...)
7075
if err != nil {
7176
return nil, err
7277
}
73-
deviceSpecs = append(deviceSpecs, gpuDeviceSpecs...)
7478

75-
migDeviceSpecs, err := l.getMigDeviceSpecs()
76-
if err != nil {
77-
return nil, err
78-
}
79-
deviceSpecs = append(deviceSpecs, migDeviceSpecs...)
80-
81-
return deviceSpecs, nil
79+
return generators.GetDeviceSpecs()
8280
}
8381

84-
// GetCommonEdits generates a CDI specification that can be used for ANY devices
85-
func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
86-
common, err := l.newCommonNVMLDiscoverer()
87-
if err != nil {
88-
return nil, fmt.Errorf("failed to create discoverer for common entities: %v", err)
82+
func (l *nvmllib) newDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (deviceSpecGenerator, error) {
83+
isMig, ret := nvmlDevice.IsMigDeviceHandle()
84+
if ret != nvml.SUCCESS {
85+
return nil, ret
86+
}
87+
if isMig {
88+
return l.newMIGDeviceSpecGeneratorFromNVMLDevice(id, nvmlDevice)
8989
}
9090

91-
return edits.FromDiscoverer(common)
91+
return l.newFullGPUDeviceSpecGeneratorFromNVMLDevice(id, nvmlDevice)
9292
}
9393

94-
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
95-
// the provided identifiers, where an identifier is an index or UUID of a valid
96-
// GPU device.
97-
// Deprecated: Use GetDeviceSpecsBy instead.
98-
func (l *nvmllib) GetDeviceSpecsByID(ids ...string) ([]specs.Device, error) {
94+
func (l *nvmllib) getDeviceSpecGeneratorsForIDs(ids ...string) (deviceSpecGenerators, error) {
9995
var identifiers []device.Identifier
100-
for _, id := range ids {
101-
identifiers = append(identifiers, device.Identifier(id))
102-
}
103-
return l.GetDeviceSpecsBy(identifiers...)
104-
}
105-
106-
// GetDeviceSpecsBy returns the device specs for devices with the specified identifiers.
107-
func (l *nvmllib) GetDeviceSpecsBy(identifiers ...device.Identifier) ([]specs.Device, error) {
10896
for _, id := range identifiers {
10997
if id == "all" {
110-
return l.GetAllDeviceSpecs()
98+
return l.getDeviceSpecGeneratorsForAllDevices()
11199
}
100+
identifiers = append(identifiers, id)
112101
}
113102

114-
var deviceSpecs []specs.Device
115-
116-
if r := l.nvmllib.Init(); r != nvml.SUCCESS {
117-
return nil, fmt.Errorf("failed to initialize NVML: %w", r)
103+
devices, err := l.getNVMLDevicesByID(identifiers...)
104+
if err != nil {
105+
return nil, err
118106
}
119-
defer func() {
120-
if r := l.nvmllib.Shutdown(); r != nvml.SUCCESS {
121-
l.logger.Warningf("failed to shutdown NVML: %v", r)
122-
}
123-
}()
124107

125-
if l.nvsandboxutilslib != nil {
126-
if r := l.nvsandboxutilslib.Init(l.driverRoot); r != nvsandboxutils.SUCCESS {
127-
l.logger.Warningf("Failed to init nvsandboxutils: %v; ignoring", r)
128-
l.nvsandboxutilslib = nil
108+
var DeviceSpecGenerators deviceSpecGenerators
109+
for i, device := range devices {
110+
editor, err := l.newDeviceSpecGeneratorFromNVMLDevice(ids[i], device)
111+
if err != nil {
112+
return nil, err
129113
}
130-
defer func() {
131-
if l.nvsandboxutilslib == nil {
132-
return
133-
}
134-
_ = l.nvsandboxutilslib.Shutdown()
135-
}()
114+
DeviceSpecGenerators = append(DeviceSpecGenerators, editor)
136115
}
137116

138-
nvmlDevices, err := l.getNVMLDevicesByID(identifiers...)
117+
return DeviceSpecGenerators, nil
118+
}
119+
120+
func (l *nvmllib) getDeviceSpecGeneratorsForAllDevices() ([]deviceSpecGenerator, error) {
121+
var DeviceSpecGenerators []deviceSpecGenerator
122+
err := l.devicelib.VisitDevices(func(i int, d device.Device) error {
123+
e := &fullGPUDeviceSpecGenerator{
124+
nvmllib: l,
125+
id: fmt.Sprintf("%d", i),
126+
device: d,
127+
}
128+
129+
DeviceSpecGenerators = append(DeviceSpecGenerators, e)
130+
return nil
131+
})
139132
if err != nil {
140-
return nil, fmt.Errorf("failed to get NVML device handles: %w", err)
133+
return nil, fmt.Errorf("failed to get full GPU device editors: %w", err)
141134
}
142135

143-
for i, nvmlDevice := range nvmlDevices {
144-
deviceEdits, err := l.getEditsForDevice(nvmlDevice)
145-
if err != nil {
146-
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err)
147-
}
148-
deviceSpec := specs.Device{
149-
Name: string(identifiers[i]),
150-
ContainerEdits: *deviceEdits.ContainerEdits,
136+
err = l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
137+
e := &migDeviceSpecGenerator{
138+
nvmllib: l,
139+
id: fmt.Sprintf("%d:%d", i, j),
140+
parent: d,
141+
device: mig,
151142
}
152-
deviceSpecs = append(deviceSpecs, deviceSpec)
143+
DeviceSpecGenerators = append(DeviceSpecGenerators, e)
144+
return nil
145+
})
146+
if err != nil {
147+
return nil, fmt.Errorf("failed to get MIG device editors: %w", err)
153148
}
154149

155-
return deviceSpecs, nil
150+
return DeviceSpecGenerators, nil
156151
}
157152

158153
// TODO: move this to go-nvlib?
@@ -201,76 +196,21 @@ func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
201196
return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id)
202197
}
203198

204-
func (l *nvmllib) getEditsForDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) {
205-
mig, err := nvmlDevice.IsMigDeviceHandle()
206-
if err != nvml.SUCCESS {
207-
return nil, fmt.Errorf("failed to determine if device handle is a MIG device: %w", err)
208-
}
209-
if mig {
210-
return l.getEditsForMIGDevice(nvmlDevice)
211-
}
212-
return l.getEditsForGPUDevice(nvmlDevice)
213-
}
214-
215-
func (l *nvmllib) getEditsForGPUDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) {
216-
nvlibDevice, err := l.devicelib.NewDevice(nvmlDevice)
217-
if err != nil {
218-
return nil, fmt.Errorf("failed to construct device: %w", err)
219-
}
220-
deviceEdits, err := l.GetGPUDeviceEdits(nvlibDevice)
221-
if err != nil {
222-
return nil, fmt.Errorf("failed to get GPU device edits: %w", err)
223-
}
224-
225-
return deviceEdits, nil
226-
}
227-
228-
func (l *nvmllib) getEditsForMIGDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) {
229-
nvmlParentDevice, ret := nvmlDevice.GetDeviceHandleFromMigDeviceHandle()
230-
if ret != nvml.SUCCESS {
231-
return nil, fmt.Errorf("failed to get parent device handle: %w", ret)
232-
}
233-
nvlibMigDevice, err := l.devicelib.NewMigDevice(nvmlDevice)
234-
if err != nil {
235-
return nil, fmt.Errorf("failed to construct device: %w", err)
236-
}
237-
nvlibParentDevice, err := l.devicelib.NewDevice(nvmlParentDevice)
238-
if err != nil {
239-
return nil, fmt.Errorf("failed to construct parent device: %w", err)
240-
}
241-
return l.GetMIGDeviceEdits(nvlibParentDevice, nvlibMigDevice)
242-
}
199+
type deviceSpecGenerators []deviceSpecGenerator
243200

244-
func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) {
245-
var deviceSpecs []specs.Device
246-
err := l.devicelib.VisitDevices(func(i int, d device.Device) error {
247-
specsForDevice, err := l.GetGPUDeviceSpecs(i, d)
248-
if err != nil {
249-
return err
201+
// GetDeviceSpecs returns the combined specs for each device spec generator.
202+
func (g deviceSpecGenerators) GetDeviceSpecs() ([]specs.Device, error) {
203+
var allDeviceSpecs []specs.Device
204+
for _, dsg := range g {
205+
if dsg == nil {
206+
continue
250207
}
251-
deviceSpecs = append(deviceSpecs, specsForDevice...)
252-
253-
return nil
254-
})
255-
if err != nil {
256-
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
257-
}
258-
return deviceSpecs, err
259-
}
260-
261-
func (l *nvmllib) getMigDeviceSpecs() ([]specs.Device, error) {
262-
var deviceSpecs []specs.Device
263-
err := l.devicelib.VisitMigDevices(func(i int, d device.Device, j int, mig device.MigDevice) error {
264-
specsForDevice, err := l.GetMIGDeviceSpecs(i, d, j, mig)
208+
deviceSpecs, err := dsg.GetDeviceSpecs()
265209
if err != nil {
266-
return err
210+
return nil, err
267211
}
268-
deviceSpecs = append(deviceSpecs, specsForDevice...)
269-
270-
return nil
271-
})
272-
if err != nil {
273-
return nil, fmt.Errorf("failed to generate CDI edits for GPU devices: %v", err)
212+
allDeviceSpecs = append(allDeviceSpecs, deviceSpecs...)
274213
}
275-
return deviceSpecs, err
214+
215+
return allDeviceSpecs, nil
276216
}

0 commit comments

Comments
 (0)