diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala index b120d29e7f..520ae135b2 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala @@ -25,6 +25,7 @@ import com.yahoo.sketches.kll.KllFloatsSketch import com.yahoo.sketches.{ArrayOfDoublesSerDe, ArrayOfItemsSerDe, ArrayOfLongsSerDe, ArrayOfStringsSerDe} import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream} +import java.security.MessageDigest import java.util import scala.collection.mutable import scala.jdk.CollectionConverters._ @@ -599,6 +600,113 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy } } +object BoundedUniqueCount { + private val SentinelSet: util.Set[Any] = new util.HashSet[Any]() + private val SentinelMarker: String = "__SENTINEL__" +} + +class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[Any], Long] { + private def toBytes(input: T): Array[Byte] = { + val bos = new ByteArrayOutputStream() + val out = new ObjectOutputStream(bos) + out.writeObject(input) + out.flush() + bos.toByteArray + } + + private def md5Hex(bytes: Array[Byte]): String = + MessageDigest.getInstance("MD5").digest(bytes).map("%02x".format(_)).mkString + + private def processInput(input: T): Any = { + inputType match { + case IntType | LongType | DoubleType | FloatType | ShortType | BinaryType => + input + case _ => + md5Hex(toBytes(input)) + } + } + + override def prepare(input: T): util.Set[Any] = { + val result = new util.HashSet[Any](k) + result.add(processInput(input)) + result + } + + override def update(ir: util.Set[Any], input: T): util.Set[Any] = { + if (ir == BoundedUniqueCount.SentinelSet || ir.size() >= k) { + return BoundedUniqueCount.SentinelSet + } + + ir.add(processInput(input)) + ir + } + + override def outputType: DataType = LongType + + override def irType: DataType = + inputType match { + case IntType | LongType | DoubleType | FloatType | ShortType | BinaryType => + ListType(inputType) + case _ => + ListType(StringType) + } + + override def merge(ir1: util.Set[Any], ir2: util.Set[Any]): util.Set[Any] = { + if (ir1 == BoundedUniqueCount.SentinelSet || ir2 == BoundedUniqueCount.SentinelSet) { + return BoundedUniqueCount.SentinelSet + } + + ir2 + .iterator() + .asScala + .foreach(v => + if (ir1.size() < k) { + ir1.add(v) + }) + + if (ir1.size() >= k) { + BoundedUniqueCount.SentinelSet + } else { + ir1 + } + } + + override def finalize(ir: util.Set[Any]): Long = { + if (ir == BoundedUniqueCount.SentinelSet) { + k + } else { + ir.size() + } + } + + override def clone(ir: util.Set[Any]): util.Set[Any] = { + if (ir == BoundedUniqueCount.SentinelSet) { + BoundedUniqueCount.SentinelSet + } else { + new util.HashSet[Any](ir) + } + } + + override def normalize(ir: util.Set[Any]): Any = { + if (ir == BoundedUniqueCount.SentinelSet) { + val list = new util.ArrayList[Any]() + list.add(BoundedUniqueCount.SentinelMarker) + list + } else { + new util.ArrayList[Any](ir) + } + } + + override def denormalize(ir: Any): util.Set[Any] = { + val list = ir.asInstanceOf[util.ArrayList[Any]] + if (list.size() == 1 && list.get(0) == BoundedUniqueCount.SentinelMarker) { + BoundedUniqueCount.SentinelSet + } else { + new util.HashSet[Any](list) + } + } +} + // Based on CPC sketch (a faster, smaller and more accurate version of HLL) // See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017 // https://arxiv.org/abs/1708.06839 diff --git a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala index d5f21b3072..d0be4ead20 100644 --- a/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala +++ b/aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala @@ -307,7 +307,19 @@ object ColumnAggregator { case BinaryType => simple(new ApproxDistinctCount[Array[Byte]](aggregationPart.getInt("k", Some(8)))) case _ => mismatchException } + case Operation.BOUNDED_UNIQUE_COUNT => + val k = aggregationPart.getInt("k", Some(8)) + inputType match { + case IntType => simple(new BoundedUniqueCount[Int](inputType, k)) + case LongType => simple(new BoundedUniqueCount[Long](inputType, k)) + case ShortType => simple(new BoundedUniqueCount[Short](inputType, k)) + case DoubleType => simple(new BoundedUniqueCount[Double](inputType, k)) + case FloatType => simple(new BoundedUniqueCount[Float](inputType, k)) + case StringType => simple(new BoundedUniqueCount[String](inputType, k)) + case BinaryType => simple(new BoundedUniqueCount[Array[Byte]](inputType, k)) + case _ => mismatchException + } case Operation.APPROX_PERCENTILE => val k = aggregationPart.getInt("k", Some(128)) val mapper = new ObjectMapper() diff --git a/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala new file mode 100644 index 0000000000..5d4189de1f --- /dev/null +++ b/aggregator/src/test/scala/ai/chronon/aggregator/test/BoundedUniqueCountTest.scala @@ -0,0 +1,207 @@ +package ai.chronon.aggregator.test + +import ai.chronon.aggregator.base.BoundedUniqueCount +import ai.chronon.api.{StringType, IntType, LongType, DoubleType, FloatType, BinaryType} +import junit.framework.TestCase +import org.junit.Assert._ + +import java.util +import scala.jdk.CollectionConverters._ + +class BoundedUniqueCountTest extends TestCase { + def testHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + var ir = boundedDistinctCount.prepare("1") + ir = boundedDistinctCount.update(ir, "1") + ir = boundedDistinctCount.update(ir, "2") + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + var ir = boundedDistinctCount.prepare("1") + ir = boundedDistinctCount.update(ir, "2") + ir = boundedDistinctCount.update(ir, "3") + ir = boundedDistinctCount.update(ir, "4") + ir = boundedDistinctCount.update(ir, "5") + ir = boundedDistinctCount.update(ir, "6") + ir = boundedDistinctCount.update(ir, "7") + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testMerge(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + val ir1 = new util.HashSet[Any](Seq("1", "2", "3").asJava) + val ir2 = new util.HashSet[Any](Seq("4", "5", "6").asJava) + + val merged = boundedDistinctCount.merge(ir1, ir2) + val result = boundedDistinctCount.finalize(merged) + assertEquals(5, result) // Should return k=5 when exceeding the limit + } + + def testIntTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + var ir = boundedDistinctCount.prepare(1) + ir = boundedDistinctCount.update(ir, 1) + ir = boundedDistinctCount.update(ir, 2) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testIntTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + var ir = boundedDistinctCount.prepare(1) + ir = boundedDistinctCount.update(ir, 2) + ir = boundedDistinctCount.update(ir, 3) + ir = boundedDistinctCount.update(ir, 4) + ir = boundedDistinctCount.update(ir, 5) + ir = boundedDistinctCount.update(ir, 6) + ir = boundedDistinctCount.update(ir, 7) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testIntTypeMerge(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + val ir1 = new util.HashSet[Any](Seq(1, 2, 3).asJava) + val ir2 = new util.HashSet[Any](Seq(4, 5, 6).asJava) + + val merged = boundedDistinctCount.merge(ir1, ir2) + val result = boundedDistinctCount.finalize(merged) + assertEquals(5, result) // Should return k=5 when exceeding the limit + } + + def testLongTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5) + var ir = boundedDistinctCount.prepare(1L) + ir = boundedDistinctCount.update(ir, 1L) + ir = boundedDistinctCount.update(ir, 2L) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testLongTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5) + var ir = boundedDistinctCount.prepare(1L) + ir = boundedDistinctCount.update(ir, 2L) + ir = boundedDistinctCount.update(ir, 3L) + ir = boundedDistinctCount.update(ir, 4L) + ir = boundedDistinctCount.update(ir, 5L) + ir = boundedDistinctCount.update(ir, 6L) + ir = boundedDistinctCount.update(ir, 7L) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testDoubleTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5) + var ir = boundedDistinctCount.prepare(1.0) + ir = boundedDistinctCount.update(ir, 1.0) + ir = boundedDistinctCount.update(ir, 2.5) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testDoubleTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5) + var ir = boundedDistinctCount.prepare(1.0) + ir = boundedDistinctCount.update(ir, 2.5) + ir = boundedDistinctCount.update(ir, 3.7) + ir = boundedDistinctCount.update(ir, 4.2) + ir = boundedDistinctCount.update(ir, 5.8) + ir = boundedDistinctCount.update(ir, 6.3) + ir = boundedDistinctCount.update(ir, 7.9) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testFloatTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5) + var ir = boundedDistinctCount.prepare(1.0f) + ir = boundedDistinctCount.update(ir, 1.0f) + ir = boundedDistinctCount.update(ir, 2.5f) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testFloatTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5) + var ir = boundedDistinctCount.prepare(1.0f) + ir = boundedDistinctCount.update(ir, 2.5f) + ir = boundedDistinctCount.update(ir, 3.7f) + ir = boundedDistinctCount.update(ir, 4.2f) + ir = boundedDistinctCount.update(ir, 5.8f) + ir = boundedDistinctCount.update(ir, 6.3f) + ir = boundedDistinctCount.update(ir, 7.9f) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testBinaryTypeHappyCase(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + val bytes1 = Array[Byte](1, 2, 3) + val bytes2 = Array[Byte](4, 5, 6) + + var ir = boundedDistinctCount.prepare(bytes1) + ir = boundedDistinctCount.update(ir, bytes1) + ir = boundedDistinctCount.update(ir, bytes2) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(2, result) + } + + def testBinaryTypeExceedSize(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + var ir = boundedDistinctCount.prepare(Array[Byte](1)) + ir = boundedDistinctCount.update(ir, Array[Byte](2)) + ir = boundedDistinctCount.update(ir, Array[Byte](3)) + ir = boundedDistinctCount.update(ir, Array[Byte](4)) + ir = boundedDistinctCount.update(ir, Array[Byte](5)) + ir = boundedDistinctCount.update(ir, Array[Byte](6)) + ir = boundedDistinctCount.update(ir, Array[Byte](7)) + + val result = boundedDistinctCount.finalize(ir) + assertEquals(5, result) + } + + def testBinaryTypeMerge(): Unit = { + val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + val ir1 = new util.HashSet[Any](Seq(Array[Byte](1), Array[Byte](2), Array[Byte](3)).asJava) + val ir2 = new util.HashSet[Any](Seq(Array[Byte](4), Array[Byte](5), Array[Byte](6)).asJava) + + val merged = boundedDistinctCount.merge(ir1, ir2) + val result = boundedDistinctCount.finalize(merged) + assertEquals(5, result) // Should return k=5 when exceeding the limit + } + + def testNumericTypeIrType(): Unit = { + val intBoundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5) + val longBoundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5) + val doubleBoundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5) + val floatBoundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5) + val binaryBoundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5) + val stringBoundedDistinctCount = new BoundedUniqueCount[String](StringType, 5) + + // For numeric and binary types, irType should be ListType(inputType) + assertEquals(ai.chronon.api.ListType(IntType), intBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(LongType), longBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(DoubleType), doubleBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(FloatType), floatBoundedDistinctCount.irType) + assertEquals(ai.chronon.api.ListType(BinaryType), binaryBoundedDistinctCount.irType) + + // For non-numeric types, irType should be ListType(StringType) + assertEquals(ai.chronon.api.ListType(StringType), stringBoundedDistinctCount.irType) + } +} \ No newline at end of file diff --git a/api/py/ai/chronon/group_by.py b/api/py/ai/chronon/group_by.py index c171113356..e08d9626ad 100644 --- a/api/py/ai/chronon/group_by.py +++ b/api/py/ai/chronon/group_by.py @@ -65,6 +65,7 @@ class Operation: # https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180 APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT) UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT + BOUNDED_UNIQUE_COUNT_K = collector(ttypes.Operation.BOUNDED_UNIQUE_COUNT) COUNT = ttypes.Operation.COUNT SUM = ttypes.Operation.SUM AVERAGE = ttypes.Operation.AVERAGE diff --git a/api/thrift/api.thrift b/api/thrift/api.thrift index c75e566e1a..2ca6242ec5 100644 --- a/api/thrift/api.thrift +++ b/api/thrift/api.thrift @@ -161,7 +161,8 @@ enum Operation { BOTTOM_K = 16 HISTOGRAM = 17, // use this only if you know the set of inputs is bounded - APPROX_HISTOGRAM_K = 18 + APPROX_HISTOGRAM_K = 18, + BOUNDED_UNIQUE_COUNT = 19 } // integers map to milliseconds in the timeunit diff --git a/docs/source/authoring_features/GroupBy.md b/docs/source/authoring_features/GroupBy.md index c972f4cff3..7f1b7acefe 100644 --- a/docs/source/authoring_features/GroupBy.md +++ b/docs/source/authoring_features/GroupBy.md @@ -147,6 +147,7 @@ Limitations: | approx_unique_count | primitive types | list, map | long | no | k=8 | yes | | approx_percentile | primitive types | list, map | list | no | k=128, percentiles | yes | | unique_count | primitive types | list, map | long | no | | no | +| bounded_unique_count | primitive types | list, map | long | no | k=inf | yes | ## Accuracy diff --git a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala index 051ec1be73..a046a16f57 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/FetcherTest.scala @@ -273,8 +273,11 @@ class FetcherTest extends TestCase { Builders.Aggregation(operation = Operation.LAST_K, argMap = Map("k" -> "300"), inputColumn = "user", - windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))) - ), + windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS))), + Builders.Aggregation(operation = Operation.BOUNDED_UNIQUE_COUNT, + argMap = Map("k" -> "5"), + inputColumn = "user", + windows = Seq(new Window(2, TimeUnit.DAYS), new Window(30, TimeUnit.DAYS)))), metaData = Builders.MetaData(name = "unit_test/vendor_ratings", namespace = namespace), accuracy = Accuracy.SNAPSHOT ) @@ -503,6 +506,11 @@ class FetcherTest extends TestCase { operation = Operation.APPROX_HISTOGRAM_K, inputColumn = "rating", windows = Seq(new Window(1, TimeUnit.DAYS)) + ), + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "rating", + windows = Seq(new Window(1, TimeUnit.DAYS)) ) ), accuracy = Accuracy.TEMPORAL, diff --git a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala index 12c64b9cfa..289b4048be 100644 --- a/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala +++ b/spark/src/test/scala/ai/chronon/spark/test/GroupByTest.scala @@ -724,4 +724,47 @@ class GroupByTest { assert(count > 0, s"Found a count value that is not greater than zero: $count") } } + + @Test + def testBoundedUniqueCounts(): Unit = { + lazy val spark: SparkSession = SparkSessionBuilder.build("GroupByTest" + "_" + Random.alphanumeric.take(6).mkString, local = true) + val (source, endPartition) = createTestSource(suffix = "_bounded_counts") + val tableUtils = TableUtils(spark) + val namespace = "test_bounded_counts" + val aggs = Seq( + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "item", + windows = Seq( + new Window(15, TimeUnit.DAYS), + new Window(60, TimeUnit.DAYS) + ), + argMap = Map("k" -> "5") + ), + Builders.Aggregation( + operation = Operation.BOUNDED_UNIQUE_COUNT, + inputColumn = "price", + windows = Seq( + new Window(15, TimeUnit.DAYS), + new Window(60, TimeUnit.DAYS) + ), + argMap = Map("k" -> "5") + ), + ) + backfill(name = "unit_test_group_by_bounded_counts", + source = source, + endPartition = endPartition, + namespace = namespace, + tableUtils = tableUtils, + additionalAgg = aggs) + + val result = spark.sql( + """ + |select * + |from test_bounded_counts.unit_test_group_by_bounded_counts + |where item_bounded_unique_count_60d > 5 or price_bounded_unique_count_60d > 5 + |""".stripMargin) + + assertTrue(result.isEmpty) + } }