Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,18 @@ public static NeuralNetworkModelManager getInstance() {
private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config);

public enum Family {
RKNN,
RUBIK
RKNN(".rknn"),
RUBIK(".tflite");

private final String fileExtension;

private Family(String fileExtension) {
this.fileExtension = fileExtension;
}

public String extension() {
return fileExtension;
}
}

public enum Version {
Expand Down Expand Up @@ -358,7 +368,11 @@ public void discoverModels() {
try {
Files.walk(modelsDirectory.toPath())
.filter(Files::isRegularFile)
.forEach(path -> loadModel(path));
.filter(
path ->
supportedBackends.stream()
.anyMatch(family -> path.toString().endsWith(family.extension())))
.forEach(this::loadModel);
} catch (IOException e) {
logger.error("Failed to discover models at " + modelsDirectory.getAbsolutePath(), e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,16 +605,13 @@ public static void onImportObjectDetectionModelRequest(Context ctx) {
return;
}

String modelFileExtension = "";
NeuralNetworkModelManager.Family family;

switch (Platform.getCurrentPlatform()) {
case LINUX_QCS6490:
modelFileExtension = "tflite";
family = NeuralNetworkModelManager.Family.RUBIK;
break;
case LINUX_RK3588_64:
modelFileExtension = "rknn";
family = NeuralNetworkModelManager.Family.RKNN;
break;
default:
Expand All @@ -625,19 +622,19 @@ public static void onImportObjectDetectionModelRequest(Context ctx) {
}

// If adding additional platforms, check platform matches
if (!modelFile.extension().contains(modelFileExtension)) {
if (!modelFile.extension().contains(family.extension())) {
ctx.status(400);
ctx.result(
"The uploaded file was not of type '"
+ modelFileExtension
+ family.extension()
+ "'. The uploaded file should be a ."
+ modelFileExtension
+ family.extension()
+ " file.");
logger.error(
"The uploaded file was not of type '"
+ modelFileExtension
+ family.extension()
+ "'. The uploaded file should be a ."
+ modelFileExtension
+ family.extension()
+ " file.");
return;
}
Expand Down Expand Up @@ -665,7 +662,7 @@ public static void onImportObjectDetectionModelRequest(Context ctx) {
.addModelProperties(
new ModelProperties(
modelPath,
modelFile.filename().replaceAll("." + modelFileExtension, ""),
modelFile.filename().replaceAll("." + family.extension(), ""),
labels,
width,
height,
Expand Down
Loading