|
1 | 1 | package testutil |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "encoding/binary" |
4 | 5 | "fmt" |
5 | 6 | "math" |
6 | 7 | "math/rand" |
@@ -734,30 +735,113 @@ func BuildArrayData(schema *schemapb.CollectionSchema, insertData *storage.Inser |
734 | 735 | columns = append(columns, builder.NewListArray()) |
735 | 736 | } |
736 | 737 | case schemapb.DataType_ArrayOfVector: |
737 | | - data := insertData.Data[fieldID].(*storage.VectorArrayFieldData).Data |
738 | | - rows := len(data) |
| 738 | + vectorArrayData := insertData.Data[fieldID].(*storage.VectorArrayFieldData) |
| 739 | + dim, err := typeutil.GetDim(field) |
| 740 | + if err != nil { |
| 741 | + return nil, err |
| 742 | + } |
| 743 | + elemType, err := storage.VectorArrayToArrowType(elementType, int(dim)) |
| 744 | + if err != nil { |
| 745 | + return nil, err |
| 746 | + } |
739 | 747 |
|
740 | | - switch elementType { |
741 | | - case schemapb.DataType_FloatVector: |
742 | | - // ArrayOfVector is flattened in Arrow - just a list of floats |
743 | | - // where total floats = dim * num_vectors |
744 | | - builder := array.NewListBuilder(mem, &arrow.Float32Type{}) |
745 | | - valueBuilder := builder.ValueBuilder().(*array.Float32Builder) |
| 748 | + // Create ListBuilder with "item" field name to match convertToArrowDataType |
| 749 | + // Always represented as a list of fixed-size binary values |
| 750 | + listBuilder := array.NewListBuilderWithField(mem, arrow.Field{ |
| 751 | + Name: "item", |
| 752 | + Type: elemType, |
| 753 | + Nullable: true, |
| 754 | + Metadata: arrow.Metadata{}, |
| 755 | + }) |
| 756 | + fixedSizeBuilder, ok := listBuilder.ValueBuilder().(*array.FixedSizeBinaryBuilder) |
| 757 | + if !ok { |
| 758 | + return nil, fmt.Errorf("unexpected list value builder for VectorArray field %s: %T", field.GetName(), listBuilder.ValueBuilder()) |
| 759 | + } |
746 | 760 |
|
747 | | - for i := 0; i < rows; i++ { |
748 | | - vectorArray := data[i].GetFloatVector() |
749 | | - if vectorArray == nil || len(vectorArray.GetData()) == 0 { |
750 | | - builder.AppendNull() |
| 761 | + vectorArrayData.Dim = dim |
| 762 | + |
| 763 | + bytesPerVector := int(fixedSizeBuilder.Type().(*arrow.FixedSizeBinaryType).ByteWidth) |
| 764 | + |
| 765 | + appendBinarySlice := func(data []byte, stride int) error { |
| 766 | + if stride == 0 { |
| 767 | + return fmt.Errorf("zero stride for VectorArray field %s", field.GetName()) |
| 768 | + } |
| 769 | + if len(data)%stride != 0 { |
| 770 | + return fmt.Errorf("vector array data length %d is not divisible by stride %d for field %s", len(data), stride, field.GetName()) |
| 771 | + } |
| 772 | + for offset := 0; offset < len(data); offset += stride { |
| 773 | + fixedSizeBuilder.Append(data[offset : offset+stride]) |
| 774 | + } |
| 775 | + return nil |
| 776 | + } |
| 777 | + |
| 778 | + for _, vectorField := range vectorArrayData.Data { |
| 779 | + if vectorField == nil { |
| 780 | + listBuilder.Append(false) |
| 781 | + continue |
| 782 | + } |
| 783 | + |
| 784 | + listBuilder.Append(true) |
| 785 | + |
| 786 | + switch elementType { |
| 787 | + case schemapb.DataType_FloatVector: |
| 788 | + floatArray := vectorField.GetFloatVector() |
| 789 | + if floatArray == nil { |
| 790 | + return nil, fmt.Errorf("expected FloatVector data for field %s", field.GetName()) |
| 791 | + } |
| 792 | + data := floatArray.GetData() |
| 793 | + if len(data) == 0 { |
| 794 | + continue |
| 795 | + } |
| 796 | + if len(data)%int(dim) != 0 { |
| 797 | + return nil, fmt.Errorf("float vector data length %d is not divisible by dim %d for field %s", len(data), dim, field.GetName()) |
| 798 | + } |
| 799 | + for offset := 0; offset < len(data); offset += int(dim) { |
| 800 | + vectorBytes := make([]byte, bytesPerVector) |
| 801 | + for j := 0; j < int(dim); j++ { |
| 802 | + binary.LittleEndian.PutUint32(vectorBytes[j*4:], math.Float32bits(data[offset+j])) |
| 803 | + } |
| 804 | + fixedSizeBuilder.Append(vectorBytes) |
| 805 | + } |
| 806 | + case schemapb.DataType_BinaryVector: |
| 807 | + binaryData := vectorField.GetBinaryVector() |
| 808 | + if len(binaryData) == 0 { |
| 809 | + continue |
| 810 | + } |
| 811 | + bytesPer := int((dim + 7) / 8) |
| 812 | + if err := appendBinarySlice(binaryData, bytesPer); err != nil { |
| 813 | + return nil, err |
| 814 | + } |
| 815 | + case schemapb.DataType_Float16Vector: |
| 816 | + float16Data := vectorField.GetFloat16Vector() |
| 817 | + if len(float16Data) == 0 { |
| 818 | + continue |
| 819 | + } |
| 820 | + if err := appendBinarySlice(float16Data, int(dim)*2); err != nil { |
| 821 | + return nil, err |
| 822 | + } |
| 823 | + case schemapb.DataType_BFloat16Vector: |
| 824 | + bfloat16Data := vectorField.GetBfloat16Vector() |
| 825 | + if len(bfloat16Data) == 0 { |
751 | 826 | continue |
752 | 827 | } |
753 | | - builder.Append(true) |
754 | | - // Append all flattened vector data |
755 | | - valueBuilder.AppendValues(vectorArray.GetData(), nil) |
| 828 | + if err := appendBinarySlice(bfloat16Data, int(dim)*2); err != nil { |
| 829 | + return nil, err |
| 830 | + } |
| 831 | + case schemapb.DataType_Int8Vector: |
| 832 | + int8Data := vectorField.GetInt8Vector() |
| 833 | + if len(int8Data) == 0 { |
| 834 | + continue |
| 835 | + } |
| 836 | + if err := appendBinarySlice(int8Data, int(dim)); err != nil { |
| 837 | + return nil, err |
| 838 | + } |
| 839 | + default: |
| 840 | + return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String()) |
756 | 841 | } |
757 | | - columns = append(columns, builder.NewListArray()) |
758 | | - default: |
759 | | - return nil, fmt.Errorf("unsupported element type in VectorArray: %s", elementType.String()) |
760 | 842 | } |
| 843 | + |
| 844 | + columns = append(columns, listBuilder.NewListArray()) |
761 | 845 | } |
762 | 846 | } |
763 | 847 | return columns, nil |
|
0 commit comments