From dac9f91612a41a4a0d6efaa8c389259a6c786dc1 Mon Sep 17 00:00:00 2001 From: Jeffrey Brooks Date: Tue, 2 Jul 2024 13:52:00 -0700 Subject: [PATCH 1/7] Add bounded unique count aggregation --- .../aggregator/base/SimpleAggregators.scala | 38 ++++++++++++++++ .../aggregator/row/ColumnAggregator.scala | 12 +++++ .../test/BoundedUniqueCountTest.scala | 44 +++++++++++++++++++ api/py/ai/chronon/group_by.py | 1 + api/thrift/api.thrift | 3 +- docs/source/authoring_features/GroupBy.md | 1 + .../ai/chronon/spark/test/FetcherTest.scala | 12 ++++- .../ai/chronon/spark/test/GroupByTest.scala | 42 ++++++++++++++++++ 8 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index b120d29e7f..83fd8df64c 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -599,6 +599,44 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy } } +class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[T], Long] { + override def prepare(input: T): util.Set[T] = { + val result = new util.HashSet[T](k) + result.add(input) + result + } + + override def update(ir: util.Set[T], input: T): util.Set[T] = { + if (ir.size() >= k) { + return ir + } + + ir.add(input) + ir + } + + override def outputType: DataType = LongType + + override def irType: DataType = ListType(inputType) + + override def merge(ir1: util.Set[T], ir2: util.Set[T]): util.Set[T] = { + ir2.asScala.foreach(v => + if (ir1.size() < k) { + ir1.add(v) + }) + + ir1 + } + + override def finalize(ir: util.Set[T]): Long = ir.size() + + override def clone(ir: util.Set[T]): util.Set[T] = new util.HashSet[T](ir) + + override def normalize(ir: util.Set[T]): Any = new util.ArrayList[T](ir) + + override def denormalize(ir: Any): util.Set[T] = new util.HashSet[T](ir.asInstanceOf[util.ArrayList[T]]) +} + // Based on CPC sketch (a faster, smaller and more accurate version of HLL) // See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017 // https://arxiv.org/abs/1708.06839 diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala index d5f21b3072..d0be4ead20 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala @@ -307,7 +307,19 @@ object ColumnAggregator { case BinaryType => simple(new ApproxDistinctCount[Array[Byte]](aggregationPart.getInt("k", Some(8)))) case _ => mismatchException } + case Operation.BOUNDED_UNIQUE_COUNT => + val k = aggregationPart.getInt("k", Some(8)) + inputType match { + case IntType => simple(new BoundedUniqueCount[Int](inputType, k)) + case LongType => simple(new BoundedUniqueCount[Long](inputType, k)) + case ShortType => simple(new BoundedUniqueCount[Short](inputType, k)) + case DoubleType => simple(new BoundedUniqueCount[Double](inputType, k)) + case FloatType => simple(new BoundedUniqueCount[Float](inputType, k)) + case StringType => simple(new BoundedUniqueCount[String](inputType, k)) + case BinaryType => simple(new BoundedUniqueCount[Array[Byte]](inputType, k)) + case _ => mismatchException + } case Operation.APPROX_PERCENTILE => val k = aggregationPart.getInt("k", Some(128)) val mapper = new ObjectMapper() diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala new file mode 100644 index 0000000000..d780aae080 --- /dev/null +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala @@ -0,0 +1,44 @@ +package ai.chronon.aggregator.test + +import ai.chronon.aggregator.base.BoundedUniqueCount +import ai.chronon.api.StringType +import junit.framework.TestCase +import org.junit.Assert._ + +import java.util +import scala.jdk.CollectionConverters._ + +class BoundedUniqueCountTest extends TestCase { + def testHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + var ir = boundedDistinctCount.prepare("1") + ir = boundedDistinctCount.update(ir, "1") + ir = boundedDistinctCount.update(ir, "2") + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + var ir = boundedDistinctCount.prepare("1") + ir = boundedDistinctCount.update(ir, "2") + ir = boundedDistinctCount.update(ir, "3") + ir = boundedDistinctCount.update(ir, "4") + ir = boundedDistinctCount.update(ir, "5") + ir = boundedDistinctCount.update(ir, "6") + ir = boundedDistinctCount.update(ir, "7") + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testMerge(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + val ir1 = new util.HashSet[String](Seq("1", "2", "3").asJava) + val ir2 = new util.HashSet[String](Seq("4", "5", "6").asJava) + + val merged = boundedDistinctCount.merge(ir1, ir2) + assertEquals(merged.size(), 5) + } +} \ No newline at end of file diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index c171113356..be84bcb8db 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -65,6 +65,7 @@ class Operation: # https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180 APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT) UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT + BOUNDED_UNIQUE_COUNT = ttypes.Operation.BOUNDED_UNIQUE_COUNT COUNT = ttypes.Operation.COUNT SUM = ttypes.Operation.SUM AVERAGE = ttypes.Operation.AVERAGE diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index c75e566e1a..2ca6242ec5 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -161,7 +161,8 @@ enum Operation { BOTTOM_K = 16 HISTOGRAM = 17, // use this only if you know the set of inputs is bounded - APPROX_HISTOGRAM_K = 18 + APPROX_HISTOGRAM_K = 18, + BOUNDED_UNIQUE_COUNT = 19 } // integers map to milliseconds in the timeunit diff --git a/docs/source/authoring_features/GroupBy.md b/docs/source/authoring_features/GroupBy.md index c972f4cff3..7f1b7acefe 100644 --- a/docs/source/authoring_features/GroupBy.md +++ b/docs/source/authoring_features/GroupBy.md @@ -147,6 +147,7 @@ Limitations: | approx_unique_count | primitive types | list, map | long | no | k=8 | yes | | approx_percentile | primitive types | list, map | list | no | k=128, percentiles | yes | | unique_count | primitive types | list, map | long | no | | no | +| bounded_unique_count | primitive types | list, map | long | no | k=inf | yes | ## Accuracy diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index 051ec1be73..a046a16f57 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -273,8 +273,11 @@ class FetcherTest extends TestCase { Builders.Aggregation(operation = Operation.LAST_K, argMap = Map("k" -> "300"), inputColumn = "user", - windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))) - ), + windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))), + Builders.Aggregation(operation = Operation.BOUNDED_UNIQUE_COUNT, + argMap = Map("k" -> "5"), + inputColumn = "user", + windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))), metaData = Builders.MetaData(name = "unit_test/vendor_ratings", namespace = namespace), accuracy = Accuracy.SNAPSHOT ) @@ -503,6 +506,11 @@ class FetcherTest extends TestCase { operation = Operation.APPROX_HISTOGRAM_K, inputColumn = "rating", windows = Seq(new Window(1, TimeUnit.DAYS)) + ), + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "rating", + windows = Seq(new Window(1, TimeUnit.DAYS)) ) ), accuracy = Accuracy.TEMPORAL, diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index 12c64b9cfa..c2ac077c3e 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -724,4 +724,46 @@ class GroupByTest { assert(count > 0, s"Found a count value that is not greater than zero: $count") } } + + @Test + def testBoundedUniqueCounts(): Unit = { + val (source, endPartition) = createTestSource(suffix = "_bounded_counts") + val tableUtils = TableUtils(spark) + val namespace = "test_bounded_counts" + val aggs = Seq( + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "item", + windows = Seq( + new Window(15, TimeUnit.DAYS), + new Window(60, TimeUnit.DAYS) + ), + argMap = Map("k" -> "5") + ), + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "price", + windows = Seq( + new Window(15, TimeUnit.DAYS), + new Window(60, TimeUnit.DAYS) + ), + argMap = Map("k" -> "5") + ), + ) + backfill(name = "unit_test_group_by_bounded_counts", + source = source, + endPartition = endPartition, + namespace = namespace, + tableUtils = tableUtils, + additionalAgg = aggs) + + val result = spark.sql( + """ + |select * + |from test_bounded_counts.unit_test_group_by_bounded_counts + |where item_bounded_unique_count_60d > 5 or price_bounded_unique_count_60d > 5 + |""".stripMargin) + + assertTrue(result.isEmpty) + } } From d5746dad710386040f0266b6c09fda9fb5db8e82 Mon Sep 17 00:00:00 2001 From: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Date: Thu, 3 Jul 2025 17:36:54 -0700 Subject: [PATCH 2/7] use hash value instead --- .../aggregator/base/SimpleAggregators.scala | 37 +++++++++++++------ api/py/ai/chronon/group_by.py | 2 +- .../ai/chronon/spark/test/GroupByTest.scala | 1 + 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index 83fd8df64c..bacbe79174 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -25,6 +25,7 @@ import com.yahoo.sketches.kll.KllFloatsSketch import com.yahoo.sketches.{ArrayOfDoublesSerDe, ArrayOfItemsSerDe, ArrayOfLongsSerDe, ArrayOfStringsSerDe} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import java.security.MessageDigest import java.util import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -599,19 +600,33 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy } } -class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[T], Long] { - override def prepare(input: T): util.Set[T] = { - val result = new util.HashSet[T](k) - result.add(input) +class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[String], Long] { + private def toBytes(input: T): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val out = new ObjectOutputStream(bos) + out.writeObject(input) + out.flush() + bos.toByteArray + } + + private def md5Hex(bytes: Array[Byte]): String = + MessageDigest.getInstance("MD5").digest(bytes).map("%02x".format(_)).mkString + + private def hashInput(input: T): String = + md5Hex(toBytes(input)) + + override def prepare(input: T): util.Set[String] = { + val result = new util.HashSet[String](k) + result.add(hashInput(input)) result } - override def update(ir: util.Set[T], input: T): util.Set[T] = { + override def update(ir: util.Set[String], input: T): util.Set[String] = { if (ir.size() >= k) { return ir } - ir.add(input) + ir.add(hashInput(input)) ir } @@ -619,7 +634,7 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre override def irType: DataType = ListType(inputType) - override def merge(ir1: util.Set[T], ir2: util.Set[T]): util.Set[T] = { + override def merge(ir1: util.Set[String], ir2: util.Set[String]): util.Set[String] = { ir2.asScala.foreach(v => if (ir1.size() < k) { ir1.add(v) @@ -628,13 +643,13 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre ir1 } - override def finalize(ir: util.Set[T]): Long = ir.size() + override def finalize(ir: util.Set[String]): Long = ir.size() - override def clone(ir: util.Set[T]): util.Set[T] = new util.HashSet[T](ir) + override def clone(ir: util.Set[String]): util.Set[String] = new util.HashSet[String](ir) - override def normalize(ir: util.Set[T]): Any = new util.ArrayList[T](ir) + override def normalize(ir: util.Set[String]): Any = new util.ArrayList[String](ir) - override def denormalize(ir: Any): util.Set[T] = new util.HashSet[T](ir.asInstanceOf[util.ArrayList[T]]) + override def denormalize(ir: Any): util.Set[String] = new util.HashSet[String](ir.asInstanceOf[util.ArrayList[String]]) } // Based on CPC sketch (a faster, smaller and more accurate version of HLL) diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index be84bcb8db..e08d9626ad 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -65,7 +65,7 @@ class Operation: # https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180 APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT) UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT - BOUNDED_UNIQUE_COUNT = ttypes.Operation.BOUNDED_UNIQUE_COUNT + BOUNDED_UNIQUE_COUNT_K = collector(ttypes.Operation.BOUNDED_UNIQUE_COUNT) COUNT = ttypes.Operation.COUNT SUM = ttypes.Operation.SUM AVERAGE = ttypes.Operation.AVERAGE diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index c2ac077c3e..289b4048be 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -727,6 +727,7 @@ class GroupByTest { @Test def testBoundedUniqueCounts(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) val (source, endPartition) = createTestSource(suffix = "_bounded_counts") val tableUtils = TableUtils(spark) val namespace = "test_bounded_counts" From 03a3a3b12b9706d24d5e9a2920599de9c12ca147 Mon Sep 17 00:00:00 2001 From: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Date: Thu, 3 Jul 2025 17:58:23 -0700 Subject: [PATCH 3/7] scala fmt --- .../scala/ai/chronon/aggregator/base/SimpleAggregators.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index bacbe79174..3b11928dc9 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -649,7 +649,8 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre override def normalize(ir: util.Set[String]): Any = new util.ArrayList[String](ir) - override def denormalize(ir: Any): util.Set[String] = new util.HashSet[String](ir.asInstanceOf[util.ArrayList[String]]) + override def denormalize(ir: Any): util.Set[String] = + new util.HashSet[String](ir.asInstanceOf[util.ArrayList[String]]) } // Based on CPC sketch (a faster, smaller and more accurate version of HLL) From fd27c7a22ec6b3a372871d0ff862e3dfc5ba26f4 Mon Sep 17 00:00:00 2001 From: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Date: Wed, 9 Jul 2025 13:37:41 -0700 Subject: [PATCH 4/7] fix ir type --- .../scala/ai/chronon/aggregator/base/SimpleAggregators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index 3b11928dc9..0d04166a69 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -632,7 +632,7 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre override def outputType: DataType = LongType - override def irType: DataType = ListType(inputType) + override def irType: DataType = ListType(StringType) override def merge(ir1: util.Set[String], ir2: util.Set[String]): util.Set[String] = { ir2.asScala.foreach(v => From 0ca15cde1b996f88947d22dc76920f9392e340f3 Mon Sep 17 00:00:00 2001 From: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Date: Wed, 9 Jul 2025 15:14:18 -0700 Subject: [PATCH 5/7] use iterator --- .../scala/ai/chronon/aggregator/base/SimpleAggregators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index 0d04166a69..4013f144d8 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -635,7 +635,7 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre override def irType: DataType = ListType(StringType) override def merge(ir1: util.Set[String], ir2: util.Set[String]): util.Set[String] = { - ir2.asScala.foreach(v => + ir2.iterator().asScala.foreach(v => if (ir1.size() < k) { ir1.add(v) }) From 5f1ff099f7ca1be2aa8636b68357c543978a269f Mon Sep 17 00:00:00 2001 From: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Date: Wed, 9 Jul 2025 15:45:32 -0700 Subject: [PATCH 6/7] scala fmt --- .../chronon/aggregator/base/SimpleAggregators.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index 4013f144d8..6499f2be49 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -635,10 +635,13 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre override def irType: DataType = ListType(StringType) override def merge(ir1: util.Set[String], ir2: util.Set[String]): util.Set[String] = { - ir2.iterator().asScala.foreach(v => - if (ir1.size() < k) { - ir1.add(v) - }) + ir2 + .iterator() + .asScala + .foreach(v => + if (ir1.size() < k) { + ir1.add(v) + }) ir1 } From 996d050a94570b2ffd2499deaf690f6b7602adde Mon Sep 17 00:00:00 2001 From: Pengyu Hou <3771747+pengyu-hou@users.noreply.github.com> Date: Thu, 10 Jul 2025 16:35:07 -0700 Subject: [PATCH 7/7] Optimize BoundedUniqueCount for numeric types with sentinel set pattern --- .../aggregator/base/SimpleAggregators.scala | 87 +++++++-- .../test/BoundedUniqueCountTest.scala | 171 +++++++++++++++++- 2 files changed, 236 insertions(+), 22 deletions(-) diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index 6499f2be49..520ae135b2 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -600,7 +600,12 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy } } -class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[String], Long] { +object BoundedUniqueCount { + private val SentinelSet: util.Set[Any] = new util.HashSet[Any]() + private val SentinelMarker: String = "__SENTINEL__" +} + +class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[Any], Long] { private def toBytes(input: T): Array[Byte] = { val bos = new ByteArrayOutputStream() val out = new ObjectOutputStream(bos) @@ -612,29 +617,45 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre private def md5Hex(bytes: Array[Byte]): String = MessageDigest.getInstance("MD5").digest(bytes).map("%02x".format(_)).mkString - private def hashInput(input: T): String = - md5Hex(toBytes(input)) + private def processInput(input: T): Any = { + inputType match { + case IntType | LongType | DoubleType | FloatType | ShortType | BinaryType => + input + case _ => + md5Hex(toBytes(input)) + } + } - override def prepare(input: T): util.Set[String] = { - val result = new util.HashSet[String](k) - result.add(hashInput(input)) + override def prepare(input: T): util.Set[Any] = { + val result = new util.HashSet[Any](k) + result.add(processInput(input)) result } - override def update(ir: util.Set[String], input: T): util.Set[String] = { - if (ir.size() >= k) { - return ir + override def update(ir: util.Set[Any], input: T): util.Set[Any] = { + if (ir == BoundedUniqueCount.SentinelSet || ir.size() >= k) { + return BoundedUniqueCount.SentinelSet } - ir.add(hashInput(input)) + ir.add(processInput(input)) ir } override def outputType: DataType = LongType - override def irType: DataType = ListType(StringType) + override def irType: DataType = + inputType match { + case IntType | LongType | DoubleType | FloatType | ShortType | BinaryType => + ListType(inputType) + case _ => + ListType(StringType) + } + + override def merge(ir1: util.Set[Any], ir2: util.Set[Any]): util.Set[Any] = { + if (ir1 == BoundedUniqueCount.SentinelSet || ir2 == BoundedUniqueCount.SentinelSet) { + return BoundedUniqueCount.SentinelSet + } - override def merge(ir1: util.Set[String], ir2: util.Set[String]): util.Set[String] = { ir2 .iterator() .asScala @@ -643,17 +664,47 @@ class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggre ir1.add(v) }) - ir1 + if (ir1.size() >= k) { + BoundedUniqueCount.SentinelSet + } else { + ir1 + } } - override def finalize(ir: util.Set[String]): Long = ir.size() + override def finalize(ir: util.Set[Any]): Long = { + if (ir == BoundedUniqueCount.SentinelSet) { + k + } else { + ir.size() + } + } - override def clone(ir: util.Set[String]): util.Set[String] = new util.HashSet[String](ir) + override def clone(ir: util.Set[Any]): util.Set[Any] = { + if (ir == BoundedUniqueCount.SentinelSet) { + BoundedUniqueCount.SentinelSet + } else { + new util.HashSet[Any](ir) + } + } - override def normalize(ir: util.Set[String]): Any = new util.ArrayList[String](ir) + override def normalize(ir: util.Set[Any]): Any = { + if (ir == BoundedUniqueCount.SentinelSet) { + val list = new util.ArrayList[Any]() + list.add(BoundedUniqueCount.SentinelMarker) + list + } else { + new util.ArrayList[Any](ir) + } + } - override def denormalize(ir: Any): util.Set[String] = - new util.HashSet[String](ir.asInstanceOf[util.ArrayList[String]]) + override def denormalize(ir: Any): util.Set[Any] = { + val list = ir.asInstanceOf[util.ArrayList[Any]] + if (list.size() == 1 && list.get(0) == BoundedUniqueCount.SentinelMarker) { + BoundedUniqueCount.SentinelSet + } else { + new util.HashSet[Any](list) + } + } } // Based on CPC sketch (a faster, smaller and more accurate version of HLL) diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala index d780aae080..5d4189de1f 100644 --- a/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala @@ -1,7 +1,7 @@ package ai.chronon.aggregator.test import ai.chronon.aggregator.base.BoundedUniqueCount -import ai.chronon.api.StringType +import ai.chronon.api.{StringType, IntType, LongType, DoubleType, FloatType, BinaryType} import junit.framework.TestCase import org.junit.Assert._ @@ -35,10 +35,173 @@ class BoundedUniqueCountTest extends TestCase { def testMerge(): Unit = { val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) - val ir1 = new util.HashSet[String](Seq("1", "2", "3").asJava) - val ir2 = new util.HashSet[String](Seq("4", "5", "6").asJava) + val ir1 = new util.HashSet[Any](Seq("1", "2", "3").asJava) + val ir2 = new util.HashSet[Any](Seq("4", "5", "6").asJava) val merged = boundedDistinctCount.merge(ir1, ir2) - assertEquals(merged.size(), 5) + val result = boundedDistinctCount.finalize(merged) + assertEquals(5, result) // Should return k=5 when exceeding the limit + } + + def testIntTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + var ir = boundedDistinctCount.prepare(1) + ir = boundedDistinctCount.update(ir, 1) + ir = boundedDistinctCount.update(ir, 2) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testIntTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + var ir = boundedDistinctCount.prepare(1) + ir = boundedDistinctCount.update(ir, 2) + ir = boundedDistinctCount.update(ir, 3) + ir = boundedDistinctCount.update(ir, 4) + ir = boundedDistinctCount.update(ir, 5) + ir = boundedDistinctCount.update(ir, 6) + ir = boundedDistinctCount.update(ir, 7) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testIntTypeMerge(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + val ir1 = new util.HashSet[Any](Seq(1, 2, 3).asJava) + val ir2 = new util.HashSet[Any](Seq(4, 5, 6).asJava) + + val merged = boundedDistinctCount.merge(ir1, ir2) + val result = boundedDistinctCount.finalize(merged) + assertEquals(5, result) // Should return k=5 when exceeding the limit + } + + def testLongTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5) + var ir = boundedDistinctCount.prepare(1L) + ir = boundedDistinctCount.update(ir, 1L) + ir = boundedDistinctCount.update(ir, 2L) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testLongTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5) + var ir = boundedDistinctCount.prepare(1L) + ir = boundedDistinctCount.update(ir, 2L) + ir = boundedDistinctCount.update(ir, 3L) + ir = boundedDistinctCount.update(ir, 4L) + ir = boundedDistinctCount.update(ir, 5L) + ir = boundedDistinctCount.update(ir, 6L) + ir = boundedDistinctCount.update(ir, 7L) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testDoubleTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5) + var ir = boundedDistinctCount.prepare(1.0) + ir = boundedDistinctCount.update(ir, 1.0) + ir = boundedDistinctCount.update(ir, 2.5) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testDoubleTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5) + var ir = boundedDistinctCount.prepare(1.0) + ir = boundedDistinctCount.update(ir, 2.5) + ir = boundedDistinctCount.update(ir, 3.7) + ir = boundedDistinctCount.update(ir, 4.2) + ir = boundedDistinctCount.update(ir, 5.8) + ir = boundedDistinctCount.update(ir, 6.3) + ir = boundedDistinctCount.update(ir, 7.9) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testFloatTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5) + var ir = boundedDistinctCount.prepare(1.0f) + ir = boundedDistinctCount.update(ir, 1.0f) + ir = boundedDistinctCount.update(ir, 2.5f) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testFloatTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5) + var ir = boundedDistinctCount.prepare(1.0f) + ir = boundedDistinctCount.update(ir, 2.5f) + ir = boundedDistinctCount.update(ir, 3.7f) + ir = boundedDistinctCount.update(ir, 4.2f) + ir = boundedDistinctCount.update(ir, 5.8f) + ir = boundedDistinctCount.update(ir, 6.3f) + ir = boundedDistinctCount.update(ir, 7.9f) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testBinaryTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + val bytes1 = Array[Byte](1, 2, 3) + val bytes2 = Array[Byte](4, 5, 6) + + var ir = boundedDistinctCount.prepare(bytes1) + ir = boundedDistinctCount.update(ir, bytes1) + ir = boundedDistinctCount.update(ir, bytes2) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testBinaryTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + var ir = boundedDistinctCount.prepare(Array[Byte](1)) + ir = boundedDistinctCount.update(ir, Array[Byte](2)) + ir = boundedDistinctCount.update(ir, Array[Byte](3)) + ir = boundedDistinctCount.update(ir, Array[Byte](4)) + ir = boundedDistinctCount.update(ir, Array[Byte](5)) + ir = boundedDistinctCount.update(ir, Array[Byte](6)) + ir = boundedDistinctCount.update(ir, Array[Byte](7)) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testBinaryTypeMerge(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + val ir1 = new util.HashSet[Any](Seq(Array[Byte](1), Array[Byte](2), Array[Byte](3)).asJava) + val ir2 = new util.HashSet[Any](Seq(Array[Byte](4), Array[Byte](5), Array[Byte](6)).asJava) + + val merged = boundedDistinctCount.merge(ir1, ir2) + val result = boundedDistinctCount.finalize(merged) + assertEquals(5, result) // Should return k=5 when exceeding the limit + } + + def testNumericTypeIrType(): Unit = { + val intBoundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + val longBoundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5) + val doubleBoundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5) + val floatBoundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5) + val binaryBoundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + val stringBoundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + + // For numeric and binary types, irType should be ListType(inputType) + assertEquals(ai.chronon.api.ListType(IntType), intBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(LongType), longBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(DoubleType), doubleBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(FloatType), floatBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(BinaryType), binaryBoundedDistinctCount.irType) + + // For non-numeric types, irType should be ListType(StringType) + assertEquals(ai.chronon.api.ListType(StringType), stringBoundedDistinctCount.irType) } } \ No newline at end of file