@@ -180,60 +180,47 @@ object BulkCopyUtils extends Logging {
180180 }
181181
182182 /**
183- * getComputedCols
184- * utility function to get computed columns.
185- * Use computed column names to exclude computed column when matching schema.
183+ * getAutoCols
184+ * utility function to get auto generated columns.
185+ * Use auto generated column names to exclude them when matching schema.
186186 */
187- private [spark] def getComputedCols (
187+ private [spark] def getAutoCols (
188188 conn : Connection ,
189- table : String ,
190- hideGraphColumns : Boolean ): List [String ] = {
191- // TODO can optimize this, also evaluate SQLi issues
192- val queryStr = if (hideGraphColumns) s """ IF (SERVERPROPERTY('EngineEdition') = 5 OR SERVERPROPERTY('ProductMajorVersion') >= 14)
193- exec sp_executesql N'SELECT name
194- FROM sys.computed_columns
195- WHERE object_id = OBJECT_ID('' ${table}'')
196- UNION ALL
197- SELECT C.name
198- FROM sys.tables AS T
199- JOIN sys.columns AS C
200- ON T.object_id = C.object_id
201- WHERE T.object_id = OBJECT_ID('' ${table}'')
202- AND (T.is_edge = 1 OR T.is_node = 1)
203- AND C.is_hidden = 0
204- AND C.graph_type = 2'
205- ELSE
206- SELECT name
207- FROM sys.computed_columns
189+ table : String ): List [String ] = {
190+ // auto cols union computed cols, generated always cols, and node / edge table auto cols
191+ val queryStr = s """ SELECT name
192+ FROM sys.columns
208193 WHERE object_id = OBJECT_ID(' ${table}')
194+ AND (is_computed = 1 -- computed column
195+ OR generated_always_type > 0 -- generated always / temporal table
196+ OR (is_hidden = 0 AND graph_type = 2)) -- graph table
209197 """
210- else s " SELECT name FROM sys.computed_columns WHERE object_id = OBJECT_ID(' ${table}'); "
211198
212- val computedColRs = conn.createStatement.executeQuery(queryStr)
213- val computedCols = ListBuffer [String ]()
214- while (computedColRs .next()) {
215- val colName = computedColRs .getString(" name" )
216- computedCols .append(colName)
199+ val autoColRs = conn.createStatement.executeQuery(queryStr)
200+ val autoCols = ListBuffer [String ]()
201+ while (autoColRs .next()) {
202+ val colName = autoColRs .getString(" name" )
203+ autoCols .append(colName)
217204 }
218- computedCols .toList
205+ autoCols .toList
219206 }
220207
221208 /**
222- * dfComputedColCount
209+ * dfAutoColCount
223210 * utility function to get number of computed columns in dataframe.
224211 * Use number of computed columns in dataframe to get number of non computed column in df,
225212 * and compare with the number of non computed column in sql table
226213 */
227- private [spark] def dfComputedColCount (
214+ private [spark] def dfAutoColCount (
228215 dfColNames : List [String ],
229- computedCols : List [String ],
216+ autoCols : List [String ],
230217 dfColCaseMap : Map [String , String ],
231218 isCaseSensitive : Boolean ): Int = {
232219 var dfComputedColCt = 0
233- for (j <- 0 to computedCols .length- 1 ){
234- if (isCaseSensitive && dfColNames.contains(computedCols (j)) ||
235- ! isCaseSensitive && dfColCaseMap.contains(computedCols (j).toLowerCase())
236- && dfColCaseMap(computedCols (j).toLowerCase()) == computedCols (j)) {
220+ for (j <- 0 to autoCols .length- 1 ){
221+ if (isCaseSensitive && dfColNames.contains(autoCols (j)) ||
222+ ! isCaseSensitive && dfColCaseMap.contains(autoCols (j).toLowerCase())
223+ && dfColCaseMap(autoCols (j).toLowerCase()) == autoCols (j)) {
237224 dfComputedColCt += 1
238225 }
239226 }
@@ -284,7 +271,7 @@ SELECT name
284271 val colMetaData = {
285272 if (checkSchema) {
286273 checkExTableType(conn, options)
287- matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled, options.hideGraphColumns )
274+ matchSchemas(conn, options.dbtable, df, rs, options.url, isCaseSensitive, options.schemaCheckEnabled)
288275 } else {
289276 defaultColMetadataMap(rs.getMetaData())
290277 }
@@ -310,7 +297,6 @@ SELECT name
310297 * @param url: String,
311298 * @param isCaseSensitive: Boolean
312299 * @param strictSchemaCheck: Boolean
313- * @param hideGraphColumns - Whether to hide the $node_id, $from_id, $to_id, $edge_id columns in SQL graph tables
314300 */
315301 private [spark] def matchSchemas (
316302 conn : Connection ,
@@ -319,40 +305,39 @@ SELECT name
319305 rs : ResultSet ,
320306 url : String ,
321307 isCaseSensitive : Boolean ,
322- strictSchemaCheck : Boolean ,
323- hideGraphColumns : Boolean ): Array [ColumnMetadata ]= {
308+ strictSchemaCheck : Boolean ): Array [ColumnMetadata ]= {
324309 val dfColCaseMap = (df.schema.fieldNames.map(item => item.toLowerCase)
325310 zip df.schema.fieldNames.toList).toMap
326311 val dfCols = df.schema
327312
328313 val tableCols = getSchema(rs, JdbcDialects .get(url))
329- val computedCols = getComputedCols (conn, dbtable, hideGraphColumns )
314+ val autoCols = getAutoCols (conn, dbtable)
330315
331316 val prefix = " Spark Dataframe and SQL Server table have differing"
332317
333- if (computedCols .length == 0 ) {
318+ if (autoCols .length == 0 ) {
334319 assertIfCheckEnabled(dfCols.length == tableCols.length, strictSchemaCheck,
335320 s " ${prefix} numbers of columns " )
336321 } else if (strictSchemaCheck) {
337322 val dfColNames = df.schema.fieldNames.toList
338- val dfComputedColCt = dfComputedColCount (dfColNames, computedCols , dfColCaseMap, isCaseSensitive)
323+ val dfComputedColCt = dfAutoColCount (dfColNames, autoCols , dfColCaseMap, isCaseSensitive)
339324 // if df has computed column(s), check column length using non computed column in df and table.
340325 // non computed column number in df: dfCols.length - dfComputedColCt
341- // non computed column number in table: tableCols.length - computedCols .length
342- assertIfCheckEnabled(dfCols.length- dfComputedColCt == tableCols.length- computedCols .length, strictSchemaCheck,
326+ // non computed column number in table: tableCols.length - autoCols .length
327+ assertIfCheckEnabled(dfCols.length- dfComputedColCt == tableCols.length- autoCols .length, strictSchemaCheck,
343328 s " ${prefix} numbers of columns " )
344329 }
345330
346331
347- val result = new Array [ColumnMetadata ](tableCols.length - computedCols .length)
332+ val result = new Array [ColumnMetadata ](tableCols.length - autoCols .length)
348333 var nonAutoColIndex = 0
349334
350335 for (i <- 0 to tableCols.length- 1 ) {
351336 val tableColName = tableCols(i).name
352337 var dfFieldIndex = - 1
353338 // set dfFieldIndex = -1 for all computed columns to skip ColumnMetadata
354- if (computedCols .contains(tableColName)) {
355- logDebug(s " skipping computed col index $i col name $tableColName dfFieldIndex $dfFieldIndex" )
339+ if (autoCols .contains(tableColName)) {
340+ logDebug(s " skipping auto generated col index $i col name $tableColName dfFieldIndex $dfFieldIndex" )
356341 }else {
357342 var dfColName : String = " "
358343 if (isCaseSensitive) {
0 commit comments