Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion backend/app/routes/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import asyncio
import concurrent.futures
from datetime import datetime, timezone
from typing import Optional
from typing import List, Optional
from pathlib import Path
import shutil
import socket
Expand Down Expand Up @@ -38,6 +38,7 @@
DocumentUpdate,
ChunkSettings,
UploadUrl,
BatchUploadResponse,
)
from app.auth import get_current_user
from app.config import get_settings
Expand Down Expand Up @@ -302,6 +303,124 @@ async def upload_document(

return DocumentResponse.model_validate(document).model_copy(update={"task_id": task_id})

@router.post("/upload/batch", response_model=BatchUploadResponse, status_code=status.HTTP_202_ACCEPTED)
async def batch_upload_documents(
files: List[UploadFile] = File(...),
chunk_size: int = Form(1000),
chunk_overlap: int = Form(200),
background_tasks: BackgroundTasks = None,
user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""Accept multiple files and enqueue parallel ingestion tasks.

Each file is validated and saved independently. Successfully saved files
are committed to the database and dispatched to Celery (or the in-process
fallback). Files that fail validation are recorded in the ``failed`` list
and do not block the remaining uploads.

Args:
files: One or more uploaded files (PDF, DOCX, TXT, MD).
chunk_size: Text chunk size for RAG ingestion (100–2000).
chunk_overlap: Overlap between consecutive chunks (0 < chunk_size).
background_tasks: FastAPI hook for in-process fallback execution.
user: Authenticated user injected by get_current_user.
db: Database session injected by get_db.

Returns:
BatchUploadResponse with the list of created documents, their task
IDs, total accepted count, and any filenames that were rejected.
"""
if not files:
raise ValidationException("No files provided")

if chunk_size < 100 or chunk_size > 2000:
raise ValidationException("Chunk size must be between 100 and 2000")
if chunk_overlap < 0 or chunk_overlap >= chunk_size:
raise ValidationException("Chunk overlap must be non-negative and less than chunk_size")

user_dir = os.path.join(settings.UPLOAD_DIR, user.id)
os.makedirs(user_dir, exist_ok=True)

created_documents: List[DocumentResponse] = []
task_ids: List[str] = []
failed: List[str] = []

for file in files:
filename = file.filename or "unknown"
try:
if not file.filename:
raise ValidationException("No filename provided")

ext = file.filename.rsplit(".", 1)[-1].lower()
if ext not in settings.ALLOWED_EXTENSIONS:
raise ValidationException(
f"File type '.{ext}' not supported. Allowed: {', '.join(settings.ALLOWED_EXTENSIONS)}"
)

temp_path = await validate_upload(file)

stored_filename = f"{uuid.uuid4().hex}.{ext}"
filepath = os.path.join(user_dir, stored_filename)
shutil.move(temp_path, filepath)
file_size = Path(filepath).stat().st_size

document = Document(
user_id=user.id,
filename=stored_filename,
original_name=file.filename,
file_size=file_size,
status="pending",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
)
db.add(document)
db.commit()
db.refresh(document)

task_id = None
try:
task = process_document.delay(
document_id=document.id,
filepath=filepath,
original_name=file.filename,
user_id=user.id,
)
task_id = task.id
except Exception as e:
logger.warning(f"Celery queue failed for {file.filename}, falling back to background task: {e}")
if background_tasks:
background_tasks.add_task(
ingest_document,
document_id=document.id,
filepath=filepath,
original_name=file.filename,
user_id=user.id,
)
task_id = f"local_{uuid.uuid4().hex}"

created_documents.append(
DocumentResponse.model_validate(document).model_copy(update={"task_id": task_id})
)
task_ids.append(task_id)

except (ValidationException, ExternalServiceException) as e:
logger.warning(f"Batch upload: skipping '{filename}' β€” {e}")
failed.append(filename)
except Exception as e:
logger.error(f"Batch upload: unexpected error for '{filename}' β€” {e}")
failed.append(filename)

if not created_documents and failed:
raise ValidationException(f"All files failed validation: {', '.join(failed)}")

return BatchUploadResponse(
documents=created_documents,
task_ids=task_ids,
total=len(created_documents),
failed=failed,
)

@router.post("/urlupload", status_code=status.HTTP_202_ACCEPTED)
async def upload_document_url(
payload: UploadUrl,
Expand Down
5 changes: 5 additions & 0 deletions backend/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ def parse_keywords(cls, v):
class Config:
from_attributes = True

class BatchUploadResponse(BaseModel):
documents: List[DocumentResponse]
task_ids: List[str]
total: int
failed: List[str] = []

class DocumentRename(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
Expand Down
227 changes: 227 additions & 0 deletions backend/tests/test_batch_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""
Tests for POST /api/v1/documents/upload/batch β€” issue #435.
"""
import io
import uuid
from unittest.mock import MagicMock, patch

import pytest

from app.models import Document


# ── helpers ──────────────────────────────────────────────────────────────────

def _fake_txt_file(name: str = "test.txt", content: bytes = b"hello world") -> tuple:
"""Return a multipart files tuple accepted by httpx TestClient."""
return ("files", (name, io.BytesIO(content), "text/plain"))


def _patch_validate(monkeypatch, tmp_path, content: bytes = b"hello world") -> None:
"""Make validate_upload write content to a real temp file and return its path."""
import tempfile, shutil

async def fake_validate(file):
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", dir=tmp_path)
tmp.write(content)
tmp.close()
return tmp.name

monkeypatch.setattr("app.routes.documents.validate_upload", fake_validate)


def _patch_celery(monkeypatch) -> MagicMock:
"""Stub out Celery so tests never touch Redis."""
mock_task = MagicMock()
mock_task.id = f"celery-{uuid.uuid4().hex}"
monkeypatch.setattr(
"app.routes.documents.process_document",
MagicMock(delay=MagicMock(return_value=mock_task)),
)
return mock_task


# ── tests ─────────────────────────────────────────────────────────────────────

def test_batch_upload_single_file(client, auth_headers, monkeypatch, tmp_path):
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[_fake_txt_file("report.txt")],
data={"chunk_size": "1000", "chunk_overlap": "200"},
)

assert response.status_code == 202
payload = response.json()
assert payload["total"] == 1
assert len(payload["documents"]) == 1
assert payload["documents"][0]["original_name"] == "report.txt"
assert payload["documents"][0]["status"] == "pending"
assert len(payload["task_ids"]) == 1
assert payload["failed"] == []


def test_batch_upload_multiple_files(client, auth_headers, monkeypatch, tmp_path):
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[
_fake_txt_file("a.txt"),
_fake_txt_file("b.txt"),
_fake_txt_file("c.txt"),
],
data={"chunk_size": "1000", "chunk_overlap": "200"},
)

assert response.status_code == 202
payload = response.json()
assert payload["total"] == 3
assert len(payload["documents"]) == 3
assert len(payload["task_ids"]) == 3
assert payload["failed"] == []


def test_batch_upload_rejects_bad_extension(client, auth_headers, monkeypatch, tmp_path):
"""A .exe file should land in failed[], not crash the whole batch."""
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[
_fake_txt_file("good.txt"),
("files", ("bad.exe", io.BytesIO(b"binary"), "application/octet-stream")),
],
data={"chunk_size": "1000", "chunk_overlap": "200"},
)

assert response.status_code == 202
payload = response.json()
assert payload["total"] == 1
assert payload["documents"][0]["original_name"] == "good.txt"
assert "bad.exe" in payload["failed"]


def test_batch_upload_all_files_fail_returns_400(client, auth_headers, monkeypatch, tmp_path):
"""When every file fails, the endpoint should return 400."""
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[
("files", ("bad1.exe", io.BytesIO(b"x"), "application/octet-stream")),
("files", ("bad2.exe", io.BytesIO(b"y"), "application/octet-stream")),
],
data={"chunk_size": "1000", "chunk_overlap": "200"},
)

assert response.status_code == 400


def test_batch_upload_requires_auth(client):
response = client.post(
"/api/v1/documents/upload/batch",
files=[_fake_txt_file("test.txt")],
data={"chunk_size": "1000", "chunk_overlap": "200"},
)

assert response.status_code in (401, 403)


def test_batch_upload_invalid_chunk_size(client, auth_headers, monkeypatch, tmp_path):
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[_fake_txt_file("test.txt")],
data={"chunk_size": "50", "chunk_overlap": "10"},
)

assert response.status_code == 400


def test_batch_upload_invalid_chunk_overlap(client, auth_headers, monkeypatch, tmp_path):
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[_fake_txt_file("test.txt")],
data={"chunk_size": "500", "chunk_overlap": "600"},
)

assert response.status_code == 400


def test_batch_upload_celery_fallback_uses_background_task(client, auth_headers, monkeypatch, tmp_path):
"""When Celery is unavailable, tasks should fall back gracefully."""
_patch_validate(monkeypatch, tmp_path)

# Make Celery raise so the fallback branch is taken
monkeypatch.setattr(
"app.routes.documents.process_document",
MagicMock(delay=MagicMock(side_effect=Exception("Redis down"))),
)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[_fake_txt_file("fallback.txt")],
data={"chunk_size": "1000", "chunk_overlap": "200"},
)

assert response.status_code == 202
payload = response.json()
assert payload["total"] == 1
assert payload["task_ids"][0].startswith("local_")


def test_batch_upload_document_persisted_in_db(client, auth_headers, monkeypatch, tmp_path, db_session):
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[_fake_txt_file("persisted.txt")],
data={"chunk_size": "1000", "chunk_overlap": "200"},
)

assert response.status_code == 202
doc_id = response.json()["documents"][0]["id"]
doc = db_session.get(Document, doc_id)
assert doc is not None
assert doc.original_name == "persisted.txt"
assert doc.status == "pending"
assert doc.chunk_size == 1000
assert doc.chunk_overlap == 200


def test_batch_upload_chunk_settings_stored(client, auth_headers, monkeypatch, tmp_path, db_session):
_patch_validate(monkeypatch, tmp_path)
_patch_celery(monkeypatch)

response = client.post(
"/api/v1/documents/upload/batch",
headers=auth_headers,
files=[_fake_txt_file("chunked.txt")],
data={"chunk_size": "800", "chunk_overlap": "100"},
)

assert response.status_code == 202
doc_id = response.json()["documents"][0]["id"]
doc = db_session.get(Document, doc_id)
assert doc.chunk_size == 800
assert doc.chunk_overlap == 100
Loading