Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.yahoo.sketches.kll.KllFloatsSketch
import com.yahoo.sketches.{ArrayOfDoublesSerDe, ArrayOfItemsSerDe, ArrayOfLongsSerDe, ArrayOfStringsSerDe}

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
import java.security.MessageDigest
import java.util
import scala.collection.mutable
import scala.jdk.CollectionConverters._
Expand Down Expand Up @@ -599,6 +600,59 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy
}
}

class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[String], Long] {
private def toBytes(input: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
val out = new ObjectOutputStream(bos)
out.writeObject(input)
out.flush()
bos.toByteArray
}

private def md5Hex(bytes: Array[Byte]): String =
MessageDigest.getInstance("MD5").digest(bytes).map("%02x".format(_)).mkString
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets say i want to unique count a bunch of user / merchant ids (long values) - won't this be less efficient than simply keeping the set of longs?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the code change to keep the numeric type as is


private def hashInput(input: T): String =
md5Hex(toBytes(input))

override def prepare(input: T): util.Set[String] = {
val result = new util.HashSet[String](k)
result.add(hashInput(input))
result
}

override def update(ir: util.Set[String], input: T): util.Set[String] = {
if (ir.size() >= k) {
return ir
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

memory optimization: we can use a sentinel set when k is reached.

Suggested change
return ir
if(ir == Constants.SentinelSet || ir.size() >= k) return Constants.SentinelSet

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm.. I don't think we have sentinel set yet in OSS branch.

}

ir.add(hashInput(input))
ir
}

override def outputType: DataType = LongType

override def irType: DataType = ListType(StringType)

override def merge(ir1: util.Set[String], ir2: util.Set[String]): util.Set[String] = {
ir2.asScala.foreach(v =>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ir2.asScala.foreach(v =>
ir2.iterator().asScala.foreach(v =>

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise it will create intermediate collections

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call out!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (ir1.size() < k) {
ir1.add(v)
})

ir1
}

override def finalize(ir: util.Set[String]): Long = ir.size()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
override def finalize(ir: util.Set[String]): Long = ir.size()
override def finalize(ir: util.Set[String]): Long = if(ir == Constants.SentinelSet) k else ir.size()


override def clone(ir: util.Set[String]): util.Set[String] = new util.HashSet[String](ir)

override def normalize(ir: util.Set[String]): Any = new util.ArrayList[String](ir)

override def denormalize(ir: Any): util.Set[String] =
new util.HashSet[String](ir.asInstanceOf[util.ArrayList[String]])
}

// Based on CPC sketch (a faster, smaller and more accurate version of HLL)
// See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017
// https://arxiv.org/abs/1708.06839
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,19 @@ object ColumnAggregator {
case BinaryType => simple(new ApproxDistinctCount[Array[Byte]](aggregationPart.getInt("k", Some(8))))
case _ => mismatchException
}
case Operation.BOUNDED_UNIQUE_COUNT =>
val k = aggregationPart.getInt("k", Some(8))

inputType match {
case IntType => simple(new BoundedUniqueCount[Int](inputType, k))
case LongType => simple(new BoundedUniqueCount[Long](inputType, k))
case ShortType => simple(new BoundedUniqueCount[Short](inputType, k))
case DoubleType => simple(new BoundedUniqueCount[Double](inputType, k))
case FloatType => simple(new BoundedUniqueCount[Float](inputType, k))
case StringType => simple(new BoundedUniqueCount[String](inputType, k))
case BinaryType => simple(new BoundedUniqueCount[Array[Byte]](inputType, k))
case _ => mismatchException
}
case Operation.APPROX_PERCENTILE =>
val k = aggregationPart.getInt("k", Some(128))
val mapper = new ObjectMapper()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package ai.chronon.aggregator.test

import ai.chronon.aggregator.base.BoundedUniqueCount
import ai.chronon.api.StringType
import junit.framework.TestCase
import org.junit.Assert._

import java.util
import scala.jdk.CollectionConverters._

class BoundedUniqueCountTest extends TestCase {
def testHappyCase(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
var ir = boundedDistinctCount.prepare("1")
ir = boundedDistinctCount.update(ir, "1")
ir = boundedDistinctCount.update(ir, "2")

val result = boundedDistinctCount.finalize(ir)
assertEquals(2, result)
}

def testExceedSize(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
var ir = boundedDistinctCount.prepare("1")
ir = boundedDistinctCount.update(ir, "2")
ir = boundedDistinctCount.update(ir, "3")
ir = boundedDistinctCount.update(ir, "4")
ir = boundedDistinctCount.update(ir, "5")
ir = boundedDistinctCount.update(ir, "6")
ir = boundedDistinctCount.update(ir, "7")

val result = boundedDistinctCount.finalize(ir)
assertEquals(5, result)
}

def testMerge(): Unit = {
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
val ir1 = new util.HashSet[String](Seq("1", "2", "3").asJava)
val ir2 = new util.HashSet[String](Seq("4", "5", "6").asJava)

val merged = boundedDistinctCount.merge(ir1, ir2)
assertEquals(merged.size(), 5)
}
}
1 change: 1 addition & 0 deletions api/py/ai/chronon/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Operation:
# https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180
APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT)
UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT
BOUNDED_UNIQUE_COUNT_K = collector(ttypes.Operation.BOUNDED_UNIQUE_COUNT)
COUNT = ttypes.Operation.COUNT
SUM = ttypes.Operation.SUM
AVERAGE = ttypes.Operation.AVERAGE
Expand Down
3 changes: 2 additions & 1 deletion api/thrift/api.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ enum Operation {
BOTTOM_K = 16

HISTOGRAM = 17, // use this only if you know the set of inputs is bounded
APPROX_HISTOGRAM_K = 18
APPROX_HISTOGRAM_K = 18,
BOUNDED_UNIQUE_COUNT = 19
}

// integers map to milliseconds in the timeunit
Expand Down
1 change: 1 addition & 0 deletions docs/source/authoring_features/GroupBy.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Limitations:
| approx_unique_count | primitive types | list, map | long | no | k=8 | yes |
| approx_percentile | primitive types | list, map | list<input,> | no | k=128, percentiles | yes |
| unique_count | primitive types | list, map | long | no | | no |
| bounded_unique_count | primitive types | list, map | long | no | k=inf | yes |


## Accuracy
Expand Down
12 changes: 10 additions & 2 deletions spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,11 @@ class FetcherTest extends TestCase {
Builders.Aggregation(operation = Operation.LAST_K,
argMap = Map("k" -> "300"),
inputColumn = "user",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))
),
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))),
Builders.Aggregation(operation = Operation.BOUNDED_UNIQUE_COUNT,
argMap = Map("k" -> "5"),
inputColumn = "user",
windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))),
metaData = Builders.MetaData(name = "unit_test/vendor_ratings", namespace = namespace),
accuracy = Accuracy.SNAPSHOT
)
Expand Down Expand Up @@ -503,6 +506,11 @@ class FetcherTest extends TestCase {
operation = Operation.APPROX_HISTOGRAM_K,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
),
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "rating",
windows = Seq(new Window(1, TimeUnit.DAYS))
)
),
accuracy = Accuracy.TEMPORAL,
Expand Down
43 changes: 43 additions & 0 deletions spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -724,4 +724,47 @@ class GroupByTest {
assert(count > 0, s"Found a count value that is not greater than zero: $count")
}
}

@Test
def testBoundedUniqueCounts(): Unit = {
lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true)
val (source, endPartition) = createTestSource(suffix = "_bounded_counts")
val tableUtils = TableUtils(spark)
val namespace = "test_bounded_counts"
val aggs = Seq(
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "item",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "5")
),
Builders.Aggregation(
operation = Operation.BOUNDED_UNIQUE_COUNT,
inputColumn = "price",
windows = Seq(
new Window(15, TimeUnit.DAYS),
new Window(60, TimeUnit.DAYS)
),
argMap = Map("k" -> "5")
),
)
backfill(name = "unit_test_group_by_bounded_counts",
source = source,
endPartition = endPartition,
namespace = namespace,
tableUtils = tableUtils,
additionalAgg = aggs)

val result = spark.sql(
"""
|select *
|from test_bounded_counts.unit_test_group_by_bounded_counts
|where item_bounded_unique_count_60d > 5 or price_bounded_unique_count_60d > 5
|""".stripMargin)

assertTrue(result.isEmpty)
}
}