Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions application/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@
<groupId>javax.annotation</groupId>
<artifactId>javax.annotation-api</artifactId>
</exclusion>
<exclusion>
<groupId>com.dev-smart</groupId>
<artifactId>ubjson</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
Expand Down
1 change: 1 addition & 0 deletions config-model-fat/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@

<!-- 3rd party artifacts embedded -->
<i>aopalliance:aopalliance:*:*</i>
<i>com.dev-smart:ubjson:*:*</i>
<i>com.google.errorprone:error_prone_annotations:*:*</i>
<i>com.google.guava:failureaccess:*:*</i>
<i>com.google.guava:guava:*:*</i>
Expand Down
4 changes: 4 additions & 0 deletions container-dev/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@
<groupId>org.lz4</groupId>
<artifactId>lz4-java</artifactId>
</exclusion>
<exclusion>
<groupId>com.dev-smart</groupId>
<artifactId>ubjson</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
Expand Down
1 change: 1 addition & 0 deletions dependency-versions/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

<!-- DO NOT UPGRADE THESE TO A NEW MAJOR VERSION WITHOUT CHECKING FOR BINARY COMPATIBILITY -->
<aopalliance.vespa.version>1.0</aopalliance.vespa.version>
<dev-smart-ubjson.vespa.version>0.1.8</dev-smart-ubjson.vespa.version>
<error-prone-annotations.vespa.version>2.30.0</error-prone-annotations.vespa.version>
<guava.vespa.version>33.2.1-jre</guava.vespa.version>
<guice.vespa.version>6.0.0</guice.vespa.version>
Expand Down
4 changes: 4 additions & 0 deletions model-integration/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,10 @@
<version>${testcontainers.vespa.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.dev-smart</groupId>
<artifactId>ubjson</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;

/**
* Base class for XGBoost parsers containing shared tree-to-expression conversion logic.
*
* @author arnej
*/
abstract class AbstractXGBoostParser {

/**
* Converts an XGBoostTree node to a Vespa ranking expression string.
* This method handles both leaf nodes and split nodes recursively.
*
* @param node XGBoost tree node to convert.
* @return Vespa ranking expression for input node.
*/
protected String treeToRankExp(XGBoostTree node) {
if (node.isLeaf()) {
return Double.toString(node.getLeaf());
} else {
assert node.getChildren().size() == 2;
String trueExp;
String falseExp;
if (node.getYes() == node.getChildren().get(0).getNodeid()) {
trueExp = treeToRankExp(node.getChildren().get(0));
falseExp = treeToRankExp(node.getChildren().get(1));
} else {
trueExp = treeToRankExp(node.getChildren().get(1));
falseExp = treeToRankExp(node.getChildren().get(0));
}
// xgboost uses float only internally, so round to closest float
float xgbSplitPoint = (float)node.getSplit_condition();
// but Vespa expects rank profile literals in double precision:
double vespaSplitPoint = xgbSplitPoint;
String formattedSplit = formatSplit(node.getSplit());
String condition;
if (node.getMissing() == node.getYes()) {
// Note: this is for handling missing features, as the backend handles comparison with NaN as false.
condition = "!(" + formattedSplit + " >= " + vespaSplitPoint + ")";
} else {
condition = formattedSplit + " < " + vespaSplitPoint;
}
return "if (" + condition + ", " + trueExp + ", " + falseExp + ")";
}
}

/**
* Formats a split field value for use in ranking expressions.
* If the split is a plain integer, wraps it with xgboost_input_X format.
* Otherwise, uses the split value as-is (for backward compatibility with JSON format).
*
* @param split The split field value from the tree node
* @return Formatted split expression for use in conditions
*/
protected String formatSplit(String split) {
try {
Integer.parseInt(split);
return "xgboost_input_" + split;
} catch (NumberFormatException e) {
// Not a plain integer, use as-is (JSON format already has full attribute name)
return split;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ public boolean canImport(String modelPath) {
File modelFile = new File(modelPath);
if ( ! modelFile.isFile()) return false;

if (modelFile.toString().endsWith(".ubj")) {
return XGBoostUbjParser.probe(modelPath);
}
return modelFile.toString().endsWith(".json") && probe(modelFile);
}

Expand Down Expand Up @@ -52,9 +55,15 @@ private boolean probe(File modelFile) {
public ImportedModel importModel(String modelName, String modelPath) {
try {
ImportedModel model = new ImportedModel(modelName, modelPath, ImportedMlModel.ModelType.XGBOOST);
XGBoostParser parser = new XGBoostParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
if (modelPath.endsWith(".ubj")) {
XGBoostUbjParser parser = new XGBoostUbjParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
} else {
XGBoostParser parser = new XGBoostParser(modelPath);
RankingExpression expression = new RankingExpression(parser.toRankingExpression());
model.expression(modelName, expression);
}
return model;
} catch (IOException e) {
throw new IllegalArgumentException("Could not import XGBoost model from '" + modelPath + "'", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
/**
* @author grace-lam
*/
class XGBoostParser {
class XGBoostParser extends AbstractXGBoostParser {

private final List<XGBoostTree> xgboostTrees;

Expand Down Expand Up @@ -49,39 +49,4 @@ String toRankingExpression() {
return ret.toString();
}

/**
* Recursive helper function for toRankingExpression().
*
* @param node XGBoost tree node to convert.
* @return Vespa ranking expression for input node.
*/
private String treeToRankExp(XGBoostTree node) {
if (node.isLeaf()) {
return Double.toString(node.getLeaf());
} else {
assert node.getChildren().size() == 2;
String trueExp;
String falseExp;
if (node.getYes() == node.getChildren().get(0).getNodeid()) {
trueExp = treeToRankExp(node.getChildren().get(0));
falseExp = treeToRankExp(node.getChildren().get(1));
} else {
trueExp = treeToRankExp(node.getChildren().get(1));
falseExp = treeToRankExp(node.getChildren().get(0));
}
// xgboost uses float only internally, so round to closest float
float xgbSplitPoint = (float)node.getSplit_condition();
// but Vespa expects rank profile literals in double precision:
double vespaSplitPoint = xgbSplitPoint;
String condition;
if (node.getMissing() == node.getYes()) {
// Note: this is for handling missing features, as the backend handles comparison with NaN as false.
condition = "!(" + node.getSplit() + " >= " + vespaSplitPoint + ")";
} else {
condition = node.getSplit() + " < " + vespaSplitPoint;
}
return "if (" + condition + ", " + trueExp + ", " + falseExp + ")";
}
}

}
Loading