diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala index c694174b8c79a..e8ffb09ff9100 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/hudi/DataSourceOptions.scala @@ -23,7 +23,7 @@ import org.apache.hudi.common.config.{ConfigProperty, DFSPropertiesConfiguration import org.apache.hudi.common.fs.ConsistencyGuardConfig import org.apache.hudi.common.model.{HoodieTableType, WriteOperationType} import org.apache.hudi.common.table.HoodieTableConfig -import org.apache.hudi.common.util.Option +import org.apache.hudi.common.util.{Option, StringUtils} import org.apache.hudi.common.util.ValidationUtils.checkState import org.apache.hudi.config.{HoodieClusteringConfig, HoodieWriteConfig} import org.apache.hudi.hive.{HiveSyncConfig, HiveSyncConfigHolder, HiveSyncTool} @@ -787,9 +787,13 @@ object DataSourceOptionsHelper { def inferKeyGenClazz(props: TypedProperties): String = { val partitionFields = props.getString(DataSourceWriteOptions.PARTITIONPATH_FIELD.key(), null) - if (partitionFields != null) { + val recordsKeyFields = props.getString(DataSourceWriteOptions.RECORDKEY_FIELD.key(), DataSourceWriteOptions.RECORDKEY_FIELD.defaultValue()) + inferKeyGenClazz(recordsKeyFields, partitionFields) + } + + def inferKeyGenClazz(recordsKeyFields: String, partitionFields: String): String = { + if (!StringUtils.isNullOrEmpty(partitionFields)) { val numPartFields = partitionFields.split(",").length - val recordsKeyFields = props.getString(DataSourceWriteOptions.RECORDKEY_FIELD.key(), DataSourceWriteOptions.RECORDKEY_FIELD.defaultValue()) val numRecordKeyFields = recordsKeyFields.split(",").length if (numPartFields == 1 && numRecordKeyFields == 1) { classOf[SimpleKeyGenerator].getName diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/catalog/HoodieCatalogTable.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/catalog/HoodieCatalogTable.scala index 09981e845a108..f1357723200d3 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/catalog/HoodieCatalogTable.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/catalyst/catalog/HoodieCatalogTable.scala @@ -17,19 +17,19 @@ package org.apache.spark.sql.catalyst.catalog -import org.apache.hudi.AvroConversionUtils import org.apache.hudi.DataSourceWriteOptions.OPERATION import org.apache.hudi.HoodieWriterUtils._ import org.apache.hudi.common.config.DFSPropertiesConfiguration import org.apache.hudi.common.model.HoodieTableType import org.apache.hudi.common.table.{HoodieTableConfig, HoodieTableMetaClient} import org.apache.hudi.common.util.{StringUtils, ValidationUtils} -import org.apache.hudi.keygen.ComplexKeyGenerator import org.apache.hudi.keygen.factory.HoodieSparkKeyGeneratorFactory +import org.apache.hudi.{AvroConversionUtils, DataSourceOptionsHelper} import org.apache.spark.internal.Logging import org.apache.spark.sql.avro.SchemaConverters import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.hudi.HoodieOptionConfig +import org.apache.spark.sql.hudi.HoodieOptionConfig.SQL_KEY_TABLE_PRIMARY_KEY import org.apache.spark.sql.hudi.HoodieSqlCommonUtils._ import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.sql.{AnalysisException, SparkSession} @@ -288,7 +288,10 @@ class HoodieCatalogTable(val spark: SparkSession, var table: CatalogTable) exten HoodieSparkKeyGeneratorFactory.convertToSparkKeyGenerator( originTableConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key)) } else { - extraConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key) = classOf[ComplexKeyGenerator].getCanonicalName + val primaryKeys = table.properties.get(SQL_KEY_TABLE_PRIMARY_KEY.sqlKeyName).getOrElse(SQL_KEY_TABLE_PRIMARY_KEY.defaultValue.get) + val partitions = table.partitionColumnNames.mkString(",") + extraConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key) = + DataSourceOptionsHelper.inferKeyGenClazz(primaryKeys, partitions) } extraConfig.toMap } diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala index 93469f2796cf9..d3640474b252c 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestHoodieSparkSqlWriter.scala @@ -872,7 +872,7 @@ class TestHoodieSparkSqlWriter { .setBasePath(tablePath1).build().getTableConfig assert(tableConfig1.getHiveStylePartitioningEnable == "true") assert(tableConfig1.getUrlEncodePartitioning == "false") - assert(tableConfig1.getKeyGeneratorClassName == classOf[ComplexKeyGenerator].getName) + assert(tableConfig1.getKeyGeneratorClassName == classOf[SimpleKeyGenerator].getName) df.write.format("hudi") .options(options) .option(HoodieWriteConfig.TBL_NAME.key, tableName1) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCreateTable.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCreateTable.scala index d3dbf9a6e6aab..6a6b41da7fb73 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCreateTable.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/spark/sql/hudi/TestCreateTable.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType import org.apache.spark.sql.types._ - import org.junit.jupiter.api.Assertions.assertFalse import scala.collection.JavaConverters._ @@ -137,7 +136,7 @@ class TestCreateTable extends HoodieSparkSqlTestBase { assertResult("dt")(tableConfig(HoodieTableConfig.PARTITION_FIELDS.key)) assertResult("id")(tableConfig(HoodieTableConfig.RECORDKEY_FIELDS.key)) assertResult("ts")(tableConfig(HoodieTableConfig.PRECOMBINE_FIELD.key)) - assertResult(classOf[ComplexKeyGenerator].getCanonicalName)(tableConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key)) + assertResult(classOf[SimpleKeyGenerator].getCanonicalName)(tableConfig(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key)) assertResult("default")(tableConfig(HoodieTableConfig.DATABASE_NAME.key())) assertResult(tableName)(tableConfig(HoodieTableConfig.NAME.key())) assertFalse(tableConfig.contains(OPERATION.key())) @@ -944,4 +943,75 @@ class TestCreateTable extends HoodieSparkSqlTestBase { spark.sql("use default") } + + test("Test Infer KegGenClazz") { + def checkKeyGenerator(targetGenerator: String, tableName: String) = { + val tablePath = spark.sessionState.catalog.getTableMetadata(TableIdentifier(tableName)).location.getPath + val metaClient = HoodieTableMetaClient.builder() + .setBasePath(tablePath) + .setConf(spark.sessionState.newHadoopConf()) + .build() + val realKeyGenerator = + metaClient.getTableConfig.getProps.asScala.toMap.get(HoodieTableConfig.KEY_GENERATOR_CLASS_NAME.key).get + assertResult(targetGenerator)(realKeyGenerator) + } + + val tableName = generateTableName + + // Test Nonpartitioned table + spark.sql( + s""" + | create table $tableName ( + | id int, + | name string, + | price double, + | ts long + | ) using hudi + | comment "This is a simple hudi table" + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + checkKeyGenerator("org.apache.hudi.keygen.NonpartitionedKeyGenerator", tableName) + spark.sql(s"drop table $tableName") + + // Test single partitioned table + spark.sql( + s""" + | create table $tableName ( + | id int, + | name string, + | price double, + | ts long + | ) using hudi + | comment "This is a simple hudi table" + | partitioned by (ts) + | tblproperties ( + | primaryKey = 'id', + | preCombineField = 'ts' + | ) + """.stripMargin) + checkKeyGenerator("org.apache.hudi.keygen.SimpleKeyGenerator", tableName) + spark.sql(s"drop table $tableName") + + // Test single partitioned dual record keys table + spark.sql( + s""" + | create table $tableName ( + | id int, + | name string, + | price double, + | ts long + | ) using hudi + | comment "This is a simple hudi table" + | partitioned by (ts) + | tblproperties ( + | primaryKey = 'id,name', + | preCombineField = 'ts' + | ) + """.stripMargin) + checkKeyGenerator("org.apache.hudi.keygen.ComplexKeyGenerator", tableName) + spark.sql(s"drop table $tableName") + } }