Skip to content

Commit 5da9c32

Browse files
Add bounded unique count aggregation (#781)
* Add bounded unique count aggregation * use hash value instead * scala fmt * fix ir type * use iterator * scala fmt * Optimize BoundedUniqueCount for numeric types with sentinel set pattern --------- Co-authored-by: Pengyu Hou <[email protected]>
1 parent 3a4d8d6 commit 5da9c32

File tree

8 files changed

+384
-3
lines changed

8 files changed

+384
-3
lines changed

aggregator/src/main/scala/ai/chronon/aggregator/base/SimpleAggregators.scala

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import com.yahoo.sketches.kll.KllFloatsSketch
2525
import com.yahoo.sketches.{ArrayOfDoublesSerDe, ArrayOfItemsSerDe, ArrayOfLongsSerDe, ArrayOfStringsSerDe}
2626

2727
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
28+
import java.security.MessageDigest
2829
import java.util
2930
import scala.collection.mutable
3031
import scala.jdk.CollectionConverters._
@@ -599,6 +600,113 @@ class ApproxHistogram[T: FrequentItemsFriendly](mapSize: Int, errorType: ErrorTy
599600
}
600601
}
601602

603+
object BoundedUniqueCount {
604+
private val SentinelSet: util.Set[Any] = new util.HashSet[Any]()
605+
private val SentinelMarker: String = "__SENTINEL__"
606+
}
607+
608+
class BoundedUniqueCount[T](inputType: DataType, k: Int = 8) extends SimpleAggregator[T, util.Set[Any], Long] {
609+
private def toBytes(input: T): Array[Byte] = {
610+
val bos = new ByteArrayOutputStream()
611+
val out = new ObjectOutputStream(bos)
612+
out.writeObject(input)
613+
out.flush()
614+
bos.toByteArray
615+
}
616+
617+
private def md5Hex(bytes: Array[Byte]): String =
618+
MessageDigest.getInstance("MD5").digest(bytes).map("%02x".format(_)).mkString
619+
620+
private def processInput(input: T): Any = {
621+
inputType match {
622+
case IntType | LongType | DoubleType | FloatType | ShortType | BinaryType =>
623+
input
624+
case _ =>
625+
md5Hex(toBytes(input))
626+
}
627+
}
628+
629+
override def prepare(input: T): util.Set[Any] = {
630+
val result = new util.HashSet[Any](k)
631+
result.add(processInput(input))
632+
result
633+
}
634+
635+
override def update(ir: util.Set[Any], input: T): util.Set[Any] = {
636+
if (ir == BoundedUniqueCount.SentinelSet || ir.size() >= k) {
637+
return BoundedUniqueCount.SentinelSet
638+
}
639+
640+
ir.add(processInput(input))
641+
ir
642+
}
643+
644+
override def outputType: DataType = LongType
645+
646+
override def irType: DataType =
647+
inputType match {
648+
case IntType | LongType | DoubleType | FloatType | ShortType | BinaryType =>
649+
ListType(inputType)
650+
case _ =>
651+
ListType(StringType)
652+
}
653+
654+
override def merge(ir1: util.Set[Any], ir2: util.Set[Any]): util.Set[Any] = {
655+
if (ir1 == BoundedUniqueCount.SentinelSet || ir2 == BoundedUniqueCount.SentinelSet) {
656+
return BoundedUniqueCount.SentinelSet
657+
}
658+
659+
ir2
660+
.iterator()
661+
.asScala
662+
.foreach(v =>
663+
if (ir1.size() < k) {
664+
ir1.add(v)
665+
})
666+
667+
if (ir1.size() >= k) {
668+
BoundedUniqueCount.SentinelSet
669+
} else {
670+
ir1
671+
}
672+
}
673+
674+
override def finalize(ir: util.Set[Any]): Long = {
675+
if (ir == BoundedUniqueCount.SentinelSet) {
676+
k
677+
} else {
678+
ir.size()
679+
}
680+
}
681+
682+
override def clone(ir: util.Set[Any]): util.Set[Any] = {
683+
if (ir == BoundedUniqueCount.SentinelSet) {
684+
BoundedUniqueCount.SentinelSet
685+
} else {
686+
new util.HashSet[Any](ir)
687+
}
688+
}
689+
690+
override def normalize(ir: util.Set[Any]): Any = {
691+
if (ir == BoundedUniqueCount.SentinelSet) {
692+
val list = new util.ArrayList[Any]()
693+
list.add(BoundedUniqueCount.SentinelMarker)
694+
list
695+
} else {
696+
new util.ArrayList[Any](ir)
697+
}
698+
}
699+
700+
override def denormalize(ir: Any): util.Set[Any] = {
701+
val list = ir.asInstanceOf[util.ArrayList[Any]]
702+
if (list.size() == 1 && list.get(0) == BoundedUniqueCount.SentinelMarker) {
703+
BoundedUniqueCount.SentinelSet
704+
} else {
705+
new util.HashSet[Any](list)
706+
}
707+
}
708+
}
709+
602710
// Based on CPC sketch (a faster, smaller and more accurate version of HLL)
603711
// See: Back to the future: an even more nearly optimal cardinality estimation algorithm, 2017
604712
// https://arxiv.org/abs/1708.06839

aggregator/src/main/scala/ai/chronon/aggregator/row/ColumnAggregator.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,19 @@ object ColumnAggregator {
307307
case BinaryType => simple(new ApproxDistinctCount[Array[Byte]](aggregationPart.getInt("k", Some(8))))
308308
case _ => mismatchException
309309
}
310+
case Operation.BOUNDED_UNIQUE_COUNT =>
311+
val k = aggregationPart.getInt("k", Some(8))
310312

313+
inputType match {
314+
case IntType => simple(new BoundedUniqueCount[Int](inputType, k))
315+
case LongType => simple(new BoundedUniqueCount[Long](inputType, k))
316+
case ShortType => simple(new BoundedUniqueCount[Short](inputType, k))
317+
case DoubleType => simple(new BoundedUniqueCount[Double](inputType, k))
318+
case FloatType => simple(new BoundedUniqueCount[Float](inputType, k))
319+
case StringType => simple(new BoundedUniqueCount[String](inputType, k))
320+
case BinaryType => simple(new BoundedUniqueCount[Array[Byte]](inputType, k))
321+
case _ => mismatchException
322+
}
311323
case Operation.APPROX_PERCENTILE =>
312324
val k = aggregationPart.getInt("k", Some(128))
313325
val mapper = new ObjectMapper()
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
package ai.chronon.aggregator.test
2+
3+
import ai.chronon.aggregator.base.BoundedUniqueCount
4+
import ai.chronon.api.{StringType, IntType, LongType, DoubleType, FloatType, BinaryType}
5+
import junit.framework.TestCase
6+
import org.junit.Assert._
7+
8+
import java.util
9+
import scala.jdk.CollectionConverters._
10+
11+
class BoundedUniqueCountTest extends TestCase {
12+
def testHappyCase(): Unit = {
13+
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
14+
var ir = boundedDistinctCount.prepare("1")
15+
ir = boundedDistinctCount.update(ir, "1")
16+
ir = boundedDistinctCount.update(ir, "2")
17+
18+
val result = boundedDistinctCount.finalize(ir)
19+
assertEquals(2, result)
20+
}
21+
22+
def testExceedSize(): Unit = {
23+
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
24+
var ir = boundedDistinctCount.prepare("1")
25+
ir = boundedDistinctCount.update(ir, "2")
26+
ir = boundedDistinctCount.update(ir, "3")
27+
ir = boundedDistinctCount.update(ir, "4")
28+
ir = boundedDistinctCount.update(ir, "5")
29+
ir = boundedDistinctCount.update(ir, "6")
30+
ir = boundedDistinctCount.update(ir, "7")
31+
32+
val result = boundedDistinctCount.finalize(ir)
33+
assertEquals(5, result)
34+
}
35+
36+
def testMerge(): Unit = {
37+
val boundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
38+
val ir1 = new util.HashSet[Any](Seq("1", "2", "3").asJava)
39+
val ir2 = new util.HashSet[Any](Seq("4", "5", "6").asJava)
40+
41+
val merged = boundedDistinctCount.merge(ir1, ir2)
42+
val result = boundedDistinctCount.finalize(merged)
43+
assertEquals(5, result) // Should return k=5 when exceeding the limit
44+
}
45+
46+
def testIntTypeHappyCase(): Unit = {
47+
val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5)
48+
var ir = boundedDistinctCount.prepare(1)
49+
ir = boundedDistinctCount.update(ir, 1)
50+
ir = boundedDistinctCount.update(ir, 2)
51+
52+
val result = boundedDistinctCount.finalize(ir)
53+
assertEquals(2, result)
54+
}
55+
56+
def testIntTypeExceedSize(): Unit = {
57+
val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5)
58+
var ir = boundedDistinctCount.prepare(1)
59+
ir = boundedDistinctCount.update(ir, 2)
60+
ir = boundedDistinctCount.update(ir, 3)
61+
ir = boundedDistinctCount.update(ir, 4)
62+
ir = boundedDistinctCount.update(ir, 5)
63+
ir = boundedDistinctCount.update(ir, 6)
64+
ir = boundedDistinctCount.update(ir, 7)
65+
66+
val result = boundedDistinctCount.finalize(ir)
67+
assertEquals(5, result)
68+
}
69+
70+
def testIntTypeMerge(): Unit = {
71+
val boundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5)
72+
val ir1 = new util.HashSet[Any](Seq(1, 2, 3).asJava)
73+
val ir2 = new util.HashSet[Any](Seq(4, 5, 6).asJava)
74+
75+
val merged = boundedDistinctCount.merge(ir1, ir2)
76+
val result = boundedDistinctCount.finalize(merged)
77+
assertEquals(5, result) // Should return k=5 when exceeding the limit
78+
}
79+
80+
def testLongTypeHappyCase(): Unit = {
81+
val boundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5)
82+
var ir = boundedDistinctCount.prepare(1L)
83+
ir = boundedDistinctCount.update(ir, 1L)
84+
ir = boundedDistinctCount.update(ir, 2L)
85+
86+
val result = boundedDistinctCount.finalize(ir)
87+
assertEquals(2, result)
88+
}
89+
90+
def testLongTypeExceedSize(): Unit = {
91+
val boundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5)
92+
var ir = boundedDistinctCount.prepare(1L)
93+
ir = boundedDistinctCount.update(ir, 2L)
94+
ir = boundedDistinctCount.update(ir, 3L)
95+
ir = boundedDistinctCount.update(ir, 4L)
96+
ir = boundedDistinctCount.update(ir, 5L)
97+
ir = boundedDistinctCount.update(ir, 6L)
98+
ir = boundedDistinctCount.update(ir, 7L)
99+
100+
val result = boundedDistinctCount.finalize(ir)
101+
assertEquals(5, result)
102+
}
103+
104+
def testDoubleTypeHappyCase(): Unit = {
105+
val boundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5)
106+
var ir = boundedDistinctCount.prepare(1.0)
107+
ir = boundedDistinctCount.update(ir, 1.0)
108+
ir = boundedDistinctCount.update(ir, 2.5)
109+
110+
val result = boundedDistinctCount.finalize(ir)
111+
assertEquals(2, result)
112+
}
113+
114+
def testDoubleTypeExceedSize(): Unit = {
115+
val boundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5)
116+
var ir = boundedDistinctCount.prepare(1.0)
117+
ir = boundedDistinctCount.update(ir, 2.5)
118+
ir = boundedDistinctCount.update(ir, 3.7)
119+
ir = boundedDistinctCount.update(ir, 4.2)
120+
ir = boundedDistinctCount.update(ir, 5.8)
121+
ir = boundedDistinctCount.update(ir, 6.3)
122+
ir = boundedDistinctCount.update(ir, 7.9)
123+
124+
val result = boundedDistinctCount.finalize(ir)
125+
assertEquals(5, result)
126+
}
127+
128+
def testFloatTypeHappyCase(): Unit = {
129+
val boundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5)
130+
var ir = boundedDistinctCount.prepare(1.0f)
131+
ir = boundedDistinctCount.update(ir, 1.0f)
132+
ir = boundedDistinctCount.update(ir, 2.5f)
133+
134+
val result = boundedDistinctCount.finalize(ir)
135+
assertEquals(2, result)
136+
}
137+
138+
def testFloatTypeExceedSize(): Unit = {
139+
val boundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5)
140+
var ir = boundedDistinctCount.prepare(1.0f)
141+
ir = boundedDistinctCount.update(ir, 2.5f)
142+
ir = boundedDistinctCount.update(ir, 3.7f)
143+
ir = boundedDistinctCount.update(ir, 4.2f)
144+
ir = boundedDistinctCount.update(ir, 5.8f)
145+
ir = boundedDistinctCount.update(ir, 6.3f)
146+
ir = boundedDistinctCount.update(ir, 7.9f)
147+
148+
val result = boundedDistinctCount.finalize(ir)
149+
assertEquals(5, result)
150+
}
151+
152+
def testBinaryTypeHappyCase(): Unit = {
153+
val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5)
154+
val bytes1 = Array[Byte](1, 2, 3)
155+
val bytes2 = Array[Byte](4, 5, 6)
156+
157+
var ir = boundedDistinctCount.prepare(bytes1)
158+
ir = boundedDistinctCount.update(ir, bytes1)
159+
ir = boundedDistinctCount.update(ir, bytes2)
160+
161+
val result = boundedDistinctCount.finalize(ir)
162+
assertEquals(2, result)
163+
}
164+
165+
def testBinaryTypeExceedSize(): Unit = {
166+
val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5)
167+
var ir = boundedDistinctCount.prepare(Array[Byte](1))
168+
ir = boundedDistinctCount.update(ir, Array[Byte](2))
169+
ir = boundedDistinctCount.update(ir, Array[Byte](3))
170+
ir = boundedDistinctCount.update(ir, Array[Byte](4))
171+
ir = boundedDistinctCount.update(ir, Array[Byte](5))
172+
ir = boundedDistinctCount.update(ir, Array[Byte](6))
173+
ir = boundedDistinctCount.update(ir, Array[Byte](7))
174+
175+
val result = boundedDistinctCount.finalize(ir)
176+
assertEquals(5, result)
177+
}
178+
179+
def testBinaryTypeMerge(): Unit = {
180+
val boundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5)
181+
val ir1 = new util.HashSet[Any](Seq(Array[Byte](1), Array[Byte](2), Array[Byte](3)).asJava)
182+
val ir2 = new util.HashSet[Any](Seq(Array[Byte](4), Array[Byte](5), Array[Byte](6)).asJava)
183+
184+
val merged = boundedDistinctCount.merge(ir1, ir2)
185+
val result = boundedDistinctCount.finalize(merged)
186+
assertEquals(5, result) // Should return k=5 when exceeding the limit
187+
}
188+
189+
def testNumericTypeIrType(): Unit = {
190+
val intBoundedDistinctCount = new BoundedUniqueCount[Int](IntType, 5)
191+
val longBoundedDistinctCount = new BoundedUniqueCount[Long](LongType, 5)
192+
val doubleBoundedDistinctCount = new BoundedUniqueCount[Double](DoubleType, 5)
193+
val floatBoundedDistinctCount = new BoundedUniqueCount[Float](FloatType, 5)
194+
val binaryBoundedDistinctCount = new BoundedUniqueCount[Array[Byte]](BinaryType, 5)
195+
val stringBoundedDistinctCount = new BoundedUniqueCount[String](StringType, 5)
196+
197+
// For numeric and binary types, irType should be ListType(inputType)
198+
assertEquals(ai.chronon.api.ListType(IntType), intBoundedDistinctCount.irType)
199+
assertEquals(ai.chronon.api.ListType(LongType), longBoundedDistinctCount.irType)
200+
assertEquals(ai.chronon.api.ListType(DoubleType), doubleBoundedDistinctCount.irType)
201+
assertEquals(ai.chronon.api.ListType(FloatType), floatBoundedDistinctCount.irType)
202+
assertEquals(ai.chronon.api.ListType(BinaryType), binaryBoundedDistinctCount.irType)
203+
204+
// For non-numeric types, irType should be ListType(StringType)
205+
assertEquals(ai.chronon.api.ListType(StringType), stringBoundedDistinctCount.irType)
206+
}
207+
}

api/py/ai/chronon/group_by.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class Operation:
6565
# https://github.com/apache/incubator-datasketches-java/blob/master/src/main/java/org/apache/datasketches/cpc/CpcSketch.java#L180
6666
APPROX_UNIQUE_COUNT_LGK = collector(ttypes.Operation.APPROX_UNIQUE_COUNT)
6767
UNIQUE_COUNT = ttypes.Operation.UNIQUE_COUNT
68+
BOUNDED_UNIQUE_COUNT_K = collector(ttypes.Operation.BOUNDED_UNIQUE_COUNT)
6869
COUNT = ttypes.Operation.COUNT
6970
SUM = ttypes.Operation.SUM
7071
AVERAGE = ttypes.Operation.AVERAGE

api/thrift/api.thrift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ enum Operation {
161161
BOTTOM_K = 16
162162

163163
HISTOGRAM = 17, // use this only if you know the set of inputs is bounded
164-
APPROX_HISTOGRAM_K = 18
164+
APPROX_HISTOGRAM_K = 18,
165+
BOUNDED_UNIQUE_COUNT = 19
165166
}
166167

167168
// integers map to milliseconds in the timeunit

docs/source/authoring_features/GroupBy.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ Limitations:
147147
| approx_unique_count | primitive types | list, map | long | no | k=8 | yes |
148148
| approx_percentile | primitive types | list, map | list<input,> | no | k=128, percentiles | yes |
149149
| unique_count | primitive types | list, map | long | no | | no |
150+
| bounded_unique_count | primitive types | list, map | long | no | k=inf | yes |
150151

151152

152153
## Accuracy

0 commit comments

Comments
 (0)