diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala new file mode 100644 index 000000000000..5c9178131d22 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/SqlIdUtil.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Locale +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.SparkException + +/** + * A three part table identifier. The first two parts can be null. + * + * @param database The database name. + * @param schema The schema name. + * @param table The table name. + */ +case class TableId(database: String, schema: String, table: String) + +/** + * Utility methods for SQL identifiers. These methods were loosely + * translated from org.apache.derby.iapi.util.IdUtil and + * org.apache.derby.iapi.util.StringUtil. + */ +object SqlIdUtil { + + private val OneQuote = """"""" + private val TwoQuotes = """""""" + private val DefaultQuote = '"' + + // Regular expression defining one id in a dot-separated SQL identifier chain + private val OneIdString = + "(\\s)*((" + // leading spaces ok + """\p{Alpha}(\p{Alnum}|_)*""" + // regular identifier (no quotes) + ")|(" + // or + """"(""|[^"])+"""" + // delimited identifier (quoted) + "))(\\s)*" // trailing spaces ok + + /** + * Quote a string so that it can be used as an identifier or a string + * literal in SQL statements. Identifiers are usually surrounded by double quotes + * and string literals are surrounded by single quotes. If the string + * contains quote characters, they are escaped. + * + * @param source the string to quote + * @param quote the framing quote character (e.g.: ', ", `) + * @return a string quoted with the indicated quote character + */ + def quoteString(source: String, quote: Char): String = { + // Normally, the quoted string is two characters longer than the source + // string (because of start quote and end quote). + val quoted = new StringBuilder(source.length() + 2) + + quoted.append(quote) + for (ch <- source) { + quoted.append(ch) + if (ch == quote) quoted.append(quote) + } + quoted.append(quote) + quoted.toString() + } + + /** Parse a user-supplied object id of the form + * [[database.]schema.]objectName + * into a TableIdentifier(database, schema, objectName). + * The database and schema names may be empty. The caller + * must supply the database-specific quote character which is used + * to frame delimited ids. For most databases this is the " + * character. For Hive, this is the ` character. The caller must + * specify whether the database uppercases or lowercases + * unquoted identifiers when they are stored in its metadata + * catalogs. + * + * The fields of the TableIdentifier are normalized to the case + * convention used by the database's catalogs. So for a database + * which uses " for quoted identifiers and which uppercases + * ids in its metadata catalogs, the string + * + * "foo".bar + * + * would result in + * + * TableIdentifier( null, foo, BAR ) + * + * @param rawName The user-supplied name. + * @param quote The db-specific character which frames delimited ids. + * @param upperCase True if the db uppercases un-delimited ids. + */ + def parseSqlIds( + rawName: String, + quote: Char, + upperCase: Boolean): TableId = { + val parsed = parseMultiPartSqlIdentifier(rawName, + quote, upperCase) + + parsed.length match { + case 1 => TableId(null, null, parsed(0)) + case 2 => TableId(null, parsed(0), parsed(1)) + case 3 => TableId(parsed(0), parsed(1), parsed(2)) + case _ => throw new Exception("Unparsable object id: " + rawName) + } + } + + /** + * Parse a multi-part (dot separated) chain of SQL identifiers from the + * String provided. Raise an excepion + * if the string does not contain valid SQL indentifiers. + * The returned String array contains the normalized form of the + * identifiers. + * + * @param rawName The string to be parsed + * @param quote The character which frames a delimited id (e.g., " or `) + * @param upperCase True if SQL ids are normalized to upper case. + * @return An array of strings made by breaking the input string at its dots, '.'. + * @throws SparkException Invalid SQL identifier. + */ + private def parseMultiPartSqlIdentifier( + rawName: String, + quote: Char, + upperCase: Boolean): ArrayBuffer[String] = { + + // construct the regex, accounting for the caller-supplied quote character + var regexString = OneIdString + if (quote != DefaultQuote) + { + regexString = regexString.replace(DefaultQuote, quote) + } + val oneIdRegex = regexString.r + + // + // Loop through the raw string, one identifier at a time. + // Discard spaces around the identifiers. Discard + // the dots which separate one identifier from the next. + // + var result = ArrayBuffer[String]() + var keepGoing = true + var remainingString = rawName + while (keepGoing) + { + oneIdRegex.findPrefixOf(remainingString) match { + + case Some(paddedId) => { + val paddedIdLength = paddedId.length + result.append(normalize(paddedId.trim, quote, upperCase)) + if (remainingString.length == paddedIdLength) { + keepGoing = false // we're done. hooray. + } + else if (remainingString.charAt(paddedIdLength) == '.') { + // chop off the old identifier and the dot separator. + // continue looking for more ids in the rest of the string. + remainingString = remainingString.substring(paddedIdLength + 1) + } + else { + throw parseError(rawName) + } + } + + case _ => { + throw parseError(rawName) + } + } // end matching an id + + } // end of loop through ids + + result + } + + /** + * Normalize a SQL identifier to the case used by the target + * database's metadata catalogs. + * + * @param rawName The string to be normalized (may be framed by quotes) + * @param quote The character which frames a delimited id (e.g., " or `) + * @param upperCase True if SQL ids are normalized to upper case. + * @return An array of strings made by breaking the input string at its dots, '.'. + */ + private def normalize(rawName: String, quote: Char, upperCase: Boolean): String = { + + // regular id + if (rawName.charAt(0) != quote) adjustCase(rawName, upperCase) + // delimited id + else stripQuotes(rawName, quote) + } + + /** + * Adjust the case of an unquoted identifier to the case convention + * used by the metadata catalogs of the target database. + * Always use the java.util.ENGLISH locale. + * + * @param rawName string to uppercase + * @param upperCase True if SQL ids are normalized to upper case. + * @return The properly cased string. + */ + private def adjustCase(rawName: String, upperCase: Boolean): String = { + if (upperCase) rawName.toUpperCase(Locale.ENGLISH) + else rawName.toLowerCase(Locale.ENGLISH) + } + + /** + * Strip framing quotes from a delimited id and un-escape interior quotes. + * + * @param rawName string to uppercase + * @param quote the database-specific quote character. + * @return The properly cased string. + */ + private def stripQuotes(rawName: String, quote: Char): String = { + var oneQuote = OneQuote + var twoQuotes = TwoQuotes + if ( quote != DefaultQuote) + { + val oneQuote = OneQuote.replace(DefaultQuote, quote) + val twoQuotes = TwoQuotes.replace(DefaultQuote, quote) + } + rawName.substring(1, rawName.length - 1).replace(twoQuotes, oneQuote) + } + + /** + * Create a parsing exception. + * + * @param orig The full text being parsed + * @return A SparkException describing a parsing error. + */ + private def parseError(orig: String): SparkException = { + new SparkException("Error parsing SQL identifier: " + orig) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 7887e559a302..1bce5d39a0ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql import java.util.Properties +import org.apache.spark.sql.jdbc.JdbcDialects + import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -255,6 +257,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) val conn = JdbcUtils.createConnection(url, props) + val dialect = JdbcDialects.get(url) try { var tableExists = JdbcUtils.tableExists(conn, url, table) @@ -268,13 +271,14 @@ final class DataFrameWriter private[sql](df: DataFrame) { } if (mode == SaveMode.Overwrite && tableExists) { - JdbcUtils.dropTable(conn, table) + JdbcUtils.dropTable(conn, dialect, table) tableExists = false } // Create the table if the table didn't exist. if (!tableExists) { val schema = JdbcUtils.schemaString(df, url) + dialect.vetSqlIdentifier(table) val sql = s"CREATE TABLE $table ($schema)" conn.prepareStatement(sql).executeUpdate() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index f89d55b20e21..fc50b91a7622 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.util.Try import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.JdbcDialects +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} @@ -54,14 +54,20 @@ object JdbcUtils extends Logging { /** * Drops a table from the JDBC database. */ - def dropTable(conn: Connection, table: String): Unit = { + def dropTable(conn: Connection, dialect: JdbcDialect, table: String): Unit = { + dialect.vetSqlIdentifier(table) conn.prepareStatement(s"DROP TABLE $table").executeUpdate() } /** * Returns a PreparedStatement that inserts a row into table via conn. */ - def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { + def insertStatement( + conn: Connection, + dialect: JdbcDialect, + table: String, + rddSchema: StructType): PreparedStatement = { + dialect.vetSqlIdentifier(table) val sql = new StringBuilder(s"INSERT INTO $table VALUES (") var fieldsLeft = rddSchema.fields.length while (fieldsLeft > 0) { @@ -88,6 +94,7 @@ object JdbcUtils extends Logging { */ def savePartition( getConnection: () => Connection, + dialect: JdbcDialect, table: String, iterator: Iterator[Row], rddSchema: StructType, @@ -97,7 +104,7 @@ object JdbcUtils extends Logging { var committed = false try { conn.setAutoCommit(false) // Everything in the same db transaction. - val stmt = insertStatement(conn, table, rddSchema) + val stmt = insertStatement(conn, dialect, table, rddSchema) try { var rowCount = 0 while (iterator.hasNext) { @@ -225,8 +232,10 @@ object JdbcUtils extends Logging { val driver: String = DriverRegistry.getDriverClassName(url) val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt + val jdbcDialect = JdbcDialects.get(url) df.foreachPartition { iterator => - savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize) + savePartition(getConnection, jdbcDialect, table, iterator, + rddSchema, nullTypes, batchSize) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 88ae83957a70..876d782c0c90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.jdbc import java.sql.Types +import org.apache.spark.SparkException +import org.apache.spark.sql.SqlIdUtil._ import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi @@ -62,6 +64,12 @@ abstract class JdbcDialect { */ def canHandle(url : String): Boolean + /** + * Return the character used to frame delimited identifiers in this database. + * @return The delimited id character (usually ", sometimes `) + */ + def quoteChar: Char = '"' + /** * Get the custom datatype mapping for the given jdbc meta information. * @param sqlType The sql type (see java.sql.Types) @@ -86,19 +94,51 @@ abstract class JdbcDialect { * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). */ def quoteIdentifier(colName: String): String = { - s""""$colName"""" + quoteString(colName, quoteChar) + } + + /** + * Get the SQL query that should be used to find if the given table exists. + * Call this method (and not tableExistsQuery) in order to verify + * that the table name is properly formed. + * @param table The name of the table. + * @return The SQL query to use for checking the table. + * @throws org.apache.spark.SparkException On invalid table name. + */ + final def getTableExistsQuery(table: String): String = { + vetSqlIdentifier(table) + tableExistsQuery(table) } /** * Get the SQL query that should be used to find if the given table exists. Dialects can * override this method to return a query that works best in a particular database. + * Don't expose this method outside this class and its subclasses. + * Other consumers should call getTableExistsQuery instead. That method + * verifies that the table name is properly formed. * @param table The name of the table. * @return The SQL query to use for checking the table. */ - def getTableExistsQuery(table: String): String = { + protected def tableExistsQuery(table: String): String = { s"SELECT * FROM $table WHERE 1=0" } + /** Vet a user-supplied object id of the form + * [[catalog.]schema.]objectName + * by parsing it into a (catalog, schema, objectName) + * triple. The catalog and schema names may be empty. Raises + * a SparkException if the user-supplied id is malformed, + * e.g., is a string like "foo; drop database finance;", + * something intended for a SQL injection attack. + * + * @param rawId The user-supplied object id (name). + * @throws org.apache.spark.SparkException On invalid ids. + */ + def vetSqlIdentifier(rawId: String) { + + // raises a SparkException if the string doesn't parse + parseSqlIds(rawId, quoteChar, true) + } } /** @@ -152,6 +192,12 @@ object JdbcDialects { case _ => new AggregatedDialect(matchingDialects) } } + + /** + * Get all dialects (useful for testing purposes). + */ + private[sql] def getAllDialects(): List[JdbcDialect] = dialects + } /** @@ -217,7 +263,7 @@ case object PostgresDialect extends JdbcDialect { case _ => None } - override def getTableExistsQuery(table: String): String = { + override protected def tableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } @@ -230,6 +276,7 @@ case object PostgresDialect extends JdbcDialect { @DeveloperApi case object MySQLDialect extends JdbcDialect { override def canHandle(url : String): Boolean = url.startsWith("jdbc:mysql") + override def quoteChar: Char = '`' override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { @@ -242,13 +289,20 @@ case object MySQLDialect extends JdbcDialect { } else None } - override def quoteIdentifier(colName: String): String = { - s"`$colName`" - } - - override def getTableExistsQuery(table: String): String = { + override protected def tableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } + + // The default implementation of this method allows embedded, + // escaped quotes inside quoted identifiers. SQL Server does not + // allow embedded quotes. This means that this method won't catch + // some illegal table names. Those names will appear to SQL Server as an + // ungrammatical sequence of quoted identifiers. In order to get + // a better error message, someone may want to provide an + // implementation which handles the SQL Server grammar better. + // + // override def vetSqlIdentifier(rawId: String) + } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d530b1a469ce..6579bb0fe5a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -25,6 +25,7 @@ import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.SparkException import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -484,4 +485,71 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(h2.getTableExistsQuery(table) == defaultQuery) assert(derby.getTableExistsQuery(table) == defaultQuery) } + + /** + * Verify that the JdbcDialect rejects an illegal name. + * @param dialect The JdbcDialect. + * @param tableName A bad table name. + */ + def badNameVetter(dialect: JdbcDialect, tableName: String) { + val badTableName = tableName.replace( '"', dialect.quoteChar) + val dialectName = dialect.getClass.getName + try { + dialect.vetSqlIdentifier(badTableName) + fail(dialectName + " should have rejected " + badTableName) + } catch { + case exc: SparkException => { + val expectedMessage = "Error parsing SQL identifier: " + badTableName + assert(expectedMessage == exc.getMessage) + } + case obj: Throwable => { + val badResponse = obj.toString + fail("Unexpected vetting failure when dialect " + dialectName + + " vets table name " + badTableName + ": " + badResponse) + } + } + } + + test("verify that JDBC dialects vet table names") { + val badNames = + """bad"name""" :: + """foo; drop database finance;""" :: + """foo.""" :: + """.foo""" :: + """foo bar""" :: + """fo"o""" :: + Nil + val allDialects = JdbcDialects.getAllDialects() + for ( d <- allDialects; b <- badNames ) badNameVetter(d, b) + + val goodNames = + """foo.bar""" :: + """"foo".bar""" :: + """"foo"."bar"""" :: + """foo."bar"""" :: + """"foo.bar"""" :: + """ foo""" :: + """"foo"""" :: + """"foo bar"""" :: + """"foo."""" :: + """foo.bar""" :: + """foo .bar""" :: + """foo.bar.wibble""" :: + """"fo""o"""" :: + Nil + for ( d <- allDialects; g <- goodNames ) { + val goodTableName = g.replace( '"', d.quoteChar) + try { + d.getTableExistsQuery(goodTableName) + } catch { + case obj: Throwable => { + val errorMessage = "Dang. " + d.getClass.getName + + " couldn't handle " + + goodTableName + ": " + obj.toString + fail(errorMessage) + } + } + } + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee6693133..0430c2b50d9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.DriverManager import java.util.Properties +import org.apache.spark.SparkException import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{Row, SaveMode} @@ -151,4 +152,29 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } + + test("Negative: CREATE with illegal table name") { + val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + + df.write.jdbc(url, "TEST.DUMMY", new Properties) + assert(2 === sqlContext.read.jdbc(url, "TEST.DUMMY", new Properties).count) + + try { + df.write.jdbc(url, "TEST.FOO(A INT); DROP TABLE TEST.DUMMY;", new Properties) + fail("Table creation should have failed.") + } catch { + case exc: SparkException => { + val expectedMessage = + "Error parsing SQL identifier: TEST.FOO(A INT); DROP TABLE TEST.DUMMY;" + assert(expectedMessage == exc.getMessage) + } + case obj: Throwable => { + val badResponse = obj.toString + fail("Unexpected failure for table creation: " + badResponse) + } + } + assert(2 === sqlContext.read.jdbc(url, "TEST.DUMMY", new Properties).count) + } + + }