Skip to content

Commit

Permalink
[SPARK-6910] [SQL] Support for pushing predicates down to metastore f…
Browse files Browse the repository at this point in the history
…or partition pruning

This PR supersedes my old one #6921. Since my patch has changed quite a bit, I am opening a new PR to make it easier to review.

The changes include-
* Implement `toMetastoreFilter()` function in `HiveShim` that takes `Seq[Expression]` and converts them into a filter string for Hive metastore.
 * This functions matches all the `AttributeReference` + `BinaryComparisonOp` + `Integral/StringType` patterns in `Seq[Expression]` and fold them into a string.
* Change `hiveQlPartitions` field in `MetastoreRelation` to `getHiveQlPartitions()` function that takes a filter string parameter.
* Call `getHiveQlPartitions()` in `HiveTableScan` with a filter string.

But there are some cases in which predicate pushdown is disabled-

Case | Predicate pushdown
------- | -----------------------------
Hive integral and string types | Yes
Hive varchar type | No
Hive 0.13 and newer | Yes
Hive 0.12 and older | No
convertMetastoreParquet=false | Yes
convertMetastoreParquet=true | No

In case of `convertMetastoreParquet=true`, predicates are not pushed down because this conversion happens in an `Analyzer` rule (`HiveMetastoreCatalog.ParquetConversions`). At this point, `HiveTableScan` hasn't run, so predicates are not available. But reading the source code, I think it is intentional to convert the entire Hive table w/ all the partitions into `ParquetRelation` because then `ParquetRelation` can be cached and reused for any query against that table. Please correct me if I am wrong.

cc marmbrus

Author: Cheolsoo Park <cheolsoop@netflix.com>

Closes #7216 from piaozhexiu/SPARK-6910-2 and squashes the following commits:

aa1490f [Cheolsoo Park] Fix ordering of imports
c212c4d [Cheolsoo Park] Incorporate review comments
5e93f9d [Cheolsoo Park] Predicate pushdown into Hive metastore
  • Loading branch information
Cheolsoo Park authored and marmbrus committed Jul 14, 2015
1 parent b7bcbe2 commit 408b384
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
val partitions = metastoreRelation.hiveQlPartitions.map { p =>
// We're converting the entire table into ParquetRelation, so predicates to Hive metastore
// are empty.
val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
val location = p.getLocation
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
Expand Down Expand Up @@ -644,32 +646,6 @@ private[hive] case class MetastoreRelation
new Table(tTable)
}

@transient val hiveQlPartitions: Seq[Partition] = table.getAllPartitions.map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
tPartition.setValues(p.values)

val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
tPartition.setSd(sd)
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))

sd.setLocation(p.storage.location)
sd.setInputFormat(p.storage.inputFormat)
sd.setOutputFormat(p.storage.outputFormat)

val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
sd.setSerdeInfo(serdeInfo)
serdeInfo.setSerializationLib(p.storage.serde)

val serdeParameters = new java.util.HashMap[String, String]()
serdeInfo.setParameters(serdeParameters)
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }

new Partition(hiveQlTable, tPartition)
}

@transient override lazy val statistics: Statistics = Statistics(
sizeInBytes = {
val totalSize = hiveQlTable.getParameters.get(StatsSetupConst.TOTAL_SIZE)
Expand All @@ -690,6 +666,34 @@ private[hive] case class MetastoreRelation
}
)

def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
table.getPartitions(predicates).map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
tPartition.setValues(p.values)

val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor()
tPartition.setSd(sd)
sd.setCols(table.schema.map(c => new FieldSchema(c.name, c.hiveType, c.comment)))

sd.setLocation(p.storage.location)
sd.setInputFormat(p.storage.inputFormat)
sd.setOutputFormat(p.storage.outputFormat)

val serdeInfo = new org.apache.hadoop.hive.metastore.api.SerDeInfo
sd.setSerdeInfo(serdeInfo)
serdeInfo.setSerializationLib(p.storage.serde)

val serdeParameters = new java.util.HashMap[String, String]()
serdeInfo.setParameters(serdeParameters)
table.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }
p.storage.serdeProperties.foreach { case (k, v) => serdeParameters.put(k, v) }

new Partition(hiveQlTable, tPartition)
}
}

/** Only compare database and tablename, not alias. */
override def sameResult(plan: LogicalPlan): Boolean = {
plan match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.reflect.ClassTag

import com.esotericsoftware.kryo.Kryo
import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private[hive] trait HiveStrategies {
InterpretedPredicate.create(castedPredicate)
}

val partitions = relation.hiveQlPartitions.filter { part =>
val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part =>
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
Expand Down Expand Up @@ -213,7 +213,7 @@ private[hive] trait HiveStrategies {
projectList,
otherPredicates,
identity[Seq[Expression]],
HiveTableScan(_, relation, pruningPredicates.reduceLeftOption(And))(hiveContext)) :: Nil
HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil
case _ =>
Nil
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.PrintStream
import java.util.{Map => JMap}

import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression

private[hive] case class HiveDatabase(
name: String,
Expand Down Expand Up @@ -71,7 +72,12 @@ private[hive] case class HiveTable(

def isPartitioned: Boolean = partitionColumns.nonEmpty

def getAllPartitions: Seq[HivePartition] = client.getAllPartitions(this)
def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = {
predicates match {
case Nil => client.getAllPartitions(this)
case _ => client.getPartitionsByFilter(this, predicates)
}
}

// Hive does not support backticks when passing names to the client.
def qualifiedName: String = s"$database.$name"
Expand Down Expand Up @@ -132,6 +138,9 @@ private[hive] trait ClientInterface {
/** Returns all partitions for the given table. */
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]

/** Returns partitions filtered by predicates for the given table. */
def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition]

/** Loads a static partition into an existing table. */
def loadPartition(
loadPath: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,24 @@

package org.apache.spark.sql.hive.client

import java.io.{BufferedReader, InputStreamReader, File, PrintStream}
import java.net.URI
import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
import java.io.{File, PrintStream}
import java.util.{Map => JMap}
import javax.annotation.concurrent.GuardedBy

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.util.CircularBuffer

import scala.collection.JavaConversions._
import scala.language.reflectiveCalls

import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.metastore.api.Database
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema}
import org.apache.hadoop.hive.metastore.{TableType => HTableType}
import org.apache.hadoop.hive.metastore.api
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.metadata
import org.apache.hadoop.hive.ql.metadata.Hive
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.ql.{Driver, metadata}

import org.apache.spark.Logging
import org.apache.spark.sql.execution.QueryExecutionException
Expand Down Expand Up @@ -316,6 +313,13 @@ private[hive] class ClientWrapper(
shim.getAllPartitions(client, qlTable).map(toHivePartition)
}

override def getPartitionsByFilter(
hTable: HiveTable,
predicates: Seq[Expression]): Seq[HivePartition] = withHiveState {
val qlTable = toQlTable(hTable)
shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition)
}

override def listTables(dbName: String): Seq[String] = withHiveState {
client.getAllTables(dbName)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
import org.apache.spark.sql.types.{StringType, IntegralType}

/**
* A shim that defines the interface between ClientWrapper and the underlying Hive library used to
Expand Down Expand Up @@ -61,6 +66,8 @@ private[client] sealed abstract class Shim {

def getAllPartitions(hive: Hive, table: Table): Seq[Partition]

def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]

def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor

def getDriverResults(driver: Driver): Seq[String]
Expand Down Expand Up @@ -109,7 +116,7 @@ private[client] sealed abstract class Shim {

}

private[client] class Shim_v0_12 extends Shim {
private[client] class Shim_v0_12 extends Shim with Logging {

private lazy val startMethod =
findStaticMethod(
Expand Down Expand Up @@ -196,6 +203,17 @@ private[client] class Shim_v0_12 extends Shim {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
// See HIVE-4888.
logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
"Please use Hive 0.13 or higher.")
getAllPartitions(hive, table)
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]

Expand Down Expand Up @@ -267,6 +285,12 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
classOf[Hive],
"getAllPartitionsOf",
classOf[Table])
private lazy val getPartitionsByFilterMethod =
findMethod(
classOf[Hive],
"getPartitionsByFilter",
classOf[Table],
classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
Expand All @@ -288,6 +312,48 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = table.getPartitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
val filter = predicates.flatMap { expr =>
expr match {
case op @ BinaryComparison(lhs, rhs) => {
lhs match {
case AttributeReference(_, _, _, _) => {
rhs.dataType match {
case _: IntegralType =>
Some(lhs.prettyString + op.symbol + rhs.prettyString)
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
case _ => None
}
}
case _ => None
}
}
case _ => None
}
}.mkString(" and ")

val partitions =
if (filter.isEmpty) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
} else {
logDebug(s"Hive metastore filter is '$filter'.")
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
}

partitions.toSeq
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[hive]
case class HiveTableScan(
requestedAttributes: Seq[Attribute],
relation: MetastoreRelation,
partitionPruningPred: Option[Expression])(
partitionPruningPred: Seq[Expression])(
@transient val context: HiveContext)
extends LeafNode {

Expand All @@ -56,7 +56,7 @@ case class HiveTableScan(

// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
private[this] val boundPruningPred = partitionPruningPred.map { pred =>
private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
require(
pred.dataType == BooleanType,
s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
Expand Down Expand Up @@ -133,7 +133,8 @@ case class HiveTableScan(
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
hadoopReader.makeRDDForPartitionedTable(prunePartitions(relation.hiveQlPartitions))
hadoopReader.makeRDDForPartitionedTable(
prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
}

override def output: Seq[Attribute] = attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.hive.client
import java.io.File

import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -151,6 +153,12 @@ class VersionsSuite extends SparkFunSuite with Logging {
client.getAllPartitions(client.getTable("default", "src_part"))
}

test(s"$version: getPartitionsByFilter") {
client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(
AttributeReference("key", IntegerType, false)(NamedExpression.newExprId),
Literal(1))))
}

test(s"$version: loadPartition") {
client.loadPartition(
emptyDir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
val partValues = if (relation.table.isPartitioned) {
p.prunePartitions(relation.hiveQlPartitions).map(_.getValues)
p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
} else {
Seq.empty
}
Expand Down

0 comments on commit 408b384

Please # to comment.