Skip to content

Commit

Permalink
[SPARK-49846][SS] Add numUpdatedStateRows and numRemovedStateRows met…
Browse files Browse the repository at this point in the history
…rics for use with transformWithState operator

### What changes were proposed in this pull request?
Add numUpdatedStateRows and numRemovedStateRows metrics for use with transformWithState operator

### Why are the changes needed?
Without this change, metrics around these operations are not available in the query progress metrics

### Does this PR introduce _any_ user-facing change?
No

Metrics updated as part of the streaming query progress

```
    "operatorName" : "transformWithStateExec",
    "numRowsTotal" : 1,
    "numRowsUpdated" : 1,
    "numRowsRemoved" : 1,
```

### How was this patch tested?
Added unit tests

```
[info] Run completed in 25 seconds, 697 milliseconds.
[info] Total number of tests run: 2
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 2, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#48317 from anishshri-db/task/SPARK-49846.

Authored-by: Anish Shrigondekar <anish.shrigondekar@databricks.com>
Signed-off-by: Jungtaek Lim <kabhwan.opensource@gmail.com>
  • Loading branch information
anishshri-db authored and himadripal committed Oct 19, 2024
1 parent 3d2623e commit d31a474
Show file tree
Hide file tree
Showing 14 changed files with 266 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.ListState
import org.apache.spark.sql.types.StructType

/**
* Provides concrete implementation for list of values associated with a state variable
Expand All @@ -30,14 +32,22 @@ import org.apache.spark.sql.streaming.ListState
* @param stateName - name of logical state partition
* @param keyExprEnc - Spark SQL encoder for key
* @param valEncoder - Spark SQL encoder for value
* @param metrics - metrics to be updated as part of stateful processing
* @tparam S - data type of object that will be stored in the list
*/
class ListStateImpl[S](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
valEncoder: Encoder[S])
extends ListState[S] with Logging {
valEncoder: Encoder[S],
metrics: Map[String, SQLMetric] = Map.empty)
extends ListStateMetricsImpl
with ListState[S]
with Logging {

override def stateStore: StateStore = store
override def baseStateName: String = stateName
override def exprEncSchema: StructType = keyExprEnc.schema

private val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder, stateName)

Expand Down Expand Up @@ -76,39 +86,56 @@ class ListStateImpl[S](

val encodedKey = stateTypesEncoder.encodeGroupingKey()
var isFirst = true
var entryCount = 0L
TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")

newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v)
if (isFirst) {
store.put(encodedKey, encodedValue, stateName)
isFirst = false
} else {
store.merge(encodedKey, encodedValue, stateName)
store.merge(encodedKey, encodedValue, stateName)
}
entryCount += 1
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
}
updateEntryCount(encodedKey, entryCount)
}

/** Append an entry to the list. */
override def appendValue(newState: S): Unit = {
StateStoreErrors.requireNonNullStateValue(newState, stateName)
store.merge(stateTypesEncoder.encodeGroupingKey(),
val encodedKey = stateTypesEncoder.encodeGroupingKey()
val entryCount = getEntryCount(encodedKey)
store.merge(encodedKey,
stateTypesEncoder.encodeValue(newState), stateName)
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
updateEntryCount(encodedKey, entryCount + 1)
}

/** Append an entire list to the existing value. */
override def appendList(newState: Array[S]): Unit = {
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
var entryCount = getEntryCount(encodedKey)
newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v)
store.merge(encodedKey, encodedValue, stateName)
entryCount += 1
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
}
updateEntryCount(encodedKey, entryCount)
}

/** Remove this state. */
override def clear(): Unit = {
store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
val encodedKey = stateTypesEncoder.encodeGroupingKey()
store.remove(encodedKey, stateName)
val entryCount = getEntryCount(encodedKey)
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount)
removeEntryCount(encodedKey)
}

private def validateNewState(newState: Array[S]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.NextIterator

/**
Expand All @@ -34,6 +36,7 @@ import org.apache.spark.util.NextIterator
* @param valEncoder - Spark SQL encoder for value
* @param ttlConfig - TTL configuration for values stored in this state
* @param batchTimestampMs - current batch processing timestamp.
* @param metrics - metrics to be updated as part of stateful processing
* @tparam S - data type of object that will be stored
*/
class ListStateImplWithTTL[S](
Expand All @@ -42,9 +45,15 @@ class ListStateImplWithTTL[S](
keyExprEnc: ExpressionEncoder[Any],
valEncoder: Encoder[S],
ttlConfig: TTLConfig,
batchTimestampMs: Long)
extends SingleKeyTTLStateImpl(
stateName, store, keyExprEnc, batchTimestampMs) with ListState[S] {
batchTimestampMs: Long,
metrics: Map[String, SQLMetric] = Map.empty)
extends SingleKeyTTLStateImpl(stateName, store, keyExprEnc, batchTimestampMs)
with ListStateMetricsImpl
with ListState[S] {

override def stateStore: StateStore = store
override def baseStateName: String = stateName
override def exprEncSchema: StructType = keyExprEnc.schema

private lazy val stateTypesEncoder = StateTypesEncoder(keyExprEnc, valEncoder,
stateName, hasTtl = true)
Expand Down Expand Up @@ -99,6 +108,8 @@ class ListStateImplWithTTL[S](

val encodedKey = stateTypesEncoder.encodeGroupingKey()
var isFirst = true
var entryCount = 0L
TWSMetricsUtils.resetMetric(metrics, "numUpdatedStateRows")

newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
Expand All @@ -108,34 +119,48 @@ class ListStateImplWithTTL[S](
} else {
store.merge(encodedKey, encodedValue, stateName)
}
entryCount += 1
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
}
upsertTTLForStateKey(encodedKey)
updateEntryCount(encodedKey, entryCount)
}

/** Append an entry to the list. */
override def appendValue(newState: S): Unit = {
StateStoreErrors.requireNonNullStateValue(newState, stateName)
val encodedKey = stateTypesEncoder.encodeGroupingKey()
val entryCount = getEntryCount(encodedKey)
store.merge(encodedKey,
stateTypesEncoder.encodeValue(newState, ttlExpirationMs), stateName)
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
upsertTTLForStateKey(encodedKey)
updateEntryCount(encodedKey, entryCount + 1)
}

/** Append an entire list to the existing value. */
override def appendList(newState: Array[S]): Unit = {
validateNewState(newState)

val encodedKey = stateTypesEncoder.encodeGroupingKey()
var entryCount = getEntryCount(encodedKey)
newState.foreach { v =>
val encodedValue = stateTypesEncoder.encodeValue(v, ttlExpirationMs)
store.merge(encodedKey, encodedValue, stateName)
entryCount += 1
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
}
upsertTTLForStateKey(encodedKey)
updateEntryCount(encodedKey, entryCount)
}

/** Remove this state. */
override def clear(): Unit = {
store.remove(stateTypesEncoder.encodeGroupingKey(), stateName)
val encodedKey = stateTypesEncoder.encodeGroupingKey()
store.remove(encodedKey, stateName)
val entryCount = getEntryCount(encodedKey)
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", entryCount)
removeEntryCount(encodedKey)
clearTTLState()
}

Expand All @@ -158,7 +183,9 @@ class ListStateImplWithTTL[S](
val unsafeRowValuesIterator = store.valuesIterator(groupingKey, stateName)
// We clear the list, and use the iterator to put back all of the non-expired values
store.remove(groupingKey, stateName)
removeEntryCount(groupingKey)
var isFirst = true
var entryCount = 0L
unsafeRowValuesIterator.foreach { encodedValue =>
if (!stateTypesEncoder.isExpired(encodedValue, batchTimestampMs)) {
if (isFirst) {
Expand All @@ -167,10 +194,13 @@ class ListStateImplWithTTL[S](
} else {
store.merge(groupingKey, encodedValue, stateName)
}
entryCount += 1
} else {
numValuesExpired += 1
}
}
updateEntryCount(groupingKey, entryCount)
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows", numValuesExpired)
numValuesExpired
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore}
import org.apache.spark.sql.types._

/**
* Trait that provides helper methods to maintain metrics for a list state.
* For list state, we keep track of the count of entries in the list in a separate column family
* to get an accurate view of the number of entries that are updated/removed from the list and
* reported as part of the query progress metrics.
*/
trait ListStateMetricsImpl {
def stateStore: StateStore

def baseStateName: String

def exprEncSchema: StructType

// We keep track of the count of entries in the list in a separate column family
// to avoid scanning the entire list to get the count.
private val counterCFValueSchema: StructType =
StructType(Seq(StructField("count", LongType, nullable = false)))

private val counterCFProjection = UnsafeProjection.create(counterCFValueSchema)

private val updatedCountRow = new GenericInternalRow(1)

private def getRowCounterCFName(stateName: String) = "$rowCounter_" + stateName

stateStore.createColFamilyIfAbsent(getRowCounterCFName(baseStateName), exprEncSchema,
counterCFValueSchema, NoPrefixKeyStateEncoderSpec(exprEncSchema), isInternal = true)

/**
* Function to get the number of entries in the list state for a given grouping key
* @param encodedKey - encoded grouping key
* @return - number of entries in the list state
*/
def getEntryCount(encodedKey: UnsafeRow): Long = {
val countRow = stateStore.get(encodedKey, getRowCounterCFName(baseStateName))
if (countRow != null) {
countRow.getLong(0)
} else {
0L
}
}

/**
* Function to update the number of entries in the list state for a given grouping key
* @param encodedKey - encoded grouping key
* @param updatedCount - updated count of entries in the list state
*/
def updateEntryCount(
encodedKey: UnsafeRow,
updatedCount: Long): Unit = {
updatedCountRow.setLong(0, updatedCount)
stateStore.put(encodedKey,
counterCFProjection(updatedCountRow.asInstanceOf[InternalRow]),
getRowCounterCFName(baseStateName))
}

/**
* Function to remove the number of entries in the list state for a given grouping key
* @param encodedKey - encoded grouping key
*/
def removeEntryCount(encodedKey: UnsafeRow): Unit = {
stateStore.remove(encodedKey, getRowCounterCFName(baseStateName))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,30 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchemaUtils._
import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.streaming.MapState
import org.apache.spark.sql.types.StructType

/**
* Class that provides a concrete implementation for map state associated with state
* variables used in the streaming transformWithState operator.
* @param store - reference to the StateStore instance to be used for storing state
* @param stateName - name of logical state partition
* @param keyExprEnc - Spark SQL encoder for key
* @param valEncoder - Spark SQL encoder for value
* @param metrics - metrics to be updated as part of stateful processing
* @tparam K - type of key for map state variable
* @tparam V - type of value for map state variable
*/
class MapStateImpl[K, V](
store: StateStore,
stateName: String,
keyExprEnc: ExpressionEncoder[Any],
userKeyEnc: Encoder[K],
valEncoder: Encoder[V]) extends MapState[K, V] with Logging {
valEncoder: Encoder[V],
metrics: Map[String, SQLMetric] = Map.empty) extends MapState[K, V] with Logging {

// Pack grouping key and user key together as a prefixed composite key
private val schemaForCompositeKeyRow: StructType = {
Expand Down Expand Up @@ -70,6 +83,7 @@ class MapStateImpl[K, V](
val encodedValue = stateTypesEncoder.encodeValue(value)
val encodedCompositeKey = stateTypesEncoder.encodeCompositeKey(key)
store.put(encodedCompositeKey, encodedValue, stateName)
TWSMetricsUtils.incrementMetric(metrics, "numUpdatedStateRows")
}

/** Get the map associated with grouping key */
Expand Down Expand Up @@ -98,6 +112,9 @@ class MapStateImpl[K, V](
StateStoreErrors.requireNonNullStateValue(key, stateName)
val compositeKey = stateTypesEncoder.encodeCompositeKey(key)
store.remove(compositeKey, stateName)
// Note that for mapState, the rows are flattened. So we count the number of rows removed
// proportional to the number of keys in the map per grouping key.
TWSMetricsUtils.incrementMetric(metrics, "numRemovedStateRows")
}

/** Remove this state. */
Expand Down
Loading

0 comments on commit d31a474

Please # to comment.