1414
1515import ai .djl .engine .Engine ;
1616
17+ import java .util .Arrays ;
18+ import java .util .Comparator ;
19+ import java .util .List ;
1720import java .util .Map ;
1821import java .util .Objects ;
1922import java .util .concurrent .ConcurrentHashMap ;
2023import java .util .regex .Matcher ;
2124import java .util .regex .Pattern ;
25+ import java .util .stream .Collectors ;
26+ import java .util .stream .IntStream ;
2227
2328/**
2429 * The {@code Device} class provides the specified assignment for CPU/GPU processing on the {@code
3035 * @see <a href="https://d2l.djl.ai/chapter_deep-learning-computation/use-gpu.html">The D2L chapter
3136 * on GPU devices</a>
3237 */
33- public final class Device {
38+ public class Device {
3439
3540 private static final Map <String , Device > CACHE = new ConcurrentHashMap <>();
3641
@@ -39,8 +44,8 @@ public final class Device {
3944
4045 private static final Pattern DEVICE_NAME = Pattern .compile ("([a-z]+)([0-9]*)" );
4146
42- private String deviceType ;
43- private int deviceId ;
47+ protected String deviceType ;
48+ protected int deviceId ;
4449
4550 /**
4651 * Creates a {@code Device} with basic information.
@@ -101,6 +106,13 @@ public static Device fromName(String deviceName, Engine engine) {
101106 return engine .defaultDevice ();
102107 }
103108
109+ if (deviceName .contains ("+" )) {
110+ String [] split = deviceName .split ("\\ +" );
111+ List <Device > subDevices =
112+ Arrays .stream (split ).map (n -> fromName (n , engine )).collect (Collectors .toList ());
113+ return new MultiDevice (subDevices );
114+ }
115+
104116 Matcher matcher = DEVICE_NAME .matcher (deviceName );
105117 if (matcher .matches ()) {
106118 String deviceType = matcher .group (1 );
@@ -214,4 +226,91 @@ public interface Type {
214226 String CPU = "cpu" ;
215227 String GPU = "gpu" ;
216228 }
229+
230+ /** A combined {@link Device} representing the composition of multiple other devices. */
231+ public static class MultiDevice extends Device {
232+
233+ List <Device > devices ;
234+
235+ /**
236+ * Constructs a {@link MultiDevice} with a range of new devices.
237+ *
238+ * @param deviceType the type of the sub-devices
239+ * @param startInclusive the start (inclusive) of the devices range
240+ * @param endExclusive the end (exclusive) of the devices range
241+ */
242+ public MultiDevice (String deviceType , int startInclusive , int endExclusive ) {
243+ this (
244+ IntStream .range (startInclusive , endExclusive )
245+ .mapToObj (i -> Device .of (deviceType , i ))
246+ .collect (Collectors .toList ()));
247+ }
248+
249+ /**
250+ * Constructs a {@link MultiDevice} from sub devices.
251+ *
252+ * @param devices the sub devices
253+ */
254+ public MultiDevice (Device ... devices ) {
255+ this (Arrays .asList (devices ));
256+ }
257+
258+ /**
259+ * Constructs a {@link MultiDevice} from sub devices.
260+ *
261+ * @param devices the sub devices
262+ */
263+ public MultiDevice (List <Device > devices ) {
264+ super (null , -1 );
265+ devices .sort (
266+ Comparator .comparing (Device ::getDeviceType , String .CASE_INSENSITIVE_ORDER )
267+ .thenComparingInt (Device ::getDeviceId ));
268+ this .deviceType =
269+ String .join (
270+ "+" ,
271+ (Iterable <String >)
272+ () ->
273+ devices .stream ()
274+ .map (d -> d .getDeviceType () + d .getDeviceId ())
275+ .iterator ());
276+ this .devices = devices ;
277+ }
278+
279+ /**
280+ * Returns the sub devices.
281+ *
282+ * @return the sub devices
283+ */
284+ public List <Device > getDevices () {
285+ return devices ;
286+ }
287+
288+ /** {@inheritDoc} */
289+ @ Override
290+ public boolean equals (Object o ) {
291+ if (this == o ) {
292+ return true ;
293+ }
294+ if (o == null || getClass () != o .getClass ()) {
295+ return false ;
296+ }
297+ if (!super .equals (o )) {
298+ return false ;
299+ }
300+ MultiDevice that = (MultiDevice ) o ;
301+ return Objects .equals (devices , that .devices );
302+ }
303+
304+ /** {@inheritDoc} */
305+ @ Override
306+ public int hashCode () {
307+ return Objects .hash (super .hashCode (), devices );
308+ }
309+
310+ /** {@inheritDoc} */
311+ @ Override
312+ public String toString () {
313+ return deviceType + "()" ;
314+ }
315+ }
217316}
0 commit comments