Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -916,6 +916,8 @@
</googleJavaFormat>
<removeUnusedImports />
<formatAnnotations />
<!-- enable spotless:off/on mark -->
<toggleOffOn />
<importOrder>
<order>org.apache.seatunnel.shade,org.apache.seatunnel,org.apache,org,,javax,java,\#</order>
</importOrder>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,48 @@ transform {
}

product_name_vector = product_name

multi_field_text_vector = [product_name, description]

multi_field_image_vector = [
{
field = product_image_url
modality = jpeg
format = url
},
{
field = thumbnail_image
modality = png
format = url
}
]

multi_field_video_vector = [
{
field = product_video_url
modality = mp4
format = url
},
{
field = promotional_video
modality = mov
format = url
}
]

multi_field_mix_vector = [
product_name,
{
field = product_image_url
modality = jpeg
format = url
},
{
field = product_video_url
modality = mp4
format = url
}
]
}

plugin_output = "multimodal_embedding_output"
Expand Down Expand Up @@ -219,6 +261,42 @@ sink {
}
]
},
{
field_name = multi_field_text_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_image_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_video_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = multi_field_mix_vector
field_type = float_vector
field_value = [
{
rule_type = NOT_NULL
}
]
},
{
field_name = category
field_type = string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

@Slf4j
public class EmbeddingTransform extends MultipleFieldOutputTransform {

private final ReadonlyConfig config;
private List<Integer> fieldOriginalIndexes;
private transient Model model;
private Integer dimension;
private boolean isMultimodalFields = false;
private Map<Integer, FieldSpec> fieldSpecMap;
private Map<VectorFieldSpec, List<Integer>> fieldSpecMap;
private List<String> fieldNames;

private final Map<String, TreeMap<Long, byte[]>> binaryFileCache = new ConcurrentHashMap<>();
Expand Down Expand Up @@ -197,29 +197,33 @@ public void open() {
}

private void initOutputFields(SeaTunnelRowType inputRowType, ReadonlyConfig config) {
Map<Integer, FieldSpec> fieldSpecMap = new HashMap<>();
List<String> fieldNames = new ArrayList<>();
Map<String, Object> fieldsConfig =
config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS);
if (fieldsConfig == null || fieldsConfig.isEmpty()) {
throw new IllegalArgumentException("vectorization_fields configuration is required");
}

for (Map.Entry<String, Object> field : fieldsConfig.entrySet()) {
FieldSpec fieldSpec = new FieldSpec(field);
log.info("Field spec: {}", fieldSpec.toString());
String srcField = fieldSpec.getFieldName();
int srcFieldIndex;
try {
srcFieldIndex = inputRowType.indexOf(srcField);
} catch (IllegalArgumentException e) {
throw TransformCommonError.cannotFindInputFieldError(getPluginName(), srcField);
}
if (fieldSpec.isMultimodalField()) {
isMultimodalFields = true;
List<String> fieldNames = new ArrayList<>();
Map<VectorFieldSpec, List<Integer>> fieldSpecMap = new HashMap<>();
for (Map.Entry<String, Object> fieldConfig : fieldsConfig.entrySet()) {
VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(fieldConfig);
log.info("Vector field spec: {}", vectorFieldSpec);
List<String> srcFieldNames =
vectorFieldSpec.getSrcFieldSpecs().stream()
.map(SrcFieldSpec::getFieldName)
.collect(Collectors.toList());
List<Integer> srcFieldIndexes = new ArrayList<>();
for (String srcFieldName : srcFieldNames) {
try {
srcFieldIndexes.add(inputRowType.indexOf(srcFieldName));
} catch (IllegalArgumentException e) {
throw TransformCommonError.cannotFindInputFieldsError(
getPluginName(), srcFieldNames);
}
}
fieldSpecMap.put(srcFieldIndex, fieldSpec);
fieldNames.add(field.getKey());
isMultimodalFields = vectorFieldSpec.isMultimodalField();
fieldSpecMap.put(vectorFieldSpec, srcFieldIndexes);
fieldNames.add(vectorFieldSpec.getFieldName());
}
this.fieldSpecMap = fieldSpecMap;
this.fieldNames = fieldNames;
Expand All @@ -232,19 +236,28 @@ protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) {
if (MetadataUtil.isBinaryFormat(inputRow)) {
return vectorizationBinaryRow(inputRow);
}
Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
Object[] fieldValues = new Object[fieldOriginalIndexes.size()];
List<ByteBuffer> vectorization;

Set<VectorFieldSpec> vectorFieldSpecs = fieldSpecMap.keySet();
Object[] fieldValues = new Object[vectorFieldSpecs.size()];
int i = 0;

for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
Object value = inputRow.getField(fieldOriginalIndex);
for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) {
List<SrcFieldSpec> srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs();
List<Integer> srcFieldIndexes = fieldSpecMap.get(vectorFieldSpec);
List<SrcField> srcFields = new ArrayList<>();
for (int j = 0; j < srcFieldSpecs.size(); j++) {
srcFields.add(
new SrcField(
srcFieldSpecs.get(j),
inputRow.getField(srcFieldIndexes.get(j))));
}
fieldValues[i++] =
isMultimodalFields ? new MultimodalFieldValue(fieldSpec, value) : value;
isMultimodalFields
? new MultimodalFieldValue(srcFields)
: srcFields.get(0).getFieldValue();
}

vectorization = model.vectorization(fieldValues);
List<ByteBuffer> vectorization = model.vectorization(fieldValues);
return vectorization.toArray();
} catch (Exception e) {
throw new RuntimeException("Failed to data vectorization", e);
Expand Down Expand Up @@ -282,32 +295,34 @@ public boolean isMultimodalFields() {

/** Process a row in binary format: [data, relativePath, partIndex] */
private Object[] vectorizationBinaryRow(SeaTunnelRowAccessor inputRow) throws Exception {

byte[] completeData = processBinaryRow(inputRow);
if (completeData == null) {
return null;
}
Set<Integer> fieldOriginalIndexes = fieldSpecMap.keySet();
Object[] fieldValues = new Object[fieldOriginalIndexes.size()];

Set<VectorFieldSpec> vectorFieldSpecs = fieldSpecMap.keySet();
Object[] fieldValues = new Object[vectorFieldSpecs.size()];
int i = 0;

for (Integer fieldOriginalIndex : fieldOriginalIndexes) {
FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex);
if (fieldSpec.isBinary()) {
fieldValues[i++] = new MultimodalFieldValue(fieldSpec, completeData);
} else {
log.warn(
"Non-binary field {} configured in binary format data",
fieldSpec.getFieldName());
fieldValues[i++] = null;
for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) {
List<SrcFieldSpec> srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs();
List<SrcField> srcFields = new ArrayList<>();
for (SrcFieldSpec srcFieldSpec : srcFieldSpecs) {
if (srcFieldSpec.isBinary()) {
srcFields.add(new SrcField(srcFieldSpec, completeData));
} else {
log.warn(
"Non-binary field {} configured in binary format data",
srcFieldSpec.getFieldName());
}
}
fieldValues[i++] = srcFields.isEmpty() ? null : new MultimodalFieldValue(srcFields);
}

try {
return model.vectorization(fieldValues).toArray();
} catch (Exception e) {
throw new RuntimeException(
"Failed to vectorize binary data for file: " + inputRow.toString(), e);
throw new RuntimeException("Failed to vectorize binary data for file: " + inputRow, e);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.seatunnel.transform.nlpmodel.embedding;

import lombok.Data;

import java.io.Serializable;
import java.util.Base64;

@Data
public class SrcField implements Serializable {

private static final long serialVersionUID = 1L;

private SrcFieldSpec fieldSpec;

private Object fieldValue;

public SrcField(SrcFieldSpec spec, Object value) {
this.fieldSpec = spec;
this.fieldValue = value;
}

public String toBase64() {
if (fieldSpec == null || !fieldSpec.isBinary()) {
throw new IllegalArgumentException("Payload format must be binary");
}
if (fieldValue == null) {
throw new IllegalArgumentException("Binary data cannot be null or empty");
}
return Base64.getEncoder().encodeToString(fieldValue.toString().getBytes());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,48 +26,20 @@
import java.util.Map;

@Data
public class FieldSpec implements Serializable {
public class SrcFieldSpec implements Serializable {

private static final long serialVersionUID = 1L;

private String fieldName;
private ModalityType modalityType;
private PayloadFormat payloadFormat;

public FieldSpec(String fieldName) {
this.fieldName = fieldName;
this.modalityType = ModalityType.TEXT;
this.payloadFormat = PayloadFormat.TEXT;
}

public FieldSpec(Map.Entry<String, Object> fieldConfig) {
String outputFieldName = fieldConfig.getKey();
if (outputFieldName == null) {
throw new IllegalArgumentException("Field spec cannot be null");
}
Object fieldValue = fieldConfig.getValue();
try {
if (fieldValue instanceof String) {
parseBasicFieldSpec((String) fieldValue);
} else {
Map<String, Object> fieldSpecConfig = (Map<String, Object>) fieldValue;
parseMultimodalFieldSpec(fieldSpecConfig);
}
} catch (Exception e) {
String errorMessage =
String.format(
"Invalid field spec for output field '%s': %s",
outputFieldName, fieldConfig);
throw new IllegalArgumentException(errorMessage, e);
}
}

/** Parse basic field spec: just the field name, defaults to TEXT modality and default format */
private void parseBasicFieldSpec(String fieldSpec) {
if (fieldSpec == null || fieldSpec.trim().isEmpty()) {
throw new IllegalArgumentException("Field spec cannot be null or empty");
public SrcFieldSpec(String fieldName) {
if (fieldName == null || fieldName.trim().isEmpty()) {
throw new IllegalArgumentException("Field name cannot be null or empty");
}
this.fieldName = fieldSpec.trim();
this.fieldName = fieldName.trim();
this.modalityType = ModalityType.TEXT;
this.payloadFormat = PayloadFormat.TEXT;
}
Expand All @@ -76,9 +48,9 @@ private void parseBasicFieldSpec(String fieldSpec) {
* Parse multimodal field spec: field name, modality, and format Supports both formats: 1.
* Separate modality and format
*/
private void parseMultimodalFieldSpec(Map<String, Object> fieldConfig) {
public SrcFieldSpec(Map<String, Object> fieldConfig) {
if (fieldConfig == null || fieldConfig.isEmpty()) {
throw new IllegalArgumentException("Field configuration cannot be null or empty");
throw new IllegalArgumentException("Field config cannot be null or empty");
}

Object fieldNameObj = fieldConfig.get("field");
Expand Down Expand Up @@ -109,10 +81,6 @@ private void parseMultimodalFieldSpec(Map<String, Object> fieldConfig) {
}
}

public boolean isMultimodalField() {
return !ModalityType.TEXT.equals(modalityType);
}

public boolean isBinary() {
return PayloadFormat.BINARY.equals(payloadFormat);
}
Expand Down
Loading