Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Profiling tool can miss datasources when they are GPU reads #4804

Merged
merged 17 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,8 @@ import org.apache.commons.lang3.StringUtils
class ProfileOutputWriter(outputDir: String, filePrefix: String, numOutputRows: Int,
outputCSV: Boolean = false) {

private val CSVDelimiter = ","

private val textFileWriter = new ToolTextFileWriter(outputDir,
s"$filePrefix.log", "Profile summary")

Expand Down Expand Up @@ -62,12 +64,13 @@ class ProfileOutputWriter(outputDir: String, filePrefix: String, numOutputRows:
val csvWriter = new ToolTextFileWriter(outputDir,
s"${suffix}.csv", s"$header CSV:")
try {
val headerString = outRows.head.outputHeaders.mkString(",")
val headerString = outRows.head.outputHeaders.mkString(CSVDelimiter)
csvWriter.write(headerString + "\n")
val rows = outRows.map(_.convertToSeq)
rows.foreach { row =>
val formattedRow = row.map(stringIfempty(_))
val outStr = formattedRow.mkString(",")
val delimiterHandledRow = row.map(ProfileUtils.replaceDelimiter(_, CSVDelimiter))
val formattedRow = delimiterHandledRow.map(stringIfempty(_))
val outStr = formattedRow.mkString(CSVDelimiter)
csvWriter.write(outStr + "\n")
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,4 +68,21 @@ object ProfileUtils {
def truncateFailureStr(failureStr: String): String = {
failureStr.substring(0, Math.min(failureStr.size, 100))
}

// if a string contains what we are going to use for a delimiter, replace
// it with something else
def replaceDelimiter(str: String, delimiter: String): String = {
if (str != null && str.contains(delimiter)) {
val replaceWith = if (delimiter.equals(",")) {
";"
} else if (delimiter.equals(";")) {
":"
} else {
";"
}
str.replace(delimiter, replaceWith)
} else {
str
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,7 @@ package com.nvidia.spark.rapids.tool.qualification
import scala.collection.mutable.{Buffer, LinkedHashMap, ListBuffer}

import com.nvidia.spark.rapids.tool.ToolTextFileWriter
import com.nvidia.spark.rapids.tool.profiling.ProfileUtils.replaceDelimiter

import org.apache.spark.sql.rapids.tool.qualification.QualificationSummaryInfo

Expand Down Expand Up @@ -238,23 +239,6 @@ object QualOutputWriter {
QualOutputWriter.constructOutputRowFromMap(headersAndSizes, delimiter, prettyPrint)
}

// if a string contains what we are going to use for a delimiter, replace
// it with something else
private def replaceDelimiter(str: String, delimiter: String): String = {
if (str.contains(delimiter)) {
val replaceWith = if (delimiter.equals(",")) {
";"
} else if (delimiter.equals(";")) {
":"
} else {
";"
}
str.replace(delimiter, replaceWith)
} else {
str
}
}

def constructAppDetailedInfo(
appInfo: QualificationSummaryInfo,
headersAndSizes: LinkedHashMap[String, Int],
Expand Down
81 changes: 36 additions & 45 deletions tools/src/main/scala/org/apache/spark/sql/rapids/tool/AppBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,6 @@ abstract class AppBase(
}
}

def getJdbcInPlan(planInfo: SparkPlanInfo): Seq[SparkPlanInfo] = {
val childRes = planInfo.children.flatMap(getJdbcInPlan(_))
if (planInfo.simpleString != null && planInfo.simpleString.contains("Scan JDBCRelation")) {
childRes :+ planInfo
} else {
childRes
}
}

// strip off the struct<> part that Spark adds to the ReadSchema
private def formatSchemaStr(schema: String): String = {
schema.stripPrefix("struct<").stripSuffix(">")
Expand All @@ -176,62 +167,62 @@ abstract class AppBase(
}
}

protected def checkJdbcScan(sqlID: Long, planInfo: SparkPlanInfo): Unit = {
val allJdbcScan = getJdbcInPlan(planInfo)
if (allJdbcScan.nonEmpty) {
dataSourceInfo += DataSourceCase(sqlID, "JDBC", "unknown", "unknown", "")
// This tries to get just the field specified by tag in a string that
// may contain multiple fields. It looks for a comma to delimit fields.
private def getFieldWithoutTag(str: String, tag: String): String = {
val index = str.indexOf(tag)
// remove the tag from the final string retruned
val subStr = str.substring(index + tag.size)
val endIndex = subStr.indexOf(", ")
if (endIndex != -1) {
subStr.substring(0, endIndex)
} else {
subStr
}
}

// This will find scans for DataSource V2, if the schema is very large it
// will likely be incomplete and have ... at the end.
protected def checkGraphNodeForBatchScan(sqlID: Long, node: SparkPlanGraphNode): Unit = {
if (node.name.equals("BatchScan")) {
protected def checkGraphNodeForReads(sqlID: Long, node: SparkPlanGraphNode): Unit = {
if (node.name.equals("BatchScan") ||
node.name.contains("GpuScan") ||
node.name.contains("GpuBatchScan") ||
node.name.contains("JDBCRelation")) {
val schemaTag = "ReadSchema: "
val schema = if (node.desc.contains(schemaTag)) {
val index = node.desc.indexOf(schemaTag)
if (index != -1) {
val subStr = node.desc.substring(index + schemaTag.size)
val endIndex = subStr.indexOf(", ")
if (endIndex != -1) {
val schemaOnly = subStr.substring(0, endIndex)
formatSchemaStr(schemaOnly)
} else {
""
}
} else {
""
}
formatSchemaStr(getFieldWithoutTag(node.desc, schemaTag))
} else {
""
}
val locationTag = "Location:"
val locationTag = "Location: "
val location = if (node.desc.contains(locationTag)) {
val index = node.desc.indexOf(locationTag)
val subStr = node.desc.substring(index)
val endIndex = subStr.indexOf(", ")
val location = subStr.substring(0, endIndex)
location
getFieldWithoutTag(node.desc, locationTag)
} else if (node.name.contains("JDBCRelation")) {
// see if we can report table or query
val JDBCPattern = raw".*JDBCRelation\((.*)\).*".r
node.name match {
case JDBCPattern(tableName) => tableName
case _ => "unknown"
}
} else {
"unknown"
}
val pushedFilterTag = "PushedFilters:"
val pushedFilterTag = "PushedFilters: "
val pushedFilters = if (node.desc.contains(pushedFilterTag)) {
val index = node.desc.indexOf(pushedFilterTag)
val subStr = node.desc.substring(index)
val endIndex = subStr.indexOf("]")
val filters = subStr.substring(0, endIndex + 1)
filters
getFieldWithoutTag(node.desc, pushedFilterTag)
} else {
"unknown"
}
val formatTag = "Format: "
val fileFormat = if (node.desc.contains(formatTag)) {
val index = node.desc.indexOf(formatTag)
val subStr = node.desc.substring(index + formatTag.size)
val endIndex = subStr.indexOf(", ")
val format = subStr.substring(0, endIndex)
format
val format = getFieldWithoutTag(node.desc, formatTag)
if (node.name.startsWith("Gpu")) {
s"${format}(GPU)"
} else {
format
}
} else if (node.name.contains("JDBCRelation")) {
"JDBC"
} else {
"unknown"
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -264,13 +264,12 @@ class ApplicationInfo(
def processSQLPlanMetrics(): Unit = {
for ((sqlID, planInfo) <- sqlPlan) {
checkMetadataForReadSchema(sqlID, planInfo)
checkJdbcScan(sqlID, planInfo)
val planGraph = SparkPlanGraph(planInfo)
// SQLPlanMetric is a case Class of
// (name: String,accumulatorId: Long,metricType: String)
val allnodes = planGraph.allNodes
for (node <- allnodes) {
checkGraphNodeForBatchScan(sqlID, node)
checkGraphNodeForReads(sqlID, node)
if (isDataSetOrRDDPlan(node.desc)) {
sqlIdToInfo.get(sqlID).foreach { sql =>
sqlIDToDataSetOrRDDCase += sqlID
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -227,11 +227,10 @@ class QualificationAppInfo(

private[qualification] def processSQLPlan(sqlID: Long, planInfo: SparkPlanInfo): Unit = {
checkMetadataForReadSchema(sqlID, planInfo)
checkJdbcScan(sqlID, planInfo)
val planGraph = SparkPlanGraph(planInfo)
val allnodes = planGraph.allNodes
for (node <- allnodes) {
checkGraphNodeForBatchScan(sqlID, node)
checkGraphNodeForReads(sqlID, node)
if (isDataSetOrRDDPlan(node.desc)) {
sqlIDToDataSetOrRDDCase += sqlID
}
Expand Down
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,75 @@ class ApplicationInfoSuite extends FunSuite with Logging {
}
}

test("test read GPU datasourcev1") {
TrampolineUtil.withTempDir { tempOutputDir =>
var apps: ArrayBuffer[ApplicationInfo] = ArrayBuffer[ApplicationInfo]()
val appArgs = new ProfileArgs(Array(s"$logDir/eventlog-gpu-dsv1.zstd"))
var index: Int = 1
val eventlogPaths = appArgs.eventlog()
for (path <- eventlogPaths) {
apps += new ApplicationInfo(hadoopConf,
EventLogPathProcessor.getEventLogInfo(path,
sparkSession.sparkContext.hadoopConfiguration).head._1, index)
index += 1
}
assert(apps.size == 1)
val collect = new CollectInformation(apps)
val dsRes = collect.getDataSourceInfo
assert(dsRes.size == 5)
val allFormats = dsRes.map { r =>
r.format
}.toSet
val expectedFormats = Set("Text", "CSV(GPU)", "Parquet(GPU)", "ORC(GPU)", "JSON(GPU)")
assert(allFormats.equals(expectedFormats))
val allSchema = dsRes.map { r =>
r.schema
}.toSet
assert(allSchema.forall(_.nonEmpty))
val schemaParquet = dsRes.filter { r =>
r.sqlID == 4
}
assert(schemaParquet.size == 1)
val parquetRow = schemaParquet.head
assert(parquetRow.schema.contains("loan_id"))
}
}

test("test read GPU datasourcev2") {
TrampolineUtil.withTempDir { tempOutputDir =>
var apps: ArrayBuffer[ApplicationInfo] = ArrayBuffer[ApplicationInfo]()
val appArgs = new ProfileArgs(Array(s"$logDir/eventlog-gpu-dsv2.zstd"))
var index: Int = 1
val eventlogPaths = appArgs.eventlog()
for (path <- eventlogPaths) {
apps += new ApplicationInfo(hadoopConf,
EventLogPathProcessor.getEventLogInfo(path,
sparkSession.sparkContext.hadoopConfiguration).head._1, index)
index += 1
}
assert(apps.size == 1)
val collect = new CollectInformation(apps)
val dsRes = collect.getDataSourceInfo
assert(dsRes.size == 5)
val allFormats = dsRes.map { r =>
r.format
}.toSet
val expectedFormats =
Set("Text", "gpucsv(GPU)", "gpujson(GPU)", "gpuparquet(GPU)", "gpuorc(GPU)")
assert(allFormats.equals(expectedFormats))
val allSchema = dsRes.map { r =>
r.schema
}.toSet
assert(allSchema.forall(_.nonEmpty))
val schemaParquet = dsRes.filter { r =>
r.sqlID == 3
}
assert(schemaParquet.size == 1)
val parquetRow = schemaParquet.head
assert(parquetRow.schema.contains("loan_id"))
}
}

test("test read datasourcev1") {
TrampolineUtil.withTempDir { tempOutputDir =>
var apps: ArrayBuffer[ApplicationInfo] = ArrayBuffer[ApplicationInfo]()
Expand Down Expand Up @@ -306,7 +375,13 @@ class ApplicationInfoSuite extends FunSuite with Logging {
val dsRes = collect.getDataSourceInfo
val format = dsRes.map(r => r.format).toSet.mkString
val expectedFormat = "JDBC"
val location = dsRes.map(r => r.location).toSet.mkString
val expectedLocation = "TBLS"
assert(format.equals(expectedFormat))
assert(location.equals(expectedLocation))
dsRes.foreach { r =>
assert(r.schema.contains("bigint"))
}
}
}

Expand Down