From 584f758edeec213df25044784c4fb33babaa5dfa Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Fri, 25 Jun 2021 09:35:35 -0700 Subject: [PATCH] [serving] support load multiple version of a model on the same endpoint Change-Id: I41824ed1e7ae2cec7d779e5dc3e821d9a695fabc --- .../djl/repository/RepositoryFactoryImpl.java | 3 + .../main/java/ai/djl/serving/ModelServer.java | 40 +++++- .../serving/http/InferenceRequestHandler.java | 42 ++++-- .../http/ManagementRequestHandler.java | 57 +++++--- .../java/ai/djl/serving/wlm/Endpoint.java | 130 +++++++++++++++++ .../src/main/java/ai/djl/serving/wlm/Job.java | 16 +-- .../java/ai/djl/serving/wlm/ModelInfo.java | 81 +++++++---- .../java/ai/djl/serving/wlm/ModelManager.java | 136 ++++++++++++------ .../ai/djl/serving/wlm/WorkLoadManager.java | 41 +++--- .../ai/djl/serving/wlm/ModelInfoTest.java | 2 +- .../src/test/resources/config.properties | 4 +- 11 files changed, 413 insertions(+), 139 deletions(-) create mode 100644 serving/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java diff --git a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java index e5fc5d8331f..4c43581899f 100644 --- a/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java +++ b/api/src/main/java/ai/djl/repository/RepositoryFactoryImpl.java @@ -79,6 +79,9 @@ public Repository newInstance(String name, String url) { } String uriPath = uri.getPath(); + if (uriPath == null) { + uriPath = uri.getSchemeSpecificPart(); + } if (uriPath.startsWith("/") && System.getProperty("os.name").startsWith("Win")) { uriPath = uriPath.substring(1); } diff --git a/serving/serving/src/main/java/ai/djl/serving/ModelServer.java b/serving/serving/src/main/java/ai/djl/serving/ModelServer.java index 84ea4444618..931ecdee449 100644 --- a/serving/serving/src/main/java/ai/djl/serving/ModelServer.java +++ b/serving/serving/src/main/java/ai/djl/serving/ModelServer.java @@ -40,6 +40,8 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.DefaultParser; @@ -54,6 +56,8 @@ public class ModelServer { private static final Logger logger = LoggerFactory.getLogger(ModelServer.class); + private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[(.+)]=)?(.+)"); + private ServerGroups serverGroups; private List futures = new ArrayList<>(2); private AtomicBoolean stopped = new AtomicBoolean(false); @@ -298,18 +302,46 @@ private void initModelStore() throws IOException { for (String url : urls) { logger.info("Initializing model: {}", url); + Matcher matcher = MODEL_STORE_PATTERN.matcher(url); + if (!matcher.matches()) { + throw new AssertionError("Invalid model store url: " + url); + } + String endpoint = matcher.group(2); + String modelUrl = matcher.group(3); + String version = null; + String engine = null; + int gpuId = -1; + String modelName; + if (endpoint != null) { + String[] tokens = endpoint.split(":", -1); + modelName = tokens[0]; + if (tokens.length > 1) { + version = tokens[1].isEmpty() ? null : tokens[1]; + } + if (tokens.length > 2) { + engine = tokens[2].isEmpty() ? null : tokens[2]; + } + if (tokens.length > 3) { + gpuId = tokens[3].isEmpty() ? -1 : Integer.parseInt(tokens[3]); + } + } else { + modelName = ModelInfo.inferModelNameFromUrl(modelUrl); + } + int workers = configManager.getDefaultWorkers(); CompletableFuture future = modelManager.registerModel( - ModelInfo.inferModelNameFromUrl(url), - url, - null, + modelName, + version, + modelUrl, + engine, + gpuId, configManager.getBatchSize(), configManager.getMaxBatchDelay(), configManager.getMaxIdleTime()); ModelInfo modelInfo = future.join(); modelManager.triggerModelUpdated(modelInfo.scaleWorkers(workers, workers)); - startupModels.add(modelInfo.getModelName()); + startupModels.add(modelName); } } diff --git a/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java b/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java index 6845ce9a79c..ec96bc4663c 100644 --- a/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java +++ b/serving/serving/src/main/java/ai/djl/serving/http/InferenceRequestHandler.java @@ -26,6 +26,7 @@ import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.QueryStringDecoder; import java.nio.charset.StandardCharsets; +import java.util.Set; import java.util.regex.Pattern; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -96,8 +97,15 @@ private void handlePredictions( if (segments.length < 3) { throw new ResourceNotFoundException(); } + String modelName = segments[2]; + String version; + if (segments.length > 3) { + version = segments[3].isEmpty() ? null : segments[3]; + } else { + version = null; + } Input input = requestParser.parseRequest(ctx, req, decoder); - predict(ctx, req, input, segments[2]); + predict(ctx, req, input, modelName, version); } private void handleInvocations( @@ -105,6 +113,7 @@ private void handleInvocations( throws ModelNotFoundException { Input input = requestParser.parseRequest(ctx, req, decoder); String modelName = NettyUtils.getParameter(decoder, "model_name", null); + String version = NettyUtils.getParameter(decoder, "model_version", null); if ((modelName == null || modelName.isEmpty())) { modelName = input.getProperty("model_name", null); if (modelName == null) { @@ -115,21 +124,29 @@ private void handleInvocations( } } if (modelName == null) { - if (ModelManager.getInstance().getStartupModels().size() == 1) { - modelName = ModelManager.getInstance().getStartupModels().iterator().next(); + Set startModels = ModelManager.getInstance().getStartupModels(); + if (startModels.size() == 1) { + modelName = startModels.iterator().next(); } if (modelName == null) { throw new BadRequestException("Parameter model_name is required."); } } - predict(ctx, req, input, modelName); + if (version == null) { + version = input.getProperty("model_version", null); + } + predict(ctx, req, input, modelName, version); } private void predict( - ChannelHandlerContext ctx, FullHttpRequest req, Input input, String modelName) + ChannelHandlerContext ctx, + FullHttpRequest req, + Input input, + String modelName, + String version) throws ModelNotFoundException { ModelManager modelManager = ModelManager.getInstance(); - ModelInfo model = modelManager.getModels().get(modelName); + ModelInfo model = modelManager.getModel(modelName, version, true); if (model == null) { String regex = ConfigManager.getInstance().getModelUrlPattern(); if (regex == null) { @@ -147,22 +164,25 @@ private void predict( } } String engineName = input.getProperty("engine_name", null); + int gpuId = Integer.parseInt(input.getProperty("gpu_id", "-1")); logger.info("Loading model {} from: {}", modelName, modelUrl); modelManager .registerModel( modelName, + version, modelUrl, engineName, + gpuId, ConfigManager.getInstance().getBatchSize(), ConfigManager.getInstance().getMaxBatchDelay(), ConfigManager.getInstance().getMaxIdleTime()) - .thenAccept(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, 1))) + .thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, 1))) .thenAccept( - p -> { + m -> { try { - if (!modelManager.addJob(new Job(ctx, modelName, input))) { + if (!modelManager.addJob(new Job(ctx, m, input))) { throw new ServiceUnavailableException( "No worker is available to serve request: " + modelName); @@ -186,8 +206,8 @@ private void predict( return; } - Job job = new Job(ctx, modelName, input); - if (!ModelManager.getInstance().addJob(job)) { + Job job = new Job(ctx, model, input); + if (!modelManager.addJob(job)) { logger.error("unable to process prediction. no free worker available."); throw new ServiceUnavailableException( "No worker is available to serve request: " + modelName); diff --git a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java index 725ac3f9b6d..82afae1fd43 100644 --- a/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java +++ b/serving/serving/src/main/java/ai/djl/serving/http/ManagementRequestHandler.java @@ -15,6 +15,7 @@ import ai.djl.ModelException; import ai.djl.repository.zoo.ModelNotFoundException; import ai.djl.serving.util.NettyUtils; +import ai.djl.serving.wlm.Endpoint; import ai.djl.serving.wlm.ModelInfo; import ai.djl.serving.wlm.ModelManager; import io.netty.channel.ChannelHandlerContext; @@ -45,8 +46,12 @@ public class ManagementRequestHandler extends HttpRequestHandler { private static final String BATCH_SIZE_PARAMETER = "batch_size"; /** HTTP Parameter "model_name". */ private static final String MODEL_NAME_PARAMETER = "model_name"; - /** HTTP Parameter "model_name". */ + /** HTTP Parameter "model_version". */ + private static final String MODEL_VERSION_PARAMETER = "model_version"; + /** HTTP Parameter "engine_name". */ private static final String ENGINE_NAME_PARAMETER = "engine_name"; + /** HTTP Parameter "gpu_id". */ + private static final String GPU_ID_PARAMETER = "gpu_id"; /** HTTP Parameter "max_batch_delay". */ private static final String MAX_BATCH_DELAY_PARAMETER = "max_batch_delay"; /** HTTP Parameter "max_idle_time". */ @@ -88,12 +93,18 @@ protected void handleRequest( throw new MethodNotAllowedException(); } + String modelName = segments[2]; + String version = null; + if (segments.length > 3) { + version = segments[3]; + } + if (HttpMethod.GET.equals(method)) { - handleDescribeModel(ctx, segments[2]); + handleDescribeModel(ctx, modelName, version); } else if (HttpMethod.PUT.equals(method)) { - handleScaleModel(ctx, decoder, segments[2]); + handleScaleModel(ctx, decoder, modelName, version); } else if (HttpMethod.DELETE.equals(method)) { - handleUnregisterModel(ctx, segments[2]); + handleUnregisterModel(ctx, modelName, version); } else { throw new MethodNotAllowedException(); } @@ -110,9 +121,9 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco } ModelManager modelManager = ModelManager.getInstance(); - Map models = modelManager.getModels(); + Map endpoints = modelManager.getEndpoints(); - List keys = new ArrayList<>(models.keySet()); + List keys = new ArrayList<>(endpoints.keySet()); Collections.sort(keys); ListModelsResponse list = new ListModelsResponse(); @@ -125,17 +136,18 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco for (int i = pageToken; i < last; ++i) { String modelName = keys.get(i); - ModelInfo model = models.get(modelName); - list.addModel(modelName, model.getModelUrl()); + for (ModelInfo m : endpoints.get(modelName).getModels()) { + list.addModel(modelName, m.getModelUrl()); + } } NettyUtils.sendJsonResponse(ctx, list); } - private void handleDescribeModel(ChannelHandlerContext ctx, String modelName) + private void handleDescribeModel(ChannelHandlerContext ctx, String modelName, String version) throws ModelNotFoundException { ModelManager modelManager = ModelManager.getInstance(); - DescribeModelResponse resp = modelManager.describeModel(modelName); + DescribeModelResponse resp = modelManager.describeModel(modelName, version); NettyUtils.sendJsonResponse(ctx, resp); } @@ -149,6 +161,8 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec if (modelName == null || modelName.isEmpty()) { modelName = ModelInfo.inferModelNameFromUrl(modelUrl); } + String version = NettyUtils.getParameter(decoder, MODEL_VERSION_PARAMETER, null); + int gpuId = NettyUtils.getIntParameter(decoder, GPU_ID_PARAMETER, -1); String engineName = NettyUtils.getParameter(decoder, ENGINE_NAME_PARAMETER, null); int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1); int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100); @@ -162,13 +176,19 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec final ModelManager modelManager = ModelManager.getInstance(); CompletableFuture future = modelManager.registerModel( - modelName, modelUrl, engineName, batchSize, maxBatchDelay, maxIdleTime); + modelName, + version, + modelUrl, + engineName, + gpuId, + batchSize, + maxBatchDelay, + maxIdleTime); CompletableFuture f = future.thenAccept( - modelInfo -> + m -> modelManager.triggerModelUpdated( - modelInfo - .scaleWorkers(initialWorkers, initialWorkers) + m.scaleWorkers(initialWorkers, initialWorkers) .configurePool(maxIdleTime, maxBatchDelay) .configureModelBatch(batchSize))); @@ -187,10 +207,10 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec }); } - private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName) + private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName, String version) throws ModelNotFoundException { ModelManager modelManager = ModelManager.getInstance(); - if (!modelManager.unregisterModel(modelName)) { + if (!modelManager.unregisterModel(modelName, version)) { throw new ModelNotFoundException("Model not found: " + modelName); } String msg = "Model \"" + modelName + "\" unregistered"; @@ -198,12 +218,11 @@ private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName) } private void handleScaleModel( - ChannelHandlerContext ctx, QueryStringDecoder decoder, String modelName) + ChannelHandlerContext ctx, QueryStringDecoder decoder, String modelName, String version) throws ModelNotFoundException { try { - ModelManager modelManager = ModelManager.getInstance(); - ModelInfo modelInfo = modelManager.getModels().get(modelName); + ModelInfo modelInfo = modelManager.getModel(modelName, version, false); if (modelInfo == null) { throw new ModelNotFoundException("Model not found: " + modelName); } diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java b/serving/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java new file mode 100644 index 00000000000..7ad2fbde5b9 --- /dev/null +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/Endpoint.java @@ -0,0 +1,130 @@ +/* + * Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.serving.wlm; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** A class that represents a webservice endpoint. */ +public class Endpoint { + + private List models; + private Map map; + private AtomicInteger position; + + /** Constructs an {@code Endpoint} instance. */ + public Endpoint() { + models = new ArrayList<>(); + map = new ConcurrentHashMap<>(); + position = new AtomicInteger(0); + } + + /** + * Adds a model to the entpoint. + * + * @param modelInfo the model to be added + * @return true if add success + */ + public synchronized boolean add(ModelInfo modelInfo) { + String version = modelInfo.getVersion(); + if (version == null) { + if (models.isEmpty()) { + map.put("default", 0); + return models.add(modelInfo); + } + return false; + } + if (map.containsKey(version)) { + return false; + } + + map.put(version, models.size()); + return models.add(modelInfo); + } + + /** + * Returns the models associated with the endpoint. + * + * @return the models associated with the endpoint + */ + public List getModels() { + return models; + } + + /** + * Removes a model version from the {@code Endpoint}. + * + * @param version the model version + * @return null if the specified version doesn't exist + */ + public synchronized ModelInfo remove(String version) { + if (version == null) { + if (models.isEmpty()) { + return null; + } + ModelInfo model = models.remove(0); + reIndex(); + return model; + } + Integer index = map.remove(version); + if (index == null) { + return null; + } + ModelInfo model = models.remove((int) index); + reIndex(); + return model; + } + + /** + * Returns the {@code ModelInfo} for the specified version. + * + * @param version the version of the model to retrieve + * @return the {@code ModelInfo} for the specified version + */ + public ModelInfo get(String version) { + Integer index = map.get(version); + if (index == null) { + return null; + } + return models.get(index); + } + + /** + * Returns the next version of model to serve the inference request. + * + * @return the next version of model to serve the inference request + */ + public ModelInfo next() { + int size = models.size(); + if (size == 1) { + return models.get(0); + } + int index = position.getAndUpdate(operand -> (operand + 1) % size); + return models.get(index); + } + + private void reIndex() { + map.clear(); + int size = models.size(); + for (int i = 0; i < size; ++i) { + ModelInfo modelInfo = models.get(i); + String version = modelInfo.getVersion(); + if (version != null) { + map.put(version, i); + } + } + } +} diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/Job.java b/serving/serving/src/main/java/ai/djl/serving/wlm/Job.java index aa6c8a737fa..b069b71765f 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/Job.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/Job.java @@ -32,7 +32,7 @@ public class Job { private ChannelHandlerContext ctx; - private String modelName; + private ModelInfo modelInfo; private Input input; private long begin; private long scheduled; @@ -41,12 +41,12 @@ public class Job { * Constructs an new {@code Job} instance. * * @param ctx the {@code ChannelHandlerContext} - * @param modelName the model name + * @param modelInfo the model to run the job * @param input the input data */ - public Job(ChannelHandlerContext ctx, String modelName, Input input) { + public Job(ChannelHandlerContext ctx, ModelInfo modelInfo, Input input) { this.ctx = ctx; - this.modelName = modelName; + this.modelInfo = modelInfo; this.input = input; begin = System.currentTimeMillis(); @@ -63,12 +63,12 @@ public String getRequestId() { } /** - * Returns the model name that associated with this job. + * Returns the model that associated with this job. * - * @return the model name that associated with this job + * @return the model that associated with this job */ - public String getModelName() { - return modelName; + public ModelInfo getModel() { + return modelInfo; } /** diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java index 253ff203382..06b07b0ddd3 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelInfo.java @@ -18,15 +18,17 @@ import ai.djl.repository.zoo.ZooModel; import java.net.URI; import java.nio.file.Path; +import java.util.Objects; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** A class represent a loaded model and it's metadata. */ -public final class ModelInfo implements AutoCloseable, Cloneable { +public final class ModelInfo implements AutoCloseable { private static final Logger logger = LoggerFactory.getLogger(ModelInfo.class); private String modelName; + private String version; private String modelUrl; private int minWorkers; @@ -42,6 +44,7 @@ public final class ModelInfo implements AutoCloseable, Cloneable { * Constructs a new {@code ModelInfo} instance. * * @param modelName the name of the model that will be used as HTTP endpoint + * @param version the version of the model * @param modelUrl the model url * @param model the {@link ZooModel} * @param queueSize the maximum request queue size @@ -51,6 +54,7 @@ public final class ModelInfo implements AutoCloseable, Cloneable { */ public ModelInfo( String modelName, + String version, String modelUrl, ZooModel model, int queueSize, @@ -58,6 +62,7 @@ public ModelInfo( int maxBatchDelay, int batchSize) { this.modelName = modelName; + this.version = version; this.modelUrl = modelUrl; this.model = model; this.maxBatchDelay = maxBatchDelay; @@ -74,15 +79,8 @@ public ModelInfo( * @return new configured ModelInfo. */ public ModelInfo configureModelBatch(int batchSize) { - ModelInfo clone; - try { - clone = (ModelInfo) this.clone(); - clone.batchSize = batchSize; - } catch (CloneNotSupportedException e) { - // this should never happen, cause we know we are cloneable. - throw new AssertionError(e); - } - return clone; + this.batchSize = batchSize; + return this; } /** @@ -94,16 +92,9 @@ public ModelInfo configureModelBatch(int batchSize) { * @return new configured ModelInfo. */ public ModelInfo scaleWorkers(int minWorkers, int maxWorkers) { - ModelInfo clone; - try { - clone = (ModelInfo) this.clone(); - clone.minWorkers = minWorkers; - clone.maxWorkers = maxWorkers; - } catch (CloneNotSupportedException e) { - // this should never happen, cause we know we are cloneable. - throw new AssertionError(e); - } - return clone; + this.minWorkers = minWorkers; + this.maxWorkers = maxWorkers; + return this; } /** @@ -117,16 +108,9 @@ public ModelInfo scaleWorkers(int minWorkers, int maxWorkers) { * @return new configured ModelInfo. */ public ModelInfo configurePool(int maxIdleTime, int maxBatchDelay) { - ModelInfo clone; - try { - clone = (ModelInfo) this.clone(); - clone.maxIdleTime = maxIdleTime; - clone.maxBatchDelay = maxBatchDelay; - } catch (CloneNotSupportedException e) { - // ..ignore cause we know we are cloneable. - clone = this; // for the compiler - } - return clone; + this.maxIdleTime = maxIdleTime; + this.maxBatchDelay = maxBatchDelay; + return this; } /** @@ -147,6 +131,15 @@ public String getModelName() { return modelName; } + /** + * Returns the model version. + * + * @return the model version + */ + public String getVersion() { + return version; + } + /** * Returns the model url. * @@ -254,4 +247,32 @@ public static String inferModelNameFromUrl(String url) { modelName = modelName.replaceAll("(\\W|^_)", "_"); return modelName; } + + /** {@inheritDoc} */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ModelInfo)) { + return false; + } + ModelInfo modelInfo = (ModelInfo) o; + return modelName.equals(modelInfo.modelName) && Objects.equals(version, modelInfo.version); + } + + /** {@inheritDoc} */ + @Override + public int hashCode() { + return Objects.hash(modelName, version); + } + + /** {@inheritDoc} */ + @Override + public String toString() { + if (version != null) { + return modelName + ':' + version; + } + return modelName; + } } diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java index f657048a28b..a5448e1ab7e 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/ModelManager.java @@ -12,6 +12,7 @@ */ package ai.djl.serving.wlm; +import ai.djl.Device; import ai.djl.ModelException; import ai.djl.modality.Input; import ai.djl.modality.Output; @@ -41,13 +42,13 @@ public final class ModelManager { private ConfigManager configManager; private WorkLoadManager wlm; - private ConcurrentHashMap models; + private Map endpoints; private Set startupModels; private ModelManager(ConfigManager configManager) { this.configManager = configManager; wlm = new WorkLoadManager(); - models = new ConcurrentHashMap<>(); + endpoints = new ConcurrentHashMap<>(); startupModels = new HashSet<>(); } @@ -73,8 +74,10 @@ public static ModelManager getInstance() { * Registers and loads a model. * * @param modelName the name of the model for HTTP endpoint + * @param version the model version * @param modelUrl the model url * @param engineName the engine to load the model + * @param gpuId the GPU device id, -1 for auto selection * @param batchSize the batch size * @param maxBatchDelay the maximum delay for batching * @param maxIdleTime the maximum idle time of the worker threads before scaling down. @@ -82,24 +85,30 @@ public static ModelManager getInstance() { */ public CompletableFuture registerModel( final String modelName, + final String version, final String modelUrl, final String engineName, + final int gpuId, final int batchSize, final int maxBatchDelay, final int maxIdleTime) { return CompletableFuture.supplyAsync( () -> { try { - Criteria criteria = + Criteria.Builder builder = Criteria.builder() .setTypes(Input.class, Output.class) .optModelUrls(modelUrl) - .optEngine(engineName) - .build(); - ZooModel model = criteria.loadModel(); + .optEngine(engineName); + if (gpuId != -1) { + builder.optDevice(Device.gpu(gpuId)); + } + + ZooModel model = builder.build().loadModel(); ModelInfo modelInfo = new ModelInfo( modelName, + version, modelUrl, model, configManager.getJobQueueSize(), @@ -107,12 +116,13 @@ public CompletableFuture registerModel( maxBatchDelay, batchSize); - ModelInfo existingModel = models.putIfAbsent(modelName, modelInfo); - if (existingModel != null) { + Endpoint endpoint = + endpoints.computeIfAbsent(modelName, k -> new Endpoint()); + if (!endpoint.add(modelInfo)) { // model already exists model.close(); throw new BadRequestException( - "Model " + modelName + " is already registered."); + "Model " + modelInfo + " is already registered."); } logger.info("Model {} loaded.", modelName); @@ -124,22 +134,41 @@ public CompletableFuture registerModel( } /** - * Unregisters a model by its name. + * Unregisters a model by its name and version. * * @param modelName the model name to be unregistered + * @param version the model version * @return {@code true} if unregister success */ - public boolean unregisterModel(String modelName) { - ModelInfo model = models.remove(modelName); - if (model == null) { + public boolean unregisterModel(String modelName, String version) { + Endpoint endpoint = endpoints.get(modelName); + if (endpoint == null) { logger.warn("Model not found: " + modelName); return false; } - model = model.scaleWorkers(0, 0); - wlm.modelChanged(model); - startupModels.remove(modelName); - model.close(); - logger.info("Model {} unregistered.", modelName); + if (version == null) { + // unregister all versions + for (ModelInfo m : endpoint.getModels()) { + m.scaleWorkers(0, 0); + wlm.modelChanged(m); + startupModels.remove(modelName); + m.close(); + } + logger.info("Model {} unregistered.", modelName); + } else { + ModelInfo model = endpoint.remove(version); + if (model == null) { + logger.warn("Model not found: " + modelName + ':' + version); + return false; + } + model.scaleWorkers(0, 0); + wlm.modelChanged(model); + startupModels.remove(modelName); + model.close(); + } + if (endpoint.getModels().isEmpty()) { + endpoints.remove(modelName); + } return true; } @@ -148,23 +177,47 @@ public boolean unregisterModel(String modelName) { * up/down all workers to match the parameters for the model. * * @param modelInfo the model that has been updated + * @return the model */ - public void triggerModelUpdated(ModelInfo modelInfo) { - if (!models.containsKey(modelInfo.getModelName())) { - throw new AssertionError("Model not found: " + modelInfo.getModelName()); - } - logger.debug("updateModel: {}", modelInfo.getModelName()); - models.put(modelInfo.getModelName(), modelInfo); + public ModelInfo triggerModelUpdated(ModelInfo modelInfo) { + String modelName = modelInfo.getModelName(); + logger.debug("updateModel: {}", modelName); wlm.modelChanged(modelInfo); + return modelInfo; } /** - * Returns the registry of all models. + * Returns the registry of all endpoints. * - * @return the registry of all models + * @return the registry of all endpoints */ - public Map getModels() { - return models; + public Map getEndpoints() { + return endpoints; + } + + /** + * Returns a version of model. + * + * @param modelName the model name + * @param version the model version + * @param predict ture for selecting a model in load balance fashion + * @return the model + */ + public ModelInfo getModel(String modelName, String version, boolean predict) { + Endpoint endpoint = endpoints.get(modelName); + if (endpoint == null) { + return null; + } + if (version == null) { + if (endpoint.getModels().isEmpty()) { + return null; + } + if (predict) { + return endpoint.next(); + } + return endpoint.getModels().get(0); + } + return endpoint.get(version); } /** @@ -184,23 +237,20 @@ public Set getStartupModels() { * @throws ModelNotFoundException if the model is not registered */ public boolean addJob(Job job) throws ModelNotFoundException { - String modelName = job.getModelName(); - ModelInfo model = models.get(modelName); - if (model == null) { - throw new ModelNotFoundException("Model not found: " + modelName); - } - return wlm.addJob(model, job); + return wlm.addJob(job); } /** * Returns a list of worker information for specified model. * - * @param modelName the model to be queried + * @param modelName the model name to be queried + * @param version the model version to be queried * @return a list of worker information for specified model * @throws ModelNotFoundException if specified model not found */ - public DescribeModelResponse describeModel(String modelName) throws ModelNotFoundException { - ModelInfo model = models.get(modelName); + public DescribeModelResponse describeModel(String modelName, String version) + throws ModelNotFoundException { + ModelInfo model = getModel(modelName, version, false); if (model == null) { throw new ModelNotFoundException("Model not found: " + modelName); } @@ -215,11 +265,11 @@ public DescribeModelResponse describeModel(String modelName) throws ModelNotFoun resp.setMaxIdleTime(model.getMaxIdleTime()); resp.setLoadedAtStartup(startupModels.contains(modelName)); - int activeWorker = wlm.getNumRunningWorkers(modelName); + int activeWorker = wlm.getNumRunningWorkers(model); int targetWorker = model.getMinWorkers(); resp.setStatus(activeWorker >= targetWorker ? "Healthy" : "Unhealthy"); - List workers = wlm.getWorkers(modelName); + List workers = wlm.getWorkers(model); for (WorkerThread worker : workers) { int workerId = worker.getWorkerId(); long startTime = worker.getStartTime(); @@ -242,9 +292,11 @@ public CompletableFuture workerStatus() { int numWorking = 0; int numScaled = 0; - for (Map.Entry m : models.entrySet()) { - numScaled += m.getValue().getMinWorkers(); - numWorking += wlm.getNumRunningWorkers(m.getValue().getModelName()); + for (Endpoint endpoint : endpoints.values()) { + for (ModelInfo m : endpoint.getModels()) { + numScaled += m.getMinWorkers(); + numWorking += wlm.getNumRunningWorkers(m); + } } if ((numWorking > 0) && (numWorking < numScaled)) { diff --git a/serving/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java b/serving/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java index 03b6ce8639d..5e4d5156255 100644 --- a/serving/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java +++ b/serving/serving/src/main/java/ai/djl/serving/wlm/WorkLoadManager.java @@ -35,7 +35,7 @@ class WorkLoadManager { private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class); private ExecutorService threadPool; - private ConcurrentHashMap workerPools; + private ConcurrentHashMap workerPools; /** Constructs a {@code WorkLoadManager} instance. */ public WorkLoadManager() { @@ -44,14 +44,14 @@ public WorkLoadManager() { } /** - * get the workers for the specific model. + * Returns the workers for the specific model. * - * @param modelName The name of the model we are looking for. + * @param modelInfo the name of the model we are looking for. * @return the list of workers responsible to handle predictions for this model. */ - public List getWorkers(String modelName) { + public List getWorkers(ModelInfo modelInfo) { List list; - WorkerPool pool = workerPools.get(modelName); + WorkerPool pool = workerPools.get(modelInfo); if (pool == null) { list = Collections.emptyList(); } else { @@ -67,20 +67,18 @@ public List getWorkers(String modelName) { * Adds an inference job to the job queue of the next free worker. scales up worker if * necessary. * - * @param modelInfo the model to use. * @param job an inference job to be executed. * @return {@code true} if submit success, false otherwise. */ - public boolean addJob(ModelInfo modelInfo, Job job) { + public boolean addJob(Job job) { boolean accepted = false; + ModelInfo modelInfo = job.getModel(); WorkerPool pool = getWorkerPoolForModel(modelInfo); - if (getNumRunningWorkers(modelInfo.getModelName()) > 0) { - + if (getNumRunningWorkers(modelInfo) > 0) { try { accepted = pool.getJobQueue().offer(job); - if (!accepted) { - synchronized (modelInfo.getModelName()) { + synchronized (modelInfo.getModel()) { scaleUpWorkers(modelInfo, pool); accepted = pool.getJobQueue() @@ -100,7 +98,7 @@ public boolean addJob(ModelInfo modelInfo, Job job) { } private void scaleUpWorkers(ModelInfo modelInfo, WorkerPool pool) { - int currentWorkers = getNumRunningWorkers(modelInfo.getModelName()); + int currentWorkers = getNumRunningWorkers(modelInfo); if (currentWorkers < modelInfo.getMaxWorkers()) { logger.debug("scaling up workers for model {} to {} ", modelInfo, currentWorkers + 1); addThreads(pool.getWorkers(), modelInfo, 1, false); @@ -112,15 +110,15 @@ private void scaleUpWorkers(ModelInfo modelInfo, WorkerPool pool) { } /** - * returns the number of running workers of a model. running workers are workers which are not + * Returns the number of running workers of a model. running workers are workers which are not * stopped, in error or scheduled to scale down. * - * @param modelName the model we are interested in. + * @param modelInfo the model we are interested in. * @return number of running workers. */ - public int getNumRunningWorkers(String modelName) { + public int getNumRunningWorkers(ModelInfo modelInfo) { int numWorking = 0; - WorkerPool pool = workerPools.get(modelName); + WorkerPool pool = workerPools.get(modelInfo); if (pool != null) { pool.cleanup(); List threads = pool.getWorkers(); @@ -136,12 +134,12 @@ public int getNumRunningWorkers(String modelName) { } /** - * trigger a model change event. scales up and down workers to match minWorkers/maxWorkers. + * Triggers a model change event. scales up and down workers to match minWorkers/maxWorkers. * * @param modelInfo the changed model. */ public void modelChanged(ModelInfo modelInfo) { - synchronized (modelInfo.getModelName()) { + synchronized (modelInfo.getModel()) { int minWorker = modelInfo.getMinWorkers(); WorkerPool pool = getWorkerPoolForModel(modelInfo); @@ -150,7 +148,7 @@ public void modelChanged(ModelInfo modelInfo) { List threads; if (minWorker == 0) { - workerPools.remove(modelInfo.getModelName()); + workerPools.remove(modelInfo); } threads = pool.getWorkers(); @@ -180,8 +178,7 @@ public void modelChanged(ModelInfo modelInfo) { } private WorkerPool getWorkerPoolForModel(ModelInfo modelInfo) { - return workerPools.computeIfAbsent( - modelInfo.getModelName(), k -> new WorkerPool(modelInfo)); + return workerPools.computeIfAbsent(modelInfo, k -> new WorkerPool(modelInfo)); } private void addThreads( @@ -258,7 +255,7 @@ public void log() { buf.append("-tmpPool\n"); } }); - logger.debug("worker pool for model {}:\n {}", modelName, buf.toString()); + logger.debug("worker pool for model {}:\n {}", modelName, buf); } } diff --git a/serving/serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java b/serving/serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java index d92ee5d2c1b..166e5248733 100644 --- a/serving/serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java +++ b/serving/serving/src/test/java/ai/djl/serving/wlm/ModelInfoTest.java @@ -19,7 +19,7 @@ public class ModelInfoTest { @Test public void testQueueSizeIsSet() { - ModelInfo modelInfo = new ModelInfo("", "", null, 4711, 1, 300, 1); + ModelInfo modelInfo = new ModelInfo("", null, "", null, 4711, 1, 300, 1); Assert.assertEquals(4711, modelInfo.getQueueSize()); } } diff --git a/serving/serving/src/test/resources/config.properties b/serving/serving/src/test/resources/config.properties index 7562f4bc599..6fbeae08aa0 100644 --- a/serving/serving/src/test/resources/config.properties +++ b/serving/serving/src/test/resources/config.properties @@ -2,8 +2,8 @@ inference_address=https://127.0.0.1:8443 management_address=https://127.0.0.1:8443 # management_address=unix:/tmp/management.sock -# model_store=../modelarchive/src/test/resources/models -load_models=https://resources.djl.ai/test-models/mlp.tar.gz +# model_store=models +load_models=https://resources.djl.ai/test-models/mlp.tar.gz,[mlp:v1:MXNet:]=https://resources.djl.ai/test-models/mlp.tar.gz # model_url_pattern=.* # number_of_netty_threads=0 # netty_client_threads=0