Skip to content

Commit 3b1d2f7

Browse files
authored
Merge pull request moby#49952 from sgopinath1/49824-amd-gpu
Added support for AMD GPUs in "docker run --gpus".
2 parents 349a2d0 + e32715e commit 3b1d2f7

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

daemon/devices_amd_linux.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package daemon
2+
3+
import (
4+
"strings"
5+
6+
"github.com/opencontainers/runtime-spec/specs-go"
7+
)
8+
9+
func setAMDGPUs(s *specs.Spec, dev *deviceInstance) error {
10+
req := dev.req
11+
if req.Count != 0 && len(req.DeviceIDs) > 0 {
12+
return errConflictCountDeviceIDs
13+
}
14+
15+
switch {
16+
case len(req.DeviceIDs) > 0:
17+
s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES="+strings.Join(req.DeviceIDs, ","))
18+
case req.Count > 0:
19+
s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES="+countToDevices(req.Count))
20+
case req.Count < 0:
21+
s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES=all")
22+
case req.Count == 0:
23+
s.Process.Env = append(s.Process.Env, "AMD_VISIBLE_DEVICES=void")
24+
}
25+
26+
return nil
27+
}
Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ import (
1717

1818
var errConflictCountDeviceIDs = errors.New("cannot set both Count and DeviceIDs on device request")
1919

20-
const nvidiaHook = "nvidia-container-runtime-hook"
20+
const (
21+
nvidiaHook = "nvidia-container-runtime-hook"
22+
amdContainerRuntimeExecutableName = "amd-container-runtime"
23+
)
2124

2225
// These are NVIDIA-specific capabilities stolen from github.com/containerd/containerd/contrib/nvidia.allCaps
2326
var allNvidiaCaps = map[nvidia.Capability]struct{}{
@@ -30,19 +33,29 @@ var allNvidiaCaps = map[nvidia.Capability]struct{}{
3033
}
3134

3235
func init() {
33-
if _, err := exec.LookPath(nvidiaHook); err != nil {
34-
// do not register Nvidia driver if helper binary is not present.
36+
// Register Nvidia driver if Nvidia helper binary is present.
37+
if _, err := exec.LookPath(nvidiaHook); err == nil {
38+
capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
39+
for c := range allNvidiaCaps {
40+
capset[string(c)] = struct{}{}
41+
}
42+
registerDeviceDriver("nvidia", &deviceDriver{
43+
capset: capset,
44+
updateSpec: setNvidiaGPUs,
45+
})
3546
return
3647
}
37-
capset := capabilities.Set{"gpu": struct{}{}, "nvidia": struct{}{}}
38-
nvidiaDriver := &deviceDriver{
39-
capset: capset,
40-
updateSpec: setNvidiaGPUs,
41-
}
42-
for c := range allNvidiaCaps {
43-
nvidiaDriver.capset[string(c)] = struct{}{}
48+
49+
// Register AMD driver if AMD helper binary is present.
50+
if _, err := exec.LookPath(amdContainerRuntimeExecutableName); err == nil {
51+
registerDeviceDriver("amd", &deviceDriver{
52+
capset: capabilities.Set{"gpu": struct{}{}, "amd": struct{}{}},
53+
updateSpec: setAMDGPUs,
54+
})
55+
return
4456
}
45-
registerDeviceDriver("nvidia", nvidiaDriver)
57+
58+
// No "gpu" capability
4659
}
4760

4861
func setNvidiaGPUs(s *specs.Spec, dev *deviceInstance) error {

0 commit comments

Comments
 (0)