Skip to content
78 changes: 67 additions & 11 deletions src/java/org/apache/cassandra/db/tries/CollectionMergeCursor.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ abstract class CollectionMergeCursor<T, C extends Cursor<T>> implements Cursor<T
/// The collected content.
T collectedContent;

private long currentPosition;

<I> CollectionMergeCursor(Trie.CollectionMergeResolver<T> resolver, Direction direction, Collection<I> inputs, IntFunction<C[]> cursorArrayConstructor, BiFunction<I, Direction, C> extractor)
{
this.resolver = resolver;
Expand All @@ -114,15 +116,16 @@ <I> CollectionMergeCursor(Trie.CollectionMergeResolver<T> resolver, Direction di
++i;
}

// The cursors are all currently positioned on the root and thus in valid heap order.
// Initialize currentPosition since encodedPosition() now returns it directly
collectAndCachePositionFlags();
}

/// Interface for internal operations that can be applied to selected top elements of the heap.
interface HeapOp<T, C extends Cursor<T>>
{
void apply(CollectionMergeCursor<T, C> self, C cursor, int index);

default boolean shouldContinueWithChild(C child, C head)
default boolean shouldContinueWithChild(CollectionMergeCursor<T, C> self, C child, C head)
{
return equalCursor(child, head);
}
Expand Down Expand Up @@ -199,7 +202,7 @@ private void applyToSelectedElementsInHeap(HeapOp<T, C> action, int index)
if (index >= heap.length)
return;
C item = heap[index];
if (!action.shouldContinueWithChild(item, head))
if (!action.shouldContinueWithChild(this, item, head))
return;

// If the children are at the same position, they also need advancing and their subheap
Expand All @@ -212,6 +215,52 @@ private void applyToSelectedElementsInHeap(HeapOp<T, C> action, int index)
action.apply(this, item, index);
}

/// Collects and caches the current position by unioning flags from all cursors at the same position.
/// This is called after advancing to ensure the position is always up-to-date.
private long collectAndCachePositionFlags()
{
long pos = head.encodedPosition();
if (Cursor.isExhausted(pos) || !branchHasMultipleSources())
{
currentPosition = pos;
return currentPosition;
}

if ((pos & Cursor.MAY_HAVE_CONTENT_BIT) == Cursor.MAY_HAVE_CONTENT_BIT)
{
currentPosition = pos;
return currentPosition;
}

currentPosition = pos;

// Walk the heap to collect flags from all equal cursors, stopping early if all flags are collected
applyToSelectedElementsInHeap(FLAG_COLLECTOR, 0);

// Position bits must match for all selected cursors, so we don't need unionFlags
return currentPosition;
}

/// HeapOp to collect flags from heap cursors, with early termination when all flags are collected
private static class FlagCollector<T, C extends Cursor<T>> implements HeapOp<T, C>
{
@Override
public void apply(CollectionMergeCursor<T, C> self, C cursor, int index)
{
self.currentPosition |= cursor.encodedPosition();
}

@Override
public boolean shouldContinueWithChild(CollectionMergeCursor<T, C> self, C child, C head)
{
// Continue only if equal AND the content flag is not yet collected.
return equalCursor(child, head) && (self.currentPosition & Cursor.MAY_HAVE_CONTENT_BIT) == 0;
}
}

@SuppressWarnings("rawtypes")
private static final HeapOp FLAG_COLLECTOR = new FlagCollector();

/// Push the given state down in the heap from the given index until it finds its proper place among
/// the subheap rooted at that position.
private void heapifyDown(C item, int index)
Expand Down Expand Up @@ -241,14 +290,21 @@ private void heapifyDown(C item, int index)
private long maybeSwapHead(long headPosition)
{
long heap0Position = heap[0].encodedPosition();
if (Cursor.compare(headPosition, heap0Position) <= 0)
long cmp = Cursor.compare(headPosition, heap0Position);
if (cmp < 0)
{
currentPosition = headPosition;
return headPosition; // head is still smallest
}

// otherwise we need to swap heap and heap[0]
C newHeap0 = head;
head = heap[0];
heapifyDown(newHeap0, 0);
return heap0Position;
if (cmp > 0)
{
// otherwise we need to swap heap and heap[0]
C newHeap0 = head;
head = heap[0];
heapifyDown(newHeap0, 0);
}
return collectAndCachePositionFlags();
}

boolean branchHasMultipleSources()
Expand Down Expand Up @@ -297,7 +353,7 @@ public long skipTo(long encodedSkipPosition)
class SkipTo implements AdvancingHeapOp<T, C>
{
@Override
public boolean shouldContinueWithChild(C child, C head)
public boolean shouldContinueWithChild(CollectionMergeCursor<T, C> self, C child, C head)
{
// When the requested position descends, the implicit prefix bytes are those of the head cursor,
// and thus we need to check against that if it is a match.
Expand All @@ -324,7 +380,7 @@ public void apply(C cursor)
@Override
public long encodedPosition()
{
return head.encodedPosition();
return currentPosition;
}

@Override
Expand Down
65 changes: 53 additions & 12 deletions src/java/org/apache/cassandra/db/tries/Cursor.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,12 @@ interface Cursor<T>
/// Used for sets and ranges to correctly define the range states for branch-inclusive ranges.
long ON_RETURN_PATH_BIT = 1L << 19;

/// Mask for the bits used for flags that should not affect position comparison.
long FLAGS_MASK = 0xFFFFL;

/// Flag indicating whether this position may have content.
long MAY_HAVE_CONTENT_BIT = 1L;

/// Mask of the transition bits including the direction. We apply xor with this value to form a position in the
/// reverse direction.
long TRANSITION_MASK = 0x8FFL << TRANSITION_SHIFT;
Expand Down Expand Up @@ -198,9 +204,37 @@ static boolean isOnReturnPath(long encodedPosition)
static long compare(long encoded1, long encoded2)
{
// This can support depth of 2^31 - 1 without overflowing.
return encoded1 - encoded2;
// Normalise the flag bits to the same value in both operands before subtracting.
return (encoded1 | FLAGS_MASK) - (encoded2 | FLAGS_MASK);
}


/// Returns a new position with flags from pos2 combined using union operation.
/// Union: (pos1 | pos2) & flags | (pos1 & ~flags)
/// This preserves the structural bits (depth, transition) from pos1 and combines the specified flags from both positions.
static long unionFlags(long pos1, long pos2, long flags)
{
return (pos1 | pos2) & flags | (pos1 & ~flags);
}

/// Returns a new position with flags from pos2 combined using union operation.
/// This version is optimized for cases where the position bits (depth and incoming transition) are known to match.
/// The assertion validates this invariant in debug builds.
static long unionFlagsMatchingPositions(long pos1, long pos2)
{
assert compare(pos1, pos2) == 0 :
String.format("Position mismatch in unionFlagsMatchingPositions: compare(%016x, %016x) = %d", pos1, pos2, compare(pos1, pos2));
return pos1 | pos2;
}

/// Returns a new position with flags from pos2 combined using intersection operation.
/// Intersection: pos1 & pos2 & flags | (pos1 & ~flags)
static long intersectionFlags(long pos1, long pos2, long flags)
{
return pos1 & pos2 & flags | (pos1 & ~flags);
}


static long rootPosition(Direction direction)
{
return direction.select(ROOT_POSITION_FORWARD, ROOT_POSITION_REVERSE);
Expand All @@ -225,7 +259,7 @@ static long exhaustedPosition(long prevPosition)

static boolean isRootPosition(long encodedPosition)
{
return encodedPosition == ROOT_POSITION_FORWARD || encodedPosition == ROOT_POSITION_REVERSE;
return compare(encodedPosition, ROOT_POSITION_FORWARD) == 0 || compare(encodedPosition, ROOT_POSITION_REVERSE) == 0;
}

static long encode(int depth, int transition, Direction direction)
Expand All @@ -252,7 +286,7 @@ static long positionForDescentWithByte(long encodedPosition, int incomingByte)
/// returned encoded position is a valid `skipTo` position for the current state.
static long positionForSkippingBranch(long encodedBranchPosition)
{
return encodedBranchPosition + (1L << TRANSITION_SHIFT);
return (encodedBranchPosition & ~FLAGS_MASK) + (1L << TRANSITION_SHIFT);
}

/// Returns true if the given `currPosition` as returned by `advance`, `advanceMultiple` or `skipTo` is the result
Expand All @@ -267,10 +301,11 @@ static boolean ascended(long currPosition, long prevPosition)

static String toString(long encodedPosition)
{
return String.format("depth %d incomingTransition %02x%s %s",
return String.format("depth %d incomingTransition %02x%s%s %s",
depth(encodedPosition),
incomingTransition(encodedPosition),
isOnReturnPath(encodedPosition) ? "↑" : " ",
(encodedPosition & MAY_HAVE_CONTENT_BIT) != 0 ? "C" : " ",
direction(encodedPosition));
}

Expand Down Expand Up @@ -361,9 +396,12 @@ default T advanceToContent(ResettingTransitionsReceiver receiver)
if (isOnReturnPath(currPosition))
receiver.onReturnPath();
}
T content = content();
if (content != null)
return content;
if ((currPosition & MAY_HAVE_CONTENT_BIT) != 0)
{
T content = content();
if (content != null)
return content;
}
prevPosition = currPosition;
}
}
Expand Down Expand Up @@ -406,7 +444,8 @@ default boolean descendAlong(ByteSource bytes)
while (next != ByteSource.END_OF_STREAM)
{
long nextPosition = positionForDescentWithByte(position, next);
if (compare(skipTo(nextPosition), nextPosition) != 0)
long arrived = skipTo(nextPosition);
if (compare(arrived, nextPosition) != 0)
return false;
next = bytes.next();
position = nextPosition;
Expand Down Expand Up @@ -463,7 +502,8 @@ interface Walker<T, R> extends Cursor.ResettingTransitionsReceiver
default <R> R process(Cursor.Walker<? super T, R> walker)
{
assertFresh();
T content = content(); // handle content on the root node
long currentPosition = encodedPosition();
T content = (currentPosition & MAY_HAVE_CONTENT_BIT) != 0 ? content() : null;

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to put this in a static Cursor.content(cursor, cursorPosition) method without changing the performance.

We can also add a line in the content() javadoc to say it's preferable to get this via the static. If you prefer, call it Cursor.checkFlagAndGetContent or something similar.

if (content == null)
content = advanceToContent(walker);

Expand All @@ -487,7 +527,8 @@ default <R> R process(Cursor.Walker<? super T, R> walker)
default <R> R processSkippingBranches(Cursor.Walker<? super T, R> walker)
{
assertFresh();
T content = content(); // handle content on the root node
long currentPosition = encodedPosition();

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making assertFresh return the position (which it has to fetch anyway) and rename it to something like getPositionAndAssertFresh?

T content = (currentPosition & MAY_HAVE_CONTENT_BIT) != 0 ? content() : null;
if (content != null)
{
walker.content(content);
Expand All @@ -504,7 +545,7 @@ default <R> R processSkippingBranches(Cursor.Walker<? super T, R> walker)
break;
walker.resetPathLength(depth(current) - 1);
walker.addPathByte(incomingTransition(current));
content = content();
content = (current & MAY_HAVE_CONTENT_BIT) != 0 ? content() : null;
if (content == null)
content = advanceToContent(walker);
}
Expand Down Expand Up @@ -541,7 +582,7 @@ public ByteComparable.Version byteComparableVersion()
@Override
public Cursor<T> tailCursor(Direction direction)
{
assert position == Cursor.rootPosition(direction) : "tailTrie called on exhausted cursor";
assert compare(position, Cursor.rootPosition(direction)) == 0 : "tailTrie called on exhausted cursor";
return new Empty<>(direction, byteComparableVersion);
}

Expand Down
14 changes: 10 additions & 4 deletions src/java/org/apache/cassandra/db/tries/DeletionAwareCursor.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,21 @@ default <R> R process(DeletionAwareWalker<? super T, ? super D, R> walker)

while (true)
{
T content = content(); // handle content on the root node
if (content != null)
walker.content(content);
// Always check for deletion branches as they are independent of content
RangeCursor<D> deletionBranch = deletionBranchCursor(direction());
if (deletionBranch != null && walker.enterDeletionsBranch())
{
processDeletionBranch(walker, deletionBranch);
walker.exitDeletionsBranch();
}

// MAY_HAVE_CONTENT_BIT optimization: only call content() if flag indicates potential content
if ((currentPosition & MAY_HAVE_CONTENT_BIT) != 0)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea for the deletion branch to be presented before content that covers it. Dumps, for one, look weird with this order.

{
T content = content();
if (content != null)
walker.content(content);
}

long prevPosition = currentPosition;
currentPosition = advanceMultiple(walker);
Expand All @@ -113,7 +119,7 @@ default <R> R process(DeletionAwareWalker<? super T, ? super D, R> walker)
private static <D> void processDeletionBranch(DeletionAwareWalker<?, ? super D, ?> walker, Cursor<D> cursor)
{
cursor.assertFresh();
D content = cursor.content(); // handle content on the root node
D content = (cursor.encodedPosition() & MAY_HAVE_CONTENT_BIT) != 0 ? cursor.content() : null;
if (content == null)
content = cursor.advanceToContent(walker);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ private long skipDeletionsToDataPosition(long dataPosition)
if (Cursor.isExhausted(deletionsPositionUncorrected))
return leaveDeletionsBranch(dataPosition);
else
return setAtDeletionsAndReturnPosition(deletionsPositionUncorrected == deletionsSkipPosition,
return setAtDeletionsAndReturnPosition(Cursor.compare(deletionsPositionUncorrected, deletionsSkipPosition) == 0,
dataPosition);
}

Expand Down
Loading