Skip to content

Commit 918dd5f

Browse files
authored
Merge pull request #23 from elezar/fix-devicelist-filter
Fix devicelist filter
2 parents b057784 + d8061f4 commit 918dd5f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+23560
-8
lines changed

go.mod

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,17 @@ module github.com/NVIDIA/go-gpuallocator
22

33
go 1.20
44

5-
require github.com/NVIDIA/go-nvlib v0.0.0-20240109130712-11603560817a
5+
require (
6+
github.com/NVIDIA/go-nvlib v0.0.0-20240109130712-11603560817a
7+
github.com/stretchr/testify v1.8.4
8+
)
69

710
require (
811
github.com/NVIDIA/go-nvml v0.12.0-1.0.20231020145430-e06766c5e74f // indirect
12+
github.com/davecgh/go-spew v1.1.1 // indirect
913
github.com/google/uuid v1.4.0 // indirect
14+
github.com/pmezard/go-difflib v1.0.0 // indirect
15+
gopkg.in/yaml.v3 v3.0.1 // indirect
1016
)
1117

1218
replace (

go.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
1616
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
1717
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
1818
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
19+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
1920
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
2021
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
2122
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

gpuallocator/device.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ func NewDevices(opts ...Option) (DeviceList, error) {
7878
opt(o)
7979
}
8080
if o.nvmllib == nil {
81-
o.nvmllib = nvml.New()
81+
o.nvmllib = nvmlNew()
8282
}
8383
if o.devicelib == nil {
8484
o.devicelib = device.New(
@@ -139,6 +139,9 @@ func (o *deviceListBuilder) build() (DeviceList, error) {
139139

140140
// NewDevicesFrom creates a list of Devices from the specific set of GPU uuids passed in.
141141
func NewDevicesFrom(uuids []string) (DeviceList, error) {
142+
if len(uuids) == 0 {
143+
return DeviceList{}, nil
144+
}
142145
devices, err := NewDevices()
143146
if err != nil {
144147
return nil, err
@@ -147,14 +150,9 @@ func NewDevicesFrom(uuids []string) (DeviceList, error) {
147150
}
148151

149152
// Filter filters out the selected devices from the list.
150-
// If the supplied list of uuids is nil, no filtering is performed.
151153
// Note that the specified uuids must exist in the list of devices.
152154
func (d DeviceList) Filter(uuids []string) (DeviceList, error) {
153-
if uuids == nil {
154-
return d, nil
155-
}
156-
157-
filtered := []*Device{}
155+
var filtered DeviceList
158156
for _, uuid := range uuids {
159157
for _, device := range d {
160158
if device.UUID == uuid {
@@ -254,3 +252,8 @@ func (ds DeviceSet) SortedSlice() []*Device {
254252

255253
return devices
256254
}
255+
256+
// nvmlNew is implemented as a function here to allow for this to be replaced for testing.
257+
var nvmlNew = func() nvml.Interface {
258+
return nvml.New()
259+
}

gpuallocator/device_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
/**
2+
# Copyright 2024 NVIDIA CORPORATION
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package gpuallocator
18+
19+
import (
20+
"testing"
21+
22+
"github.com/NVIDIA/go-nvlib/pkg/nvml"
23+
"github.com/stretchr/testify/require"
24+
)
25+
26+
func TestDeviceListFilter(t *testing.T) {
27+
singleDeviceNVML := &nvml.InterfaceMock{
28+
InitFunc: func() nvml.Return {
29+
return nvml.SUCCESS
30+
},
31+
ShutdownFunc: func() nvml.Return {
32+
return nvml.SUCCESS
33+
},
34+
DeviceGetCountFunc: func() (int, nvml.Return) {
35+
return 1, nvml.SUCCESS
36+
},
37+
DeviceGetHandleByIndexFunc: func(Index int) (nvml.Device, nvml.Return) {
38+
device := &nvml.DeviceMock{
39+
GetNameFunc: func() (string, nvml.Return) {
40+
return "Device0", nvml.SUCCESS
41+
},
42+
GetUUIDFunc: func() (string, nvml.Return) {
43+
return "GPU-0", nvml.SUCCESS
44+
},
45+
GetPciInfoFunc: func() (nvml.PciInfo, nvml.Return) {
46+
return nvml.PciInfo{}, nvml.SUCCESS
47+
},
48+
}
49+
return device, nvml.SUCCESS
50+
},
51+
}
52+
53+
testCases := []struct {
54+
description string
55+
uuids []string
56+
nvmllib nvml.Interface
57+
expectedDeviceList DeviceList
58+
expectedError error
59+
}{
60+
{
61+
description: "nil uuids returns empty list",
62+
nvmllib: singleDeviceNVML,
63+
expectedDeviceList: DeviceList{},
64+
},
65+
{
66+
description: "empty uuids returns empty list",
67+
uuids: []string{},
68+
nvmllib: singleDeviceNVML,
69+
expectedDeviceList: DeviceList{},
70+
},
71+
}
72+
73+
for _, tc := range testCases {
74+
t.Run(tc.description, func(t *testing.T) {
75+
defer setNVMLNewDuringTest(tc.nvmllib)()
76+
deviceList, err := NewDevicesFrom(tc.uuids)
77+
require.ErrorIs(t, tc.expectedError, err)
78+
require.EqualValues(t, tc.expectedDeviceList, deviceList)
79+
})
80+
}
81+
}
82+
83+
func setNVMLNewDuringTest(to nvml.Interface) func() {
84+
original := nvmlNew
85+
nvmlNew = func() nvml.Interface {
86+
return to
87+
}
88+
89+
return func() {
90+
nvmlNew = original
91+
}
92+
}

vendor/github.com/davecgh/go-spew/LICENSE

Lines changed: 15 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vendor/github.com/davecgh/go-spew/spew/bypass.go

Lines changed: 145 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vendor/github.com/davecgh/go-spew/spew/bypasssafe.go

Lines changed: 38 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)