Skip to content

Commit e78abd7

Browse files
author
zhangchen
committed
ut
1 parent d254d9a commit e78abd7

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/SparkTestBase.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.apache.iceberg.spark;
2121

22+
import java.lang.reflect.Array;
2223
import java.util.List;
2324
import java.util.Map;
2425
import java.util.stream.Collectors;
@@ -39,7 +40,6 @@
3940
import org.junit.AfterClass;
4041
import org.junit.Assert;
4142
import org.junit.BeforeClass;
42-
import org.junit.internal.ExactComparisonCriteria;
4343

4444
import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS;
4545

@@ -119,6 +119,8 @@ private Object[] toJava(Row row) {
119119
return row.getList(pos);
120120
} else if (value instanceof scala.collection.Map) {
121121
return row.getJavaMap(pos);
122+
} else if (value.getClass().isArray() && value.getClass().getComponentType().isPrimitive()) {
123+
return IntStream.range(0, Array.getLength(value)).mapToObj(i -> Array.get(value, i)).toArray();
122124
} else {
123125
return value;
124126
}
@@ -158,11 +160,7 @@ private void assertEquals(String context, Object[] expectedRow, Object[] actualR
158160
Object actualValue = actualRow[col];
159161
if (expectedValue != null && expectedValue.getClass().isArray()) {
160162
String newContext = String.format("%s (nested col %d)", context, col + 1);
161-
if (expectedValue.getClass().getComponentType().isPrimitive()) {
162-
new ExactComparisonCriteria().arrayEquals(newContext, expectedValue, actualValue);
163-
} else {
164-
assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue);
165-
}
163+
assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue);
166164
} else if (expectedValue != ANY) {
167165
Assert.assertEquals(context + " contents should match", expectedValue, actualValue);
168166
}

spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/sql/TestSelect.java

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ public class TestSelect extends SparkCatalogTestBase {
3737
private int scanEventCount = 0;
3838
private ScanEvent lastScanEvent = null;
3939

40+
private final String binaryTableName = tableName("binary_table");
41+
4042
public TestSelect(String catalogName, String implementation, Map<String, String> config) {
4143
super(catalogName, implementation, config);
4244

@@ -49,8 +51,11 @@ public TestSelect(String catalogName, String implementation, Map<String, String>
4951

5052
@Before
5153
public void createTables() {
52-
sql("CREATE TABLE %s (id bigint, data string, float float, binary binary) USING iceberg", tableName);
53-
sql("INSERT INTO %s VALUES (1, 'a', 1.0, X''), (2, 'b', 2.0, X'11'), (3, 'c', float('NaN'), X'1111')", tableName);
54+
sql("CREATE TABLE %s (id bigint, data string, float float) USING iceberg", tableName);
55+
sql("INSERT INTO %s VALUES (1, 'a', 1.0), (2, 'b', 2.0), (3, 'c', float('NaN'))", tableName);
56+
57+
sql("CREATE TABLE %s (id bigint, binary binary) USING iceberg", binaryTableName);
58+
sql("INSERT INTO %s VALUES (1, X''), (2, X'11'), (3, X'1111')", binaryTableName);
5459

5560
this.scanEventCount = 0;
5661
this.lastScanEvent = null;
@@ -59,21 +64,20 @@ public void createTables() {
5964
@After
6065
public void removeTables() {
6166
sql("DROP TABLE IF EXISTS %s", tableName);
67+
sql("DROP TABLE IF EXISTS %s", binaryTableName);
6268
}
6369

6470
@Test
6571
public void testSelect() {
6672
List<Object[]> expected = ImmutableList.of(
67-
row(1L, "a", 1.0F, new byte[]{}),
68-
row(2L, "b", 2.0F, new byte[]{0x11}),
69-
row(3L, "c", Float.NaN, new byte[]{0x11, 0x11}));
73+
row(1L, "a", 1.0F), row(2L, "b", 2.0F), row(3L, "c", Float.NaN));
7074

7175
assertEquals("Should return all expected rows", expected, sql("SELECT * FROM %s", tableName));
7276
}
7377

7478
@Test
7579
public void testSelectRewrite() {
76-
List<Object[]> expected = ImmutableList.of(row(3L, "c", Float.NaN, new byte[]{0x11, 0x11}));
80+
List<Object[]> expected = ImmutableList.of(row(3L, "c", Float.NaN));
7781

7882
assertEquals("Should return all expected rows", expected,
7983
sql("SELECT * FROM %s where float = float('NaN')", tableName));
@@ -125,9 +129,9 @@ public void testMetadataTables() {
125129

126130
@Test
127131
public void testFilterBinary() {
128-
List<Object[]> expected = ImmutableList.of(row(3L, "c", Float.NaN, new byte[]{0x11, 0x11}));
132+
List<Object[]> expected = ImmutableList.of(row(3L, new Byte[]{0x11, 0x11}));
129133

130134
assertEquals("Should return all expected rows", expected,
131-
sql("SELECT * FROM %s where binary > X'1101'", tableName));
135+
sql("SELECT * FROM %s where binary > X'1101'", binaryTableName));
132136
}
133137
}

0 commit comments

Comments
 (0)