Skip to content

Commit

Permalink
Incorporate review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Cheolsoo Park committed Jul 13, 2015
1 parent 5e93f9d commit c212c4d
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.hive.client._
import org.apache.spark.sql.hive.execution.HiveTableScan
import org.apache.spark.sql.parquet.ParquetRelation2
import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -302,9 +301,9 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
val result = if (metastoreRelation.hiveQlTable.isPartitioned) {
val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
val partitionColumnDataTypes = partitionSchema.map(_.dataType)
// We're converting the entire table into a ParquetRelation, so the filter to Hive metastore
// is None.
val partitions = metastoreRelation.getHiveQlPartitions(None).map { p =>
// We're converting the entire table into ParquetRelation, so predicates to Hive metastore
// are empty.
val partitions = metastoreRelation.getHiveQlPartitions().map { p =>
val location = p.getLocation
val values = InternalRow.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
Expand Down Expand Up @@ -667,8 +666,8 @@ private[hive] case class MetastoreRelation
}
)

def getHiveQlPartitions(filter: Option[String]): Seq[Partition] = {
table.getPartitions(filter).map { p =>
def getHiveQlPartitions(predicates: Seq[Expression] = Nil): Seq[Partition] = {
table.getPartitions(predicates).map { p =>
val tPartition = new org.apache.hadoop.hive.metastore.api.Partition
tPartition.setDbName(databaseName)
tPartition.setTableName(tableName)
Expand Down
56 changes: 1 addition & 55 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive

import java.io.{InputStream, OutputStream}
import java.rmi.server.UID
import java.util.List

/* Implicit conversions */
import scala.collection.JavaConversions._
Expand All @@ -31,18 +30,15 @@ import com.esotericsoftware.kryo.io.{Input, Output}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.hive.metastore.api.FieldSchema
import org.apache.hadoop.hive.ql.exec.{UDF, Utilities}
import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
import org.apache.hadoop.hive.serde2.avro.AvroGenericRecordWritable
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector
import org.apache.hadoop.io.Writable

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BinaryComparison, Expression}
import org.apache.spark.sql.types.{StringType, IntegralType, Decimal}
import org.apache.spark.sql.types.Decimal
import org.apache.spark.util.Utils

private[hive] object HiveShim {
Expand Down Expand Up @@ -104,56 +100,6 @@ private[hive] object HiveShim {
}
}

def toMetastoreFilter(
predicates: Seq[Expression],
partitionKeys: List[FieldSchema],
hiveMetastoreVersion: String): Option[String] = {

// Binary comparison has been supported in getPartitionsByFilter() since Hive 0.13.
// So if Hive matastore version is older than 0.13, predicates cannot be pushed down.
// See HIVE-4888.
val versionPattern = "([\\d]+\\.[\\d]+).*".r
hiveMetastoreVersion match {
case versionPattern(version) if (version.toDouble < 0.13) => return None
case _ => // continue
}

// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = partitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
Option(predicates.foldLeft("") {
(prevStr, expr) => {
expr match {
case op @ BinaryComparison(lhs, rhs) => {
val curr: Option[String] =
lhs match {
case AttributeReference(_, _, _, _) => {
rhs.dataType match {
case _: IntegralType =>
Some(lhs.prettyString + op.symbol + rhs.prettyString)
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
case _ => None
}
}
case _ => None
}
curr match {
case Some(currStr) if (prevStr.nonEmpty) => s"$prevStr and $currStr"
case Some(currStr) if (prevStr.isEmpty) => currStr
case None => prevStr
}
}
case _ => prevStr
}
}
}).filter(_.nonEmpty)
}

/**
* This class provides the UDF creation and also the UDF instance serialization and
* de-serialization cross process boundary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ private[hive] trait HiveStrategies {

try {
if (relation.hiveQlTable.isPartitioned) {
val metastoreFilter =
HiveShim.toMetastoreFilter(
pruningPredicates,
relation.hiveQlTable.getPartitionKeys,
hiveContext.hiveMetastoreVersion)

val rawPredicate = pruningPredicates.reduceOption(And).getOrElse(Literal(true))
// Translate the predicate so that it automatically casts the input values to the
// correct data types during evaluation.
Expand All @@ -131,9 +125,7 @@ private[hive] trait HiveStrategies {
InterpretedPredicate.create(castedPredicate)
}

logDebug(s"Hive metastore filter is $metastoreFilter")

val partitions = relation.getHiveQlPartitions(metastoreFilter).filter { part =>
val partitions = relation.getHiveQlPartitions(pruningPredicates).filter { part =>
val partitionValues = part.getValues
var i = 0
while (i < partitionValues.size()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.PrintStream
import java.util.{Map => JMap}

import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchTableException}
import org.apache.spark.sql.catalyst.expressions.Expression

private[hive] case class HiveDatabase(
name: String,
Expand Down Expand Up @@ -71,10 +72,10 @@ private[hive] case class HiveTable(

def isPartitioned: Boolean = partitionColumns.nonEmpty

def getPartitions(filter: Option[String]): Seq[HivePartition] = {
filter match {
case None => client.getAllPartitions(this)
case Some(expr) => client.getPartitionsByFilter(this, expr)
def getPartitions(predicates: Seq[Expression]): Seq[HivePartition] = {
predicates match {
case Nil => client.getAllPartitions(this)
case _ => client.getPartitionsByFilter(this, predicates)
}
}

Expand Down Expand Up @@ -138,7 +139,7 @@ private[hive] trait ClientInterface {
def getAllPartitions(hTable: HiveTable): Seq[HivePartition]

/** Returns partitions filtered by predicates for the given table. */
def getPartitionsByFilter(hTable: HiveTable, filter: String): Seq[HivePartition]
def getPartitionsByFilter(hTable: HiveTable, predicates: Seq[Expression]): Seq[HivePartition]

/** Loads a static partition into an existing table. */
def loadPartition(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

package org.apache.spark.sql.hive.client

import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.net.URI
import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet}
import java.io.{File, PrintStream}
import java.util.{Map => JMap}
import javax.annotation.concurrent.GuardedBy

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.util.CircularBuffer

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -315,9 +315,9 @@ private[hive] class ClientWrapper(

override def getPartitionsByFilter(
hTable: HiveTable,
filter: String): Seq[HivePartition] = withHiveState {
predicates: Seq[Expression]): Seq[HivePartition] = withHiveState {
val qlTable = toQlTable(hTable)
shim.getPartitionsByFilter(client, qlTable, filter).map(toHivePartition)
shim.getPartitionsByFilter(client, qlTable, predicates).map(toHivePartition)
}

override def listTables(dbName: String): Seq[String] = withHiveState {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.{Hive, Partition, Table}
import org.apache.hadoop.hive.ql.processors.{CommandProcessor, CommandProcessorFactory}
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde.serdeConstants

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference, BinaryComparison}
import org.apache.spark.sql.types.{StringType, IntegralType}

/**
* A shim that defines the interface between ClientWrapper and the underlying Hive library used to
Expand Down Expand Up @@ -61,7 +66,7 @@ private[client] sealed abstract class Shim {

def getAllPartitions(hive: Hive, table: Table): Seq[Partition]

def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition]
def getPartitionsByFilter(hive: Hive, table: Table, predicates: Seq[Expression]): Seq[Partition]

def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor

Expand Down Expand Up @@ -111,7 +116,7 @@ private[client] sealed abstract class Shim {

}

private[client] class Shim_v0_12 extends Shim {
private[client] class Shim_v0_12 extends Shim with Logging {

private lazy val startMethod =
findStaticMethod(
Expand All @@ -129,12 +134,6 @@ private[client] class Shim_v0_12 extends Shim {
classOf[Hive],
"getAllPartitionsForPruner",
classOf[Table])
private lazy val getPartitionsByFilterMethod =
findMethod(
classOf[Hive],
"getPartitionsByFilter",
classOf[Table],
classOf[String])
private lazy val getCommandProcessorMethod =
findStaticMethod(
classOf[CommandProcessorFactory],
Expand Down Expand Up @@ -204,9 +203,16 @@ private[client] class Shim_v0_12 extends Shim {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition] =
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
.toSeq
override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// getPartitionsByFilter() doesn't support binary comparison ops in Hive 0.12.
// See HIVE-4888.
logDebug("Hive 0.12 doesn't support predicate pushdown to metastore. " +
"Please use Hive 0.13 or higher.")
getAllPartitions(hive, table)
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, token, conf).asInstanceOf[CommandProcessor]
Expand Down Expand Up @@ -306,9 +312,47 @@ private[client] class Shim_v0_13 extends Shim_v0_12 {
override def getAllPartitions(hive: Hive, table: Table): Seq[Partition] =
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]].toSeq

override def getPartitionsByFilter(hive: Hive, table: Table, filter: String): Seq[Partition] =
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
.toSeq
override def getPartitionsByFilter(
hive: Hive,
table: Table,
predicates: Seq[Expression]): Seq[Partition] = {
// hive varchar is treated as catalyst string, but hive varchar can't be pushed down.
val varcharKeys = table.getPartitionKeys
.filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME))
.map(col => col.getName).toSet

// Hive getPartitionsByFilter() takes a string that represents partition
// predicates like "str_key=\"value\" and int_key=1 ..."
val filter = predicates.flatMap { expr =>
expr match {
case op @ BinaryComparison(lhs, rhs) => {
lhs match {
case AttributeReference(_, _, _, _) => {
rhs.dataType match {
case _: IntegralType =>
Some(lhs.prettyString + op.symbol + rhs.prettyString)
case _: StringType if (!varcharKeys.contains(lhs.prettyString)) =>
Some(lhs.prettyString + op.symbol + "\"" + rhs.prettyString + "\"")
case _ => None
}
}
case _ => None
}
}
case _ => None
}
}.mkString(" and ")

val partitions =
if (filter.isEmpty) {
getAllPartitionsMethod.invoke(hive, table).asInstanceOf[JSet[Partition]]
} else {
logDebug(s"Hive metastore filter is '$filter'.")
getPartitionsByFilterMethod.invoke(hive, table, filter).asInstanceOf[JArrayList[Partition]]
}

partitions.toSeq
}

override def getCommandProcessor(token: String, conf: HiveConf): CommandProcessor =
getCommandProcessorMethod.invoke(null, Array(token), conf).asInstanceOf[CommandProcessor]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ case class HiveTableScan(
// Retrieve the original attributes based on expression ID so that capitalization matches.
val attributes = requestedAttributes.map(relation.attributeMap)

val metastoreFilter: Option[String] =
HiveShim.toMetastoreFilter(
partitionPruningPred,
relation.hiveQlTable.getPartitionKeys,
context.hiveMetastoreVersion)

// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
private[this] val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
Expand Down Expand Up @@ -139,9 +133,8 @@ case class HiveTableScan(
protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) {
hadoopReader.makeRDDForTable(relation.hiveQlTable)
} else {
logDebug(s"Hive metastore filter is $metastoreFilter")
hadoopReader.makeRDDForPartitionedTable(
prunePartitions(relation.getHiveQlPartitions(metastoreFilter)))
prunePartitions(relation.getHiveQlPartitions(partitionPruningPred)))
}

override def output: Seq[Attribute] = attributes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.client

import java.io.File

import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.{Logging, SparkFunSuite}
import org.apache.spark.sql.catalyst.util.quietly
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -152,7 +154,9 @@ class VersionsSuite extends SparkFunSuite with Logging {
}

test(s"$version: getPartitionsByFilter") {
client.getPartitionsByFilter(client.getTable("default", "src_part"), "key = 1")
client.getPartitionsByFilter(client.getTable("default", "src_part"), Seq(EqualTo(
AttributeReference("key", IntegerType, false)(NamedExpression.newExprId),
Literal(1))))
}

test(s"$version: loadPartition") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter {
case p @ HiveTableScan(columns, relation, _) =>
val columnNames = columns.map(_.name)
val partValues = if (relation.table.isPartitioned) {
p.prunePartitions(relation.getHiveQlPartitions(None)).map(_.getValues)
p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues)
} else {
Seq.empty
}
Expand Down

0 comments on commit c212c4d

Please # to comment.