diff --git a/backend/app/routes/documents.py b/backend/app/routes/documents.py index 05916f6..de131b3 100644 --- a/backend/app/routes/documents.py +++ b/backend/app/routes/documents.py @@ -9,7 +9,7 @@ import asyncio import concurrent.futures from datetime import datetime, timezone -from typing import Optional +from typing import List, Optional from pathlib import Path import shutil import socket @@ -38,6 +38,7 @@ DocumentUpdate, ChunkSettings, UploadUrl, + BatchUploadResponse, ) from app.auth import get_current_user from app.config import get_settings @@ -302,6 +303,124 @@ async def upload_document( return DocumentResponse.model_validate(document).model_copy(update={"task_id": task_id}) +@router.post("/upload/batch", response_model=BatchUploadResponse, status_code=status.HTTP_202_ACCEPTED) +async def batch_upload_documents( + files: List[UploadFile] = File(...), + chunk_size: int = Form(1000), + chunk_overlap: int = Form(200), + background_tasks: BackgroundTasks = None, + user: User = Depends(get_current_user), + db: Session = Depends(get_db), +): + """Accept multiple files and enqueue parallel ingestion tasks. + + Each file is validated and saved independently. Successfully saved files + are committed to the database and dispatched to Celery (or the in-process + fallback). Files that fail validation are recorded in the ``failed`` list + and do not block the remaining uploads. + + Args: + files: One or more uploaded files (PDF, DOCX, TXT, MD). + chunk_size: Text chunk size for RAG ingestion (100–2000). + chunk_overlap: Overlap between consecutive chunks (0 < chunk_size). + background_tasks: FastAPI hook for in-process fallback execution. + user: Authenticated user injected by get_current_user. + db: Database session injected by get_db. + + Returns: + BatchUploadResponse with the list of created documents, their task + IDs, total accepted count, and any filenames that were rejected. + """ + if not files: + raise ValidationException("No files provided") + + if chunk_size < 100 or chunk_size > 2000: + raise ValidationException("Chunk size must be between 100 and 2000") + if chunk_overlap < 0 or chunk_overlap >= chunk_size: + raise ValidationException("Chunk overlap must be non-negative and less than chunk_size") + + user_dir = os.path.join(settings.UPLOAD_DIR, user.id) + os.makedirs(user_dir, exist_ok=True) + + created_documents: List[DocumentResponse] = [] + task_ids: List[str] = [] + failed: List[str] = [] + + for file in files: + filename = file.filename or "unknown" + try: + if not file.filename: + raise ValidationException("No filename provided") + + ext = file.filename.rsplit(".", 1)[-1].lower() + if ext not in settings.ALLOWED_EXTENSIONS: + raise ValidationException( + f"File type '.{ext}' not supported. Allowed: {', '.join(settings.ALLOWED_EXTENSIONS)}" + ) + + temp_path = await validate_upload(file) + + stored_filename = f"{uuid.uuid4().hex}.{ext}" + filepath = os.path.join(user_dir, stored_filename) + shutil.move(temp_path, filepath) + file_size = Path(filepath).stat().st_size + + document = Document( + user_id=user.id, + filename=stored_filename, + original_name=file.filename, + file_size=file_size, + status="pending", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ) + db.add(document) + db.commit() + db.refresh(document) + + task_id = None + try: + task = process_document.delay( + document_id=document.id, + filepath=filepath, + original_name=file.filename, + user_id=user.id, + ) + task_id = task.id + except Exception as e: + logger.warning(f"Celery queue failed for {file.filename}, falling back to background task: {e}") + if background_tasks: + background_tasks.add_task( + ingest_document, + document_id=document.id, + filepath=filepath, + original_name=file.filename, + user_id=user.id, + ) + task_id = f"local_{uuid.uuid4().hex}" + + created_documents.append( + DocumentResponse.model_validate(document).model_copy(update={"task_id": task_id}) + ) + task_ids.append(task_id) + + except (ValidationException, ExternalServiceException) as e: + logger.warning(f"Batch upload: skipping '{filename}' — {e}") + failed.append(filename) + except Exception as e: + logger.error(f"Batch upload: unexpected error for '{filename}' — {e}") + failed.append(filename) + + if not created_documents and failed: + raise ValidationException(f"All files failed validation: {', '.join(failed)}") + + return BatchUploadResponse( + documents=created_documents, + task_ids=task_ids, + total=len(created_documents), + failed=failed, + ) + @router.post("/urlupload", status_code=status.HTTP_202_ACCEPTED) async def upload_document_url( payload: UploadUrl, diff --git a/backend/app/schemas.py b/backend/app/schemas.py index 81bbb77..788e75a 100644 --- a/backend/app/schemas.py +++ b/backend/app/schemas.py @@ -210,6 +210,11 @@ def parse_keywords(cls, v): class Config: from_attributes = True +class BatchUploadResponse(BaseModel): + documents: List[DocumentResponse] + task_ids: List[str] + total: int + failed: List[str] = [] class DocumentRename(BaseModel): name: str = Field(..., min_length=1, max_length=255) diff --git a/backend/tests/test_batch_upload.py b/backend/tests/test_batch_upload.py new file mode 100644 index 0000000..6f1601f --- /dev/null +++ b/backend/tests/test_batch_upload.py @@ -0,0 +1,227 @@ +""" +Tests for POST /api/v1/documents/upload/batch — issue #435. +""" +import io +import uuid +from unittest.mock import MagicMock, patch + +import pytest + +from app.models import Document + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _fake_txt_file(name: str = "test.txt", content: bytes = b"hello world") -> tuple: + """Return a multipart files tuple accepted by httpx TestClient.""" + return ("files", (name, io.BytesIO(content), "text/plain")) + + +def _patch_validate(monkeypatch, tmp_path, content: bytes = b"hello world") -> None: + """Make validate_upload write content to a real temp file and return its path.""" + import tempfile, shutil + + async def fake_validate(file): + tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", dir=tmp_path) + tmp.write(content) + tmp.close() + return tmp.name + + monkeypatch.setattr("app.routes.documents.validate_upload", fake_validate) + + +def _patch_celery(monkeypatch) -> MagicMock: + """Stub out Celery so tests never touch Redis.""" + mock_task = MagicMock() + mock_task.id = f"celery-{uuid.uuid4().hex}" + monkeypatch.setattr( + "app.routes.documents.process_document", + MagicMock(delay=MagicMock(return_value=mock_task)), + ) + return mock_task + + +# ── tests ───────────────────────────────────────────────────────────────────── + +def test_batch_upload_single_file(client, auth_headers, monkeypatch, tmp_path): + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[_fake_txt_file("report.txt")], + data={"chunk_size": "1000", "chunk_overlap": "200"}, + ) + + assert response.status_code == 202 + payload = response.json() + assert payload["total"] == 1 + assert len(payload["documents"]) == 1 + assert payload["documents"][0]["original_name"] == "report.txt" + assert payload["documents"][0]["status"] == "pending" + assert len(payload["task_ids"]) == 1 + assert payload["failed"] == [] + + +def test_batch_upload_multiple_files(client, auth_headers, monkeypatch, tmp_path): + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[ + _fake_txt_file("a.txt"), + _fake_txt_file("b.txt"), + _fake_txt_file("c.txt"), + ], + data={"chunk_size": "1000", "chunk_overlap": "200"}, + ) + + assert response.status_code == 202 + payload = response.json() + assert payload["total"] == 3 + assert len(payload["documents"]) == 3 + assert len(payload["task_ids"]) == 3 + assert payload["failed"] == [] + + +def test_batch_upload_rejects_bad_extension(client, auth_headers, monkeypatch, tmp_path): + """A .exe file should land in failed[], not crash the whole batch.""" + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[ + _fake_txt_file("good.txt"), + ("files", ("bad.exe", io.BytesIO(b"binary"), "application/octet-stream")), + ], + data={"chunk_size": "1000", "chunk_overlap": "200"}, + ) + + assert response.status_code == 202 + payload = response.json() + assert payload["total"] == 1 + assert payload["documents"][0]["original_name"] == "good.txt" + assert "bad.exe" in payload["failed"] + + +def test_batch_upload_all_files_fail_returns_400(client, auth_headers, monkeypatch, tmp_path): + """When every file fails, the endpoint should return 400.""" + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[ + ("files", ("bad1.exe", io.BytesIO(b"x"), "application/octet-stream")), + ("files", ("bad2.exe", io.BytesIO(b"y"), "application/octet-stream")), + ], + data={"chunk_size": "1000", "chunk_overlap": "200"}, + ) + + assert response.status_code == 400 + + +def test_batch_upload_requires_auth(client): + response = client.post( + "/api/v1/documents/upload/batch", + files=[_fake_txt_file("test.txt")], + data={"chunk_size": "1000", "chunk_overlap": "200"}, + ) + + assert response.status_code in (401, 403) + + +def test_batch_upload_invalid_chunk_size(client, auth_headers, monkeypatch, tmp_path): + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[_fake_txt_file("test.txt")], + data={"chunk_size": "50", "chunk_overlap": "10"}, + ) + + assert response.status_code == 400 + + +def test_batch_upload_invalid_chunk_overlap(client, auth_headers, monkeypatch, tmp_path): + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[_fake_txt_file("test.txt")], + data={"chunk_size": "500", "chunk_overlap": "600"}, + ) + + assert response.status_code == 400 + + +def test_batch_upload_celery_fallback_uses_background_task(client, auth_headers, monkeypatch, tmp_path): + """When Celery is unavailable, tasks should fall back gracefully.""" + _patch_validate(monkeypatch, tmp_path) + + # Make Celery raise so the fallback branch is taken + monkeypatch.setattr( + "app.routes.documents.process_document", + MagicMock(delay=MagicMock(side_effect=Exception("Redis down"))), + ) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[_fake_txt_file("fallback.txt")], + data={"chunk_size": "1000", "chunk_overlap": "200"}, + ) + + assert response.status_code == 202 + payload = response.json() + assert payload["total"] == 1 + assert payload["task_ids"][0].startswith("local_") + + +def test_batch_upload_document_persisted_in_db(client, auth_headers, monkeypatch, tmp_path, db_session): + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[_fake_txt_file("persisted.txt")], + data={"chunk_size": "1000", "chunk_overlap": "200"}, + ) + + assert response.status_code == 202 + doc_id = response.json()["documents"][0]["id"] + doc = db_session.get(Document, doc_id) + assert doc is not None + assert doc.original_name == "persisted.txt" + assert doc.status == "pending" + assert doc.chunk_size == 1000 + assert doc.chunk_overlap == 200 + + +def test_batch_upload_chunk_settings_stored(client, auth_headers, monkeypatch, tmp_path, db_session): + _patch_validate(monkeypatch, tmp_path) + _patch_celery(monkeypatch) + + response = client.post( + "/api/v1/documents/upload/batch", + headers=auth_headers, + files=[_fake_txt_file("chunked.txt")], + data={"chunk_size": "800", "chunk_overlap": "100"}, + ) + + assert response.status_code == 202 + doc_id = response.json()["documents"][0]["id"] + doc = db_session.get(Document, doc_id) + assert doc.chunk_size == 800 + assert doc.chunk_overlap == 100