Skip to content

Commit eb58bb2

Browse files
zachgkfrankfliu
authored andcommitted
Creates MultiDevice (#2819)
This creates an abstraction for combining devices into a single device. The main use case for now is in DJL Serving TP_parallel. It will allow us to create a WorkerGroup and a PyPredictor for a set of devices and then track the usage of devices properly. It could also be used later for multi-gpu training or other multi-device cases.
1 parent c6499e8 commit eb58bb2

File tree

3 files changed

+114
-3
lines changed

3 files changed

+114
-3
lines changed

api/src/main/java/ai/djl/Device.java

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414

1515
import ai.djl.engine.Engine;
1616

17+
import java.util.Arrays;
18+
import java.util.Comparator;
19+
import java.util.List;
1720
import java.util.Map;
1821
import java.util.Objects;
1922
import java.util.concurrent.ConcurrentHashMap;
2023
import java.util.regex.Matcher;
2124
import java.util.regex.Pattern;
25+
import java.util.stream.Collectors;
26+
import java.util.stream.IntStream;
2227

2328
/**
2429
* The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code
@@ -30,7 +35,7 @@
3035
* @see <a href="https://d2l.djl.ai/chapter_deep-learning-computation/use-gpu.html">The D2L chapter
3136
* on GPU devices</a>
3237
*/
33-
public final class Device {
38+
public class Device {
3439

3540
private static final Map<String, Device> CACHE = new ConcurrentHashMap<>();
3641

@@ -39,8 +44,8 @@ public final class Device {
3944

4045
private static final Pattern DEVICE_NAME = Pattern.compile("([a-z]+)([0-9]*)");
4146

42-
private String deviceType;
43-
private int deviceId;
47+
protected String deviceType;
48+
protected int deviceId;
4449

4550
/**
4651
* Creates a {@code Device} with basic information.
@@ -101,6 +106,13 @@ public static Device fromName(String deviceName, Engine engine) {
101106
return engine.defaultDevice();
102107
}
103108

109+
if (deviceName.contains("+")) {
110+
String[] split = deviceName.split("\\+");
111+
List<Device> subDevices =
112+
Arrays.stream(split).map(n -> fromName(n, engine)).collect(Collectors.toList());
113+
return new MultiDevice(subDevices);
114+
}
115+
104116
Matcher matcher = DEVICE_NAME.matcher(deviceName);
105117
if (matcher.matches()) {
106118
String deviceType = matcher.group(1);
@@ -214,4 +226,91 @@ public interface Type {
214226
String CPU = "cpu";
215227
String GPU = "gpu";
216228
}
229+
230+
/** A combined {@link Device} representing the composition of multiple other devices. */
231+
public static class MultiDevice extends Device {
232+
233+
List<Device> devices;
234+
235+
/**
236+
* Constructs a {@link MultiDevice} with a range of new devices.
237+
*
238+
* @param deviceType the type of the sub-devices
239+
* @param startInclusive the start (inclusive) of the devices range
240+
* @param endExclusive the end (exclusive) of the devices range
241+
*/
242+
public MultiDevice(String deviceType, int startInclusive, int endExclusive) {
243+
this(
244+
IntStream.range(startInclusive, endExclusive)
245+
.mapToObj(i -> Device.of(deviceType, i))
246+
.collect(Collectors.toList()));
247+
}
248+
249+
/**
250+
* Constructs a {@link MultiDevice} from sub devices.
251+
*
252+
* @param devices the sub devices
253+
*/
254+
public MultiDevice(Device... devices) {
255+
this(Arrays.asList(devices));
256+
}
257+
258+
/**
259+
* Constructs a {@link MultiDevice} from sub devices.
260+
*
261+
* @param devices the sub devices
262+
*/
263+
public MultiDevice(List<Device> devices) {
264+
super(null, -1);
265+
devices.sort(
266+
Comparator.comparing(Device::getDeviceType, String.CASE_INSENSITIVE_ORDER)
267+
.thenComparingInt(Device::getDeviceId));
268+
this.deviceType =
269+
String.join(
270+
"+",
271+
(Iterable<String>)
272+
() ->
273+
devices.stream()
274+
.map(d -> d.getDeviceType() + d.getDeviceId())
275+
.iterator());
276+
this.devices = devices;
277+
}
278+
279+
/**
280+
* Returns the sub devices.
281+
*
282+
* @return the sub devices
283+
*/
284+
public List<Device> getDevices() {
285+
return devices;
286+
}
287+
288+
/** {@inheritDoc} */
289+
@Override
290+
public boolean equals(Object o) {
291+
if (this == o) {
292+
return true;
293+
}
294+
if (o == null || getClass() != o.getClass()) {
295+
return false;
296+
}
297+
if (!super.equals(o)) {
298+
return false;
299+
}
300+
MultiDevice that = (MultiDevice) o;
301+
return Objects.equals(devices, that.devices);
302+
}
303+
304+
/** {@inheritDoc} */
305+
@Override
306+
public int hashCode() {
307+
return Objects.hash(super.hashCode(), devices);
308+
}
309+
310+
/** {@inheritDoc} */
311+
@Override
312+
public String toString() {
313+
return deviceType + "()";
314+
}
315+
}
217316
}

api/src/main/java/ai/djl/training/ParameterStore.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
package ai.djl.training;
1515

1616
import ai.djl.Device;
17+
import ai.djl.Device.MultiDevice;
1718
import ai.djl.ndarray.NDArray;
1819
import ai.djl.ndarray.NDManager;
1920
import ai.djl.nn.Parameter;
@@ -64,6 +65,10 @@ public void setParameterServer(ParameterServer parameterServer, Device[] devices
6465
this.parameterServer = parameterServer;
6566
deviceMap.clear();
6667
for (int i = 0; i < devices.length; ++i) {
68+
if (devices[i] instanceof MultiDevice) {
69+
throw new IllegalArgumentException(
70+
"The parameter store does not support MultiDevices");
71+
}
6772
if (deviceMap.put(devices[i], i) != null) {
6873
throw new IllegalArgumentException("Duplicated devices are not allowed.");
6974
}

api/src/test/java/ai/djl/DeviceTest.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
package ai.djl;
1515

16+
import ai.djl.Device.MultiDevice;
1617
import ai.djl.engine.Engine;
1718

1819
import org.testng.Assert;
@@ -37,6 +38,8 @@ public void testDevice() {
3738

3839
System.setProperty("test_key", "test");
3940
Engine.debugEnvironment();
41+
42+
Assert.assertEquals(2, new MultiDevice(Device.gpu(1), Device.gpu(2)).getDevices().size());
4043
}
4144

4245
@Test
@@ -54,5 +57,9 @@ public void testDeviceName() {
5457
Device defaultDevice = Engine.getInstance().defaultDevice();
5558
Assert.assertEquals(Device.fromName(""), defaultDevice);
5659
Assert.assertEquals(Device.fromName(null), defaultDevice);
60+
61+
Assert.assertEquals(
62+
Device.fromName("gpu1+gpu2"), new MultiDevice(Device.gpu(2), Device.gpu(1)));
63+
Assert.assertEquals(Device.fromName("gpu1+gpu2"), new MultiDevice("gpu", 1, 3));
5764
}
5865
}

0 commit comments

Comments
 (0)