runner.executorStdin = stdin
runner.executorStdout = stdout
runner.executorStderr = stderr
+
+ cudaDeviceCount := 0
+ if runner.Container.RuntimeConstraints.CUDADriverVersion != "" ||
+ runner.Container.RuntimeConstraints.CUDAHardwareCapability != "" ||
+ runner.Container.RuntimeConstraints.CUDADeviceCount != 0 {
+ // if any of these are set, enable CUDA GPU support
+ cudaDeviceCount = runner.Container.RuntimeConstraints.CUDADeviceCount
+ if cudaDeviceCount == 0 {
+ cudaDeviceCount = 1
+ }
+ }
+
return runner.executor.Create(containerSpec{
Image: imageID,
VCPUs: runner.Container.RuntimeConstraints.VCPUs,
BindMounts: bindmounts,
Command: runner.Container.Command,
EnableNetwork: enableNetwork,
- CUDADeviceCount: runner.Container.RuntimeConstraints.CUDADeviceCount,
+ CUDADeviceCount: cudaDeviceCount,
NetworkMode: runner.networkMode,
CgroupParent: runner.setCgroupParent,
Stdin: stdin,