diff --git a/java/common/src/main/java/org/apache/tsfile/enums/TSDataType.java b/java/common/src/main/java/org/apache/tsfile/enums/TSDataType.java index e4ffc6b1b..cf17aa040 100644 --- a/java/common/src/main/java/org/apache/tsfile/enums/TSDataType.java +++ b/java/common/src/main/java/org/apache/tsfile/enums/TSDataType.java @@ -26,6 +26,13 @@ import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Objects; +import java.util.Set; public enum TSDataType { /** BOOLEAN. */ @@ -63,9 +70,57 @@ public enum TSDataType { /** STRING */ STRING((byte) 11); - ; private final byte type; + private static final Map> compatibleTypes; + + static { + compatibleTypes = new EnumMap<>(TSDataType.class); + + compatibleTypes.put(BOOLEAN, Collections.emptySet()); + + compatibleTypes.put(INT32, Collections.emptySet()); + + Set i64CompatibleTypes = new HashSet<>(); + i64CompatibleTypes.add(INT32); + i64CompatibleTypes.add(TIMESTAMP); + compatibleTypes.put(INT64, i64CompatibleTypes); + + Set floatCompatibleTypes = new HashSet<>(); + floatCompatibleTypes.add(INT32); + compatibleTypes.put(FLOAT, floatCompatibleTypes); + + Set doubleCompatibleTypes = new HashSet<>(); + doubleCompatibleTypes.add(INT32); + doubleCompatibleTypes.add(INT64); + doubleCompatibleTypes.add(FLOAT); + doubleCompatibleTypes.add(TIMESTAMP); + compatibleTypes.put(DOUBLE, doubleCompatibleTypes); + + Set textCompatibleTypes = new HashSet<>(); + textCompatibleTypes.add(STRING); + compatibleTypes.put(TEXT, textCompatibleTypes); + + compatibleTypes.put(VECTOR, Collections.emptySet()); + + compatibleTypes.put(UNKNOWN, Collections.emptySet()); + + Set timestampCompatibleTypes = new HashSet<>(); + timestampCompatibleTypes.add(INT32); + timestampCompatibleTypes.add(INT64); + compatibleTypes.put(TIMESTAMP, timestampCompatibleTypes); + + compatibleTypes.put(DATE, Collections.emptySet()); + + Set blobCompatibleTypes = new HashSet<>(); + blobCompatibleTypes.add(STRING); + blobCompatibleTypes.add(TEXT); + compatibleTypes.put(BLOB, blobCompatibleTypes); + + Set stringCompatibleTypes = new HashSet<>(); + stringCompatibleTypes.add(TEXT); + compatibleTypes.put(STRING, stringCompatibleTypes); + } TSDataType(byte type) { this.type = type; @@ -116,6 +171,211 @@ public static TSDataType getTsDataType(byte type) { } } + /** + * @return if the source type can be cast to this type. + */ + public boolean isCompatible(TSDataType source) { + return this == source + || compatibleTypes.getOrDefault(this, Collections.emptySet()).contains(source); + } + + @SuppressWarnings({"java:S3012", "java:S3776", "java:S6541"}) + public Object castFromSingleValue(TSDataType sourceType, Object value) { + if (Objects.isNull(value)) { + return null; + } + switch (this) { + case BOOLEAN: + if (sourceType == TSDataType.BOOLEAN) { + return value; + } else { + break; + } + case INT32: + if (sourceType == TSDataType.INT32) { + return value; + } else { + break; + } + case INT64: + if (sourceType == TSDataType.INT64) { + return value; + } else if (sourceType == INT32) { + return (long) ((int) value); + } else if (sourceType == TIMESTAMP) { + return value; + } else { + break; + } + case FLOAT: + if (sourceType == TSDataType.FLOAT) { + return value; + } else if (sourceType == INT32) { + return (float) ((int) value); + } else { + break; + } + case DOUBLE: + if (sourceType == TSDataType.DOUBLE) { + return value; + } else if (sourceType == INT32) { + return (double) ((int) value); + } else if (sourceType == INT64) { + return (double) ((long) value); + } else if (sourceType == FLOAT) { + return (double) ((float) value); + } else if (sourceType == TIMESTAMP) { + return (double) ((long) value); + } else { + break; + } + case TEXT: + if (sourceType == TSDataType.TEXT || sourceType == TSDataType.STRING) { + return value; + } else { + break; + } + case TIMESTAMP: + if (sourceType == TSDataType.TIMESTAMP) { + return value; + } else if (sourceType == INT32) { + return (long) ((int) value); + } else if (sourceType == INT64) { + return value; + } else { + break; + } + case DATE: + if (sourceType == TSDataType.DATE) { + return value; + } else { + break; + } + case BLOB: + if (sourceType == TSDataType.BLOB + || sourceType == TSDataType.STRING + || sourceType == TSDataType.TEXT) { + return value; + } else { + break; + } + case STRING: + if (sourceType == TSDataType.STRING || sourceType == TSDataType.TEXT) { + return value; + } else { + break; + } + case VECTOR: + case UNKNOWN: + default: + break; + } + throw new ClassCastException( + String.format("Unsupported cast: from %s to %s", sourceType, this)); + } + + @SuppressWarnings({"java:S3012", "java:S3776", "java:S6541"}) + public Object castFromArray(TSDataType sourceType, Object array) { + switch (this) { + case BOOLEAN: + if (sourceType == TSDataType.BOOLEAN) { + return array; + } else { + break; + } + case INT32: + if (sourceType == TSDataType.INT32) { + return array; + } else { + break; + } + case INT64: + if (sourceType == TSDataType.INT64) { + return array; + } else if (sourceType == INT32) { + return Arrays.stream((int[]) array).mapToLong(Long::valueOf).toArray(); + } else if (sourceType == TIMESTAMP) { + return array; + } else { + break; + } + case FLOAT: + if (sourceType == TSDataType.FLOAT) { + return array; + } else if (sourceType == INT32) { + int[] tmp = (int[]) array; + float[] result = new float[tmp.length]; + for (int i = 0; i < tmp.length; i++) { + result[i] = tmp[i]; + } + return result; + } else { + break; + } + case DOUBLE: + if (sourceType == TSDataType.DOUBLE) { + return array; + } else if (sourceType == INT32) { + return Arrays.stream((int[]) array).mapToDouble(Double::valueOf).toArray(); + } else if (sourceType == INT64) { + return Arrays.stream((long[]) array).mapToDouble(Double::valueOf).toArray(); + } else if (sourceType == FLOAT) { + float[] tmp = (float[]) array; + double[] result = new double[tmp.length]; + for (int i = 0; i < tmp.length; i++) { + result[i] = tmp[i]; + } + return result; + } else if (sourceType == TIMESTAMP) { + return Arrays.stream((long[]) array).mapToDouble(Double::valueOf).toArray(); + } else { + break; + } + case TEXT: + if (sourceType == TSDataType.TEXT || sourceType == STRING) { + return array; + } else { + break; + } + case TIMESTAMP: + if (sourceType == TSDataType.TIMESTAMP) { + return array; + } else if (sourceType == INT32) { + return Arrays.stream((int[]) array).mapToLong(Long::valueOf).toArray(); + } else if (sourceType == INT64) { + return array; + } else { + break; + } + case DATE: + if (sourceType == TSDataType.DATE) { + return array; + } else { + break; + } + case BLOB: + if (sourceType == TSDataType.BLOB + || sourceType == TSDataType.STRING + || sourceType == TSDataType.TEXT) { + return array; + } else { + break; + } + case STRING: + if (sourceType == TSDataType.STRING || sourceType == TSDataType.TEXT) { + return array; + } else { + break; + } + case VECTOR: + case UNKNOWN: + default: + break; + } + throw new ClassCastException( + String.format("Unsupported cast: from %s to %s", sourceType, this)); + } + public static TSDataType deserializeFrom(ByteBuffer buffer) { return deserialize(buffer.get()); } diff --git a/java/tsfile/src/test/java/org/apache/tsfile/utils/TypeCastTest.java b/java/tsfile/src/test/java/org/apache/tsfile/utils/TypeCastTest.java new file mode 100644 index 000000000..10d26db18 --- /dev/null +++ b/java/tsfile/src/test/java/org/apache/tsfile/utils/TypeCastTest.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.tsfile.utils; + +import org.apache.tsfile.enums.TSDataType; + +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.fail; + +public class TypeCastTest { + + @Test + public void testSingleCast() { + Set dataTypes = new HashSet<>(); + Collections.addAll(dataTypes, TSDataType.values()); + dataTypes.remove(TSDataType.VECTOR); + dataTypes.remove(TSDataType.UNKNOWN); + + for (TSDataType from : dataTypes) { + for (TSDataType to : dataTypes) { + Object src = genValue(from); + if (to.isCompatible(from)) { + assertEquals(genValue(to), to.castFromSingleValue(from, src)); + } else { + assertThrows(ClassCastException.class, () -> to.castFromSingleValue(from, src)); + } + } + } + } + + @Test + public void testArrayCast() { + Set dataTypes = new HashSet<>(); + Collections.addAll(dataTypes, TSDataType.values()); + dataTypes.remove(TSDataType.VECTOR); + dataTypes.remove(TSDataType.UNKNOWN); + + for (TSDataType from : dataTypes) { + for (TSDataType to : dataTypes) { + Object array = genValueArray(from); + if (!to.isCompatible(from)) { + assertThrows(ClassCastException.class, () -> to.castFromArray(from, array)); + return; + } + switch (to) { + case INT32: + case DATE: + assertArrayEquals((int[]) genValueArray(to), (int[]) to.castFromArray(from, array)); + break; + case INT64: + case TIMESTAMP: + assertArrayEquals((long[]) genValueArray(to), (long[]) to.castFromArray(from, array)); + break; + case BOOLEAN: + assertArrayEquals( + (boolean[]) genValueArray(to), (boolean[]) to.castFromArray(from, array)); + break; + case STRING: + case BLOB: + case TEXT: + assertArrayEquals( + (Binary[]) genValueArray(to), (Binary[]) to.castFromArray(from, array)); + break; + case FLOAT: + assertArrayEquals( + (float[]) genValueArray(to), (float[]) to.castFromArray(from, array), 0.1f); + break; + case DOUBLE: + assertArrayEquals( + (double[]) genValueArray(to), (double[]) to.castFromArray(from, array), 0.1); + break; + case UNKNOWN: + case VECTOR: + default: + fail("Unexpected type: " + to); + } + } + } + } + + private Object genValue(TSDataType dataType) { + int i = 1; + switch (dataType) { + case INT32: + case DATE: + return i; + case TIMESTAMP: + case INT64: + return (long) i; + case BOOLEAN: + return false; + case FLOAT: + return i * 1.0f; + case DOUBLE: + return i * 1.0; + case STRING: + case TEXT: + case BLOB: + return new Binary(Integer.toString(i), StandardCharsets.UTF_8); + case UNKNOWN: + case VECTOR: + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + private Object genValueArray(TSDataType dataType) { + switch (dataType) { + case INT32: + case DATE: + return new int[] {1, 2, 3}; + case TIMESTAMP: + case INT64: + return new long[] {1, 2, 3}; + case BOOLEAN: + return new boolean[] {true, false}; + case FLOAT: + return new float[] {1.0f, 2.0f, 3.0f}; + case DOUBLE: + return new double[] {1.0, 2.0, 3.0}; + case STRING: + case TEXT: + case BLOB: + return new Binary[] { + new Binary(Integer.toString(1), StandardCharsets.UTF_8), + new Binary(Integer.toString(2), StandardCharsets.UTF_8), + new Binary(Integer.toString(3), StandardCharsets.UTF_8) + }; + case UNKNOWN: + case VECTOR: + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } +}