diff --git a/application/pom.xml b/application/pom.xml index 2fffd911d97a..d526309d40fd 100644 --- a/application/pom.xml +++ b/application/pom.xml @@ -53,6 +53,10 @@ javax.annotation javax.annotation-api + + com.dev-smart + ubjson + diff --git a/config-model-fat/pom.xml b/config-model-fat/pom.xml index 2e883a0c1e35..3a06630e032f 100644 --- a/config-model-fat/pom.xml +++ b/config-model-fat/pom.xml @@ -186,6 +186,7 @@ aopalliance:aopalliance:*:* + com.dev-smart:ubjson:*:* com.google.errorprone:error_prone_annotations:*:* com.google.guava:failureaccess:*:* com.google.guava:guava:*:* diff --git a/container-dev/pom.xml b/container-dev/pom.xml index d4365be0408e..ed9e8e340d6b 100644 --- a/container-dev/pom.xml +++ b/container-dev/pom.xml @@ -94,6 +94,10 @@ org.lz4 lz4-java + + com.dev-smart + ubjson + diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml index 5c650cc49195..27923acf5a4a 100644 --- a/dependency-versions/pom.xml +++ b/dependency-versions/pom.xml @@ -33,6 +33,7 @@ 1.0 + 0.1.8 2.30.0 33.2.1-jre 6.0.0 diff --git a/model-integration/pom.xml b/model-integration/pom.xml index ba1e79f28cfa..58ee4017bd05 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -376,6 +376,10 @@ ${testcontainers.vespa.version} test + + com.dev-smart + ubjson + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/AbstractXGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/AbstractXGBoostParser.java new file mode 100644 index 000000000000..5bdf69d87fef --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/AbstractXGBoostParser.java @@ -0,0 +1,66 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.xgboost; + +/** + * Base class for XGBoost parsers containing shared tree-to-expression conversion logic. + * + * @author arnej + */ +abstract class AbstractXGBoostParser { + + /** + * Converts an XGBoostTree node to a Vespa ranking expression string. + * This method handles both leaf nodes and split nodes recursively. + * + * @param node XGBoost tree node to convert. + * @return Vespa ranking expression for input node. + */ + protected String treeToRankExp(XGBoostTree node) { + if (node.isLeaf()) { + return Double.toString(node.getLeaf()); + } else { + assert node.getChildren().size() == 2; + String trueExp; + String falseExp; + if (node.getYes() == node.getChildren().get(0).getNodeid()) { + trueExp = treeToRankExp(node.getChildren().get(0)); + falseExp = treeToRankExp(node.getChildren().get(1)); + } else { + trueExp = treeToRankExp(node.getChildren().get(1)); + falseExp = treeToRankExp(node.getChildren().get(0)); + } + // xgboost uses float only internally, so round to closest float + float xgbSplitPoint = (float)node.getSplit_condition(); + // but Vespa expects rank profile literals in double precision: + double vespaSplitPoint = xgbSplitPoint; + String formattedSplit = formatSplit(node.getSplit()); + String condition; + if (node.getMissing() == node.getYes()) { + // Note: this is for handling missing features, as the backend handles comparison with NaN as false. + condition = "!(" + formattedSplit + " >= " + vespaSplitPoint + ")"; + } else { + condition = formattedSplit + " < " + vespaSplitPoint; + } + return "if (" + condition + ", " + trueExp + ", " + falseExp + ")"; + } + } + + /** + * Formats a split field value for use in ranking expressions. + * If the split is a plain integer, wraps it with xgboost_input_X format. + * Otherwise, uses the split value as-is (for backward compatibility with JSON format). + * + * @param split The split field value from the tree node + * @return Formatted split expression for use in conditions + */ + protected String formatSplit(String split) { + try { + Integer.parseInt(split); + return "xgboost_input_" + split; + } catch (NumberFormatException e) { + // Not a plain integer, use as-is (JSON format already has full attribute name) + return split; + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index 4d530747d7d3..87cc2478835b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -25,6 +25,9 @@ public boolean canImport(String modelPath) { File modelFile = new File(modelPath); if ( ! modelFile.isFile()) return false; + if (modelFile.toString().endsWith(".ubj")) { + return XGBoostUbjParser.probe(modelPath); + } return modelFile.toString().endsWith(".json") && probe(modelFile); } @@ -52,9 +55,15 @@ private boolean probe(File modelFile) { public ImportedModel importModel(String modelName, String modelPath) { try { ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.XGBOOST); - XGBoostParser parser = new XGBoostParser(modelPath); - RankingExpression expression = new RankingExpression(parser.toRankingExpression()); - model.expression(modelName, expression); + if (modelPath.endsWith(".ubj")) { + XGBoostUbjParser parser = new XGBoostUbjParser(modelPath); + RankingExpression expression = new RankingExpression(parser.toRankingExpression()); + model.expression(modelName, expression); + } else { + XGBoostParser parser = new XGBoostParser(modelPath); + RankingExpression expression = new RankingExpression(parser.toRankingExpression()); + model.expression(modelName, expression); + } return model; } catch (IOException e) { throw new IllegalArgumentException("Could not import XGBoost model from '" + modelPath + "'", e); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java index f869875815fe..4ab7bf8a6787 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java @@ -13,7 +13,7 @@ /** * @author grace-lam */ -class XGBoostParser { +class XGBoostParser extends AbstractXGBoostParser { private final List xgboostTrees; @@ -49,39 +49,4 @@ String toRankingExpression() { return ret.toString(); } - /** - * Recursive helper function for toRankingExpression(). - * - * @param node XGBoost tree node to convert. - * @return Vespa ranking expression for input node. - */ - private String treeToRankExp(XGBoostTree node) { - if (node.isLeaf()) { - return Double.toString(node.getLeaf()); - } else { - assert node.getChildren().size() == 2; - String trueExp; - String falseExp; - if (node.getYes() == node.getChildren().get(0).getNodeid()) { - trueExp = treeToRankExp(node.getChildren().get(0)); - falseExp = treeToRankExp(node.getChildren().get(1)); - } else { - trueExp = treeToRankExp(node.getChildren().get(1)); - falseExp = treeToRankExp(node.getChildren().get(0)); - } - // xgboost uses float only internally, so round to closest float - float xgbSplitPoint = (float)node.getSplit_condition(); - // but Vespa expects rank profile literals in double precision: - double vespaSplitPoint = xgbSplitPoint; - String condition; - if (node.getMissing() == node.getYes()) { - // Note: this is for handling missing features, as the backend handles comparison with NaN as false. - condition = "!(" + node.getSplit() + " >= " + vespaSplitPoint + ")"; - } else { - condition = node.getSplit() + " < " + vespaSplitPoint; - } - return "if (" + condition + ", " + trueExp + ", " + falseExp + ")"; - } - } - } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostUbjParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostUbjParser.java new file mode 100644 index 000000000000..207c6d69e28d --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostUbjParser.java @@ -0,0 +1,583 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.xgboost; + +import com.devsmart.ubjson.UBArray; +import com.devsmart.ubjson.UBObject; +import com.devsmart.ubjson.UBReader; +import com.devsmart.ubjson.UBValue; + +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.FileReader; +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * Parser for XGBoost models in Universal Binary JSON (UBJ) format. + * + * @author arnej + */ +class XGBoostUbjParser extends AbstractXGBoostParser { + + private final List xgboostTrees; + private final double baseScore; + private final List featureNames; + private final String objective; + + /** + * Probes a file to check if it looks like an XGBoost UBJ model. + * This performs minimal parsing to validate the structure. + * + * @param filePath Path to the file to probe. + * @return true if the file appears to be an XGBoost UBJ model. + */ + static boolean probe(String filePath) { + try (FileInputStream fileStream = new FileInputStream(filePath); + UBReader reader = new UBReader(fileStream)) { + UBValue root = reader.read(); + + // Check if it's an array (simple format) + if (root.isArray()) { + UBArray array = root.asArray(); + // Should have at least one tree + if (array.size() == 0) return false; + // First element should be an object with tree structure + if (!array.get(0).isObject()) return false; + UBObject firstTree = array.get(0).asObject(); + return hasTreeStructure(firstTree); + } + + // Check if it's an object (full format with learner) + if (root.isObject()) { + UBObject rootObj = root.asObject(); + UBValue learnerValue = rootObj.get("learner"); + if (learnerValue == null || !learnerValue.isObject()) return false; + + UBObject learner = learnerValue.asObject(); + UBValue gradientBoosterValue = learner.get("gradient_booster"); + if (gradientBoosterValue == null || !gradientBoosterValue.isObject()) return false; + + UBObject gradientBooster = gradientBoosterValue.asObject(); + UBValue modelValue = gradientBooster.get("model"); + if (modelValue == null || !modelValue.isObject()) return false; + + UBObject model = modelValue.asObject(); + UBValue treesValue = model.get("trees"); + if (treesValue == null || !treesValue.isArray()) return false; + + // Looks like a valid XGBoost model structure + return true; + } + + return false; + } catch (IOException | RuntimeException e) { + // Any error during probing means it's not a valid XGBoost UBJ file + return false; + } + } + + /** + * Checks if a UBObject has the expected XGBoost tree structure. + * + * @param treeObj Object to check. + * @return true if it has expected tree arrays. + */ + private static boolean hasTreeStructure(UBObject treeObj) { + // Check for required tree arrays + return treeObj.get("left_children") != null && + treeObj.get("right_children") != null && + treeObj.get("split_conditions") != null && + treeObj.get("split_indices") != null && + treeObj.get("base_weights") != null; + } + + /** + * Constructor stores parsed UBJ trees. + * + * @param filePath XGBoost UBJ input file. + * @throws IOException Fails file reading or UBJ parsing. + */ + XGBoostUbjParser(String filePath) throws IOException { + this.xgboostTrees = new ArrayList<>(); + double tmpBaseScore = 0.5; // default value + List tmpFeatureNames = new ArrayList<>(); + String tmpObjective = "reg:squarederror"; // default objective if not found + try (FileInputStream fileStream = new FileInputStream(filePath); + UBReader reader = new UBReader(fileStream)) { + UBValue root = reader.read(); + + UBArray forestArray; + if (root.isArray()) { + // Simple array format (like JSON export) + forestArray = root.asArray(); + } else if (root.isObject()) { + UBObject rootObj = root.asObject(); + UBObject learner = getRequiredObject(rootObj, "learner", "UBJ root"); + + // Extract objective if available + tmpObjective = extractObjective(learner); + + // Extract base_score if available + tmpBaseScore = extractBaseScore(learner, tmpObjective); + + // Extract feature_names if available + UBValue featureNamesValue = learner.get("feature_names"); + if (featureNamesValue != null && featureNamesValue.isArray()) { + UBArray featureNamesArray = featureNamesValue.asArray(); + for (int i = 0; i < featureNamesArray.size(); i++) { + tmpFeatureNames.add(featureNamesArray.get(i).asString()); + } + } + + // Navigate to trees array + forestArray = navigateToTreesArray(learner); + } else { + throw new IOException("Expected UBJ array or object at root, got: " + root.getClass().getSimpleName()); + } + + // Parse each tree (UBJ format uses flat arrays, not nested objects) + for (int i = 0; i < forestArray.size(); i++) { + UBValue treeValue = forestArray.get(i); + if (!treeValue.isObject()) { + throw new IOException("Expected UBJ object for tree, got: " + treeValue.getClass().getSimpleName()); + } + this.xgboostTrees.add(convertUbjTree(treeValue.asObject())); + } + } + this.baseScore = tmpBaseScore; + this.objective = tmpObjective; + + // Check for optional feature names file (e.g., foobar-features.txt for foobar.ubj) + String featuresPath = withFeaturesSuffix(filePath); + List overrideFeatureNames = loadFeatureNamesFromFile(featuresPath); + if (overrideFeatureNames != null) { + tmpFeatureNames = overrideFeatureNames; + } + + this.featureNames = Collections.unmodifiableList(tmpFeatureNames); + } + + /** + * Converts parsed UBJ trees to Vespa ranking expressions. + * If feature names were loaded from a -features.txt file or extracted from the UBJ file, + * and they match the required count, they will be used automatically. + * Otherwise, uses indexed format (xgboost_input_X). + * + * @return Vespa ranking expressions. + */ + String toRankingExpression() { + if (!featureNames.isEmpty()) { + // note: requires enough feature names. + return toRankingExpression(featureNames); + } + + // Use indexed format (xgboost_input_X) + StringBuilder result = new StringBuilder(); + + // Convert all trees to expressions and join with " + " + for (int i = 0; i < xgboostTrees.size(); i++) { + if (i > 0) { + result.append(" + \n"); + } + result.append(treeToRankExp(xgboostTrees.get(i))); + } + + // Add base_score, with logit transformation only for logistic objectives + result.append(" + \n"); + if (objective.endsWith(":logistic")) { + if (baseScore > 0.0 && baseScore < 1.0) { + // Add precomputed base_score logit transformation + double baseScoreLogit = Math.log(baseScore) - Math.log(1.0 - baseScore); + result.append(baseScoreLogit); + } else { + System.err.println("Bad basescore " + baseScore + " for logistic model, should be in range (0.0, 1.0)"); + result.append("0.0"); + } + } else { + result.append(baseScore); + } + + return result.toString(); + } + + /** + * Converts parsed UBJ trees to Vespa ranking expressions using provided feature names. + * + * @param customFeatureNames List of feature names to map indices to actual names. + * Must contain enough names to cover all feature indices used. + * @return Vespa ranking expressions with named features. + * @throws IllegalArgumentException if customFeatureNames is insufficient for the indices used + */ + String toRankingExpression(List customFeatureNames) { + // Validate that we have the right number of feature names + validateFeatureNames(customFeatureNames); + + StringBuilder result = new StringBuilder(); + + for (int i = 0; i < xgboostTrees.size(); i++) { + if (i > 0) { + result.append(" + \n"); + } + result.append(treeToRankExpWithFeatureNames(xgboostTrees.get(i), customFeatureNames)); + } + + // Add base_score, with logit transformation only for logistic objectives + result.append(" + \n"); + if (objective.endsWith(":logistic")) { + double baseScoreLogit = Math.log(baseScore) - Math.log(1.0 - baseScore); + result.append(baseScoreLogit); + } else { + result.append(baseScore); + } + + return result.toString(); + } + + /** + * Validates that the provided feature names list has exactly the required size for the model. + * + * @param customFeatureNames List of feature names to validate + * @throws IllegalArgumentException if validation fails + */ + private void validateFeatureNames(List customFeatureNames) { + if (customFeatureNames == null || customFeatureNames.isEmpty()) { + throw new IllegalArgumentException("Feature names list cannot be null or empty"); + } + + // Find max feature index used in all trees + int maxIndex = findMaxFeatureIndex(); + int requiredSize = maxIndex + 1; + + if (customFeatureNames.size() < requiredSize) { + throw new IllegalArgumentException( + "Feature names list size mismatch: model requires at least " + requiredSize + + " feature names (indices 0-" + maxIndex + ") but " + + customFeatureNames.size() + " names provided" + ); + } + } + + /** + * Finds the maximum feature index used across all trees. + * + * @return Maximum feature index, or -1 if no features are used + */ + private int findMaxFeatureIndex() { + int max = -1; + for (XGBoostTree tree : xgboostTrees) { + max = Math.max(max, findMaxFeatureIndexInTree(tree)); + } + return max; + } + + /** + * Recursively finds the maximum feature index in a tree. + * + * @param node Tree node to search + * @return Maximum feature index in this tree, or -1 if node is a leaf + */ + private int findMaxFeatureIndexInTree(XGBoostTree node) { + if (node.isLeaf()) { + return -1; // Leaf node + } + + int currentIndex = -1; + try { + currentIndex = Integer.parseInt(node.getSplit()); + } catch (NumberFormatException e) { + // Split is not a number, skip + } + + int childMax = -1; + if (node.getChildren() != null) { + for (XGBoostTree child : node.getChildren()) { + childMax = Math.max(childMax, findMaxFeatureIndexInTree(child)); + } + } + + return Math.max(currentIndex, childMax); + } + + /** + * Converts a tree to ranking expression using custom feature names. + * + * @param node Tree node to convert + * @param customFeatureNames List of feature names for index lookup + * @return Ranking expression string + */ + private String treeToRankExpWithFeatureNames(XGBoostTree node, List customFeatureNames) { + if (node.isLeaf()) { + return Double.toString(node.getLeaf()); + } + + assert node.getChildren().size() == 2; + String trueExp; + String falseExp; + if (node.getYes() == node.getChildren().get(0).getNodeid()) { + trueExp = treeToRankExpWithFeatureNames(node.getChildren().get(0), customFeatureNames); + falseExp = treeToRankExpWithFeatureNames(node.getChildren().get(1), customFeatureNames); + } else { + trueExp = treeToRankExpWithFeatureNames(node.getChildren().get(1), customFeatureNames); + falseExp = treeToRankExpWithFeatureNames(node.getChildren().get(0), customFeatureNames); + } + + int featureIdx = Integer.parseInt(node.getSplit()); + String featureName = customFeatureNames.get(featureIdx); + + // Use the actual feature name instead of indexed format + // Apply the same float rounding as in treeToRankExp + float xgbSplitPoint = (float)node.getSplit_condition(); + double vespaSplitPoint = xgbSplitPoint; + + String condition; + if (node.getMissing() == node.getYes()) { + condition = "!(" + featureName + " >= " + vespaSplitPoint + ")"; + } else { + condition = featureName + " < " + vespaSplitPoint; + } + + return "if (" + condition + ", " + trueExp + ", " + falseExp + ")"; + } + + private static String withFeaturesSuffix(String ubjFilePath) { + if (ubjFilePath.endsWith(".ubj")) { + ubjFilePath = ubjFilePath.substring(0, ubjFilePath.length() - 4); + } + return ubjFilePath + "-features.txt"; + } + + /** + * Attempts to load feature names from an optional text file. + * For a UBJ file "path/to/model.ubj", looks for "path/to/model-features.txt". + * Each line in the file should contain one feature name. + * + * @param ubjFilePath Path to the UBJ file + * @return List of feature names if file exists and is valid, null otherwise + */ + private static List loadFeatureNamesFromFile(String featuresFilePath) { + Path path = Paths.get(featuresFilePath); + + if (!Files.exists(path)) { + return null; // File doesn't exist, that's okay + } + + try { + List featureNames = new ArrayList<>(); + try (BufferedReader reader = new BufferedReader(new FileReader(featuresFilePath))) { + String line; + while ((line = reader.readLine()) != null) { + line = line.trim(); + if (!line.isEmpty() && !line.startsWith("#")) { + featureNames.add(line); + } + } + } + return featureNames.isEmpty() ? null : featureNames; + } catch (IOException e) { + // If we can't read the file, just return null and use default naming + return null; + } + } + + /** + * Extracts a required UBObject from a parent object. + * + * @param parent Parent UBObject to extract from. + * @param key Key name to extract. + * @param parentDescription Description of parent for error messages. + * @return The extracted UBObject. + * @throws IOException If the key is missing or not an object. + */ + private static UBObject getRequiredObject(UBObject parent, String key, String parentDescription) throws IOException { + UBValue value = parent.get(key); + if (value == null || !value.isObject()) { + throw new IOException("Expected '" + key + "' object in " + parentDescription); + } + return value.asObject(); + } + + /** + * Extracts the base_score from learner_model_param if available. + * + * @param learner The learner UBObject. + * @return The extracted base_score, or 0.5 if not found. + */ + private static double extractBaseScore(UBObject learner, String objective) { + UBValue learnerModelParamValue = learner.get("learner_model_param"); + if (learnerModelParamValue != null && learnerModelParamValue.isObject()) { + UBObject learnerModelParam = learnerModelParamValue.asObject(); + UBValue baseScoreValue = learnerModelParam.get("base_score"); + if (baseScoreValue != null && baseScoreValue.isString()) { + String baseScoreStr = baseScoreValue.asString(); + // Parse string like "[6.274165E-1]" - remove brackets and parse + baseScoreStr = baseScoreStr.replace("[", "").replace("]", ""); + return Double.parseDouble(baseScoreStr); + } + } + if (objective != null && objective.endsWith(":logistic")) { + return 0.5; // default value for logistic + } else { + return 0.0; // default value for simple regression + } + } + + /** + * Extracts the objective name from the objective object if available. + * + * @param learner The learner UBObject. + * @return The extracted objective name, or "reg:squarederror" if not found. + */ + private static String extractObjective(UBObject learner) { + UBValue objectiveValue = learner.get("objective"); + if (objectiveValue != null && objectiveValue.isObject()) { + UBObject objective = objectiveValue.asObject(); + UBValue nameValue = objective.get("name"); + if (nameValue != null && nameValue.isString()) { + return nameValue.asString(); + } + } + return "reg:squarederror"; // default objective if not found + } + + /** + * Navigates from learner object to the trees array. + * + * @param learner The learner UBObject. + * @return The trees UBArray. + * @throws IOException If navigation fails. + */ + private static UBArray navigateToTreesArray(UBObject learner) throws IOException { + UBObject gradientBooster = getRequiredObject(learner, "gradient_booster", "learner"); + UBObject model = getRequiredObject(gradientBooster, "model", "gradient_booster"); + UBValue treesValue = model.get("trees"); + if (treesValue == null || !treesValue.isArray()) { + throw new IOException("Expected 'trees' array in model"); + } + return treesValue.asArray(); + } + + /** + * Converts a UBJ tree (flat array format) to the root XGBoostTree node (hierarchical format). + * + * @param treeObj UBJ object containing flat arrays representing the tree. + * @return Root XGBoostTree node with hierarchical structure. + */ + private static XGBoostTree convertUbjTree(UBObject treeObj) { + // Extract flat arrays from UBJ format + int[] leftChildren = treeObj.get("left_children").asInt32Array(); + int[] rightChildren = treeObj.get("right_children").asInt32Array(); + float[] splitConditions = treeObj.get("split_conditions").asFloat32Array(); + int[] splitIndices = treeObj.get("split_indices").asInt32Array(); + float[] baseWeights = treeObj.get("base_weights").asFloat32Array(); + byte[] defaultLeftBytes = extractDefaultLeft(treeObj.get("default_left")); + + // Convert from flat arrays to hierarchical tree structure, starting at root (node 0, depth 0) + return buildTreeFromArrays(0, 0, leftChildren, rightChildren, splitConditions, + splitIndices, baseWeights, defaultLeftBytes); + } + + /** + * Extracts the default_left array from UBJ value. + * Handles both UBArray and direct byte array formats. + * + * @param defaultLeftValue The UBValue containing default_left data. + * @return Byte array with default_left values. + */ + private static byte[] extractDefaultLeft(UBValue defaultLeftValue) { + if (defaultLeftValue.isArray()) { + // It's a UBArray, iterate and convert + UBArray defaultLeftArray = defaultLeftValue.asArray(); + byte[] result = new byte[defaultLeftArray.size()]; + for (int i = 0; i < defaultLeftArray.size(); i++) { + result[i] = defaultLeftArray.get(i).asByte(); + } + return result; + } else { + return defaultLeftValue.asByteArray(); + } + } + + /** + * Recursively builds a hierarchical XGBoostTree from flat arrays. + * + * @param nodeId Current node index in the arrays. + * @param depth Current depth in the tree (0 for root). + * @param leftChildren Array of left child indices. + * @param rightChildren Array of right child indices. + * @param splitConditions Array of split threshold values. + * @param splitIndices Array of feature indices to split on. + * @param baseWeights Array of base weights (leaf values). + * @param defaultLeft Array indicating if missing values go left. + * @return XGBoostTree node. + */ + private static XGBoostTree buildTreeFromArrays(int nodeId, int depth, int[] leftChildren, int[] rightChildren, + float[] splitConditions, int[] splitIndices, + float[] baseWeights, byte[] defaultLeft) { + XGBoostTree node = new XGBoostTree(); + setField(node, "nodeid", nodeId); + setField(node, "depth", depth); + + // Check if this is a leaf node + boolean isLeaf = leftChildren[nodeId] == -1; + + if (isLeaf) { + // Leaf node: set the leaf value from base_weights + // Apply float rounding to match XGBoost's internal precision + double leafValue = baseWeights[nodeId]; + setField(node, "leaf", leafValue); + } else { + // Split node: set split information + int featureIdx = splitIndices[nodeId]; + setField(node, "split", String.valueOf(featureIdx)); + // Apply float rounding to match XGBoost's internal precision (same as XGBoostParser) + double splitValue = splitConditions[nodeId]; + setField(node, "split_condition", splitValue); + + int leftChild = leftChildren[nodeId]; + int rightChild = rightChildren[nodeId]; + boolean goLeftOnMissing = defaultLeft[nodeId] != 0; + + // In XGBoost trees: + // - Left child is taken when feature < threshold + // - Right child is taken when feature >= threshold + // - default_left only controls where missing values go + setField(node, "yes", leftChild); // yes = condition is true = feature < threshold = go left + setField(node, "no", rightChild); // no = condition is false = feature >= threshold = go right + setField(node, "missing", goLeftOnMissing ? leftChild : rightChild); + + // Recursively build children + List children = new ArrayList<>(); + children.add(buildTreeFromArrays(leftChild, depth + 1, leftChildren, rightChildren, + splitConditions, splitIndices, baseWeights, defaultLeft)); + children.add(buildTreeFromArrays(rightChild, depth + 1, leftChildren, rightChildren, + splitConditions, splitIndices, baseWeights, defaultLeft)); + setField(node, "children", children); + } + + return node; + } + + /** + * Uses reflection to set a private field on an object. + * + * @param obj Object to modify. + * @param fieldName Name of the field to set. + * @param value Value to set. + */ + private static void setField(Object obj, String fieldName, Object value) { + try { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } catch (NoSuchFieldException | IllegalAccessException e) { + throw new RuntimeException("Failed to set field '" + fieldName + "' via reflection", e); + } + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/UbjToJson.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/UbjToJson.java new file mode 100644 index 000000000000..83c1194c55b6 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/UbjToJson.java @@ -0,0 +1,146 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.xgboost; + +import com.devsmart.ubjson.UBArray; +import com.devsmart.ubjson.UBObject; +import com.devsmart.ubjson.UBReader; +import com.devsmart.ubjson.UBValue; + +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Set; + +/** + * Utility to dump UBJ files as JSON format. + * Outputs the raw UBJ structure without any conversion. + * Usage: java UbjToJson + * + * @author arnej + */ +public class UbjToJson { + + public static void main(String[] args) throws IOException { + if (args.length != 1) { + System.err.println("Usage: java UbjToJson "); + System.exit(1); + } + + String ubjPath = args[0]; + UbjToJson converter = new UbjToJson(); + String json = converter.convertUbjToJson(ubjPath); + System.out.println(json); + } + + public String convertUbjToJson(String filePath) throws IOException { + try (FileInputStream fileStream = new FileInputStream(filePath); + UBReader reader = new UBReader(fileStream)) { + UBValue root = reader.read(); + return toJson(root, 0); + } + } + + private String toJson(UBValue value, int indent) { + if (value.isNull()) { + return "null"; + } else if (value.isBool()) { + return Boolean.toString(value.asBool()); + } else if (value.isNumber()) { + if (value.isInteger()) { + return Long.toString(value.asLong()); + } else { + return Double.toString(value.asFloat64()); + } + } else if (value.isString()) { + return jsonString(value.asString()); + } else if (value.isArray()) { + return arrayToJson(value.asArray(), indent); + } else if (value.isObject()) { + return objectToJson(value.asObject(), indent); + } else { + return "\"\""; + } + } + + private String arrayToJson(UBArray array, int indent) { + if (array.size() == 0) { + return "[]"; + } + + // Check if this is a typed array and if all elements are numbers + boolean isNumberArray = true; + if (array.size() > 0) { + for (int i = 0; i < array.size(); i++) { + if (!array.get(i).isNumber()) { + isNumberArray = false; + break; + } + } + } + + StringBuilder sb = new StringBuilder(); + String spaces = " ".repeat(indent); + String itemSpaces = " ".repeat(indent + 1); + + if (isNumberArray && array.size() > 10) { + // Compact format for large number arrays + sb.append("["); + for (int i = 0; i < array.size(); i++) { + if (i > 0) sb.append(", "); + sb.append(toJson(array.get(i), indent + 1)); + } + sb.append("]"); + } else { + // Regular format + sb.append("[\n"); + for (int i = 0; i < array.size(); i++) { + sb.append(itemSpaces); + sb.append(toJson(array.get(i), indent + 1)); + if (i < array.size() - 1) { + sb.append(","); + } + sb.append("\n"); + } + sb.append(spaces).append("]"); + } + + return sb.toString(); + } + + private String objectToJson(UBObject obj, int indent) { + Set keys = obj.keySet(); + if (keys.isEmpty()) { + return "{}"; + } + + StringBuilder sb = new StringBuilder(); + String spaces = " ".repeat(indent); + String itemSpaces = " ".repeat(indent + 1); + + sb.append("{\n"); + int i = 0; + for (String key : keys) { + sb.append(itemSpaces); + sb.append(jsonString(key)); + sb.append(": "); + sb.append(toJson(obj.get(key), indent + 1)); + if (i < keys.size() - 1) { + sb.append(","); + } + sb.append("\n"); + i++; + } + sb.append(spaces).append("}"); + + return sb.toString(); + } + + private String jsonString(String str) { + return "\"" + str + .replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t") + + "\""; + } +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java index c45d99274b89..845daa0cdad6 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -5,9 +5,15 @@ import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Test; +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertFalse; /** * @author bratseth @@ -26,4 +32,89 @@ public void testXGBoost() { assertEquals(1, model.outputExpressions().size()); } + @Test + public void testXGBoostUBJ() { + // Test that UBJ format imports successfully and includes base_score adjustment + XGBoostImporter importer = new XGBoostImporter(); + ImportedModel jsonModel = importer.importModel("test", "src/test/models/xgboost/binary_breast_cancer.json"); + ImportedModel ubjModel = importer.importModel("test", "src/test/models/xgboost/binary_breast_cancer.ubj"); + + assertNotNull("JSON model should be imported", jsonModel); + assertNotNull("UBJ model should be imported", ubjModel); + + RankingExpression jsonExpression = jsonModel.expressions().get("test"); + RankingExpression ubjExpression = ubjModel.expressions().get("test"); + + assertNotNull("JSON expression should exist", jsonExpression); + assertNotNull("UBJ expression should exist", ubjExpression); + + String jsonExprStr = jsonExpression.getRoot().toString(); + String ubjExprStr = ubjExpression.getRoot().toString(); + + // UBJ should include the base_score logit transformation + assertTrue("UBJ expression should contain base_score adjustment", + ubjExprStr.contains(" 0.52114942")); + + // JSON should use xgboost_input_X format (from the JSON file) + assertTrue("JSON should use xgboost_input_ format", + jsonExprStr.contains("xgboost_input_")); + + // UBJ should use feature names (auto-loaded from binary_breast_cancer-features.txt) + assertTrue("UBJ should use feature names from file", + ubjExprStr.contains("mean_radius")); + assertFalse("UBJ should not use indexed format", + ubjExprStr.contains("xgboost_input_")); + } + + @Test + public void testXGBoostUBJWithFeatureNames() throws IOException { + XGBoostUbjParser parser = new XGBoostUbjParser("src/test/models/xgboost/binary_breast_cancer.ubj"); + + // Create feature names list (30 features for breast cancer dataset) + List featureNames = Arrays.asList( + "mean_radius", "mean_texture", "mean_perimeter", "mean_area", + "mean_smoothness", "mean_compactness", "mean_concavity", + "mean_concave_points", "mean_symmetry", "mean_fractal_dimension", + "radius_error", "texture_error", "perimeter_error", "area_error", + "smoothness_error", "compactness_error", "concavity_error", + "concave_points_error", "symmetry_error", "fractal_dimension_error", + "worst_radius", "worst_texture", "worst_perimeter", "worst_area", + "worst_smoothness", "worst_compactness", "worst_concavity", + "worst_concave_points", "worst_symmetry", "worst_fractal_dimension" + ); + + String expression = parser.toRankingExpression(featureNames); + assertNotNull(expression); + assertTrue("Expression should contain custom feature name", expression.contains("mean_radius")); + assertTrue("Expression should contain custom feature name", expression.contains("mean_texture")); + assertFalse("Expression should not contain indexed format", expression.contains("xgboost_input_")); + } + + @Test + public void testXGBoostUBJWithInsufficientFeatureNames() throws IOException { + XGBoostUbjParser parser = new XGBoostUbjParser("src/test/models/xgboost/binary_breast_cancer.ubj"); + + // Only provide 5 feature names when model needs 30 + List featureNames = Arrays.asList("f0", "f1", "f2", "f3", "f4"); + + assertThrows(IllegalArgumentException.class, () -> { + parser.toRankingExpression(featureNames); + }); + } + + @Test + public void testXGBoostUBJAutoLoadFeatureNames() throws IOException { + // The binary_breast_cancer-features.txt file should be automatically loaded + XGBoostUbjParser parser = new XGBoostUbjParser("src/test/models/xgboost/binary_breast_cancer.ubj"); + + // Call no-arg toRankingExpression() - should use feature names from file + String expression = parser.toRankingExpression(); + assertNotNull(expression); + + // Verify that custom feature names are used (from the -features.txt file) + assertTrue("Expression should contain feature name from file", expression.contains("mean_radius")); + assertTrue("Expression should contain feature name from file", expression.contains("worst_texture")); + assertFalse("Expression should not contain indexed format", expression.contains("xgboost_input_")); + } + } diff --git a/model-integration/src/test/models/xgboost/binary_breast_cancer-features.txt b/model-integration/src/test/models/xgboost/binary_breast_cancer-features.txt new file mode 100644 index 000000000000..dd48a897c86b --- /dev/null +++ b/model-integration/src/test/models/xgboost/binary_breast_cancer-features.txt @@ -0,0 +1,32 @@ +# Feature names for Wisconsin Breast Cancer dataset +# One feature name per line +mean_radius +mean_texture +mean_perimeter +mean_area +mean_smoothness +mean_compactness +mean_concavity +mean_concave_points +mean_symmetry +mean_fractal_dimension +radius_error +texture_error +perimeter_error +area_error +smoothness_error +compactness_error +concavity_error +concave_points_error +symmetry_error +fractal_dimension_error +worst_radius +worst_texture +worst_perimeter +worst_area +worst_smoothness +worst_compactness +worst_concavity +worst_concave_points +worst_symmetry +worst_fractal_dimension diff --git a/model-integration/src/test/models/xgboost/binary_breast_cancer.json b/model-integration/src/test/models/xgboost/binary_breast_cancer.json new file mode 100644 index 000000000000..47a4cf498928 --- /dev/null +++ b/model-integration/src/test/models/xgboost/binary_breast_cancer.json @@ -0,0 +1,343 @@ +[ + { "nodeid": 0, "depth": 0, "split": "xgboost_input_20", "split_condition": 16.81999969482422, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_27", "split_condition": 0.13570000231266022, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_10", "split_condition": 0.6449999809265137, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.4603551924228668 }, + { "nodeid": 8, "depth": 3, "leaf": -0.01896175928413868 } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_21", "split_condition": 25.579999923706055, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "split": "xgboost_input_23", "split_condition": 806.9000244140625, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "leaf": 0.3054772615432739 }, + { "nodeid": 16, "depth": 4, "leaf": -0.15728549659252167 } ] }, + { "nodeid": 10, "depth": 3, "split": "xgboost_input_6", "split_condition": 0.09696999937295914, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 4, "leaf": -0.09545934945344925 }, + { "nodeid": 18, "depth": 4, "leaf": -0.6689253449440002 } ] } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_21", "split_condition": 19.59000015258789, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_7", "split_condition": 0.06254000216722488, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "leaf": 0.3241020143032074 }, + { "nodeid": 12, "depth": 3, "leaf": -0.5246468782424927 } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_26", "split_condition": 0.1889999955892563, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 3, "leaf": -0.15728549659252167 }, + { "nodeid": 14, "depth": 3, "leaf": -0.7851951718330383 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_22", "split_condition": 106.0, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_27", "split_condition": 0.15729999542236328, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_27", "split_condition": 0.13570000231266022, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_13", "split_condition": 45.189998626708984, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "split": "xgboost_input_14", "split_condition": 0.003289999905973673, "yes": 23, "no": 24, "missing": 24, "children": [ + { "nodeid": 23, "depth": 5, "leaf": 0.11408944427967072 }, + { "nodeid": 24, "depth": 5, "leaf": 0.40121492743492126 } ] }, + { "nodeid": 16, "depth": 4, "leaf": -0.030774641782045364 } ] }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_1", "split_condition": 19.219999313354492, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 4, "leaf": 0.2960207462310791 }, + { "nodeid": 18, "depth": 4, "leaf": -0.23078586161136627 } ] } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_22", "split_condition": 97.66000366210938, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": -0.10795557498931885 }, + { "nodeid": 10, "depth": 3, "leaf": -0.3868221342563629 } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_7", "split_condition": 0.04845999926328659, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_17", "split_condition": 0.009996999986469746, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "split": "xgboost_input_1", "split_condition": 19.3799991607666, "yes": 19, "no": 20, "missing": 20, "children": [ + { "nodeid": 19, "depth": 4, "leaf": 0.11794889718294144 }, + { "nodeid": 20, "depth": 4, "leaf": -0.42626824975013733 } ] }, + { "nodeid": 12, "depth": 3, "leaf": 0.3407819867134094 } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_21", "split_condition": 20.450000762939453, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 3, "split": "xgboost_input_7", "split_condition": 0.07339999824762344, "yes": 21, "no": 22, "missing": 22, "children": [ + { "nodeid": 21, "depth": 4, "leaf": 0.25160130858421326 }, + { "nodeid": 22, "depth": 4, "leaf": -0.39930129051208496 } ] }, + { "nodeid": 14, "depth": 3, "leaf": -0.5225854516029358 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_7", "split_condition": 0.04938000068068504, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_20", "split_condition": 16.81999969482422, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_13", "split_condition": 43.95000076293945, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_14", "split_condition": 0.003289999905973673, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "leaf": 0.0977279469370842 }, + { "nodeid": 16, "depth": 4, "split": "xgboost_input_21", "split_condition": 33.209999084472656, "yes": 21, "no": 22, "missing": 22, "children": [ + { "nodeid": 21, "depth": 5, "leaf": 0.3699623644351959 }, + { "nodeid": 22, "depth": 5, "split": "xgboost_input_1", "split_condition": 26.989999771118164, "yes": 27, "no": 28, "missing": 28, "children": [ + { "nodeid": 27, "depth": 6, "leaf": -0.05926831439137459 }, + { "nodeid": 28, "depth": 6, "leaf": 0.2650623321533203 } ] } ] } ] }, + { "nodeid": 8, "depth": 3, "leaf": -0.064254529774189 } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_1", "split_condition": 15.729999542236328, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.21864144504070282 }, + { "nodeid": 10, "depth": 3, "split": "xgboost_input_17", "split_condition": 0.009232999756932259, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 4, "leaf": -0.35293275117874146 }, + { "nodeid": 18, "depth": 4, "leaf": -0.06598913669586182 } ] } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_22", "split_condition": 101.9000015258789, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_21", "split_condition": 25.479999542236328, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "leaf": 0.3031226098537445 }, + { "nodeid": 12, "depth": 3, "leaf": -0.18023964762687683 } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_26", "split_condition": 0.21230000257492065, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 3, "leaf": 0.14003755152225494 }, + { "nodeid": 14, "depth": 3, "split": "xgboost_input_1", "split_condition": 15.34000015258789, "yes": 19, "no": 20, "missing": 20, "children": [ + { "nodeid": 19, "depth": 4, "split": "xgboost_input_6", "split_condition": 0.14569999277591705, "yes": 23, "no": 24, "missing": 24, "children": [ + { "nodeid": 23, "depth": 5, "leaf": 0.10976236313581467 }, + { "nodeid": 24, "depth": 5, "leaf": -0.2641395330429077 } ] }, + { "nodeid": 20, "depth": 4, "split": "xgboost_input_4", "split_condition": 0.08354999870061874, "yes": 25, "no": 26, "missing": 26, "children": [ + { "nodeid": 25, "depth": 5, "leaf": -0.11790292710065842 }, + { "nodeid": 26, "depth": 5, "leaf": -0.43853873014450073 } ] } ] } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_22", "split_condition": 102.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_24", "split_condition": 0.17820000648498535, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_13", "split_condition": 49.0, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_0", "split_condition": 14.109999656677246, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 4, "split": "xgboost_input_8", "split_condition": 0.23749999701976776, "yes": 19, "no": 20, "missing": 20, "children": [ + { "nodeid": 19, "depth": 5, "leaf": 0.3398503065109253 }, + { "nodeid": 20, "depth": 5, "leaf": 0.08510195463895798 } ] }, + { "nodeid": 14, "depth": 4, "leaf": 0.06697364151477814 } ] }, + { "nodeid": 8, "depth": 3, "leaf": -0.04186965152621269 } ] }, + { "nodeid": 4, "depth": 2, "leaf": -0.16538292169570923 } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_7", "split_condition": 0.04845999926328659, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_21", "split_condition": 26.440000534057617, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.2624058723449707 }, + { "nodeid": 10, "depth": 3, "split": "xgboost_input_15", "split_condition": 0.01631000079214573, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "leaf": -0.3455977141857147 }, + { "nodeid": 16, "depth": 4, "leaf": 0.20227135717868805 } ] } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_26", "split_condition": 0.21230000257492065, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "leaf": 0.12086576223373413 }, + { "nodeid": 12, "depth": 3, "split": "xgboost_input_21", "split_condition": 18.40999984741211, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 4, "leaf": -0.054553307592868805 }, + { "nodeid": 18, "depth": 4, "split": "xgboost_input_16", "split_condition": 0.10270000249147415, "yes": 21, "no": 22, "missing": 22, "children": [ + { "nodeid": 21, "depth": 5, "leaf": -0.3827503025531769 }, + { "nodeid": 22, "depth": 5, "leaf": -0.09866306185722351 } ] } ] } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_23", "split_condition": 888.2999877929688, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_27", "split_condition": 0.1606999933719635, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_1", "split_condition": 21.309999465942383, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_10", "split_condition": 0.550599992275238, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 4, "split": "xgboost_input_23", "split_condition": 811.2999877929688, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 5, "leaf": 0.3330357074737549 }, + { "nodeid": 16, "depth": 5, "split": "xgboost_input_15", "split_condition": 0.01841999962925911, "yes": 19, "no": 20, "missing": 20, "children": [ + { "nodeid": 19, "depth": 6, "leaf": 0.226839080452919 }, + { "nodeid": 20, "depth": 6, "leaf": -0.04981991648674011 } ] } ] }, + { "nodeid": 12, "depth": 4, "leaf": 0.06199278682470322 } ] }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_23", "split_condition": 648.2999877929688, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 4, "leaf": 0.26304540038108826 }, + { "nodeid": 14, "depth": 4, "split": "xgboost_input_14", "split_condition": 0.005872000008821487, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 5, "leaf": 0.058584075421094894 }, + { "nodeid": 18, "depth": 5, "split": "xgboost_input_9", "split_condition": 0.06066000089049339, "yes": 21, "no": 22, "missing": 22, "children": [ + { "nodeid": 21, "depth": 6, "leaf": -0.37050333619117737 }, + { "nodeid": 22, "depth": 6, "leaf": -0.07894755154848099 } ] } ] } ] } ] }, + { "nodeid": 4, "depth": 2, "leaf": -0.2753813862800598 } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_6", "split_condition": 0.0729300007224083, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_1", "split_condition": 19.510000228881836, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.18285000324249268 }, + { "nodeid": 10, "depth": 3, "leaf": -0.2675382196903229 } ] }, + { "nodeid": 6, "depth": 2, "leaf": -0.3542742133140564 } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_27", "split_condition": 0.1111999973654747, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_3", "split_condition": 698.7999877929688, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_28", "split_condition": 0.1987999975681305, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.0755092054605484 }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_21", "split_condition": 33.209999084472656, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "leaf": 0.3183874487876892 }, + { "nodeid": 16, "depth": 4, "leaf": 0.11721514165401459 } ] } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_1", "split_condition": 20.219999313354492, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.08036360889673233 }, + { "nodeid": 10, "depth": 3, "leaf": -0.252962589263916 } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_23", "split_condition": 734.5999755859375, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_27", "split_condition": 0.17649999260902405, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "split": "xgboost_input_1", "split_condition": 20.25, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 4, "leaf": 0.2636157274246216 }, + { "nodeid": 18, "depth": 4, "leaf": 0.028473349288105965 } ] }, + { "nodeid": 12, "depth": 3, "leaf": -0.21361969411373138 } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_21", "split_condition": 19.899999618530273, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 3, "split": "xgboost_input_13", "split_condition": 43.400001525878906, "yes": 19, "no": 20, "missing": 20, "children": [ + { "nodeid": 19, "depth": 4, "leaf": 0.22751691937446594 }, + { "nodeid": 20, "depth": 4, "leaf": -0.22558310627937317 } ] }, + { "nodeid": 14, "depth": 3, "leaf": -0.32582372426986694 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_7", "split_condition": 0.04938000068068504, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_3", "split_condition": 698.7999877929688, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_13", "split_condition": 41.5099983215332, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_21", "split_condition": 33.209999084472656, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "split": "xgboost_input_19", "split_condition": 0.0013810000382363796, "yes": 19, "no": 20, "missing": 20, "children": [ + { "nodeid": 19, "depth": 5, "leaf": 0.03624707832932472 }, + { "nodeid": 20, "depth": 5, "leaf": 0.3080938160419464 } ] }, + { "nodeid": 16, "depth": 4, "split": "xgboost_input_28", "split_condition": 0.24799999594688416, "yes": 21, "no": 22, "missing": 22, "children": [ + { "nodeid": 21, "depth": 5, "leaf": 0.17582161724567413 }, + { "nodeid": 22, "depth": 5, "leaf": -0.09668976068496704 } ] } ] }, + { "nodeid": 8, "depth": 3, "leaf": -0.03945023939013481 } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_1", "split_condition": 19.510000228881836, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.045127417892217636 }, + { "nodeid": 10, "depth": 3, "leaf": -0.2194928079843521 } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_21", "split_condition": 23.75, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_23", "split_condition": 809.7000122070312, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "leaf": 0.24866445362567902 }, + { "nodeid": 12, "depth": 3, "split": "xgboost_input_15", "split_condition": 0.020749999210238457, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 4, "leaf": -0.005317678675055504 }, + { "nodeid": 18, "depth": 4, "leaf": -0.24861639738082886 } ] } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_23", "split_condition": 680.5999755859375, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 3, "leaf": -0.06991633772850037 }, + { "nodeid": 14, "depth": 3, "leaf": -0.31441980600357056 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_27", "split_condition": 0.1111999973654747, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_3", "split_condition": 698.7999877929688, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_28", "split_condition": 0.1987999975681305, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.03794674202799797 }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_21", "split_condition": 33.209999084472656, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 4, "leaf": 0.29745399951934814 }, + { "nodeid": 14, "depth": 4, "leaf": 0.09482062608003616 } ] } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_1", "split_condition": 20.219999313354492, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.06832267343997955 }, + { "nodeid": 10, "depth": 3, "leaf": -0.20128701627254486 } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_22", "split_condition": 116.19999694824219, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_21", "split_condition": 27.489999771118164, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "split": "xgboost_input_27", "split_condition": 0.1606999933719635, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "leaf": 0.21199870109558105 }, + { "nodeid": 16, "depth": 4, "leaf": -0.09946209192276001 } ] }, + { "nodeid": 12, "depth": 3, "split": "xgboost_input_23", "split_condition": 699.4000244140625, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 4, "leaf": -0.02004398964345455 }, + { "nodeid": 18, "depth": 4, "leaf": -0.25623419880867004 } ] } ] }, + { "nodeid": 6, "depth": 2, "leaf": -0.30207687616348267 } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_27", "split_condition": 0.14239999651908875, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_23", "split_condition": 967.0, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_13", "split_condition": 35.2400016784668, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_21", "split_condition": 30.149999618530273, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 4, "leaf": 0.2801840007305145 }, + { "nodeid": 14, "depth": 4, "split": "xgboost_input_1", "split_condition": 23.5, "yes": 17, "no": 18, "missing": 18, "children": [ + { "nodeid": 17, "depth": 5, "leaf": -0.1386641263961792 }, + { "nodeid": 18, "depth": 5, "leaf": 0.20163708925247192 } ] } ] }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_19", "split_condition": 0.002767999889329076, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "leaf": -0.17439478635787964 }, + { "nodeid": 16, "depth": 4, "leaf": 0.12734214961528778 } ] } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_28", "split_condition": 0.2533000111579895, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": -0.007179913576692343 }, + { "nodeid": 10, "depth": 3, "leaf": -0.20481328666210175 } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_13", "split_condition": 21.459999084472656, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "leaf": -4.309949144953862e-05 }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_4", "split_condition": 0.08998999744653702, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "leaf": -0.06140953674912453 }, + { "nodeid": 12, "depth": 3, "leaf": -0.28825199604034424 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_7", "split_condition": 0.04938000068068504, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_20", "split_condition": 16.81999969482422, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_15", "split_condition": 0.012029999867081642, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_16", "split_condition": 0.012719999998807907, "yes": 15, "no": 16, "missing": 16, "children": [ + { "nodeid": 15, "depth": 4, "leaf": 0.22184161841869354 }, + { "nodeid": 16, "depth": 4, "leaf": -0.15230606496334076 } ] }, + { "nodeid": 8, "depth": 3, "leaf": 0.27187174558639526 } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_28", "split_condition": 0.2653999924659729, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.03678994998335838 }, + { "nodeid": 10, "depth": 3, "leaf": -0.13423432409763336 } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_21", "split_condition": 23.75, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_23", "split_condition": 809.7000122070312, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "leaf": 0.20270322263240814 }, + { "nodeid": 12, "depth": 3, "leaf": -0.15306414663791656 } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_6", "split_condition": 0.09060999751091003, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 3, "leaf": -0.05368896201252937 }, + { "nodeid": 14, "depth": 3, "leaf": -0.2783971130847931 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_26", "split_condition": 0.2079000025987625, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_13", "split_condition": 40.5099983215332, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "leaf": 0.27354902029037476 }, + { "nodeid": 4, "depth": 2, "leaf": -0.05269660800695419 } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_23", "split_condition": 648.2999877929688, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_7", "split_condition": 0.055959999561309814, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.19431108236312866 }, + { "nodeid": 8, "depth": 3, "leaf": -0.042131464928388596 } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_21", "split_condition": 19.899999618530273, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.04776393994688988 }, + { "nodeid": 10, "depth": 3, "split": "xgboost_input_24", "split_condition": 0.10920000076293945, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 4, "leaf": 0.011151635088026524 }, + { "nodeid": 12, "depth": 4, "leaf": -0.26751622557640076 } ] } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_23", "split_condition": 967.0, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_21", "split_condition": 29.25, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_12", "split_condition": 3.430000066757202, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_29", "split_condition": 0.10189999639987946, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 4, "leaf": 0.25556275248527527 }, + { "nodeid": 12, "depth": 4, "leaf": 0.018566781654953957 } ] }, + { "nodeid": 8, "depth": 3, "leaf": -0.01612720638513565 } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_27", "split_condition": 0.09139999747276306, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.14816634356975555 }, + { "nodeid": 10, "depth": 3, "split": "xgboost_input_24", "split_condition": 0.13410000503063202, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 4, "leaf": -0.0205707810819149 }, + { "nodeid": 14, "depth": 4, "leaf": -0.2519259452819824 } ] } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_25", "split_condition": 0.17110000550746918, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "leaf": -0.02155451662838459 }, + { "nodeid": 6, "depth": 2, "leaf": -0.2605815827846527 } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_22", "split_condition": 120.4000015258789, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_21", "split_condition": 29.25, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_25", "split_condition": 0.32350000739097595, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 3, "split": "xgboost_input_25", "split_condition": 0.08340000361204147, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 4, "leaf": -0.03031761385500431 }, + { "nodeid": 10, "depth": 4, "leaf": 0.2458493560552597 } ] }, + { "nodeid": 6, "depth": 3, "split": "xgboost_input_23", "split_condition": 734.5999755859375, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 4, "leaf": 0.10233080387115479 }, + { "nodeid": 12, "depth": 4, "leaf": -0.09648152440786362 } ] } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_26", "split_condition": 0.20280000567436218, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.13269340991973877 }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_15", "split_condition": 0.017960000783205032, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 4, "leaf": -0.24554097652435303 }, + { "nodeid": 14, "depth": 4, "leaf": -0.033455345779657364 } ] } ] } ] }, + { "nodeid": 2, "depth": 1, "leaf": -0.23360854387283325 } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_23", "split_condition": 876.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_24", "split_condition": 0.14069999754428864, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_15", "split_condition": 0.01104000024497509, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": -0.0031496614683419466 }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_6", "split_condition": 0.1111999973654747, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 4, "leaf": 0.23949843645095825 }, + { "nodeid": 12, "depth": 4, "leaf": 0.03775443509221077 } ] } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_22", "split_condition": 91.62000274658203, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.10552486777305603 }, + { "nodeid": 10, "depth": 3, "split": "xgboost_input_21", "split_condition": 27.209999084472656, "yes": 13, "no": 14, "missing": 14, "children": [ + { "nodeid": 13, "depth": 4, "leaf": -0.023241691291332245 }, + { "nodeid": 14, "depth": 4, "leaf": -0.20789241790771484 } ] } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_6", "split_condition": 0.05928000062704086, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "leaf": 0.004834835417568684 }, + { "nodeid": 6, "depth": 2, "leaf": -0.23364728689193726 } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_26", "split_condition": 0.2079000025987625, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_13", "split_condition": 40.5099983215332, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "leaf": 0.23598936200141907 }, + { "nodeid": 4, "depth": 2, "leaf": -0.060127921402454376 } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_21", "split_condition": 25.579999923706055, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_23", "split_condition": 811.2999877929688, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.173164963722229 }, + { "nodeid": 8, "depth": 3, "leaf": -0.09621238708496094 } ] }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_4", "split_condition": 0.08946000039577484, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": -0.03863441199064255 }, + { "nodeid": 10, "depth": 3, "leaf": -0.21613681316375732 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_13", "split_condition": 33.0099983215332, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_24", "split_condition": 0.13770000636577606, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_14", "split_condition": 0.004147999919950962, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": -0.005957402754575014 }, + { "nodeid": 8, "depth": 3, "leaf": 0.22239074110984802 } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_22", "split_condition": 91.62000274658203, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 3, "leaf": 0.1025158017873764 }, + { "nodeid": 10, "depth": 3, "leaf": -0.12606994807720184 } ] } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_24", "split_condition": 0.11180000007152557, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "leaf": 0.04408538341522217 }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_21", "split_condition": 23.190000534057617, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 3, "leaf": -0.03383628651499748 }, + { "nodeid": 12, "depth": 3, "leaf": -0.21980330348014832 } ] } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_26", "split_condition": 0.2079000025987625, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_13", "split_condition": 40.5099983215332, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "leaf": 0.21418677270412445 }, + { "nodeid": 4, "depth": 2, "leaf": -0.03940589725971222 } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_23", "split_condition": 967.0, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "split": "xgboost_input_21", "split_condition": 29.25, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "split": "xgboost_input_13", "split_condition": 23.309999465942383, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 4, "leaf": 0.17243850231170654 }, + { "nodeid": 10, "depth": 4, "leaf": -0.014778186567127705 } ] }, + { "nodeid": 8, "depth": 3, "leaf": -0.1282825917005539 } ] }, + { "nodeid": 6, "depth": 2, "leaf": -0.20745433866977692 } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_27", "split_condition": 0.1606999933719635, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_1", "split_condition": 20.200000762939453, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_13", "split_condition": 35.029998779296875, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 3, "split": "xgboost_input_20", "split_condition": 16.25, "yes": 9, "no": 10, "missing": 10, "children": [ + { "nodeid": 9, "depth": 4, "leaf": 0.20469191670417786 }, + { "nodeid": 10, "depth": 4, "leaf": 0.05293010547757149 } ] }, + { "nodeid": 6, "depth": 3, "leaf": 0.0055474769324064255 } ] }, + { "nodeid": 4, "depth": 2, "split": "xgboost_input_23", "split_condition": 653.5999755859375, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.1103351041674614 }, + { "nodeid": 8, "depth": 3, "split": "xgboost_input_5", "split_condition": 0.07326000183820724, "yes": 11, "no": 12, "missing": 12, "children": [ + { "nodeid": 11, "depth": 4, "leaf": -0.19937729835510254 }, + { "nodeid": 12, "depth": 4, "leaf": -0.011575295589864254 } ] } ] } ] }, + { "nodeid": 2, "depth": 1, "leaf": -0.16972780227661133 } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_7", "split_condition": 0.04938000068068504, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_15", "split_condition": 0.012029999867081642, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "split": "xgboost_input_17", "split_condition": 0.0074970000423491, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": 0.0373166985809803 }, + { "nodeid": 8, "depth": 3, "leaf": -0.14507374167442322 } ] }, + { "nodeid": 4, "depth": 2, "leaf": 0.18085254728794098 } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_21", "split_condition": 23.75, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "leaf": 0.021509487181901932 }, + { "nodeid": 6, "depth": 2, "leaf": -0.1693851202726364 } ] } ] }, + { "nodeid": 0, "depth": 0, "split": "xgboost_input_26", "split_condition": 0.2079000025987625, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "xgboost_input_13", "split_condition": 40.5099983215332, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "depth": 2, "leaf": 0.18477416038513184 }, + { "nodeid": 4, "depth": 2, "leaf": -0.027129333466291428 } ] }, + { "nodeid": 2, "depth": 1, "split": "xgboost_input_23", "split_condition": 734.5999755859375, "yes": 5, "no": 6, "missing": 6, "children": [ + { "nodeid": 5, "depth": 2, "leaf": 0.05231140926480293 }, + { "nodeid": 6, "depth": 2, "split": "xgboost_input_21", "split_condition": 22.149999618530273, "yes": 7, "no": 8, "missing": 8, "children": [ + { "nodeid": 7, "depth": 3, "leaf": -0.007632177323102951 }, + { "nodeid": 8, "depth": 3, "leaf": -0.16687200963497162 } ] } ] } ] } ] diff --git a/model-integration/src/test/models/xgboost/binary_breast_cancer.ubj b/model-integration/src/test/models/xgboost/binary_breast_cancer.ubj new file mode 100644 index 000000000000..e66e270b5208 Binary files /dev/null and b/model-integration/src/test/models/xgboost/binary_breast_cancer.ubj differ diff --git a/model-integration/src/test/models/xgboost/model_parser.py b/model-integration/src/test/models/xgboost/model_parser.py new file mode 100644 index 000000000000..da7675ab81e1 --- /dev/null +++ b/model-integration/src/test/models/xgboost/model_parser.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 + +import argparse +import json +from dataclasses import dataclass +from enum import IntEnum, unique +from typing import Any, Dict, List, Sequence, Union + +import numpy as np + +try: + import ubjson +except ImportError: + ubjson = None + + +ParamT = Dict[str, str] + + +def to_integers(data: Union[bytes, List[int]]) -> List[int]: + """Convert a sequence of bytes to a list of Python integer""" + return [v for v in data] + + +@unique +class SplitType(IntEnum): + numerical = 0 + categorical = 1 + + +@dataclass +class Node: + # properties + left: int + right: int + parent: int + split_idx: int + split_cond: float + default_left: bool + split_type: SplitType + categories: List[int] + # statistic + base_weight: float + loss_chg: float + sum_hess: float + + +class Tree: + """A tree built by XGBoost.""" + + def __init__(self, tree_id: int, nodes: Sequence[Node]) -> None: + self.tree_id = tree_id + self.nodes = nodes + + def loss_change(self, node_id: int) -> float: + """Loss gain of a node.""" + return self.nodes[node_id].loss_chg + + def sum_hessian(self, node_id: int) -> float: + """Sum Hessian of a node.""" + return self.nodes[node_id].sum_hess + + def base_weight(self, node_id: int) -> float: + """Base weight of a node.""" + return self.nodes[node_id].base_weight + + def split_index(self, node_id: int) -> int: + """Split feature index of node.""" + return self.nodes[node_id].split_idx + + def split_condition(self, node_id: int) -> float: + """Split value of a node.""" + return self.nodes[node_id].split_cond + + def split_categories(self, node_id: int) -> List[int]: + """Categories in a node.""" + return self.nodes[node_id].categories + + def is_categorical(self, node_id: int) -> bool: + """Whether a node has categorical split.""" + return self.nodes[node_id].split_type == SplitType.categorical + + def is_numerical(self, node_id: int) -> bool: + return not self.is_categorical(node_id) + + def parent(self, node_id: int) -> int: + """Parent ID of a node.""" + return self.nodes[node_id].parent + + def left_child(self, node_id: int) -> int: + """Left child ID of a node.""" + return self.nodes[node_id].left + + def right_child(self, node_id: int) -> int: + """Right child ID of a node.""" + return self.nodes[node_id].right + + def is_leaf(self, node_id: int) -> bool: + """Whether a node is leaf.""" + return self.nodes[node_id].left == -1 + + def is_deleted(self, node_id: int) -> bool: + """Whether a node is deleted.""" + return self.split_index(node_id) == np.iinfo(np.uint32).max + + def __str__(self) -> str: + stack = [0] + nodes = [] + while stack: + node: Dict[str, Union[float, int, List[int]]] = {} + nid = stack.pop() + + node["node id"] = nid + node["gain"] = self.loss_change(nid) + node["cover"] = self.sum_hessian(nid) + nodes.append(node) + + if not self.is_leaf(nid) and not self.is_deleted(nid): + left = self.left_child(nid) + right = self.right_child(nid) + stack.append(left) + stack.append(right) + categories = self.split_categories(nid) + if categories: + assert self.is_categorical(nid) + node["categories"] = categories + else: + assert self.is_numerical(nid) + node["condition"] = self.split_condition(nid) + if self.is_leaf(nid): + node["weight"] = self.split_condition(nid) + + string = "\n".join(map(lambda x: " " + str(x), nodes)) + return string + + +class Model: + """Gradient boosted tree model.""" + + def __init__(self, model: dict) -> None: + """Construct the Model from a JSON object. + + parameters + ---------- + model : A dictionary loaded by json representing a XGBoost boosted tree model. + """ + # Basic properties of a model + self.learner_model_shape: ParamT = model["learner"]["learner_model_param"] + self.num_output_group = int(self.learner_model_shape["num_class"]) + self.num_feature = int(self.learner_model_shape["num_feature"]) + self.base_score: List[float] = json.loads( + self.learner_model_shape["base_score"] + ) + # A field encoding which output group a tree belongs + self.tree_info = model["learner"]["gradient_booster"]["model"]["tree_info"] + + model_shape: ParamT = model["learner"]["gradient_booster"]["model"][ + "gbtree_model_param" + ] + + # JSON representation of trees + j_trees = model["learner"]["gradient_booster"]["model"]["trees"] + + # Load the trees + self.num_trees = int(model_shape["num_trees"]) + + trees: List[Tree] = [] + for i in range(self.num_trees): + tree: Dict[str, Any] = j_trees[i] + tree_id = int(tree["id"]) + assert tree_id == i, (tree_id, i) + # - properties + left_children: List[int] = tree["left_children"] + right_children: List[int] = tree["right_children"] + parents: List[int] = tree["parents"] + split_conditions: List[float] = tree["split_conditions"] + split_indices: List[int] = tree["split_indices"] + # when ubjson is used, this is a byte array with each element as uint8 + default_left = to_integers(tree["default_left"]) + + # - categorical features + # when ubjson is used, this is a byte array with each element as uint8 + split_types = to_integers(tree["split_type"]) + # categories for each node is stored in a CSR style storage with segment as + # the begin ptr and the `categories' as values. + cat_segments: List[int] = tree["categories_segments"] + cat_sizes: List[int] = tree["categories_sizes"] + # node index for categorical nodes + cat_nodes: List[int] = tree["categories_nodes"] + assert len(cat_segments) == len(cat_sizes) == len(cat_nodes) + cats = tree["categories"] + assert len(left_children) == len(split_types) + + # The storage for categories is only defined for categorical nodes to + # prevent unnecessary overhead for numerical splits, we track the + # categorical node that are processed using a counter. + cat_cnt = 0 + if cat_nodes: + last_cat_node = cat_nodes[cat_cnt] + else: + last_cat_node = -1 + node_categories: List[List[int]] = [] + for node_id in range(len(left_children)): + if node_id == last_cat_node: + beg = cat_segments[cat_cnt] + size = cat_sizes[cat_cnt] + end = beg + size + node_cats = cats[beg:end] + # categories are unique for each node + assert len(set(node_cats)) == len(node_cats) + cat_cnt += 1 + if cat_cnt == len(cat_nodes): + last_cat_node = -1 # continue to process the rest of the nodes + else: + last_cat_node = cat_nodes[cat_cnt] + assert node_cats + node_categories.append(node_cats) + else: + # append an empty node, it's either a numerical node or a leaf. + node_categories.append([]) + + # - stats + base_weights: List[float] = tree["base_weights"] + loss_changes: List[float] = tree["loss_changes"] + sum_hessian: List[float] = tree["sum_hessian"] + + # Construct a list of nodes that have complete information + nodes: List[Node] = [ + Node( + left_children[node_id], + right_children[node_id], + parents[node_id], + split_indices[node_id], + split_conditions[node_id], + default_left[node_id] == 1, # to boolean + SplitType(split_types[node_id]), + node_categories[node_id], + base_weights[node_id], + loss_changes[node_id], + sum_hessian[node_id], + ) + for node_id in range(len(left_children)) + ] + + pytree = Tree(tree_id, nodes) + trees.append(pytree) + + self.trees = trees + + def print_model(self) -> None: + for i, tree in enumerate(self.trees): + print("\ntree_id:", i) + print(tree) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Demonstration for loading XGBoost JSON/UBJSON model." + ) + parser.add_argument( + "--model", type=str, required=True, help="Path to .json/.ubj model file." + ) + args = parser.parse_args() + if args.model.endswith("json"): + # use json format + with open(args.model, "r") as fd: + model = json.load(fd) + elif args.model.endswith("ubj"): + if ubjson is None: + raise ImportError("ubjson is not installed.") + # use ubjson format + with open(args.model, "rb") as bfd: + model = ubjson.load(bfd) + else: + raise ValueError( + "Unexpected file extension. Supported file extension are json and ubj." + ) + model = Model(model) + model.print_model() diff --git a/model-integration/src/test/models/xgboost/ubj-to-json.sh b/model-integration/src/test/models/xgboost/ubj-to-json.sh new file mode 100755 index 000000000000..12571b88d335 --- /dev/null +++ b/model-integration/src/test/models/xgboost/ubj-to-json.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Convert XGBoost UBJ file to JSON format +# Usage: ./ubj-to-json.sh [output.json] + +if [ $# -lt 1 ]; then + echo "Usage: $0 [output.json]" + exit 1 +fi + +INPUT="$1" +OUTPUT="${2:-}" + +if [ -n "$OUTPUT" ]; then + mvn exec:java -Dexec.mainClass="ai.vespa.rankingexpression.importer.xgboost.UbjToJson" \ + -Dexec.args="$INPUT" -Dexec.classpathScope=test -q > "$OUTPUT" + echo "Converted $INPUT to $OUTPUT" +else + mvn exec:java -Dexec.mainClass="ai.vespa.rankingexpression.importer.xgboost.UbjToJson" \ + -Dexec.args="$INPUT" -Dexec.classpathScope=test -q +fi diff --git a/parent/pom.xml b/parent/pom.xml index 2183736b0cba..0a7b87f83271 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -529,6 +529,11 @@ java-jwt ${java-jwt.vespa.version} + + com.dev-smart + ubjson + ${dev-smart-ubjson.vespa.version} + com.fasterxml.jackson.dataformat jackson-dataformat-cbor diff --git a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt index 2a116b45f065..0a0df812e6a0 100644 --- a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt +++ b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt @@ -6,6 +6,7 @@ aopalliance:aopalliance:${aopalliance.vespa.version} backport-util-concurrent:backport-util-concurrent:3.1 classworlds:classworlds:1.1-alpha-2 com.auth0:java-jwt:${java-jwt.vespa.version} +com.dev-smart:ubjson:${dev-smart-ubjson.vespa.version} com.ethlo.time:itu:1.10.3 com.fasterxml.jackson.core:jackson-annotations:${jackson2.vespa.version} com.fasterxml.jackson.core:jackson-core:${jackson2.vespa.version}