Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("")
private val plusOrMinus = 2

def render(request: HttpServletRequest): Seq[Node] = {
val requestedPage = Option(request.getParameter("page")).getOrElse("1").toInt
val requestedPage = Option(UIUtils.stripXSS(request.getParameter("page"))).getOrElse("1").toInt
val requestedFirst = (requestedPage - 1) * pageSize

// stripXSS is called first to remove suspicious characters used in XSS attacks
val requestedIncomplete =
Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
Option(UIUtils.stripXSS(request.getParameter("showIncomplete"))).getOrElse("false").toBoolean

val allApps = parent.getApplicationList()
.filter(_.attempts.head.completed != requestedIncomplete)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app")

/** Executor details for a particular application */
def render(request: HttpServletRequest): Seq[Node] = {
val appId = request.getParameter("appId")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val appId = UIUtils.stripXSS(request.getParameter("appId"))
val state = master.askWithRetry[MasterStateResponse](RequestMasterState)
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
if (parent.killEnabled &&
parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val id = Option(request.getParameter("id"))
// stripXSS is called first to remove suspicious characters used in XSS attacks
val killFlag =
Option(UIUtils.stripXSS(request.getParameter("terminate"))).getOrElse("false").toBoolean
val id = Option(UIUtils.stripXSS(request.getParameter("id")))
if (id.isDefined && killFlag) {
action(id.get)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ import org.apache.spark.ui.{UIUtils, WebUIPage}
private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") {

override def render(request: HttpServletRequest): Seq[Node] = {
val driverId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val driverId = UIUtils.stripXSS(request.getParameter("id"))
require(driverId != null && driverId.nonEmpty, "Missing id parameter")

val state = parent.scheduler.getDriverState(driverId)
Expand Down Expand Up @@ -96,22 +97,22 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
<td>Mesos Slave ID</td>
<td>{state.slaveId.getValue}</td>
</tr>
<tr>
<td>Mesos Task ID</td>
<td>{state.taskId.getValue}</td>
</tr>
<tr>
<td>Launch Time</td>
<td>{state.startDate}</td>
</tr>
<tr>
<td>Finish Time</td>
<td>{state.finishDate.map(_.toString).getOrElse("")}</td>
</tr>
<tr>
<td>Last Task Status</td>
<td>{state.mesosTaskStatus.map(_.toString).getOrElse("")}</td>
</tr>
<tr>
<td>Mesos Task ID</td>
<td>{state.taskId.getValue}</td>
</tr>
<tr>
<td>Launch Time</td>
<td>{state.startDate}</td>
</tr>
<tr>
<td>Finish Time</td>
<td>{state.finishDate.map(_.toString).getOrElse("")}</td>
</tr>
<tr>
<td>Last Task Status</td>
<td>{state.mesosTaskStatus.map(_.toString).getOrElse("")}</td>
</tr>
}.getOrElse(Seq[Node]())
}

Expand All @@ -127,39 +128,39 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver")
<tr>
<td>Main class</td><td>{command.mainClass}</td>
</tr>
<tr>
<td>Arguments</td><td>{command.arguments.mkString(" ")}</td>
</tr>
<tr>
<td>Class path entries</td><td>{command.classPathEntries.mkString(" ")}</td>
</tr>
<tr>
<td>Java options</td><td>{command.javaOpts.mkString((" "))}</td>
</tr>
<tr>
<td>Library path entries</td><td>{command.libraryPathEntries.mkString((" "))}</td>
</tr>
<tr>
<td>Arguments</td><td>{command.arguments.mkString(" ")}</td>
</tr>
<tr>
<td>Class path entries</td><td>{command.classPathEntries.mkString(" ")}</td>
</tr>
<tr>
<td>Java options</td><td>{command.javaOpts.mkString((" "))}</td>
</tr>
<tr>
<td>Library path entries</td><td>{command.libraryPathEntries.mkString((" "))}</td>
</tr>
}

private def driverRow(driver: MesosDriverDescription): Seq[Node] = {
<tr>
<td>Name</td><td>{driver.name}</td>
</tr>
<tr>
<td>Id</td><td>{driver.submissionId}</td>
</tr>
<tr>
<td>Cores</td><td>{driver.cores}</td>
</tr>
<tr>
<td>Memory</td><td>{driver.mem}</td>
</tr>
<tr>
<td>Submitted</td><td>{driver.submissionDate}</td>
</tr>
<tr>
<td>Supervise</td><td>{driver.supervise}</td>
</tr>
<tr>
<td>Id</td><td>{driver.submissionId}</td>
</tr>
<tr>
<td>Cores</td><td>{driver.cores}</td>
</tr>
<tr>
<td>Memory</td><td>{driver.mem}</td>
</tr>
<tr>
<td>Submitted</td><td>{driver.submissionDate}</td>
</tr>
<tr>
<td>Supervise</td><td>{driver.supervise}</td>
</tr>
}

private def retryRow(retryState: Option[MesosClusterRetryState]): Seq[Node] = {
Expand Down
30 changes: 18 additions & 12 deletions core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,18 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
private val workDir = new File(parent.workDir.toURI.normalize().getPath)
private val supportedLogTypes = Set("stderr", "stdout")

// stripXSS is called first to remove suspicious characters used in XSS attacks
def renderLog(request: HttpServletRequest): String = {
val defaultBytes = 100 * 1024

val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength =
Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
.getOrElse(defaultBytes)

val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand All @@ -57,14 +60,17 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with
pre + logText
}

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val defaultBytes = 100 * 1024
val appId = Option(request.getParameter("appId"))
val executorId = Option(request.getParameter("executorId"))
val driverId = Option(request.getParameter("driverId"))
val logType = request.getParameter("logType")
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
val appId = Option(UIUtils.stripXSS(request.getParameter("appId")))
val executorId = Option(UIUtils.stripXSS(request.getParameter("executorId")))
val driverId = Option(UIUtils.stripXSS(request.getParameter("driverId")))
val logType = UIUtils.stripXSS(request.getParameter("logType"))
val offset = Option(UIUtils.stripXSS(request.getParameter("offset"))).map(_.toLong)
val byteLength =
Option(UIUtils.stripXSS(request.getParameter("byteLength"))).map(_.toInt)
.getOrElse(defaultBytes)

val (logDir, params, pageName) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
Expand Down
23 changes: 23 additions & 0 deletions core/src/main/scala/org/apache/spark/ui/UIUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import scala.util.control.NonFatal
import scala.xml._
import scala.xml.transform.{RewriteRule, RuleTransformer}

import org.apache.spark.Logging
import org.apache.commons.lang3.StringEscapeUtils

import org.apache.spark.Logging
import org.apache.spark.ui.scope.RDDOperationGraph

Expand All @@ -34,6 +37,8 @@ private[spark] object UIUtils extends Logging {
val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
val TABLE_CLASS_STRIPED_SORTABLE = TABLE_CLASS_STRIPED + " sortable"

private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r

// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
Expand Down Expand Up @@ -467,4 +472,22 @@ private[spark] object UIUtils extends Logging {
}
param
}

/**
* Remove suspicious characters of user input to prevent Cross-Site scripting (XSS) attacks
*
* For more information about XSS testing:
* https://www.owasp.org/index.php/XSS_Filter_Evasion_Cheat_Sheet and
* https://www.owasp.org/index.php/Testing_for_Reflected_Cross_site_scripting_(OTG-INPVAL-001)
*/
def stripXSS(requestParameter: String): String = {
if (requestParameter == null) {
null
} else {
// Remove new lines and single quotes, followed by escaping HTML version 4.0
StringEscapeUtils.escapeHtml4(
NEWLINE_AND_SINGLE_QUOTE_REGEX.replaceAllIn(requestParameter, ""))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage

private val sc = parent.sc

// stripXSS is called first to remove suspicious characters used in XSS attacks
def render(request: HttpServletRequest): Seq[Node] = {
val executorId = Option(request.getParameter("executorId")).map { executorId =>
val executorId =
Option(UIUtils.stripXSS(request.getParameter("executorId"))).map { executorId =>
UIUtils.decodeURLParameter(executorId)
}.getOrElse {
throw new IllegalArgumentException(s"Missing executorId parameter")
Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
val listener = parent.jobProgresslistener

listener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val jobId = parameterId.toInt
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/

package org.apache.spark.ui.jobs

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}

Expand All @@ -33,4 +32,5 @@ private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {

attachPage(new AllJobsPage(this))
attachPage(new JobPage(this))

}
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {

def render(request: HttpServletRequest): Seq[Node] = {
listener.synchronized {
val poolName = Option(request.getParameter("poolname")).map { poolname =>
// stripXSS is called first to remove suspicious characters used in XSS attacks
val poolName = Option(UIUtils.stripXSS(request.getParameter("poolname"))).map { poolname =>
UIUtils.decodeURLParameter(poolname)
}.getOrElse {
throw new IllegalArgumentException(s"Missing poolname parameter")
Expand Down
13 changes: 7 additions & 6 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,16 +87,17 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {

def render(request: HttpServletRequest): Seq[Node] = {
progressListener.synchronized {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterAttempt = request.getParameter("attempt")
val parameterAttempt = UIUtils.stripXSS(request.getParameter("attempt"))
require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter")

val parameterTaskPage = request.getParameter("task.page")
val parameterTaskSortColumn = request.getParameter("task.sort")
val parameterTaskSortDesc = request.getParameter("task.desc")
val parameterTaskPageSize = request.getParameter("task.pageSize")
val parameterTaskPage = UIUtils.stripXSS(request.getParameter("task.page"))
val parameterTaskSortColumn = UIUtils.stripXSS(request.getParameter("task.sort"))
val parameterTaskSortDesc = UIUtils.stripXSS(request.getParameter("task.desc"))
val parameterTaskPageSize = UIUtils.stripXSS(request.getParameter("task.pageSize"))

val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1)
val taskSortColumn = Option(parameterTaskSortColumn).map { sortColumn =>
Expand Down
8 changes: 5 additions & 3 deletions core/src/main/scala/org/apache/spark/ui/jobs/StagesTab.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ui.jobs
import javax.servlet.http.HttpServletRequest

import org.apache.spark.scheduler.SchedulingMode
import org.apache.spark.ui.{SparkUI, SparkUITab}
import org.apache.spark.ui.{SparkUI, SparkUITab, UIUtils}

/** Web UI showing progress status of all stages in the given SparkContext. */
private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
Expand All @@ -38,8 +38,10 @@ private[ui] class StagesTab(parent: SparkUI) extends SparkUITab(parent, "stages"

def handleKillRequest(request: HttpServletRequest): Unit = {
if (killEnabled && parent.securityManager.checkModifyPermissions(request.getRemoteUser)) {
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
val killFlag = Option(UIUtils.stripXSS(request.getParameter("terminate")))
.getOrElse("false").toBoolean
// stripXSS is called first to remove suspicious characters used in XSS attacks
val stageId = Option(UIUtils.stripXSS(request.getParameter("id"))).getOrElse("-1").toInt
if (stageId >= 0 && killFlag && progressListener.activeStages.contains(stageId)) {
sc.get.cancelStage(stageId)
}
Expand Down
11 changes: 6 additions & 5 deletions core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
private val listener = parent.listener

def render(request: HttpServletRequest): Seq[Node] = {
val parameterId = request.getParameter("id")
// stripXSS is called first to remove suspicious characters used in XSS attacks
val parameterId = UIUtils.stripXSS(request.getParameter("id"))
require(parameterId != null && parameterId.nonEmpty, "Missing id parameter")

val parameterBlockPage = request.getParameter("block.page")
val parameterBlockSortColumn = request.getParameter("block.sort")
val parameterBlockSortDesc = request.getParameter("block.desc")
val parameterBlockPageSize = request.getParameter("block.pageSize")
val parameterBlockPage = UIUtils.stripXSS(request.getParameter("block.page"))
val parameterBlockSortColumn = UIUtils.stripXSS(request.getParameter("block.sort"))
val parameterBlockSortDesc = UIUtils.stripXSS(request.getParameter("block.desc"))
val parameterBlockPageSize = UIUtils.stripXSS(request.getParameter("block.pageSize"))

val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1)
val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name")
Expand Down
Loading