Skip to content

Commit eb99064

Browse files
ryanboganjmazanec15
andcommitted
Add model version to model metadata and change model metadata reads to be from cluster metadata (#2005)
* Add model version to model metadata Signed-off-by: Ryan Bogan <rbogan@amazon.com> * Add model version to model metadata and change model metadata reads to be from cluster metadata Signed-off-by: Ryan Bogan <rbogan@amazon.com> * Add changelog entry Signed-off-by: Ryan Bogan <rbogan@amazon.com> * Set version from config context Signed-off-by: Ryan Bogan <rbogan@amazon.com> * Fix spotless Signed-off-by: Ryan Bogan <rbogan@amazon.com> * Update model index mappings Signed-off-by: Ryan Bogan <rbogan@amazon.com> * Change field mapper to read model version Signed-off-by: Ryan Bogan <rbogan@amazon.com> * Fix tests Signed-off-by: Ryan Bogan <rbogan@amazon.com> * remove println Signed-off-by: John Mazanec <jmazane@amazon.com> --------- Signed-off-by: Ryan Bogan <rbogan@amazon.com> Signed-off-by: John Mazanec <jmazane@amazon.com> Co-authored-by: John Mazanec <jmazane@amazon.com> (cherry picked from commit 6814c8f)
1 parent f3d38bc commit eb99064

27 files changed

Lines changed: 484 additions & 319 deletions

release-notes/opensearch-knn.release-notes-2.17.0.0.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Compatible with OpenSearch 2.17.0
1111
* Add spaceType as a top level optional parameter while creating vector field. [#2044](https://github.com/opensearch-project/k-NN/pull/2044)
1212
### Enhancements
1313
* Adds iterative graph build capability into a faiss index to improve the memory footprint during indexing and Integrates KNNVectorsFormat for native engines[#1950](https://github.com/opensearch-project/k-NN/pull/1950)
14+
* Add model version to model metadata and change model metadata reads to be from cluster metadata [#2005](https://github.com/opensearch-project/k-NN/pull/2005)
1415
### Bug Fixes
1516
* Corrected search logic for scenario with non-existent fields in filter [#1874](https://github.com/opensearch-project/k-NN/pull/1874)
1617
* Add script_fields context to KNNAllowlist [#1917] (https://github.com/opensearch-project/k-NN/pull/1917)

src/main/java/org/opensearch/knn/common/KNNConstants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public class KNNConstants {
7777
public static final String TOP_LEVEL_SPACE_TYPE_FEATURE = "top_level_space_type_feature";
7878

7979
public static final String RADIAL_SEARCH_KEY = "radial_search";
80+
public static final String MODEL_VERSION = "model_version";
8081
public static final String QUANTIZATION_STATE_FILE_SUFFIX = "osknnqstate";
8182

8283
// Lucene specific constants

src/main/java/org/opensearch/knn/index/mapper/ModelFieldMapper.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ private static KNNMethodConfigContext getKNNMethodConfigContextFromModelMetadata
290290
return KNNMethodConfigContext.builder()
291291
.vectorDataType(modelMetadata.getVectorDataType())
292292
.dimension(modelMetadata.getDimension())
293-
.versionCreated(Version.V_2_14_0)
293+
.versionCreated(modelMetadata.getModelVersion())
294294
.mode(modelMetadata.getMode())
295295
.compressionLevel(modelMetadata.getCompressionLevel())
296296
.build();

src/main/java/org/opensearch/knn/index/util/IndexUtil.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ public class IndexUtil {
5454
private static final Version MINIMAL_RESCORE_FEATURE = Version.V_2_17_0;
5555
private static final Version MINIMAL_MODE_AND_COMPRESSION_FEATURE = Version.V_2_17_0;
5656
private static final Version MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE = Version.V_2_17_0;
57+
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION = Version.V_2_17_0;
5758
// public so neural search can access it
5859
public static final Map<String, Version> minimalRequiredVersionMap = initializeMinimalRequiredVersionMap();
5960
public static final Set<VectorDataType> VECTOR_DATA_TYPES_NOT_SUPPORTING_ENCODERS = Set.of(VectorDataType.BINARY, VectorDataType.BYTE);
@@ -394,6 +395,7 @@ private static Map<String, Version> initializeMinimalRequiredVersionMap() {
394395
put(RESCORE_PARAMETER, MINIMAL_RESCORE_FEATURE);
395396
put(KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE, MINIMAL_MODE_AND_COMPRESSION_FEATURE);
396397
put(KNNConstants.TOP_LEVEL_SPACE_TYPE_FEATURE, MINIMAL_TOP_LEVEL_SPACE_TYPE_FEATURE);
398+
put(KNNConstants.MODEL_VERSION, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_VERSION);
397399
}
398400
};
399401

src/main/java/org/opensearch/knn/indices/ModelDao.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ private void putInternal(Model model, ActionListener<IndexResponse> listener, Do
301301
if (CompressionLevel.isConfigured(modelMetadata.getCompressionLevel())) {
302302
put(KNNConstants.COMPRESSION_LEVEL_PARAMETER, modelMetadata.getCompressionLevel().getName());
303303
}
304+
put(KNNConstants.MODEL_VERSION, modelMetadata.getModelVersion());
304305
MethodComponentContext methodComponentContext = modelMetadata.getMethodComponentContext();
305306
if (!methodComponentContext.getName().isEmpty()) {
306307
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();

src/main/java/org/opensearch/knn/indices/ModelMetadata.java

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import lombok.extern.log4j.Log4j2;
1616
import org.apache.commons.lang.builder.EqualsBuilder;
1717
import org.apache.commons.lang.builder.HashCodeBuilder;
18+
import org.opensearch.Version;
1819
import org.opensearch.common.xcontent.json.JsonXContent;
1920
import org.opensearch.core.common.io.stream.StreamInput;
2021
import org.opensearch.core.common.io.stream.StreamOutput;
@@ -59,14 +60,14 @@ public class ModelMetadata implements Writeable, ToXContentObject {
5960
private String error;
6061
@Getter
6162
private final CompressionLevel compressionLevel;
63+
private final Version version;
6264

6365
/**
6466
* Constructor
6567
*
6668
* @param in Stream input
6769
*/
6870
public ModelMetadata(StreamInput in) throws IOException {
69-
String tempTrainingNodeAssignment;
7071
this.knnEngine = KNNEngine.getEngine(in.readString());
7172
this.spaceType = SpaceType.getSpace(in.readString());
7273
this.dimension = in.readInt();
@@ -96,7 +97,6 @@ public ModelMetadata(StreamInput in) throws IOException {
9697
} else {
9798
this.vectorDataType = VectorDataType.DEFAULT;
9899
}
99-
100100
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MINIMAL_MODE_AND_COMPRESSION_FEATURE)) {
101101
this.mode = Mode.fromName(in.readOptionalString());
102102
this.compressionLevel = CompressionLevel.fromName(in.readOptionalString());
@@ -105,6 +105,11 @@ public ModelMetadata(StreamInput in) throws IOException {
105105
this.compressionLevel = CompressionLevel.NOT_CONFIGURED;
106106
}
107107

108+
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(in.getVersion(), KNNConstants.MODEL_VERSION)) {
109+
this.version = Version.fromString(in.readString());
110+
} else {
111+
this.version = Version.V_EMPTY;
112+
}
108113
}
109114

110115
/**
@@ -133,7 +138,8 @@ public ModelMetadata(
133138
MethodComponentContext methodComponentContext,
134139
VectorDataType vectorDataType,
135140
Mode mode,
136-
CompressionLevel compressionLevel
141+
CompressionLevel compressionLevel,
142+
Version version
137143
) {
138144
this.knnEngine = Objects.requireNonNull(knnEngine, "knnEngine must not be null");
139145
this.spaceType = Objects.requireNonNull(spaceType, "spaceType must not be null");
@@ -159,6 +165,7 @@ public ModelMetadata(
159165
this.vectorDataType = Objects.requireNonNull(vectorDataType, "vector data type must not be null");
160166
this.mode = Objects.requireNonNull(mode, "Mode must not be null");
161167
this.compressionLevel = Objects.requireNonNull(compressionLevel, "Compression level must not be null");
168+
this.version = Objects.requireNonNull(version, "model version must not be null");
162169
}
163170

164171
/**
@@ -246,6 +253,14 @@ public VectorDataType getVectorDataType() {
246253
return vectorDataType;
247254
}
248255

256+
/**
257+
* Getter for the model version
258+
* @return version
259+
*/
260+
public Version getModelVersion() {
261+
return version;
262+
}
263+
249264
/**
250265
* setter for model's state
251266
*
@@ -279,7 +294,8 @@ public String toString() {
279294
methodComponentContext.toClusterStateString(),
280295
vectorDataType.getValue(),
281296
mode.getName(),
282-
compressionLevel.getName()
297+
compressionLevel.getName(),
298+
version.toString()
283299
);
284300
}
285301

@@ -317,6 +333,7 @@ public int hashCode() {
317333
.append(getVectorDataType())
318334
.append(getMode())
319335
.append(getCompressionLevel())
336+
.append(getModelVersion())
320337
.toHashCode();
321338
}
322339

@@ -329,15 +346,15 @@ public int hashCode() {
329346
public static ModelMetadata fromString(String modelMetadataString) {
330347
String[] modelMetadataArray = modelMetadataString.split(DELIMITER, -1);
331348
int length = modelMetadataArray.length;
332-
333-
if (length < 7 || length > 12) {
349+
if (length < 7 || length > 13) {
334350
throw new IllegalArgumentException(
335351
"Illegal format for model metadata. Must be of the form "
336352
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>\" or "
337353
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>\" or "
338354
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>\" or "
339355
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>\". or "
340-
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>\"."
356+
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>\" or "
357+
+ "\"<KNNEngine>,<SpaceType>,<Dimension>,<ModelState>,<Timestamp>,<Description>,<Error>,<NodeAssignment>,<MethodContext>,<VectorDataType>,<Mode>,<CompressionLevel>,<Version>\"."
341358
);
342359
}
343360

@@ -357,6 +374,7 @@ public static ModelMetadata fromString(String modelMetadataString) {
357374
CompressionLevel compressionLevel = length > 11
358375
? CompressionLevel.fromName(modelMetadataArray[11])
359376
: CompressionLevel.NOT_CONFIGURED;
377+
Version version = length > 12 ? Version.fromString(modelMetadataArray[12]) : Version.V_EMPTY;
360378

361379
log.debug(getLogMessage(length));
362380

@@ -372,7 +390,8 @@ public static ModelMetadata fromString(String modelMetadataString) {
372390
methodComponentContext,
373391
vectorDataType,
374392
mode,
375-
compressionLevel
393+
compressionLevel,
394+
version
376395
);
377396
}
378397

@@ -386,9 +405,10 @@ private static String getLogMessage(int length) {
386405
return "Model metadata contains training node assignment and method context.";
387406
case 10:
388407
return "Model metadata contains training node assignment, method context and vector data type.";
389-
case 11:
390408
case 12:
391409
return "Model metadata contains mode and compression level";
410+
case 13:
411+
return "Model metadata contains training node assignment, method context, vector data type, and version";
392412
default:
393413
throw new IllegalArgumentException("Unexpected metadata array length: " + length);
394414
}
@@ -423,6 +443,7 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
423443
Object vectorDataType = modelSourceMap.get(KNNConstants.VECTOR_DATA_TYPE_FIELD);
424444
Object mode = modelSourceMap.get(KNNConstants.MODE_PARAMETER);
425445
Object compressionLevel = modelSourceMap.get(KNNConstants.COMPRESSION_LEVEL_PARAMETER);
446+
Object version = modelSourceMap.get(KNNConstants.MODEL_VERSION);
426447

427448
if (trainingNodeAssignment == null) {
428449
trainingNodeAssignment = "";
@@ -447,6 +468,10 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
447468
vectorDataType = VectorDataType.DEFAULT.getValue();
448469
}
449470

471+
if (version == null) {
472+
version = Version.V_EMPTY;
473+
}
474+
450475
ModelMetadata modelMetadata = new ModelMetadata(
451476
KNNEngine.getEngine(objectToString(engine)),
452477
SpaceType.getSpace(objectToString(space)),
@@ -459,7 +484,8 @@ public static ModelMetadata getMetadataFromSourceMap(final Map<String, Object> m
459484
(MethodComponentContext) methodComponentContext,
460485
VectorDataType.get(objectToString(vectorDataType)),
461486
Mode.fromName(objectToString(mode)),
462-
CompressionLevel.fromName(objectToString(compressionLevel))
487+
CompressionLevel.fromName(objectToString(compressionLevel)),
488+
Version.fromString(version.toString())
463489
);
464490
return modelMetadata;
465491
}
@@ -486,6 +512,9 @@ public void writeTo(StreamOutput out) throws IOException {
486512
out.writeOptionalString(mode.getName());
487513
out.writeOptionalString(compressionLevel.getName());
488514
}
515+
if (IndexUtil.isVersionOnOrAfterMinRequiredVersion(out.getVersion(), KNNConstants.MODEL_VERSION)) {
516+
out.writeString(version.toString());
517+
}
489518
}
490519

491520
@Override
@@ -517,6 +546,13 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
517546
builder.field(KNNConstants.COMPRESSION_LEVEL_PARAMETER, compressionLevel.getName());
518547
}
519548
}
549+
if (IndexUtil.isClusterOnOrAfterMinRequiredVersion(KNNConstants.MODEL_VERSION)) {
550+
String versionString = "unknown";
551+
if (version != Version.V_EMPTY) {
552+
versionString = version.toString();
553+
}
554+
builder.field(KNNConstants.MODEL_VERSION, versionString);
555+
}
520556
return builder;
521557
}
522558
}

src/main/java/org/opensearch/knn/indices/ModelUtil.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ public static ModelMetadata getModelMetadata(final String modelId) {
4848
if (StringUtils.isEmpty(modelId)) {
4949
return null;
5050
}
51-
final Model model = ModelCache.getInstance().get(modelId);
52-
final ModelMetadata modelMetadata = model.getModelMetadata();
51+
ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance();
52+
final ModelMetadata modelMetadata = modelDao.getMetadata(modelId);
5353
if (isModelCreated(modelMetadata) == false) {
5454
throw new IllegalArgumentException(String.format(Locale.ROOT, "Model ID '%s' is not created.", modelId));
5555
}

src/main/java/org/opensearch/knn/training/TrainingJob.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ public TrainingJob(
9999
knnMethodContext.getMethodComponentContext(),
100100
knnMethodConfigContext.getVectorDataType(),
101101
mode,
102-
compressionLevel
102+
compressionLevel,
103+
knnMethodConfigContext.getVersionCreated()
103104
),
104105
null,
105106
this.modelId

src/main/java/org/opensearch/knn/training/TrainingJobRunner.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import org.opensearch.core.action.ActionListener;
1717
import org.opensearch.action.index.IndexResponse;
1818
import org.opensearch.common.ValidationException;
19-
import org.opensearch.knn.indices.Model;
2019
import org.opensearch.knn.indices.ModelDao;
2120
import org.opensearch.knn.indices.ModelMetadata;
2221
import org.opensearch.knn.indices.ModelState;
@@ -166,11 +165,11 @@ private void train(TrainingJob trainingJob) {
166165
private void serializeModel(TrainingJob trainingJob, ActionListener<IndexResponse> listener, boolean update) throws IOException,
167166
ExecutionException, InterruptedException {
168167
if (update) {
169-
Model model = modelDao.get(trainingJob.getModelId());
170-
if (model.getModelMetadata().getState().equals(ModelState.TRAINING)) {
168+
ModelMetadata modelMetadata = modelDao.getMetadata(trainingJob.getModelId());
169+
if (modelMetadata.getState().equals(ModelState.TRAINING)) {
171170
modelDao.update(trainingJob.getModel(), listener);
172171
} else {
173-
logger.info("Model state is {}. Skipping serialization of trained data", model.getModelMetadata().getState());
172+
logger.info("Model state is {}. Skipping serialization of trained data", modelMetadata.getState());
174173
}
175174
} else {
176175
modelDao.put(trainingJob.getModel(), listener);

src/main/resources/mappings/model-index.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
},
3939
"compression_level": {
4040
"type": "keyword"
41+
},
42+
"model_version": {
43+
"type": "keyword"
4144
}
4245
}
4346
}

0 commit comments

Comments
 (0)