diff --git a/README.md b/README.md index be49195f..ac0c4b49 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,18 @@ Audio annotation tools are another key feature of LabelU. These tools possess ef ### Artificial Intelligence Assisted Labelling LabelU supports one-click loading of pre-annotated data, which can be refined and adjusted according to actual needs. This feature improves the efficiency and accuracy of annotation. +### AI Auto-Annotation +LabelU integrates AI model services for automatic annotation of image data. Click the "AI Annotate" button on the annotation page to have the model automatically detect and segment objects. Supports batch annotation for entire tasks with real-time progress tracking. Three reference model servers are provided out of the box: + +- **Florence-2** — lightweight, CPU-friendly (~4GB VRAM) +- **GroundingDINO + EfficientSAM** — high-quality detection + segmentation (~4GB VRAM) +- **SAM 3** — state-of-the-art unified model (~8GB VRAM, requires high-end GPU) + +See [`model_server/README.md`](./model_server/README.md) for setup instructions. + +### S3 Data Source Import +LabelU supports importing annotation data directly from S3-compatible object storage (AWS S3, MinIO, etc.). Configure data source connections in the task settings, browse and preview files, then import selected files or all files under a path with one click. + https://github.com/user-attachments/assets/0fa5bc39-20ba-46b6-9839-379a49f692cf diff --git a/README_zh-CN.md b/README_zh-CN.md index 8ca52443..02198114 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -28,9 +28,25 @@ LabelU为图像标注提供了全面的工具集,包括2D框、语义分割、 ### 人工智能辅助标注 LabelU 支持预标注数据的一键载入,用户可以根据实际需要对其进行细化和调整。这一特性提高了标注的效率和准确性。 +### AI 自动标注 +LabelU 集成了 AI 模型服务,支持图像数据的自动标注。在标注页面点击「AI 标注」按钮即可让模型自动检测和分割目标,也支持对整个任务的所有未标注样本进行批量标注,并可实时查看进度。项目内置提供了三个参考模型服务: + +- **Florence-2** — 轻量级,CPU 友好(约 4GB 显存) +- **GroundingDINO + EfficientSAM** — 高质量检测 + 分割(约 4GB 显存) +- **SAM 3** — 最新一代统一模型(约 8GB 显存,需要高端 GPU) + +详见 [`model_server/README.md`](./model_server/README.md) 了解部署方式。 + +### S3 数据源导入 +LabelU 支持从 S3 兼容对象存储(AWS S3、MinIO 等)直接导入标注数据。在任务设置中配置数据源连接,浏览和预览文件,然后一键导入选定文件或路径下的所有文件。 + + https://github.com/user-attachments/assets/f90e5a66-ab4d-456e-af4d-e6408a623812 +https://github.com/user-attachments/assets/0fa5bc39-20ba-46b6-9839-379a49f692cf + + ## 特性 - 简易,提供多种图像标注工具,通过简单可视化配置即可标注。 diff --git a/labelu/alembic_labelu/versions/a1b2c3d4e5f6_add_export_job_table.py b/labelu/alembic_labelu/versions/a1b2c3d4e5f6_add_export_job_table.py index 161b8cd7..2aad8a5e 100644 --- a/labelu/alembic_labelu/versions/a1b2c3d4e5f6_add_export_job_table.py +++ b/labelu/alembic_labelu/versions/a1b2c3d4e5f6_add_export_job_table.py @@ -10,7 +10,7 @@ # revision identifiers, used by Alembic. revision = 'a1b2c3d4e5f6' -down_revision = '2eb983c9a254' +down_revision = '034c7045b540' branch_labels = None depends_on = None diff --git a/labelu/alembic_labelu/versions/b2c3d4e5f6a7_add_data_source.py b/labelu/alembic_labelu/versions/b2c3d4e5f6a7_add_data_source.py new file mode 100644 index 00000000..17eb359e --- /dev/null +++ b/labelu/alembic_labelu/versions/b2c3d4e5f6a7_add_data_source.py @@ -0,0 +1,73 @@ +"""add data_source table and attachment.data_source_id + +Revision ID: b2c3d4e5f6a7 +Revises: a1b2c3d4e5f6 +Create Date: 2026-04-17 12:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +revision = "b2c3d4e5f6a7" +down_revision = "a1b2c3d4e5f6" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + existing_tables = inspector.get_table_names() + + if "data_source" not in existing_tables: + op.create_table( + "data_source", + sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), + sa.Column("name", sa.String(128), nullable=False), + sa.Column("type", sa.String(32), nullable=False, server_default="S3"), + sa.Column("endpoint", sa.String(512)), + sa.Column("region", sa.String(64)), + sa.Column("bucket", sa.String(256), nullable=False), + sa.Column("prefix", sa.String(512), server_default=""), + sa.Column("access_key_id", sa.String(512)), + sa.Column("secret_access_key", sa.String(1024)), + sa.Column("path_style", sa.Boolean(), server_default=sa.text("0")), + sa.Column("use_ssl", sa.Boolean(), server_default=sa.text("1")), + sa.Column("presign_expire_secs", sa.Integer(), server_default=sa.text("3600")), + sa.Column("created_by", sa.Integer(), sa.ForeignKey("user.id")), + sa.Column("updated_by", sa.Integer(), sa.ForeignKey("user.id")), + sa.Column("created_at", sa.DateTime(timezone=True)), + sa.Column("updated_at", sa.DateTime(timezone=True)), + sa.Column("deleted_at", sa.DateTime()), + ) + op.create_index("ix_data_source_id", "data_source", ["id"]) + op.create_index("ix_data_source_created_by", "data_source", ["created_by"]) + op.create_index("ix_data_source_deleted_at", "data_source", ["deleted_at"]) + + existing_columns = [c["name"] for c in inspector.get_columns("task_attachment")] + if "data_source_id" not in existing_columns: + with op.batch_alter_table("task_attachment", naming_convention={"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s"}) as batch_op: + batch_op.add_column( + sa.Column("data_source_id", sa.Integer(), nullable=True) + ) + batch_op.create_foreign_key( + "fk_task_attachment_data_source_id_data_source", + "data_source", + ["data_source_id"], + ["id"], + ) + batch_op.create_index("ix_task_attachment_data_source_id", ["data_source_id"]) + + +def downgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + + existing_columns = [c["name"] for c in inspector.get_columns("task_attachment")] + if "data_source_id" in existing_columns: + with op.batch_alter_table("task_attachment") as batch_op: + batch_op.drop_index("ix_task_attachment_data_source_id") + batch_op.drop_column("data_source_id") + + if "data_source" in inspector.get_table_names(): + op.drop_table("data_source") diff --git a/labelu/alembic_labelu/versions/c3d4e5f6a7b8_add_auto_label_job_table.py b/labelu/alembic_labelu/versions/c3d4e5f6a7b8_add_auto_label_job_table.py new file mode 100644 index 00000000..9ee86352 --- /dev/null +++ b/labelu/alembic_labelu/versions/c3d4e5f6a7b8_add_auto_label_job_table.py @@ -0,0 +1,54 @@ +"""add auto_label_job table + +Revision ID: c3d4e5f6a7b8 +Revises: b2c3d4e5f6a7 +Create Date: 2026-04-20 10:00:00.000000 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'c3d4e5f6a7b8' +down_revision = 'b2c3d4e5f6a7' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + tables = inspector.get_table_names() + + if 'auto_label_job' not in tables: + op.create_table( + 'auto_label_job', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('task_id', sa.Integer(), nullable=True), + sa.Column('created_by', sa.Integer(), nullable=True), + sa.Column('status', sa.String(length=32), nullable=True), + sa.Column('sample_count', sa.Integer(), nullable=True), + sa.Column('processed_count', sa.Integer(), nullable=True), + sa.Column('success_count', sa.Integer(), nullable=True), + sa.Column('failed_count', sa.Integer(), nullable=True), + sa.Column('filter_by_labels', sa.Boolean(), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=True), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['task_id'], ['task.id']), + sa.ForeignKeyConstraint(['created_by'], ['user.id']), + sa.PrimaryKeyConstraint('id'), + ) + op.create_index('ix_auto_label_job_id', 'auto_label_job', ['id']) + op.create_index('ix_auto_label_job_task_id', 'auto_label_job', ['task_id']) + + +def downgrade() -> None: + conn = op.get_bind() + inspector = sa.inspect(conn) + tables = inspector.get_table_names() + + if 'auto_label_job' in tables: + op.drop_index('ix_auto_label_job_task_id', table_name='auto_label_job') + op.drop_index('ix_auto_label_job_id', table_name='auto_label_job') + op.drop_table('auto_label_job') diff --git a/labelu/internal/adapter/persistence/crud_auto_label_job.py b/labelu/internal/adapter/persistence/crud_auto_label_job.py new file mode 100644 index 00000000..9f72f1cb --- /dev/null +++ b/labelu/internal/adapter/persistence/crud_auto_label_job.py @@ -0,0 +1,40 @@ +from typing import Optional +from sqlalchemy.orm import Session +from labelu.internal.domain.models.auto_label_job import AutoLabelJob, AutoLabelStatus + + +def create(db: Session, task_id: int, user_id: int, sample_count: int, filter_by_labels: bool) -> AutoLabelJob: + job = AutoLabelJob( + task_id=task_id, + created_by=user_id, + sample_count=sample_count, + filter_by_labels=filter_by_labels, + ) + db.add(job) + db.flush() + db.refresh(job) + return job + + +def get(db: Session, job_id: int) -> Optional[AutoLabelJob]: + return db.query(AutoLabelJob).filter(AutoLabelJob.id == job_id).first() + + +def update_status(db: Session, job: AutoLabelJob, status: str, **kwargs) -> AutoLabelJob: + job.status = status + for k, v in kwargs.items(): + setattr(job, k, v) + db.flush() + db.refresh(job) + return job + + +def increment_progress(db: Session, job: AutoLabelJob, success: bool) -> AutoLabelJob: + job.processed_count = (job.processed_count or 0) + 1 + if success: + job.success_count = (job.success_count or 0) + 1 + else: + job.failed_count = (job.failed_count or 0) + 1 + db.flush() + db.refresh(job) + return job diff --git a/labelu/internal/adapter/persistence/crud_datasource.py b/labelu/internal/adapter/persistence/crud_datasource.py new file mode 100644 index 00000000..31b81697 --- /dev/null +++ b/labelu/internal/adapter/persistence/crud_datasource.py @@ -0,0 +1,45 @@ +from typing import Optional, List, Tuple +from datetime import datetime + +from sqlalchemy.orm import Session + +from labelu.internal.domain.models.data_source import DataSource + + +def create(db: Session, data_source: DataSource) -> DataSource: + db.add(data_source) + db.flush() + db.refresh(data_source) + return data_source + + +def get(db: Session, ds_id: int) -> Optional[DataSource]: + return ( + db.query(DataSource) + .filter(DataSource.id == ds_id, DataSource.deleted_at.is_(None)) + .first() + ) + + +def list_by_user( + db: Session, user_id: int, page: int = 0, size: int = 100 +) -> Tuple[List[DataSource], int]: + query = db.query(DataSource).filter( + DataSource.created_by == user_id, DataSource.deleted_at.is_(None) + ) + total = query.count() + items = query.order_by(DataSource.id.desc()).offset(page * size).limit(size).all() + return items, total + + +def update(db: Session, db_obj: DataSource, obj_in: dict) -> DataSource: + for k, v in obj_in.items(): + setattr(db_obj, k, v) + db.flush() + db.refresh(db_obj) + return db_obj + + +def soft_delete(db: Session, db_obj: DataSource) -> None: + db_obj.deleted_at = datetime.now() + db.flush() diff --git a/labelu/internal/adapter/persistence/crud_sample.py b/labelu/internal/adapter/persistence/crud_sample.py index ba0836b2..5a38443b 100644 --- a/labelu/internal/adapter/persistence/crud_sample.py +++ b/labelu/internal/adapter/persistence/crud_sample.py @@ -81,6 +81,18 @@ def get_by_ids(db: Session, sample_ids: List[int], task_id: Union[int, None] = N return db.query(TaskSample).filter(*query_filter).all() +def list_new_samples(db: Session, task_id: int) -> List[TaskSample]: + return ( + db.query(TaskSample) + .filter( + TaskSample.task_id == task_id, + TaskSample.state == SampleState.NEW.value, + TaskSample.deleted_at == None, + ) + .all() + ) + + def update(db: Session, db_obj: TaskSample, obj_in: Dict[str, Any]) -> TaskSample: obj_data = jsonable_encoder(obj_in) for field in obj_data: diff --git a/labelu/internal/adapter/routers/__init__.py b/labelu/internal/adapter/routers/__init__.py index bb271012..5b037e2e 100644 --- a/labelu/internal/adapter/routers/__init__.py +++ b/labelu/internal/adapter/routers/__init__.py @@ -6,6 +6,7 @@ from labelu.internal.adapter.routers import sample from labelu.internal.adapter.routers import attachment from labelu.internal.adapter.routers import pre_annotation +from labelu.internal.adapter.routers import datasource def add_router(app: FastAPI): @@ -14,3 +15,4 @@ def add_router(app: FastAPI): app.include_router(attachment.router, prefix=settings.API_V1_STR) app.include_router(sample.router, prefix=settings.API_V1_STR) app.include_router(pre_annotation.router, prefix=settings.API_V1_STR) + app.include_router(datasource.router, prefix=settings.API_V1_STR) diff --git a/labelu/internal/adapter/routers/attachment.py b/labelu/internal/adapter/routers/attachment.py index 3c836d35..4e957037 100644 --- a/labelu/internal/adapter/routers/attachment.py +++ b/labelu/internal/adapter/routers/attachment.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from fastapi import APIRouter, status, Depends, Security from fastapi import File, Header, UploadFile -from fastapi.responses import FileResponse, StreamingResponse, Response +from fastapi.responses import FileResponse, StreamingResponse, Response, RedirectResponse from fastapi.security import HTTPAuthorizationCredentials import mimetypes @@ -60,8 +60,9 @@ async def download_attachment(file_path: str): # business logic data = await service.download_attachment(file_path=file_path) - - return data + if data.get("redirect_url"): + return RedirectResponse(url=data["redirect_url"], status_code=status.HTTP_307_TEMPORARY_REDIRECT) + return FileResponse(path=data["local_path"]) @router.get( "/partial/{file_path:path}", @@ -75,7 +76,9 @@ async def get_content(file_path: str, range: str = Header(None)): try: full_path = await service.download_attachment(file_path=file_path) - full_path = Path(full_path) + if full_path.get("redirect_url"): + return RedirectResponse(url=full_path["redirect_url"], status_code=status.HTTP_307_TEMPORARY_REDIRECT) + full_path = Path(full_path["local_path"]) except (FileNotFoundError, OSError, LabelUException): raise LabelUException( code=ErrorCode.CODE_51001_TASK_ATTACHMENT_NOT_FOUND, diff --git a/labelu/internal/adapter/routers/datasource.py b/labelu/internal/adapter/routers/datasource.py new file mode 100644 index 00000000..51020583 --- /dev/null +++ b/labelu/internal/adapter/routers/datasource.py @@ -0,0 +1,129 @@ +from typing import List, Union + +from sqlalchemy.orm import Session +from fastapi import APIRouter, Depends, Query, status, Security +from fastapi.security import HTTPAuthorizationCredentials + +from labelu.internal.common import db as db_module +from labelu.internal.common.security import security +from labelu.internal.domain.models.user import User +from labelu.internal.dependencies.user import get_current_user +from labelu.internal.application.service import datasource as service +from labelu.internal.application.command.datasource import ( + CreateDataSourceCommand, + UpdateDataSourceCommand, +) +from labelu.internal.application.response.base import ( + CommonDataResp, + MetaData, + OkResp, + OkRespWithMeta, +) +from labelu.internal.application.response.datasource import ( + DataSourceResponse, + S3ObjectListResponse, +) + +router = APIRouter(prefix="/datasources", tags=["datasources"]) + + +@router.post( + "", + response_model=OkResp[DataSourceResponse], + status_code=status.HTTP_201_CREATED, +) +async def create( + cmd: CreateDataSourceCommand, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = await service.create(db=db, cmd=cmd, current_user=current_user) + return OkResp[DataSourceResponse](data=data) + + +@router.get( + "", + response_model=OkRespWithMeta[List[DataSourceResponse]], + status_code=status.HTTP_200_OK, +) +async def list_all( + page: int = Query(default=0, ge=0), + size: int = Query(default=100, ge=1, le=500), + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data, total = await service.list_by(db=db, current_user=current_user, page=page, size=size) + return OkRespWithMeta[List[DataSourceResponse]]( + meta_data=MetaData(total=total, page=page, size=len(data)), + data=data, + ) + + +@router.get( + "/{ds_id}", + response_model=OkResp[DataSourceResponse], + status_code=status.HTTP_200_OK, +) +async def get( + ds_id: int, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = await service.get(db=db, ds_id=ds_id) + return OkResp[DataSourceResponse](data=data) + + +@router.patch( + "/{ds_id}", + response_model=OkResp[DataSourceResponse], + status_code=status.HTTP_200_OK, +) +async def update( + ds_id: int, + cmd: UpdateDataSourceCommand, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = await service.update(db=db, ds_id=ds_id, cmd=cmd, current_user=current_user) + return OkResp[DataSourceResponse](data=data) + + +@router.delete( + "/{ds_id}", + response_model=OkResp[CommonDataResp], + status_code=status.HTTP_200_OK, +) +async def delete( + ds_id: int, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + await service.delete(db=db, ds_id=ds_id, current_user=current_user) + return OkResp[CommonDataResp](data=CommonDataResp(ok=True)) + + +@router.get( + "/{ds_id}/objects", + response_model=OkResp[S3ObjectListResponse], + status_code=status.HTTP_200_OK, +) +async def list_objects( + ds_id: int, + prefix: Union[str, None] = Query(default=None), + extension: Union[str, None] = Query(default=None, description="Comma-separated extensions, e.g. .jpg,.png"), + page_token: Union[str, None] = Query(default=None), + size: int = Query(default=100, ge=1, le=1000), + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = await service.list_objects( + db=db, ds_id=ds_id, prefix=prefix, extension=extension, + page_token=page_token, size=size, + ) + return OkResp[S3ObjectListResponse](data=data) diff --git a/labelu/internal/adapter/routers/sample.py b/labelu/internal/adapter/routers/sample.py index 55e86bd0..ddd77483 100644 --- a/labelu/internal/adapter/routers/sample.py +++ b/labelu/internal/adapter/routers/sample.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import Session from fastapi import APIRouter, Depends, Query, status, Security -from fastapi.responses import FileResponse +from fastapi.responses import FileResponse, RedirectResponse from fastapi.security import HTTPAuthorizationCredentials from labelu.internal.common import db as db_module @@ -13,19 +13,24 @@ from labelu.internal.domain.models.user import User from labelu.internal.dependencies.user import get_current_user from labelu.internal.application.service import sample as service +from labelu.internal.application.service import auto_label as auto_label_service +from labelu.internal.application.command.auto_label import AutoLabelCommand, BatchAutoLabelCommand from labelu.internal.application.command.sample import ExportType from labelu.internal.application.command.sample import PatchSampleCommand from labelu.internal.application.command.sample import CreateSampleCommand from labelu.internal.application.command.sample import DeleteSampleCommand from labelu.internal.application.command.sample import ExportSampleCommand +from labelu.internal.application.command.datasource import ImportS3SamplesCommand from labelu.internal.application.response.base import OkResp from labelu.internal.application.response.base import MetaData from labelu.internal.application.response.base import CommonDataResp from labelu.internal.application.response.base import OkRespWithMeta from labelu.internal.application.response.sample import SampleResponse from labelu.internal.application.response.sample import CreateSampleResponse +from labelu.internal.application.response.auto_label import AutoLabelResponse, AutoLabelJobResponse from labelu.internal.application.response.export import ExportJobResponse from labelu.internal.adapter.persistence import crud_export_job +from labelu.internal.common.storage import get_storage_backend router = APIRouter(prefix="/tasks", tags=["samples"]) @@ -162,6 +167,84 @@ async def update( return OkResp[SampleResponse](data=data) +@router.post( + "/{task_id}/samples/{sample_id}/auto_label", + response_model=OkResp[AutoLabelResponse], + status_code=status.HTTP_200_OK, +) +async def auto_label( + task_id: int, + sample_id: int, + cmd: AutoLabelCommand, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = await auto_label_service.create( + db=db, + task_id=task_id, + sample_id=sample_id, + cmd=cmd, + current_user=current_user, + ) + return OkResp[AutoLabelResponse](data=data) + + +@router.post( + "/{task_id}/auto_label_job", + response_model=OkResp[AutoLabelJobResponse], + status_code=status.HTTP_200_OK, +) +async def create_auto_label_job( + task_id: int, + cmd: BatchAutoLabelCommand, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = await auto_label_service.create_batch_job( + db=db, + task_id=task_id, + cmd=cmd, + current_user=current_user, + ) + return OkResp[AutoLabelJobResponse](data=data) + + +@router.get( + "/{task_id}/auto_label_job/{job_id}", + response_model=OkResp[AutoLabelJobResponse], + status_code=status.HTTP_200_OK, +) +async def get_auto_label_job_status( + task_id: int, + job_id: int, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = auto_label_service.get_batch_job(db=db, task_id=task_id, job_id=job_id) + return OkResp[AutoLabelJobResponse](data=data) + + +@router.post( + "/{task_id}/samples/import_s3", + response_model=OkResp[CreateSampleResponse], + status_code=status.HTTP_201_CREATED, +) +async def import_s3( + task_id: int, + cmd: ImportS3SamplesCommand, + authorization: HTTPAuthorizationCredentials = Security(security), + db: Session = Depends(db_module.get_db), + current_user: User = Depends(get_current_user), +): + data = await service.import_from_s3( + db=db, task_id=task_id, cmd=cmd, current_user=current_user, + ) + return OkResp[CreateSampleResponse](data=data) + + @router.delete( "/{task_id}/samples", response_model=OkResp[CommonDataResp], @@ -290,6 +373,11 @@ async def download_export( status_code=status.HTTP_404_NOT_FOUND, ) + storage = get_storage_backend() + if storage.is_remote: + download_url = storage.get_read_url(job.file_path) + return RedirectResponse(url=download_url, status_code=status.HTTP_307_TEMPORARY_REDIRECT) + file_path = Path(job.file_path) media_type = ".json" if file_path.suffix == ".json" else file_path.suffix.strip(".") return FileResponse( diff --git a/labelu/internal/application/command/auto_label.py b/labelu/internal/application/command/auto_label.py new file mode 100644 index 00000000..9bb4a3d8 --- /dev/null +++ b/labelu/internal/application/command/auto_label.py @@ -0,0 +1,14 @@ +from typing import Union + +from pydantic import BaseModel, Field + + +class AutoLabelCommand(BaseModel): + overwrite: bool = Field(default=True, description="overwrite latest ai-generated pre-annotation") + template_id: Union[int, None] = Field(default=None, description="reserved prompt template id") + prompt: Union[str, None] = Field(default=None, description="optional prompt override") + filter_by_labels: bool = Field(default=True, description="only keep results matching configured labels") + + +class BatchAutoLabelCommand(BaseModel): + filter_by_labels: bool = Field(default=True, description="only keep results matching configured labels") diff --git a/labelu/internal/application/command/datasource.py b/labelu/internal/application/command/datasource.py new file mode 100644 index 00000000..ee4576b6 --- /dev/null +++ b/labelu/internal/application/command/datasource.py @@ -0,0 +1,43 @@ +from typing import Union + +from pydantic import BaseModel, Field, model_validator + + +class CreateDataSourceCommand(BaseModel): + name: str = Field(max_length=128, description="Display name") + type: str = Field(default="S3", max_length=32) + endpoint: Union[str, None] = Field(default=None, max_length=512) + region: Union[str, None] = Field(default=None, max_length=64) + bucket: str = Field(max_length=256) + prefix: str = Field(default="", max_length=512) + access_key_id: str = Field(max_length=256) + secret_access_key: str = Field(max_length=256) + path_style: bool = Field(default=False) + use_ssl: bool = Field(default=True) + presign_expire_secs: int = Field(default=3600, ge=60, le=86400) + + +class UpdateDataSourceCommand(BaseModel): + name: Union[str, None] = Field(default=None, max_length=128) + endpoint: Union[str, None] = Field(default=None, max_length=512) + region: Union[str, None] = Field(default=None, max_length=64) + bucket: Union[str, None] = Field(default=None, max_length=256) + prefix: Union[str, None] = Field(default=None, max_length=512) + access_key_id: Union[str, None] = Field(default=None, max_length=256) + secret_access_key: Union[str, None] = Field(default=None, max_length=256) + path_style: Union[bool, None] = None + use_ssl: Union[bool, None] = None + presign_expire_secs: Union[int, None] = Field(default=None, ge=60, le=86400) + + +class ImportS3SamplesCommand(BaseModel): + data_source_id: int = Field(gt=0) + object_keys: list[str] = Field(default_factory=list, max_length=10000) + prefix: Union[str, None] = Field(default=None, max_length=512) + extension: Union[str, None] = Field(default=None, max_length=256) + + @model_validator(mode='after') + def check_keys_or_prefix(self): + if not self.object_keys and self.prefix is None: + raise ValueError("Either object_keys or prefix must be provided") + return self diff --git a/labelu/internal/application/response/attachment.py b/labelu/internal/application/response/attachment.py index 017ee134..dd57c33f 100644 --- a/labelu/internal/application/response/attachment.py +++ b/labelu/internal/application/response/attachment.py @@ -9,6 +9,15 @@ class AttachmentResponse(BaseModel): url: Union[str, None] = Field( default=None, description="description: upload file url" ) + thumbnail_url: Union[str, None] = Field( + default=None, description="description: upload file thumbnail url" + ) + stream_url: Union[str, None] = Field( + default=None, description="description: upload file stream url" + ) + storage_backend: Union[str, None] = Field( + default=None, description="description: upload file storage backend" + ) filename: Union[str, None] = Field( default=None, description="description: upload file name" ) diff --git a/labelu/internal/application/response/auto_label.py b/labelu/internal/application/response/auto_label.py new file mode 100644 index 00000000..32621d5a --- /dev/null +++ b/labelu/internal/application/response/auto_label.py @@ -0,0 +1,28 @@ +from typing import Union +from datetime import datetime + +from pydantic import BaseModel, Field + + +class AutoLabelResponse(BaseModel): + status: str = Field(description="auto-label job status") + task_id: int = Field(description="task id") + sample_id: int = Field(description="sample id") + media_type: str = Field(description="media type") + provider: str = Field(description="ai provider") + model: Union[str, None] = Field(default=None, description="model name") + latency_ms: Union[int, None] = Field(default=None, description="model latency in milliseconds") + pre_annotation_id: Union[int, None] = Field(default=None, description="created or reused pre-annotation id") + warning_message: Union[str, None] = Field(default=None, description="non-blocking warning message") + + +class AutoLabelJobResponse(BaseModel): + id: int + task_id: int + status: str + sample_count: int + processed_count: int + success_count: int + failed_count: int + error_message: Union[str, None] = None + created_at: Union[datetime, None] = None diff --git a/labelu/internal/application/response/datasource.py b/labelu/internal/application/response/datasource.py new file mode 100644 index 00000000..e26d3e80 --- /dev/null +++ b/labelu/internal/application/response/datasource.py @@ -0,0 +1,32 @@ +from typing import Union +from datetime import datetime + +from pydantic import BaseModel, Field + + +class DataSourceResponse(BaseModel): + id: int + name: str + type: str + endpoint: Union[str, None] = None + region: Union[str, None] = None + bucket: str + prefix: str = "" + path_style: bool = False + use_ssl: bool = True + presign_expire_secs: int = 3600 + created_by: int + created_at: Union[datetime, None] = None + updated_at: Union[datetime, None] = None + + +class S3ObjectItem(BaseModel): + key: str + size: int = 0 + last_modified: Union[str, None] = None + + +class S3ObjectListResponse(BaseModel): + objects: list[S3ObjectItem] = [] + next_page_token: Union[str, None] = None + truncated: bool = False diff --git a/labelu/internal/application/service/attachment.py b/labelu/internal/application/service/attachment.py index 0f9b745c..52a42612 100644 --- a/labelu/internal/application/service/attachment.py +++ b/labelu/internal/application/service/attachment.py @@ -1,7 +1,7 @@ import re import aiofiles import os -from PIL import Image +import tempfile from pathlib import Path from loguru import logger @@ -12,6 +12,13 @@ from labelu.internal.common.config import settings from labelu.internal.common.error_code import ErrorCode from labelu.internal.common.error_code import LabelUException +from labelu.internal.common.storage import ( + build_attachment_api_path, + build_partial_api_path, + build_thumbnail_key, + create_thumbnail_bytes, + get_storage_backend, +) from labelu.internal.domain.models.user import User from labelu.internal.domain.models.attachment import TaskAttachment from labelu.internal.adapter.persistence import crud_task @@ -22,9 +29,48 @@ from labelu.internal.application.response.attachment import AttachmentResponse +def build_attachment_response(attachment) -> AttachmentResponse | None: + if not attachment: + return None + + # External data source attachment — use the data source's own credentials + if getattr(attachment, "data_source_id", None) and getattr(attachment, "data_source", None): + from labelu.internal.application.service.datasource import get_presigned_url + url = get_presigned_url(attachment.data_source, attachment.path) + return AttachmentResponse( + id=attachment.id, + filename=attachment.filename, + url=url, + thumbnail_url=None, + stream_url=url, + storage_backend="s3", + ) + + # Local or global S3 storage + storage = get_storage_backend() + if storage.is_remote: + url = storage.get_read_url(attachment.path) + thumbnail_key = build_thumbnail_key(attachment.path) + thumbnail_url = storage.get_read_url(thumbnail_key) if storage.exists(thumbnail_key) else None + stream_url = url + else: + url = attachment.url or build_attachment_api_path(attachment.path) + thumbnail_url = build_attachment_api_path(build_thumbnail_key(attachment.path)) + stream_url = build_partial_api_path(attachment.path) + return AttachmentResponse( + id=attachment.id, + filename=attachment.filename, + url=url, + thumbnail_url=thumbnail_url, + stream_url=stream_url, + storage_backend=storage.backend_name, + ) + + async def create( db: Session, task_id: int, cmd: AttachmentCommand, current_user: User ) -> AttachmentResponse: + storage = get_storage_backend() task = crud_task.get(db=db, task_id=task_id) if not task: @@ -46,29 +92,25 @@ async def create( ) attachment_relative_path = str(attachment_relative_base_dir.joinpath(sanitized)) - # file full path - attachment_full_base_dir = Path(settings.MEDIA_ROOT).joinpath( - attachment_relative_base_dir - ) - attachment_full_path = Path(settings.MEDIA_ROOT).joinpath( - attachment_relative_path - ) - # check file exist - if attachment_full_path.exists(): - logger.error("file already exists:{}", attachment_full_path) + if storage.exists(attachment_relative_path): + logger.error("file already exists:{}", attachment_relative_path) raise LabelUException( code=ErrorCode.CODE_51002_TASK_ATTACHMENT_ALREADY_EXISTS, status_code=status.HTTP_400_BAD_REQUEST, ) - - # create dicreatory - attachment_full_base_dir.mkdir(parents=True, exist_ok=True) - + CHUNK_SIZE = 8 * 1024 * 1024 # 8MB - logger.info(attachment_full_path) + thumbnail_key = build_thumbnail_key(attachment_relative_path) + thumbnail_bytes = None + temp_file_path = None + try: - async with aiofiles.open(attachment_full_path, "wb") as out_file: + temp_file = tempfile.NamedTemporaryFile(delete=False) + temp_file_path = Path(temp_file.name) + temp_file.close() + + async with aiofiles.open(temp_file_path, "wb") as out_file: total_size = 0 while True: chunk = await cmd.file.read(CHUNK_SIZE) @@ -78,43 +120,50 @@ async def create( total_size += len(chunk) logger.debug(f"{total_size} bytes written") - logger.info(f"File saved: {attachment_full_path}, size: {total_size} bytes") + if cmd.file.content_type and cmd.file.content_type.startswith("image/"): + thumbnail_bytes = create_thumbnail_bytes(temp_file_path) + + storage.save_file( + local_path=temp_file_path, + key=attachment_relative_path, + content_type=cmd.file.content_type, + ) + + if thumbnail_bytes: + storage.save_bytes( + content=thumbnail_bytes, + key=thumbnail_key, + content_type=cmd.file.content_type, + ) + + if temp_file_path and temp_file_path.exists(): + os.remove(temp_file_path) + + logger.info(f"File saved: {attachment_relative_path}, size: {total_size} bytes") except Exception as e: - if attachment_full_path.exists(): - os.remove(attachment_full_path) + if temp_file_path and temp_file_path.exists(): + os.remove(temp_file_path) + if storage.exists(attachment_relative_path): + storage.delete(attachment_relative_path) + if thumbnail_bytes and storage.exists(thumbnail_key): + storage.delete(thumbnail_key) logger.error(f"Upload failed: {str(e)}") raise LabelUException( code=ErrorCode.CODE_51000_CREATE_ATTACHMENT_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - message=f"Upload failed: {str(e)}" - ) - - # create thumbnail for image - if cmd.file.content_type.startswith("image/"): - tumbnail_full_path = Path( - f"{attachment_full_path.parent}/{attachment_full_path.stem}-thumbnail{attachment_full_path.suffix}" ) - logger.info(tumbnail_full_path) - image = Image.open(attachment_full_path) - image.thumbnail( - ( - round(image.width / image.height * settings.THUMBNAIL_HEIGH_PIXEL), - settings.THUMBNAIL_HEIGH_PIXEL, - ), - ) - if image.mode != "RGB": - image = image.convert("RGB") - image.save(tumbnail_full_path) # check file already saved - if not attachment_full_path.exists() or ( - cmd.file.content_type.startswith("image/") and not tumbnail_full_path.exists() + if not storage.exists(attachment_relative_path) or ( + cmd.file.content_type + and cmd.file.content_type.startswith("image/") + and not storage.exists(thumbnail_key) ): logger.error( "cannot find saved images, path is:{}, image content-type is:{}, thumbnail path is:{}", - attachment_full_path, + attachment_relative_path, cmd.file.content_type, - tumbnail_full_path, + thumbnail_key, ) raise LabelUException( code=ErrorCode.CODE_51000_CREATE_ATTACHMENT_ERROR, @@ -122,7 +171,7 @@ async def create( ) attachment_url_path = attachment_relative_path.replace("\\", "/") - attachment_api_url = f"{settings.API_V1_STR}/tasks/attachment/{attachment_url_path}" + attachment_api_url = build_attachment_api_path(attachment_url_path) # add a task file record with begin_transaction(db): attachment = crud_attachment.create( @@ -138,31 +187,45 @@ async def create( ) # response + if storage.is_remote: + url = storage.get_read_url(attachment_url_path) + thumb_url = storage.get_read_url(thumbnail_key) if thumbnail_bytes else None + stream = url + else: + url = attachment_api_url + thumb_url = build_attachment_api_path(thumbnail_key) if thumbnail_bytes else None + stream = build_partial_api_path(attachment_url_path) return AttachmentResponse( id=attachment.id, - url=attachment_api_url, + url=url, + thumbnail_url=thumb_url, + stream_url=stream, + storage_backend=storage.backend_name, filename=sanitized, ) -async def download_attachment(file_path: str) -> str: +async def download_attachment(file_path: str) -> dict: + storage = get_storage_backend() # check file exist - file_full_path = settings.MEDIA_ROOT.joinpath(file_path.lstrip("/")) - if not file_full_path.is_file() or not file_full_path.exists(): - logger.error("attachment not found:{}", file_full_path) + if not storage.exists(file_path): + logger.error("attachment not found:{}", file_path) raise LabelUException( code=ErrorCode.CODE_51001_TASK_ATTACHMENT_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND, ) - # response - return file_full_path + local_path = storage.get_local_path(file_path) + if local_path: + return {"local_path": str(local_path)} + return {"redirect_url": storage.get_read_url(file_path)} async def delete( db: Session, task_id: int, cmd: AttachmentDeleteCommand, current_user: User ) -> CommonDataResp: + storage = get_storage_backend() # get task task = crud_task.get(db=db, task_id=task_id) @@ -190,8 +253,10 @@ async def delete( db=db, attachment_ids=cmd.attachment_ids ) for attachment in attachments: - file_full_path = Path(settings.MEDIA_ROOT).joinpath(attachment.path) - os.remove(file_full_path) + storage.delete(attachment.path) + thumbnail_key = build_thumbnail_key(attachment.path) + if storage.exists(thumbnail_key): + storage.delete(thumbnail_key) except Exception as e: logger.error(e) diff --git a/labelu/internal/application/service/auto_label.py b/labelu/internal/application/service/auto_label.py new file mode 100644 index 00000000..f40fdc7f --- /dev/null +++ b/labelu/internal/application/service/auto_label.py @@ -0,0 +1,537 @@ +import asyncio +import json +import time +import uuid +from typing import Any + +import httpx +from fastapi import status +from loguru import logger +from sqlalchemy.orm import Session + +from labelu.internal.adapter.persistence import crud_pre_annotation, crud_sample, crud_task +from labelu.internal.adapter.persistence import crud_auto_label_job +from labelu.internal.application.command.auto_label import AutoLabelCommand, BatchAutoLabelCommand +from labelu.internal.application.response.auto_label import AutoLabelResponse, AutoLabelJobResponse +from labelu.internal.common.db import begin_transaction +from labelu.internal.common.error_code import ErrorCode, LabelUException +from labelu.internal.common.storage import get_model_read_url +from labelu.internal.common.config import settings +from labelu.internal.domain.models.auto_label_job import AutoLabelStatus +from labelu.internal.domain.models.pre_annotation import TaskPreAnnotation +from labelu.internal.domain.models.task import MediaType +from labelu.internal.domain.models.user import User + + +SUPPORTED_IMAGE_TOOLS = {"rectTool", "polygonTool", "pointTool", "lineTool"} + + +def _parse_task_config(raw_config: str | dict | None) -> dict[str, Any]: + if not raw_config: + return {} + if isinstance(raw_config, dict): + return raw_config + try: + return json.loads(raw_config) + except json.JSONDecodeError: + logger.warning("failed to parse task config: {}", raw_config) + return {} + + +def _extract_tool_configs(task_config: dict[str, Any]) -> tuple[list[dict[str, Any]], dict[str, Any]]: + labels: list[dict[str, Any]] = [] + config_by_tool: dict[str, Any] = {} + for tool in task_config.get("tools", []) or []: + tool_name = tool.get("tool") + tool_config = tool.get("config", {}) or {} + if tool_name not in SUPPORTED_IMAGE_TOOLS: + continue + attributes = tool_config.get("attributes", []) or [] + if not attributes: + continue + config_by_tool[tool_name] = tool_config + for attr in attributes: + labels.append( + { + "name": attr.get("value") or attr.get("key"), + "display_name": attr.get("key"), + "color": attr.get("color"), + "tool": tool_name, + } + ) + return labels, config_by_tool + + +def _get_image_url(attachment) -> str: + """Get a readable URL for the image, handling both local/global-S3 and external data source attachments.""" + if getattr(attachment, "data_source_id", None) and getattr(attachment, "data_source", None): + from labelu.internal.application.service.datasource import get_presigned_url + return get_presigned_url(attachment.data_source, attachment.path) + return get_model_read_url(attachment.path) + + +def _build_model_payload(sample, task, task_config: dict[str, Any], cmd: AutoLabelCommand) -> dict[str, Any]: + labels, config_by_tool = _extract_tool_configs(task_config) + if not labels: + raise LabelUException( + code=ErrorCode.CODE_56002_AUTO_LABEL_NO_LABELS_CONFIGURED, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + return { + "request_id": str(uuid.uuid4()), + "image_url": _get_image_url(sample.file), + "task": { + "id": task.id, + "name": task.name, + }, + "labels": labels, + "constraints": { + "allowed_tools": list(config_by_tool.keys()), + "max_results_per_label": 100, + "filter_by_labels": cmd.filter_by_labels, + }, + "prompt": cmd.prompt, + } + + +def _annotation_meta(data: str | dict | None) -> dict[str, Any]: + if not data: + return {} + parsed = data if isinstance(data, dict) else json.loads(data) + return parsed.get("meta", {}) or {} + + +def _is_ai_generated(pre_annotation: TaskPreAnnotation) -> bool: + try: + return _annotation_meta(pre_annotation.data).get("source_type") == "ai_generated" + except Exception: + return False + + +def _normalize_single_result(tool_name: str, item: dict[str, Any], order: int) -> dict[str, Any]: + result_payload = item.get("result", item) + attributes = item.get("attributes") or {} + if item.get("score") is not None: + attributes = {**attributes, "score": str(item["score"])} + + base = { + "id": item.get("id") or str(uuid.uuid4()), + "order": item.get("order", order), + "label": item.get("label"), + "visible": item.get("visible", True), + } + if attributes: + base["attributes"] = attributes + + if tool_name == "rectTool": + return { + **base, + "x": result_payload["x"], + "y": result_payload["y"], + "width": result_payload["width"], + "height": result_payload["height"], + } + + if tool_name == "polygonTool": + return { + **base, + "type": result_payload.get("type", "line"), + "points": result_payload.get("points") or result_payload.get("pointList") or [], + } + + if tool_name == "lineTool": + return { + **base, + "type": result_payload.get("type", "line"), + "points": result_payload.get("points") or result_payload.get("pointList") or [], + } + + if tool_name == "pointTool": + return { + **base, + "x": result_payload["x"], + "y": result_payload["y"], + } + + raise LabelUException( + code=ErrorCode.CODE_56005_AUTO_LABEL_INVALID_RESPONSE, + status_code=status.HTTP_502_BAD_GATEWAY, + ) + + +def _normalize_results(model_data: dict[str, Any], sample_name: str, task_config: dict[str, Any]) -> dict[str, Any]: + results = model_data.get("results", []) + flattened: list[tuple[str, dict[str, Any]]] = [] + if isinstance(results, dict): + for tool_name, items in results.items(): + for item in items or []: + flattened.append((tool_name, item)) + else: + for item in results: + flattened.append((item.get("toolName"), item)) + + grouped_annotations: dict[str, dict[str, Any]] = {} + for index, (tool_name, item) in enumerate(flattened): + if tool_name not in SUPPORTED_IMAGE_TOOLS: + continue + grouped_annotations.setdefault(tool_name, {"toolName": tool_name, "result": []}) + grouped_annotations[tool_name]["result"].append(_normalize_single_result(tool_name, item, index)) + + _, config_by_tool = _extract_tool_configs(task_config) + # Pre-annotation config format: { toolName: [{key, value, color}, ...] } + pre_annotation_config = { + tool_name: tool_cfg.get("attributes", []) + for tool_name, tool_cfg in config_by_tool.items() + } + return { + "sample_name": sample_name, + "annotations": grouped_annotations, + "config": pre_annotation_config, + "meta": { + "source_type": "ai_generated", + "provider": settings.AI_PROVIDER, + "model": model_data.get("model") or settings.AI_MODEL_NAME, + "latency_ms": model_data.get("latency_ms"), + "warning_message": model_data.get("warning_message"), + }, + } + + +async def create( + db: Session, + task_id: int, + sample_id: int, + cmd: AutoLabelCommand, + current_user: User, +) -> AutoLabelResponse: + if not settings.AI_AUTO_LABEL_ENABLED: + raise LabelUException( + code=ErrorCode.CODE_56000_AUTO_LABEL_DISABLED, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + if not settings.AI_MODEL_ENDPOINT: + logger.error("AI_MODEL_ENDPOINT is not configured") + raise LabelUException( + code=ErrorCode.CODE_56004_AUTO_LABEL_NOT_CONFIGURED, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + task = crud_task.get(db=db, task_id=task_id) + if not task: + raise LabelUException( + code=ErrorCode.CODE_50002_TASK_NOT_FOUND, + status_code=status.HTTP_404_NOT_FOUND, + ) + + collaborator_ids = {c.id for c in task.collaborators} + if task.created_by != current_user.id and current_user.id not in collaborator_ids: + raise LabelUException( + code=ErrorCode.CODE_30001_NO_PERMISSION, + status_code=status.HTTP_403_FORBIDDEN, + ) + + if task.media_type != MediaType.IMAGE.value: + raise LabelUException( + code=ErrorCode.CODE_56001_AUTO_LABEL_UNSUPPORTED_MEDIA, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + sample = crud_sample.get(db=db, sample_id=sample_id) + if not sample or sample.task_id != task_id or not sample.file: + raise LabelUException( + code=ErrorCode.CODE_55001_SAMPLE_NOT_FOUND, + status_code=status.HTTP_404_NOT_FOUND, + ) + + existing_pre_annotations, _ = crud_pre_annotation.list_by( + db=db, + task_id=task_id, + sample_name=sample.file.filename, + page=0, + size=100, + ) + existing_ai_annotations = [item for item in existing_pre_annotations if _is_ai_generated(item)] + if existing_ai_annotations and not cmd.overwrite: + latest_pre_annotation = existing_ai_annotations[-1] + meta = _annotation_meta(latest_pre_annotation.data) + return AutoLabelResponse( + status="COMPLETED", + task_id=task_id, + sample_id=sample_id, + media_type=task.media_type, + provider=meta.get("provider") or settings.AI_PROVIDER, + model=meta.get("model") or settings.AI_MODEL_NAME, + latency_ms=meta.get("latency_ms"), + pre_annotation_id=latest_pre_annotation.id, + warning_message=meta.get("warning_message"), + ) + + model_payload = _build_model_payload(sample, task, _parse_task_config(task.config), cmd) + start = time.perf_counter() + try: + async with httpx.AsyncClient(timeout=settings.AI_MODEL_TIMEOUT_SECONDS) as client: + response = await client.post(settings.AI_MODEL_ENDPOINT, json=model_payload) + if response.status_code >= 400: + logger.error( + "model service returned {}: {}", + response.status_code, + response.text, + ) + response.raise_for_status() + model_data = response.json() + except Exception as exc: + logger.opt(exception=exc).error("auto label model request failed") + raise LabelUException( + code=ErrorCode.CODE_56003_AUTO_LABEL_MODEL_ERROR, + message=response.text, + status_code=status.HTTP_502_BAD_GATEWAY, + ) + + normalized_payload = _normalize_results( + model_data=model_data, + sample_name=sample.file.filename, + task_config=_parse_task_config(task.config), + ) + normalized_payload["meta"]["latency_ms"] = int((time.perf_counter() - start) * 1000) + + with begin_transaction(db): + if existing_ai_annotations: + crud_pre_annotation.delete( + db=db, + pre_annotation_ids=[item.id for item in existing_ai_annotations], + ) + + pre_annotations = crud_pre_annotation.batch( + db=db, + pre_annotations=[ + TaskPreAnnotation( + task_id=task_id, + file_id=None, + sample_name=sample.file.filename, + data=json.dumps(normalized_payload, ensure_ascii=False), + created_by=current_user.id, + updated_by=current_user.id, + ) + ], + ) + + created_pre_annotation = pre_annotations[0] + meta = normalized_payload.get("meta", {}) + return AutoLabelResponse( + status="COMPLETED", + task_id=task_id, + sample_id=sample_id, + media_type=task.media_type, + provider=meta.get("provider") or settings.AI_PROVIDER, + model=meta.get("model") or settings.AI_MODEL_NAME, + latency_ms=meta.get("latency_ms"), + pre_annotation_id=created_pre_annotation.id, + warning_message=meta.get("warning_message"), + ) + + +# ── Batch auto-label ────────────────────────────────────────────────────────── + + +async def create_batch_job( + db: Session, + task_id: int, + cmd: BatchAutoLabelCommand, + current_user: User, +) -> AutoLabelJobResponse: + if not settings.AI_AUTO_LABEL_ENABLED: + raise LabelUException( + code=ErrorCode.CODE_56000_AUTO_LABEL_DISABLED, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + if not settings.AI_MODEL_ENDPOINT: + raise LabelUException( + code=ErrorCode.CODE_56004_AUTO_LABEL_NOT_CONFIGURED, + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + task = crud_task.get(db=db, task_id=task_id) + if not task: + raise LabelUException( + code=ErrorCode.CODE_50002_TASK_NOT_FOUND, + status_code=status.HTTP_404_NOT_FOUND, + ) + + if task.media_type != MediaType.IMAGE.value: + raise LabelUException( + code=ErrorCode.CODE_56001_AUTO_LABEL_UNSUPPORTED_MEDIA, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + task_config = _parse_task_config(task.config) + labels, _ = _extract_tool_configs(task_config) + if not labels: + raise LabelUException( + code=ErrorCode.CODE_56002_AUTO_LABEL_NO_LABELS_CONFIGURED, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + # Get all NEW state samples + new_samples = crud_sample.list_new_samples(db=db, task_id=task_id) + if not new_samples: + raise LabelUException( + code=ErrorCode.CODE_56006_AUTO_LABEL_NO_SAMPLES, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + sample_ids = [s.id for s in new_samples] + + with begin_transaction(db): + job = crud_auto_label_job.create( + db=db, + task_id=task_id, + user_id=current_user.id, + sample_count=len(sample_ids), + filter_by_labels=cmd.filter_by_labels, + ) + job_id = job.id + + asyncio.get_event_loop().run_in_executor( + None, _run_batch_auto_label_sync, job_id, task_id, sample_ids, cmd.filter_by_labels + ) + + return AutoLabelJobResponse( + id=job_id, + task_id=task_id, + status=AutoLabelStatus.PENDING.value, + sample_count=len(sample_ids), + processed_count=0, + success_count=0, + failed_count=0, + created_at=job.created_at, + ) + + +def get_batch_job(db: Session, task_id: int, job_id: int) -> AutoLabelJobResponse: + job = crud_auto_label_job.get(db=db, job_id=job_id) + if not job or job.task_id != task_id: + raise LabelUException( + code=ErrorCode.CODE_50002_TASK_NOT_FOUND, + status_code=status.HTTP_404_NOT_FOUND, + ) + return AutoLabelJobResponse( + id=job.id, + task_id=job.task_id, + status=job.status, + sample_count=job.sample_count or 0, + processed_count=job.processed_count or 0, + success_count=job.success_count or 0, + failed_count=job.failed_count or 0, + error_message=job.error_message, + created_at=job.created_at, + ) + + +def _run_batch_auto_label_sync(job_id: int, task_id: int, sample_ids: list[int], filter_by_labels: bool): + """Run batch auto-label in a thread. Processes samples sequentially.""" + from labelu.internal.common.db import SessionLocal + + db = SessionLocal() + try: + job = crud_auto_label_job.get(db=db, job_id=job_id) + with begin_transaction(db): + crud_auto_label_job.update_status(db, job, AutoLabelStatus.PROCESSING.value) + + task = crud_task.get(db=db, task_id=task_id) + task_config = _parse_task_config(task.config) + + for sample_id in sample_ids: + try: + _process_single_sample(db, task, task_config, sample_id, filter_by_labels, job) + with begin_transaction(db): + crud_auto_label_job.increment_progress(db, job, success=True) + except Exception as exc: + logger.error("Batch auto-label failed for sample {}: {}", sample_id, str(exc)) + with begin_transaction(db): + crud_auto_label_job.increment_progress(db, job, success=False) + + with begin_transaction(db): + crud_auto_label_job.update_status(db, job, AutoLabelStatus.COMPLETED.value) + except Exception as e: + logger.error("Batch auto-label job {} failed: {}", job_id, str(e)) + try: + job = crud_auto_label_job.get(db=db, job_id=job_id) + with begin_transaction(db): + crud_auto_label_job.update_status( + db, job, AutoLabelStatus.FAILED.value, + error_message=str(e), + ) + except Exception: + logger.error("Failed to update auto-label job status for job {}", job_id) + finally: + db.close() + + +def _process_single_sample( + db: Session, + task, + task_config: dict[str, Any], + sample_id: int, + filter_by_labels: bool, + job, +): + """Process a single sample: call model, normalize results, save pre-annotation.""" + sample = crud_sample.get(db=db, sample_id=sample_id) + if not sample or not sample.file: + raise ValueError(f"Sample {sample_id} not found or has no file") + + # Build payload + labels, config_by_tool = _extract_tool_configs(task_config) + payload = { + "request_id": str(uuid.uuid4()), + "image_url": _get_image_url(sample.file), + "task": {"id": task.id, "name": task.name}, + "labels": labels, + "constraints": { + "allowed_tools": list(config_by_tool.keys()), + "max_results_per_label": 100, + "filter_by_labels": filter_by_labels, + }, + "prompt": None, + } + + # Synchronous HTTP call (we're in a thread) + with httpx.Client(timeout=settings.AI_MODEL_TIMEOUT_SECONDS) as client: + response = client.post(settings.AI_MODEL_ENDPOINT, json=payload) + if response.status_code >= 400: + logger.error("model service returned {}: {}", response.status_code, response.text) + response.raise_for_status() + model_data = response.json() + + normalized_payload = _normalize_results( + model_data=model_data, + sample_name=sample.file.filename, + task_config=task_config, + ) + + # Delete existing AI-generated pre-annotations for this sample + existing_pre_annotations, _ = crud_pre_annotation.list_by( + db=db, task_id=task.id, sample_name=sample.file.filename, page=0, size=100, + ) + existing_ai = [item for item in existing_pre_annotations if _is_ai_generated(item)] + + with begin_transaction(db): + if existing_ai: + crud_pre_annotation.delete(db=db, pre_annotation_ids=[item.id for item in existing_ai]) + + crud_pre_annotation.batch( + db=db, + pre_annotations=[ + TaskPreAnnotation( + task_id=task.id, + file_id=None, + sample_name=sample.file.filename, + data=json.dumps(normalized_payload, ensure_ascii=False), + created_by=job.created_by, + updated_by=job.created_by, + ) + ], + ) diff --git a/labelu/internal/application/service/datasource.py b/labelu/internal/application/service/datasource.py new file mode 100644 index 00000000..feef3f53 --- /dev/null +++ b/labelu/internal/application/service/datasource.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import List, Tuple, Optional + +import boto3 +from botocore.config import Config +from fastapi import status +from loguru import logger +from sqlalchemy.orm import Session + +from labelu.internal.adapter.persistence import crud_datasource +from labelu.internal.application.command.datasource import ( + CreateDataSourceCommand, + UpdateDataSourceCommand, +) +from labelu.internal.application.response.datasource import ( + DataSourceResponse, + S3ObjectItem, + S3ObjectListResponse, +) +from labelu.internal.common.crypto import decrypt_value, encrypt_value +from labelu.internal.common.db import begin_transaction +from labelu.internal.common.error_code import ErrorCode, LabelUException +from labelu.internal.domain.models.data_source import DataSource +from labelu.internal.domain.models.user import User + + +IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".tif", ".gif"} + + +def _to_response(ds: DataSource) -> DataSourceResponse: + return DataSourceResponse( + id=ds.id, + name=ds.name, + type=ds.type, + endpoint=ds.endpoint, + region=ds.region, + bucket=ds.bucket, + prefix=ds.prefix or "", + path_style=ds.path_style, + use_ssl=ds.use_ssl, + presign_expire_secs=ds.presign_expire_secs, + created_by=ds.created_by, + created_at=ds.created_at, + updated_at=ds.updated_at, + ) + + +def _build_s3_client(ds: DataSource): + """Create a boto3 S3 client from a DataSource's (decrypted) credentials.""" + ak = decrypt_value(ds.access_key_id) if ds.access_key_id else None + sk = decrypt_value(ds.secret_access_key) if ds.secret_access_key else None + kwargs = {} + if ds.region: + kwargs["region_name"] = ds.region + kwargs["config"] = Config( + s3={"addressing_style": "path" if ds.path_style else "auto"} + ) + return boto3.client( + "s3", + endpoint_url=ds.endpoint or None, + aws_access_key_id=ak, + aws_secret_access_key=sk, + use_ssl=ds.use_ssl, + **kwargs, + ) + + +def get_presigned_url(ds: DataSource, key: str, expires_in: Optional[int] = None) -> str: + """Generate a presigned read URL for a given object key using the data source credentials.""" + client = _build_s3_client(ds) + return client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": ds.bucket, "Key": key}, + ExpiresIn=expires_in or ds.presign_expire_secs or 3600, + ) + + +# ── CRUD ────────────────────────────────────────────────────────────── + +async def create( + db: Session, cmd: CreateDataSourceCommand, current_user: User +) -> DataSourceResponse: + ds = DataSource( + name=cmd.name, + type=cmd.type, + endpoint=cmd.endpoint, + region=cmd.region, + bucket=cmd.bucket, + prefix=cmd.prefix, + access_key_id=encrypt_value(cmd.access_key_id), + secret_access_key=encrypt_value(cmd.secret_access_key), + path_style=cmd.path_style, + use_ssl=cmd.use_ssl, + presign_expire_secs=cmd.presign_expire_secs, + created_by=current_user.id, + updated_by=current_user.id, + ) + with begin_transaction(db): + ds = crud_datasource.create(db=db, data_source=ds) + return _to_response(ds) + + +async def list_by( + db: Session, current_user: User, page: int = 0, size: int = 100 +) -> Tuple[List[DataSourceResponse], int]: + items, total = crud_datasource.list_by_user( + db=db, user_id=current_user.id, page=page, size=size + ) + return [_to_response(ds) for ds in items], total + + +async def get(db: Session, ds_id: int) -> DataSourceResponse: + ds = crud_datasource.get(db=db, ds_id=ds_id) + if not ds: + raise LabelUException( + code=ErrorCode.CODE_61000_NO_DATA, + status_code=status.HTTP_404_NOT_FOUND, + ) + return _to_response(ds) + + +async def update( + db: Session, ds_id: int, cmd: UpdateDataSourceCommand, current_user: User +) -> DataSourceResponse: + ds = crud_datasource.get(db=db, ds_id=ds_id) + if not ds: + raise LabelUException( + code=ErrorCode.CODE_61000_NO_DATA, + status_code=status.HTTP_404_NOT_FOUND, + ) + obj_in = cmd.model_dump(exclude_unset=True) + if "access_key_id" in obj_in and obj_in["access_key_id"] is not None: + obj_in["access_key_id"] = encrypt_value(obj_in["access_key_id"]) + if "secret_access_key" in obj_in and obj_in["secret_access_key"] is not None: + obj_in["secret_access_key"] = encrypt_value(obj_in["secret_access_key"]) + obj_in["updated_by"] = current_user.id + with begin_transaction(db): + ds = crud_datasource.update(db=db, db_obj=ds, obj_in=obj_in) + return _to_response(ds) + + +async def delete(db: Session, ds_id: int, current_user: User) -> None: + ds = crud_datasource.get(db=db, ds_id=ds_id) + if not ds: + raise LabelUException( + code=ErrorCode.CODE_61000_NO_DATA, + status_code=status.HTTP_404_NOT_FOUND, + ) + with begin_transaction(db): + crud_datasource.soft_delete(db=db, db_obj=ds) + + +# ── S3 file listing ────────────────────────────────────────────────── + +async def list_objects( + db: Session, + ds_id: int, + prefix: Optional[str] = None, + extension: Optional[str] = None, + page_token: Optional[str] = None, + size: int = 100, +) -> S3ObjectListResponse: + ds = crud_datasource.get(db=db, ds_id=ds_id) + if not ds: + raise LabelUException( + code=ErrorCode.CODE_61000_NO_DATA, + status_code=status.HTTP_404_NOT_FOUND, + ) + + client = _build_s3_client(ds) + full_prefix = prefix if prefix is not None else (ds.prefix or "") + + kwargs = { + "Bucket": ds.bucket, + "Prefix": full_prefix, + "MaxKeys": size, + } + if page_token: + kwargs["ContinuationToken"] = page_token + + allowed_exts = None + if extension: + allowed_exts = {("." + e.strip().lower().lstrip(".")) for e in extension.split(",") if e.strip()} + + try: + resp = client.list_objects_v2(**kwargs) + except Exception as exc: + logger.opt(exception=exc).error("S3 list_objects_v2 failed for datasource {}", ds_id) + raise LabelUException( + code=ErrorCode.CODE_62002_S3_REQUEST_FAILED, + status_code=status.HTTP_502_BAD_GATEWAY, + ) + + objects: list[S3ObjectItem] = [] + for obj in resp.get("Contents", []): + key: str = obj["Key"] + if key.endswith("/"): + continue + if allowed_exts: + ext = "." + key.rsplit(".", 1)[-1].lower() if "." in key else "" + if ext not in allowed_exts: + continue + objects.append(S3ObjectItem( + key=key, + size=obj.get("Size", 0), + last_modified=obj.get("LastModified", "").isoformat() if obj.get("LastModified") else None, + )) + + return S3ObjectListResponse( + objects=objects, + next_page_token=resp.get("NextContinuationToken"), + truncated=resp.get("IsTruncated", False), + ) diff --git a/labelu/internal/application/service/pre_annotation.py b/labelu/internal/application/service/pre_annotation.py index 9b00afc5..297cfae2 100644 --- a/labelu/internal/application/service/pre_annotation.py +++ b/labelu/internal/application/service/pre_annotation.py @@ -1,6 +1,5 @@ import json import os -from pathlib import Path from typing import List, Tuple, Optional from loguru import logger @@ -11,6 +10,12 @@ from labelu.internal.common.config import settings from labelu.internal.common.error_code import ErrorCode from labelu.internal.common.error_code import LabelUException +from labelu.internal.common.storage import ( + build_attachment_api_path, + build_partial_api_path, + build_thumbnail_key, + get_storage_backend, +) from labelu.internal.adapter.persistence import crud_task from labelu.internal.adapter.persistence import crud_pre_annotation from labelu.internal.adapter.persistence import crud_attachment @@ -22,30 +27,26 @@ from labelu.internal.application.response.base import CommonDataResp from labelu.internal.application.response.pre_annotation import CreatePreAnnotationResponse, PreAnnotationFileResponse from labelu.internal.application.response.pre_annotation import PreAnnotationResponse -from labelu.internal.application.response.attachment import AttachmentResponse +from labelu.internal.application.service.attachment import build_attachment_response + def read_pre_annotation_file(attachment: TaskAttachment) -> List[dict]: if attachment is None: return [] - attachment_path = attachment.path - file_full_path = settings.MEDIA_ROOT.joinpath(attachment_path.lstrip("/")) + storage = get_storage_backend() # check if the file exists - if not file_full_path.exists() or (not attachment.filename.endswith('.jsonl') and not attachment.filename.endswith('.json')): + if not storage.exists(attachment.path) or (not attachment.filename.endswith('.jsonl') and not attachment.filename.endswith('.json')): return [] try: if attachment.filename.endswith('.jsonl'): - with open(file_full_path, "r", encoding="utf-8") as f: - data = f.readlines() - return [json.loads(line) for line in data] + data = storage.read_text(attachment.path, encoding="utf-8").splitlines() + return [json.loads(line) for line in data if line.strip()] else: - with open(file_full_path, "r", encoding="utf-8") as f: - # parse result - parsed_data = json.load(f) - - return [{**item, "result": json.loads(item["result"])} for item in parsed_data] + parsed_data = json.loads(storage.read_text(attachment.path, encoding="utf-8")) + return [{**item, "result": json.loads(item["result"])} for item in parsed_data] except FileNotFoundError: raise LabelUException(status_code=404, code=ErrorCode.CODE_51001_TASK_ATTACHMENT_NOT_FOUND) @@ -123,7 +124,7 @@ async def list_by( PreAnnotationResponse( id=pre_annotation.id, data=pre_annotation.data, - file=AttachmentResponse(id=pre_annotation.file.id, filename=pre_annotation.file.filename, url=pre_annotation.file.url) if pre_annotation.file else None, + file=build_attachment_response(pre_annotation.file), created_at=pre_annotation.created_at, created_by=UserResp( id=pre_annotation.owner.id, @@ -159,10 +160,22 @@ async def list_pre_annotation_files( if pre_annotation['file_id'] in _attachment_ids and pre_annotation['sample_name'] is not None: sample_names_those_has_pre_annotations.append(pre_annotation['sample_name']) + storage = get_storage_backend() return [ PreAnnotationFileResponse( id=attachment.id, - url=attachment.url, + url=(storage.get_read_url(attachment.path) if storage.is_remote else attachment.url), + thumbnail_url=( + storage.get_read_url(build_thumbnail_key(attachment.path)) + if storage.is_remote + else build_attachment_api_path(build_thumbnail_key(attachment.path)) + ), + stream_url=( + storage.get_read_url(attachment.path) + if storage.is_remote + else build_partial_api_path(attachment.path) + ), + storage_backend=storage.backend_name, filename=attachment.filename, sample_names=sample_names_those_has_pre_annotations, ) @@ -195,7 +208,7 @@ async def get( # response return PreAnnotationResponse( id=pre_annotation.id, - file=AttachmentResponse(id=pre_annotation.file.id, filename=pre_annotation.file.filename, url=pre_annotation.file.url) if pre_annotation.file else None, + file=build_attachment_response(pre_annotation.file), created_at=pre_annotation.created_at, created_by=UserResp( id=pre_annotation.owner.id, @@ -211,6 +224,7 @@ async def get( async def delete_pre_annotation_file( db: Session, task_id: int, file_id: int, current_user: User ) -> CommonDataResp: + storage = get_storage_backend() with begin_transaction(db): task = crud_task.get(db=db, task_id=task_id, lock=True) if not task: @@ -231,8 +245,10 @@ async def delete_pre_annotation_file( ) for attachment in attachments: - file_full_path = Path(settings.MEDIA_ROOT).joinpath(attachment.path) - os.remove(file_full_path) + storage.delete(attachment.path) + thumbnail_key = build_thumbnail_key(attachment.path) + if storage.exists(thumbnail_key): + storage.delete(thumbnail_key) pre_annotations = crud_pre_annotation.list_by_task_id_and_file_id(db=db, task_id=task_id, file_id=file_id) pre_annotation_ids = [pre_annotation.id for pre_annotation in pre_annotations] diff --git a/labelu/internal/application/service/sample.py b/labelu/internal/application/service/sample.py index adf4076a..4e7bac68 100644 --- a/labelu/internal/application/service/sample.py +++ b/labelu/internal/application/service/sample.py @@ -15,9 +15,14 @@ from labelu.internal.common.converter import converter from labelu.internal.common.error_code import ErrorCode from labelu.internal.common.error_code import LabelUException +from labelu.internal.common.storage import ( + build_thumbnail_key, + get_storage_backend, +) from labelu.internal.adapter.persistence import crud_attachment, crud_pre_annotation, crud_task from labelu.internal.adapter.persistence import crud_sample from labelu.internal.adapter.persistence import crud_export_job +from labelu.internal.adapter.persistence import crud_datasource from labelu.internal.domain.models.pre_annotation import TaskPreAnnotation from labelu.internal.domain.models.user import User from labelu.internal.domain.models.task import Task @@ -28,26 +33,28 @@ from labelu.internal.application.command.sample import ExportType from labelu.internal.application.command.sample import PatchSampleCommand from labelu.internal.application.command.sample import CreateSampleCommand +from labelu.internal.application.command.datasource import ImportS3SamplesCommand from labelu.internal.application.response.base import UserResp from labelu.internal.application.response.base import CommonDataResp from labelu.internal.application.response.sample import CreateSampleResponse from labelu.internal.application.response.sample import SampleResponse -from labelu.internal.application.response.attachment import AttachmentResponse +from labelu.internal.application.service.attachment import build_attachment_response from labelu.internal.clients.ws import sampleConnectionManager from labelu.internal.common.websocket import Message, MessageType from labelu.internal.adapter.ws.sample import TaskSampleWsPayload -def is_sample_pre_annotated(db: Session, task_id: int, sample_name: str | None = None) -> Tuple[List[TaskPreAnnotation], int]: + +def is_sample_pre_annotated(db: Session, task_id: int, sample_name: str | None = None) -> bool: if sample_name is None: return False - + _, total = crud_pre_annotation.list_by( db=db, task_id=task_id, sample_name=sample_name, size=1, ) - + return total > 0 async def create( @@ -86,6 +93,134 @@ async def create( return CreateSampleResponse(ids=ids) +MAX_IMPORT_KEYS = 10000 + + +def _collect_s3_keys(ds, prefix: str, extension: str | None) -> list[str]: + """List all matching S3 object keys under *prefix*, paginating internally.""" + from labelu.internal.application.service.datasource import _build_s3_client + + client = _build_s3_client(ds) + full_prefix = prefix if prefix else (ds.prefix or "") + + allowed_exts = None + if extension: + allowed_exts = { + ("." + e.strip().lower().lstrip(".")) + for e in extension.split(",") + if e.strip() + } + + keys: list[str] = [] + continuation_token = None + + while True: + kwargs = {"Bucket": ds.bucket, "Prefix": full_prefix, "MaxKeys": 1000} + if continuation_token: + kwargs["ContinuationToken"] = continuation_token + + try: + resp = client.list_objects_v2(**kwargs) + except Exception as exc: + logger.opt(exception=exc).error( + "S3 list_objects_v2 failed for datasource {}", ds.id + ) + raise LabelUException( + code=ErrorCode.CODE_62002_S3_REQUEST_FAILED, + status_code=status.HTTP_502_BAD_GATEWAY, + ) + + for obj in resp.get("Contents", []): + key: str = obj["Key"] + if key.endswith("/"): + continue + if allowed_exts: + ext = ("." + key.rsplit(".", 1)[-1].lower()) if "." in key else "" + if ext not in allowed_exts: + continue + keys.append(key) + if len(keys) > MAX_IMPORT_KEYS: + raise LabelUException( + code=ErrorCode.CODE_62000_S3_IMPORT_TOO_MANY, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + if not resp.get("IsTruncated"): + break + continuation_token = resp.get("NextContinuationToken") + + return keys + + +async def import_from_s3( + db: Session, task_id: int, cmd: ImportS3SamplesCommand, current_user: User, +) -> CreateSampleResponse: + """Import S3 objects as task samples (no file copy — stores reference only).""" + from labelu.internal.domain.models.attachment import TaskAttachment + from labelu.internal.application.service.datasource import _build_s3_client + + with begin_transaction(db): + task = crud_task.get(db=db, task_id=task_id, lock=True) + if not task: + raise LabelUException( + code=ErrorCode.CODE_50002_TASK_NOT_FOUND, + status_code=status.HTTP_404_NOT_FOUND, + ) + + ds = crud_datasource.get(db=db, ds_id=cmd.data_source_id) + if not ds: + raise LabelUException( + code=ErrorCode.CODE_61000_NO_DATA, + status_code=status.HTTP_404_NOT_FOUND, + ) + + # Resolve object keys: either from explicit list or by listing S3 prefix + object_keys = cmd.object_keys + if not object_keys and cmd.prefix is not None: + object_keys = _collect_s3_keys(ds, cmd.prefix, cmd.extension) + + if not object_keys: + raise LabelUException( + code=ErrorCode.CODE_62001_S3_IMPORT_NO_MATCH, + status_code=status.HTTP_400_BAD_REQUEST, + ) + + attachments = [] + for key in object_keys: + filename = key.rsplit("/", 1)[-1] if "/" in key else key + att = TaskAttachment( + path=key, + url="", + filename=filename, + task_id=task_id, + data_source_id=ds.id, + created_by=current_user.id, + updated_by=current_user.id, + ) + db.add(att) + attachments.append(att) + db.flush() + + samples = [ + TaskSample( + inner_id=task.last_sample_inner_id + i + 1, + task_id=task_id, + file_id=att.id, + created_by=current_user.id, + updated_by=current_user.id, + data=json.dumps({}), + ) + for i, att in enumerate(attachments) + ] + obj_in = {Task.last_sample_inner_id.key: task.last_sample_inner_id + len(samples)} + if task.status == TaskStatus.DRAFT.value: + obj_in[Task.status.key] = TaskStatus.IMPORTED + crud_task.update(db=db, db_obj=task, obj_in=obj_in) + new_samples = crud_sample.batch(db=db, samples=samples) + + return CreateSampleResponse(ids=[s.id for s in new_samples]) + + async def list_by( db: Session, task_id: Union[int, None], @@ -116,7 +251,7 @@ async def list_by( data=json.loads(sample.data), annotated_count=sample.annotated_count, is_pre_annotated=is_sample_pre_annotated(db=db, task_id=task_id, sample_name=sample.file.filename if sample.file else None), - file=AttachmentResponse(id=sample.file.id, filename=sample.file.filename, url=sample.file.url) if sample.file else None, + file=build_attachment_response(sample.file), created_at=sample.created_at, created_by=UserResp( id=sample.owner.id, @@ -154,7 +289,7 @@ async def get( state=sample.state, data=json.loads(sample.data), is_pre_annotated=is_sample_pre_annotated(db=db, task_id=task_id, sample_name=sample.file.filename if sample.file else None), - file=AttachmentResponse(id=sample.file.id, filename=sample.file.filename, url=sample.file.url) if sample.file else None, + file=build_attachment_response(sample.file), annotated_count=sample.annotated_count, created_at=sample.created_at, created_by=UserResp( @@ -249,6 +384,7 @@ async def patch( state=updated_sample.state, data=json.loads(updated_sample.data), is_pre_annotated=is_sample_pre_annotated(db=db, task_id=task_id, sample_name=sample.file.filename if sample.file else None), + file=build_attachment_response(updated_sample.file), annotated_count=updated_sample.annotated_count, created_at=updated_sample.created_at, created_by=UserResp( @@ -266,6 +402,7 @@ async def patch( async def delete( db: Session, sample_ids: List[int], current_user: User ) -> CommonDataResp: + storage = get_storage_backend() with begin_transaction(db): # delete media @@ -277,8 +414,10 @@ async def delete( db=db, attachment_ids=attachment_ids ) for attachment in attachments: - file_full_path = Path(settings.MEDIA_ROOT).joinpath(attachment.path) - os.remove(file_full_path) + storage.delete(attachment.path) + thumbnail_key = build_thumbnail_key(attachment.path) + if storage.exists(thumbnail_key): + storage.delete(thumbnail_key) crud_sample.delete(db=db, sample_ids=sample_ids) # response @@ -363,10 +502,20 @@ def _run_export_sync(job_id: int, task_id: int, export_type: ExportType, sample_ format=export_type.value, ) + storage = get_storage_backend() + if storage.is_remote: + local_export_path = Path(file_full_path) + export_key = f"{settings.EXPORT_DIR}/{local_export_path.name}" + storage.save_file(local_export_path, export_key) + local_export_path.unlink(missing_ok=True) + stored_path = export_key + else: + stored_path = str(file_full_path) + with begin_transaction(db): crud_export_job.update_status( db, job, ExportStatus.COMPLETED.value, - file_path=str(file_full_path), + file_path=stored_path, processed_count=len(data), ) except Exception as e: diff --git a/labelu/internal/common/config.py b/labelu/internal/common/config.py index 4555930a..3b896fe9 100644 --- a/labelu/internal/common/config.py +++ b/labelu/internal/common/config.py @@ -27,6 +27,24 @@ class Settings(BaseSettings): EXPORT_DIR: str = "export" UPLOAD_FILE_MAX_SIZE: int = 200_000_000 # ~200MB THUMBNAIL_HEIGH_PIXEL: int = 120 + STORAGE_BACKEND: str = "local" + + S3_ENDPOINT: str = "" + S3_REGION: str = "" + S3_BUCKET: str = "" + S3_ACCESS_KEY_ID: str = "" + S3_SECRET_ACCESS_KEY: str = "" + S3_PUBLIC_BASE_URL: str = "" + S3_PRESIGN_EXPIRE_SECONDS: int = 3600 + S3_PATH_STYLE: bool = False + S3_USE_SSL: bool = True + + AI_AUTO_LABEL_ENABLED: bool = False + AI_PROVIDER: str = "local_http" + AI_MODEL_ENDPOINT: str = "" + AI_MODEL_TIMEOUT_SECONDS: int = 60 + AI_MODEL_NAME: str = "" + AI_IMAGE_URL_EXPIRE_SECONDS: int = 300 DATABASE_URL: str = Field( # default="mysql://labelu:labelupass@localhost/labeludb", diff --git a/labelu/internal/common/crypto.py b/labelu/internal/common/crypto.py new file mode 100644 index 00000000..1b5c674b --- /dev/null +++ b/labelu/internal/common/crypto.py @@ -0,0 +1,38 @@ +"""Symmetric encryption utilities for storing sensitive values (e.g. S3 credentials). + +Uses Fernet (AES-128-CBC + HMAC-SHA256) with a key derived from the +application's PASSWORD_SECRET_KEY. +""" + +from __future__ import annotations + +import base64 +import hashlib + +from cryptography.fernet import Fernet + + +def _derive_key(secret: str) -> bytes: + """Derive a 32-byte Fernet-compatible key from an arbitrary secret string.""" + raw = hashlib.sha256(secret.encode()).digest() + return base64.urlsafe_b64encode(raw) + + +def _get_fernet() -> Fernet: + from labelu.internal.common.config import settings + + if not settings.PASSWORD_SECRET_KEY: + raise RuntimeError( + "PASSWORD_SECRET_KEY must be set to use credential encryption" + ) + return Fernet(_derive_key(settings.PASSWORD_SECRET_KEY)) + + +def encrypt_value(plaintext: str) -> str: + """Encrypt a string and return a URL-safe base64-encoded ciphertext.""" + return _get_fernet().encrypt(plaintext.encode()).decode() + + +def decrypt_value(ciphertext: str) -> str: + """Decrypt a URL-safe base64-encoded ciphertext back to plaintext.""" + return _get_fernet().decrypt(ciphertext.encode()).decode() diff --git a/labelu/internal/common/error_code.py b/labelu/internal/common/error_code.py index 40148121..5515e0d3 100644 --- a/labelu/internal/common/error_code.py +++ b/labelu/internal/common/error_code.py @@ -106,10 +106,50 @@ class ErrorCode(Enum): TASK_INIT_CODE + 5003, "Sample name exists", ) + CODE_56000_AUTO_LABEL_DISABLED = ( + TASK_INIT_CODE + 6000, + "AI annotation is not enabled. Set AI_AUTO_LABEL_ENABLED=true to activate", + ) + CODE_56001_AUTO_LABEL_UNSUPPORTED_MEDIA = ( + TASK_INIT_CODE + 6001, + "AI annotation only supports image tasks", + ) + CODE_56002_AUTO_LABEL_NO_LABELS_CONFIGURED = ( + TASK_INIT_CODE + 6002, + "No annotation tools with labels configured for AI annotation", + ) + CODE_56003_AUTO_LABEL_MODEL_ERROR = ( + TASK_INIT_CODE + 6003, + "AI model service request failed", + ) + CODE_56004_AUTO_LABEL_NOT_CONFIGURED = ( + TASK_INIT_CODE + 6004, + "AI model endpoint is not configured. Set AI_MODEL_ENDPOINT in settings", + ) + CODE_56005_AUTO_LABEL_INVALID_RESPONSE = ( + TASK_INIT_CODE + 6005, + "AI model returned an unsupported tool type in response", + ) + CODE_56006_AUTO_LABEL_NO_SAMPLES = ( + TASK_INIT_CODE + 6006, + "No unannotated samples found for batch auto-label", + ) CODE_61000_NO_DATA = ( EXPORT_INIT_CODE + 1000, "No data", ) + CODE_62000_S3_IMPORT_TOO_MANY = ( + EXPORT_INIT_CODE + 2000, + "Too many files to import (max 10000)", + ) + CODE_62001_S3_IMPORT_NO_MATCH = ( + EXPORT_INIT_CODE + 2001, + "No matching files found under the specified prefix", + ) + CODE_62002_S3_REQUEST_FAILED = ( + EXPORT_INIT_CODE + 2002, + "Failed to connect to S3 storage service", + ) @@ -128,9 +168,9 @@ def _request_context(request: Request) -> str: class LabelUException(HTTPException): def __init__( - self, code: ErrorCode, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR + self, code: ErrorCode, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, message: str = None ): - self.msg = code.value[1] + self.msg = message or code.value[1] self.code = code.value[0] super().__init__(status_code=status_code, detail=self.msg) diff --git a/labelu/internal/common/storage.py b/labelu/internal/common/storage.py new file mode 100644 index 00000000..fe797c2c --- /dev/null +++ b/labelu/internal/common/storage.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import lru_cache +from pathlib import Path +from typing import Optional + +from PIL import Image + +from labelu.internal.common.config import settings + + +class StorageBackend(ABC): + @property + @abstractmethod + def backend_name(self) -> str: + raise NotImplementedError + + @property + def is_remote(self) -> bool: + return False + + @abstractmethod + def save_file(self, local_path: Path, key: str, content_type: Optional[str] = None) -> None: + raise NotImplementedError + + @abstractmethod + def save_bytes(self, content: bytes, key: str, content_type: Optional[str] = None) -> None: + raise NotImplementedError + + @abstractmethod + def delete(self, key: str) -> None: + raise NotImplementedError + + @abstractmethod + def exists(self, key: str) -> bool: + raise NotImplementedError + + @abstractmethod + def read_text(self, key: str, encoding: str = "utf-8") -> str: + raise NotImplementedError + + @abstractmethod + def get_local_path(self, key: str) -> Optional[Path]: + raise NotImplementedError + + @abstractmethod + def get_read_url(self, key: str, expires_in: Optional[int] = None) -> str: + raise NotImplementedError + + +class LocalStorageBackend(StorageBackend): + @property + def backend_name(self) -> str: + return "local" + + def _resolve(self, key: str) -> Path: + return settings.MEDIA_ROOT.joinpath(key.lstrip("/")) + + def save_file(self, local_path: Path, key: str, content_type: Optional[str] = None) -> None: + target_path = self._resolve(key) + target_path.parent.mkdir(parents=True, exist_ok=True) + local_path.replace(target_path) + + def save_bytes(self, content: bytes, key: str, content_type: Optional[str] = None) -> None: + target_path = self._resolve(key) + target_path.parent.mkdir(parents=True, exist_ok=True) + target_path.write_bytes(content) + + def delete(self, key: str) -> None: + file_path = self._resolve(key) + if file_path.exists(): + file_path.unlink() + + def exists(self, key: str) -> bool: + return self._resolve(key).exists() + + def read_text(self, key: str, encoding: str = "utf-8") -> str: + return self._resolve(key).read_text(encoding=encoding) + + def get_local_path(self, key: str) -> Optional[Path]: + return self._resolve(key) + + def get_read_url(self, key: str, expires_in: Optional[int] = None) -> str: + return build_absolute_media_url(build_attachment_api_path(key)) + + +class S3StorageBackend(StorageBackend): + def __init__(self) -> None: + import boto3 + from botocore.config import Config + + self._bucket = settings.S3_BUCKET + client_config = {} + if settings.S3_REGION: + client_config["region_name"] = settings.S3_REGION + client_config["config"] = Config( + s3={"addressing_style": "path" if settings.S3_PATH_STYLE else "auto"} + ) + + self._client = boto3.client( + "s3", + endpoint_url=settings.S3_ENDPOINT or None, + aws_access_key_id=settings.S3_ACCESS_KEY_ID or None, + aws_secret_access_key=settings.S3_SECRET_ACCESS_KEY or None, + use_ssl=settings.S3_USE_SSL, + **client_config, + ) + + @property + def backend_name(self) -> str: + return "s3" + + @property + def is_remote(self) -> bool: + return True + + def save_file(self, local_path: Path, key: str, content_type: Optional[str] = None) -> None: + extra_args = {} + if content_type: + extra_args["ContentType"] = content_type + if extra_args: + self._client.upload_file(str(local_path), self._bucket, key, ExtraArgs=extra_args) + else: + self._client.upload_file(str(local_path), self._bucket, key) + + def save_bytes(self, content: bytes, key: str, content_type: Optional[str] = None) -> None: + kwargs = {"Bucket": self._bucket, "Key": key, "Body": content} + if content_type: + kwargs["ContentType"] = content_type + self._client.put_object(**kwargs) + + def delete(self, key: str) -> None: + self._client.delete_object(Bucket=self._bucket, Key=key) + + def exists(self, key: str) -> bool: + try: + self._client.head_object(Bucket=self._bucket, Key=key) + return True + except Exception: + return False + + def read_text(self, key: str, encoding: str = "utf-8") -> str: + result = self._client.get_object(Bucket=self._bucket, Key=key) + return result["Body"].read().decode(encoding) + + def get_local_path(self, key: str) -> Optional[Path]: + return None + + def get_read_url(self, key: str, expires_in: Optional[int] = None) -> str: + expires = expires_in or settings.S3_PRESIGN_EXPIRE_SECONDS + return self._client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self._bucket, "Key": key}, + ExpiresIn=expires, + ) + + +def build_attachment_api_path(key: str) -> str: + normalized = key.replace("\\", "/").lstrip("/") + return f"{settings.API_V1_STR}/tasks/attachment/{normalized}" + + +def build_partial_api_path(key: str) -> str: + normalized = key.replace("\\", "/").lstrip("/") + return f"{settings.API_V1_STR}/tasks/partial/{normalized}" + + +def build_absolute_media_url(relative_url: str) -> str: + return f"{settings.MEDIA_HOST.rstrip('/')}/{relative_url.lstrip('/')}" + + +def build_thumbnail_key(key: str) -> str: + path = Path(key) + return str(path.with_name(f"{path.stem}-thumbnail{path.suffix}")).replace("\\", "/") + + +def create_thumbnail_bytes(local_path: Path) -> bytes: + image = Image.open(local_path) + image.thumbnail( + ( + round(image.width / image.height * settings.THUMBNAIL_HEIGH_PIXEL), + settings.THUMBNAIL_HEIGH_PIXEL, + ), + ) + if image.mode != "RGB": + image = image.convert("RGB") + + from io import BytesIO + + buffer = BytesIO() + extension_to_format = { + ".jpg": "JPEG", + ".jpeg": "JPEG", + ".png": "PNG", + ".webp": "WEBP", + } + format_name = extension_to_format.get(local_path.suffix.lower(), "PNG") + image.save(buffer, format=format_name) + return buffer.getvalue() + + +def get_model_read_url(key: str) -> str: + backend = get_storage_backend() + if backend.is_remote: + return backend.get_read_url(key, expires_in=settings.AI_IMAGE_URL_EXPIRE_SECONDS) + return build_absolute_media_url(build_attachment_api_path(key)) + + +@lru_cache(maxsize=1) +def get_storage_backend() -> StorageBackend: + if settings.STORAGE_BACKEND.lower() == "s3": + return S3StorageBackend() + return LocalStorageBackend() diff --git a/labelu/internal/domain/models/__init__.py b/labelu/internal/domain/models/__init__.py index aa1c0b09..0f2ec59e 100644 --- a/labelu/internal/domain/models/__init__.py +++ b/labelu/internal/domain/models/__init__.py @@ -1,4 +1,5 @@ from .attachment import TaskAttachment +from .data_source import DataSource from .sample import TaskSample from .task import Task from .user import User diff --git a/labelu/internal/domain/models/attachment.py b/labelu/internal/domain/models/attachment.py index 35f99625..43b1062b 100644 --- a/labelu/internal/domain/models/attachment.py +++ b/labelu/internal/domain/models/attachment.py @@ -2,6 +2,8 @@ from sqlalchemy.schema import Index from sqlalchemy import Column, DateTime, ForeignKey, Integer, String +from sqlalchemy.orm import relationship +from sqlalchemy.schema import Index from labelu.internal.common.db import Base @@ -12,8 +14,12 @@ class TaskAttachment(Base): id = Column(Integer, primary_key=True, autoincrement=True, index=True) filename = Column(String(256), comment="file name") url = Column(String(256), comment="file url") - path = Column(String(256), comment="file storage path") + path = Column(String(256), comment="file storage path or S3 object key") task_id = Column(Integer, ForeignKey("task.id"), index=True) + data_source_id = Column( + Integer, ForeignKey("data_source.id"), nullable=True, index=True, + comment="NULL = local upload, set = imported from external data source", + ) created_by = Column(Integer, ForeignKey("user.id"), index=True) updated_by = Column(Integer, ForeignKey("user.id"), index=True) created_at = Column( @@ -27,4 +33,6 @@ class TaskAttachment(Base): ) deleted_at = Column(DateTime, index=True, comment="Task delete time") + data_source = relationship("DataSource", lazy="joined") + Index("idx_attachment_id_deleted_at", id, deleted_at) diff --git a/labelu/internal/domain/models/auto_label_job.py b/labelu/internal/domain/models/auto_label_job.py new file mode 100644 index 00000000..1a64f708 --- /dev/null +++ b/labelu/internal/domain/models/auto_label_job.py @@ -0,0 +1,28 @@ +from enum import Enum +from datetime import datetime +from sqlalchemy import Column, Integer, String, Boolean, Text, DateTime, ForeignKey +from labelu.internal.common.db import Base + + +class AutoLabelStatus(str, Enum): + PENDING = "PENDING" + PROCESSING = "PROCESSING" + COMPLETED = "COMPLETED" + FAILED = "FAILED" + + +class AutoLabelJob(Base): + __tablename__ = "auto_label_job" + + id = Column(Integer, primary_key=True, autoincrement=True, index=True) + task_id = Column(Integer, ForeignKey("task.id"), index=True) + created_by = Column(Integer, ForeignKey("user.id")) + status = Column(String(32), default=AutoLabelStatus.PENDING.value) + sample_count = Column(Integer, default=0) + processed_count = Column(Integer, default=0) + success_count = Column(Integer, default=0) + failed_count = Column(Integer, default=0) + filter_by_labels = Column(Boolean, default=True) + error_message = Column(Text, nullable=True) + created_at = Column(DateTime(timezone=True), default=datetime.now) + updated_at = Column(DateTime(timezone=True), default=datetime.now, onupdate=datetime.now) diff --git a/labelu/internal/domain/models/data_source.py b/labelu/internal/domain/models/data_source.py new file mode 100644 index 00000000..1056a580 --- /dev/null +++ b/labelu/internal/domain/models/data_source.py @@ -0,0 +1,37 @@ +from datetime import datetime + +from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String, Text +from sqlalchemy.schema import Index + +from labelu.internal.common.db import Base + + +class DataSource(Base): + __tablename__ = "data_source" + + id = Column(Integer, primary_key=True, autoincrement=True, index=True) + name = Column(String(128), nullable=False, comment="Display name") + type = Column(String(32), nullable=False, default="S3", comment="Source type: S3") + endpoint = Column(String(512), comment="S3 endpoint URL") + region = Column(String(64), comment="AWS region") + bucket = Column(String(256), nullable=False, comment="Bucket name") + prefix = Column(String(512), default="", comment="Default key prefix") + access_key_id = Column(String(512), comment="Encrypted access key") + secret_access_key = Column(String(1024), comment="Encrypted secret key") + path_style = Column(Boolean, default=False, comment="Use path-style addressing") + use_ssl = Column(Boolean, default=True, comment="Use HTTPS") + presign_expire_secs = Column(Integer, default=3600, comment="Presigned URL TTL") + created_by = Column(Integer, ForeignKey("user.id"), index=True) + updated_by = Column(Integer, ForeignKey("user.id"), index=True) + created_at = Column( + DateTime(timezone=True), default=datetime.now, comment="Created time" + ) + updated_at = Column( + DateTime(timezone=True), + default=datetime.now, + onupdate=datetime.now, + comment="Updated time", + ) + deleted_at = Column(DateTime, index=True, comment="Soft delete time") + + Index("idx_datasource_id_deleted_at", id, deleted_at) diff --git a/model_server/README.md b/model_server/README.md new file mode 100644 index 00000000..8fd68b2d --- /dev/null +++ b/model_server/README.md @@ -0,0 +1,164 @@ +# LabelU Model Server + +LabelU 自动标注功能的参考模型服务实现。提供三套方案,均实现统一的 API 协议。 + +## 方案对比 + +| | Florence-2 | GroundingDINO + EfficientSAM | SAM 3 | +|---|---|---|---| +| 架构 | 单模型(检测+分割分步) | 两模型串联 | 单模型统一 | +| 开放词汇 | 支持 | 支持 | 支持(400万+概念) | +| 检测 (rectTool) | 一般 | 很好 | 很好 | +| 分割 (polygonTool) | 支持 | 精度高 | 精度最高 | +| 最低显存 | ~4GB (base) | ~4GB (tiny + vitt) | ~8GB | +| CPU 可用 | 可以(较慢) | 不推荐 | 不支持 | +| 模型大小 | ~500MB | ~900MB | ~1.7GB (848M params) | +| 部署复杂度 | 低 | 中 | 中 | +| Python 要求 | 3.8+ | 3.8+ | **3.12+** | +| PyTorch 要求 | 2.1+ | 2.1+ | **2.7+** | +| CUDA 要求 | 可选 | 推荐 | **12.6+(必须)** | +| GPU 算力要求 | 无限制 | 无限制 | **SM 80+(Ampere/Ada)** | + +**推荐**: +- 有 4090/A100 等高端 GPU → **SAM 3**(质量最好,单模型最简单) +- 有中端 GPU(如 1660/2060) → **GroundingDINO + EfficientSAM** +- 只有 CPU 或显存紧张 → **Florence-2** + +## 快速启动 + +### Florence-2 + +```bash +cd model_server/florence2 +pip install -r requirements.txt +python server.py --device cpu --port 5000 +``` + +### GroundingDINO + EfficientSAM + +```bash +cd model_server/grounding_dino_sam +pip install -r requirements.txt +python server.py --device cuda --port 5000 + +# 仅检测(不加载 SAM,节省显存) +python server.py --device cuda --port 5000 --no-sam +``` + +### SAM 3 + +```bash +# 需要 Python 3.12+, CUDA 12.6+ +conda create -n sam3 python=3.12 +conda activate sam3 +pip install torch==2.10.0 torchvision --index-url https://download.pytorch.org/whl/cu128 + +cd model_server/sam3 +pip install -r requirements.txt +python server.py --device cuda --port 5000 +``` + +### Docker + +```bash +# Florence-2 +cd model_server/florence2 +docker build -t labelu-florence2 . +docker run -p 5000:5000 labelu-florence2 + +# GroundingDINO + EfficientSAM (GPU) +cd model_server/grounding_dino_sam +docker build -t labelu-dino-sam . +docker run --gpus all -p 5000:5000 labelu-dino-sam python server.py --device cuda + +# SAM 3 (GPU, CUDA 12.6+) +cd model_server/sam3 +docker build -t labelu-sam3 . +docker run --gpus all -p 5000:5000 labelu-sam3 +``` + +## 配置 LabelU 连接 + +在 LabelU 的 `.env` 或环境变量中设置: + +```env +AI_AUTO_LABEL_ENABLED=true +AI_MODEL_ENDPOINT=http://localhost:5000/ +AI_MODEL_TIMEOUT_SECONDS=60 +AI_MODEL_NAME=florence-2-base # 或 grounding-dino-tiny+efficient-sam / sam3 +``` + +## API 协议 + +### POST / + +**请求:** + +```json +{ + "request_id": "uuid", + "image_url": "https://presigned-or-public-url/image.jpg", + "task": { "id": 10, "name": "demo-task" }, + "labels": [ + { "name": "car", "tool": "rectTool" }, + { "name": "person", "tool": "polygonTool" } + ], + "constraints": { + "allowed_tools": ["rectTool", "polygonTool"], + "max_results_per_label": 100, + "filter_by_labels": false + }, + "prompt": null +} +``` + +**constraints 字段说明:** + +| 字段 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `allowed_tools` | string[] | [] | 限制返回的工具类型,为空不限制 | +| `max_results_per_label` | int | 100 | 每个标签最多返回的结果数 | +| `filter_by_labels` | bool | false | 是否过滤掉不在 labels 列表中的检测结果 | + +`filter_by_labels` 说明: +- `false`(默认):模型检测到什么就返回什么,标签名作为 prompt 提示但不做过滤 +- `true`:仅返回与配置标签名精确匹配的结果 + +**响应:** + +```json +{ + "model": "microsoft/Florence-2-base", + "latency_ms": 1840, + "results": [ + { + "toolName": "rectTool", + "label": "car", + "result": { "x": 120, "y": 80, "width": 260, "height": 140 }, + "score": 0.94 + }, + { + "toolName": "polygonTool", + "label": "person", + "result": { + "type": "line", + "points": [ + { "x": 50, "y": 100 }, + { "x": 80, "y": 200 }, + { "x": 30, "y": 200 } + ] + }, + "score": 0.87 + } + ], + "warning_message": null +} +``` + +### GET /health + +返回模型加载状态。 + +## 自定义模型 + +如果要接入其他模型,只需实现上述 API 协议即可。LabelU 后端通过 HTTP POST 调用模型服务,不感知具体模型实现。 diff --git a/model_server/florence2/Dockerfile b/model_server/florence2/Dockerfile new file mode 100644 index 00000000..b1d4cfbb --- /dev/null +++ b/model_server/florence2/Dockerfile @@ -0,0 +1,15 @@ +FROM python:3.11-slim + +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY server.py . + +# Pre-download model at build time (optional, ~500MB) +RUN python -c "from transformers import AutoModelForCausalLM, AutoProcessor; \ + AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True); \ + AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)" + +EXPOSE 5000 +CMD ["python", "server.py", "--device", "cpu"] diff --git a/model_server/florence2/requirements.txt b/model_server/florence2/requirements.txt new file mode 100644 index 00000000..af1c9f45 --- /dev/null +++ b/model_server/florence2/requirements.txt @@ -0,0 +1,10 @@ +torch>=2.1.0 +transformers>=4.40.0,<=4.51.3 +fastapi>=0.100.0 +uvicorn>=0.19.0 +httpx>=0.27.0 +pillow>=9.3.0 +numpy>=1.24.0 +opencv-python-headless>=4.8.0 +einops +timm diff --git a/model_server/florence2/server.py b/model_server/florence2/server.py new file mode 100644 index 00000000..d7f6f1e1 --- /dev/null +++ b/model_server/florence2/server.py @@ -0,0 +1,316 @@ +""" +LabelU Model Server — Florence-2 + +Lightweight reference implementation using Microsoft Florence-2 for +open-vocabulary image detection and segmentation. + +Implements the LabelU auto-label model API protocol: + POST / → { request_id, image_url, labels, constraints, prompt } + Returns → { model, results: [{ toolName, label, result, score }], ... } + +Quick start: + pip install -r requirements.txt + python server.py # default: 0.0.0.0:5000 + python server.py --port 5001 --device cuda +""" + +from __future__ import annotations + +import argparse +import io +import time +import logging +from typing import Any + +import httpx +import numpy as np +import torch +import uvicorn +from fastapi import FastAPI, HTTPException +from PIL import Image +from pydantic import BaseModel, Field +from transformers import AutoModelForCausalLM, AutoProcessor + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("florence2-server") + +app = FastAPI(title="LabelU Florence-2 Model Server") + +# ── globals (set in main) ───────────────────────────────────────────── +model = None +processor = None +device = "cpu" + +MODEL_NAME = "microsoft/Florence-2-base" + + +# ── request / response schemas ──────────────────────────────────────── +class LabelItem(BaseModel): + name: str + display_name: str | None = None + color: str | None = None + tool: str + + +class Constraints(BaseModel): + allowed_tools: list[str] = [] + max_results_per_label: int = 100 + filter_by_labels: bool = False + + +class PredictRequest(BaseModel): + request_id: str = "" + image_url: str + task: dict[str, Any] = {} + labels: list[LabelItem] = [] + constraints: Constraints = Field(default_factory=Constraints) + prompt: str | None = None + + +class ResultItem(BaseModel): + toolName: str + label: str + result: dict[str, Any] + score: float | None = None + + +class PredictResponse(BaseModel): + model: str = MODEL_NAME + latency_ms: int = 0 + results: list[ResultItem] = [] + warning_message: str | None = None + + +# ── helpers ─────────────────────────────────────────────────────────── +async def _download_image(url: str) -> Image.Image: + async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: + resp = await client.get(url) + resp.raise_for_status() + return Image.open(io.BytesIO(resp.content)).convert("RGB") + + +def _run_florence(image: Image.Image, task_prompt: str, text_input: str = "") -> dict: + full_prompt = task_prompt if not text_input else task_prompt + text_input + inputs = processor(text=full_prompt, images=image, return_tensors="pt").to(device) + with torch.inference_mode(): + generated_ids = model.generate( + **inputs, + max_new_tokens=1024, + num_beams=3, + ) + generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + return processor.post_process_generation( + generated_text, task=task_prompt, image_size=image.size + ) + + +def _bbox_to_rect(bbox: list[float]) -> dict[str, float]: + """Convert [x1, y1, x2, y2] to {x, y, width, height}.""" + x1, y1, x2, y2 = bbox + return {"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1} + + +def _mask_to_polygon_points(mask: np.ndarray) -> list[dict[str, float]]: + """Extract the largest contour from a binary mask as polygon points.""" + import cv2 + + contours, _ = cv2.findContours( + mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + if not contours: + return [] + largest = max(contours, key=cv2.contourArea) + epsilon = 0.005 * cv2.arcLength(largest, True) + approx = cv2.approxPolyDP(largest, epsilon, True) + return [{"x": float(pt[0][0]), "y": float(pt[0][1])} for pt in approx] + + +def _detect_with_od( + image: Image.Image, + default_tool: str, +) -> list[ResultItem]: + """Use for general object detection (Florence2 built-in categories).""" + raw = _run_florence(image, "") + data = raw.get("", {}) + bboxes = data.get("bboxes", []) + detected_labels = data.get("labels", []) + + results: list[ResultItem] = [] + for bbox, det_label in zip(bboxes, detected_labels): + det_clean = det_label.strip().lower() + results.append(ResultItem( + toolName=default_tool, + label=det_clean, + result=_bbox_to_rect(bbox), + score=None, + )) + return results + + +def _detect_with_grounding( + image: Image.Image, + labels: list[LabelItem], + max_per_label: int, + filter_by_labels: bool = False, +) -> list[ResultItem]: + """Use as baseline, then for custom labels.""" + label_tool_map = {lbl.name.lower(): lbl.tool for lbl in labels} + default_tool = labels[0].tool if labels else "rectTool" + + # Step 1: Run to get all standard detections + od_results = _detect_with_od(image, default_tool) + od_labels_found = {r.label for r in od_results} + logger.info(" found %d objects: %s", len(od_results), od_labels_found) + + # Step 2: Find configured labels not covered by , run grounding for those + uncovered = [lbl for lbl in labels if lbl.name.lower() not in od_labels_found] + grounding_results: list[ResultItem] = [] + if uncovered: + caption = ". ".join(lbl.name for lbl in uncovered) + raw = _run_florence(image, "", caption) + logger.info(" raw: %s", raw) + data = raw.get("", {}) + for bbox, det_label in zip(data.get("bboxes", []), data.get("labels", [])): + det_clean = det_label.strip().lower() + tool = label_tool_map.get(det_clean, default_tool) + grounding_results.append(ResultItem( + toolName=tool, + label=det_clean, + result=_bbox_to_rect(bbox), + score=None, + )) + + # Merge: OD results + grounding results for uncovered labels + all_results = od_results + grounding_results + + if filter_by_labels: + label_name_set = {lbl.name.lower() for lbl in labels} + all_results = [r for r in all_results if r.label in label_name_set] + + # Apply tool mapping from configured labels + for r in all_results: + if r.label in label_tool_map: + r.toolName = label_tool_map[r.label] + + return all_results + + +def _segment_for_labels( + image: Image.Image, + labels: list[LabelItem], + max_per_label: int, +) -> list[ResultItem]: + """Use per label for polygon output.""" + results: list[ResultItem] = [] + for lbl in labels: + if lbl.tool != "polygonTool": + continue + raw = _run_florence(image, "", lbl.name) + seg_data = raw.get("", {}) + polygons_raw = seg_data.get("polygons", []) + + for i, polygon in enumerate(polygons_raw[:max_per_label]): + if isinstance(polygon, list) and len(polygon) > 0: + if isinstance(polygon[0], list): + points = [ + {"x": float(polygon[0][j]), "y": float(polygon[0][j + 1])} + for j in range(0, len(polygon[0]), 2) + ] + else: + points = [ + {"x": float(polygon[j]), "y": float(polygon[j + 1])} + for j in range(0, len(polygon), 2) + ] + if len(points) >= 3: + results.append( + ResultItem( + toolName="polygonTool", + label=lbl.name, + result={"type": "line", "points": points}, + ) + ) + return results + + +# ── main endpoint ───────────────────────────────────────────────────── +@app.post("/", response_model=PredictResponse) +async def predict(req: PredictRequest) -> PredictResponse: + start = time.perf_counter() + try: + image = await _download_image(req.image_url) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Failed to download image: {exc}") + + allowed = ( + set(req.constraints.allowed_tools) if req.constraints.allowed_tools else None + ) + max_per = req.constraints.max_results_per_label + results: list[ResultItem] = [] + + need_detect = allowed is None or allowed & {"rectTool", "pointTool", "lineTool"} + need_segment = allowed is None or "polygonTool" in (allowed or set()) + + if need_detect: + detect_labels = ( + [l for l in req.labels if l.tool != "polygonTool"] + if need_segment + else req.labels + ) + + logger.info("detect_labels: %s", detect_labels) + + if detect_labels: + results.extend( + _detect_with_grounding( + image, detect_labels, max_per, req.constraints.filter_by_labels + ) + ) + + if need_segment: + seg_labels = [l for l in req.labels if l.tool == "polygonTool"] + if seg_labels: + results.extend(_segment_for_labels(image, seg_labels, max_per)) + + latency = int((time.perf_counter() - start) * 1000) + return PredictResponse(model=MODEL_NAME, latency_ms=latency, results=results) + + +@app.get("/health") +async def health(): + return {"status": "ok", "model": MODEL_NAME, "device": device} + + +# ── entrypoint ──────────────────────────────────────────────────────── +def main(): + global model, processor, device, MODEL_NAME + + parser = argparse.ArgumentParser(description="LabelU Florence-2 Model Server") + parser.add_argument("--port", type=int, default=5000) + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument( + "--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu" + ) + parser.add_argument("--model", type=str, default=MODEL_NAME) + args = parser.parse_args() + + device = args.device + MODEL_NAME = args.model + + logger.info("Loading %s on %s ...", MODEL_NAME, device) + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + model = ( + AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + trust_remote_code=True, + torch_dtype=torch.float16 if "cuda" in device else torch.float32, + ) + .to(device) + .eval() + ) + logger.info("Model loaded.") + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/model_server/grounding_dino_sam/Dockerfile b/model_server/grounding_dino_sam/Dockerfile new file mode 100644 index 00000000..33e1c926 --- /dev/null +++ b/model_server/grounding_dino_sam/Dockerfile @@ -0,0 +1,21 @@ +FROM python:3.11-slim + +RUN apt-get update && apt-get install -y --no-install-recommends libgl1 libglib2.0-0 && rm -rf /var/lib/apt/lists/* + +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY server.py . + +# Pre-download models at build time +RUN python -c "\ +from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection; \ +AutoProcessor.from_pretrained('IDEA-Research/grounding-dino-tiny'); \ +AutoModelForZeroShotObjectDetection.from_pretrained('IDEA-Research/grounding-dino-tiny'); \ +from transformers import EfficientSamModel, SamImageProcessor; \ +SamImageProcessor.from_pretrained('ybelkada/efficient-sam-vitt'); \ +EfficientSamModel.from_pretrained('ybelkada/efficient-sam-vitt')" + +EXPOSE 5000 +CMD ["python", "server.py", "--device", "cpu"] diff --git a/model_server/grounding_dino_sam/requirements.txt b/model_server/grounding_dino_sam/requirements.txt new file mode 100644 index 00000000..bf9b11e7 --- /dev/null +++ b/model_server/grounding_dino_sam/requirements.txt @@ -0,0 +1,8 @@ +torch>=2.1.0 +transformers>=4.40.0 +fastapi>=0.100.0 +uvicorn>=0.19.0 +httpx>=0.27.0 +pillow>=9.3.0 +numpy>=1.24.0 +opencv-python-headless>=4.8.0 diff --git a/model_server/grounding_dino_sam/server.py b/model_server/grounding_dino_sam/server.py new file mode 100644 index 00000000..fb00809a --- /dev/null +++ b/model_server/grounding_dino_sam/server.py @@ -0,0 +1,268 @@ +""" +LabelU Model Server — GroundingDINO + EfficientSAM + +High-quality reference implementation for open-vocabulary detection +(GroundingDINO) paired with segmentation (EfficientSAM). + +Implements the LabelU auto-label model API protocol: + POST / → { request_id, image_url, labels, constraints, prompt } + Returns → { model, results: [{ toolName, label, result, score }], ... } + +Quick start: + pip install -r requirements.txt + python server.py # default: 0.0.0.0:5000 + python server.py --port 5001 --device cuda +""" + +from __future__ import annotations + +import argparse +import io +import time +import logging +from typing import Any + +import cv2 +import httpx +import numpy as np +import torch +import uvicorn +from fastapi import FastAPI, HTTPException +from PIL import Image +from pydantic import BaseModel, Field + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("dino-sam-server") + +app = FastAPI(title="LabelU GroundingDINO + EfficientSAM Model Server") + +# ── globals ─────────────────────────────────────────────────────────── +dino_model = None +dino_processor = None +sam_model = None +sam_processor = None +device = "cpu" + +DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny" +SAM_MODEL_ID = "ybelkada/efficient-sam-vitt" +MODEL_LABEL = "grounding-dino-tiny+efficient-sam" + +BOX_THRESHOLD = 0.25 +TEXT_THRESHOLD = 0.25 + + +# ── request / response schemas ──────────────────────────────────────── +class LabelItem(BaseModel): + name: str + display_name: str | None = None + color: str | None = None + tool: str + + +class Constraints(BaseModel): + allowed_tools: list[str] = [] + max_results_per_label: int = 100 + filter_by_labels: bool = False + + +class PredictRequest(BaseModel): + request_id: str = "" + image_url: str + task: dict[str, Any] = {} + labels: list[LabelItem] = [] + constraints: Constraints = Field(default_factory=Constraints) + prompt: str | None = None + + +class ResultItem(BaseModel): + toolName: str + label: str + result: dict[str, Any] + score: float | None = None + + +class PredictResponse(BaseModel): + model: str = MODEL_LABEL + latency_ms: int = 0 + results: list[ResultItem] = [] + warning_message: str | None = None + + +# ── helpers ─────────────────────────────────────────────────────────── +async def _download_image(url: str) -> Image.Image: + async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: + resp = await client.get(url) + resp.raise_for_status() + return Image.open(io.BytesIO(resp.content)).convert("RGB") + + +def _detect_objects( + image: Image.Image, + text_prompt: str, + box_threshold: float = BOX_THRESHOLD, + text_threshold: float = TEXT_THRESHOLD, +) -> tuple[list[list[float]], list[str], list[float]]: + """Run GroundingDINO detection. Returns (boxes_xyxy_abs, labels, scores).""" + inputs = dino_processor(images=image, text=text_prompt, return_tensors="pt").to(device) + with torch.inference_mode(): + outputs = dino_model(**inputs) + + post = dino_processor.post_process_grounded_object_detection( + outputs, + inputs.input_ids, + box_threshold=box_threshold, + text_threshold=text_threshold, + target_sizes=[image.size[::-1]], + )[0] + + boxes = post["boxes"].cpu().numpy().tolist() + labels_out = post["labels"] + scores = post["scores"].cpu().numpy().tolist() + return boxes, labels_out, scores + + +def _segment_box(image: Image.Image, box_xyxy: list[float]) -> np.ndarray | None: + """Run EfficientSAM on a single bounding box prompt. Returns binary mask.""" + if sam_model is None: + return None + + w, h = image.size + input_points = torch.tensor([[[ + [box_xyxy[0] / w, box_xyxy[1] / h], + [box_xyxy[2] / w, box_xyxy[3] / h], + ]]]).to(device) + input_labels = torch.tensor([[[2, 3]]]).to(device) # box prompt: top-left=2, bottom-right=3 + + inputs = sam_processor(image, input_points=input_points, input_labels=input_labels, return_tensors="pt").to(device) + with torch.inference_mode(): + outputs = sam_model(**inputs) + + mask = sam_processor.image_processor.post_process_masks( + outputs.pred_masks.cpu(), + inputs["original_sizes"].cpu(), + inputs["reshaped_input_sizes"].cpu(), + )[0] + + mask_np = mask[0, 0].numpy().astype(np.uint8) + return mask_np + + +def _mask_to_polygon(mask: np.ndarray) -> list[dict[str, float]]: + """Extract the largest contour as simplified polygon points.""" + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return [] + largest = max(contours, key=cv2.contourArea) + epsilon = 0.005 * cv2.arcLength(largest, True) + approx = cv2.approxPolyDP(largest, epsilon, True) + return [{"x": float(pt[0][0]), "y": float(pt[0][1])} for pt in approx] + + +def _bbox_to_rect(box_xyxy: list[float]) -> dict[str, float]: + x1, y1, x2, y2 = box_xyxy + return {"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1} + + +# ── main endpoint ───────────────────────────────────────────────────── +@app.post("/", response_model=PredictResponse) +async def predict(req: PredictRequest) -> PredictResponse: + start = time.perf_counter() + try: + image = await _download_image(req.image_url) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Failed to download image: {exc}") + + label_names = [lbl.name for lbl in req.labels] + label_name_set = {n.lower().strip() for n in label_names} + label_tool_map = {lbl.name.lower().strip(): lbl.tool for lbl in req.labels} + text_prompt = " . ".join(label_names) + " ." + max_per = req.constraints.max_results_per_label + allowed = set(req.constraints.allowed_tools) if req.constraints.allowed_tools else None + filter_labels = req.constraints.filter_by_labels + + boxes, det_labels, scores = _detect_objects(image, text_prompt) + + results: list[ResultItem] = [] + counts: dict[str, int] = {} + + for box, det_label, score in zip(boxes, det_labels, scores): + clean = det_label.strip().lower() + if filter_labels and clean not in label_name_set: + continue + counts[clean] = counts.get(clean, 0) + 1 + if counts[clean] > max_per: + continue + + tool = label_tool_map.get(clean, "rectTool") + if allowed and tool not in allowed: + continue + + if tool == "polygonTool" and sam_model is not None: + mask = _segment_box(image, box) + if mask is not None: + points = _mask_to_polygon(mask) + if len(points) >= 3: + results.append(ResultItem( + toolName="polygonTool", + label=clean, + result={"type": "line", "points": points}, + score=round(score, 4), + )) + continue + + results.append(ResultItem( + toolName=tool, + label=clean, + result=_bbox_to_rect(box), + score=round(score, 4), + )) + + latency = int((time.perf_counter() - start) * 1000) + warning = None if sam_model else "EfficientSAM not loaded; polygon results use bounding boxes" + return PredictResponse(model=MODEL_LABEL, latency_ms=latency, results=results, warning_message=warning) + + +@app.get("/health") +async def health(): + return {"status": "ok", "model": MODEL_LABEL, "device": device, "sam_loaded": sam_model is not None} + + +# ── entrypoint ──────────────────────────────────────────────────────── +def main(): + global dino_model, dino_processor, sam_model, sam_processor, device + + parser = argparse.ArgumentParser(description="LabelU GroundingDINO + EfficientSAM Model Server") + parser.add_argument("--port", type=int, default=5000) + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + parser.add_argument("--no-sam", action="store_true", help="Disable EfficientSAM (detection only)") + parser.add_argument("--box-threshold", type=float, default=BOX_THRESHOLD) + parser.add_argument("--text-threshold", type=float, default=TEXT_THRESHOLD) + args = parser.parse_args() + + device = args.device + global BOX_THRESHOLD, TEXT_THRESHOLD + BOX_THRESHOLD = args.box_threshold + TEXT_THRESHOLD = args.text_threshold + + from transformers import AutoModelForZeroShotObjectDetection, AutoProcessor as AP + + logger.info("Loading GroundingDINO (%s) on %s ...", DINO_MODEL_ID, device) + dino_processor = AP.from_pretrained(DINO_MODEL_ID) + dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(DINO_MODEL_ID).to(device).eval() + logger.info("GroundingDINO loaded.") + + if not args.no_sam: + from transformers import EfficientSamModel, SamImageProcessor + logger.info("Loading EfficientSAM (%s) on %s ...", SAM_MODEL_ID, device) + sam_processor = SamImageProcessor.from_pretrained(SAM_MODEL_ID) + sam_model = EfficientSamModel.from_pretrained(SAM_MODEL_ID).to(device).eval() + logger.info("EfficientSAM loaded.") + else: + logger.info("EfficientSAM disabled (--no-sam).") + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/model_server/sam3/Dockerfile b/model_server/sam3/Dockerfile new file mode 100644 index 00000000..d8f70c44 --- /dev/null +++ b/model_server/sam3/Dockerfile @@ -0,0 +1,16 @@ +FROM python:3.12-slim + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git build-essential libgl1 libglib2.0-0 && \ + rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY server.py . + +EXPOSE 5000 + +CMD ["python", "server.py", "--device", "cuda"] diff --git a/model_server/sam3/requirements.txt b/model_server/sam3/requirements.txt new file mode 100644 index 00000000..5423a4f1 --- /dev/null +++ b/model_server/sam3/requirements.txt @@ -0,0 +1,9 @@ +torch>=2.7.0 +torchvision +sam3 @ git+https://github.com/facebookresearch/sam3.git +fastapi>=0.100.0 +uvicorn>=0.19.0 +httpx>=0.27.0 +pillow>=9.3.0 +numpy>=1.24.0 +opencv-python-headless>=4.8.0 diff --git a/model_server/sam3/server.py b/model_server/sam3/server.py new file mode 100644 index 00000000..aa3e2262 --- /dev/null +++ b/model_server/sam3/server.py @@ -0,0 +1,214 @@ +""" +LabelU Model Server — SAM 3 (Segment Anything with Concepts) + +Uses Meta SAM 3 for unified open-vocabulary detection + segmentation. +Single model replaces GroundingDINO + SAM pipeline. + +Implements the LabelU auto-label model API protocol: + POST / -> { request_id, image_url, labels, constraints, prompt } + Returns -> { model, results: [{ toolName, label, result, score }], ... } + +Quick start: + pip install -r requirements.txt + python server.py # default: 0.0.0.0:5000 + python server.py --port 5001 --device cuda +""" + +from __future__ import annotations + +import argparse +import io +import time +import logging +from typing import Any + +import cv2 +import httpx +import numpy as np +import torch +import uvicorn +from fastapi import FastAPI, HTTPException +from PIL import Image +from pydantic import BaseModel, Field + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("sam3-server") + +app = FastAPI(title="LabelU SAM 3 Model Server") + +# -- globals (set in main) --------------------------------------------------- +sam3_model = None +sam3_processor = None +device = "cpu" + +MODEL_NAME = "sam3" + + +# -- request / response schemas ---------------------------------------------- +class LabelItem(BaseModel): + name: str + display_name: str | None = None + color: str | None = None + tool: str + + +class Constraints(BaseModel): + allowed_tools: list[str] = [] + max_results_per_label: int = 100 + filter_by_labels: bool = False + + +class PredictRequest(BaseModel): + request_id: str = "" + image_url: str + task: dict[str, Any] = {} + labels: list[LabelItem] = [] + constraints: Constraints = Field(default_factory=Constraints) + prompt: str | None = None + + +class ResultItem(BaseModel): + toolName: str + label: str + result: dict[str, Any] + score: float | None = None + + +class PredictResponse(BaseModel): + model: str = MODEL_NAME + latency_ms: int = 0 + results: list[ResultItem] = [] + warning_message: str | None = None + + +# -- helpers ------------------------------------------------------------------ +async def _download_image(url: str) -> Image.Image: + async with httpx.AsyncClient(timeout=30, follow_redirects=True) as client: + resp = await client.get(url) + resp.raise_for_status() + return Image.open(io.BytesIO(resp.content)).convert("RGB") + + +def _bbox_to_rect(box: list[float] | torch.Tensor) -> dict[str, float]: + """Convert [x1, y1, x2, y2] to {x, y, width, height}.""" + if isinstance(box, torch.Tensor): + box = box.cpu().tolist() + x1, y1, x2, y2 = box + return {"x": x1, "y": y1, "width": x2 - x1, "height": y2 - y1} + + +def _mask_to_polygon(mask: np.ndarray) -> list[dict[str, float]]: + """Extract the largest contour as simplified polygon points.""" + if isinstance(mask, torch.Tensor): + mask = mask.cpu().numpy() + mask_uint8 = (mask > 0).astype(np.uint8) + contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return [] + largest = max(contours, key=cv2.contourArea) + epsilon = 0.005 * cv2.arcLength(largest, True) + approx = cv2.approxPolyDP(largest, epsilon, True) + return [{"x": float(pt[0][0]), "y": float(pt[0][1])} for pt in approx] + + +def _detect_and_segment( + image: Image.Image, + label: LabelItem, + max_per_label: int, +) -> list[ResultItem]: + """Run SAM 3 text-prompted detection + segmentation for a single label.""" + inference_state = sam3_processor.set_image(image) + output = sam3_processor.set_text_prompt( + state=inference_state, + prompt=label.name, + ) + + masks = output.get("masks") # [N, H, W] + boxes = output.get("boxes") # [N, 4] xyxy + scores = output.get("scores") # [N] + + if boxes is None or len(boxes) == 0: + return [] + + results: list[ResultItem] = [] + n = min(len(boxes), max_per_label) + + for i in range(n): + score = round(float(scores[i]), 4) if scores is not None and i < len(scores) else None + + if label.tool == "polygonTool" and masks is not None and i < len(masks): + points = _mask_to_polygon(masks[i]) + if len(points) >= 3: + results.append(ResultItem( + toolName="polygonTool", + label=label.name, + result={"type": "line", "points": points}, + score=score, + )) + continue + + # Default: bounding box + results.append(ResultItem( + toolName=label.tool, + label=label.name, + result=_bbox_to_rect(boxes[i]), + score=score, + )) + + return results + + +# -- main endpoint ------------------------------------------------------------ +@app.post("/", response_model=PredictResponse) +async def predict(req: PredictRequest) -> PredictResponse: + start = time.perf_counter() + try: + image = await _download_image(req.image_url) + except Exception as exc: + raise HTTPException(status_code=400, detail=f"Failed to download image: {exc}") + + allowed = set(req.constraints.allowed_tools) if req.constraints.allowed_tools else None + max_per = req.constraints.max_results_per_label + results: list[ResultItem] = [] + + for lbl in req.labels: + if allowed and lbl.tool not in allowed: + continue + label_results = _detect_and_segment(image, lbl, max_per) + results.extend(label_results) + + latency = int((time.perf_counter() - start) * 1000) + return PredictResponse(model=MODEL_NAME, latency_ms=latency, results=results) + + +@app.get("/health") +async def health(): + return {"status": "ok", "model": MODEL_NAME, "device": device} + + +# -- entrypoint --------------------------------------------------------------- +def main(): + global sam3_model, sam3_processor, device, MODEL_NAME + + parser = argparse.ArgumentParser(description="LabelU SAM 3 Model Server") + parser.add_argument("--port", type=int, default=5000) + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") + args = parser.parse_args() + + device = args.device + + from sam3.model_builder import build_sam3_image_model + from sam3.model.sam3_image_processor import Sam3Processor + + logger.info("Loading SAM 3 on %s ...", device) + sam3_model = build_sam3_image_model() + sam3_processor = Sam3Processor(sam3_model) + MODEL_NAME = "sam3" + logger.info("SAM 3 loaded.") + + uvicorn.run(app, host=args.host, port=args.port) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 2f3938a1..7a0afec3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,8 @@ dependencies = [ "fastapi>=0.100.0", "loguru>=0.6.0", "sqlalchemy>=2.0.0", + "boto3", + "cryptography>=41.0.0", "python-jose[cryptography]>=3.3.0", "pydantic>=2.0.0", "pydantic-settings>=2.0.0", diff --git a/uv.lock b/uv.lock index 24e1237d..eee82f61 100644 --- a/uv.lock +++ b/uv.lock @@ -156,6 +156,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8e/0d/52d98722666d6fc6c3dd4c76df339501d6efd40e0ff95e6186a7b7f0befd/black-26.3.1-py3-none-any.whl", hash = "sha256:2bd5aa94fc267d38bb21a70d7410a89f1a1d318841855f698746f8e7f51acd1b", size = 207542, upload-time = "2026-03-12T03:36:01.668Z" }, ] +[[package]] +name = "boto3" +version = "1.42.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, + { name = "jmespath" }, + { name = "s3transfer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/62/17/510f31d7d6190c01725710d95415733e467f4406d450a106f6eacfd3a94d/boto3-1.42.90.tar.gz", hash = "sha256:bafb5bb1dea262ac95f9afb1e415f06a9490f05cb203bdd897d0afdcd17733c6", size = 113174, upload-time = "2026-04-16T20:27:43.012Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e1/bd/a0c5011a8eddce39a9b613b13a75057bf960ef2145ff4d1583ed81a2599b/boto3-1.42.90-py3-none-any.whl", hash = "sha256:fde7f7bcad6ec8342d6bf18f56d118d0cb6df189310cfaf73e2eb6443b1cb418", size = 140554, upload-time = "2026-04-16T20:27:40.226Z" }, +] + +[[package]] +name = "botocore" +version = "1.42.90" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jmespath" }, + { name = "python-dateutil" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/86/cf/4eaa0b7ab2ba0b2a0c93e277779b1385127e2f07876a08d698b529affdae/botocore-1.42.90.tar.gz", hash = "sha256:234c39492cd3088acb021d999e3392a4d50238ae3e70b9d9ae1504c30d9009d1", size = 15209231, upload-time = "2026-04-16T20:27:29.323Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/1e/44afcdc3b526b6e1569dd142083c6ed1cb8b92b4141de1c78ded883b449a/botocore-1.42.90-py3-none-any.whl", hash = "sha256:5c95504720346990adc8e3ae1023eb46f9409084b79688e4773ba7099c5fd3db", size = 14892274, upload-time = "2026-04-16T20:27:24.057Z" }, +] + [[package]] name = "certifi" version = "2026.2.25" @@ -745,15 +773,26 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "jmespath" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/59/322338183ecda247fb5d1763a6cbe46eff7222eaeebafd9fa65d4bf5cb11/jmespath-1.1.0.tar.gz", hash = "sha256:472c87d80f36026ae83c6ddd0f1d05d4e510134ed462851fd5f754c8c3cbb88d", size = 27377, upload-time = "2026-01-22T16:35:26.279Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/14/2f/967ba146e6d58cf6a652da73885f52fc68001525b4197effc174321d70b4/jmespath-1.1.0-py3-none-any.whl", hash = "sha256:a5663118de4908c91729bea0acadca56526eb2698e83de10cd116ae0f4e97c64", size = 20419, upload-time = "2026-01-22T16:35:24.919Z" }, +] + [[package]] name = "labelu" -version = "1.3.4" +version = "1.3.6" source = { editable = "." } dependencies = [ { name = "aiofiles" }, { name = "alembic" }, { name = "appdirs" }, { name = "bcrypt" }, + { name = "boto3" }, + { name = "cryptography" }, { name = "email-validator" }, { name = "fastapi" }, { name = "httpx" }, @@ -793,6 +832,8 @@ requires-dist = [ { name = "alembic", specifier = ">=1.9.4" }, { name = "appdirs", specifier = ">=1.4.4" }, { name = "bcrypt", specifier = "==4.3.0" }, + { name = "boto3" }, + { name = "cryptography", specifier = ">=41.0.0" }, { name = "email-validator", specifier = ">=2.0.0" }, { name = "fastapi", specifier = ">=0.100.0" }, { name = "httpx", specifier = ">=0.27.0" }, @@ -1371,6 +1412,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ee/49/1377b49de7d0c1ce41292161ea0f721913fa8722c19fb9c1e3aa0367eecb/pytest_cov-7.0.0-py3-none-any.whl", hash = "sha256:3b8e9558b16cc1479da72058bdecf8073661c7f57f7d3c5f22a1c23507f2d861", size = 22424, upload-time = "2025-09-09T10:57:00.695Z" }, ] +[[package]] +name = "python-dateutil" +version = "2.9.0.post0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/66/c0/0c8b6ad9f17a802ee498c46e004a0eb49bc148f2fd230864601a86dcf6db/python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3", size = 342432, upload-time = "2024-03-01T18:36:20.211Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, +] + [[package]] name = "python-dotenv" version = "1.2.2" @@ -1469,6 +1522,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, ] +[[package]] +name = "s3transfer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "botocore" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, +] + [[package]] name = "six" version = "1.17.0"