@@ -17,7 +17,10 @@ import (
1717
1818var errConflictCountDeviceIDs = errors .New ("cannot set both Count and DeviceIDs on device request" )
1919
20- const nvidiaHook = "nvidia-container-runtime-hook"
20+ const (
21+ nvidiaHook = "nvidia-container-runtime-hook"
22+ amdContainerRuntimeExecutableName = "amd-container-runtime"
23+ )
2124
2225// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
2326var allNvidiaCaps = map [nvidia.Capability ]struct {}{
@@ -30,19 +33,29 @@ var allNvidiaCaps = map[nvidia.Capability]struct{}{
3033}
3134
3235func init () {
33- if _ , err := exec .LookPath (nvidiaHook ); err != nil {
34- // do not register Nvidia driver if helper binary is not present.
36+ // Register Nvidia driver if Nvidia helper binary is present.
37+ if _ , err := exec .LookPath (nvidiaHook ); err == nil {
38+ capset := capabilities.Set {"gpu" : struct {}{}, "nvidia" : struct {}{}}
39+ for c := range allNvidiaCaps {
40+ capset [string (c )] = struct {}{}
41+ }
42+ registerDeviceDriver ("nvidia" , & deviceDriver {
43+ capset : capset ,
44+ updateSpec : setNvidiaGPUs ,
45+ })
3546 return
3647 }
37- capset := capabilities.Set {"gpu" : struct {}{}, "nvidia" : struct {}{}}
38- nvidiaDriver := & deviceDriver {
39- capset : capset ,
40- updateSpec : setNvidiaGPUs ,
41- }
42- for c := range allNvidiaCaps {
43- nvidiaDriver .capset [string (c )] = struct {}{}
48+
49+ // Register AMD driver if AMD helper binary is present.
50+ if _ , err := exec .LookPath (amdContainerRuntimeExecutableName ); err == nil {
51+ registerDeviceDriver ("amd" , & deviceDriver {
52+ capset : capabilities.Set {"gpu" : struct {}{}, "amd" : struct {}{}},
53+ updateSpec : setAMDGPUs ,
54+ })
55+ return
4456 }
45- registerDeviceDriver ("nvidia" , nvidiaDriver )
57+
58+ // No "gpu" capability
4659}
4760
4861func setNvidiaGPUs (s * specs.Spec , dev * deviceInstance ) error {
0 commit comments