Skip to content

Commit 6839eb6

Browse files
zachgkfrankfliu
andcommitted
Updates XGBoost to 2.0.1 (#2833)
* Updates XGBoost to 2.0.1 * Use devtools 8 * Updates based on new Xgboost JNI API. --------- Co-authored-by: Frank Liu <[email protected]>
1 parent 4560476 commit 6839eb6

File tree

5 files changed

+21
-7
lines changed

5 files changed

+21
-7
lines changed

.github/workflows/native_s3_xgboost.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ jobs:
3434
run: |
3535
yum -y update
3636
yum -y install centos-release-scl-rh epel-release
37-
yum -y install devtoolset-7 git patch libstdc++-static curl python3-devel
37+
yum -y install devtoolset-8 git patch libstdc++-static curl python3-devel
3838
curl -L -o cmake.tar.gz https://github.com/Kitware/CMake/releases/download/v3.27.0-rc2/cmake-3.27.0-rc2-linux-aarch64.tar.gz
3939
tar xvfz cmake.tar.gz
4040
ln -sf $PWD/cmake-3.*/bin/cmake /usr/bin/cmake
@@ -50,7 +50,7 @@ jobs:
5050
XGBOOST_VERSION=${{ github.event.inputs.xgb_version }}
5151
XGBOOST_VERSION=${XGBOOST_VERSION:-$(cat gradle.properties | awk -F '=' '/xgboost_version/ {print $2}')}
5252
git clone https://github.com/dmlc/xgboost --recursive -b v"$XGBOOST_VERSION"
53-
export PATH=$PATH:/opt/rh/devtoolset-7/root/usr/bin
53+
export PATH=$PATH:/opt/rh/devtoolset-8/root/usr/bin
5454
cd xgboost/jvm-packages
5555
python3 create_jni.py
5656
cd ../..

engines/ml/xgboost/src/main/java/ai/djl/ml/xgboost/XgbNDManager.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ public class XgbNDManager extends BaseNDManager {
3939
private static final XgbNDManager SYSTEM_MANAGER = new SystemManager();
4040

4141
private float missingValue = Float.NaN;
42+
private int nthread = 1;
4243

4344
private XgbNDManager(NDManager parent, Device device) {
4445
super(parent, device);
@@ -57,6 +58,15 @@ public void setMissingValue(float missingValue) {
5758
this.missingValue = missingValue;
5859
}
5960

61+
/**
62+
* Sets the default number of threads.
63+
*
64+
* @param nthread the default number of threads
65+
*/
66+
public void setNthread(int nthread) {
67+
this.nthread = nthread;
68+
}
69+
6070
/** {@inheritDoc} */
6171
@Override
6272
public ByteBuffer allocateDirect(int capacity) {
@@ -166,7 +176,7 @@ public NDArray createCSR(Buffer buffer, long[] indptr, long[] indices, Shape sha
166176
int[] intIndices = Arrays.stream(indices).mapToInt(Math::toIntExact).toArray();
167177
float[] data = new float[buffer.remaining()];
168178
((FloatBuffer) buffer).get(data);
169-
long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data);
179+
long handle = JniUtils.createDMatrixCSR(indptr, intIndices, data, missingValue, nthread);
170180
return new XgbNDArray(this, alternativeManager, handle, shape, SparseFormat.CSR);
171181
}
172182

engines/ml/xgboost/src/main/java/ml/dmlc/xgboost4j/java/JniUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,12 @@ public static long createDMatrix(ColumnBatch columnBatch, float missing, int nth
6767
return handles[0];
6868
}
6969

70-
public static long createDMatrixCSR(long[] indptr, int[] indices, float[] array) {
70+
public static long createDMatrixCSR(
71+
long[] indptr, int[] indices, float[] array, float missing, int nthread) {
7172
long[] handles = new long[1];
72-
checkCall(XGBoostJNI.XGDMatrixCreateFromCSREx(indptr, indices, array, 0, handles));
73+
checkCall(
74+
XGBoostJNI.XGDMatrixCreateFromCSR(
75+
indptr, indices, array, 0, missing, nthread, handles));
7376
return handles[0];
7477
}
7578

engines/ml/xgboost/src/test/java/ai/djl/ml/xgboost/XgbModelTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public void downloadXGBoostModel() throws IOException {
5353
@Test
5454
public void testVersion() {
5555
Engine engine = Engine.getEngine("XGBoost");
56-
Assert.assertEquals("1.7.5", engine.getVersion());
56+
Assert.assertEquals("2.0.1", engine.getVersion());
5757
}
5858

5959
/*
@@ -93,6 +93,7 @@ public void testNDArray() {
9393
try (XgbNDManager manager =
9494
(XgbNDManager) XgbNDManager.getSystemManager().newSubManager()) {
9595
manager.setMissingValue(Float.NaN);
96+
manager.setNthread(1);
9697
NDArray zeros = manager.zeros(new Shape(1, 2));
9798
Assert.expectThrows(UnsupportedOperationException.class, zeros::toFloatArray);
9899

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ paddlepaddle_version=2.3.2
2222
sentencepiece_version=0.1.97
2323
tokenizers_version=0.14.1
2424
fasttext_version=0.9.2
25-
xgboost_version=1.7.5
25+
xgboost_version=2.0.1
2626
lightgbm_version=3.2.110
2727
rapis_version=22.12.0
2828

0 commit comments

Comments
 (0)