Skip to content

Refactor cdi api #1166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmd/nvidia-ctk/cdi/generate/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ func (m command) generateSpec(opts *options) (spec.Interface, error) {
return nil, fmt.Errorf("failed to create CDI library: %v", err)
}

deviceSpecs, err := cdilib.GetAllDeviceSpecs()
deviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
if err != nil {
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
}
Expand Down
21 changes: 14 additions & 7 deletions pkg/nvcdi/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package nvcdi

import (
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"

Expand All @@ -27,14 +26,22 @@ import (

// Interface defines the API for the nvcdi package
type Interface interface {
GetSpec(...string) (spec.Interface, error)
SpecGenerator
GetCommonEdits() (*cdi.ContainerEdits, error)
GetAllDeviceSpecs() ([]specs.Device, error)
GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error)
GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error)
GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error)
GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error)
GetDeviceSpecsByID(...string) ([]specs.Device, error)
// Deprecated: GetAllDeviceSpecs is deprecated. Use GetDeviceSpecsByID("all") instead.
GetAllDeviceSpecs() ([]specs.Device, error)
}

// A SpecGenerator is used to generate a complete CDI spec for a collected set
// of devices.
type SpecGenerator interface {
GetSpec(...string) (spec.Interface, error)
}

// A DeviceSpecGenerator is used to generate the specs for one or more devices.
type DeviceSpecGenerator interface {
GetDeviceSpecs() ([]specs.Device, error)
}

// A HookName represents one of the predefined NVIDIA CDI hooks.
Expand Down
59 changes: 45 additions & 14 deletions pkg/nvcdi/full-gpu-nvml.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,68 @@ package nvcdi
import (
"fmt"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"github.com/NVIDIA/go-nvml/pkg/nvml"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/dgpu"
)

// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
edits, err := l.GetGPUDeviceEdits(d)
// A fullGPUDeviceSpecGenerator generates the CDI device specifications for a
// single full GPU.
type fullGPUDeviceSpecGenerator struct {
*nvmllib
id string
index int
device device.Device
}

var _ DeviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)

func (l *nvmllib) newFullGPUDeviceSpecGeneratorFromNVMLDevice(id string, nvmlDevice nvml.Device) (DeviceSpecGenerator, error) {
device, err := l.devicelib.NewDevice(nvmlDevice)
if err != nil {
return nil, fmt.Errorf("failed to get edits for device: %v", err)
return nil, err
}

var deviceSpecs []specs.Device
names, err := l.deviceNamers.GetDeviceNames(i, convert{d})
e := &fullGPUDeviceSpecGenerator{
nvmllib: l,
id: id,
Copy link
Preview

Copilot AI Jul 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fullGPUDeviceSpecGenerator is not initializing its index field, so subsequent calls to getNames() will always use index 0. You should add index: i (or parse from id) when constructing the struct.

Suggested change
id: id,
id: id,
index: index,

Copilot uses AI. Check for mistakes.

device: device,
}
return e, nil
}

func (l *fullGPUDeviceSpecGenerator) GetDeviceSpecs() ([]specs.Device, error) {
deviceEdits, err := l.getDeviceEdits()
if err != nil {
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", l.id, err)
}

names, err := l.getNames()
if err != nil {
return nil, fmt.Errorf("failed to get device name: %v", err)
return nil, fmt.Errorf("failed to get device names: %w", err)
}

var deviceSpecs []specs.Device
for _, name := range names {
spec := specs.Device{
deviceSpec := specs.Device{
Name: name,
ContainerEdits: *edits.ContainerEdits,
ContainerEdits: *deviceEdits.ContainerEdits,
}
deviceSpecs = append(deviceSpecs, spec)
deviceSpecs = append(deviceSpecs, deviceSpec)
}

return deviceSpecs, nil
}

// GetGPUDeviceEdits returns the CDI edits for the full GPU represented by 'device'.
func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error) {
device, err := l.newFullGPUDiscoverer(d)
func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, error) {
device, err := l.newFullGPUDiscoverer(l.device)
if err != nil {
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
}
Expand All @@ -66,8 +93,12 @@ func (l *nvmllib) GetGPUDeviceEdits(d device.Device) (*cdi.ContainerEdits, error
return editsForDevice, nil
}

func (l *fullGPUDeviceSpecGenerator) getNames() ([]string, error) {
return l.deviceNamers.GetDeviceNames(l.index, convert{l.device})
}

// newFullGPUDiscoverer creates a discoverer for the full GPU defined by the specified device.
func (l *nvmllib) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
func (l *fullGPUDeviceSpecGenerator) newFullGPUDiscoverer(d device.Device) (discover.Discover, error) {
deviceNodes, err := dgpu.NewForDevice(d,
dgpu.WithDevRoot(l.devRoot),
dgpu.WithLogger(l.logger),
Expand Down
45 changes: 7 additions & 38 deletions pkg/nvcdi/gds.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,23 @@ package nvcdi
import (
"fmt"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
)

type gdslib nvcdilib

var _ Interface = (*gdslib)(nil)
var _ deviceSpecGeneratorFactory = (*gdslib)(nil)

// GetAllDeviceSpecs returns the device specs for all available devices.
func (l *gdslib) GetAllDeviceSpecs() ([]specs.Device, error) {
func (l *gdslib) DeviceSpecGenerators(...string) (DeviceSpecGenerator, error) {
return l, nil
}

// GetDeviceSpecs returns the CDI device specs for a single all device.
func (l *gdslib) GetDeviceSpecs() ([]specs.Device, error) {
discoverer, err := discover.NewGDSDiscoverer(l.logger, l.driverRoot, l.devRoot)
if err != nil {
return nil, fmt.Errorf("failed to create GPUDirect Storage discoverer: %v", err)
Expand All @@ -55,36 +57,3 @@ func (l *gdslib) GetAllDeviceSpecs() ([]specs.Device, error) {
func (l *gdslib) GetCommonEdits() (*cdi.ContainerEdits, error) {
return edits.FromDiscoverer(discover.None{})
}

// GetSpec is unsppported for the gdslib specs.
// gdslib is typically wrapped by a spec that implements GetSpec.
func (l *gdslib) GetSpec(...string) (spec.Interface, error) {
return nil, fmt.Errorf("GetSpec is not supported")
}

// GetGPUDeviceEdits is unsupported for the gdslib specs
func (l *gdslib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported")
}

// GetGPUDeviceSpecs is unsupported for the gdslib specs
func (l *gdslib) GetGPUDeviceSpecs(int, device.Device) ([]specs.Device, error) {
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported")
}

// GetMIGDeviceEdits is unsupported for the gdslib specs
func (l *gdslib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported")
}

// GetMIGDeviceSpecs is unsupported for the gdslib specs
func (l *gdslib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported")
}

// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
// the provided identifiers, where an identifier is an index or UUID of a valid
// GPU device.
func (l *gdslib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
return nil, fmt.Errorf("GetDeviceSpecsByID is not supported")
}
52 changes: 15 additions & 37 deletions pkg/nvcdi/lib-csv.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,33 @@ package nvcdi
import (
"fmt"

"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
"tags.cncf.io/container-device-interface/pkg/cdi"
"tags.cncf.io/container-device-interface/specs-go"

"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
"github.com/NVIDIA/nvidia-container-toolkit/pkg/nvcdi/spec"
)

type csvlib nvcdilib

var _ Interface = (*csvlib)(nil)
var _ deviceSpecGeneratorFactory = (*csvlib)(nil)

// GetSpec should not be called for wsllib
func (l *csvlib) GetSpec(...string) (spec.Interface, error) {
return nil, fmt.Errorf("unexpected call to csvlib.GetSpec()")
func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
for _, id := range ids {
switch id {
case "all":
case "0":
default:
return nil, fmt.Errorf("unsupported device id: %v", id)
}
}

return l, nil
}

// GetAllDeviceSpecs returns the device specs for all available devices.
func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
// GetDeviceSpecs returns the CDI device specs for a single device.
func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
d, err := tegra.New(
tegra.WithLogger(l.logger),
tegra.WithDriverRoot(l.driverRoot),
Expand Down Expand Up @@ -76,33 +82,5 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {

// GetCommonEdits generates a CDI specification that can be used for ANY devices
func (l *csvlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
d := discover.None{}
return edits.FromDiscoverer(d)
}

// GetGPUDeviceEdits generates a CDI specification that can be used for GPU devices
func (l *csvlib) GetGPUDeviceEdits(device.Device) (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("GetGPUDeviceEdits is not supported for CSV files")
}

// GetGPUDeviceSpecs returns the CDI device specs for the full GPU represented by 'device'.
func (l *csvlib) GetGPUDeviceSpecs(i int, d device.Device) ([]specs.Device, error) {
return nil, fmt.Errorf("GetGPUDeviceSpecs is not supported for CSV files")
}

// GetMIGDeviceEdits generates a CDI specification that can be used for MIG devices
func (l *csvlib) GetMIGDeviceEdits(device.Device, device.MigDevice) (*cdi.ContainerEdits, error) {
return nil, fmt.Errorf("GetMIGDeviceEdits is not supported for CSV files")
}

// GetMIGDeviceSpecs returns the CDI device specs for the full MIG represented by 'device'.
func (l *csvlib) GetMIGDeviceSpecs(int, device.Device, int, device.MigDevice) ([]specs.Device, error) {
return nil, fmt.Errorf("GetMIGDeviceSpecs is not supported for CSV files")
}

// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
// the provided identifiers, where an identifier is an index or UUID of a valid
// GPU device.
func (l *csvlib) GetDeviceSpecsByID(...string) ([]specs.Device, error) {
return nil, fmt.Errorf("GetDeviceSpecsByID is not supported for CSV files")
return edits.FromDiscoverer(discover.None{})
}
Loading
Loading