Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,8 @@ default void configureFilterDelegation(FilterDelegationHandle handle, BackendExe
}

/**
* Configure task-level resource tracking for delegation callbacks executing on foreign threads.
* Called after {@link #configureFilterDelegation}. Backends should wrap their callback dispatch
* with start/finish tracking calls for the given task.
* Install a thread tracker for attribution of delegation callbacks executing on foreign threads.
* Called after {@link #configureFilterDelegation}. Pass {@code null} to clear.
*/
default void configureTaskTracking(org.opensearch.tasks.TaskResourceTrackingService trackingService, long taskId) {}

/**
* Clear task tracking state after fragment execution completes.
*/
default void clearTaskTracking() {}
default void setDelegationThreadTracker(DelegationThreadTracker tracker) {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.analytics.spi;

/**
* Tracks thread-level resource attribution for delegation callbacks executing
* on foreign threads (e.g., DataFusion/Tokio workers invoking Lucene via FFM).
*
* @opensearch.internal
*/
public interface DelegationThreadTracker {

/**
* Signal that delegation work has started on the current thread.
*
* @return thread id to pass to {@link #trackEnd}, or {@code -1} if tracking is inactive
*/
long trackStart();

/**
* Signal that delegation work has finished on the given thread.
*
* @param threadId the value returned by {@link #trackStart}
*/
void trackEnd(long threadId);
}
Original file line number Diff line number Diff line change
Expand Up @@ -569,12 +569,7 @@ public void configureFilterDelegation(FilterDelegationHandle handle, BackendExec
}

@Override
public void configureTaskTracking(org.opensearch.tasks.TaskResourceTrackingService trackingService, long taskId) {
FilterTreeCallbacks.setTaskTracking(trackingService, taskId);
}

@Override
public void clearTaskTracking() {
FilterTreeCallbacks.setTaskTracking(null, -1);
public void setDelegationThreadTracker(org.opensearch.analytics.spi.DelegationThreadTracker tracker) {
FilterTreeCallbacks.setThreadTracker(tracker);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.analytics.spi.DelegationThreadTracker;
import org.opensearch.analytics.spi.FilterDelegationHandle;
import org.opensearch.tasks.TaskResourceTrackingService;

import java.lang.foreign.MemorySegment;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -36,8 +36,7 @@ public final class FilterTreeCallbacks {
private static final Logger LOGGER = LogManager.getLogger(FilterTreeCallbacks.class);

private static final AtomicReference<FilterDelegationHandle> HANDLE = new AtomicReference<>();
private static final AtomicReference<TaskResourceTrackingService> TRACKING_SERVICE = new AtomicReference<>();
private static long currentTaskId = -1;
private static final AtomicReference<DelegationThreadTracker> TRACKER = new AtomicReference<>();

private FilterTreeCallbacks() {}

Expand All @@ -51,25 +50,21 @@ public static void setHandle(FilterDelegationHandle handle) {
}

/**
* Configure task resource tracking. All subsequent callbacks will attribute
* CPU/heap to the given task until cleared.
* Install or clear the thread tracker for resource attribution.
*/
public static void setTaskTracking(TaskResourceTrackingService trackingService, long taskId) {
TRACKING_SERVICE.set(trackingService);
currentTaskId = taskId;
public static void setThreadTracker(DelegationThreadTracker tracker) {
TRACKER.set(tracker);
}

private static long trackStart() {
TaskResourceTrackingService tracker = TRACKING_SERVICE.get();
if (tracker == null || currentTaskId < 0) return -1;
long threadId = Thread.currentThread().threadId();
tracker.taskExecutionStartedOnThread(currentTaskId, threadId);
return threadId;
DelegationThreadTracker t = TRACKER.get();
return (t != null) ? t.trackStart() : -1;
}

private static void trackEnd(long threadId) {
if (threadId < 0) return;
TRACKING_SERVICE.get().taskExecutionFinishedOnThread(currentTaskId, threadId);
DelegationThreadTracker t = TRACKER.get();
if (t != null) t.trackEnd(threadId);
}

// ── Provider lifecycle (cold path, once per query) ────────────────
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.be.datafusion.indexfilter;

import org.opensearch.analytics.exec.task.AnalyticsShardTask;
import org.opensearch.analytics.spi.DelegationThreadTracker;
import org.opensearch.analytics.spi.FilterDelegationHandle;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
Expand Down Expand Up @@ -50,27 +51,27 @@ public void setUp() throws Exception {
);
trackingService.setTaskResourceTrackingEnabled(true);
FilterTreeCallbacks.setHandle(null);
FilterTreeCallbacks.setTaskTracking(null, -1);
FilterTreeCallbacks.setThreadTracker(null);
}

@Override
public void tearDown() throws Exception {
FilterTreeCallbacks.setTaskTracking(null, -1);
FilterTreeCallbacks.setThreadTracker(null);
FilterTreeCallbacks.setHandle(null);
terminate(threadPool);
super.tearDown();
}

/**
* Tests the full production wiring: configureTaskTracking via SPI, then
* Tests the full production wiring: setDelegationThreadTracker via SPI, then
* all three callback methods (createProvider, createCollector, collectDocs)
* on a foreign thread. Verifies the thread is tracked against the task.
*/
public void testAllCallbackMethodsTrackedOnForeignThread() throws Exception {
AnalyticsShardTask task = createAndTrackTask(1);

var backendPlugin = new org.opensearch.be.datafusion.DataFusionAnalyticsBackendPlugin(null);
backendPlugin.configureTaskTracking(trackingService, task.getId());
backendPlugin.setDelegationThreadTracker(createTracker(task.getId()));
FilterTreeCallbacks.setHandle(new MockHandle(new long[] { 0xCAFEL }));

CountDownLatch done = new CountDownLatch(1);
Expand All @@ -88,25 +89,25 @@ public void testAllCallbackMethodsTrackedOnForeignThread() throws Exception {
foreignThread.start();
assertTrue(done.await(5, TimeUnit.SECONDS));

backendPlugin.clearTaskTracking();
backendPlugin.setDelegationThreadTracker(null);
trackingService.stopTracking(task);

Map<Long, List<ThreadResourceInfo>> stats = task.getResourceStats();
assertTrue("Foreign thread should be tracked. Got threads: " + stats.keySet(), stats.containsKey(foreignThread.threadId()));
}

/**
* Tests that clearTaskTracking stops attribution. After clearing,
* Tests that clearing the thread tracker stops attribution. After clearing,
* callbacks on a new thread should NOT be attributed to the old task.
*/
public void testClearTaskTrackingStopsAttribution() throws Exception {
AnalyticsShardTask task = createAndTrackTask(2);

FilterTreeCallbacks.setTaskTracking(trackingService, task.getId());
FilterTreeCallbacks.setThreadTracker(createTracker(task.getId()));
FilterTreeCallbacks.setHandle(new MockHandle(new long[] { 1L }));

// Clear tracking BEFORE running callbacks
FilterTreeCallbacks.setTaskTracking(null, -1);
FilterTreeCallbacks.setThreadTracker(null);

CountDownLatch done = new CountDownLatch(1);
Thread foreignThread = new Thread(() -> {
Expand All @@ -126,7 +127,7 @@ public void testClearTaskTrackingStopsAttribution() throws Exception {
trackingService.stopTracking(task);

Map<Long, List<ThreadResourceInfo>> stats = task.getResourceStats();
assertFalse("Thread after clearTaskTracking should NOT be tracked", stats.containsKey(foreignThread.threadId()));
assertFalse("Thread after clearing tracker should NOT be tracked", stats.containsKey(foreignThread.threadId()));
}

/**
Expand All @@ -135,7 +136,7 @@ public void testClearTaskTrackingStopsAttribution() throws Exception {
public void testConcurrentThreadsAllTracked() throws Exception {
AnalyticsShardTask task = createAndTrackTask(3);

FilterTreeCallbacks.setTaskTracking(trackingService, task.getId());
FilterTreeCallbacks.setThreadTracker(createTracker(task.getId()));
FilterTreeCallbacks.setHandle(new MockHandle(new long[] { 0xFFL }));

int threadCount = 4;
Expand Down Expand Up @@ -165,7 +166,7 @@ public void testConcurrentThreadsAllTracked() throws Exception {
}
assertTrue(done.await(10, TimeUnit.SECONDS));

FilterTreeCallbacks.setTaskTracking(null, -1);
FilterTreeCallbacks.setThreadTracker(null);
trackingService.stopTracking(task);

Map<Long, List<ThreadResourceInfo>> stats = task.getResourceStats();
Expand All @@ -174,6 +175,23 @@ public void testConcurrentThreadsAllTracked() throws Exception {
}
}

private DelegationThreadTracker createTracker(long taskId) {
TaskResourceTrackingService service = trackingService;
return new DelegationThreadTracker() {
@Override
public long trackStart() {
long threadId = Thread.currentThread().threadId();
service.taskExecutionStartedOnThread(taskId, threadId);
return threadId;
}

@Override
public void trackEnd(long threadId) {
service.taskExecutionFinishedOnThread(taskId, threadId);
}
};
}

private AnalyticsShardTask createAndTrackTask(long id) {
AnalyticsShardTask task = new AnalyticsShardTask(
id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.analytics.spi.AnalyticsSearchBackendPlugin;
import org.opensearch.analytics.spi.BackendExecutionContext;
import org.opensearch.analytics.spi.DelegationDescriptor;
import org.opensearch.analytics.spi.DelegationThreadTracker;
import org.opensearch.analytics.spi.FilterDelegationHandle;
import org.opensearch.analytics.spi.FragmentInstructionHandler;
import org.opensearch.analytics.spi.FragmentInstructionHandlerFactory;
Expand Down Expand Up @@ -99,8 +100,6 @@ public FragmentResources executeFragmentStreaming(FragmentExecutionRequest reque
} catch (Exception e) {
listener.onFragmentFailure(resolved.queryId, resolved.stageId, resolved.shardIdStr, e);
throw new RuntimeException("Failed to start streaming fragment on " + shard.shardId(), e);
} finally {
backends.get(resolved.plan.getBackendId()).clearTaskTracking();
}
}

Expand All @@ -110,6 +109,7 @@ private FragmentResources startFragment(FragmentExecutionRequest request, Resolv
SearchExecEngine<ShardScanExecutionContext, EngineResultStream> engine = null;
EngineResultStream stream = null;
BackendExecutionContext backendContext = null;
Runnable trackerCleanup = null;
try {
ShardScanExecutionContext ctx = buildContext(request, gatedReader.get(), resolved.plan, shard, task);
AnalyticsSearchBackendPlugin backend = backends.get(resolved.plan.getBackendId());
Expand All @@ -135,16 +135,31 @@ private FragmentResources startFragment(FragmentExecutionRequest request, Resolv
backend.configureFilterDelegation(handle, backendContext);

if (task != null && taskResourceTrackingService != null) {
backend.configureTaskTracking(taskResourceTrackingService, task.getId());
long taskId = task.getId();
TaskResourceTrackingService service = taskResourceTrackingService;
backend.setDelegationThreadTracker(new DelegationThreadTracker() {
@Override
public long trackStart() {
long threadId = Thread.currentThread().threadId();
service.taskExecutionStartedOnThread(taskId, threadId);
return threadId;
}

@Override
public void trackEnd(long threadId) {
service.taskExecutionFinishedOnThread(taskId, threadId);
}
});
trackerCleanup = () -> backend.setDelegationThreadTracker(null);
}
}

engine = backend.getSearchExecEngineProvider().createSearchExecEngine(ctx, backendContext);
stream = engine.execute(ctx);
return new FragmentResources(gatedReader, engine, stream);
return new FragmentResources(gatedReader, engine, stream, trackerCleanup);
} catch (Exception e) {
try {
new FragmentResources(gatedReader, engine, stream).close();
new FragmentResources(gatedReader, engine, stream, trackerCleanup).close();
} catch (Exception suppressed) {
e.addSuppressed(suppressed);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,26 @@ public final class FragmentResources implements AutoCloseable {
private final GatedCloseable<Reader> gatedReader;
private final SearchExecEngine<ShardScanExecutionContext, EngineResultStream> engine;
private final EngineResultStream stream;
private final Runnable onClose;

public FragmentResources(
GatedCloseable<Reader> gatedReader,
SearchExecEngine<ShardScanExecutionContext, EngineResultStream> engine,
EngineResultStream stream
) {
this(gatedReader, engine, stream, null);
}

public FragmentResources(
GatedCloseable<Reader> gatedReader,
SearchExecEngine<ShardScanExecutionContext, EngineResultStream> engine,
EngineResultStream stream,
Runnable onClose
) {
this.gatedReader = gatedReader;
this.engine = engine;
this.stream = stream;
this.onClose = onClose;
}

public EngineResultStream stream() {
Expand All @@ -42,8 +53,15 @@ public EngineResultStream stream() {

@Override
public void close() throws Exception {
Exception first;
first = closeQuietly(stream, null);
Exception first = null;
if (onClose != null) {
try {
onClose.run();
} catch (Exception e) {
first = e;
}
}
first = closeQuietly(stream, first);
first = closeQuietly(engine, first);
first = closeQuietly(gatedReader, first);
if (first != null) throw first;
Expand Down
Loading