-
Notifications
You must be signed in to change notification settings - Fork 3
Abort in-flight requests when the client disconnects #136
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
merceod
wants to merge
5
commits into
main
Choose a base branch
from
fix/cancelled-request-gpu-leak
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
6e66ea4
Abort in-flight requests when the client disconnects
merceod 6516808
fix race condition (#137)
NSagan271 9a636c7
Merge main into fix/cancelled-request-gpu-leak
merceod 8b884e9
Ack abandoned result tensors so the producing worker can reclaim them
merceod 4cdf0c8
Avoid infinite loop replaying buffered messages for a removed request
merceod File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| from typing import Optional | ||
|
|
||
| import uvicorn | ||
| from fastapi import FastAPI, File, Form, HTTPException, UploadFile | ||
| from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from fastapi.responses import JSONResponse, StreamingResponse | ||
| from starlette.concurrency import run_in_threadpool | ||
|
|
@@ -325,10 +325,14 @@ def _process_messages(self) -> None: | |
| message.body.final_outputs | ||
| elif rid in self.recently_completed: | ||
| logger.debug("Late message for completed %s: %s", rid, message.message_type) | ||
| if message.message_type == "result_tensors": | ||
| self.preprocess_worker.discard_result_tensors(message.body) | ||
| else: | ||
| logger.warning( | ||
| "Message for unknown request %s: %s", rid, message.message_type | ||
| ) | ||
| if message.message_type == "result_tensors": | ||
| self.preprocess_worker.discard_result_tensors(message.body) | ||
| for result_chunk in self.preprocess_worker.get_result_chunks(): | ||
| logger.debug( | ||
| "Got result chunk of %s modality for request %s", | ||
|
|
@@ -360,42 +364,46 @@ async def iter_result_chunks(self, request_id: str): | |
| pre-serialized line). | ||
| """ | ||
| start = time.time() | ||
| while True: | ||
| if time.time() - start > self.timeout_seconds: | ||
| with self.request_lock: | ||
| self.pending_requests.pop(request_id, None) | ||
| raise HTTPException(status_code=500, detail="Request timed out") | ||
|
|
||
| new_chunks: list[ResultChunk] = [] | ||
| done = False | ||
| with self.request_lock: | ||
| req = self.pending_requests.get(request_id) | ||
| if req: | ||
| avail = len(req["chunks"]) | ||
| consumed = req["consumed_chunks"] | ||
| new_chunks = req["chunks"][consumed:avail] | ||
| req["consumed_chunks"] = avail | ||
| done = req["event"].is_set() | ||
| else: | ||
| done = True | ||
|
|
||
| for chunk in new_chunks: | ||
| yield chunk | ||
|
|
||
| if done: | ||
| logger.info("Async stream results received finish for %s", request_id) | ||
| # flush remaining | ||
| remaining: list[ResultChunk] = [] | ||
| finished = False | ||
| try: | ||
| while True: | ||
| if time.time() - start > self.timeout_seconds: | ||
| raise HTTPException(status_code=500, detail="Request timed out") | ||
|
|
||
| new_chunks: list[ResultChunk] = [] | ||
| done = False | ||
| with self.request_lock: | ||
| req = self.pending_requests.get(request_id) | ||
| if req: | ||
| remaining = req["chunks"][req["consumed_chunks"]:] | ||
| self.pending_requests.pop(request_id, None) | ||
| for chunk in remaining: | ||
| avail = len(req["chunks"]) | ||
| consumed = req["consumed_chunks"] | ||
| new_chunks = req["chunks"][consumed:avail] | ||
| req["consumed_chunks"] = avail | ||
| done = req["event"].is_set() | ||
| else: | ||
| done = True | ||
|
|
||
| for chunk in new_chunks: | ||
| yield chunk | ||
| break | ||
|
|
||
| await asyncio.sleep(0.001) | ||
| if done: | ||
| logger.info("Async stream results received finish for %s", request_id) | ||
| # flush remaining | ||
| remaining: list[ResultChunk] = [] | ||
| with self.request_lock: | ||
| req = self.pending_requests.get(request_id) | ||
| if req: | ||
| remaining = req["chunks"][req["consumed_chunks"]:] | ||
| self.pending_requests.pop(request_id, None) | ||
| for chunk in remaining: | ||
| yield chunk | ||
| finished = True | ||
| break | ||
|
|
||
| await asyncio.sleep(0.001) | ||
| finally: | ||
| if not finished: | ||
| self.abort_request(request_id) | ||
|
|
||
| async def async_stream_results(self, request_id: str): | ||
| """Yield NDJSON lines as result chunks arrive (``/generate`` format).""" | ||
|
|
@@ -411,28 +419,47 @@ def _chunk_to_ndjson(chunk: ResultChunk) -> str: | |
| }) + "\n" | ||
|
|
||
| # ---------------------------------------------------------- | ||
| # Blocking helper (non-streaming) | ||
| # Non-streaming helper | ||
| # ---------------------------------------------------------- | ||
|
|
||
| def collect_results(self, request_id: str) -> list[ResultChunk]: | ||
| """Block until the request completes, then return all chunks.""" | ||
| with self.request_lock: | ||
| req = self.pending_requests.get(request_id) | ||
| if not req: | ||
| raise HTTPException( | ||
| status_code=404, detail=f"Request {request_id} not found" | ||
| ) | ||
| event = req["event"] | ||
|
|
||
| if not event.wait(timeout=self.timeout_seconds): | ||
| async def collect_results( | ||
| self, request_id: str, raw_request: Request | None = None | ||
| ) -> list[ResultChunk]: | ||
| """Wait for the request to finish (or the client to disconnect), then | ||
| return its chunks. Disconnecting or timing out releases engine state.""" | ||
| start = time.time() | ||
| while True: | ||
| with self.request_lock: | ||
| self.pending_requests.pop(request_id, None) | ||
| raise HTTPException(status_code=500, detail="Request timed out") | ||
| req = self.pending_requests.get(request_id) | ||
| done = req["event"].is_set() if req else True | ||
| if done: | ||
| break | ||
| if time.time() - start > self.timeout_seconds: | ||
| self.abort_request(request_id) | ||
| raise HTTPException(status_code=500, detail="Request timed out") | ||
| if raw_request is not None and await raw_request.is_disconnected(): | ||
| self.abort_request(request_id) | ||
| return [] | ||
| await asyncio.sleep(0.005) | ||
|
|
||
| with self.request_lock: | ||
| chunks = self.pending_requests[request_id]["chunks"][:] | ||
| req = self.pending_requests.pop(request_id, None) | ||
| return list(req["chunks"]) if req else [] | ||
|
|
||
| def abort_request(self, request_id: str) -> None: | ||
| """Stop GPU work for a request the client abandoned and drop its state.""" | ||
| with self.request_lock: | ||
| active = ( | ||
| request_id in self.pending_requests | ||
| or request_id in self.recently_completed | ||
| ) | ||
| self.pending_requests.pop(request_id, None) | ||
| return chunks | ||
| self.recently_completed.pop(request_id, None) | ||
| if not active: | ||
| return | ||
| logger.info("Client cancelled request %s; releasing resources", request_id) | ||
| self.preprocess_worker.abort_request(request_id) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: is it cleaner to have preprocess_worker.abort_request also do the cleanup path, instead of the caller having to remember to call both? |
||
| self.preprocess_worker.cleanup_request(request_id) | ||
|
|
||
| # ---------------------------------------------------------- | ||
| # Cleanup | ||
|
|
@@ -472,6 +499,7 @@ def cleanup(self) -> None: | |
|
|
||
| @app.post("/generate") | ||
| async def generate( | ||
| request: Request, | ||
| text: Optional[str] = Form(None), | ||
| files: Optional[list[UploadFile]] = File(None), | ||
| input_modalities: Optional[str] = Form(None), | ||
|
|
@@ -547,9 +575,7 @@ async def generate( | |
| headers={"Cache-Control": "no-cache"}, | ||
| ) | ||
|
|
||
| chunks = await run_in_threadpool( | ||
| api_server.collect_results, request_id | ||
| ) | ||
| chunks = await api_server.collect_results(request_id, request) | ||
| outputs: dict[str, list[dict]] = {} | ||
| for chunk in chunks: | ||
| outputs.setdefault(chunk.modality, []).append({ | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably for another PR (I can raise an issue), but I think it could be worthwhile to be able to raise/lower timeout_seconds per request? e.g., to fail fast if we know the request should be short, or to bump it up for a longer request. Thoughts?