diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java index 01943d2572..fed56d3373 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java @@ -226,8 +226,18 @@ public static NeuralNetworkModelManager getInstance() { private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config); public enum Family { - RKNN, - RUBIK + RKNN(".rknn"), + RUBIK(".tflite"); + + private final String fileExtension; + + private Family(String fileExtension) { + this.fileExtension = fileExtension; + } + + public String extension() { + return fileExtension; + } } public enum Version { @@ -358,7 +368,11 @@ public void discoverModels() { try { Files.walk(modelsDirectory.toPath()) .filter(Files::isRegularFile) - .forEach(path -> loadModel(path)); + .filter( + path -> + supportedBackends.stream() + .anyMatch(family -> path.toString().endsWith(family.extension()))) + .forEach(this::loadModel); } catch (IOException e) { logger.error("Failed to discover models at " + modelsDirectory.getAbsolutePath(), e); } diff --git a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java index 84bc34604a..e25bb5f433 100644 --- a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java +++ b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java @@ -605,16 +605,13 @@ public static void onImportObjectDetectionModelRequest(Context ctx) { return; } - String modelFileExtension = ""; NeuralNetworkModelManager.Family family; switch (Platform.getCurrentPlatform()) { case LINUX_QCS6490: - modelFileExtension = "tflite"; family = NeuralNetworkModelManager.Family.RUBIK; break; case LINUX_RK3588_64: - modelFileExtension = "rknn"; family = NeuralNetworkModelManager.Family.RKNN; break; default: @@ -625,19 +622,19 @@ public static void onImportObjectDetectionModelRequest(Context ctx) { } // If adding additional platforms, check platform matches - if (!modelFile.extension().contains(modelFileExtension)) { + if (!modelFile.extension().contains(family.extension())) { ctx.status(400); ctx.result( "The uploaded file was not of type '" - + modelFileExtension + + family.extension() + "'. The uploaded file should be a ." - + modelFileExtension + + family.extension() + " file."); logger.error( "The uploaded file was not of type '" - + modelFileExtension + + family.extension() + "'. The uploaded file should be a ." - + modelFileExtension + + family.extension() + " file."); return; } @@ -665,7 +662,7 @@ public static void onImportObjectDetectionModelRequest(Context ctx) { .addModelProperties( new ModelProperties( modelPath, - modelFile.filename().replaceAll("." + modelFileExtension, ""), + modelFile.filename().replaceAll("." + family.extension(), ""), labels, width, height,