Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 292 additions & 1 deletion backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import shutil
import asyncio
import time
import csv
import io
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from pathlib import Path
Expand All @@ -14,7 +16,7 @@
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from fastapi import FastAPI, Query, Request
from fastapi import FastAPI, Query, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from db import delete_pull_proxy as db_delete_pull_proxy
from db import delete_record_policy as db_delete_record_policy
Expand Down Expand Up @@ -202,6 +204,77 @@ def _stream_proxy_key(vhost: str, app: str, stream: str) -> str:
return f"{vhost}/{app}/{stream}"


_APP_STREAM_PATTERN = re.compile(r"^[a-zA-Z0-9._-]+$")
_URL_PREFIXES = ("rtsp://", "rtmp://", "http://", "https://")
_IMPORT_DEFAULT_VHOST = "__defaultVhost__"
_CSV_IMPORT_HEADERS = {
"vhost": {"vhost", "虚拟主机(vhost)", "虚拟主机"},
"app": {"app", "应用名(app)", "应用名"},
"stream": {"stream", "流id(stream)", "流id", "流ID(stream)", "流ID"},
"url": {"url", "源流地址(url)", "源流地址"},
"audio_type": {"audio_type", "音频设置(audio_type)", "音频设置"},
}


def _normalize_csv_header(value: str) -> str:
return str(value or "").strip().lower()


def _resolve_csv_header_map(headers: list[str]) -> dict[str, str]:
resolved: dict[str, str] = {}
normalized_to_raw = {_normalize_csv_header(h): h for h in headers}
for field, aliases in _CSV_IMPORT_HEADERS.items():
found = None
for alias in aliases:
if _normalize_csv_header(alias) in normalized_to_raw:
found = normalized_to_raw[_normalize_csv_header(alias)]
break
if found:
resolved[field] = found
return resolved


def _extract_csv_value(row: dict[str, str], header_map: dict[str, str], field: str) -> str:
key = header_map.get(field)
if not key:
return ""
return str(row.get(key, "")).strip()


def _parse_audio_type(value: str) -> tuple[int | None, str | None]:
text = str(value or "").strip()
if text == "":
return None, None
if text in {"0", "1", "2"}:
return int(text), None
lowered = text.lower()
if lowered in {"disable", "off", "none", "noaudio"}:
return 0, None
if lowered in {"origin", "source", "audio"}:
return 1, None
if lowered in {"mute", "silent"}:
return 2, None
return None, "audio_type 仅支持 0/1/2(可为空)"


def _validate_pull_proxy_input(
*,
app: str,
stream: str,
url: str,
audio_type: int | None,
) -> str | None:
if not _APP_STREAM_PATTERN.match(app):
return "app 只能包含字母、数字、下划线(_)、短横线(-) 或英文句点(.)"
if not _APP_STREAM_PATTERN.match(stream):
return "stream 只能包含字母、数字、下划线(_)、短横线(-) 或英文句点(.)"
if not any(url.startswith(prefix) for prefix in _URL_PREFIXES):
return "源流地址必须以 rtsp://、rtmp://、http:// 或 https:// 开头"
if audio_type is not None and audio_type not in (0, 1, 2):
return "audio_type 仅支持 0/1/2"
return None


async def _add_stream_proxy_to_zlm(
*,
vhost: str,
Expand Down Expand Up @@ -1128,6 +1201,224 @@ async def _restart_self():
return {"code": -1, "msg": "重启失败", "error": str(e), "via": "docker"}


@app.get(
"/api/stream/pull-proxy-template",
summary="Download pull proxy import template",
tags=["流"],
)
async def get_pull_proxy_template():
output = io.StringIO()
writer = csv.writer(output, lineterminator="\n")
writer.writerow(
[
"虚拟主机(vhost)",
"应用名(app)",
"流ID(stream)",
"源流地址(url)",
"音频设置(audio_type)",
]
)
writer.writerow(
[
"__defaultVhost__",
"live",
"cam101",
"rtsp://admin:password@192.168.0.64:554/Streaming/Channels/101",
"1",
]
)
content = "\ufeff" + output.getvalue()
return Response(
content=content,
media_type="text/csv; charset=utf-8",
headers={
"Content-Disposition": (
'attachment; filename="pull-proxy-template.csv"; '
"filename*=UTF-8''pull-proxy-template.csv"
)
},
)


@app.get(
"/api/stream/pull-proxy-export",
summary="Export pull proxies to CSV",
tags=["流"],
)
async def get_pull_proxy_export():
rows = db_list_pull_proxies()
output = io.StringIO()
writer = csv.writer(output, lineterminator="\n")
writer.writerow(
[
"虚拟主机(vhost)",
"应用名(app)",
"流ID(stream)",
"源流地址(url)",
"音频设置(audio_type)",
]
)
for row in rows:
writer.writerow(
[
str(row.get("vhost") or _IMPORT_DEFAULT_VHOST),
str(row.get("app") or ""),
str(row.get("stream") or ""),
str(row.get("url") or ""),
"" if row.get("audio_type") is None else str(row.get("audio_type")),
]
)
content = "\ufeff" + output.getvalue()
now_str = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"pull-proxy-export-{now_str}.csv"
return Response(
content=content,
media_type="text/csv; charset=utf-8",
headers={
"Content-Disposition": (
f'attachment; filename="{filename}"; '
f"filename*=UTF-8''{filename}"
)
},
)


@app.post(
"/api/stream/pull-proxy-import",
summary="Import pull proxies from CSV",
tags=["流"],
)
async def post_pull_proxy_import(request: Request):
try:
body = await request.json()
except Exception:
return {"code": -1, "msg": "请求体必须是 JSON"}

csv_text = str(body.get("csv_text") or "")
mode = str(body.get("mode") or "overwrite").strip().lower()
if mode not in {"overwrite", "skip"}:
return {"code": -1, "msg": "mode 仅支持 overwrite 或 skip"}

if not csv_text.strip():
return {"code": -1, "msg": "导入文件内容为空"}

# Allow UTF-8 BOM.
csv_text = csv_text.lstrip("\ufeff")

try:
reader = csv.DictReader(io.StringIO(csv_text))
headers = list(reader.fieldnames or [])
except Exception:
return {"code": -1, "msg": "CSV 格式解析失败"}

if not headers:
return {"code": -1, "msg": "CSV 缺少表头"}

header_map = _resolve_csv_header_map(headers)
required_fields = ("app", "stream", "url")
missing = [field for field in required_fields if field not in header_map]
if missing:
return {
"code": -1,
"msg": f"CSV 缺少必要列: {', '.join(missing)}",
"supported_headers": _CSV_IMPORT_HEADERS,
}

existing_rows = db_list_pull_proxies()
existing_keys = {
_stream_proxy_key(
str(item.get("vhost") or _IMPORT_DEFAULT_VHOST),
str(item.get("app") or ""),
str(item.get("stream") or ""),
)
for item in existing_rows
}

total = 0
created = 0
overwritten = 0
skipped = 0
failed = 0
seen_in_file: set[str] = set()
errors: list[dict] = []
imported_rows: list[dict] = []

for line_no, row in enumerate(reader, start=2):
total += 1
vhost = _extract_csv_value(row, header_map, "vhost") or _IMPORT_DEFAULT_VHOST
app = _extract_csv_value(row, header_map, "app")
stream = _extract_csv_value(row, header_map, "stream")
url = _extract_csv_value(row, header_map, "url")
audio_raw = _extract_csv_value(row, header_map, "audio_type")

audio_type, audio_err = _parse_audio_type(audio_raw)
if audio_err:
failed += 1
errors.append({"line": line_no, "reason": audio_err, "row": row})
continue

invalid_reason = _validate_pull_proxy_input(
app=app,
stream=stream,
url=url,
audio_type=audio_type,
)
if invalid_reason:
failed += 1
errors.append({"line": line_no, "reason": invalid_reason, "row": row})
continue

key = _stream_proxy_key(vhost, app, stream)
if key in seen_in_file:
failed += 1
errors.append({"line": line_no, "reason": "导入文件内存在重复 app/stream", "row": row})
continue
seen_in_file.add(key)

if mode == "skip" and key in existing_keys:
skipped += 1
continue

db_row = db_upsert_pull_proxy(
vhost=vhost,
app=app,
stream=stream,
url=url,
audio_type=audio_type,
)
asyncio.create_task(
_add_stream_proxy_to_zlm(
vhost=vhost,
app=app,
stream=stream,
url=url,
audio_type=audio_type,
)
)
imported_rows.append({"vhost": vhost, "app": app, "stream": stream, "db": db_row})

if key in existing_keys:
overwritten += 1
else:
created += 1
existing_keys.add(key)

return {
"code": 0,
"msg": "导入完成",
"mode": mode,
"summary": {
"total": total,
"created": created,
"overwritten": overwritten,
"skipped": skipped,
"failed": failed,
},
"errors": errors[:200],
"imported": imported_rows[:200],
}


if __name__ == "__main__":
import uvicorn

Expand Down
Loading