Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ class SQLMetric(val metricType: String, initValue: Long = 0L) extends Accumulato
}

object SQLMetrics {
private val SUM_METRIC = "sum"
private val SIZE_METRIC = "size"
private val TIMING_METRIC = "timing"
private val AVERAGE_METRIC = "average"
val SUM_METRIC = "sum"
val SIZE_METRIC = "size"
val TIMING_METRIC = "timing"
val AVERAGE_METRIC = "average"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was to handle exception case while aggregating custom metrics, especially filtering out average since it is not aggregated correctly. Since we remove custom average metric, we no longer need to filter out them. Will revert the change as well as relevant logic.


private val baseForAvgMetric: Int = 10

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state

import java.io._
import java.util.Locale
import java.util.concurrent.atomic.LongAdder

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand Down Expand Up @@ -164,7 +165,8 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}

override def metrics: StateStoreMetrics = {
StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate), Map.empty)
StateStoreMetrics(mapToUpdate.size(), SizeEstimator.estimate(mapToUpdate),
getCustomMetricsForProvider())
}

/**
Expand All @@ -179,6 +181,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}
}

def getCustomMetricsForProvider(): Map[StateStoreCustomMetric, Long] = synchronized {
Map(metricProviderLoaderMapSizeBytes -> SizeEstimator.estimate(loadedMaps),
metricLoadedMapCacheHit -> loadedMapCacheHitCount.sum(),
metricLoadedMapCacheMiss -> loadedMapCacheMissCount.sum())
}

/** Get the state store for making updates to create a new `version` of the store. */
override def getStore(version: Long): StateStore = synchronized {
require(version >= 0, "Version cannot be less than 0")
Expand Down Expand Up @@ -224,7 +232,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
}

override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = {
Nil
metricProviderLoaderMapSizeBytes :: metricLoadedMapCacheHit :: metricLoadedMapCacheMiss :: Nil
}

override def toString(): String = {
Expand All @@ -245,6 +253,21 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf)
private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)

private val loadedMapCacheHitCount: LongAdder = new LongAdder
private val loadedMapCacheMissCount: LongAdder = new LongAdder

private lazy val metricProviderLoaderMapSizeBytes: StateStoreCustomSizeMetric =
StateStoreCustomSizeMetric("providerLoadedMapSizeBytes",
"estimated size of states cache in provider")

private lazy val metricLoadedMapCacheHit: StateStoreCustomMetric =
StateStoreCustomSumMetric("loadedMapCacheHitCount",
"count of cache hit on states cache in provider")

private lazy val metricLoadedMapCacheMiss: StateStoreCustomMetric =
StateStoreCustomSumMetric("loadedMapCacheMissCount",
"count of cache miss on states cache in provider")

private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean)

private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = {
Expand Down Expand Up @@ -276,13 +299,16 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit
// Shortcut if the map for this version is already there to avoid a redundant put.
val loadedCurrentVersionMap = synchronized { loadedMaps.get(version) }
if (loadedCurrentVersionMap.isDefined) {
loadedMapCacheHitCount.increment()
return loadedCurrentVersionMap.get
}

logWarning(s"The state for version $version doesn't exist in loadedMaps. " +
"Reading snapshot file and delta files if needed..." +
"Note that this is normal for the first batch of starting query.")

loadedMapCacheMissCount.increment()

val (result, elapsedMs) = Utils.timeTakenMs {
val snapshotCurrentVersionMap = readSnapshotFile(version)
if (snapshotCurrentVersionMap.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ trait StateStoreCustomMetric {
def name: String
def desc: String
}

case class StateStoreCustomSumMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomSizeMetric(name: String, desc: String) extends StateStoreCustomMetric
case class StateStoreCustomTimingMetric(name: String, desc: String) extends StateStoreCustomMetric

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ class SymmetricHashJoinStateManager(
keyWithIndexToValueMetrics.numKeys, // represent each buffered row only once
keyToNumValuesMetrics.memoryUsedBytes + keyWithIndexToValueMetrics.memoryUsedBytes,
keyWithIndexToValueMetrics.customMetrics.map {
case (s @ StateStoreCustomSumMetric(_, desc), value) =>
s.copy(desc = newDesc(desc)) -> value
case (s @ StateStoreCustomSizeMetric(_, desc), value) =>
s.copy(desc = newDesc(desc)) -> value
case (s @ StateStoreCustomTimingMetric(_, desc), value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,20 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
* the driver after this SparkPlan has been executed and metrics have been updated.
*/
def getProgress(): StateOperatorProgress = {
// average metric is a bit tricky, so hard to aggregate: just exclude them to simplify issue
val avgExcludedCustomMetrics = stateStoreCustomMetrics
.filterNot(_._2.metricType == SQLMetrics.AVERAGE_METRIC)
.map(entry => entry._1 -> longMetric(entry._1).value)

val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] =
new java.util.HashMap(avgExcludedCustomMetrics.mapValues(long2Long).asJava)

new StateOperatorProgress(
numRowsTotal = longMetric("numTotalStateRows").value,
numRowsUpdated = longMetric("numUpdatedStateRows").value,
memoryUsedBytes = longMetric("stateMemory").value)
memoryUsedBytes = longMetric("stateMemory").value,
javaConvertedCustomMetrics
)
}

/** Records the duration of running `body` for the next query progress update. */
Expand All @@ -115,6 +125,8 @@ trait StateStoreWriter extends StatefulOperator { self: SparkPlan =>
private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
val provider = StateStoreProvider.create(sqlContext.conf.stateStoreProviderClass)
provider.supportedCustomMetrics.map {
case StateStoreCustomSumMetric(name, desc) =>
name -> SQLMetrics.createMetric(sparkContext, desc)
case StateStoreCustomSizeMetric(name, desc) =>
name -> SQLMetrics.createSizeMetric(sparkContext, desc)
case StateStoreCustomTimingMetric(name, desc) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ import org.apache.spark.annotation.InterfaceStability
class StateOperatorProgress private[sql](
val numRowsTotal: Long,
val numRowsUpdated: Long,
val memoryUsedBytes: Long
val memoryUsedBytes: Long,
val customMetrics: ju.Map[String, JLong] = new ju.HashMap()
) extends Serializable {

/** The compact JSON representation of this progress. */
Expand All @@ -48,12 +49,24 @@ class StateOperatorProgress private[sql](
def prettyJson: String = pretty(render(jsonValue))

private[sql] def copy(newNumRowsUpdated: Long): StateOperatorProgress =
new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes)
new StateOperatorProgress(numRowsTotal, newNumRowsUpdated, memoryUsedBytes, customMetrics)

private[sql] def jsonValue: JValue = {
("numRowsTotal" -> JInt(numRowsTotal)) ~
("numRowsUpdated" -> JInt(numRowsUpdated)) ~
("memoryUsedBytes" -> JInt(memoryUsedBytes))
def safeMapToJValue[T](map: ju.Map[String, T], valueToJValue: T => JValue): JValue = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T is always Long. Why make a generic function for that? This does not even need a separate function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've first trying to leverage StreamingQueryProgress.safeMapToJValue but can't find proper place to move to be co-used, so I simply copied it. Will simplify the code block and inline.

if (map.isEmpty) return JNothing
val keys = map.keySet.asScala.toSeq.sorted
keys.map { k => k -> valueToJValue(map.get(k)) : JObject }.reduce(_ ~ _)
}

val jsonVal = ("numRowsTotal" -> JInt(numRowsTotal)) ~
("numRowsUpdated" -> JInt(numRowsUpdated)) ~
("memoryUsedBytes" -> JInt(memoryUsedBytes))

if (!customMetrics.isEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are already handling the case of map being empty in safeMapToJValue by adding JNothing. Doesnt JNothing values just get dropped from the json text any way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually didn't notice that. Thanks for letting me know! Will simplify.

jsonVal ~ ("customMetrics" -> safeMapToJValue[JLong](customMetrics, v => JInt(v.toLong)))
} else {
jsonVal
}
}

override def toString: String = prettyJson
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,90 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider]
assert(CreateAtomicTestManager.cancelCalledInCreateAtomic)
}

test("expose metrics with custom metrics to StateStoreMetrics") {
def getCustomMetric(metrics: StateStoreMetrics, name: String): Long = {
val metricPair = metrics.customMetrics.find(_._1.name == name)
assert(metricPair.isDefined)
metricPair.get._2
}

def getLoadedMapSizeMetric(metrics: StateStoreMetrics): Long = {
getCustomMetric(metrics, "providerLoadedMapSizeBytes")
}

def assertCacheHitAndMiss(
metrics: StateStoreMetrics,
expectedCacheHitCount: Long,
expectedCacheMissCount: Long): Unit = {
val cacheHitCount = getCustomMetric(metrics, "loadedMapCacheHitCount")
val cacheMissCount = getCustomMetric(metrics, "loadedMapCacheMissCount")
assert(cacheHitCount === expectedCacheHitCount)
assert(cacheMissCount === expectedCacheMissCount)
}

val provider = newStoreProvider()

// Verify state before starting a new set of updates
assert(getLatestData(provider).isEmpty)

val store = provider.getStore(0)
assert(!store.hasCommitted)

assert(store.metrics.numKeys === 0)

val initialLoadedMapSize = getLoadedMapSizeMetric(store.metrics)
assert(initialLoadedMapSize >= 0)
assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0)

put(store, "a", 1)
assert(store.metrics.numKeys === 1)

put(store, "b", 2)
put(store, "aa", 3)
assert(store.metrics.numKeys === 3)
remove(store, _.startsWith("a"))
assert(store.metrics.numKeys === 1)
assert(store.commit() === 1)

assert(store.hasCommitted)

val loadedMapSizeForVersion1 = getLoadedMapSizeMetric(store.metrics)
assert(loadedMapSizeForVersion1 > initialLoadedMapSize)
assertCacheHitAndMiss(store.metrics, expectedCacheHitCount = 0, expectedCacheMissCount = 0)

val storeV2 = provider.getStore(1)
assert(!storeV2.hasCommitted)
assert(storeV2.metrics.numKeys === 1)

put(storeV2, "cc", 4)
assert(storeV2.metrics.numKeys === 2)
assert(storeV2.commit() === 2)

assert(storeV2.hasCommitted)

val loadedMapSizeForVersion1And2 = getLoadedMapSizeMetric(storeV2.metrics)
assert(loadedMapSizeForVersion1And2 > loadedMapSizeForVersion1)
assertCacheHitAndMiss(storeV2.metrics, expectedCacheHitCount = 1, expectedCacheMissCount = 0)

val reloadedProvider = newStoreProvider(store.id)
// intended to load version 2 instead of 1
// version 2 will not be loaded to the cache in provider
val reloadedStore = reloadedProvider.getStore(1)
assert(reloadedStore.metrics.numKeys === 1)

assert(getLoadedMapSizeMetric(reloadedStore.metrics) === loadedMapSizeForVersion1)
assertCacheHitAndMiss(reloadedStore.metrics, expectedCacheHitCount = 0,
expectedCacheMissCount = 1)

// now we are loading version 2
val reloadedStoreV2 = reloadedProvider.getStore(2)
assert(reloadedStoreV2.metrics.numKeys === 2)

assert(getLoadedMapSizeMetric(reloadedStoreV2.metrics) > loadedMapSizeForVersion1)
assertCacheHitAndMiss(reloadedStoreV2.metrics, expectedCacheHitCount = 0,
expectedCacheMissCount = 2)
}

override def newStoreProvider(): HDFSBackedStateStoreProvider = {
newStoreProvider(opId = Random.nextInt(), partition = 0)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
test("event ordering") {
val listener = new EventCollector
withListenerAdded(listener) {
for (i <- 1 to 100) {
for (i <- 1 to 50) {
Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Jun 7, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the patch this test starts failing: it just means there's more time needed to run this loop 100 times. It doesn't mean the logic is broken. Decreasing number works for me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, and I agree with the implicit claim that this slowdown isn't too worrying.

listener.reset()
require(listener.startEvent === null)
testStream(MemoryStream[Int].toDS)(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| "stateOperators" : [ {
| "numRowsTotal" : 0,
| "numRowsUpdated" : 1,
| "memoryUsedBytes" : 2
| "memoryUsedBytes" : 2,
| "customMetrics" : {
| "loadedMapCacheHitCount" : 1,
| "loadedMapCacheMissCount" : 0,
| "providerLoadedMapSizeBytes" : 3
| }
| } ],
| "sources" : [ {
| "description" : "source",
Expand Down Expand Up @@ -230,7 +235,11 @@ object StreamingQueryStatusAndProgressSuite {
"avg" -> "2016-12-05T20:54:20.827Z",
"watermark" -> "2016-12-05T20:54:20.827Z").asJava),
stateOperators = Array(new StateOperatorProgress(
numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2)),
numRowsTotal = 0, numRowsUpdated = 1, memoryUsedBytes = 2,
customMetrics = new java.util.HashMap(Map("providerLoadedMapSizeBytes" -> 3L,
"loadedMapCacheHitCount" -> 1L, "loadedMapCacheMissCount" -> 0L)
.mapValues(long2Long).asJava)
)),
sources = Array(
new SourceProgress(
description = "source",
Expand Down