Skip to content

Commit

Permalink
Merge pull request twitter#394 from miguno/twitterGH-392
Browse files Browse the repository at this point in the history
twitterGH-392: Improve hashing of BigInt
  • Loading branch information
johnynek committed Jan 14, 2015
2 parents 5547a2d + 57e2437 commit 20a3019
Show file tree
Hide file tree
Showing 3 changed files with 312 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package com.twitter.algebird.caliper

import com.google.caliper.{Param, SimpleBenchmark}

/**
* Benchmarks the hashing algorithms used by Count-Min sketch for CMS[BigInt].
*
* The input values are generated ahead of time to ensure that each trial uses the same input (and that the RNG is not
* influencing the runtime of the trials).
*
* More details available at https://github.com/twitter/algebird/issues/392.
*/
// Once we can convince cappi (https://github.com/softprops/capp) -- the sbt plugin we use to run
// caliper benchmarks -- to work with the latest caliper 1.0-beta-1, we would:
// - Let `CMSHashingBenchmark` extend `Benchmark` (instead of `SimpleBenchmark`)
// - Annotate `timePlus` with `@MacroBenchmark`.
class CMSHashingBenchmark extends SimpleBenchmark {

/**
* The `a` parameter for CMS' default ("legacy") hashing algorithm: `h_i(x) = a_i * x + b_i (mod p)`.
*/
@Param(Array("5123456"))
val a: Int = 0

/**
* The `b` parameter for CMS' default ("legacy") hashing algorithm: `h_i(x) = a_i * x + b_i (mod p)`.
*
* Algebird's CMS implementation hard-codes `b` to `0`.
*/
@Param(Array("0"))
val b: Int = 0

/**
* Width of the counting table.
*/
@Param(Array("11" /* eps = 0.271 */ , "544" /* eps = 0.005 */ , "2719" /* eps = 1E-3 */ , "271829" /* eps = 1E-5 */))
val width: Int = 0

/**
* Number of operations per benchmark repetition.
*/
@Param(Array("100000"))
val operations: Int = 0

/**
* Maximum number of bits for randomly generated BigInt instances.
*/
@Param(Array("128", "1024", "2048"))
val maxBits: Int = 0

var random: scala.util.Random = _
var inputs: Seq[BigInt] = _

override def setUp() {
random = new scala.util.Random
// We draw numbers randomly from a 2^maxBits address space.
inputs = (1 to operations).view.map { _ => scala.math.BigInt(maxBits, random)}
}

private def murmurHashScala(a: Int, b: Int, width: Int)(x: BigInt) = {
val hash: Int = scala.util.hashing.MurmurHash3.arrayHash(x.toByteArray, a)
val h = {
// We only want positive integers for the subsequent modulo. This method mimics Java's Hashtable
// implementation. The Java code uses `0x7FFFFFFF` for the bit-wise AND, which is equal to Int.MaxValue.
val positiveHash = hash & Int.MaxValue
positiveHash % width
}
assert(h >= 0, "hash must not be negative")
h
}

private val PRIME_MODULUS = (1L << 31) - 1

private def brokenCurrentHash(a: Int, b: Int, width: Int)(x: BigInt) = {
val unModded: BigInt = (x * a) + b
val modded: BigInt = (unModded + (unModded >> 32)) & PRIME_MODULUS
val h = modded.toInt % width
assert(h >= 0, "hash must not be negative")
h
}

def timeBrokenCurrentHashWithRandomMaxBitsNumbers(operations: Int): Int = {
var dummy = 0
while (dummy < operations) {
inputs.foreach { input => brokenCurrentHash(a, b, width)(input)}
dummy += 1
}
dummy
}

def timeMurmurHashScalaWithRandomMaxBitsNumbers(operations: Int): Int = {
var dummy = 0
while (dummy < operations) {
inputs.foreach { input => murmurHashScala(a, b, width)(input)}
dummy += 1
}
dummy
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ case class CMSInstance[K: Ordering](countsTable: CMSInstance.CountsTable[K],
* Let X be a CMS, and let count_X[j, k] denote the value in X's 2-dimensional count table at row j and column k.
* Then the Count-Min sketch estimate of the inner product between A and B is the minimum inner product between their
* rows:
* estimatedInnerProduct = min_j (\sum_k count_A[j, k] * count_B[j, k])
* estimatedInnerProduct = min_j (\sum_k count_A[j, k] * count_B[j, k]|)
*/
def innerProduct(other: CMS[K]): Approximate[Long] = {
other match {
Expand All @@ -491,7 +491,8 @@ case class CMSInstance[K: Ordering](countsTable: CMSInstance.CountsTable[K],
}.sum

val est = (0 to (depth - 1)).iterator.map { innerProductAtDepth }.min
Approximate(est - (eps * totalCount * other.totalCount).toLong, est, est, 1 - delta)
val minimum = math.max(est - (eps * totalCount * other.totalCount).toLong, 0)
Approximate(minimum, est, est, 1 - delta)
case _ => other.innerProduct(this)
}
}
Expand Down Expand Up @@ -663,14 +664,19 @@ case class TopCMSItem[K: Ordering](item: K, override val cms: CMS[K], params: To

override val heavyHitters: Set[K] = Set(item)

override def +(x: K, count: Long): TopCMS[K] = TopCMSInstance(cms, params) + item + (x, count)
override def +(x: K, count: Long): TopCMS[K] = toCMSInstance + (x, count)

override def ++(other: TopCMS[K]): TopCMS[K] = other match {
case other: TopCMSZero[_] => this
case other: TopCMSItem[K] => TopCMSInstance[K](cms, params) + item + other.item
case other: TopCMSItem[K] => toCMSInstance + other.item
case other: TopCMSInstance[K] => other + item
}

private def toCMSInstance: TopCMSInstance[K] = {
val hhs = HeavyHitters.from(HeavyHitter(item, 1L))
TopCMSInstance(cms, hhs, params)
}

}

object TopCMSInstance {
Expand Down Expand Up @@ -798,6 +804,8 @@ object HeavyHitters {

def from[K: Ordering](hhs: Set[HeavyHitter[K]]): HeavyHitters[K] = HeavyHitters(hhs.foldLeft(emptyHhs)(_ + _))

def from[K: Ordering](hh: HeavyHitter[K]): HeavyHitters[K] = HeavyHitters(emptyHhs + hh)

}

case class HeavyHitter[K](item: K, count: Long) extends java.io.Serializable
Expand Down Expand Up @@ -850,7 +858,7 @@ class TopPctCMSMonoid[K: Ordering](cms: CMS[K], heavyHittersPct: Double = 0.01)
/**
* Creates a sketch out of a single item.
*/
def create(item: K): TopCMS[K] = TopCMSItem[K](item, cms, params)
def create(item: K): TopCMS[K] = TopCMSItem[K](item, cms + item, params)

/**
* Creates a sketch out of multiple items.
Expand Down Expand Up @@ -964,7 +972,7 @@ class TopNCMSMonoid[K: Ordering](cms: CMS[K], heavyHittersN: Int = 100) extends
/**
* Creates a sketch out of a single item.
*/
def create(item: K): TopCMS[K] = TopCMSItem[K](item, cms, params)
def create(item: K): TopCMS[K] = TopCMSItem[K](item, cms + item, params)

/**
* Creates a sketch out of multiple items.
Expand Down Expand Up @@ -1022,7 +1030,7 @@ case class TopNCMSAggregator[K](cmsMonoid: TopNCMSMonoid[K])
*/
trait CMSHasher[K] extends java.io.Serializable {

val PRIME_MODULUS = (1L << 31) - 1
val PRIME_MODULUS = Int.MaxValue

/**
* Returns `a * x + b (mod p) (mod width)`.
Expand All @@ -1047,7 +1055,7 @@ object CMSHasherImplicits {

implicit object CMSHasherLong extends CMSHasher[Long] {

def hash(a: Int, b: Int, width: Int)(x: Long) = {
override def hash(a: Int, b: Int, width: Int)(x: Long): Int = {
val unModded: Long = (x * a) + b
// Apparently a super fast way of computing x mod 2^p-1
// See page 149 of http://www.cs.princeton.edu/courses/archive/fall09/cos521/Handouts/universalclasses.pdf
Expand All @@ -1061,32 +1069,55 @@ object CMSHasherImplicits {

implicit object CMSHasherShort extends CMSHasher[Short] {

def hash(a: Int, b: Int, width: Int)(x: Short) = CMSHasherInt.hash(a, b, width)(x)
override def hash(a: Int, b: Int, width: Int)(x: Short): Int = CMSHasherInt.hash(a, b, width)(x)

}

implicit object CMSHasherInt extends CMSHasher[Int] {

def hash(a: Int, b: Int, width: Int)(x: Int) = {
override def hash(a: Int, b: Int, width: Int)(x: Int): Int = {
val unModded: Int = (x * a) + b
val modded: Long = (unModded + (unModded >> 32)) & PRIME_MODULUS
val h = modded.toInt % width
assert(h >= 0, "hash must not be negative")
h
modded.toInt % width
}

}

implicit object CMSHasherBigInt extends CMSHasher[BigInt] {

def hash(a: Int, b: Int, width: Int)(x: BigInt) = {
val unModded: BigInt = (x * a) + b
val modded: BigInt = (unModded + (unModded >> 32)) & PRIME_MODULUS
val h = modded.toInt % width
assert(h >= 0, "hash must not be negative")
h
/**
* =Implementation details=
*
* This hash function is based upon Murmur3. Note that the original CMS paper requires
* `d` (depth) pair-wise independent hash functions; in the specific case of Murmur3 we argue that it is sufficient
* to pass `d` different seed values to Murmur3 to achieve a similar effect.
*
* To seed Murmur3 we use only `a`, which is a randomly drawn `Int` via [[scala.util.Random]] in the CMS code.
* What is important to note is that we intentionally ignore `b`. Why? We need to ensure that we seed Murmur3 with
* a random value, notably one that is uniformly distributed. Somewhat surprisingly, combining two random values
* (such as `a` and `b` in our case) typically worsens the "randomness" of the combination, i.e. the combination is
* less uniformly distributed as either of its original inputs. Hence the combination of two random values is
* discouraged in this context, notably if the two random inputs were generated from the same source anyways, which
* is the case for us because we use Scala's PRNG only.
*
* For further details please refer to the discussion
* [[http://stackoverflow.com/questions/3956478/understanding-randomness Understanding Randomness]] on
* StackOverflow.
*
* @param a Must be a random value, typically created via [[scala.util.Random]].
* @param b Ignored by this particular hash function, see the reasoning above for the justification.
* @param width Width of the CMS counting table, i.e. the width/size of each row in the counting table.
* @param x Item to be hashed.
* @return Slot assigned to item `x` in the vector of size `width`, where `x in [0, width)`.
*/
override def hash(a: Int, b: Int, width: Int)(x: BigInt): Int = {
val hash: Int = scala.util.hashing.MurmurHash3.arrayHash(x.toByteArray, a)
// We only want positive integers for the subsequent modulo. This method mimics Java's Hashtable
// implementation. The Java code uses `0x7FFFFFFF` for the bit-wise AND, which is equal to Int.MaxValue.
val positiveHash = hash & Int.MaxValue
positiveHash % width
}

}

}
}
Loading

0 comments on commit 20a3019

Please # to comment.