@@ -28,24 +28,29 @@ import (
28
28
29
29
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
30
30
"github.com/NVIDIA/nvidia-container-toolkit/internal/nvsandboxutils"
31
- "github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
32
31
)
33
32
34
33
type nvmllib nvcdilib
35
34
36
- var _ Interface = (* nvmllib )(nil )
35
+ var _ wrapped = (* nvmllib )(nil )
37
36
38
- // GetSpec should not be called for nvmllib
39
- func (l * nvmllib ) GetSpec (... string ) (spec.Interface , error ) {
40
- return nil , fmt .Errorf ("unexpected call to nvmllib.GetSpec()" )
41
- }
37
+ // GetCommonEdits generates a CDI specification that can be used for ANY devices
38
+ func (l * nvmllib ) GetCommonEdits () (* cdi.ContainerEdits , error ) {
39
+ common , err := l .newCommonNVMLDiscoverer ()
40
+ if err != nil {
41
+ return nil , fmt .Errorf ("failed to create discoverer for common entities: %v" , err )
42
+ }
42
43
43
- // GetAllDeviceSpecs returns the device specs for all available devices.
44
- func (l * nvmllib ) GetAllDeviceSpecs () ([]specs.Device , error ) {
45
- var deviceSpecs []specs.Device
44
+ return edits .FromDiscoverer (common )
45
+ }
46
46
47
+ // GetDeviceSpecsByID returns the CDI device specs for the devices represented
48
+ // by the requested identifiers. Here an identifier is one of the following:
49
+ // * an index of a GPU or MIG device
50
+ // * a UUID of a GPU or MIG device
51
+ func (l * nvmllib ) GetDeviceSpecsByID (ids ... string ) ([]specs.Device , error ) {
47
52
if r := l .nvmllib .Init (); r != nvml .SUCCESS {
48
- return nil , fmt .Errorf ("failed to initialize NVML: %v " , r )
53
+ return nil , fmt .Errorf ("failed to initialize NVML: %w " , r )
49
54
}
50
55
defer func () {
51
56
if r := l .nvmllib .Shutdown (); r != nvml .SUCCESS {
@@ -66,93 +71,83 @@ func (l *nvmllib) GetAllDeviceSpecs() ([]specs.Device, error) {
66
71
}()
67
72
}
68
73
69
- gpuDeviceSpecs , err := l .getGPUDeviceSpecs ()
70
- if err != nil {
71
- return nil , err
72
- }
73
- deviceSpecs = append (deviceSpecs , gpuDeviceSpecs ... )
74
-
75
- migDeviceSpecs , err := l .getMigDeviceSpecs ()
74
+ generators , err := l .getDeviceSpecGeneratorsForIDs (ids ... )
76
75
if err != nil {
77
76
return nil , err
78
77
}
79
- deviceSpecs = append (deviceSpecs , migDeviceSpecs ... )
80
78
81
- return deviceSpecs , nil
79
+ return generators . GetDeviceSpecs ()
82
80
}
83
81
84
- // GetCommonEdits generates a CDI specification that can be used for ANY devices
85
- func (l * nvmllib ) GetCommonEdits () (* cdi.ContainerEdits , error ) {
86
- common , err := l .newCommonNVMLDiscoverer ()
87
- if err != nil {
88
- return nil , fmt .Errorf ("failed to create discoverer for common entities: %v" , err )
82
+ func (l * nvmllib ) newDeviceSpecGeneratorFromNVMLDevice (id string , nvmlDevice nvml.Device ) (deviceSpecGenerator , error ) {
83
+ isMig , ret := nvmlDevice .IsMigDeviceHandle ()
84
+ if ret != nvml .SUCCESS {
85
+ return nil , ret
86
+ }
87
+ if isMig {
88
+ return l .newMIGDeviceSpecGeneratorFromNVMLDevice (id , nvmlDevice )
89
89
}
90
90
91
- return edits . FromDiscoverer ( common )
91
+ return l . newFullGPUDeviceSpecGeneratorFromNVMLDevice ( id , nvmlDevice )
92
92
}
93
93
94
- // GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
95
- // the provided identifiers, where an identifier is an index or UUID of a valid
96
- // GPU device.
97
- // Deprecated: Use GetDeviceSpecsBy instead.
98
- func (l * nvmllib ) GetDeviceSpecsByID (ids ... string ) ([]specs.Device , error ) {
94
+ func (l * nvmllib ) getDeviceSpecGeneratorsForIDs (ids ... string ) (deviceSpecGenerators , error ) {
99
95
var identifiers []device.Identifier
100
96
for _ , id := range ids {
101
- identifiers = append (identifiers , device .Identifier (id ))
102
- }
103
- return l .GetDeviceSpecsBy (identifiers ... )
104
- }
105
-
106
- // GetDeviceSpecsBy returns the device specs for devices with the specified identifiers.
107
- func (l * nvmllib ) GetDeviceSpecsBy (identifiers ... device.Identifier ) ([]specs.Device , error ) {
108
- for _ , id := range identifiers {
109
97
if id == "all" {
110
- return l .GetAllDeviceSpecs ()
98
+ return l .getDeviceSpecGeneratorsForAllDevices ()
111
99
}
100
+ identifiers = append (identifiers , device .Identifier (id ))
112
101
}
113
102
114
- var deviceSpecs []specs.Device
115
-
116
- if r := l .nvmllib .Init (); r != nvml .SUCCESS {
117
- return nil , fmt .Errorf ("failed to initialize NVML: %w" , r )
103
+ devices , err := l .getNVMLDevicesByID (identifiers ... )
104
+ if err != nil {
105
+ return nil , err
118
106
}
119
- defer func () {
120
- if r := l .nvmllib .Shutdown (); r != nvml .SUCCESS {
121
- l .logger .Warningf ("failed to shutdown NVML: %v" , r )
122
- }
123
- }()
124
107
125
- if l .nvsandboxutilslib != nil {
126
- if r := l .nvsandboxutilslib .Init (l .driverRoot ); r != nvsandboxutils .SUCCESS {
127
- l .logger .Warningf ("Failed to init nvsandboxutils: %v; ignoring" , r )
128
- l .nvsandboxutilslib = nil
108
+ var DeviceSpecGenerators deviceSpecGenerators
109
+ for i , device := range devices {
110
+ editor , err := l .newDeviceSpecGeneratorFromNVMLDevice (ids [i ], device )
111
+ if err != nil {
112
+ return nil , err
129
113
}
130
- defer func () {
131
- if l .nvsandboxutilslib == nil {
132
- return
133
- }
134
- _ = l .nvsandboxutilslib .Shutdown ()
135
- }()
114
+ DeviceSpecGenerators = append (DeviceSpecGenerators , editor )
136
115
}
137
116
138
- nvmlDevices , err := l .getNVMLDevicesByID (identifiers ... )
117
+ return DeviceSpecGenerators , nil
118
+ }
119
+
120
+ func (l * nvmllib ) getDeviceSpecGeneratorsForAllDevices () ([]deviceSpecGenerator , error ) {
121
+ var DeviceSpecGenerators []deviceSpecGenerator
122
+ err := l .devicelib .VisitDevices (func (i int , d device.Device ) error {
123
+ e := & fullGPUDeviceSpecGenerator {
124
+ nvmllib : l ,
125
+ id : fmt .Sprintf ("%d" , i ),
126
+ device : d ,
127
+ }
128
+
129
+ DeviceSpecGenerators = append (DeviceSpecGenerators , e )
130
+ return nil
131
+ })
139
132
if err != nil {
140
- return nil , fmt .Errorf ("failed to get NVML device handles : %w" , err )
133
+ return nil , fmt .Errorf ("failed to get full GPU device editors : %w" , err )
141
134
}
142
135
143
- for i , nvmlDevice := range nvmlDevices {
144
- deviceEdits , err := l .getEditsForDevice (nvmlDevice )
145
- if err != nil {
146
- return nil , fmt .Errorf ("failed to get CDI device edits for identifier %q: %w" , identifiers [i ], err )
147
- }
148
- deviceSpec := specs.Device {
149
- Name : string (identifiers [i ]),
150
- ContainerEdits : * deviceEdits .ContainerEdits ,
136
+ err = l .devicelib .VisitMigDevices (func (i int , d device.Device , j int , mig device.MigDevice ) error {
137
+ e := & migDeviceSpecGenerator {
138
+ nvmllib : l ,
139
+ id : fmt .Sprintf ("%d:%d" , i , j ),
140
+ parent : d ,
141
+ device : mig ,
151
142
}
152
- deviceSpecs = append (deviceSpecs , deviceSpec )
143
+ DeviceSpecGenerators = append (DeviceSpecGenerators , e )
144
+ return nil
145
+ })
146
+ if err != nil {
147
+ return nil , fmt .Errorf ("failed to get MIG device editors: %w" , err )
153
148
}
154
149
155
- return deviceSpecs , nil
150
+ return DeviceSpecGenerators , nil
156
151
}
157
152
158
153
// TODO: move this to go-nvlib?
@@ -201,76 +196,21 @@ func (l *nvmllib) getNVMLDeviceByID(id device.Identifier) (nvml.Device, error) {
201
196
return nil , fmt .Errorf ("identifier is not a valid UUID or index: %q" , id )
202
197
}
203
198
204
- func (l * nvmllib ) getEditsForDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
205
- mig , err := nvmlDevice .IsMigDeviceHandle ()
206
- if err != nvml .SUCCESS {
207
- return nil , fmt .Errorf ("failed to determine if device handle is a MIG device: %w" , err )
208
- }
209
- if mig {
210
- return l .getEditsForMIGDevice (nvmlDevice )
211
- }
212
- return l .getEditsForGPUDevice (nvmlDevice )
213
- }
214
-
215
- func (l * nvmllib ) getEditsForGPUDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
216
- nvlibDevice , err := l .devicelib .NewDevice (nvmlDevice )
217
- if err != nil {
218
- return nil , fmt .Errorf ("failed to construct device: %w" , err )
219
- }
220
- deviceEdits , err := l .GetGPUDeviceEdits (nvlibDevice )
221
- if err != nil {
222
- return nil , fmt .Errorf ("failed to get GPU device edits: %w" , err )
223
- }
199
+ type deviceSpecGenerators []deviceSpecGenerator
224
200
225
- return deviceEdits , nil
226
- }
227
-
228
- func (l * nvmllib ) getEditsForMIGDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
229
- nvmlParentDevice , ret := nvmlDevice .GetDeviceHandleFromMigDeviceHandle ()
230
- if ret != nvml .SUCCESS {
231
- return nil , fmt .Errorf ("failed to get parent device handle: %w" , ret )
232
- }
233
- nvlibMigDevice , err := l .devicelib .NewMigDevice (nvmlDevice )
234
- if err != nil {
235
- return nil , fmt .Errorf ("failed to construct device: %w" , err )
236
- }
237
- nvlibParentDevice , err := l .devicelib .NewDevice (nvmlParentDevice )
238
- if err != nil {
239
- return nil , fmt .Errorf ("failed to construct parent device: %w" , err )
240
- }
241
- return l .GetMIGDeviceEdits (nvlibParentDevice , nvlibMigDevice )
242
- }
243
-
244
- func (l * nvmllib ) getGPUDeviceSpecs () ([]specs.Device , error ) {
245
- var deviceSpecs []specs.Device
246
- err := l .devicelib .VisitDevices (func (i int , d device.Device ) error {
247
- specsForDevice , err := l .GetGPUDeviceSpecs (i , d )
248
- if err != nil {
249
- return err
201
+ // GetDeviceSpecs returns the combined specs for each device spec generator.
202
+ func (g deviceSpecGenerators ) GetDeviceSpecs () ([]specs.Device , error ) {
203
+ var allDeviceSpecs []specs.Device
204
+ for _ , dsg := range g {
205
+ if dsg == nil {
206
+ continue
250
207
}
251
- deviceSpecs = append (deviceSpecs , specsForDevice ... )
252
-
253
- return nil
254
- })
255
- if err != nil {
256
- return nil , fmt .Errorf ("failed to generate CDI edits for GPU devices: %v" , err )
257
- }
258
- return deviceSpecs , err
259
- }
260
-
261
- func (l * nvmllib ) getMigDeviceSpecs () ([]specs.Device , error ) {
262
- var deviceSpecs []specs.Device
263
- err := l .devicelib .VisitMigDevices (func (i int , d device.Device , j int , mig device.MigDevice ) error {
264
- specsForDevice , err := l .GetMIGDeviceSpecs (i , d , j , mig )
208
+ deviceSpecs , err := dsg .GetDeviceSpecs ()
265
209
if err != nil {
266
- return err
210
+ return nil , err
267
211
}
268
- deviceSpecs = append (deviceSpecs , specsForDevice ... )
269
-
270
- return nil
271
- })
272
- if err != nil {
273
- return nil , fmt .Errorf ("failed to generate CDI edits for GPU devices: %v" , err )
212
+ allDeviceSpecs = append (allDeviceSpecs , deviceSpecs ... )
274
213
}
275
- return deviceSpecs , err
214
+
215
+ return allDeviceSpecs , nil
276
216
}
0 commit comments