diff --git a/src/main/java/org/apache/datasketches/hll/HllUnion.java b/src/main/java/org/apache/datasketches/hll/HllUnion.java
index dca2cd7da..2738d68d7 100644
--- a/src/main/java/org/apache/datasketches/hll/HllUnion.java
+++ b/src/main/java/org/apache/datasketches/hll/HllUnion.java
@@ -176,7 +176,7 @@ public double getEstimate() {
checkRebuildCurMinNumKxQ(gadget);
return gadget.getEstimate();
}
-
+
/**
* Gets the effective lgConfigK for the HllUnion operator, which may be less than
* lgMaxK.
@@ -320,6 +320,14 @@ public void update(final HllSketch sketch) {
gadget.hllSketchImpl = unionImpl(sketch, gadget, lgMaxK);
}
+ /**
+ * Update this HllUnion operator with the given HllUnion.
+ * @param union the given HllUnion.
+ */
+ public void update(final HllUnion union) {
+ gadget.hllSketchImpl = unionImpl(union.gadget, gadget, lgMaxK);
+ }
+
@Override
void couponUpdate(final int coupon) {
if (coupon == EMPTY) { return; }
diff --git a/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java b/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java
index 84ec7ce9a..608585f5e 100644
--- a/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java
+++ b/src/test/java/org/apache/datasketches/hll/UnionCaseTest.java
@@ -31,6 +31,8 @@
import static org.testng.Assert.assertTrue;
import java.lang.foreign.MemorySegment;
+//import java.lang.invoke.MethodHandles;
+//import java.lang.invoke.VarHandle;
import org.apache.datasketches.common.SketchesStateException;
import org.testng.annotations.Test;
@@ -42,8 +44,8 @@ public class UnionCaseTest {
private static final String LS = System.getProperty("line.separator");
long v = 0;
final static int maxLgK = 12;
- HllSketch source;
- //HllUnion union;
+ HllSketch skSource;
+ HllUnion uSource;
String hfmt = "%10s%10s%10s%10s%10s%10s%10s%10s%10s%10s%10s" + LS;
String hdr = String.format(hfmt, "caseNum","srcLgKStr","gdtLgKStr","srcType","gdtType",
"srcSeg","gdtSeg","srcMode","gdtMode","srcOoof","gdtOoof");
@@ -52,48 +54,68 @@ public class UnionCaseTest {
public void checkAllCases() {
print(hdr);
for (int i = 0; i < 24; i++) {
- checkCase(i, HLL_4, false);
+ checkCase(i, HLL_4, false, false);
}
println("");
print(hdr);
for (int i = 0; i < 24; i++) {
- checkCase(i, HLL_6, false);
+ checkCase(i, HLL_6, false, false);
}
println("");
print(hdr);
for (int i = 0; i < 24; i++) {
- checkCase(i, HLL_8, false);
+ checkCase(i, HLL_8, false, false);
}
println("");
print(hdr);
for (int i = 0; i < 24; i++) {
- checkCase(i, HLL_4, true);
+ checkCase(i, HLL_8, false, true); //srcUnion
+ }
+ println("");
+
+ print(hdr);
+ for (int i = 0; i < 24; i++) {
+ checkCase(i, HLL_4, true, false);
}
println("");
print(hdr);
for (int i = 0; i < 24; i++) {
- checkCase(i, HLL_6, true);
+ checkCase(i, HLL_6, true, false);
}
println("");
print(hdr);
for (int i = 0; i < 24; i++) {
- checkCase(i, HLL_8, true);
+ checkCase(i, HLL_8, true, false);
+ }
+ println("");
+
+ print(hdr);
+ for (int i = 0; i < 24; i++) {
+ checkCase(i, HLL_8, true, true); //srcUnion
}
println("");
}
- private void checkCase(final int caseNum, final TgtHllType srcType, final boolean srcSeg) {
- source = getSource(caseNum, srcType, srcSeg);
+ private void checkCase(final int caseNum, final TgtHllType srcType, final boolean srcSeg, final boolean srcUnion) {
+ if (srcUnion) {
+ uSource = getUnionSrc(caseNum);
+ } else {
+ skSource = getSkSource(caseNum, srcType, srcSeg);
+ }
final boolean gdtSeg = (caseNum & 1) > 0;
final HllUnion union = getUnion(caseNum, gdtSeg);
- union.update(source);
+ if (srcUnion) {
+ union.update(uSource);
+ } else {
+ union.update(skSource);
+ }
final int totalU = getSrcCount(caseNum, maxLgK) + getUnionCount(caseNum);
- output(caseNum, source, union, totalU);
+ output(caseNum, skSource, union, totalU);
}
private void output(final int caseNum, final HllSketch source, final HllUnion union, final int totalU) {
@@ -121,7 +143,7 @@ private void output(final int caseNum, final HllSketch source, final HllUnion un
assertTrue(err < rse, "Err: " + err + ", RSE: " + rse);
}
- private HllSketch getSource(final int caseNum, final TgtHllType tgtHllType, final boolean useMemorySegment) {
+ private HllSketch getSkSource(final int caseNum, final TgtHllType tgtHllType, final boolean useMemorySegment) {
final int srcLgK = getSrcLgK(caseNum, maxLgK);
final int srcU = getSrcCount(caseNum, maxLgK);
if (useMemorySegment) {
@@ -131,9 +153,18 @@ private HllSketch getSource(final int caseNum, final TgtHllType tgtHllType, fina
}
}
+ private HllUnion getUnionSrc(final int caseNum) {
+ final int srcLgK = getSrcLgK(caseNum, maxLgK);
+ final int srcU = getSrcCount(caseNum, maxLgK);
+ final HllSketch sk = buildHeapSketch(srcLgK, HLL_8, srcU);
+ final HllUnion u = new HllUnion(maxLgK);
+ u.update(sk);
+ return u;
+ }
+
private HllUnion getUnion(final int caseNum, final boolean useMemorySegment) {
final int unionU = getUnionCount(caseNum);
- return (useMemorySegment) ? buildMemorSegmentUnion(maxLgK, unionU) : buildHeapUnion(maxLgK, unionU);
+ return (useMemorySegment) ? buildMemorySegmentUnion(maxLgK, unionU) : buildHeapUnion(maxLgK, unionU);
}
private static int getUnionCount(final int caseNum) {
@@ -394,7 +425,7 @@ private HllUnion buildHeapUnion(final int lgMaxK, final int n) {
return u;
}
- private HllUnion buildMemorSegmentUnion(final int lgMaxK, final int n) {
+ private HllUnion buildMemorySegmentUnion(final int lgMaxK, final int n) {
final int bytes = HllSketch.getMaxUpdatableSerializationBytes(lgMaxK, TgtHllType.HLL_8);
final MemorySegment wseg = MemorySegment.ofArray(new byte[bytes]);
final HllUnion u = new HllUnion(lgMaxK, wseg);