Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ public Repository newInstance(String name, String url) {
}

String uriPath = uri.getPath();
if (uriPath == null) {
uriPath = uri.getSchemeSpecificPart();
}
if (uriPath.startsWith("/") && System.getProperty("os.name").startsWith("Win")) {
uriPath = uriPath.substring(1);
}
Expand Down
40 changes: 36 additions & 4 deletions serving/serving/src/main/java/ai/djl/serving/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
Expand All @@ -54,6 +56,8 @@ public class ModelServer {

private static final Logger logger = LoggerFactory.getLogger(ModelServer.class);

private static final Pattern MODEL_STORE_PATTERN = Pattern.compile("(\\[(.+)]=)?(.+)");

private ServerGroups serverGroups;
private List<ChannelFuture> futures = new ArrayList<>(2);
private AtomicBoolean stopped = new AtomicBoolean(false);
Expand Down Expand Up @@ -298,18 +302,46 @@ private void initModelStore() throws IOException {

for (String url : urls) {
logger.info("Initializing model: {}", url);
Matcher matcher = MODEL_STORE_PATTERN.matcher(url);
if (!matcher.matches()) {
throw new AssertionError("Invalid model store url: " + url);
}
String endpoint = matcher.group(2);
String modelUrl = matcher.group(3);
String version = null;
String engine = null;
int gpuId = -1;
String modelName;
if (endpoint != null) {
String[] tokens = endpoint.split(":", -1);
modelName = tokens[0];
if (tokens.length > 1) {
version = tokens[1].isEmpty() ? null : tokens[1];
}
if (tokens.length > 2) {
engine = tokens[2].isEmpty() ? null : tokens[2];
}
if (tokens.length > 3) {
gpuId = tokens[3].isEmpty() ? -1 : Integer.parseInt(tokens[3]);
}
} else {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}

int workers = configManager.getDefaultWorkers();
CompletableFuture<ModelInfo> future =
modelManager.registerModel(
ModelInfo.inferModelNameFromUrl(url),
url,
null,
modelName,
version,
modelUrl,
engine,
gpuId,
configManager.getBatchSize(),
configManager.getMaxBatchDelay(),
configManager.getMaxIdleTime());
ModelInfo modelInfo = future.join();
modelManager.triggerModelUpdated(modelInfo.scaleWorkers(workers, workers));
startupModels.add(modelInfo.getModelName());
startupModels.add(modelName);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.QueryStringDecoder;
import java.nio.charset.StandardCharsets;
import java.util.Set;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -96,15 +97,23 @@ private void handlePredictions(
if (segments.length < 3) {
throw new ResourceNotFoundException();
}
String modelName = segments[2];
String version;
if (segments.length > 3) {
version = segments[3].isEmpty() ? null : segments[3];
} else {
version = null;
}
Input input = requestParser.parseRequest(ctx, req, decoder);
predict(ctx, req, input, segments[2]);
predict(ctx, req, input, modelName, version);
}

private void handleInvocations(
ChannelHandlerContext ctx, FullHttpRequest req, QueryStringDecoder decoder)
throws ModelNotFoundException {
Input input = requestParser.parseRequest(ctx, req, decoder);
String modelName = NettyUtils.getParameter(decoder, "model_name", null);
String version = NettyUtils.getParameter(decoder, "model_version", null);
if ((modelName == null || modelName.isEmpty())) {
modelName = input.getProperty("model_name", null);
if (modelName == null) {
Expand All @@ -115,21 +124,29 @@ private void handleInvocations(
}
}
if (modelName == null) {
if (ModelManager.getInstance().getStartupModels().size() == 1) {
modelName = ModelManager.getInstance().getStartupModels().iterator().next();
Set<String> startModels = ModelManager.getInstance().getStartupModels();
if (startModels.size() == 1) {
modelName = startModels.iterator().next();
}
if (modelName == null) {
throw new BadRequestException("Parameter model_name is required.");
}
}
predict(ctx, req, input, modelName);
if (version == null) {
version = input.getProperty("model_version", null);
}
predict(ctx, req, input, modelName, version);
}

private void predict(
ChannelHandlerContext ctx, FullHttpRequest req, Input input, String modelName)
ChannelHandlerContext ctx,
FullHttpRequest req,
Input input,
String modelName,
String version)
throws ModelNotFoundException {
ModelManager modelManager = ModelManager.getInstance();
ModelInfo model = modelManager.getModels().get(modelName);
ModelInfo model = modelManager.getModel(modelName, version, true);
if (model == null) {
String regex = ConfigManager.getInstance().getModelUrlPattern();
if (regex == null) {
Expand All @@ -147,22 +164,25 @@ private void predict(
}
}
String engineName = input.getProperty("engine_name", null);
int gpuId = Integer.parseInt(input.getProperty("gpu_id", "-1"));

logger.info("Loading model {} from: {}", modelName, modelUrl);

modelManager
.registerModel(
modelName,
version,
modelUrl,
engineName,
gpuId,
ConfigManager.getInstance().getBatchSize(),
ConfigManager.getInstance().getMaxBatchDelay(),
ConfigManager.getInstance().getMaxIdleTime())
.thenAccept(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, 1)))
.thenApply(m -> modelManager.triggerModelUpdated(m.scaleWorkers(1, 1)))
.thenAccept(
p -> {
m -> {
try {
if (!modelManager.addJob(new Job(ctx, modelName, input))) {
if (!modelManager.addJob(new Job(ctx, m, input))) {
throw new ServiceUnavailableException(
"No worker is available to serve request: "
+ modelName);
Expand All @@ -186,8 +206,8 @@ private void predict(
return;
}

Job job = new Job(ctx, modelName, input);
if (!ModelManager.getInstance().addJob(job)) {
Job job = new Job(ctx, model, input);
if (!modelManager.addJob(job)) {
logger.error("unable to process prediction. no free worker available.");
throw new ServiceUnavailableException(
"No worker is available to serve request: " + modelName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import ai.djl.ModelException;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.serving.util.NettyUtils;
import ai.djl.serving.wlm.Endpoint;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.ModelManager;
import io.netty.channel.ChannelHandlerContext;
Expand Down Expand Up @@ -45,8 +46,12 @@ public class ManagementRequestHandler extends HttpRequestHandler {
private static final String BATCH_SIZE_PARAMETER = "batch_size";
/** HTTP Parameter "model_name". */
private static final String MODEL_NAME_PARAMETER = "model_name";
/** HTTP Parameter "model_name". */
/** HTTP Parameter "model_version". */
private static final String MODEL_VERSION_PARAMETER = "model_version";
/** HTTP Parameter "engine_name". */
private static final String ENGINE_NAME_PARAMETER = "engine_name";
/** HTTP Parameter "gpu_id". */
private static final String GPU_ID_PARAMETER = "gpu_id";
/** HTTP Parameter "max_batch_delay". */
private static final String MAX_BATCH_DELAY_PARAMETER = "max_batch_delay";
/** HTTP Parameter "max_idle_time". */
Expand Down Expand Up @@ -88,12 +93,18 @@ protected void handleRequest(
throw new MethodNotAllowedException();
}

String modelName = segments[2];
String version = null;
if (segments.length > 3) {
version = segments[3];
}

if (HttpMethod.GET.equals(method)) {
handleDescribeModel(ctx, segments[2]);
handleDescribeModel(ctx, modelName, version);
} else if (HttpMethod.PUT.equals(method)) {
handleScaleModel(ctx, decoder, segments[2]);
handleScaleModel(ctx, decoder, modelName, version);
} else if (HttpMethod.DELETE.equals(method)) {
handleUnregisterModel(ctx, segments[2]);
handleUnregisterModel(ctx, modelName, version);
} else {
throw new MethodNotAllowedException();
}
Expand All @@ -110,9 +121,9 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco
}

ModelManager modelManager = ModelManager.getInstance();
Map<String, ModelInfo> models = modelManager.getModels();
Map<String, Endpoint> endpoints = modelManager.getEndpoints();

List<String> keys = new ArrayList<>(models.keySet());
List<String> keys = new ArrayList<>(endpoints.keySet());
Collections.sort(keys);
ListModelsResponse list = new ListModelsResponse();

Expand All @@ -125,17 +136,18 @@ private void handleListModels(ChannelHandlerContext ctx, QueryStringDecoder deco

for (int i = pageToken; i < last; ++i) {
String modelName = keys.get(i);
ModelInfo model = models.get(modelName);
list.addModel(modelName, model.getModelUrl());
for (ModelInfo m : endpoints.get(modelName).getModels()) {
list.addModel(modelName, m.getModelUrl());
}
}

NettyUtils.sendJsonResponse(ctx, list);
}

private void handleDescribeModel(ChannelHandlerContext ctx, String modelName)
private void handleDescribeModel(ChannelHandlerContext ctx, String modelName, String version)
throws ModelNotFoundException {
ModelManager modelManager = ModelManager.getInstance();
DescribeModelResponse resp = modelManager.describeModel(modelName);
DescribeModelResponse resp = modelManager.describeModel(modelName, version);
NettyUtils.sendJsonResponse(ctx, resp);
}

Expand All @@ -149,6 +161,8 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
if (modelName == null || modelName.isEmpty()) {
modelName = ModelInfo.inferModelNameFromUrl(modelUrl);
}
String version = NettyUtils.getParameter(decoder, MODEL_VERSION_PARAMETER, null);
int gpuId = NettyUtils.getIntParameter(decoder, GPU_ID_PARAMETER, -1);
String engineName = NettyUtils.getParameter(decoder, ENGINE_NAME_PARAMETER, null);
int batchSize = NettyUtils.getIntParameter(decoder, BATCH_SIZE_PARAMETER, 1);
int maxBatchDelay = NettyUtils.getIntParameter(decoder, MAX_BATCH_DELAY_PARAMETER, 100);
Expand All @@ -162,13 +176,19 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
final ModelManager modelManager = ModelManager.getInstance();
CompletableFuture<ModelInfo> future =
modelManager.registerModel(
modelName, modelUrl, engineName, batchSize, maxBatchDelay, maxIdleTime);
modelName,
version,
modelUrl,
engineName,
gpuId,
batchSize,
maxBatchDelay,
maxIdleTime);
CompletableFuture<Void> f =
future.thenAccept(
modelInfo ->
m ->
modelManager.triggerModelUpdated(
modelInfo
.scaleWorkers(initialWorkers, initialWorkers)
m.scaleWorkers(initialWorkers, initialWorkers)
.configurePool(maxIdleTime, maxBatchDelay)
.configureModelBatch(batchSize)));

Expand All @@ -187,23 +207,22 @@ private void handleRegisterModel(final ChannelHandlerContext ctx, QueryStringDec
});
}

private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName)
private void handleUnregisterModel(ChannelHandlerContext ctx, String modelName, String version)
throws ModelNotFoundException {
ModelManager modelManager = ModelManager.getInstance();
if (!modelManager.unregisterModel(modelName)) {
if (!modelManager.unregisterModel(modelName, version)) {
throw new ModelNotFoundException("Model not found: " + modelName);
}
String msg = "Model \"" + modelName + "\" unregistered";
NettyUtils.sendJsonResponse(ctx, new StatusResponse(msg));
}

private void handleScaleModel(
ChannelHandlerContext ctx, QueryStringDecoder decoder, String modelName)
ChannelHandlerContext ctx, QueryStringDecoder decoder, String modelName, String version)
throws ModelNotFoundException {
try {

ModelManager modelManager = ModelManager.getInstance();
ModelInfo modelInfo = modelManager.getModels().get(modelName);
ModelInfo modelInfo = modelManager.getModel(modelName, version, false);
if (modelInfo == null) {
throw new ModelNotFoundException("Model not found: " + modelName);
}
Expand Down
Loading