Skip to content

Commit ffc81e7

Browse files
buddharajusahilSahil Buddharaju
andcommitted
Allow method parameter override for training based indices (solves issue opensearch-project#2246) (opensearch-project#2290)
* Allow method parameter override for training based indices Signed-off-by: Sahil Buddharaju <[email protected]> * Fixed code squashing imports Signed-off-by: Sahil Buddharaju <[email protected]> * Changed changelog Signed-off-by: Sahil Buddharaju <[email protected]> * spotlessApply styling Signed-off-by: Sahil Buddharaju <[email protected]> --------- Signed-off-by: Sahil Buddharaju <[email protected]> Co-authored-by: Sahil Buddharaju <[email protected]> (cherry picked from commit 19f045d)
1 parent 1e269e5 commit ffc81e7

4 files changed

Lines changed: 103 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
1717
### Features
1818
### Enhancements
1919
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
20+
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
2021
### Bug Fixes
2122
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
2223
### Infrastructure

src/main/java/org/opensearch/knn/plugin/rest/RestTrainModelHandler.java

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ && ensureSpaceTypeNotSet(topLevelSpaceType)) {
135135
}
136136

137137
ensureAtleastOneSet(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode, COMPRESSION_LEVEL_PARAMETER, compressionLevel);
138-
ensureMutualExclusion(KNN_METHOD, knnMethodContext, MODE_PARAMETER, mode);
139-
ensureMutualExclusion(KNN_METHOD, knnMethodContext, COMPRESSION_LEVEL_PARAMETER, compressionLevel);
140138

141139
ensureSet(DIMENSION, dimension);
142140
ensureSet(TRAIN_INDEX_PARAMETER, trainingIndex);

src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,23 @@
2626
import java.util.List;
2727

2828
import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER;
29+
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
30+
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
2931
import static org.opensearch.knn.common.KNNConstants.FAISS_NAME;
3032
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
3133
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
34+
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
3235
import static org.opensearch.knn.common.KNNConstants.METHOD_IVF;
3336
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER;
3437
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
38+
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
3539
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST_DEFAULT;
3640
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NPROBES;
41+
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
3742
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
3843
import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER;
3944
import static org.opensearch.knn.common.KNNConstants.NAME;
45+
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
4046
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
4147
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
4248
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
@@ -260,17 +266,31 @@ public void testCompressionIndexWithNonVectorFieldsSegment_whenValid_ThenSucceed
260266
public void testTraining_whenInvalid_thenFail() {
261267
setupTrainingIndex();
262268
String modelId = "test";
269+
263270
XContentBuilder builder1 = XContentFactory.jsonBuilder()
264271
.startObject()
265272
.field(TRAIN_INDEX_PARAMETER, TRAINING_INDEX_NAME)
266273
.field(TRAIN_FIELD_PARAMETER, TRAINING_FIELD_NAME)
267274
.field(KNNConstants.DIMENSION, DIMENSION)
275+
.field(VECTOR_DATA_TYPE_FIELD, "float")
276+
.field(MODEL_DESCRIPTION, "")
277+
.field(MODE_PARAMETER, Mode.ON_DISK)
278+
.field(COMPRESSION_LEVEL_PARAMETER, "16x")
268279
.startObject(KNN_METHOD)
269280
.field(NAME, METHOD_IVF)
270281
.field(KNN_ENGINE, FAISS_NAME)
282+
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
283+
.startObject(PARAMETERS)
284+
.field(METHOD_PARAMETER_NLIST, 1)
285+
.startObject(METHOD_ENCODER_PARAMETER)
286+
.field(NAME, "pq")
287+
.startObject(PARAMETERS)
288+
.field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2)
289+
.field(ENCODER_PARAMETER_PQ_M, 8)
290+
.endObject()
291+
.endObject()
292+
.endObject()
271293
.endObject()
272-
.field(MODEL_DESCRIPTION, "")
273-
.field(MODE_PARAMETER, Mode.ON_DISK)
274294
.endObject();
275295
expectThrows(ResponseException.class, () -> trainModel(modelId, builder1));
276296

src/test/java/org/opensearch/knn/plugin/action/RestTrainModelHandlerIT.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111

1212
package org.opensearch.knn.plugin.action;
1313

14+
1415
import org.apache.http.util.EntityUtils;
16+
import org.opensearch.client.Request;
1517
import org.opensearch.client.Response;
1618
import org.opensearch.client.ResponseException;
1719
import org.opensearch.core.xcontent.XContentBuilder;
@@ -22,15 +24,22 @@
2224

2325
import java.util.Map;
2426

27+
import static org.opensearch.knn.common.KNNConstants.COMPRESSION_LEVEL_PARAMETER;
28+
import static org.opensearch.knn.common.KNNConstants.DIMENSION;
2529
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;
2630
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
2731
import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
32+
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
2833
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
2934
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
3035
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE;
36+
import static org.opensearch.knn.common.KNNConstants.MODEL_DESCRIPTION;
3137
import static org.opensearch.knn.common.KNNConstants.MODEL_ID;
38+
import static org.opensearch.knn.common.KNNConstants.MODE_PARAMETER;
3239
import static org.opensearch.knn.common.KNNConstants.NAME;
3340
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
41+
import static org.opensearch.knn.common.KNNConstants.TRAIN_FIELD_PARAMETER;
42+
import static org.opensearch.knn.common.KNNConstants.TRAIN_INDEX_PARAMETER;
3443

3544
public class RestTrainModelHandlerIT extends KNNRestTestCase {
3645

@@ -472,4 +481,75 @@ public void testTrainModel_success_nestedField() throws Exception {
472481

473482
assertTrainingSucceeds(modelId, 30, 1000);
474483
}
484+
485+
// Test to checks when user tries to train a model compression/mode and method
486+
public void testTrainModel_success_methodOverrideWithCompressionMode() throws Exception {
487+
String modelId = "test-model-id";
488+
String trainingIndexName = "train-index";
489+
String nestedFieldPath = "a.b.train-field";
490+
int dimension = 8;
491+
492+
// Create a training index and randomly ingest data into it
493+
String mapping = createKnnIndexNestedMapping(dimension, nestedFieldPath);
494+
createKnnIndex(trainingIndexName, mapping);
495+
int trainingDataCount = 200;
496+
bulkIngestRandomVectorsWithNestedField(trainingIndexName, nestedFieldPath, trainingDataCount, dimension);
497+
498+
// Call the train API with this definition:
499+
500+
/*
501+
POST /_plugins/_knn/models/test-model/_train
502+
{
503+
"training_index": "train_index",
504+
"training_field": "train_field",
505+
"dimension": 8,
506+
"description": "model",
507+
"space_type": "innerproduct",
508+
"mode": "on_disk",
509+
"method": {
510+
"name": "ivf",
511+
"params": {
512+
"nlist": 16
513+
}
514+
}
515+
}
516+
517+
*/
518+
XContentBuilder builder = XContentFactory.jsonBuilder()
519+
.startObject()
520+
.field(NAME, "ivf")
521+
.startObject(PARAMETERS)
522+
.field(METHOD_PARAMETER_NLIST, 16)
523+
.endObject()
524+
.endObject();
525+
Map<String, Object> method = xContentBuilderToMap(builder);
526+
527+
XContentBuilder outerParams = XContentFactory.jsonBuilder()
528+
.startObject()
529+
.field(TRAIN_INDEX_PARAMETER, trainingIndexName)
530+
.field(TRAIN_FIELD_PARAMETER, nestedFieldPath)
531+
.field(DIMENSION, dimension)
532+
.field(COMPRESSION_LEVEL_PARAMETER, "16x")
533+
.field(MODE_PARAMETER, "on_disk")
534+
.field(KNN_METHOD, method)
535+
.field(MODEL_DESCRIPTION, "dummy description")
536+
.endObject();
537+
538+
Request request = new Request("POST", "/_plugins/_knn/models/" + modelId + "/_train");
539+
request.setJsonEntity(outerParams.toString());
540+
541+
Response trainResponse = client().performRequest(request);
542+
543+
assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode()));
544+
545+
Response getResponse = getModel(modelId, null);
546+
String responseBody = EntityUtils.toString(getResponse.getEntity());
547+
assertNotNull(responseBody);
548+
549+
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();
550+
551+
assertEquals(modelId, responseMap.get(MODEL_ID));
552+
553+
assertTrainingSucceeds(modelId, 30, 1000);
554+
}
475555
}

0 commit comments

Comments
 (0)