Skip to content
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions kyuubi-common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@
<artifactId>derby</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>com.jakewharton.fliptables</groupId>
<artifactId>fliptables</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,44 @@

package org.apache.kyuubi.service.authentication

import java.sql.{Connection, PreparedStatement, Statement}
import java.util.Properties
import javax.security.sasl.AuthenticationException
import javax.sql.DataSource

import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
import com.zaxxer.hikari.util.DriverDataSource
import org.apache.commons.lang3.StringUtils

import org.apache.kyuubi.Logging
import org.apache.kyuubi.config.KyuubiConf
import org.apache.kyuubi.config.KyuubiConf._
import org.apache.kyuubi.util.JdbcUtils

class JdbcAuthenticationProviderImpl(conf: KyuubiConf) extends PasswdAuthenticationProvider
with Logging {

private val driverClass = conf.get(AUTHENTICATION_JDBC_DRIVER)
private val jdbcUrl = conf.get(AUTHENTICATION_JDBC_URL)
private val jdbcUsername = conf.get(AUTHENTICATION_JDBC_USERNAME)
private val jdbcUserPassword = conf.get(AUTHENTICATION_JDBC_PASSWORD)
private val authQuerySql = conf.get(AUTHENTICATION_JDBC_QUERY)

private val SQL_PLACEHOLDER_REGEX = """\$\{.+?}""".r
private val USERNAME_SQL_PLACEHOLDER = "${username}"
private val PASSWORD_SQL_PLACEHOLDER = "${password}"

private val driverClass = conf.get(AUTHENTICATION_JDBC_DRIVER)
private val jdbcUrl = conf.get(AUTHENTICATION_JDBC_URL)
private val username = conf.get(AUTHENTICATION_JDBC_USERNAME)
private val password = conf.get(AUTHENTICATION_JDBC_PASSWORD)
private val authQuery = conf.get(AUTHENTICATION_JDBC_QUERY)

private val redactedPasswd = password match {
case Some(value) => s"${"*" * value.length}(length: ${value.length})"
case None => "(empty)"
}

checkJdbcConfigs()

private[kyuubi] val hikariDataSource = getHikariDataSource
implicit private[kyuubi] val ds: DataSource = new DriverDataSource(
jdbcUrl.orNull,
driverClass.orNull,
new Properties,
username.orNull,
password.orNull)

/**
* The authenticate method is called by the Kyuubi Server authentication layer
Expand All @@ -62,37 +73,27 @@ class JdbcAuthenticationProviderImpl(conf: KyuubiConf) extends PasswdAuthenticat
s" or contains blank space")
}

if (StringUtils.isBlank(password)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we bring back password not blank checking? If auth method allow passing with missing password, the whole authentication relies on the confidentiality of username alone. @pan3793

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should allow empty password, and the auth query will fail the athentication is password is required.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree

throw new AuthenticationException(s"Error validating, password is null" +
s" or contains blank space")
}

var connection: Connection = null
var queryStatement: PreparedStatement = null

try {
connection = hikariDataSource.getConnection

queryStatement = getAndPrepareQueryStatement(connection, user, password)

val resultSet = queryStatement.executeQuery()

if (resultSet == null || !resultSet.next()) {
// auth failed
throw new AuthenticationException(s"Password does not match or no such user. user:" +
s" $user , password length: ${password.length}")
debug(s"prepared auth query: $preparedQuery")
JdbcUtils.executeQuery(preparedQuery) { stmt =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we a add a query time out for JdbcUtils.executeQuery to prevent blocking out of connection timeout ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can extend KyuubiConf to allow configuring arbitrary properties, but it's out of scope in this PR, better do it in new PR.

stmt.setMaxRows(1) // minimum result size required for authentication
queryPlaceholders.zipWithIndex.foreach {
case (USERNAME_SQL_PLACEHOLDER, i) => stmt.setString(i + 1, user)
case (PASSWORD_SQL_PLACEHOLDER, i) => stmt.setString(i + 1, password)
case (p, _) => throw new IllegalArgumentException(
s"Unrecognized placeholder in Query SQL: $p")
}
} { resultSet =>
if (resultSet == null || !resultSet.next()) {
throw new AuthenticationException("Password does not match or no such user. " +
s"user: $user, password: $redactedPasswd")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pan3793 It seems throwing woring redacted password with authdb password. It should throws with connection password. Let me fix it in next PR?

Copy link
Member Author

@pan3793 pan3793 Aug 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, my fault, thanks for catching this issue. Sure, let's fix it in followup

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good. And we have the fix and ut in #3288 now.

}
}

// auth passed

} catch {
case e: AuthenticationException =>
throw e
case e: Exception =>
error("Cannot get user info", e);
throw e
} finally {
closeDbConnection(connection, queryStatement)
case rethrow: AuthenticationException =>
throw rethrow
case rethrow: Exception =>
throw new AuthenticationException("Cannot get user info", rethrow)
}
}

Expand All @@ -101,104 +102,31 @@ class JdbcAuthenticationProviderImpl(conf: KyuubiConf) extends PasswdAuthenticat

debug(configLog("Driver Class", driverClass.orNull))
debug(configLog("JDBC URL", jdbcUrl.orNull))
debug(configLog("Database username", jdbcUsername.orNull))
debug(configLog("Database password length", jdbcUserPassword.getOrElse("").length.toString))
debug(configLog("Query SQL", authQuerySql.orNull))
debug(configLog("Database username", username.orNull))
debug(configLog("Database password", redactedPasswd))
debug(configLog("Query SQL", authQuery.orNull))

// Check if JDBC parameters valid
if (driverClass.isEmpty) {
throw new IllegalArgumentException("JDBC driver class is not configured.")
require(driverClass.nonEmpty, "JDBC driver class is not configured.")
require(jdbcUrl.nonEmpty, "JDBC url is not configured.")
require(username.nonEmpty, "JDBC username is not configured")
// allow empty password
require(authQuery.nonEmpty, "Query SQL is not configured")

val query = authQuery.get.trim.toLowerCase
// allow simple select query sql only, complex query like CTE is not allowed
require(query.startsWith("select"), "Query SQL must start with 'SELECT'")
if (!query.contains("where")) {
warn("Query SQL does not contains 'WHERE' keyword")
}

if (jdbcUrl.isEmpty) {
throw new IllegalArgumentException("JDBC url is not configured")
}

if (jdbcUsername.isEmpty || jdbcUserPassword.isEmpty) {
throw new IllegalArgumentException("JDBC username or password is not configured")
}

// Check Query SQL
if (authQuerySql.isEmpty) {
throw new IllegalArgumentException("Query SQL is not configured")
}
val querySqlInLowerCase = authQuerySql.get.trim.toLowerCase
if (!querySqlInLowerCase.startsWith("select")) { // allow select query sql only
throw new IllegalArgumentException("Query SQL must start with \"SELECT\"");
}
if (!querySqlInLowerCase.contains("where")) {
warn("Query SQL does not contains \"WHERE\" keyword");
}
if (!querySqlInLowerCase.contains("${username}")) {
warn("Query SQL does not contains \"${username}\" placeholder");
if (!query.contains("${username}")) {
warn("Query SQL does not contains '${username}' placeholder")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
}
}
if (!query.contains("${password}")) {
warn("Query SQL does not contains '${password}' placeholder")
}

How about adding check for ${password} placeholder here, as auth method checked password not blank in first place? @pan3793

Copy link
Member Author

@pan3793 pan3793 Aug 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added warning message

}

private def getPlaceholderList(sql: String): List[String] = {
SQL_PLACEHOLDER_REGEX.findAllMatchIn(sql)
.map(m => m.matched)
.toList
}

private def getAndPrepareQueryStatement(
connection: Connection,
user: String,
password: String): PreparedStatement = {
private def preparedQuery: String =
SQL_PLACEHOLDER_REGEX.replaceAllIn(authQuery.get, "?")

val preparedSql: String = {
SQL_PLACEHOLDER_REGEX.replaceAllIn(authQuerySql.get, "?")
}
debug(s"prepared auth query sql: $preparedSql")

val stmt = connection.prepareStatement(preparedSql)
stmt.setMaxRows(1) // minimum result size required for authentication

// Extract placeholder list and fill parameters to placeholders
val placeholderList: List[String] = getPlaceholderList(authQuerySql.get)
for (i <- placeholderList.indices) {
val param = placeholderList(i) match {
case USERNAME_SQL_PLACEHOLDER => user
case PASSWORD_SQL_PLACEHOLDER => password
case otherPlaceholder =>
throw new IllegalArgumentException(
s"Unrecognized Placeholder In Query SQL: $otherPlaceholder")
}

stmt.setString(i + 1, param)
}

stmt
}

private def closeDbConnection(connection: Connection, statement: Statement): Unit = {
if (statement != null && !statement.isClosed) {
try {
statement.close()
} catch {
case e: Exception =>
error("Cannot close PreparedStatement to auth database ", e)
}
}

if (connection != null && !connection.isClosed) {
try {
connection.close()
} catch {
case e: Exception =>
error("Cannot close connection to auth database ", e)
}
}
}

private def getHikariDataSource: HikariDataSource = {
val datasourceProperties = new Properties()
val hikariConfig = new HikariConfig(datasourceProperties)
hikariConfig.setDriverClassName(driverClass.orNull)
hikariConfig.setJdbcUrl(jdbcUrl.orNull)
hikariConfig.setUsername(jdbcUsername.orNull)
hikariConfig.setPassword(jdbcUserPassword.orNull)
hikariConfig.setPoolName("jdbc-auth-pool")

new HikariDataSource(hikariConfig)
}
private def queryPlaceholders: Iterator[String] =
SQL_PLACEHOLDER_REGEX.findAllMatchIn(authQuery.get).map(_.matched)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.kyuubi.util

import java.sql.{Connection, PreparedStatement, ResultSet}
import javax.sql.DataSource

import scala.util.control.NonFatal

import org.apache.kyuubi.Logging

object JdbcUtils extends Logging {

def close(c: AutoCloseable): Unit = {
if (c != null) {
try {
c.close()
} catch {
case NonFatal(t) => warn(s"Error on closing", t)
}
}
}

def withCloseable[R, C <: AutoCloseable](c: C)(block: C => R): R = {
try {
block(c)
} finally {
close(c)
}
}

def withConnection[R](block: Connection => R)(implicit ds: DataSource): R = {
withCloseable(ds.getConnection)(block)
}

def execute(
sqlTemplate: String)(
setParameters: PreparedStatement => Unit = _ => {})(
implicit ds: DataSource): Boolean = withConnection { conn =>
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
setParameters(pStmt)
pStmt.execute()
}
}

def executeUpdate(
sqlTemplate: String)(
setParameters: PreparedStatement => Unit = _ => {})(
implicit ds: DataSource): Int = withConnection { conn =>
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
setParameters(pStmt)
pStmt.executeUpdate()
}
}

def executeQuery[R](
sqlTemplate: String)(
setParameters: PreparedStatement => Unit = _ => {})(
processResultSet: ResultSet => R)(
implicit ds: DataSource): R = withConnection { conn =>
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
setParameters(pStmt)
withCloseable(pStmt.executeQuery()) { rs =>
processResultSet(rs)
}
}
}

def executeQueryWithRowMapper[R](
sqlTemplate: String)(
setParameters: PreparedStatement => Unit = _ => {})(
rowMapper: ResultSet => R)(
implicit ds: DataSource): Seq[R] = withConnection { conn =>
withCloseable(conn.prepareStatement(sqlTemplate)) { pStmt =>
setParameters(pStmt)
withCloseable(pStmt.executeQuery()) { rs =>
val builder = Seq.newBuilder[R]
while (rs.next()) builder += rowMapper(rs)
builder.result
}
}
}
}
16 changes: 16 additions & 0 deletions kyuubi-common/src/test/scala/org/apache/kyuubi/TestUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.kyuubi

import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Path, StandardOpenOption}
import java.sql.ResultSet

import scala.collection.mutable.ArrayBuffer

import com.jakewharton.fliptables.FlipTable
import org.scalatest.Assertions.convertToEqualizer

object TestUtils {
Expand Down Expand Up @@ -59,4 +61,18 @@ object TestUtils {
newOutput.zip(expected).foreach { case (out, in) => assert(out === in, hint) }
}
}

def displayResultSet(resultSet: ResultSet): Unit = {
if (resultSet == null) throw new NullPointerException("resultSet == null")
val resultSetMetaData = resultSet.getMetaData
val columnCount: Int = resultSetMetaData.getColumnCount
val headers = (1 to columnCount).map(resultSetMetaData.getColumnName).toArray
val data = ArrayBuffer.newBuilder[Array[String]]
while (resultSet.next) {
data += (1 to columnCount).map(resultSet.getString).toArray
}
// scalastyle:off println
println(FlipTable.of(headers, data.result().toArray))
// scalastyle:on println
}
}
Loading