diff --git a/.github/workflows/build-release-lite.yml b/.github/workflows/build-release-lite.yml new file mode 100644 index 00000000..ca083e0b --- /dev/null +++ b/.github/workflows/build-release-lite.yml @@ -0,0 +1,106 @@ +name: Build and Release SuperPicky Lite + +on: + workflow_dispatch: + inputs: + version: + description: Version number for the lite release + required: true + +jobs: + build-windows-lite: + runs-on: windows-latest + permissions: + contents: write + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + cache: 'pip' + cache-dependency-path: | + requirements.txt + requirements_cuda.txt + + - name: Resolve release metadata + id: release_meta + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + INPUT_VERSION: ${{ github.event.inputs.version }} + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata + + - name: Build Windows Lite release + env: + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + run: python build_release_win.py --build-type lite --copy-dir output/lite-win --no-zip + + - name: Create Windows Lite installer + uses: Minionguyjpro/Inno-Setup-Action@v1.2.7 + with: + path: ./output/lite-win/installer_lite/SuperPicky-lite.iss + + - name: Collect Windows Lite assets + run: | + python scripts/ci_release.py collect-assets --output-dir release_assets/lite-win --pattern output/lite-win/installer_lite/Output/SuperPicky_Setup_Lite_Win64_*.exe + + - name: Upload Windows Lite assets to GitHub Release + uses: softprops/action-gh-release@v3 + with: + tag_name: ${{ steps.release_meta.outputs.tag }} + name: SuperPicky Lite ${{ steps.release_meta.outputs.tag }} + body_path: ChangeLog.md + fail_on_unmatched_files: true + prerelease: ${{ contains(steps.release_meta.outputs.tag, '-rc') }} + files: | + release_assets/lite-win/* + + build-mac-lite: + runs-on: macos-14 + permissions: + contents: write + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + cache: 'pip' + cache-dependency-path: requirements_mac.txt + + - name: Resolve release metadata + id: release_meta + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + INPUT_VERSION: ${{ github.event.inputs.version }} + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata + + - name: Install macOS Lite build dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements_mac.txt + + - name: Build unsigned macOS Lite release + env: + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + run: | + python build_release_mac.py --build-type lite --arch arm64 --copy-dir output/lite-mac + + - name: Upload macOS Lite assets to GitHub Release + uses: softprops/action-gh-release@v3 + with: + tag_name: ${{ steps.release_meta.outputs.tag }} + name: SuperPicky Lite ${{ steps.release_meta.outputs.tag }} + body_path: ChangeLog.md + fail_on_unmatched_files: true + prerelease: ${{ contains(steps.release_meta.outputs.tag, '-rc') }} + files: | + output/lite-mac/*.dmg \ No newline at end of file diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index af09e39b..a9d1613e 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' cache: 'pip' @@ -30,7 +30,6 @@ jobs: requirements_cuda.txt - name: Prepare telemetry build override - shell: pwsh env: COUNTLY_APP_KEY: ${{ secrets.COUNTLY_APP_KEY }} COUNTLY_SERVER_URL: ${{ secrets.COUNTLY_SERVER_URL }} @@ -38,175 +37,53 @@ jobs: - name: Resolve release metadata id: release_meta - shell: pwsh env: - EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_NAME: ${{ github.event_name }} INPUT_VERSION: ${{ github.event.inputs.version }} - REF_NAME: ${{ github.ref_name }} - run: | - if ($env:EVENT_NAME -eq "workflow_dispatch") { - $rawVersion = $env:INPUT_VERSION - } else { - $rawVersion = $env:REF_NAME - } - - if (-not $rawVersion) { - throw "Release version is required." - } - - $tag = if ($rawVersion.StartsWith("v")) { $rawVersion } else { "v$rawVersion" } - "tag=$tag" >> $env:GITHUB_OUTPUT - "name=SuperPicky $tag" >> $env:GITHUB_OUTPUT - - - name: Create build virtual environment - shell: pwsh - run: python -m venv .venv - - - name: Install CPU build dependencies - shell: pwsh - run: | - .\.venv\Scripts\python.exe -m pip install --upgrade pip - .\.venv\Scripts\python.exe -m pip install -r requirements.txt + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata - - name: Build CPU and CUDA patch release payloads - shell: pwsh + - name: Build Full release env: RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} - run: | - .\.venv\Scripts\python.exe scripts\download_models.py - .\.venv\Scripts\python.exe build_release_win.py --build-type cuda-patch --copy-dir output --debug --no-zip + run: python build_release_win.py --build-type cpu --copy-dir output --no-zip - - name: Create CPU installer with Inno Setup + - name: Create Full installer with Inno Setup uses: Minionguyjpro/Inno-Setup-Action@v1.2.7 with: path: ./output/installer_cpu/SuperPicky.iss - - name: Create CUDA patch installer with Inno Setup + - name: Build Lite release + env: + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + run: python build_release_win.py --build-type lite --copy-dir output --no-zip + + - name: Create Lite installer with Inno Setup uses: Minionguyjpro/Inno-Setup-Action@v1.2.7 with: - path: ./output/cuda_patch_installer/SuperPicky_CUDA_Patch.iss + path: ./output/installer_lite/SuperPicky-lite.iss - name: Prepare release assets - shell: pwsh run: | - New-Item -ItemType Directory -Path release_assets -Force | Out-Null - - $assetPatterns = @( - "output/cuda_patch_installer/Output/SuperPicky_CUDA_Patch_Win64_*.exe", - "output/installer_cpu/Output/SuperPicky_Setup_Win64_*.exe" - ) - - foreach ($pattern in $assetPatterns) { - $foundFiles = Get-ChildItem -Path $pattern -File -ErrorAction Stop - if ($foundFiles.Count -ne 1) { - throw "Expected exactly one asset for pattern '$pattern', found $($foundFiles.Count)." - } - Copy-Item -Path $foundFiles[0].FullName -Destination release_assets/ - } + python scripts/ci_release.py collect-assets --output-dir release_assets --pattern output/installer_cpu/Output/SuperPicky_Setup_Full_Win64_*.exe --pattern output/installer_lite/Output/SuperPicky_Setup_Lite_Win64_*.exe - name: Build code patch zip - shell: pwsh - run: | - $version = .\\.venv\\Scripts\\python.exe -c "from constants import APP_VERSION; print(APP_VERSION)" - $patchVersion = "${{ steps.release_meta.outputs.tag }}" - $zipName = "code_patch_${patchVersion}.zip" - $zipPath = "release_assets/$zipName" - - # 7-Zip 在 windows-latest runner 上内置 - # 包含所有运行时 Python 文件;排除: - # main.py — PyInstaller 入口,patch 后不会生效 - # build_release_win.py — 仅构建用,不需随 patch 分发 - # _telemetry_build.py — 构建时临时覆盖文件 - # pyi_rth_*.py — PyInstaller runtime hook,非运行时逻辑 - # test_*.py — 测试脚本 - $items = @( - "constants.py", "advanced_config.py", "ai_model.py", - "birdid_server.py", "birdid_cli.py", "iqa_scorer.py", - "post_adjustment_engine.py", "server_manager.py", - "superpicky_cli.py", "topiq_model.py", - "tools", "core", "ui", "birdid", "locales" - ) | Where-Object { Test-Path $_ } - & "C:\Program Files\7-Zip\7z.exe" a -tzip $zipPath @items ` - -xr!"__pycache__" -xr!"*.pyc" -xr!"*.pyo" ` - -xr!"main.py" | Out-Null - - $meta = @{ - patch_version = $patchVersion - base_version = $version - applied_at = (Get-Date -Format "o") - } | ConvertTo-Json - $meta | Out-File -Encoding utf8 "release_assets/patch_meta.json" + run: python scripts/ci_release.py build-patch --output-dir release_assets --patch-version ${{ steps.release_meta.outputs.tag }} - name: Clean telemetry build override if: always() - shell: pwsh - run: Remove-Item -Path _telemetry_build.py -Force -ErrorAction SilentlyContinue + run: python scripts/ci_release.py cleanup-paths --path _telemetry_build.py - name: Create GitHub Release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 with: tag_name: ${{ steps.release_meta.outputs.tag }} name: ${{ steps.release_meta.outputs.name }} body_path: ChangeLog.md + fail_on_unmatched_files: true files: | release_assets/* - - name: Create GitCode Release - continue-on-error: true - shell: pwsh - env: - GITCODE_TOKEN: ${{ secrets.GITCODE_TOKEN }} - TAG: ${{ steps.release_meta.outputs.tag }} - run: | - $baseUrl = "https://gitcode.com/api/v4/projects/Jamesphotography%2FSuperPicky" - $tag = $env:TAG - $headers = @{ "PRIVATE-TOKEN" = $env:GITCODE_TOKEN; "Content-Type" = "application/json" } - - # 尝试创建 GitCode Release(Mac job 可能已抢先创建,忽略冲突) - $body = @{ tag_name = $tag; name = "SuperPicky $tag"; description = "See GitHub release for details." } | ConvertTo-Json - try { - Invoke-RestMethod -Uri "$baseUrl/releases" -Method Post -Body $body -Headers $headers -ErrorAction Stop - } catch { - Write-Host "Release already exists or creation failed, continuing..." - } - - - name: Update downloads_github.json and push - shell: pwsh - env: - TAG: ${{ steps.release_meta.outputs.tag }} - GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - REF: ${{ github.ref_name }} - run: | - $tag = $env:TAG - # 从 GitHub Release 资产列表中提取各平台文件名 - $assets = gh release view $tag --json assets --jq '.assets[].name' 2>$null - $macArm64 = ($assets | Where-Object { $_ -match '(?i)arm64.*\.dmg$' } | Select-Object -First 1) - $winCpu = ($assets | Where-Object { $_ -match '(?i)Setup.*Win.*\.exe$' } | Select-Object -First 1) - $winCuda = ($assets | Where-Object { $_ -match '(?i)CUDA.*Win.*\.exe$' } | Select-Object -First 1) - - $json = [ordered]@{ - beta = [ordered]@{ - tag = $tag - updated_at = (Get-Date -Format "o") - files = [ordered]@{ - mac_arm64 = if ($macArm64) { $macArm64 } else { $null } - win_cpu = if ($winCpu) { $winCpu } else { $null } - win_cuda = if ($winCuda) { $winCuda } else { $null } - } - } - } - # 保存 JSON 内容,切换到 nightly 分支后写入并推送 - $jsonContent = $json | ConvertTo-Json -Depth 5 - git config user.name "github-actions[bot]" - git config user.email "github-actions[bot]@users.noreply.github.com" - git fetch origin nightly - git checkout -B nightly origin/nightly - $jsonContent | Out-File -Encoding utf8NoBOM "docs/downloads_github.json" - git add docs/downloads_github.json - git diff --cached --quiet && exit 0 - git commit -m "chore(web): 自动更新 downloads_github.json → $tag [skip ci]" - git push origin HEAD:nightly - build-mac: runs-on: macos-14 permissions: @@ -214,10 +91,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' cache: 'pip' @@ -232,315 +109,44 @@ jobs: - name: Resolve release metadata id: release_meta env: - EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_NAME: ${{ github.event_name }} INPUT_VERSION: ${{ github.event.inputs.version }} - REF_NAME: ${{ github.ref_name }} - run: | - if [ "$EVENT_NAME" = "workflow_dispatch" ]; then - RAW_VERSION="$INPUT_VERSION" - else - RAW_VERSION="$REF_NAME" - fi - - if [ -z "$RAW_VERSION" ]; then - echo "Release version is required." >&2 - exit 1 - fi - - if [[ "$RAW_VERSION" == v* ]]; then - TAG="$RAW_VERSION" - else - TAG="v$RAW_VERSION" - fi - - echo "tag=$TAG" >> "$GITHUB_OUTPUT" - echo "name=SuperPicky $TAG" >> "$GITHUB_OUTPUT" + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements_mac.txt + python -m pip install -r requirements_mac.txt - - name: Import signing certificate + - name: Materialize signing certificate env: MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }} - MACOS_CERTIFICATE_PWD: ${{ secrets.MACOS_CERTIFICATE_PWD }} - run: | - KEYCHAIN_PATH="$RUNNER_TEMP/build.keychain" - KEYCHAIN_PWD=$(openssl rand -hex 16) - - echo "$MACOS_CERTIFICATE" | base64 --decode > "$RUNNER_TEMP/certificate.p12" - - security create-keychain -p "$KEYCHAIN_PWD" "$KEYCHAIN_PATH" - security set-keychain-settings -lut 21600 "$KEYCHAIN_PATH" - security unlock-keychain -p "$KEYCHAIN_PWD" "$KEYCHAIN_PATH" - security import "$RUNNER_TEMP/certificate.p12" \ - -k "$KEYCHAIN_PATH" \ - -P "$MACOS_CERTIFICATE_PWD" \ - -T /usr/bin/codesign - security list-keychain -d user -s "$KEYCHAIN_PATH" - security set-key-partition-list -S apple-tool:,apple: -s -k "$KEYCHAIN_PWD" "$KEYCHAIN_PATH" - - echo "KEYCHAIN_PWD=$KEYCHAIN_PWD" >> "$GITHUB_ENV" - echo "KEYCHAIN_PATH=$KEYCHAIN_PATH" >> "$GITHUB_ENV" - - - name: Download models - run: python scripts/download_models.py - - - name: Inject commit hash and run PyInstaller - env: - TAG: ${{ steps.release_meta.outputs.tag }} - run: | - COMMIT_HASH=$(git rev-parse --short HEAD) - BUILD_INFO_FILE="core/build_info.py" - - # 判断渠道:纯 vX.Y.Z 为 official,含 -RC 为 nightly - if echo "$TAG" | grep -iqE '\-rc'; then - RELEASE_CHANNEL="nightly" - else - RELEASE_CHANNEL="official" - fi - - cp "$BUILD_INFO_FILE" "${BUILD_INFO_FILE}.backup" - sed -i.tmp "s/COMMIT_HASH = .*/COMMIT_HASH = \"${COMMIT_HASH}\"/" "$BUILD_INFO_FILE" - sed -i.tmp "s/RELEASE_CHANNEL = .*/RELEASE_CHANNEL = \"${RELEASE_CHANNEL}\"/" "$BUILD_INFO_FILE" - rm -f "${BUILD_INFO_FILE}.tmp" - - pyinstaller SuperPicky.spec --clean --noconfirm - - mv "${BUILD_INFO_FILE}.backup" "$BUILD_INFO_FILE" - - echo "COMMIT_HASH=$COMMIT_HASH" >> "$GITHUB_ENV" - - - name: Organize .app bundle resources - run: | - APP_PATH="dist/SuperPicky.app" - mkdir -p "${APP_PATH}/Contents/MacOS" "${APP_PATH}/Contents/Resources" - - if [ -d "dist/SuperPicky" ] && [ ! -f "${APP_PATH}/Contents/MacOS/SuperPicky" ]; then - mv dist/SuperPicky/* "${APP_PATH}/Contents/MacOS/" - fi - - for res in SuperBirdIDPlugin.lrplugin en.lproj zh-Hans.lproj; do - if [ -d "${APP_PATH}/Contents/MacOS/$res" ]; then - mv "${APP_PATH}/Contents/MacOS/$res" "${APP_PATH}/Contents/Resources/" - fi - done - - echo "APP_PATH=$APP_PATH" >> "$GITHUB_ENV" + run: python scripts/ci_release.py materialize-secret-file --env-name MACOS_CERTIFICATE --output $RUNNER_TEMP/certificate.p12 --decode-base64 - - name: Sign application + - name: Build notarized macOS release env: - DEVELOPER_ID: ${{ secrets.MACOS_DEVELOPER_ID }} - run: | - find "${APP_PATH}/Contents" -type f \( -name "*.dylib" -o -name "*.so" -o -perm +111 \) -print0 | \ - xargs -0 -P 8 -I {} codesign --force --sign "$DEVELOPER_ID" --timestamp --options runtime {} 2>/dev/null || true - - codesign --force --deep --sign "$DEVELOPER_ID" \ - --timestamp \ - --options runtime \ - --entitlements entitlements.plist \ - "${APP_PATH}" - - codesign --verify --deep --strict --verbose=2 "${APP_PATH}" - - - name: Install create-dmg - run: brew install create-dmg - - - name: Create DMG - env: - DEVELOPER_ID: ${{ secrets.MACOS_DEVELOPER_ID }} - run: | - VERSION=$(python -c "from constants import APP_VERSION; print(APP_VERSION)") - DMG_NAME="SuperPicky_v${VERSION}_arm64_${COMMIT_HASH}.dmg" - DMG_TEMP="dist/dmg_temp" - - rm -rf "$DMG_TEMP" - mkdir -p "$DMG_TEMP" - - cp -R "${APP_PATH}" "${DMG_TEMP}/" - - if [ -d "SuperBirdIDPlugin.lrplugin" ]; then - cp -R "SuperBirdIDPlugin.lrplugin" "${DMG_TEMP}/" - fi - - # README 安装说明 - if [ -f "resources/DMG_README.txt" ]; then - cp "resources/DMG_README.txt" "${DMG_TEMP}/README.txt" - fi - - # 使用 create-dmg 生成带 Applications 图标的标准安装 DMG - create-dmg \ - --volname "SuperPicky ${VERSION}" \ - --window-pos 200 120 \ - --window-size 580 380 \ - --icon-size 100 \ - --icon "SuperPicky.app" 140 180 \ - --hide-extension "SuperPicky.app" \ - --app-drop-link 440 180 \ - --no-internet-enable \ - "dist/${DMG_NAME}" \ - "$DMG_TEMP" - - rm -rf "$DMG_TEMP" - - codesign --force --sign "$DEVELOPER_ID" --timestamp "dist/${DMG_NAME}" - - echo "DMG_PATH=dist/${DMG_NAME}" >> "$GITHUB_ENV" - - - name: Notarize and staple DMG - env: - APPLE_ID: ${{ secrets.APPLE_ID }} + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + MACOS_CERTIFICATE_PWD: ${{ secrets.MACOS_CERTIFICATE_PWD }} APPLE_APP_PASSWORD: ${{ secrets.APPLE_APP_PASSWORD }} - TEAM_ID: ${{ secrets.MACOS_TEAM_ID }} run: | - NOTARIZE_OUTPUT=$(xcrun notarytool submit "${DMG_PATH}" \ - --apple-id "$APPLE_ID" \ - --password "$APPLE_APP_PASSWORD" \ - --team-id "$TEAM_ID" \ - --wait \ - --output-format json 2>&1) - - echo "$NOTARIZE_OUTPUT" - - if ! echo "$NOTARIZE_OUTPUT" | grep -Eq '"status"[[:space:]]*:[[:space:]]*"Accepted"'; then - echo "Notarization failed." >&2 - exit 1 - fi - - xcrun stapler staple "${DMG_PATH}" - xcrun stapler validate "${DMG_PATH}" + python build_release_mac.py --build-type full --arch arm64 --copy-dir output/mac --sign-p12 $RUNNER_TEMP/certificate.p12 --sign-p12-password-env MACOS_CERTIFICATE_PWD --notarize --apple-id ${{ secrets.APPLE_ID }} --team-id ${{ secrets.MACOS_TEAM_ID }} - name: Build code patch zip - env: - TAG: ${{ steps.release_meta.outputs.tag }} - run: | - VERSION=$(python3 -c "from constants import APP_VERSION; print(APP_VERSION)") - PATCH_VERSION="${TAG}" - ZIP_NAME="code_patch_${PATCH_VERSION}.zip" - - # 与 Windows CI 保持一致:包含所有运行时 Python 文件 - # 排除 main.py(PyInstaller 入口,patch 无效) - # build_info.py 必须包含:CI 已注入 RELEASE_CHANNEL,patch 覆盖后不影响正确性 - # 若排除则 code_updates/core/ 遮蔽冻结包导致 ModuleNotFoundError - PATCH_ITEMS=( - constants.py advanced_config.py ai_model.py - birdid_server.py birdid_cli.py iqa_scorer.py - post_adjustment_engine.py server_manager.py - superpicky_cli.py topiq_model.py - tools/ core/ ui/ birdid/ locales/ - ) - EXISTING_ITEMS=() - for item in "${PATCH_ITEMS[@]}"; do - [ -e "$item" ] && EXISTING_ITEMS+=("$item") - done - - zip -r "dist/${ZIP_NAME}" "${EXISTING_ITEMS[@]}" \ - --exclude "**/__pycache__/*" --exclude "**/*.pyc" --exclude "**/*.pyo" - - APPLIED_AT=$(python3 -c "import datetime; print(datetime.datetime.now(datetime.timezone.utc).isoformat())") - printf '{"patch_version":"%s","base_version":"%s","applied_at":"%s"}\n' \ - "$PATCH_VERSION" "$VERSION" "$APPLIED_AT" > dist/patch_meta.json - - echo "PATCH_ZIP=dist/${ZIP_NAME}" >> "$GITHUB_ENV" - echo "PATCH_META=dist/patch_meta.json" >> "$GITHUB_ENV" - - - name: Push patch files to mirror server - continue-on-error: true - env: - MIRROR_SSH_KEY: ${{ secrets.MIRROR_SSH_KEY }} - TAG: ${{ steps.release_meta.outputs.tag }} - run: | - mkdir -p ~/.ssh - echo "$MIRROR_SSH_KEY" > ~/.ssh/mirror_key - chmod 600 ~/.ssh/mirror_key - - VERSION=$(python3 -c "from constants import APP_VERSION; print(APP_VERSION)") - PATCH_VERSION="${TAG}" - MIRROR_BASE="http://1.119.150.179:59080/superpicky" - - # 生成 latest.json(供中国大陆用户版本查询 fallback) - python3 -c "import json,datetime,sys; d={'tag_name':sys.argv[1],'version':sys.argv[2],'published_at':datetime.datetime.now(datetime.timezone.utc).isoformat(),'patch_meta_url':sys.argv[3]+'/patch_meta.json','patch_zip_url':sys.argv[3]+'/code_patch_'+sys.argv[4]+'.zip'}; json.dump(d,open('dist/latest.json','w'),indent=2)" "$TAG" "$VERSION" "$MIRROR_BASE" "$PATCH_VERSION" - - # 建立远端目录 - ssh -i ~/.ssh/mirror_key -p 22 \ - -o StrictHostKeyChecking=no \ - -o ConnectTimeout=10 \ - jordan@1.119.150.179 \ - "mkdir -p /opt/1panel/www/sites/github/index/superpicky" - - # 推送 latest.json、patch_meta.json、code_patch zip - scp -i ~/.ssh/mirror_key -P 22 \ - -o StrictHostKeyChecking=no \ - -o ConnectTimeout=10 \ - dist/latest.json \ - "${PATCH_META}" \ - "${PATCH_ZIP}" \ - jordan@1.119.150.179:/opt/1panel/www/sites/github/index/superpicky/ - - rm -f ~/.ssh/mirror_key + run: python scripts/ci_release.py build-patch --output-dir output/mac --patch-version ${{ steps.release_meta.outputs.tag }} - name: Clean up if: always() - run: | - rm -f _telemetry_build.py - security delete-keychain "$KEYCHAIN_PATH" 2>/dev/null || true + run: python scripts/ci_release.py cleanup-paths --path _telemetry_build.py --path $RUNNER_TEMP/certificate.p12 - name: Upload to GitHub Release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 with: tag_name: ${{ steps.release_meta.outputs.tag }} name: ${{ steps.release_meta.outputs.name }} body_path: ChangeLog.md + fail_on_unmatched_files: true files: | - ${{ env.DMG_PATH }} - ${{ env.PATCH_ZIP }} - ${{ env.PATCH_META }} - - - name: Upload patch files to GitCode - continue-on-error: true - env: - GITCODE_TOKEN: ${{ secrets.GITCODE_TOKEN }} - TAG: ${{ steps.release_meta.outputs.tag }} - run: | - BASE_URL="https://gitcode.com/api/v4/projects/Jamesphotography%2FSuperPicky" - FILE_BASE="https://gitcode.com/Jamesphotography/SuperPicky/-/package_files/generic/release" - PATCH_ZIP_NAME=$(basename "${PATCH_ZIP}") - - # 上传三个小文件到 GitCode generic packages(不上传大 DMG,避免 413) - for FILE in dist/latest.json "${PATCH_META}" "${PATCH_ZIP}"; do - FILENAME=$(basename "$FILE") - curl -s --fail -T "$FILE" \ - --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - "$BASE_URL/packages/generic/release/$TAG/$FILENAME" || true - done - - # 尝试在 GitCode 创建 Release(Windows job 可能已抢先创建,忽略冲突) - curl -s \ - --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - --header "Content-Type: application/json" \ - --request POST \ - --data "{\"tag_name\":\"$TAG\",\"name\":\"SuperPicky $TAG\",\"description\":\"See GitHub release for details.\"}" \ - "$BASE_URL/releases" || true - - # 添加补丁文件 asset links 到 GitCode Release - curl -s \ - --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - --header "Content-Type: application/json" \ - --request POST \ - --data "{\"name\":\"patch_meta.json\",\"url\":\"$FILE_BASE/$TAG/patch_meta.json\",\"link_type\":\"package\"}" \ - "$BASE_URL/releases/$TAG/assets/links" || true - - curl -s \ - --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - --header "Content-Type: application/json" \ - --request POST \ - --data "{\"name\":\"$PATCH_ZIP_NAME\",\"url\":\"$FILE_BASE/$TAG/$PATCH_ZIP_NAME\",\"link_type\":\"package\"}" \ - "$BASE_URL/releases/$TAG/assets/links" || true - - curl -s \ - --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - --header "Content-Type: application/json" \ - --request POST \ - --data "{\"name\":\"latest.json\",\"url\":\"$FILE_BASE/$TAG/latest.json\",\"link_type\":\"package\"}" \ - "$BASE_URL/releases/$TAG/assets/links" || true + output/mac/*.dmg + output/mac/code_patch_*.zip + output/mac/patch_meta.json diff --git a/.gitignore b/.gitignore index a7aea6cc..66c858b7 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ build_dist*/* dist*/* !dist/*.dmg build/ +.python-version # IDE .idea/copilot.* diff --git a/AGENTS.md b/AGENTS.md index f957debe..4a086ca7 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,22 +1,204 @@ -# AGENTS.md (Codex / OpenAI Coding Agents) +## 第一性原理 / First Principles -Follow `scripts_dev/AI_CODING_RULES.md` as the project baseline. +请使用第一性原理思考。你不能总是假设我非常清楚自己想要什么和该怎么得到。请保持审慎,从原始需求和问题出发,如果动机和目标不清晰,停下来和我讨论。 +Please use first principles thinking. You should not assume that I always know exactly what I want or how to achieve it. Be cautious and start from the original needs and problems. If the motivation and goals are unclear, stop and discuss with me. -## Mandatory Project Constraints +## 技术方案规范 / Technical Solution Specifications -- Keep files in UTF-8; avoid introducing mojibake. -- For ExifTool non-ASCII metadata writes, prefer UTF-8 temp-file redirection (`-Tag<=file`) over inline command args. -- Preserve Windows/macOS compatibility for paths and subprocess behavior. -- For SQLite in threaded code: either serialize shared-connection access with a lock or use per-thread connections; never assume `check_same_thread=False` is enough. -- Do not directly access private DB connection internals from business code (e.g., `report_db._conn.*`); add thread-safe wrapper methods instead. -- Ensure transaction handling is consistent and defensive (avoid unsynchronized mixed transaction styles; commit only when valid). -- Ensure persistent external processes (like `exiftool -stay_open`) have explicit shutdown and are closed on exit. -- For packaged-only CUDA issues, first suspect packaging/runtime differences. -- In Windows PyInstaller spec for Torch/CUDA, keep `upx=False` unless explicitly re-validated. +当需要你给出修改或者重构方案时必须符合以下规范: +The following specifications must be followed when giving modification or refactoring plans: -## Validation Minimum +* 你是技术专家,所以设计方案时要使用各种工具查询网络资料,确定基本事实,不要给出虚假观点。 + You are a technical expert, so when designing solutions, use various tools to check online resources and ensure the basic facts are correct. Do not provide false opinions. +* 除非我很确定,不然不能随意迁就我的观点,因为我的观点很可能是错的,需要基于基本事实有理有据的说服我同意你的新观点。 + Unless I am very sure, do not easily accommodate my opinions because they may be wrong. You need to convince me to agree with your new views based on facts. +* 给出兼容性或者补丁性的方案时需要给出确定性的理由与我讨论。 + When proposing compatibility or patch solutions, provide definitive reasons for discussion. +* 必须确保方案的逻辑正确,必须经过全链路的逻辑验证。 + Ensure that the solution is logically correct and has been verified across the entire system. -- Run `py -3 -m py_compile` on changed Python files. +## 编码规范 / Coding Specifications + +所有文件读写均需要满足如下规范: +All file reading and writing must meet the following specifications: + +* 使用UTF-8编码,强制所有的中文输出,均为UTF-8。 + Use UTF-8 encoding, and enforce all Chinese output to be UTF-8. +* 在PowerShell中读取含有中文的文件时,限制性** **`chcp 65001`并设置UTF-8输出。 + When reading Chinese files in PowerShell, use** **`chcp 65001` and set UTF-8 output. +* 读取时用** **`open(file, 'r', encoding='utf-8')`方式读取。 + Use** **`open(file, 'r', encoding='utf-8')` to read files. +* 不要使用shell脚本(如sed/awk)处理含中文的文件,优先使用Python(Python 3.x),如果Python环境无法满足需求,再考虑其他语言,最后才考虑PowerShell。 + Do not use shell scripts (like sed/awk) to handle files with Chinese characters. Prefer Python (Python 3.x), and if Python environment cannot meet the requirements, consider other languages, and only as a last resort consider PowerShell. + +## 代码规范 / Code Specifications + +所有代码增删查改均需要满足如下规范: +All code changes (addition, deletion, modification) must meet the following specifications: + +* 先阅读相关代码段落,预先评估代码修改量,如果发现改动文件过多,或者改动量很大,提前分成几个小部分进行修补,避免系统拒绝修补。 + First, read the relevant code sections, assess the extent of the changes, and if too many files are affected or the changes are too large, break them down into smaller parts to avoid rejection by the system. +* 代码按照逻辑顺序进行修补,避免改完之后又回头改。 + Make code changes in logical order to avoid having to go back and modify things again. +* 代码改动完毕后要重新整体阅读全链路,避免出现变量函数未定义未声明导致编译不通过。 + After code changes, review the entire system to ensure there are no undefined or undeclared variables or functions that could cause compilation errors. +* 代码优化精简的时候需要按照逻辑顺序对变量函数进行重排,方便维护者从上到下进行阅读。 + When optimizing and simplifying the code, rearrange variables and functions in logical order to make it easier for maintainers to read from top to bottom. +* 跨文件代码边界维护要清晰分明,高内聚低耦合。 + Maintain clear boundaries for cross-file code, ensuring high cohesion and low coupling. +* 在Python中,避免使用全局变量。优先选择函数或类封装,保持数据和功能分离。 + In Python, avoid using global variables. Prefer encapsulation in functions or classes to separate data and functionality. + +## 注释规范 / Commenting Specifications + +所有注释增删查改均需要满足如下规范: +All comment changes (addition, deletion, modification) must meet the following specifications: + +* 如果没有额外指定,请使用UTF-8编码的中文注释 + 相同格式的英文注释。 + If not otherwise specified, use UTF-8 encoded Chinese comments + corresponding English comments in the same format. +* 需要给出详细且必要的功能说明,增加可维护性,让不熟悉相关类型代码的人也能看懂。 + Provide detailed and necessary functional descriptions to increase maintainability, so that those unfamiliar with the relevant code can understand it. +* 使用docstring格式进行函数、类注释,确保清晰描述函数的功能、参数、返回值及可能的异常。 + Use docstring format for function and class comments, ensuring clear descriptions of the function's functionality, parameters, return values, and possible exceptions. + +```python +def example_function(param: int) -> str: + """ + 这是一个示例函数,接受一个整数作为输入,返回字符串。 + + 参数: + param (int): 输入的整数 + + 返回: + str: 返回一个简单的字符串,表示输入的平方值 + + This is a sample function that takes an integer as input and returns a string. + + Parameters: + param (int): The integer to input + + Return: + str: Returns a simple string representing the square of the input. + """ + + return f"The square is {param ** 2}" +``` + +## 总结汇报规范 / Summary Reporting Specifications + +所有的总结汇报均需要满足如下规范: +All summary reports must meet the following specifications: + +* 改动部分请加上具体文件的行号,如果涉及多个跨行的改动,给出相关段落,方便进行查找。 + Specify the line numbers of the changed parts, and provide relevant sections for easy search if multiple lines are involved. +* 对于Python项目,考虑到代码可能涉及模块导入、功能封装等,需要明确指出哪些模块或类的修改或新增影响了其他模块的功能。 + For Python projects, since the code may involve module imports and function encapsulation, clearly indicate which module or class changes or additions affect the functionality of other modules. + +## Python使用规范 / Python Usage Specifications + +在使用Python语言时均需要满足如下规范: +The following specifications must be met when using Python: + +* **类型注解 / Type Annotations** :尽量使用类型注解(Python 3.x),以增强代码可读性和静态检查工具的支持。例如,函数的输入和输出应该明确标注类型。 + **Type annotations** : Try to use type annotations (Python 3.x) to enhance code readability and static analysis tool support. For example, the input and output of functions should clearly annotate their types. + +```python + def add_numbers(a: int, b: int) -> int: + return a + b +``` + +* **避免使用过于宽泛的类型标注 / Avoid overly broad type annotations** :Python中不存在** **`any`类型,但要尽量避免过于宽泛的类型标注。 + Python does not have an** **`any` type, but avoid overly broad type annotations whenever possible. +* **操作用户文件规范 / User File Operations** :当使用代码操作用户系统中的文件时,要使用安全的方法,并注意权限。对于配置文件的存放位置应该局限在一个文件夹内,不要在用户的文件夹中到处存放零星文件。 + When manipulating user files, use secure methods and be mindful of permissions. The storage location for configuration files should be limited to a single folder, and avoid scattering files across the user's directories. +* **遵循PEP8规范 / Follow PEP8** :始终遵循Python的官方代码风格PEP8,并且使用自动化工具(如** **`black`)进行格式化。 + Always follow the official Python coding style PEP8 and use automation tools (like** **`black`) for formatting. +* **严格使用UTF-8 / Strict Use of UTF-8** :始终遵循Python的官方代码标准PEP686,始终使用 UTF-8 作为文件、标准输入输出和管道的默认编码。 + Always follow Python's official code standard PEP686, and use UTF-8 as the default encoding for files, standard input/output, and pipes. +* **注重安全性 / Focus on Security** :避免直接执行来自不可信来源的代码,如避免使用** **`eval()`或** **`exec()`等函数。使用适当的输入验证和参数化查询,避免SQL注入、XSS等安全漏洞。 + Avoid executing code from untrusted sources, such as using** **`eval()` or** **`exec()`. Use proper input validation and parameterized queries to avoid SQL injection, XSS, and other security vulnerabilities. + +```python + import sqlite3 + connection = sqlite3.connect('database.db') + cursor = connection.cursor() + + # 避免 SQL 注入,使用参数化查询 + cursor.execute("SELECT * FROM users WHERE username = ?", (username,)) +``` + +* **异常处理 / Exception Handling** :要优雅地处理可能的错误和异常,避免程序崩溃。优先使用Python标准库提供的异常机制。 + Handle potential errors and exceptions gracefully to avoid crashes. Use Python's standard exception mechanisms first. + +```python + try: + result = 10 / 0 + except ZeroDivisionError as e: + print(f"Error occurred: {e}") +``` + +## Python 3 环境配置与工具使用规范 / Python 3 Environment Setup and Tool Usage Specifications + +为了避免Python 3工具默认使用系统中的Python环境(可能导致许多不可预料的问题),请务必采用以下规范进行配置: + +* **使用虚拟环境 / Virtual Environment** :优先使用 `venv`或 `conda`等工具创建独立的Python环境,避免使用系统全局环境。 + Prefer using** **`venv` or** **`conda` to create isolated Python environments, avoiding the use of the system's global environment. +* **确保包管理一致性 / Ensure Package Management Consistency** :在项目中使用 `pip`来管理依赖,确保依赖版本的一致性,避免版本冲突和意外问题。 + Use** **`pip` to manage dependencies in the project, ensuring version consistency and avoiding conflicts and unexpected issues. +* **工具使用推荐 / Recommended Tool Usage** :为了避免依赖于系统环境的Python,建议使用虚拟环境中的解释器进行构建和运行。 + To avoid relying on the system environment's Python, it is recommended to use the interpreter in the virtual environment for builds and executions. + +## 多系统规范 / Multi-System Specifications + +### 1. 避免多系统之间的差异导致程序出现无法运行甚至安全漏洞 / Avoid System-Specific Differences Leading to Errors or Security Vulnerabilities + +- 在开发跨平台应用时,需避免代码中因操作系统差异(如Windows与Linux、macOS之间的差异)导致程序无法运行或出现安全漏洞。 + When developing cross-platform applications, avoid code differences that cause errors or security vulnerabilities due to differences between operating systems (e.g., Windows vs. Linux or macOS). +- **路径问题**:文件路径的格式在不同操作系统间有所不同。确保使用跨平台兼容的路径分隔符,推荐使用Python的 `os.path`模块,或 `pathlib`模块来自动处理路径分隔符。 + **Path Issues**: File path formats differ across operating systems. Ensure the use of cross-platform compatible path separators. It is recommended to use Python's `os.path` or `pathlib` modules to automatically handle path separators. + + ``` + from pathlib import Path + + file_path = Path("some_folder") / "file.txt" # This works across all OS + ``` +- **换行符问题**:Windows和类Unix系统的换行符不同。 + **Line Endings**: Line endings differ between Windows and Unix-based systems. + +### 2. 不同系统的文件存储策略和文件夹权限管理不同,需要提前预防 / Different Systems Have Different File Storage and Folder Permissions + +- 在设计涉及文件存储和访问的应用时,需注意不同操作系统对文件权限和路径访问的管理差异。Windows、Linux和macOS在文件权限、符号链接和隐藏文件的处理上有所不同。 + When designing applications that involve file storage and access, be aware of the differences in file permission and path access management across operating systems. Windows, Linux, and macOS handle file permissions, symlinks, and hidden files differently. +- **权限问题**:Linux和macOS有严格的文件权限控制,而Windows则使用ACL(访问控制列表)来管理权限。确保文件的读写权限适合所使用的操作系统,并且文件夹权限应在应用设计时进行适当配置。 + **Permission Issues**: Linux and macOS have strict file permission controls, while Windows uses ACLs (Access Control Lists) for permission management. Ensure that file read/write permissions are suitable for the operating system in use, and folder permissions should be appropriately configured during application design. + +### 3. 避免大量使用PowerShell代码 / Avoid Excessive Use of PowerShell Code + +- PowerShell主要是Windows环境下使用的脚本语言,避免在跨平台项目中广泛使用PowerShell。为了确保程序的兼容性,尽量使用Python脚本或其他语言。 + PowerShell is primarily used in Windows environments. Avoid using PowerShell extensively in cross-platform projects. To ensure compatibility, try to use Python scripts or other languages instead. +- 如果必须使用PowerShell,请确保通过条件语句检查操作系统类型,并仅在Windows系统中执行相关命令。 + If PowerShell must be used, ensure that conditional statements are used to check the operating system and only execute related commands on Windows systems. + + ``` + import platform + + if platform.system() == "Windows": + # Execute PowerShell command + pass + ``` + +## Always Enforce + +- UTF-8 safety first; do not introduce Chinese text corruption. +- ExifTool Chinese metadata writes must use UTF-8 temp files (`-XMP:Title<=tmp.txt`) instead of inline CLI values. +- Keep changes cross-platform (Windows + macOS). +- Any persistent external process must have deterministic cleanup on task/app exit. +- Packaged CUDA failures: prioritize packaging/runtime diagnosis before algorithm refactors. +- Keep Windows Torch/CUDA packaging with `upx=False` unless explicitly requested and validated. + +## Minimum Verification + +- Run `.venv*/bin/python -m py_compile` on changed Python files. - For metadata changes: write + read-back verification with Chinese sample values. - For `.spec` changes: packaged startup smoke test. - For DB/threading changes: run a small multi-thread write/read stress check and confirm no transaction-state errors. diff --git a/CLAUDE.md b/CLAUDE.md index fd3176d5..216e850c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,7 +13,199 @@ Use `scripts_dev/AI_CODING_RULES.md` as the single source of truth for this repo ## Minimum Verification -- `py -3 -m py_compile` for changed Python modules. -- Metadata write/read-back check for non-ASCII fields. -- Packaged app smoke test when `.spec` or runtime packaging behavior changes. +- Run `.venv*/bin/python -m py_compile` on changed Python files. +- For metadata changes: write + read-back verification with Chinese sample values. +- For `.spec` changes: packaged startup smoke test. +- For DB/threading changes: run a small multi-thread write/read stress check and confirm no transaction-state errors. +## 第一性原理 / First Principles + +请使用第一性原理思考。你不能总是假设我非常清楚自己想要什么和该怎么得到。请保持审慎,从原始需求和问题出发,如果动机和目标不清晰,停下来和我讨论。 +Please use first principles thinking. You should not assume that I always know exactly what I want or how to achieve it. Be cautious and start from the original needs and problems. If the motivation and goals are unclear, stop and discuss with me. + +## 技术方案规范 / Technical Solution Specifications + +当需要你给出修改或者重构方案时必须符合以下规范: +The following specifications must be followed when giving modification or refactoring plans: + +* 你是技术专家,所以设计方案时要使用各种工具查询网络资料,确定基本事实,不要给出虚假观点。 + You are a technical expert, so when designing solutions, use various tools to check online resources and ensure the basic facts are correct. Do not provide false opinions. +* 除非我很确定,不然不能随意迁就我的观点,因为我的观点很可能是错的,需要基于基本事实有理有据的说服我同意你的新观点。 + Unless I am very sure, do not easily accommodate my opinions because they may be wrong. You need to convince me to agree with your new views based on facts. +* 给出兼容性或者补丁性的方案时需要给出确定性的理由与我讨论。 + When proposing compatibility or patch solutions, provide definitive reasons for discussion. +* 必须确保方案的逻辑正确,必须经过全链路的逻辑验证。 + Ensure that the solution is logically correct and has been verified across the entire system. + +## 编码规范 / Coding Specifications + +所有文件读写均需要满足如下规范: +All file reading and writing must meet the following specifications: + +* 使用UTF-8编码,强制所有的中文输出,均为UTF-8。 + Use UTF-8 encoding, and enforce all Chinese output to be UTF-8. +* 在PowerShell中读取含有中文的文件时,限制性** **`chcp 65001`并设置UTF-8输出。 + When reading Chinese files in PowerShell, use** **`chcp 65001` and set UTF-8 output. +* 读取时用** **`open(file, 'r', encoding='utf-8')`方式读取。 + Use** **`open(file, 'r', encoding='utf-8')` to read files. +* 不要使用shell脚本(如sed/awk)处理含中文的文件,优先使用Python(Python 3.x),如果Python环境无法满足需求,再考虑其他语言,最后才考虑PowerShell。 + Do not use shell scripts (like sed/awk) to handle files with Chinese characters. Prefer Python (Python 3.x), and if Python environment cannot meet the requirements, consider other languages, and only as a last resort consider PowerShell. + +## 代码规范 / Code Specifications + +所有代码增删查改均需要满足如下规范: +All code changes (addition, deletion, modification) must meet the following specifications: + +* 先阅读相关代码段落,预先评估代码修改量,如果发现改动文件过多,或者改动量很大,提前分成几个小部分进行修补,避免系统拒绝修补。 + First, read the relevant code sections, assess the extent of the changes, and if too many files are affected or the changes are too large, break them down into smaller parts to avoid rejection by the system. +* 代码按照逻辑顺序进行修补,避免改完之后又回头改。 + Make code changes in logical order to avoid having to go back and modify things again. +* 代码改动完毕后要重新整体阅读全链路,避免出现变量函数未定义未声明导致编译不通过。 + After code changes, review the entire system to ensure there are no undefined or undeclared variables or functions that could cause compilation errors. +* 代码优化精简的时候需要按照逻辑顺序对变量函数进行重排,方便维护者从上到下进行阅读。 + When optimizing and simplifying the code, rearrange variables and functions in logical order to make it easier for maintainers to read from top to bottom. +* 跨文件代码边界维护要清晰分明,高内聚低耦合。 + Maintain clear boundaries for cross-file code, ensuring high cohesion and low coupling. +* 在Python中,避免使用全局变量。优先选择函数或类封装,保持数据和功能分离。 + In Python, avoid using global variables. Prefer encapsulation in functions or classes to separate data and functionality. + +## 注释规范 / Commenting Specifications + +所有注释增删查改均需要满足如下规范: +All comment changes (addition, deletion, modification) must meet the following specifications: + +* 如果没有额外指定,请使用UTF-8编码的中文注释 + 相同格式的英文注释。 + If not otherwise specified, use UTF-8 encoded Chinese comments + corresponding English comments in the same format. +* 需要给出详细且必要的功能说明,增加可维护性,让不熟悉相关类型代码的人也能看懂。 + Provide detailed and necessary functional descriptions to increase maintainability, so that those unfamiliar with the relevant code can understand it. +* 使用docstring格式进行函数、类注释,确保清晰描述函数的功能、参数、返回值及可能的异常。 + Use docstring format for function and class comments, ensuring clear descriptions of the function's functionality, parameters, return values, and possible exceptions. + +```python +def example_function(param: int) -> str: + """ + 这是一个示例函数,接受一个整数作为输入,返回字符串。 + + 参数: + param (int): 输入的整数 + + 返回: + str: 返回一个简单的字符串,表示输入的平方值 + + This is a sample function that takes an integer as input and returns a string. + + Parameters: + param (int): The integer to input + + Return: + str: Returns a simple string representing the square of the input. + """ + + return f"The square is {param ** 2}" +``` + +## 总结汇报规范 / Summary Reporting Specifications + +所有的总结汇报均需要满足如下规范: +All summary reports must meet the following specifications: + +* 改动部分请加上具体文件的行号,如果涉及多个跨行的改动,给出相关段落,方便进行查找。 + Specify the line numbers of the changed parts, and provide relevant sections for easy search if multiple lines are involved. +* 对于Python项目,考虑到代码可能涉及模块导入、功能封装等,需要明确指出哪些模块或类的修改或新增影响了其他模块的功能。 + For Python projects, since the code may involve module imports and function encapsulation, clearly indicate which module or class changes or additions affect the functionality of other modules. + +## Python使用规范 / Python Usage Specifications + +在使用Python语言时均需要满足如下规范: +The following specifications must be met when using Python: + +* **类型注解 / Type Annotations** :尽量使用类型注解(Python 3.x),以增强代码可读性和静态检查工具的支持。例如,函数的输入和输出应该明确标注类型。 + **Type annotations** : Try to use type annotations (Python 3.x) to enhance code readability and static analysis tool support. For example, the input and output of functions should clearly annotate their types. + +```python + def add_numbers(a: int, b: int) -> int: + return a + b +``` + +* **避免使用过于宽泛的类型标注 / Avoid overly broad type annotations** :Python中不存在** **`any`类型,但要尽量避免过于宽泛的类型标注。 + Python does not have an** **`any` type, but avoid overly broad type annotations whenever possible. +* **操作用户文件规范 / User File Operations** :当使用代码操作用户系统中的文件时,要使用安全的方法,并注意权限。对于配置文件的存放位置应该局限在一个文件夹内,不要在用户的文件夹中到处存放零星文件。 + When manipulating user files, use secure methods and be mindful of permissions. The storage location for configuration files should be limited to a single folder, and avoid scattering files across the user's directories. +* **遵循PEP8规范 / Follow PEP8** :始终遵循Python的官方代码风格PEP8,并且使用自动化工具(如** **`black`)进行格式化。 + Always follow the official Python coding style PEP8 and use automation tools (like** **`black`) for formatting. +* **严格使用UTF-8 / Strict Use of UTF-8** :始终遵循Python的官方代码标准PEP686,始终使用 UTF-8 作为文件、标准输入输出和管道的默认编码。 + Always follow Python's official code standard PEP686, and use UTF-8 as the default encoding for files, standard input/output, and pipes. +* **注重安全性 / Focus on Security** :避免直接执行来自不可信来源的代码,如避免使用** **`eval()`或** **`exec()`等函数。使用适当的输入验证和参数化查询,避免SQL注入、XSS等安全漏洞。 + Avoid executing code from untrusted sources, such as using** **`eval()` or** **`exec()`. Use proper input validation and parameterized queries to avoid SQL injection, XSS, and other security vulnerabilities. + +```python + import sqlite3 + connection = sqlite3.connect('database.db') + cursor = connection.cursor() + + # 避免 SQL 注入,使用参数化查询 + cursor.execute("SELECT * FROM users WHERE username = ?", (username,)) +``` + +* **异常处理 / Exception Handling** :要优雅地处理可能的错误和异常,避免程序崩溃。优先使用Python标准库提供的异常机制。 + Handle potential errors and exceptions gracefully to avoid crashes. Use Python's standard exception mechanisms first. + +```python + try: + result = 10 / 0 + except ZeroDivisionError as e: + print(f"Error occurred: {e}") +``` + +## Python 3 环境配置与工具使用规范 / Python 3 Environment Setup and Tool Usage Specifications + +为了避免Python 3工具默认使用系统中的Python环境(可能导致许多不可预料的问题),请务必采用以下规范进行配置: + +* **使用虚拟环境 / Virtual Environment** :优先使用 `venv`或 `conda`等工具创建独立的Python环境,避免使用系统全局环境。 + Prefer using** **`venv` or** **`conda` to create isolated Python environments, avoiding the use of the system's global environment. +* **确保包管理一致性 / Ensure Package Management Consistency** :在项目中使用 `pip`来管理依赖,确保依赖版本的一致性,避免版本冲突和意外问题。 + Use** **`pip` to manage dependencies in the project, ensuring version consistency and avoiding conflicts and unexpected issues. +* **工具使用推荐 / Recommended Tool Usage** :为了避免依赖于系统环境的Python,建议使用虚拟环境中的解释器进行构建和运行。 + To avoid relying on the system environment's Python, it is recommended to use the interpreter in the virtual environment for builds and executions. + +## 多系统规范 / Multi-System Specifications + +### 1. 避免多系统之间的差异导致程序出现无法运行甚至安全漏洞 / Avoid System-Specific Differences Leading to Errors or Security Vulnerabilities + +- 在开发跨平台应用时,需避免代码中因操作系统差异(如Windows与Linux、macOS之间的差异)导致程序无法运行或出现安全漏洞。 + When developing cross-platform applications, avoid code differences that cause errors or security vulnerabilities due to differences between operating systems (e.g., Windows vs. Linux or macOS). + +- **路径问题**:文件路径的格式在不同操作系统间有所不同。确保使用跨平台兼容的路径分隔符,推荐使用Python的 `os.path`模块,或 `pathlib`模块来自动处理路径分隔符。 + **Path Issues**: File path formats differ across operating systems. Ensure the use of cross-platform compatible path separators. It is recommended to use Python's `os.path` or `pathlib` modules to automatically handle path separators. + + ``` + from pathlib import Path + + file_path = Path("some_folder") / "file.txt" # This works across all OS + ``` + +- **换行符问题**:Windows和类Unix系统的换行符不同。 + **Line Endings**: Line endings differ between Windows and Unix-based systems. + +### 2. 不同系统的文件存储策略和文件夹权限管理不同,需要提前预防 / Different Systems Have Different File Storage and Folder Permissions + +- 在设计涉及文件存储和访问的应用时,需注意不同操作系统对文件权限和路径访问的管理差异。Windows、Linux和macOS在文件权限、符号链接和隐藏文件的处理上有所不同。 + When designing applications that involve file storage and access, be aware of the differences in file permission and path access management across operating systems. Windows, Linux, and macOS handle file permissions, symlinks, and hidden files differently. +- **权限问题**:Linux和macOS有严格的文件权限控制,而Windows则使用ACL(访问控制列表)来管理权限。确保文件的读写权限适合所使用的操作系统,并且文件夹权限应在应用设计时进行适当配置。 + **Permission Issues**: Linux and macOS have strict file permission controls, while Windows uses ACLs (Access Control Lists) for permission management. Ensure that file read/write permissions are suitable for the operating system in use, and folder permissions should be appropriately configured during application design. + +### 3. 避免大量使用PowerShell代码 / Avoid Excessive Use of PowerShell Code + +- PowerShell主要是Windows环境下使用的脚本语言,避免在跨平台项目中广泛使用PowerShell。为了确保程序的兼容性,尽量使用Python脚本或其他语言。 + PowerShell is primarily used in Windows environments. Avoid using PowerShell extensively in cross-platform projects. To ensure compatibility, try to use Python scripts or other languages instead. + +- 如果必须使用PowerShell,请确保通过条件语句检查操作系统类型,并仅在Windows系统中执行相关命令。 + If PowerShell must be used, ensure that conditional statements are used to check the operating system and only execute related commands on Windows systems. + + ``` + import platform + + if platform.system() == "Windows": + # Execute PowerShell command + pass + ``` diff --git a/SuperPicky.spec b/SuperPicky.spec index 2c362f21..0f24fbe9 100644 --- a/SuperPicky.spec +++ b/SuperPicky.spec @@ -5,6 +5,14 @@ import sys sys.path.append(os.path.abspath('.')) from constants import APP_VERSION + +def _env_or_none(name): + value = os.environ.get(name, "").strip() + return value or None + + +APP_VERSION = os.environ.get("SUPERPICKY_APP_VERSION", APP_VERSION) + # 获取当前工作目录 base_path = os.path.abspath('.') @@ -161,9 +169,9 @@ exe = EXE( console=False, disable_windowed_traceback=False, argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, + target_arch=_env_or_none("SUPERPICKY_TARGET_ARCH"), + codesign_identity=_env_or_none("SUPERPICKY_CODESIGN_IDENTITY"), + entitlements_file=_env_or_none("SUPERPICKY_ENTITLEMENTS_FILE"), icon=os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns') if os.path.exists(os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns')) else None, ) diff --git a/SuperPicky_full.spec b/SuperPicky_full.spec new file mode 100644 index 00000000..5f67072f --- /dev/null +++ b/SuperPicky_full.spec @@ -0,0 +1,5 @@ +# Full-package compatibility wrapper. +# We intentionally keep the legacy full build spec unchanged and expose a new +# explicit entrypoint so release automation can choose between full/lite builds. +exec(open("SuperPicky.spec", "r", encoding="utf-8").read()) + diff --git a/SuperPicky_lite.spec b/SuperPicky_lite.spec new file mode 100644 index 00000000..05f08ca3 --- /dev/null +++ b/SuperPicky_lite.spec @@ -0,0 +1,194 @@ +import os +import site +from PyInstaller.utils.hooks import collect_data_files, copy_metadata +import sys +sys.path.append(os.path.abspath('.')) +from constants import APP_VERSION + + +def _env_or_none(name): + value = os.environ.get(name, "").strip() + return value or None + + +def _optional_copy_metadata(package_name): + try: + return copy_metadata(package_name) + except Exception: + return [] + + +APP_VERSION = os.environ.get("SUPERPICKY_APP_VERSION", APP_VERSION) + +base_path = os.path.abspath('.') +sp = [p for p in site.getsitepackages() if os.path.isdir(p)] +site_packages = sp[0] if sp else site.getusersitepackages() + +ultralytics_base = site_packages +if not os.path.exists(os.path.join(ultralytics_base, 'ultralytics')): + try: + import ultralytics + ultralytics_base = os.path.dirname(os.path.dirname(ultralytics.__file__)) + except ImportError: + pass + +ultralytics_datas = collect_data_files('ultralytics') +imageio_datas = collect_data_files('imageio') +rawpy_datas = collect_data_files('rawpy') +pillow_heif_datas = collect_data_files('pillow_heif') + +all_datas = [ + # Lite keeps the AI runtime bundled for startup stability while still + # allowing the first-run flow to fetch missing resources on demand. + (os.path.join(base_path, 'exiftools_mac'), 'exiftools_mac'), + (os.path.join(base_path, 'img'), 'img'), + (os.path.join(base_path, 'locales'), 'locales'), + (os.path.join(base_path, 'locales', 'en.lproj'), 'en.lproj'), + (os.path.join(base_path, 'locales', 'zh-Hans.lproj'), 'zh-Hans.lproj'), + (os.path.join(base_path, 'models', 'yolo11l-seg.pt'), 'models'), + (os.path.join(base_path, 'birdid', 'data', 'bird_reference.sqlite'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_classid_mapping.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_regions.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'offline_ebird_data'), 'birdid/data/offline_ebird_data'), + (os.path.join(ultralytics_base, 'ultralytics/cfg'), 'ultralytics/cfg'), + (os.path.join(base_path, 'SuperBirdIDPlugin.lrplugin'), 'SuperBirdIDPlugin.lrplugin'), + (os.path.join(base_path, 'ioc'), 'ioc'), +] + +all_datas.extend(ultralytics_datas) +all_datas.extend(imageio_datas) +all_datas.extend(rawpy_datas) +all_datas.extend(pillow_heif_datas) +all_datas.extend(_optional_copy_metadata('imageio')) +all_datas.extend(_optional_copy_metadata('rawpy')) +all_datas.extend(_optional_copy_metadata('ultralytics')) +all_datas.extend(_optional_copy_metadata('pillow_heif')) +all_datas.extend(_optional_copy_metadata('pi_heif')) + +app_hiddenimports = [ + 'ultralytics', + 'torch', + 'torchvision', + 'torchvision.models', + 'torchvision.transforms', + 'torchvision.transforms.functional', + 'torchaudio', + 'PIL', + 'cv2', + 'numpy', + 'yaml', + 'matplotlib', + 'matplotlib.pyplot', + 'matplotlib.backends.backend_agg', + 'PySide6', + 'PySide6.QtCore', + 'PySide6.QtGui', + 'PySide6.QtWidgets', + 'timm', + 'timm.models', + 'timm.models.resnet', + 'imageio', + 'rawpy', + 'imagehash', + 'pywt', + 'pillow_heif', + 'pi_heif', + 'core', + 'core.burst_detector', + 'core.config_manager', + 'core.exposure_detector', + 'core.file_manager', + 'core.flight_detector', + 'core.focus_point_detector', + 'core.initialization_manager', + 'core.keypoint_detector', + 'core.photo_processor', + 'core.rating_engine', + 'core.source_probe', + 'core.stats_formatter', + 'multiprocessing', + 'multiprocessing.spawn', + 'tools.update_checker', + 'packaging', + 'packaging.version', + 'birdid', + 'birdid.bird_identifier', + 'birdid.ebird_country_filter', + 'birdid_server', + 'server_manager', + 'flask', + 'flask.json', + 'cryptography', + 'cryptography.fernet', + '_telemetry_build', + 'app_user_stat._telemetry_build', + 'app_user_stat', + 'app_user_stat.telemetry', + 'app_user_stat.consent_texts', + 'app_user_stat.consent_texts.en_US', + 'app_user_stat.consent_texts.zh_CN', +] + +a = Analysis( + ['main.py'], + pathex=[base_path], + binaries=[], + datas=all_datas, + hiddenimports=app_hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=['pyi_rth_cv2.py'] if os.path.exists('pyi_rth_cv2.py') else [], + excludes=[ + 'PyQt5', 'PyQt6', 'tkinter', + 'polars', 'numba', 'llvmlite', 'pyarrow', 'facexlib', 'datasets', + ], + noarchive=False, + optimize=0, +) + +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='SuperPickyLite', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=_env_or_none("SUPERPICKY_TARGET_ARCH"), + codesign_identity=_env_or_none("SUPERPICKY_CODESIGN_IDENTITY"), + entitlements_file=_env_or_none("SUPERPICKY_ENTITLEMENTS_FILE"), + icon=os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns') if os.path.exists(os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns')) else None, +) + +coll = COLLECT( + exe, + a.binaries, + a.datas, + strip=False, + upx=False, + upx_exclude=[], + name='SuperPickyLite', +) + +app = BUNDLE( + coll, + name='SuperPickyLite.app', + icon=os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns') if os.path.exists(os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns')) else None, + bundle_identifier='com.jamesphotography.superpicky.lite', + info_plist={ + 'CFBundleName': 'SuperPickyLite', + 'CFBundleDisplayName': 'SuperPickyLite', + 'CFBundleVersion': APP_VERSION, + 'CFBundleShortVersionString': APP_VERSION, + 'NSHighResolutionCapable': True, + 'NSAppleEventsUsageDescription': '慧眼选鸟需要发送 AppleEvents 与其他应用通信。', + 'NSAppleScriptEnabled': False, + }, +) diff --git a/SuperPicky_lite_win.spec b/SuperPicky_lite_win.spec new file mode 100644 index 00000000..3cdfe74b --- /dev/null +++ b/SuperPicky_lite_win.spec @@ -0,0 +1,223 @@ +import os +import site +import sys +from PyInstaller.utils.hooks import collect_data_files, copy_metadata + +sys.path.append(os.path.abspath('.')) + +base_path = os.path.abspath('.') + + +def _optional_copy_metadata(package_name): + try: + return copy_metadata(package_name) + except Exception: + return [] + + +sp = [p for p in site.getsitepackages() if os.path.isdir(p)] +site_packages = sp[0] if sp else site.getusersitepackages() + +ultralytics_base = site_packages +if not os.path.exists(os.path.join(ultralytics_base, 'ultralytics')): + try: + import ultralytics + ultralytics_base = os.path.dirname(os.path.dirname(ultralytics.__file__)) + except ImportError: + pass + +ultralytics_datas = collect_data_files('ultralytics') +imageio_datas = collect_data_files('imageio') +rawpy_datas = collect_data_files('rawpy') +pillow_heif_datas = collect_data_files('pillow_heif') + +all_datas = [ + (os.path.join(base_path, 'exiftools_win'), 'exiftools_win'), + (os.path.join(base_path, 'img'), 'img'), + (os.path.join(base_path, 'locales'), 'locales'), + (os.path.join(base_path, 'ioc'), 'ioc'), + (os.path.join(base_path, 'models', 'yolo11l-seg.pt'), 'models'), + (os.path.join(base_path, 'birdid', 'data', 'bird_reference.sqlite'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_classid_mapping.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_regions.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'offline_ebird_data'), 'birdid/data/offline_ebird_data'), + (os.path.join(base_path, 'SuperBirdIDPlugin.lrplugin'), 'SuperBirdIDPlugin.lrplugin'), + (os.path.join(base_path, 'requirements_base.txt'), '.'), + (os.path.join(base_path, 'core', 'runtime_requirements.py'), 'core'), + (os.path.join(ultralytics_base, 'ultralytics', 'cfg'), 'ultralytics/cfg'), +] + +all_datas.extend(ultralytics_datas) +all_datas.extend(imageio_datas) +all_datas.extend(rawpy_datas) +all_datas.extend(pillow_heif_datas) +all_datas.extend(_optional_copy_metadata('imageio')) +all_datas.extend(_optional_copy_metadata('rawpy')) +all_datas.extend(_optional_copy_metadata('ultralytics')) +all_datas.extend(_optional_copy_metadata('pillow_heif')) +all_datas.extend(_optional_copy_metadata('pi_heif')) + +# Windows Lite 冻结包会在主程序最早期进入 `--runtime-bootstrap` 路径, +# 由打包后的可执行文件自身安装并校验 Torch 运行时。 +# PyInstaller 对这条链路里的标准库模块并不总能静态识别, +# 所以需要在这里集中声明,避免后续再零散追加到主 hiddenimports 列表。 +# The Windows Lite frozen build enters `--runtime-bootstrap` very early and +# installs/verifies the Torch runtime from the packaged executable itself. +# PyInstaller does not always detect the stdlib modules used along that path, +# so keep them centralized here instead of appending ad hoc entries later. +runtime_bootstrap_stdlib_hiddenimports = [ + 'argparse', + 'ast', + 'base64', + 'bisect', + 'cProfile', + 'concurrent', + 'copy', + 'csv', + 'ctypes', + 'dataclasses', + 'datetime', + 'difflib', + 'dis', + 'enum', + 'faulthandler', + 'fnmatch', + 'gc', + 'getpass', + 'glob', + 'gzip', + 'hashlib', + 'heapq', + 'inspect', + 'ipaddress', + 'linecache', + 'locale', + 'modulefinder', + 'numbers', + 'pickletools', + 'profile', + 'pprint', + 'pstats', + 'queue', + 'resource', + 'runpy', + 'shlex', + 'signal', + 'sqlite3', + 'statistics', + 'sysconfig', + 'tarfile', + 'timeit', + 'tokenize', + 'traceback', + 'unittest', + 'uuid', + 'weakref', + 'xml', + 'zipfile', +] + +app_hiddenimports = [ + 'ultralytics', + 'PIL', + 'cv2', + 'numpy', + 'yaml', + 'PySide6', + 'PySide6.QtCore', + 'PySide6.QtGui', + 'PySide6.QtWidgets', + 'imageio', + 'rawpy', + 'imagehash', + 'pywt', + 'pillow_heif', + 'core', + 'core.burst_detector', + 'core.config_manager', + 'core.exposure_detector', + 'core.file_manager', + 'core.flight_detector', + 'core.focus_point_detector', + 'core.initialization_manager', + 'core.keypoint_detector', + 'core.photo_processor', + 'core.rating_engine', + 'core.runtime_bootstrap', + 'core.source_probe', + 'core.stats_formatter', + 'multiprocessing', + 'multiprocessing.spawn', + 'tools.update_checker', + 'packaging', + 'packaging.version', + 'birdid', + 'birdid.bird_identifier', + 'birdid.ebird_country_filter', + 'birdid_server', + 'server_manager', + 'flask', + 'flask.json', + 'cryptography', + 'cryptography.fernet', + '_telemetry_build', + 'app_user_stat._telemetry_build', + 'app_user_stat', + 'app_user_stat.telemetry', + 'app_user_stat.consent_texts', + 'app_user_stat.consent_texts.en_US', + 'app_user_stat.consent_texts.zh_CN', +] + +a = Analysis( + ['main.py'], + pathex=[base_path], + binaries=[], + datas=all_datas, + hiddenimports=app_hiddenimports + runtime_bootstrap_stdlib_hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=['pyi_rth_cv2.py'] if os.path.exists('pyi_rth_cv2.py') else [], + excludes=[ + 'torch', 'torchvision', 'torchaudio', 'timm', + 'PyQt5', 'PyQt6', 'tkinter', + 'polars', 'numba', 'llvmlite', 'pyarrow', 'facexlib', 'datasets', + ], + noarchive=False, + optimize=0, +) + +pyz = PYZ(a.pure) + +icon_path = os.path.join(base_path, 'img', 'icon.ico') +if not os.path.exists(icon_path): + icon_path = None + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='SuperPicky', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, + icon=icon_path, +) + +coll = COLLECT( + exe, + a.binaries, + a.datas, + strip=False, + upx=False, + upx_exclude=[], + name='SuperPicky', +) diff --git a/advanced_config.py b/advanced_config.py index d8223ae6..a9535370 100644 --- a/advanced_config.py +++ b/advanced_config.py @@ -88,7 +88,32 @@ class AdvancedConfig: # 更新提醒控制 "ignored_update_version": None, # 跳过提醒的版本号,如 "4.3.0" "include_prerelease": False, # 是否接收 Beta/RC 更新提醒 - "auto_check_updates": True, # 启动时自动检查更新(含补丁) + "auto_check_updates": False, # 启动时自动检查更新(含补丁) + + # V4.3+: 轻量底包首启初始化状态 + "initialization_completed": False, + "initialization_manifest_version": "v1", + "initialization_in_progress": False, + "last_init_exit_reason": "none", + "last_init_mode": "none", + + # V4.3+: 运行时选择与能力探测 + "selected_runtime_variant": "auto", # auto | cpu | cuda | mac + "detected_cuda_capable": False, + "runtime_install_location_preference": None, # None | default | install + "resolved_runtime_dir": None, + + # V4.3+: 首启启用的功能集与资源记录 + "enabled_feature_set": [ + "core_detection", + "quality", + "keypoint", + "flight", + "birdid", + ], + "downloaded_resources": {}, + "resolved_source_map": {}, + "last_init_error": None, # 主界面复选框状态 "flight_check": False, # 飞鸟检测默认关闭(开启后速度较慢,用户可手动开启) @@ -368,12 +393,115 @@ def set_include_prerelease(self, value: bool): @property def auto_check_updates(self) -> bool: - return self.config.get("auto_check_updates", True) + return self.config.get("auto_check_updates", False) def set_auto_check_updates(self, value: bool): """设置启动时是否自动检查更新。""" self.config["auto_check_updates"] = bool(value) + # V4.3+: 首启初始化状态 getter/setter + def _set_init_config(self, key: str, value): + self.config[key] = value + + @property + def initialization_completed(self) -> bool: + return self.config.get("initialization_completed", False) + + def set_initialization_completed(self, value: bool): + self._set_init_config("initialization_completed", bool(value)) + + @property + def initialization_manifest_version(self) -> str: + return str(self.config.get("initialization_manifest_version", "v1")) + + def set_initialization_manifest_version(self, value: str): + self._set_init_config("initialization_manifest_version", str(value or "v1")) + + @property + def initialization_in_progress(self) -> bool: + return self.config.get("initialization_in_progress", False) + + def set_initialization_in_progress(self, value: bool): + self._set_init_config("initialization_in_progress", bool(value)) + + @property + def last_init_exit_reason(self) -> str: + value = str(self.config.get("last_init_exit_reason", "none") or "none") + return value if value in ("none", "interrupted", "failed") else "none" + + def set_last_init_exit_reason(self, value: str): + normalized = value if value in ("none", "interrupted", "failed") else "none" + self._set_init_config("last_init_exit_reason", normalized) + + @property + def last_init_mode(self) -> str: + value = str(self.config.get("last_init_mode", "none") or "none") + return value if value in ("none", "init", "repair") else "none" + + def set_last_init_mode(self, value: str): + normalized = value if value in ("none", "init", "repair") else "none" + self._set_init_config("last_init_mode", normalized) + + @property + def selected_runtime_variant(self) -> str: + return str(self.config.get("selected_runtime_variant", "auto")) + + def set_selected_runtime_variant(self, value: str): + if value in ("auto", "cpu", "cuda", "mac"): + self._set_init_config("selected_runtime_variant", value) + + @property + def detected_cuda_capable(self) -> bool: + return self.config.get("detected_cuda_capable", False) + + def set_detected_cuda_capable(self, value: bool): + self._set_init_config("detected_cuda_capable", bool(value)) + + @property + def runtime_install_location_preference(self): + value = self.config.get("runtime_install_location_preference", None) + return value if value in ("default", "install", None) else None + + def set_runtime_install_location_preference(self, value): + normalized = value if value in ("default", "install") else None + self._set_init_config("runtime_install_location_preference", normalized) + + @property + def resolved_runtime_dir(self): + value = self.config.get("resolved_runtime_dir", None) + return None if value in (None, "") else str(value) + + def set_resolved_runtime_dir(self, value): + self._set_init_config("resolved_runtime_dir", None if not value else str(value)) + + @property + def enabled_feature_set(self) -> list: + return list(self.config.get("enabled_feature_set", [])) + + def set_enabled_feature_set(self, features: list): + self._set_init_config("enabled_feature_set", list(features)) + + @property + def downloaded_resources(self) -> dict: + return dict(self.config.get("downloaded_resources", {})) + + def set_downloaded_resources(self, resources: dict): + self._set_init_config("downloaded_resources", dict(resources)) + + @property + def resolved_source_map(self) -> dict: + return dict(self.config.get("resolved_source_map", {})) + + def set_resolved_source_map(self, source_map: dict): + self._set_init_config("resolved_source_map", dict(source_map)) + + @property + def last_init_error(self): + return self.config.get("last_init_error", None) + + def set_last_init_error(self, value): + self._set_init_config("last_init_error", value if value is None else str(value)) + # 主界面复选框状态 getter/setter @property def flight_check(self): diff --git a/ai_model.py b/ai_model.py index b48fb431..94d75c8c 100644 --- a/ai_model.py +++ b/ai_model.py @@ -17,7 +17,9 @@ def load_yolo_model(log_callback=None): """加载 YOLO 模型(使用最佳计算设备)""" - model_path = config.ai.get_model_path() + model_path = os.path.abspath(config.ai.get_model_path()) + if not os.path.exists(model_path): + raise FileNotFoundError(f"YOLO model file not found: {model_path}") model = YOLO(str(model_path)) # 使用统一的设备检测逻辑 @@ -431,4 +433,4 @@ def detect_and_draw_birds(image_path, model, output_path, dir, ui_settings, i18n # Mask processing failed, ignore pass - return found_bird, bird_result, bird_confidence, bird_sharpness, nima_score, bird_bbox, img_dims, bird_mask, bird_count \ No newline at end of file + return found_bird, bird_result, bird_confidence, bird_sharpness, nima_score, bird_bbox, img_dims, bird_mask, bird_count diff --git a/birdid/avonet_filter.py b/birdid/avonet_filter.py index 8ed383a4..0e8d41c3 100644 --- a/birdid/avonet_filter.py +++ b/birdid/avonet_filter.py @@ -14,148 +14,155 @@ import os import sqlite3 from typing import Set, List, Optional, Tuple +from config import get_install_scoped_resource_path from tools.i18n import t as _t -# 区域边界定义 (south, north, west, east) -# 格式: REGION_CODE: (南纬界, 北纬界, 西经界, 东经界) REGION_BOUNDS = { - # 全球 "GLOBAL": (-90, 90, -180, 180), - - # 六大洲 (宽泛定义,用于大范围检索) - "AF": (-35, 37, -17, 51), # 非洲 (Africa) - "AS": (-10, 81, 26, 170), # 亚洲 (Asia) - "EU": (34, 71, -25, 45), # 欧洲 (Europe) - "NA": (14, 83, -168, -52), # 北美洲 (North America) - "SA": (-56, 13, -81, -34), # 南美洲 (South America) - "OC": (-47, -10, 110, 180), # 大洋洲 (Oceania) - - # 亚太地区 - 国家 - "AU": (-44, -10, 112, 155), # 澳大利亚 - "NZ": (-47.5, -34, 166, 179), # 新西兰 - "CN": (18, 54, 73, 135), # 中国 - "JP": (24, 46, 122, 154), # 日本 - "KR": (33, 43, 124, 132), # 韩国 - "TW": (21.5, 25.5, 119, 122.5), # 台湾 - "HK": (22.1, 22.6, 113.8, 114.5), # 香港 - "TH": (5.5, 20.5, 97.5, 105.5), # 泰国 - "MY": (0.5, 7.5, 99.5, 119.5), # 马来西亚 - "SG": (1.1, 1.5, 103.6, 104.1), # 新加坡 - "ID": (-11, 6, 95, 141), # 印度尼西亚 - "PH": (4.5, 21, 116, 127), # 菲律宾 - "VN": (8, 23.5, 102, 110), # 越南 - "IN": (6, 36, 68, 98), # 印度 - "LK": (5, 10, 79, 82), # 斯里兰卡 - "NP": (26, 31, 80, 88), # 尼泊尔 - "MN": (41, 52, 87, 120), # 蒙古 - "RU": (41, 82, 19, 180), # 俄罗斯 - - # 美洲 - "US": (24, 49, -125, -66), # 美国本土 - "CA": (42, 83, -141, -52), # 加拿大 - "MX": (14, 33, -118, -86), # 墨西哥 - "BR": (-34, 5.5, -74, -34), # 巴西 - "AR": (-55, -21, -73, -53), # 阿根廷 - "CL": (-56, -17, -76, -66), # 智利 - "CO": (-4.5, 13, -79, -66), # 哥伦比亚 - "PE": (-18.5, 0, -81, -68), # 秘鲁 - "EC": (-5, 2, -81, -75), # 厄瓜多尔 - "CR": (8, 11.5, -86, -82.5), # 哥斯达黎加 - - # 欧洲 - "GB": (49, 61, -8, 2), # 英国 - "FR": (41, 51.5, -5, 10), # 法国 - "DE": (47, 55.5, 5.5, 15.5), # 德国 - "ES": (35.5, 44, -10, 4.5), # 西班牙 - "IT": (36, 47.5, 6.5, 18.5), # 意大利 - "NO": (57.5, 71.5, 4.5, 31.5), # 挪威 - "SE": (55, 69.5, 10.5, 24.5), # 瑞典 - "FI": (59.5, 70.5, 19.5, 31.5), # 芬兰 - "PL": (49, 55, 14, 24.5), # 波兰 - "TR": (35.5, 42.5, 25.5, 45), # 土耳其 - "PT": (36, 42, -10, -6), # 葡萄牙 - "NL": (50, 54, 3, 8), # 荷兰 - "CH": (45, 48, 5, 11), # 瑞士 - "GR": (34, 42, 19, 29), # 希腊 - "UA": (44, 53, 22, 41), # 乌克兰 - - # 非洲 - "MG": (-26, -11, 43, 51), # 马达加斯加 - "ZA": (-35, -22, 16.5, 33), # 南非 - "KE": (-5, 5, 33.5, 42), # 肯尼亚 - "TZ": (-12, -1, 29, 41), # 坦桑尼亚 - "EG": (22, 32, 24.5, 37), # 埃及 - "MA": (27, 36, -13, -1), # 摩洛哥 - - # 澳大利亚各州 - "AU-QLD": (-29, -10, 138, 154), # Queensland - "AU-NSW": (-37.5, -28, 141, 154), # New South Wales - "AU-VIC": (-39.2, -34, 141, 150), # Victoria - "AU-TAS": (-43.7, -39.5, 143.5, 148.5), # Tasmania - "AU-SA": (-38, -26, 129, 141), # South Australia - "AU-WA": (-35, -13.5, 112.5, 129), # Western Australia - "AU-NT": (-26, -10.5, 129, 138), # Northern Territory - "AU-ACT": (-35.95,-35.1, 148.75,149.4), # Australian Capital Territory - - # 美国各州 (south, north, west, east) - "US-AL": (30, 35, -88.5, -84.9), "US-AK": (51, 72, -168, -130), - "US-AZ": (31.3, 37, -114.8,-109), "US-AR": (33, 36.5, -94.6, -89.6), - "US-CA": (32.5, 42, -124.5,-114), "US-CO": (37, 41, -109, -102), - "US-CT": (40.9, 42.1, -73.7, -71.8),"US-DE": (38.4, 39.8, -75.8, -75), - "US-FL": (24.4, 31, -87.7, -80), "US-GA": (30.4, 35, -85.6, -80.8), - "US-HI": (18.9, 22.2, -160.3,-154.8),"US-ID": (42, 49, -117.2,-111), - "US-IL": (36.9, 42.5, -91.5, -87.5),"US-IN": (37.8, 41.8, -88.1, -84.8), - "US-IA": (40.4, 43.5, -96.6, -90.1),"US-KS": (37, 40, -102.1,-94.6), - "US-KY": (36.5, 39.2, -89.6, -81.9),"US-LA": (28.9, 33.1, -94.1, -88.8), - "US-ME": (43.1, 47.5, -71.1, -66.9),"US-MD": (37.9, 39.7, -79.5, -75), - "US-MA": (41.2, 42.9, -73.5, -69.9),"US-MI": (41.7, 48.3, -90.4, -82.4), - "US-MN": (43.5, 49.4, -97.2, -89.5),"US-MS": (30, 35, -91.7, -88.1), - "US-MO": (36, 40.6, -95.8, -89.1),"US-MT": (44.4, 49, -116.1,-104), - "US-NE": (40, 43, -104.1,-95.3),"US-NV": (35, 42, -120, -114), - "US-NH": (42.7, 45.3, -72.6, -70.7),"US-NJ": (38.9, 41.4, -75.6, -73.9), - "US-NM": (31.3, 37, -109.1,-103), "US-NY": (40.5, 45.1, -79.8, -71.9), - "US-NC": (33.8, 36.6, -84.3, -75.5),"US-ND": (45.9, 49, -104.1,-96.6), - "US-OH": (38.4, 42, -84.8, -80.5),"US-OK": (33.6, 37, -103, -94.4), - "US-OR": (41.9, 46.3, -124.6,-116.5),"US-PA": (39.7, 42.3, -80.5, -74.7), - "US-RI": (41.1, 42.1, -71.9, -71.1),"US-SC": (32, 35.2, -83.4, -78.5), - "US-SD": (42.5, 45.9, -104.1,-96.4),"US-TN": (35, 36.7, -90.3, -81.6), - "US-TX": (25.8, 36.5, -106.6,-93.5),"US-UT": (37, 42, -114.1,-109), - "US-VT": (42.7, 45.1, -73.4, -71.5),"US-VA": (36.5, 39.5, -83.7, -75.2), - "US-WA": (45.5, 49, -124.8,-116.9),"US-WV": (37.2, 40.6, -82.7, -77.7), - "US-WI": (42.5, 47.1, -92.9, -86.8),"US-WY": (41, 45, -111.1,-104), - - # 中国各省(south, north, west, east) - "CN-11": (39.4, 41.1, 115.4, 117.7), # 北京 - "CN-12": (38.6, 40.3, 116.7, 118.1), # 天津 - "CN-13": (36, 42.7, 113.5, 119.8), # 河北 - "CN-14": (34.6, 40.7, 110.2, 114.6), # 山西 - "CN-15": (37.5, 53.3, 97.2, 126.1), # 内蒙古 - "CN-21": (38.7, 43.5, 118.8, 125.7), # 辽宁 - "CN-22": (41.2, 46, 121.6, 131.3), # 吉林 - "CN-23": (43.4, 53.6, 121.1, 135.1), # 黑龙江 - "CN-31": (30.7, 31.9, 120.8, 122), # 上海 - "CN-32": (30.8, 35.1, 116.4, 121.9), # 江苏 - "CN-33": (27.1, 31.2, 118.1, 122.9), # 浙江 - "CN-34": (29.4, 34.7, 114.9, 119.9), # 安徽 - "CN-35": (23.5, 28.3, 115.8, 120.7), # 福建 - "CN-36": (24.5, 30.1, 113.6, 118.5), # 江西 - "CN-37": (34.4, 38.3, 114.8, 122.7), # 山东 - "CN-41": (31.4, 36.4, 110.4, 116.7), # 河南 - "CN-42": (29.1, 33.2, 108.4, 116.1), # 湖北 - "CN-43": (24.6, 30.1, 108.8, 114.3), # 湖南 - "CN-44": (20.2, 25.5, 109.7, 117.3), # 广东 - "CN-45": (20.9, 26.4, 104.5, 112.1), # 广西 - "CN-46": (18.1, 20.2, 108.4, 111.2), # 海南 - "CN-50": (28.2, 32.2, 105.3, 110.2), # 重庆 - "CN-51": (26, 34.3, 97.4, 108.5), # 四川 - "CN-52": (24.6, 29.2, 103.6, 109.6), # 贵州 - "CN-53": (21.1, 29.3, 97.5, 106.2), # 云南 - "CN-54": (26.8, 36.5, 78.4, 99.1), # 西藏 - "CN-61": (31.7, 39.6, 105.5, 111.3), # 陕西 - "CN-62": (32.6, 42.8, 92.4, 108.7), # 甘肃 - "CN-63": (31.6, 39.2, 89.4, 103.1), # 青海 - "CN-64": (35.2, 39.4, 104.3, 107.7), # 宁夏 - "CN-65": (34.3, 49.2, 73.5, 96.4), # 新疆 + "AF": (-35, 37, -17, 51), + "AS": (-10, 81, 26, 170), + "EU": (34, 71, -25, 45), + "NA": (14, 83, -168, -52), + "SA": (-56, 13, -81, -34), + "OC": (-47, -10, 110, 180), + "AU": (-44, -10, 112, 155), + "NZ": (-47.5, -34, 166, 179), + "CN": (18, 54, 73, 135), + "JP": (24, 46, 122, 154), + "KR": (33, 43, 124, 132), + "TW": (21.5, 25.5, 119, 122.5), + "HK": (22.1, 22.6, 113.8, 114.5), + "TH": (5.5, 20.5, 97.5, 105.5), + "MY": (0.5, 7.5, 99.5, 119.5), + "SG": (1.1, 1.5, 103.6, 104.1), + "ID": (-11, 6, 95, 141), + "PH": (4.5, 21, 116, 127), + "VN": (8, 23.5, 102, 110), + "IN": (6, 36, 68, 98), + "LK": (5, 10, 79, 82), + "NP": (26, 31, 80, 88), + "MN": (41, 52, 87, 120), + "RU": (41, 82, 19, 180), + "US": (24, 49, -125, -66), + "CA": (42, 83, -141, -52), + "MX": (14, 33, -118, -86), + "BR": (-34, 5.5, -74, -34), + "AR": (-55, -21, -73, -53), + "CL": (-56, -17, -76, -66), + "CO": (-4.5, 13, -79, -66), + "PE": (-18.5, 0, -81, -68), + "EC": (-5, 2, -81, -75), + "CR": (8, 11.5, -86, -82.5), + "GB": (49, 61, -8, 2), + "FR": (41, 51.5, -5, 10), + "DE": (47, 55.5, 5.5, 15.5), + "ES": (35.5, 44, -10, 4.5), + "IT": (36, 47.5, 6.5, 18.5), + "NO": (57.5, 71.5, 4.5, 31.5), + "SE": (55, 69.5, 10.5, 24.5), + "FI": (59.5, 70.5, 19.5, 31.5), + "PL": (49, 55, 14, 24.5), + "TR": (35.5, 42.5, 25.5, 45), + "PT": (36, 42, -10, -6), + "NL": (50, 54, 3, 8), + "CH": (45, 48, 5, 11), + "GR": (34, 42, 19, 29), + "UA": (44, 53, 22, 41), + "MG": (-26, -11, 43, 51), + "ZA": (-35, -22, 16.5, 33), + "KE": (-5, 5, 33.5, 42), + "TZ": (-12, -1, 29, 41), + "EG": (22, 32, 24.5, 37), + "MA": (27, 36, -13, -1), + "AU-QLD": (-29, -10, 138, 154), + "AU-NSW": (-37.5, -28, 141, 154), + "AU-VIC": (-39.2, -34, 141, 150), + "AU-TAS": (-43.7, -39.5, 143.5, 148.5), + "AU-SA": (-38, -26, 129, 141), + "AU-WA": (-35, -13.5, 112.5, 129), + "AU-NT": (-26, -10.5, 129, 138), + "AU-ACT": (-35.95, -35.1, 148.75, 149.4), + "US-AL": (30, 35, -88.5, -84.9), + "US-AK": (51, 72, -168, -130), + "US-AZ": (31.3, 37, -114.8, -109), + "US-AR": (33, 36.5, -94.6, -89.6), + "US-CA": (32.5, 42, -124.5, -114), + "US-CO": (37, 41, -109, -102), + "US-CT": (40.9, 42.1, -73.7, -71.8), + "US-DE": (38.4, 39.8, -75.8, -75), + "US-FL": (24.4, 31, -87.7, -80), + "US-GA": (30.4, 35, -85.6, -80.8), + "US-HI": (18.9, 22.2, -160.3, -154.8), + "US-ID": (42, 49, -117.2, -111), + "US-IL": (36.9, 42.5, -91.5, -87.5), + "US-IN": (37.8, 41.8, -88.1, -84.8), + "US-IA": (40.4, 43.5, -96.6, -90.1), + "US-KS": (37, 40, -102.1, -94.6), + "US-KY": (36.5, 39.2, -89.6, -81.9), + "US-LA": (28.9, 33.1, -94.1, -88.8), + "US-ME": (43.1, 47.5, -71.1, -66.9), + "US-MD": (37.9, 39.7, -79.5, -75), + "US-MA": (41.2, 42.9, -73.5, -69.9), + "US-MI": (41.7, 48.3, -90.4, -82.4), + "US-MN": (43.5, 49.4, -97.2, -89.5), + "US-MS": (30, 35, -91.7, -88.1), + "US-MO": (36, 40.6, -95.8, -89.1), + "US-MT": (44.4, 49, -116.1, -104), + "US-NE": (40, 43, -104.1, -95.3), + "US-NV": (35, 42, -120, -114), + "US-NH": (42.7, 45.3, -72.6, -70.7), + "US-NJ": (38.9, 41.4, -75.6, -73.9), + "US-NM": (31.3, 37, -109.1, -103), + "US-NY": (40.5, 45.1, -79.8, -71.9), + "US-NC": (33.8, 36.6, -84.3, -75.5), + "US-ND": (45.9, 49, -104.1, -96.6), + "US-OH": (38.4, 42, -84.8, -80.5), + "US-OK": (33.6, 37, -103, -94.4), + "US-OR": (41.9, 46.3, -124.6, -116.5), + "US-PA": (39.7, 42.3, -80.5, -74.7), + "US-RI": (41.1, 42.1, -71.9, -71.1), + "US-SC": (32, 35.2, -83.4, -78.5), + "US-SD": (42.5, 45.9, -104.1, -96.4), + "US-TN": (35, 36.7, -90.3, -81.6), + "US-TX": (25.8, 36.5, -106.6, -93.5), + "US-UT": (37, 42, -114.1, -109), + "US-VT": (42.7, 45.1, -73.4, -71.5), + "US-VA": (36.5, 39.5, -83.7, -75.2), + "US-WA": (45.5, 49, -124.8, -116.9), + "US-WV": (37.2, 40.6, -82.7, -77.7), + "US-WI": (42.5, 47.1, -92.9, -86.8), + "US-WY": (41, 45, -111.1, -104), + "CN-11": (39.4, 41.1, 115.4, 117.7), + "CN-12": (38.6, 40.3, 116.7, 118.1), + "CN-13": (36, 42.7, 113.5, 119.8), + "CN-14": (34.6, 40.7, 110.2, 114.6), + "CN-15": (37.5, 53.3, 97.2, 126.1), + "CN-21": (38.7, 43.5, 118.8, 125.7), + "CN-22": (41.2, 46, 121.6, 131.3), + "CN-23": (43.4, 53.6, 121.1, 135.1), + "CN-31": (30.7, 31.9, 120.8, 122), + "CN-32": (30.8, 35.1, 116.4, 121.9), + "CN-33": (27.1, 31.2, 118.1, 122.9), + "CN-34": (29.4, 34.7, 114.9, 119.9), + "CN-35": (23.5, 28.3, 115.8, 120.7), + "CN-36": (24.5, 30.1, 113.6, 118.5), + "CN-37": (34.4, 38.3, 114.8, 122.7), + "CN-41": (31.4, 36.4, 110.4, 116.7), + "CN-42": (29.1, 33.2, 108.4, 116.1), + "CN-43": (24.6, 30.1, 108.8, 114.3), + "CN-44": (20.2, 25.5, 109.7, 117.3), + "CN-45": (20.9, 26.4, 104.5, 112.1), + "CN-46": (18.1, 20.2, 108.4, 111.2), + "CN-50": (28.2, 32.2, 105.3, 110.2), + "CN-51": (26, 34.3, 97.4, 108.5), + "CN-52": (24.6, 29.2, 103.6, 109.6), + "CN-53": (21.1, 29.3, 97.5, 106.2), + "CN-54": (26.8, 36.5, 78.4, 99.1), + "CN-61": (31.7, 39.6, 105.5, 111.3), + "CN-62": (32.6, 42.8, 92.4, 108.7), + "CN-63": (31.6, 39.2, 89.4, 103.1), + "CN-64": (35.2, 39.4, 104.3, 107.7), + "CN-65": (34.3, 49.2, 73.5, 96.4), } @@ -176,14 +183,12 @@ def __init__(self, db_path: Optional[str] = None): db_path: avonet.db 的路径,如果为 None 则自动定位 """ if db_path is None: - # 自动定位数据库文件 db_path = self._find_database() self.db_path = db_path self._conn: Optional[sqlite3.Connection] = None - self._ebird_cls_map: Optional[dict] = None # eBird code -> class_id(懒加载) + self._ebird_cls_map: Optional[dict] = None - # 尝试连接数据库 if self.db_path and os.path.exists(self.db_path): try: self._conn = sqlite3.connect(self.db_path, check_same_thread=False) @@ -197,17 +202,12 @@ def _find_database(self) -> Optional[str]: 自动查找 avonet.db 文件 查找顺序: - 1. birdid/data/avonet.db (相对于当前文件) - 2. data/avonet.db (相对于当前工作目录) - 3. 常见安装位置 + 1. 统一安装目录资源路径 + 2. 兼容旧开发目录 """ - # 相对于当前模块的位置 - module_dir = os.path.dirname(os.path.abspath(__file__)) possible_paths = [ - os.path.join(module_dir, "data", "avonet.db"), - os.path.join(module_dir, "..", "data", "avonet.db"), - os.path.join(os.getcwd(), "birdid", "data", "avonet.db"), - os.path.join(os.getcwd(), "data", "avonet.db"), + str(get_install_scoped_resource_path(os.path.join("birdid", "data", "avonet.db"))), + os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "avonet.db"), ] for path in possible_paths: @@ -251,7 +251,6 @@ def get_species_by_gps(self, lat: float, lon: float) -> Set[int]: return set() try: - # 查询包含该GPS点的网格中的所有物种 query = """ SELECT DISTINCT sm.cls FROM distributions d @@ -306,7 +305,6 @@ def _get_species_by_bounds( return set() try: - # 查询与边界框重叠的所有网格中的物种 query = """ SELECT DISTINCT sm.cls FROM distributions d @@ -365,8 +363,6 @@ def __del__(self): """析构时关闭连接""" self.close() - # ==================== eBird 国家级回退 ==================== - def _load_ebird_cls_map(self) -> dict: """懒加载 ebird_classid_mapping.json,返回 ebird_code -> class_id 的反向映射""" if self._ebird_cls_map is not None: @@ -381,7 +377,7 @@ def _load_ebird_cls_map(self) -> dict: try: import json with open(map_path, "r", encoding="utf-8") as f: - raw = json.load(f) # {str(class_id): ebird_code} + raw = json.load(f) self._ebird_cls_map = {v: int(k) for k, v in raw.items()} except Exception as e: print(_t("logs.avonet_classid_failed", e=e)) @@ -394,10 +390,8 @@ def _detect_country_from_gps(self, lat: float, lon: float) -> Optional[str]: 根据 GPS 坐标离线判定国家代码(仅返回国家级,不含州级)。 优先匹配面积最小的边界框,避免大国遮蔽小国。 """ - # 大陆级/全球代码,跳过 _SKIP = {"GLOBAL", "AF", "AS", "EU", "NA", "SA", "OC"} - # 收集匹配的国家及其面积 candidates = [] for code, bounds in REGION_BOUNDS.items(): if code in _SKIP: @@ -410,7 +404,6 @@ def _detect_country_from_gps(self, lat: float, lon: float) -> Optional[str]: if not candidates: return None - # 返回面积最小的匹配(最具体的) candidates.sort() return candidates[0][1] @@ -431,7 +424,6 @@ def get_species_by_country_ebird( if not country_code: return set(), None - # 加载对应国家的 eBird 物种列表 module_dir = os.path.dirname(os.path.abspath(__file__)) species_file = os.path.join( module_dir, "data", "offline_ebird_data", @@ -450,7 +442,6 @@ def get_species_by_country_ebird( print(_t("logs.avonet_read_ebird_failed", code=country_code, e=e)) return set(), None - # 转换 eBird 代码 -> class_id cls_map = self._load_ebird_cls_map() class_ids: Set[int] = set() for code in ebird_codes: @@ -481,7 +472,6 @@ def _load_ebird_file(code: str) -> Optional[List[str]]: try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) - # 支持两种格式:纯 list 或 {"species": [...]} if isinstance(data, list): return data return data.get("species", []) @@ -490,11 +480,9 @@ def _load_ebird_file(code: str) -> Optional[List[str]]: return None region_code = region_code.upper() - # 尝试加载州级数据 species_codes = _load_ebird_file(region_code) actual_region = region_code - # 州级无数据则回退到国家级 if not species_codes and "-" in region_code: country = region_code.split("-")[0] species_codes = _load_ebird_file(country) @@ -511,85 +499,3 @@ def _load_ebird_file(code: str) -> Optional[List[str]]: class_ids.add(cls_id) return class_ids, actual_region - - -if __name__ == "__main__": - print("=" * 60) - print("AvonetFilter 测试") - print("=" * 60) - - # 创建过滤器实例 - af = AvonetFilter() - - # 检查数据库是否可用 - print(f"\n数据库路径: {af.db_path}") - print(f"数据库可用: {af.is_available()}") - - if not af.is_available(): - print("错误: 数据库不可用,无法继续测试") - exit(1) - - # 测试 GPS 查询 - print("\n" + "-" * 40) - print("GPS 坐标查询测试") - print("-" * 40) - - test_locations = [ - ("吉隆坡 (马来西亚)", 3.0, 101.7), - ("悉尼 (澳大利亚)", -33.9, 151.2), - ("东京 (日本)", 35.7, 139.7), - ("伦敦 (英国)", 51.5, -0.1), - ] - - for name, lat, lon in test_locations: - species = af.get_species_by_gps(lat, lon) - print(f" {name}: {len(species)} 个物种") - if species: - sample = sorted(list(species))[:5] - print(f" 样例 class_ids: {sample}") - - # 测试区域查询 - print("\n" + "-" * 40) - print("区域代码查询测试") - print("-" * 40) - - test_regions = ["AU", "AU-SA", "CN", "JP"] - - for region in test_regions: - species = af.get_species_by_region(region) - bounds = af.get_region_bounds(region) - print(f" {region}: {len(species)} 个物种") - print(f" 边界: {bounds}") - if species: - sample = sorted(list(species))[:5] - print(f" 样例 class_ids: {sample}") - - # 显示支持的区域列表 - print("\n" + "-" * 40) - print("支持的区域代码") - print("-" * 40) - - regions = af.get_supported_regions() - print(f" 共 {len(regions)} 个区域:") - - # 按类别分组显示 - global_regions = [r for r in regions if r == "GLOBAL"] - au_states = [r for r in regions if r.startswith("AU-")] - au_country = [r for r in regions if r == "AU"] - asia = [r for r in regions if r in ["CN", "JP", "KR", "TW", "TH", "MY", "SG", "ID", "PH", "VN", "IN", "NZ"]] - americas = [r for r in regions if r in ["US", "CA", "MX", "BR", "AR", "CL", "CO", "PE", "EC", "CR"]] - europe = [r for r in regions if r in ["GB", "FR", "DE", "ES", "IT", "NO", "SE", "FI", "PL", "TR"]] - africa = [r for r in regions if r in ["ZA", "KE", "TZ", "EG", "MA"]] - - print(f" 全球: {global_regions}") - print(f" 澳大利亚: {au_country + au_states}") - print(f" 亚太: {asia}") - print(f" 美洲: {americas}") - print(f" 欧洲: {europe}") - print(f" 非洲: {africa}") - - # 关闭连接 - af.close() - print("\n" + "=" * 60) - print("测试完成") - print("=" * 60) diff --git a/birdid/bird_identifier.py b/birdid/bird_identifier.py index 1a8c2e36..95c9d520 100644 --- a/birdid/bird_identifier.py +++ b/birdid/bird_identifier.py @@ -1,7 +1,11 @@ #!/usr/bin/env python3 """ -鸟类识别核心模块 -从 SuperBirdID 移植,提供鸟类检测与分类识别功能 +鸟类识别核心模块。 +Core bird-identification module. + +从 SuperBirdID 移植,负责鸟类检测、分类与离线资源路径兼容。 +Ported from SuperBirdID and responsible for bird detection, classification, +and compatibility with offline resource paths. """ __version__ = "1.0.0" @@ -15,113 +19,141 @@ import io import os import sys -from typing import Optional, List, Dict, Tuple, Set +from typing import Any, Optional, List, Dict, Tuple, Set, cast from tools.i18n import t as _t -from config import get_best_device, get_lazy_registry - -# ==================== 设备配置 ==================== +from config import ( + get_best_device, + get_lazy_registry, + get_app_config_dir, + get_install_scoped_resource_path, + get_packaged_model_relative_path, + get_runtime_meipass, +) -CLASSIFIER_DEVICE = get_best_device() +CLASSIFIER_DEVICE = torch.device(str(get_best_device())) -# ==================== 可选依赖检测 ==================== +RESAMPLING_LANCZOS = Image.Resampling.LANCZOS -# RAW格式支持 try: import rawpy import imageio + RAW_SUPPORT = True except ImportError: + rawpy = cast(Any, None) + imageio = cast(Any, None) RAW_SUPPORT = False -# YOLO检测支持 try: from ultralytics import YOLO + YOLO_AVAILABLE = True except ImportError: + YOLO = cast(Any, None) YOLO_AVAILABLE = False -# ==================== 路径配置 ==================== - -# birdid 模块目录 BIRDID_DIR = os.path.dirname(os.path.abspath(__file__)) -# 项目根目录(code_updates overlay 场景下 __file__ 指向 code_updates/birdid/,需通过 sys.path 找真实根) + + def _find_project_root() -> str: candidate = os.path.dirname(BIRDID_DIR) - if os.path.exists(os.path.join(candidate, 'models', 'model20240824.pth')): + if os.path.exists(os.path.join(candidate, "models", "model20240824.pth")): return candidate for p in sys.path: - if p and os.path.isdir(p) and os.path.exists(os.path.join(p, 'models', 'model20240824.pth')): + if ( + p + and os.path.isdir(p) + and os.path.exists(os.path.join(p, "models", "model20240824.pth")) + ): return p - return candidate # 兜底 + return candidate + def _find_birdid_dir() -> str: - if os.path.exists(os.path.join(BIRDID_DIR, 'data', 'bird_reference.sqlite')): + if os.path.exists(os.path.join(BIRDID_DIR, "data", "bird_reference.sqlite")): return BIRDID_DIR for p in sys.path: if p and os.path.isdir(p): - candidate = os.path.join(p, 'birdid') - if os.path.exists(os.path.join(candidate, 'data', 'bird_reference.sqlite')): + candidate = os.path.join(p, "birdid") + if os.path.exists(os.path.join(candidate, "data", "bird_reference.sqlite")): return candidate - return BIRDID_DIR # 兜底 + return BIRDID_DIR + PROJECT_ROOT = _find_project_root() BIRDID_DIR = _find_birdid_dir() def get_birdid_path(relative_path: str) -> str: - """获取 birdid 模块内的资源路径""" - if getattr(sys, 'frozen', False): - # PyInstaller 打包环境 - return os.path.join(sys._MEIPASS, 'birdid', relative_path) + """ + 返回 `birdid/` 目录下的资源路径。 + Return a resource path under the `birdid/` directory. + + Windows Lite 构建需要从安装目录 `_internal` 读取资源,其余冻结环境仍跟随 + PyInstaller bundle 目录;源码环境则回退到仓库内的 `birdid/` 目录。 + Windows Lite builds read from the install-scoped `_internal` tree, other + frozen builds still follow the PyInstaller bundle, and source runs fall back + to the repository `birdid/` directory. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + return str( + get_install_scoped_resource_path(os.path.join("birdid", relative_path)) + ) + if getattr(sys, "frozen", False): + meipass = get_runtime_meipass() + if meipass is not None: + return os.path.join(meipass, "birdid", relative_path) return os.path.join(BIRDID_DIR, relative_path) def get_project_path(relative_path: str) -> str: - """获取项目根目录下的资源路径""" - if getattr(sys, 'frozen', False): - return os.path.join(sys._MEIPASS, relative_path) + """ + 返回项目级资源路径。 + Return a project-level resource path. + + 这里统一兼容 Windows Lite 安装目录、普通 PyInstaller bundle 与源码目录, + 避免各调用方再自行拼接 `_MEIPASS` 路径。 + This helper centralizes path selection for Windows Lite installs, regular + PyInstaller bundles, and source checkouts so callers do not rebuild + `_MEIPASS`-based paths themselves. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + packaged_relative_path = None + if relative_path.startswith("models/"): + packaged_relative_path = get_packaged_model_relative_path(relative_path) + return str( + get_install_scoped_resource_path( + relative_path, packaged_relative_path=packaged_relative_path + ) + ) + if getattr(sys, "frozen", False): + meipass = get_runtime_meipass() + if meipass is not None: + return os.path.join(meipass, relative_path) return os.path.join(PROJECT_ROOT, relative_path) def get_user_data_dir() -> str: - """获取用户数据目录""" - if sys.platform == 'darwin': - user_data_dir = os.path.expanduser('~/Documents/SuperPicky_Data') - elif sys.platform == 'win32': - user_data_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') - else: - user_data_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') + user_data_dir = str(get_app_config_dir()) os.makedirs(user_data_dir, exist_ok=True) return user_data_dir -# ==================== 模型路径 ==================== -# 鸟类识别专用模型和数据(在 birdid/ 目录下) -# OSEA ResNet34 模型(替代旧 birdid2024) -MODEL_PATH = get_project_path('models/model20240824.pth') -# 旧模型路径(保留作为回退) -MODEL_PATH_LEGACY = get_birdid_path('models/birdid2024.pt') -MODEL_PATH_ENC = get_birdid_path('models/birdid2024.pt.enc') -# OSEA 模型类别数 +MODEL_PATH = get_project_path("models/model20240824.pth") +MODEL_PATH_LEGACY = get_birdid_path("models/birdid2024.pt") +MODEL_PATH_ENC = get_birdid_path("models/birdid2024.pt.enc") OSEA_NUM_CLASSES = 11000 -DATABASE_PATH = get_birdid_path('data/bird_reference.sqlite') - -# YOLO 模型(共用项目根目录的模型) -YOLO_MODEL_PATH = get_project_path('models/yolo11l-seg.pt') +DATABASE_PATH = get_birdid_path("data/bird_reference.sqlite") +YOLO_MODEL_PATH = get_project_path("models/yolo11l-seg.pt") -# ==================== 全局变量(懒加载)==================== -# 已迁移至 config.get_lazy_registry() 统一管理 - -# ==================== 模型加密解密 ==================== def decrypt_model(encrypted_path: str, password: str) -> bytes: - """解密模型文件""" from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC - with open(encrypted_path, 'rb') as f: + with open(encrypted_path, "rb") as f: encrypted_data = f.read() salt = encrypted_data[:16] @@ -133,15 +165,11 @@ def decrypt_model(encrypted_path: str, password: str) -> bytes: length=32, salt=salt, iterations=100000, - backend=default_backend() + backend=default_backend(), ) key = kdf.derive(password.encode()) - cipher = Cipher( - algorithms.AES(key), - modes.CBC(iv), - backend=default_backend() - ) + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) decryptor = cipher.decryptor() plaintext_padded = decryptor.update(ciphertext) + decryptor.finalize() @@ -150,15 +178,11 @@ def decrypt_model(encrypted_path: str, password: str) -> bytes: def _load_torchscript_from_bytes(model_data: bytes): - """Load TorchScript from bytes to avoid Windows non-ASCII temp path issues.""" buffer = io.BytesIO(model_data) - return torch.jit.load(buffer, map_location='cpu') + return torch.jit.load(buffer, map_location="cpu") -# ==================== 懒加载函数 ==================== - def get_classifier(): - """懒加载分类模型(OSEA ResNet34)""" registry = get_lazy_registry() def _factory(): @@ -166,11 +190,10 @@ def _factory(): if os.path.exists(MODEL_PATH): model = models.resnet34(num_classes=OSEA_NUM_CLASSES) - state_dict = torch.load(MODEL_PATH, map_location='cpu', weights_only=True) + state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=True) model.load_state_dict(state_dict) - model = model.to(CLASSIFIER_DEVICE) + model = model.to(device=CLASSIFIER_DEVICE) model.eval() - print(f"[BirdID] OSEA ResNet34 model loaded, device: {CLASSIFIER_DEVICE}") return model SECRET_PASSWORD = "SuperBirdID_2024_AI_Model_Encryption_Key_v1" @@ -179,40 +202,38 @@ def _factory(): model = _load_torchscript_from_bytes(model_data) elif os.path.exists(MODEL_PATH_LEGACY): try: - model = torch.jit.load(MODEL_PATH_LEGACY, map_location='cpu') + model = torch.jit.load(MODEL_PATH_LEGACY, map_location="cpu") except RuntimeError as e: - if 'open file failed' not in str(e) or 'fopen' not in str(e): + if "open file failed" not in str(e) or "fopen" not in str(e): raise - with open(MODEL_PATH_LEGACY, 'rb') as f: + with open(MODEL_PATH_LEGACY, "rb") as f: model_data = f.read() model = _load_torchscript_from_bytes(model_data) else: raise RuntimeError(f"未找到分类模型: {MODEL_PATH} 或 {MODEL_PATH_LEGACY}") - model = model.to(CLASSIFIER_DEVICE) - model.eval() # noqa: model.eval() is a PyTorch API call, not Python eval() - print(_t("logs.birdid_fallback_model")) + model = model.to(device=CLASSIFIER_DEVICE) + model.eval() return model return registry.get_or_create("birdid.classifier", _factory) def get_bird_model(): - """获取识鸟模型(get_classifier 的别名,用于模型预加载)""" return get_classifier() def get_database_manager(): - """懒加载数据库管理器""" registry = get_lazy_registry() def _factory(): try: from birdid.bird_database_manager import BirdDatabaseManager + if os.path.exists(DATABASE_PATH): return BirdDatabaseManager(DATABASE_PATH) except Exception as e: - print(_t("logs.db_load_failed", e=e)) + pass return False result = registry.get_or_create("birdid.database_manager", _factory) @@ -220,40 +241,38 @@ def _factory(): def get_yolo_detector(): - """懒加载YOLO检测器""" if not YOLO_AVAILABLE: return None registry = get_lazy_registry() return registry.get_or_create( "birdid.yolo_detector", - lambda: YOLOBirdDetector(YOLO_MODEL_PATH) if os.path.exists(YOLO_MODEL_PATH) else None, + lambda: ( + YOLOBirdDetector(YOLO_MODEL_PATH) + if os.path.exists(YOLO_MODEL_PATH) + else None + ), ) def get_species_filter(): - """懒加载 AvonetFilter(单例模式)""" registry = get_lazy_registry() def _factory(): try: from birdid.avonet_filter import AvonetFilter + filt = AvonetFilter() if filt.is_available(): - print(_t("logs.avonet_loaded")) return filt except Exception as e: - print(_t("logs.avonet_init_failed", e=e)) + pass return None return registry.get_or_create("birdid.avonet_filter", _factory) -# ==================== YOLO 鸟类检测器 ==================== - class YOLOBirdDetector: - """YOLO 鸟类检测器""" - - def __init__(self, model_path: str = None): + def __init__(self, model_path: Optional[str] = None): if not YOLO_AVAILABLE: self.model = None return @@ -261,10 +280,14 @@ def __init__(self, model_path: str = None): if model_path is None: model_path = YOLO_MODEL_PATH + model_path = os.path.abspath(model_path) + if not os.path.exists(model_path): + self.model = None + return + try: self.model = YOLO(model_path) except Exception as e: - print(_t("logs.yolo_load_failed", e=e)) self.model = None def detect_and_crop_bird( @@ -272,26 +295,8 @@ def detect_and_crop_bird( image_input, confidence_threshold: float = 0.25, padding_ratio: float = 0.15, - fill_color: Tuple[int, int, int] = (0, 0, 0) + fill_color: Tuple[int, int, int] = (0, 0, 0), ) -> Tuple[Optional[Image.Image], str]: - """ - 检测并裁剪鸟类区域(智能正方形裁剪 + Letterboxing) - - 处理流程: - 1. YOLO 检测获取 bounding box - 2. 智能正方形扩展: max_side * (1 + padding_ratio) - 3. 边界限制: 裁剪区域不超出图片范围 - 4. Letterboxing: 如果裁剪后非正方形,用 fill_color 填充成正方形 - - Args: - image_input: 文件路径或 PIL Image - confidence_threshold: 置信度阈值 - padding_ratio: padding 比例(基于 bbox 最大边长),默认 0.15 (15%) - fill_color: Letterboxing 填充颜色,默认黑色 (0, 0, 0) - - Returns: - (裁剪后的正方形图像, 检测信息) 或 (None, 错误信息) - """ if self.model is None: return None, "YOLO模型未可用" @@ -315,29 +320,27 @@ def detect_and_crop_bird( confidence = box.conf[0].cpu().numpy() class_id = int(box.cls[0].cpu().numpy()) - # COCO 数据集中鸟类的 class_id 是 14 if class_id == 14: - detections.append({ - 'bbox': [int(x1), int(y1), int(x2), int(y2)], - 'confidence': float(confidence) - }) + detections.append( + { + "bbox": [int(x1), int(y1), int(x2), int(y2)], + "confidence": float(confidence), + } + ) if not detections: - return None, _t("logs.no_bird_detected") + return None, "未检测到鸟类" - best = max(detections, key=lambda x: x['confidence']) + best = max(detections, key=lambda x: x["confidence"]) img_width, img_height = image.size - # Phase 1: 获取 bbox - x1, y1, x2, y2 = best['bbox'] + x1, y1, x2, y2 = best["bbox"] bbox_width = x2 - x1 bbox_height = y2 - y1 - # Phase 2: 智能正方形扩展 (基于最大边长 + padding_ratio) max_side = max(bbox_width, bbox_height) target_side = int(max_side * (1 + padding_ratio)) - # 以 bbox 中心为基准扩展 cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 half = target_side // 2 @@ -347,7 +350,6 @@ def detect_and_crop_bird( sq_x2 = cx + half sq_y2 = cy + half - # Phase 3: 边界限制 crop_x1 = max(0, sq_x1) crop_y1 = max(0, sq_y1) crop_x2 = min(img_width, sq_x2) @@ -356,10 +358,9 @@ def detect_and_crop_bird( cropped = image.crop((crop_x1, crop_y1, crop_x2, crop_y2)) crop_w, crop_h = cropped.size - # Phase 4: Letterboxing (如果裁剪后非正方形) if crop_w != crop_h: sq_size = max(crop_w, crop_h) - square = Image.new('RGB', (sq_size, sq_size), fill_color) + square = Image.new("RGB", (sq_size, sq_size), fill_color) paste_x = (sq_size - crop_w) // 2 paste_y = (sq_size - crop_h) // 2 square.paste(cropped, (paste_x, paste_y)) @@ -373,68 +374,79 @@ def detect_and_crop_bird( return None, f"检测失败: {e}" -# ==================== 图像加载 ==================== - def load_image(image_path: str) -> Image.Image: - """ - 加载图像,支持标准格式和 RAW 格式 - 对 RAW 文件优先提取内嵌 JPEG 预览图(更适合 YOLO 检测) - """ if not os.path.exists(image_path): raise FileNotFoundError(f"文件不存在: {image_path}") ext = os.path.splitext(image_path)[1].lower() raw_extensions = [ - '.cr2', '.cr3', '.nef', '.nrw', '.arw', '.srf', '.dng', - '.raf', '.orf', '.rw2', '.pef', '.srw', '.raw', '.rwl', - '.3fr', '.fff', '.erf', '.mef', '.mos', '.mrw', '.x3f', - '.hif', '.heif', '.heic', # Sony HIF / HEIF + ".cr2", + ".cr3", + ".nef", + ".nrw", + ".arw", + ".srf", + ".dng", + ".raf", + ".orf", + ".rw2", + ".pef", + ".srw", + ".raw", + ".rwl", + ".3fr", + ".fff", + ".erf", + ".mef", + ".mos", + ".mrw", + ".x3f", + ".hif", + ".heif", + ".heic", ] - # HEIF 格式(rawpy 不支持):直接补 pillow-heif 路径 - heif_extensions = {'.hif', '.heif', '.heic'} + heif_extensions = {".hif", ".heif", ".heic"} if ext in raw_extensions: if ext in heif_extensions: return _load_heif(image_path) if RAW_SUPPORT: + thumb_format_enum = getattr(rawpy, "ThumbFormat", None) + jpeg_thumb_format = getattr(thumb_format_enum, "JPEG", None) + bitmap_thumb_format = getattr(thumb_format_enum, "BITMAP", None) + rawpy_internal = getattr(rawpy, "_rawpy", None) + unsupported_error = getattr( + rawpy_internal, "LibRawFileUnsupportedError", None + ) try: with rawpy.imread(image_path) as raw: - # 优先尝试提取内嵌的 JPEG 预览图 try: thumb = raw.extract_thumb() - if thumb.format == rawpy.ThumbFormat.JPEG: - # 直接使用内嵌的 JPEG + if thumb.format == jpeg_thumb_format: from io import BytesIO + img = Image.open(BytesIO(thumb.data)).convert("RGB") - print(_t("logs.raw_embedded_jpeg", w=img.size[0], h=img.size[1])) return img - elif thumb.format == rawpy.ThumbFormat.BITMAP: - # 位图格式 + elif thumb.format == bitmap_thumb_format: img = Image.fromarray(thumb.data).convert("RGB") - print(_t("logs.raw_embedded_bitmap", w=img.size[0], h=img.size[1])) return img except Exception as e: - print(_t("logs.raw_preview_failed", e=e)) - - # 如果无法提取预览,使用半尺寸后处理 + pass + rgb = raw.postprocess( use_camera_wb=True, output_bps=8, no_auto_bright=False, auto_bright_thr=0.01, - half_size=True # 使用半尺寸,加快处理 + half_size=True, ) img = Image.fromarray(rgb) - print(_t("logs.raw_half_size", w=img.size[0], h=img.size[1])) return img - except rawpy._rawpy.LibRawFileUnsupportedError: - # LibRaw 不支持的格式(如 Sony A7M5 NeXt/Compressed RAW 2) - # 回退:使用 exiftool -b -JpgFromRaw 提取相机内嵌 JPEG - print(f"[RAW] rawpy 不支持此 RAW 格式,尝试 ExifTool JpgFromRaw 回退...") - return _load_raw_via_exiftool(image_path) except Exception as e: + if unsupported_error is not None and isinstance(e, unsupported_error): + return _load_raw_via_exiftool(image_path) raise Exception(f"RAW处理失败: {e}") else: raise ImportError("需要安装 rawpy 来处理 RAW 格式") @@ -442,19 +454,19 @@ def load_image(image_path: str) -> Image.Image: return Image.open(image_path).convert("RGB") -def _load_raw_via_exiftool(image_path: str) -> "Image.Image": +def _load_raw_via_exiftool(image_path: str) -> Image.Image: """ - 使用 ExifTool 从 RAW 文件提取内嵌 JPEG。 - 用于 LibRaw 不支持的格式(如 Sony A7M5 NeXt/Compressed RAW 2)。 - 按优先级依次尝试:JpgFromRaw → PreviewImage → ThumbnailImage + 使用 ExifTool 从 RAW 文件提取可解码预览图。 + Extract a decodable preview image from a RAW file via ExifTool. """ import subprocess from io import BytesIO - # 查找 exiftool(优先使用打包内的版本) possible_paths = [] if getattr(sys, "frozen", False): - possible_paths.append(os.path.join(sys._MEIPASS, "exiftools_mac", "exiftool")) + meipass = get_runtime_meipass() + if meipass is not None: + possible_paths.append(os.path.join(meipass, "exiftools_mac", "exiftool")) possible_paths += [ os.path.join(PROJECT_ROOT, "exiftools_mac", "exiftool"), "/opt/homebrew/bin/exiftool", @@ -463,165 +475,157 @@ def _load_raw_via_exiftool(image_path: str) -> "Image.Image": ] exiftool = next((p for p in possible_paths if os.path.isfile(p)), "exiftool") - # 依次尝试各种嵌入图像标签 for tag in ["-JpgFromRaw", "-PreviewImage", "-ThumbnailImage"]: try: result = subprocess.run( - [exiftool, "-b", tag, image_path], - capture_output=True, timeout=15 + [exiftool, "-b", tag, image_path], capture_output=True, timeout=15 ) if result.returncode == 0 and result.stdout and len(result.stdout) > 1000: img = Image.open(BytesIO(result.stdout)).convert("RGB") - print(f"[RAW] ExifTool {tag} 提取成功: {img.size[0]}x{img.size[1]}") return img except Exception as e: - print(f"[RAW] ExifTool {tag} 失败: {e}") continue raise Exception( f"\u6682\u4e0d\u652f\u6301\u6b64 RAW \u683c\u5f0f\uff08{os.path.basename(image_path)}\uff09\u3002" "Sony A7M5 \u7b49\u76f8\u673a\u7684 NeXt/Compressed RAW 2 \u683c\u5f0f\u76ee\u524d\u7b2c\u4e09\u65b9\u5e93\u5c1a\u672a\u5b8c\u6574\u652f\u6301\uff0c" - "\u5c06\u5728\u540e\u7eed\u7248\u672c\u4e2d\u4fee\u590d\u3002\u5efa\u8bae\u4e34\u65f6\u4f7f\u7528\u65e0\u538b\u7f29 RAW \u6216 JPEG \u683c\u5f0f\u62cd\u6444\u3002" + "\u5c06\u0627\u5728\u540e\u7eed\u7248\u672c\u4e2d\u4fee\u590d\u3002\u5efa\u8bae\u4e34\u65f6\u4f7f\u7528\u65e0\u538b\u7f29 RAW \u6216 JPEG \u683c\u5f0f\u62cd\u6444\u3002" ) -def _load_heif(image_path: str) -> "Image.Image": - """ - \u4f7f\u7528 pillow-heif \u89e3\u7801 HEIF/HIF \u6587\u4ef6\uff08Sony HIF \u3001\u82f9\u679c HEIC \u7b49\uff09\u4e3a PIL Image\u3002 - """ +def _load_heif(image_path: str) -> Image.Image: try: import pillow_heif + heif_file = pillow_heif.read_heif(image_path) + if heif_file.data is None: + raise ValueError("HEIF 解码结果缺少像素数据") img = Image.frombytes( heif_file.mode, heif_file.size, heif_file.data, "raw", ).convert("RGB") - print(f"[HEIF] pillow-heif \u89e3\u7801\u6210\u529f: {img.size[0]}x{img.size[1]}") return img except ImportError: raise Exception( - "\u8bf7\u5b89\u88c5 pillow-heif \u6765\u652f\u6301 HIF/HEIC \u683c\u5f0f\uff1a pip install pillow-heif" + "请安装 pillow-heif 来支持 HIF/HEIC 格式: pip install pillow-heif" ) except Exception as e: - raise Exception(f"HEIF \u89e3\u7801\u5931\u8d25 ({os.path.basename(image_path)}): {e}") + raise Exception(f"HEIF 解码失败 ({os.path.basename(image_path)}): {e}") -# ==================== GPS 提取 ==================== -def extract_gps_from_exif(image_path: str) -> Tuple[Optional[float], Optional[float], str]: - """ - 从图像 EXIF 提取 GPS 坐标 - 支持 RAW 文件(使用 exiftool) - - Returns: - (纬度, 经度, 信息) 或 (None, None, 错误信息) - """ +def extract_gps_from_exif( + image_path: str, +) -> Tuple[Optional[float], Optional[float], str]: import subprocess import json as json_module - - # 首先尝试使用 exiftool(支持 RAW 格式) + try: - # 查找 exiftool exiftool_paths = [ - '/usr/local/bin/exiftool', - '/opt/homebrew/bin/exiftool', - 'exiftool', # 在 PATH 中查找 + "/usr/local/bin/exiftool", + "/opt/homebrew/bin/exiftool", + "exiftool", ] - + exiftool_path = None for path in exiftool_paths: try: - result = subprocess.run([path, '-ver'], capture_output=True, text=False, timeout=5) + result = subprocess.run( + [path, "-ver"], capture_output=True, text=False, timeout=5 + ) if result.returncode == 0: - # 解码输出 stdout_bytes = result.stdout - # 尝试多种编码解码 decoded_output = None - for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: + for encoding in ["utf-8", "gbk", "gb2312", "latin-1"]: try: decoded_output = stdout_bytes.decode(encoding) break except UnicodeDecodeError: continue - + if decoded_output is None: - # 如果所有编码都失败,使用 latin-1 作为最后手段(不会失败) - decoded_output = stdout_bytes.decode('latin-1') - - # 检查是否成功获取版本 + decoded_output = stdout_bytes.decode("latin-1") + if decoded_output.strip(): exiftool_path = path break except: continue - + if exiftool_path: - # 使用 exiftool 提取 GPS 信息 result = subprocess.run( - [exiftool_path, '-j', '-GPSLatitude', '-GPSLongitude', '-GPSLatitudeRef', '-GPSLongitudeRef', image_path], + [ + exiftool_path, + "-j", + "-GPSLatitude", + "-GPSLongitude", + "-GPSLatitudeRef", + "-GPSLongitudeRef", + image_path, + ], capture_output=True, - text=False, # 使用 bytes 模式,避免自动解码 - timeout=10 + text=False, + timeout=10, ) - + if result.returncode == 0 and result.stdout: stdout_bytes = result.stdout - # 尝试多种编码解码 decoded_output = None - for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: + for encoding in ["utf-8", "gbk", "gb2312", "latin-1"]: try: decoded_output = stdout_bytes.decode(encoding) break except UnicodeDecodeError: continue - + if decoded_output is None: - # 如果所有编码都失败,使用 latin-1 作为最后手段(不会失败) - decoded_output = stdout_bytes.decode('latin-1') - + decoded_output = stdout_bytes.decode("latin-1") + data = json_module.loads(decoded_output) if data and len(data) > 0: gps_data = data[0] - - lat_str = gps_data.get('GPSLatitude', '') - lon_str = gps_data.get('GPSLongitude', '') - lat_ref = gps_data.get('GPSLatitudeRef', 'N') - lon_ref = gps_data.get('GPSLongitudeRef', 'E') - + + lat_str = gps_data.get("GPSLatitude", "") + lon_str = gps_data.get("GPSLongitude", "") + lat_ref = gps_data.get("GPSLatitudeRef", "N") + lon_ref = gps_data.get("GPSLongitudeRef", "E") + if lat_str and lon_str: - # 解析度分秒格式,如 "27 deg 25' 0.53\" S" + def parse_dms(dms_str): import re - match = re.search(r'(\d+)\s*deg\s*(\d+)\'\s*([\d.]+)"?', str(dms_str)) + + match = re.search( + r'(\d+)\s*deg\s*(\d+)\'\s*([\d.]+)"?', str(dms_str) + ) if match: - d, m, s = float(match.group(1)), float(match.group(2)), float(match.group(3)) - return d + m/60 + s/3600 - # 尝试直接作为数字解析 + d, m, s = ( + float(match.group(1)), + float(match.group(2)), + float(match.group(3)), + ) + return d + m / 60 + s / 3600 try: return float(dms_str) except: return None - + lat = parse_dms(lat_str) lon = parse_dms(lon_str) - + if lat is not None and lon is not None: - # 处理南纬 (S 或 South) - if lat_ref and lat_ref.upper().startswith('S'): + if lat_ref and lat_ref.upper().startswith("S"): lat = -lat - # 处理西经 (W 或 West) - if lon_ref and lon_ref.upper().startswith('W'): + if lon_ref and lon_ref.upper().startswith("W"): lon = -lon - print(_t("logs.gps_extracted", lat=f"{lat:.6f}", lon=f"{lon:.6f}")) return lat, lon, f"GPS: {lat:.6f}, {lon:.6f}" except Exception as e: - print(_t("logs.gps_failed", e=e)) - - # 回退到 PIL(仅支持 JPEG 等常规格式) + pass + try: image = Image.open(image_path) - exif_data = image._getexif() + exif_data = image.getexif() if not exif_data: return None, None, "无EXIF数据" @@ -641,18 +645,22 @@ def parse_dms(dms_str): def convert_to_degrees(coord, ref): d, m, s = coord decimal = d + (m / 60.0) + (s / 3600.0) - if ref in ['S', 'W']: + if ref in ["S", "W"]: decimal = -decimal return decimal lat = None lon = None - if 'GPSLatitude' in gps_info and 'GPSLatitudeRef' in gps_info: - lat = convert_to_degrees(gps_info['GPSLatitude'], gps_info['GPSLatitudeRef']) + if "GPSLatitude" in gps_info and "GPSLatitudeRef" in gps_info: + lat = convert_to_degrees( + gps_info["GPSLatitude"], gps_info["GPSLatitudeRef"] + ) - if 'GPSLongitude' in gps_info and 'GPSLongitudeRef' in gps_info: - lon = convert_to_degrees(gps_info['GPSLongitude'], gps_info['GPSLongitudeRef']) + if "GPSLongitude" in gps_info and "GPSLongitudeRef" in gps_info: + lon = convert_to_degrees( + gps_info["GPSLongitude"], gps_info["GPSLongitudeRef"] + ) if lat is not None and lon is not None: return lat, lon, f"GPS: {lat:.6f}, {lon:.6f}" @@ -663,24 +671,20 @@ def convert_to_degrees(coord, ref): return None, None, f"GPS解析失败: {e}" -# ==================== 图像预处理 ==================== - def smart_resize(image: Image.Image, target_size: int = 224) -> Image.Image: - """智能图像尺寸调整""" width, height = image.size max_dim = max(width, height) if max_dim < 1000: - return image.resize((target_size, target_size), Image.LANCZOS) + return image.resize((target_size, target_size), RESAMPLING_LANCZOS) - resized = image.resize((256, 256), Image.LANCZOS) + resized = image.resize((256, 256), RESAMPLING_LANCZOS) left = (256 - target_size) // 2 top = (256 - target_size) // 2 return resized.crop((left, top, left + target_size, top + target_size)) def apply_enhancement(image: Image.Image, method: str = "unsharp_mask") -> Image.Image: - """应用图像增强""" if method == "unsharp_mask": return image.filter(ImageFilter.UnsharpMask()) elif method == "edge_enhance_more": @@ -694,72 +698,51 @@ def apply_enhancement(image: Image.Image, method: str = "unsharp_mask") -> Image return image -# ==================== OSEA 预处理 ==================== - -# CenterCrop 预处理: Resize(256) + CenterCrop(224) + ImageNet Normalize -# 用于原始大图(未经 YOLO 裁剪) -OSEA_TRANSFORM = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) - -# 直接缩放预处理: Resize(224, 224) with Lanczos + ImageNet Normalize -# 用于 YOLO 裁剪后的正方形图片(已经过 Letterboxing 处理) -# 使用 Lanczos 插值保证高质量缩放 -OSEA_TRANSFORM_DIRECT = transforms.Compose([ - transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) - +OSEA_TRANSFORM = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) + +OSEA_TRANSFORM_DIRECT = transforms.Compose( + [ + transforms.Resize( + (224, 224), interpolation=transforms.InterpolationMode.LANCZOS + ), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) -# ==================== 核心识别函数 ==================== def predict_bird( image: Image.Image, top_k: int = 5, species_class_ids: Optional[Set[int]] = None, is_yolo_cropped: bool = False, - name_format: str = None + name_format: Optional[str] = None, ) -> List[Dict]: - """ - 识别鸟类(OSEA ResNet34) - - Args: - image: PIL Image 对象 - top_k: 返回前 K 个结果 - species_class_ids: 区域物种的 class_id 集合(用于过滤) - is_yolo_cropped: 图片是否经过 YOLO 裁剪(用于选择预处理方式) - - Returns: - 识别结果列表 [{cn_name, en_name, confidence, ebird_code, ...}, ...] - """ model = get_classifier() db_manager = get_database_manager() - # 根据是否经过 YOLO 裁剪选择预处理方式 - # - YOLO 裁剪后: 直接 Resize(224,224),避免 CenterCrop 丢失特征 - # - 原始大图: Resize(256) + CenterCrop(224),鸟在中心时效果更好 - if image.mode != 'RGB': - image = image.convert('RGB') + if image.mode != "RGB": + image = image.convert("RGB") transform = OSEA_TRANSFORM_DIRECT if is_yolo_cropped else OSEA_TRANSFORM - input_tensor = transform(image).unsqueeze(0).to(CLASSIFIER_DEVICE) + transformed_tensor = cast(torch.Tensor, transform(image)) + input_tensor = transformed_tensor.unsqueeze(0).to(CLASSIFIER_DEVICE) - # 推理 with torch.no_grad(): output = model(input_tensor)[0] - # 截取有效类别数(模型输出可能多于实际物种数) num_classes = min(10964, output.shape[0]) output = output[:num_classes] - # Softmax(温度=0.9 更平滑:降低过高置信度,避免 99%+ 输出) TEMPERATURE = 0.9 best_probs = torch.nn.functional.softmax(output / TEMPERATURE, dim=0) - # 获取 top-k 结果 k = min(100 if species_class_ids else top_k, len(best_probs)) top_probs, top_indices = torch.topk(best_probs, k) @@ -767,7 +750,6 @@ def predict_bird( for i in range(len(top_indices)): class_id = top_indices[i].item() confidence = top_probs[i].item() * 100 - # 置信度阈值:使用区域过滤时降低阈值以保留更多候选 min_confidence = 0.3 if species_class_ids else 1.0 if confidence < min_confidence: continue @@ -778,54 +760,57 @@ def predict_bird( ebird_code = None description = None - # 优先从数据库获取 if db_manager: info = db_manager.get_bird_by_class_id(class_id) if info: - cn_name = info.get('chinese_simplified') - en_name = info.get('english_name') - scientific_name = info.get('scientific_name') - ebird_code = info.get('ebird_code') - description = info.get('short_description_zh') + cn_name = info.get("chinese_simplified") + en_name = info.get("english_name") + scientific_name = info.get("scientific_name") + ebird_code = info.get("ebird_code") + description = info.get("short_description_zh") if not cn_name: cn_name = f"Unknown (ID: {class_id})" en_name = f"Unknown (ID: {class_id})" - # AviList name format override if name_format and name_format != "default" and db_manager: avilist_info = db_manager.get_avilist_names_by_class_id(class_id) - if avilist_info and avilist_info.get('match_type') != 'no_match': + if avilist_info and avilist_info.get("match_type") != "no_match": if name_format == "scientific": - en_name = avilist_info.get('scientific_name_avilist') or scientific_name or en_name + en_name = ( + avilist_info.get("scientific_name_avilist") + or scientific_name + or en_name + ) else: - # Map format to column: avilist/clements/birdlife col = f"en_name_{name_format}" alt_name = avilist_info.get(col) - # Fallback chain: selected -> avilist -> keep default if alt_name: en_name = alt_name - elif name_format != "avilist" and avilist_info.get('en_name_avilist'): - en_name = avilist_info['en_name_avilist'] + elif name_format != "avilist" and avilist_info.get( + "en_name_avilist" + ): + en_name = avilist_info["en_name_avilist"] - # Avonet 地理过滤 region_match = False if species_class_ids: if class_id in species_class_ids: region_match = True else: - continue # 不在区域物种列表中,跳过 - - results.append({ - 'class_id': class_id, - 'cn_name': cn_name, - 'en_name': en_name, - 'scientific_name': scientific_name, - 'confidence': confidence, - 'ebird_code': ebird_code, - 'region_match': region_match, - 'description': description or '' - }) + continue + + results.append( + { + "class_id": class_id, + "cn_name": cn_name, + "en_name": en_name, + "scientific_name": scientific_name, + "confidence": confidence, + "ebird_code": ebird_code, + "region_match": region_match, + "description": description or "", + } + ) if len(results) >= top_k: break @@ -838,205 +823,157 @@ def identify_bird( use_yolo: bool = True, use_gps: bool = True, use_ebird: bool = True, - country_code: str = None, - region_code: str = None, + country_code: Optional[str] = None, + region_code: Optional[str] = None, top_k: int = 5, - name_format: str = None, - preloaded_crop=None, # PIL Image,主流水线已裁剪好时传入,跳过重复 YOLO + name_format: Optional[str] = None, + preloaded_crop: Optional[Image.Image] = None, ) -> Dict: - """ - 端到端鸟类识别 - - Args: - image_path: 图像路径(仍用于 GPS 提取) - use_yolo: 是否使用 YOLO 裁剪(preloaded_crop 存在时忽略) - use_gps: 是否使用 GPS 自动检测区域 - use_ebird: 是否启用 eBird 区域过滤 - country_code: 手动指定国家代码(如 "AU") - region_code: 手动指定区域代码(如 "AU-SA") - top_k: 返回前 K 个结果 - preloaded_crop: 预裁剪的鸟类区域 PIL Image,由调用方传入时跳过 YOLO - - Returns: - 识别结果字典 - """ result = { - 'success': False, - 'image_path': image_path, - 'results': [], - 'yolo_info': None, - 'gps_info': None, - 'ebird_info': None, - 'error': None + "success": False, + "image_path": image_path, + "results": [], + "yolo_info": None, + "gps_info": None, + "ebird_info": None, + "error": None, } try: - # 若调用方已提供裁剪好的鸟类区域,直接使用,跳过图像加载和 YOLO + is_yolo_cropped = False if preloaded_crop is not None: image = preloaded_crop is_yolo_cropped = True - result['yolo_info'] = {'preloaded': True} + result["yolo_info"] = {"preloaded": True} else: - # 加载图像 image = load_image(image_path) - # YOLO 裁剪(preloaded_crop 存在时已跳过) - if preloaded_crop is None: - is_yolo_cropped = False - print(f"[YOLO] use_yolo={use_yolo}, YOLO_AVAILABLE={YOLO_AVAILABLE}") if preloaded_crop is None and use_yolo and YOLO_AVAILABLE: width, height = image.size - print(f"[YOLO] image size: {width}x{height}") if max(width, height) > 640: detector = get_yolo_detector() - print(f"[YOLO] detector={detector is not None}") if detector: cropped, info = detector.detect_and_crop_bird(image) - print(f"[YOLO] detect result: cropped={cropped is not None}, info={info}") if cropped: image = cropped - result['yolo_info'] = info - result['cropped_image'] = cropped # square-cropped PIL Image + result["yolo_info"] = info + result["cropped_image"] = cropped is_yolo_cropped = True - print(f"[YOLO] ✅ Bird region cropped") else: - print(f"[YOLO] ⚠️ No bird detected") - # strict mode: no bird found, short-circuit - result['success'] = True - result['results'] = [] - result['yolo_info'] = {'bird_count': 0} + result["success"] = True + result["results"] = [] + result["yolo_info"] = {"bird_count": 0} return result - else: - print(f"[YOLO] Image too small, skipping crop") - else: - print(f"[YOLO] YOLO not enabled or unavailable") - # Avonet 地理过滤 species_class_ids = None - lat = lon = None # GPS 坐标(供后续回退使用) - species_filter = None # 物种过滤器(供后续回退使用) + lat = lon = None + species_filter = None - if use_ebird: # 参数名保持兼容,实际使用 Avonet + if use_ebird: try: species_filter = get_species_filter() - if not species_filter: - print(_t("logs.avonet_unavailable")) - else: - # 优先使用 GPS 坐标 + if species_filter: if use_gps: lat, lon, gps_msg = extract_gps_from_exif(image_path) if lat and lon: - result['gps_info'] = { - 'latitude': lat, - 'longitude': lon, - 'info': gps_msg + result["gps_info"] = { + "latitude": lat, + "longitude": lon, + "info": gps_msg, } - species_class_ids = species_filter.get_species_by_gps(lat, lon) - if species_class_ids: - print(f"[Avonet] GPS ({lat:.2f}, {lon:.2f}): {len(species_class_ids)} species") + species_class_ids = species_filter.get_species_by_gps( + lat, lon + ) - # 回退到区域代码(优先 eBird 离线物种列表,其次 Avonet 边界) if species_class_ids is None and (region_code or country_code): effective_region = region_code or country_code - # 优先使用 eBird 离线物种 JSON(精确到州/省) try: - ebird_ids, actual_region = species_filter.get_species_by_region_ebird(effective_region) + ebird_ids, actual_region = ( + species_filter.get_species_by_region_ebird( + effective_region + ) + ) if ebird_ids: species_class_ids = ebird_ids - print(f"[eBird] Region {effective_region}: {len(species_class_ids)} species (offline JSON)") except Exception as _e: - print(f"[eBird] State filter failed: {_e}") - # 如果 eBird 数据不可用,回退到 Avonet 边界框查询 + pass if not species_class_ids: - species_class_ids = species_filter.get_species_by_region(effective_region) - if species_class_ids: - print(f"[Avonet] Region {effective_region}: {len(species_class_ids)} species (bounds)") + species_class_ids = species_filter.get_species_by_region( + effective_region + ) - # 记录过滤信息 if species_class_ids: - result['ebird_info'] = { # 保持键名兼容 - 'enabled': True, - 'species_count': len(species_class_ids), - 'data_source': 'avonet.db (offline)', - 'region_code': region_code or country_code if not result.get('gps_info') else None + result["ebird_info"] = { + "enabled": True, + "species_count": len(species_class_ids), + "data_source": "avonet.db (offline)", + "region_code": ( + region_code or country_code + if not result.get("gps_info") + else None + ), } except Exception as e: - print(f"[Avonet] Filter init failed: {e}") + pass - # 执行识别 results = predict_bird( image, top_k=top_k, species_class_ids=species_class_ids, is_yolo_cropped=is_yolo_cropped, - name_format=name_format + name_format=name_format, ) - # GPS 过滤无匹配时,先尝试 eBird 国家级回退,再全局 if not results and species_class_ids: - print(f"[Avonet] ⚠️ No match after GPS filter ({len(species_class_ids)} species), trying eBird country fallback") - - # 第一步:eBird 国家级回退 country_cls_ids = None country_cc = None if lat is not None and lon is not None and species_filter is not None: try: - country_cls_ids, country_cc = species_filter.get_species_by_country_ebird(lat, lon) + country_cls_ids, country_cc = ( + species_filter.get_species_by_country_ebird(lat, lon) + ) except Exception as _e: - print(f"[eBird] Country fallback failed: {_e}") + pass if country_cls_ids: - print(f"[eBird] Trying country fallback: {country_cc} ({len(country_cls_ids)} species)") results = predict_bird( image, top_k=top_k, species_class_ids=country_cls_ids, is_yolo_cropped=is_yolo_cropped, - name_format=name_format + name_format=name_format, ) if results: - if not result.get('ebird_info'): - result['ebird_info'] = {} - result['ebird_info']['country_fallback'] = True - result['ebird_info']['country_code'] = country_cc + if not result.get("ebird_info"): + result["ebird_info"] = {} + result["ebird_info"]["country_fallback"] = True + result["ebird_info"]["country_code"] = country_cc - # 第二步:仍无结果 → 全局模式 if not results: - print(f"[Avonet] ⚠️ Country fallback still no match, switching to global mode") results = predict_bird( image, top_k=top_k, species_class_ids=None, is_yolo_cropped=is_yolo_cropped, - name_format=name_format + name_format=name_format, ) - if results and result.get('ebird_info'): - result['ebird_info']['gps_fallback'] = True + if results and result.get("ebird_info"): + result["ebird_info"]["gps_fallback"] = True - result['success'] = True - result['results'] = results + result["success"] = True + result["results"] = results except Exception as e: - result['error'] = str(e) + result["error"] = str(e) return result -# ==================== 便捷函数 ==================== - def quick_identify(image_path: str, top_k: int = 3) -> List[Dict]: - """ - 快速识别(简化接口) - - Returns: - 识别结果列表 - """ result = identify_bird(image_path, top_k=top_k) - return result.get('results', []) - + return result.get("results", []) -# ==================== 测试 ==================== if __name__ == "__main__": print("BirdIdentifier 模块测试") diff --git a/birdid/osea_classifier.py b/birdid/osea_classifier.py index f5b6b5e5..71716190 100644 --- a/birdid/osea_classifier.py +++ b/birdid/osea_classifier.py @@ -14,14 +14,17 @@ import os import sqlite3 -import sys from pathlib import Path from typing import Dict, List, Optional, Set import torch from PIL import Image from torchvision import models, transforms -from config import get_best_device +from config import ( + get_best_device, + get_install_scoped_resource_path, + get_packaged_model_relative_path, +) def _torch_load_compat(path: str, *, map_location: str, weights_only: bool): @@ -73,11 +76,14 @@ def _load_osea_checkpoint(model_path: str): return _torch_load_compat(model_path, map_location="cpu", weights_only=True) except Exception as e: if _should_retry_without_weights_only(e): - print("[OSEA] weights_only=True 加载失败,回退 weights_only=False(仅限可信模型)") - return _torch_load_compat(model_path, map_location="cpu", weights_only=False) + print( + "[OSEA] weights_only=True 加载失败,回退 weights_only=False(仅限可信模型)" + ) + return _torch_load_compat( + model_path, map_location="cpu", weights_only=False + ) raise -# ==================== 路径配置 ==================== def _get_birdid_dir() -> Path: """获取 birdid 模块目录""" @@ -90,38 +96,36 @@ def _get_project_root() -> Path: def _get_resource_path(relative_path: str) -> Path: - """获取资源路径 (支持 PyInstaller 打包)""" - if getattr(sys, 'frozen', False): - base = Path(sys._MEIPASS) - else: - base = _get_project_root() - return base / relative_path - - -# ==================== 设备配置 ==================== + """获取资源路径 (支持安装目录约束的打包场景)""" + packaged_relative_path = None + if relative_path.startswith("models/"): + packaged_relative_path = get_packaged_model_relative_path(relative_path) + return get_install_scoped_resource_path( + relative_path, packaged_relative_path=packaged_relative_path + ) DEVICE = get_best_device() -# ==================== 预处理 transforms ==================== - -CENTER_CROP_TRANSFORM = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) +CENTER_CROP_TRANSFORM = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) -BASELINE_TRANSFORM = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) +BASELINE_TRANSFORM = transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) -# ==================== OSEA 分类器 ==================== - class OSEAClassifier: """ OSEA ResNet34 鸟类分类器 @@ -152,7 +156,9 @@ def __init__( """ self.device = device or DEVICE self.use_center_crop = use_center_crop - self.transform = CENTER_CROP_TRANSFORM if use_center_crop else BASELINE_TRANSFORM + self.transform = ( + CENTER_CROP_TRANSFORM if use_center_crop else BASELINE_TRANSFORM + ) self.model_path = model_path or str(_get_resource_path(self.DEFAULT_MODEL_PATH)) self.model = self._load_model() @@ -199,13 +205,15 @@ def _load_bird_info(self) -> List[List[str]]: conn.close() num_classes = 10964 - bird_info: List[List[str]] = [['Unknown', 'Unknown', ''] for _ in range(num_classes)] + bird_info: List[List[str]] = [ + ["Unknown", "Unknown", ""] for _ in range(num_classes) + ] for class_id, cn_name, en_name, scientific_name in rows: if 0 <= class_id < num_classes: bird_info[class_id] = [ - cn_name or 'Unknown', - en_name or 'Unknown', - scientific_name or '', + cn_name or "Unknown", + en_name or "Unknown", + scientific_name or "", ] return bird_info @@ -228,8 +236,8 @@ def predict( Returns: 识别结果列表 [{cn_name, en_name, scientific_name, confidence, class_id}, ...] """ - if image.mode != 'RGB': - image = image.convert('RGB') + if image.mode != "RGB": + image = image.convert("RGB") input_tensor = self.transform(image).unsqueeze(0).to(self.device) @@ -237,7 +245,7 @@ def predict( with torch.no_grad(): output = self.model(input_tensor)[0] - output = output[:self.num_classes] + output = output[: self.num_classes] probs = torch.nn.functional.softmax(output / temperature, dim=0) k = min(100 if ebird_species_set else top_k, self.num_classes) @@ -257,16 +265,16 @@ def predict( en_name = info[1] scientific_name = info[2] if len(info) > 2 else None - ebird_match = False - - results.append({ - 'class_id': class_id, - 'cn_name': cn_name, - 'en_name': en_name, - 'scientific_name': scientific_name, - 'confidence': confidence, - 'ebird_match': ebird_match, - }) + results.append( + { + "class_id": class_id, + "cn_name": cn_name, + "en_name": en_name, + "scientific_name": scientific_name, + "confidence": confidence, + "ebird_match": False, + } + ) if len(results) >= top_k: break @@ -286,8 +294,8 @@ def predict_with_tta( TTA 策略: 原图 + 水平翻转取平均 推理时间翻倍,但可能提高准确率 """ - if image.mode != 'RGB': - image = image.convert('RGB') + if image.mode != "RGB": + image = image.convert("RGB") input1 = self.transform(image).unsqueeze(0).to(self.device) @@ -296,8 +304,8 @@ def predict_with_tta( self.model.eval() with torch.no_grad(): - output1 = self.model(input1)[0][:self.num_classes] - output2 = self.model(input2)[0][:self.num_classes] + output1 = self.model(input1)[0][: self.num_classes] + output2 = self.model(input2)[0][: self.num_classes] avg_output = (output1 + output2) / 2 probs = torch.nn.functional.softmax(avg_output / temperature, dim=0) @@ -319,14 +327,16 @@ def predict_with_tta( en_name = info[1] scientific_name = info[2] if len(info) > 2 else None - results.append({ - 'class_id': class_id, - 'cn_name': cn_name, - 'en_name': en_name, - 'scientific_name': scientific_name, - 'confidence': confidence, - 'ebird_match': False, - }) + results.append( + { + "class_id": class_id, + "cn_name": cn_name, + "en_name": en_name, + "scientific_name": scientific_name, + "confidence": confidence, + "ebird_match": False, + } + ) if len(results) >= top_k: break @@ -334,8 +344,6 @@ def predict_with_tta( return results -# ==================== 全局单例 ==================== - _osea_classifier: Optional[OSEAClassifier] = None @@ -347,8 +355,6 @@ def get_osea_classifier() -> OSEAClassifier: return _osea_classifier -# ==================== 便捷函数 ==================== - def osea_predict(image: Image.Image, top_k: int = 5) -> List[Dict]: """快速 OSEA 预测""" classifier = get_osea_classifier() @@ -358,12 +364,11 @@ def osea_predict(image: Image.Image, top_k: int = 5) -> List[Dict]: def osea_predict_file(image_path: str, top_k: int = 5) -> List[Dict]: """OSEA 预测 (从文件路径)""" from birdid.bird_identifier import load_image + image = load_image(image_path) return osea_predict(image, top_k=top_k) -# ==================== 测试 ==================== - if __name__ == "__main__": import argparse @@ -374,6 +379,7 @@ def osea_predict_file(image_path: str, top_k: int = 5) -> List[Dict]: args = parser.parse_args() from birdid.bird_identifier import load_image + image = load_image(args.image) classifier = OSEAClassifier() diff --git a/build_release.bat b/build_release.bat index e358fe9e..1ac60924 100644 --- a/build_release.bat +++ b/build_release.bat @@ -1,404 +1,46 @@ @echo off -chcp 65001 >nul -setlocal EnableExtensions EnableDelayedExpansion +setlocal EnableExtensions -set "APP_NAME=SuperPicky" -set "SPEC_FILE=SuperPicky_win64.spec" -set "ROOT_DIR=%~dp0" -set "ROOT_DIR=%ROOT_DIR:~0,-1%" -cd /d "%ROOT_DIR%" +set "SCRIPT_DIR=%~dp0" +set "PYTHON_EXE=%SCRIPT_DIR%.venv\Scripts\python.exe" +if not exist "%PYTHON_EXE%" set "PYTHON_EXE=python" -set "VERSION_ARG=" -set "ZIP_COPY_DIR=" - -if "!OUT_DIST_DIR!"=="" set "OUT_DIST_DIR=dist" - -set "BUILD_ZIP=1" - -call :parse_args %* -if errorlevel 1 exit /b 1 -if defined SHOW_HELP goto :show_help - -goto :start - -:show_help -echo SuperPicky Windows build script -echo. -echo Usage: -echo %~nx0 [version] [zip_copy_dir] -echo. -echo version 版本号 ^(如 4.0.6^),用于 ZIP 文件名;缺省则从 ui/about_dialog.py 读取 -echo zip_copy_dir 目标目录 ^(如 E:\_SuperPickyVersions^);若指定则复制 SuperPicky 为 SuperPicky_版本号 并打 zip -echo. -exit /b 0 - -:parse_args -:parse_args_loop -if "%~1"=="" exit /b 0 - -if /i "%~1"=="--help" ( - set "SHOW_HELP=1" - exit /b 0 -) -if /i "%~1"=="-h" ( - set "SHOW_HELP=1" - exit /b 0 -) -if "%VERSION_ARG%"=="" ( - set "VERSION_ARG=%~1" -) else if "!ZIP_COPY_DIR!"=="" ( - set "ZIP_COPY_DIR=%~1" -) else ( - echo [WARNING] Ignored extra argument: %~1 -) -shift -goto :parse_args_loop - -:start -echo. -echo [========================================] -echo Step 0: Clean old build files -echo [========================================] - -rem Set Inno Setup directory -set "INNO_DIR=%ROOT_DIR%\inno" - -rem Clean old build directories -if exist "%ROOT_DIR%\build_dist" rd /s /q "%ROOT_DIR%\build_dist" >nul 2>&1 -if exist "%ROOT_DIR%\build_dist_cpu" rd /s /q "%ROOT_DIR%\build_dist_cpu" >nul 2>&1 -if exist "%ROOT_DIR%\build_dist_cuda" rd /s /q "%ROOT_DIR%\build_dist_cuda" >nul 2>&1 -if exist "%ROOT_DIR%\dist" rd /s /q "%ROOT_DIR%\dist" >nul 2>&1 -if exist "%ROOT_DIR%\dist_cpu" rd /s /q "%ROOT_DIR%\dist_cpu" >nul 2>&1 -if exist "%ROOT_DIR%\dist_cuda" rd /s /q "%ROOT_DIR%\dist_cuda" >nul 2>&1 -if exist "%ROOT_DIR%\output" rd /s /q "%ROOT_DIR%\output" >nul 2>&1 - -echo [SUCCESS] Cleaned old build files - -echo. -echo [========================================] -echo Step 1: Environment check -echo [========================================] - -if not exist "%SPEC_FILE%" ( - echo [ERROR] Missing spec file: %SPEC_FILE% - exit /b 1 -) - -echo [SUCCESS] Spec file found: %SPEC_FILE% - -if "!PYTHON_EXE!"=="" set "PYTHON_EXE=python" -rem Prefer current env Python (e.g. activated venv): use first "python" in PATH -if "!PYTHON_EXE!"=="python" ( - where python >nul 2>nul && for /f "tokens=*" %%i in ('python -c "import sys; print(sys.executable)" 2^>nul') do set "PYTHON_EXE=%%i" -) -if "!PYTHON_EXE!"=="" set "PYTHON_EXE=python" -call :check_python "!PYTHON_EXE!" "default" -if errorlevel 1 exit /b 1 - -echo. -echo [========================================] -echo Step 1: Resolve version -echo [========================================] - -set "VERSION=4.0.5_sp3" -if not "!VERSION_ARG!"=="" ( - set "VERSION=!VERSION_ARG!" - echo [SUCCESS] Use version from args: !VERSION! -) else ( - for /f "usebackq delims=" %%i in (`powershell -NoProfile -Command "$c=Get-Content -Path 'ui/about_dialog.py' -Raw -Encoding UTF8; if($c -match 'v([0-9A-Za-z._-]+)'){ $matches[1] }"`) do ( - set "VERSION=%%i" - ) - if "!VERSION!"=="" set "VERSION=0.0.0" - echo [SUCCESS] Detected version: v!VERSION! -) - -echo. -echo [========================================] -echo Step 1.5: Inject build metadata -echo [========================================] - -set "COMMIT_HASH=unknown" -rem 优先从 Python 代码读取 COMMIT_HASH(保证跨平台一致) -for /f "tokens=*" %%i in ('"%PYTHON_EXE%" -c "exec('try:\n from core.build_info_local import COMMIT_HASH\nexcept ImportError:\n from core.build_info import COMMIT_HASH\nprint(COMMIT_HASH or chr(0))')" 2^>nul') do set "COMMIT_HASH=%%i" -if "%COMMIT_HASH%"=="" for /f "tokens=*" %%i in ('git rev-parse --short HEAD 2^>nul') do set "COMMIT_HASH=%%i" -if "%COMMIT_HASH%"=="" set "COMMIT_HASH=unknown" -echo [INFO] Commit hash: %COMMIT_HASH% - -set "BUILD_INFO_FILE=core\build_info.py" -set "BUILD_INFO_BACKUP=core\build_info.py.backup" -if exist "%BUILD_INFO_FILE%" copy /y "%BUILD_INFO_FILE%" "%BUILD_INFO_BACKUP%" >nul - -powershell -NoProfile -Command "(Get-Content -Path '%BUILD_INFO_FILE%' -Raw -Encoding UTF8) -replace 'COMMIT_HASH\s*=\s*.*', 'COMMIT_HASH = \"%COMMIT_HASH%\"' | Set-Content -Path '%BUILD_INFO_FILE%' -Encoding UTF8" -if errorlevel 1 ( - echo [ERROR] Failed to inject build info - call :restore_build_info >nul - exit /b 1 -) +if /I "%~1"=="--help" goto :show_help +if /I "%~1"=="-h" goto :show_help -echo [SUCCESS] Build info injected +set "FIRST_ARG=%~1" -call :build_single -set "RET=%ERRORLEVEL%" -call :restore_build_info >nul -exit /b %RET% +if "%~1"=="" goto :run_default +if "%FIRST_ARG:~0,2%"=="--" goto :run_passthrough +if not "%~3"=="" goto :show_positional_error -:check_python -set "CHECK_PY=%~1" -set "CHECK_LABEL=%~2" - -echo [INFO] Checking Python (%CHECK_LABEL%): %CHECK_PY% -"%CHECK_PY%" -c "import sys; print(sys.executable)" >nul 2>nul -if errorlevel 1 ( - echo [ERROR] Python not available: %CHECK_PY% - exit /b 1 -) -for /f "tokens=*" %%i in ('"%CHECK_PY%" -c "import sys; print(sys.executable)" 2^>nul') do set "_PY_RESOLVED=%%i" -echo [SUCCESS] Python (%CHECK_LABEL%): !_PY_RESOLVED! - -echo [INFO] Checking PyInstaller (%CHECK_LABEL%)... -"%CHECK_PY%" -c "import PyInstaller" >nul 2>nul -if errorlevel 1 ( - echo [ERROR] PyInstaller missing in %CHECK_LABEL% environment - exit /b 1 -) -echo [SUCCESS] PyInstaller is available (%CHECK_LABEL%) -exit /b 0 - -:build_with_python -set "B_PY=%~1" -set "B_WORK=%~2" -set "B_DIST=%~3" -set "B_LABEL=%~4" - -echo. -echo [========================================] -echo Build: %B_LABEL% -echo [========================================] - -if exist "%B_WORK%" rd /s /q "%B_WORK%" -if exist "%B_DIST%" rd /s /q "%B_DIST%" - -"%B_PY%" -m PyInstaller "%SPEC_FILE%" --clean --noconfirm --workpath "%B_WORK%" --distpath "%B_DIST%" -set "PYI_RC=%ERRORLEVEL%" -echo [INFO] PyInstaller process rc (%B_LABEL%): %PYI_RC% -if not "%PYI_RC%"=="0" ( - echo [WARNING] PyInstaller returned non-zero [%B_LABEL%]: %PYI_RC% -) - -if not exist "%B_DIST%\%APP_NAME%\SuperPicky.exe" ( - echo [ERROR] Missing output exe: %B_DIST%\%APP_NAME%\SuperPicky.exe - exit /b 1 -) - -echo [SUCCESS] Build completed (%B_LABEL%) -exit /b 0 - -rem Copy folder C_SRC into C_DST using robocopy (more reliable than xcopy for deep trees). -:copy_dir -set "C_SRC=%~1" -set "C_DST=%~2" - -if not exist "%C_SRC%" ( - echo [ERROR] Copy source not found: %C_SRC% - exit /b 1 -) - -if not exist "%C_DST%" mkdir "%C_DST%" -if errorlevel 1 ( - echo [ERROR] Failed to create target dir: %C_DST% - exit /b 1 -) - -robocopy "%C_SRC%" "%C_DST%" /E /R:2 /W:1 /NFL /NDL /NJH /NJS /NP >nul -set "COPY_RC=%ERRORLEVEL%" -if !COPY_RC! GEQ 8 ( - echo [ERROR] Failed to copy to %C_DST% ^(robocopy exit code !COPY_RC!^) - exit /b 1 -) - -echo [SUCCESS] Copied directory: %C_SRC% -^> %C_DST% -exit /b 0 - -rem Zip folder Z_SRC into Z_OUT. Archive contains one top-level folder (e.g. SuperPicky\) so unzip gives one dir. -:zip_dir -set "Z_SRC=%~1" -set "Z_OUT=%~2" - -if not exist "%Z_SRC%" ( - echo [ERROR] Zip source not found: %Z_SRC% - exit /b 1 -) - -if exist "%Z_OUT%" del /q "%Z_OUT%" >nul 2>&1 - -where 7z >nul 2>&1 -if not errorlevel 1 ( - 7z a -tzip "%Z_OUT%" "%Z_SRC%" -r >nul - if errorlevel 1 ( - echo [ERROR] Failed to create zip with 7z: %Z_OUT% - exit /b 1 - ) -) else ( - powershell -NoProfile -Command "Compress-Archive -Path '%Z_SRC%' -DestinationPath '%Z_OUT%' -Force" - if errorlevel 1 ( - echo [ERROR] Failed to create zip with Compress-Archive: %Z_OUT% - exit /b 1 - ) -) - -echo [SUCCESS] Created zip: %Z_OUT% -exit /b 0 +set "VERSION_ARG=" +set "COPY_DIR_ARG=" -:build_single -set "WORK_DIR=%ROOT_DIR%\build_!OUT_DIST_DIR!" -set "DIST_DIR=%ROOT_DIR%\!OUT_DIST_DIR!" +if not "%~1"=="" set "VERSION_ARG=--version %~1" +if not "%~2"=="" set "COPY_DIR_ARG=--copy-dir %~2" -call :build_with_python "%PYTHON_EXE%" "!WORK_DIR!" "!DIST_DIR!" "release" -set "BUILD_RC=%ERRORLEVEL%" -echo [INFO] build_with_python rc: !BUILD_RC! -if !BUILD_RC! NEQ 0 exit /b !BUILD_RC! +:run_default +"%PYTHON_EXE%" "%SCRIPT_DIR%build_release_win.py" --build-type cpu %VERSION_ARG% %COPY_DIR_ARG% +exit /b %ERRORLEVEL% -rem Default: always create one release ZIP -if "%BUILD_ZIP%"=="1" ( - set "ZIP_NAME=!APP_NAME!_v!VERSION!_Win64.zip" - - rem Remove Inno Setup files before creating zip - if exist "!DIST_DIR!\!APP_NAME!\SuperPicky.iss" del /q "!DIST_DIR!\!APP_NAME!\SuperPicky.iss" >nul 2>&1 - if exist "!DIST_DIR!\!APP_NAME!\ChineseSimplified.isl" del /q "!DIST_DIR!\!APP_NAME!\ChineseSimplified.isl" >nul 2>&1 - - call :zip_dir "!DIST_DIR!\!APP_NAME!" "!DIST_DIR!\!ZIP_NAME!" - if errorlevel 1 exit /b 1 - - rem Restore Inno Setup files after creating zip - if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "!DIST_DIR!\!APP_NAME!\SuperPicky.iss" >nul - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '!DIST_DIR!\!APP_NAME!\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=!VERSION!' | Set-Content -Path '!DIST_DIR!\!APP_NAME!\SuperPicky.iss' -Encoding UTF8" - ) - if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "!DIST_DIR!\!APP_NAME!\ChineseSimplified.isl" >nul - ) - - if defined ZIP_COPY_DIR ( - set "TARGET_SUBDIR=%APP_NAME%_!VERSION!" - set "TARGET_DIR=!ZIP_COPY_DIR!\!TARGET_SUBDIR!" - if not exist "!ZIP_COPY_DIR!" mkdir "!ZIP_COPY_DIR!" - if errorlevel 1 ( - echo [ERROR] Failed to create copy root dir: !ZIP_COPY_DIR! - exit /b 1 - ) - if exist "!TARGET_DIR!" rd /s /q "!TARGET_DIR!" - if exist "!TARGET_DIR!" ( - echo [ERROR] Failed to clean old target dir: !TARGET_DIR! - exit /b 1 - ) - call :copy_dir "%DIST_DIR%\%APP_NAME%" "!TARGET_DIR!" - if errorlevel 1 exit /b 1 - - rem Copy Inno Setup files to target directory - if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "!TARGET_DIR!\SuperPicky.iss" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy SuperPicky.iss to target directory - exit /b 1 - ) - echo [SUCCESS] Copied SuperPicky.iss to !TARGET_DIR! - - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=!VERSION!' | Set-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Encoding UTF8" - if errorlevel 1 ( - echo [ERROR] Failed to update version in SuperPicky.iss in target directory - exit /b 1 - ) - echo [SUCCESS] Updated version in SuperPicky.iss to !VERSION! in target directory - ) - - if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "!TARGET_DIR!\ChineseSimplified.isl" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy ChineseSimplified.isl to target directory - exit /b 1 - ) - echo [SUCCESS] Copied ChineseSimplified.isl to !TARGET_DIR! - ) - - rem Remove Inno Setup files before creating zip - if exist "!TARGET_DIR!\SuperPicky.iss" del /q "!TARGET_DIR!\SuperPicky.iss" >nul 2>&1 - if exist "!TARGET_DIR!\ChineseSimplified.isl" del /q "!TARGET_DIR!\ChineseSimplified.isl" >nul 2>&1 - - call :zip_dir "!TARGET_DIR!" "!ZIP_COPY_DIR!\!TARGET_SUBDIR!.zip" - if errorlevel 1 exit /b 1 - - rem Restore Inno Setup files after creating zip - if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "!TARGET_DIR!\SuperPicky.iss" >nul - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=!VERSION!' | Set-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Encoding UTF8" - ) - if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "!TARGET_DIR!\ChineseSimplified.isl" >nul - ) - - echo [SUCCESS] Copied !TARGET_SUBDIR! + created !ZIP_COPY_DIR!\!TARGET_SUBDIR!.zip - ) -) else ( - set "ZIP_NAME=" - echo [INFO] ZIP creation skipped ^(--no-zip^) -) +:run_passthrough +"%PYTHON_EXE%" "%SCRIPT_DIR%build_release_win.py" --build-type cpu %* +exit /b %ERRORLEVEL% +:show_help +echo SuperPicky Windows compatibility wrapper echo. -echo [========================================] -echo Step 4: Copy Inno Setup files -echo [========================================] - -set "INNO_DIR=%ROOT_DIR%\inno" -set "OUTPUT_EXE_DIR=%DIST_DIR%\%APP_NAME%" - -if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "%OUTPUT_EXE_DIR%\SuperPicky.iss" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy SuperPicky.iss - exit /b 1 - ) - echo [SUCCESS] Copied SuperPicky.iss to %OUTPUT_EXE_DIR% - - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '%OUTPUT_EXE_DIR%\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=%VERSION%' | Set-Content -Path '%OUTPUT_EXE_DIR%\SuperPicky.iss' -Encoding UTF8" - if errorlevel 1 ( - echo [ERROR] Failed to update version in SuperPicky.iss - exit /b 1 - ) - echo [SUCCESS] Updated version in SuperPicky.iss to %VERSION% -) else ( - echo [WARNING] SuperPicky.iss not found in %INNO_DIR% -) - -if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "%OUTPUT_EXE_DIR%\ChineseSimplified.isl" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy ChineseSimplified.isl - exit /b 1 - ) - echo [SUCCESS] Copied ChineseSimplified.isl to %OUTPUT_EXE_DIR% -) else ( - echo [WARNING] ChineseSimplified.isl not found in %INNO_DIR% -) - +echo Usage: +echo %~nx0 [version] [copy_dir] +echo %~nx0 [build_release_win.py options] echo. -echo [========================================] -echo Build finished -echo [========================================] -echo EXE: %DIST_DIR%\%APP_NAME%\SuperPicky.exe -if defined ZIP_NAME ( -echo ZIP: %DIST_DIR%\%ZIP_NAME% -if defined ZIP_COPY_DIR echo Copy: %ZIP_COPY_DIR%\%APP_NAME%_%VERSION% + .zip -) else ( -echo ZIP: ^(skipped^) -) +echo This wrapper forwards to build_release_win.py --build-type cpu. +echo If the first argument starts with --, all arguments are passed through directly. exit /b 0 -:restore_build_info -if exist "%BUILD_INFO_BACKUP%" ( - move /y "%BUILD_INFO_BACKUP%" "%BUILD_INFO_FILE%" >nul -) -exit /b 0 +:show_positional_error +echo [ERROR] Positional compatibility mode only accepts [version] [copy_dir]. +echo [ERROR] If you need extra options such as --debug or --help, use explicit option mode. +echo [ERROR] Example: build_release.bat --debug --help +exit /b 1 \ No newline at end of file diff --git a/build_release.sh b/build_release.sh index 7d004a39..bd624256 100755 --- a/build_release.sh +++ b/build_release.sh @@ -1,378 +1,34 @@ -#!/bin/bash -# SuperPicky - 打包、签名和公证脚本 -# 作者: James Zhen Yu -# 版本: 1.0 -# -# 用法: -# ./build_release.sh --test # 仅打包和签名(跳过公证) -# ./build_release.sh --release # 完整流程:打包、签名、公证 -# ./build_release.sh --help # 显示帮助 +#!/bin/sh +set -eu -set -e # 遇到错误立即退出 - -# ============================================ -# 配置参数 -# ============================================ -APP_NAME="SuperPicky" -BUNDLE_ID="com.jamesphotography.superpicky" -DEVELOPER_ID="Developer ID Application: James Zhen Yu (JWR6FDB52H)" -APPLE_ID="james@jamesphotography.com.au" -TEAM_ID="JWR6FDB52H" -KEYCHAIN_ITEM="SuperPicky-Notarize" - -# 颜色输出 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -CYAN='\033[0;36m' -NC='\033[0m' # No Color - -# ============================================ -# 辅助函数 -# ============================================ -log_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -log_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -log_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -log_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -log_step() { - echo -e "\n${CYAN}========================================${NC}" - echo -e "${CYAN}步骤$1: $2${NC}" - echo -e "${CYAN}========================================${NC}" -} +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +MODE_ARGS="" show_help() { - echo "SuperPicky 构建脚本" - echo "" - echo "用法: $0 [选项]" + echo "SuperPicky macOS compatibility wrapper" echo "" - echo "选项:" - echo " --test 仅打包和签名,跳过公证(用于快速测试)" - echo " --release 完整流程:打包、签名、公证、装订" - echo " --help 显示此帮助信息" - echo "" - echo "首次使用前,需要配置 Keychain:" - echo " security add-generic-password -a \"${APPLE_ID}\" \\" - echo " -s \"${KEYCHAIN_ITEM}\" -w \"你的App-Specific-Password\"" + echo "Usage:" + echo " ./build_release.sh --test [extra build_release_mac.py args]" + echo " ./build_release.sh --release [extra build_release_mac.py args]" echo "" + echo "This wrapper forwards to build_release_mac.py --build-type full." + echo "Use --release to append --notarize." } -# ============================================ -# 参数解析 -# ============================================ -MODE="" -if [ $# -eq 0 ]; then - show_help - exit 0 -fi - -case "$1" in - --test) - MODE="test" - ;; - --release) - MODE="release" - ;; - --help|-h) - show_help - exit 0 - ;; - *) - log_error "未知选项: $1" - show_help - exit 1 - ;; -esac - -# ============================================ -# 步骤0: 环境检查 -# ============================================ -log_step "0" "环境检查" - -# 检查开发者证书 -log_info "检查开发者证书..." -if ! security find-identity -v -p codesigning | grep -q "${DEVELOPER_ID}"; then - log_error "未找到开发者证书: ${DEVELOPER_ID}" - log_info "请确保已在 Keychain 中安装有效的开发者证书" - exit 1 -fi -log_success "开发者证书已就绪" - -# 检查 Keychain 密码(仅 release 模式) -if [ "$MODE" = "release" ]; then - log_info "检查 Keychain 中的 App-Specific Password..." - if ! security find-generic-password -a "${APPLE_ID}" -s "${KEYCHAIN_ITEM}" -w &>/dev/null; then - log_error "未在 Keychain 中找到 App-Specific Password" - log_info "请运行以下命令添加密码:" - echo "" - echo " security add-generic-password -a \"${APPLE_ID}\" \\" - echo " -s \"${KEYCHAIN_ITEM}\" -w \"你的密码\"" - echo "" - exit 1 - fi - log_success "Keychain 密码已配置" -fi - -# 检查 PyInstaller -log_info "检查 PyInstaller..." -if ! command -v pyinstaller &>/dev/null; then - # 尝试从虚拟环境 - if [ -f ".venv/bin/pyinstaller" ]; then - PYINSTALLER=".venv/bin/pyinstaller" - else - log_error "未找到 PyInstaller,请先安装: pip install pyinstaller" - exit 1 - fi -else - PYINSTALLER="pyinstaller" -fi -log_success "PyInstaller 已就绪" - -# 检查 entitlements.plist -if [ ! -f "entitlements.plist" ]; then - log_error "未找到 entitlements.plist 文件" - exit 1 -fi - -# ============================================ -# 步骤1: 提取版本号 -# ============================================ -log_step "1" "提取版本号" - -VERSION=$(grep 'APP_VERSION' constants.py | grep -oE '"[0-9]+\.[0-9]+\.[0-9]+"' | tr -d '"' | head -1) -if [ -z "$VERSION" ]; then - log_error "无法从 constants.py 提取版本号" - exit 1 -fi -log_success "检测到版本: v${VERSION}" - -# ============================================ -# 步骤1.5: 检测 CPU 架构 -# ============================================ -log_info "检测 CPU 架构..." -ARCH=$(uname -m) -if [ "$ARCH" = "arm64" ]; then - ARCH_SUFFIX="arm64" - log_success "检测到 Apple Silicon (arm64)" -elif [ "$ARCH" = "x86_64" ]; then - ARCH_SUFFIX="intel" - log_success "检测到 Intel (x86_64)" -else - ARCH_SUFFIX="$ARCH" - log_warning "未知架构: $ARCH" -fi - -# 设置输出文件名(包含架构信息) -if [ "$MODE" = "test" ]; then - DMG_NAME="${APP_NAME}_v${VERSION}_${ARCH_SUFFIX}_test.dmg" -else - DMG_NAME="${APP_NAME}_v${VERSION}_${ARCH_SUFFIX}.dmg" -fi -DMG_PATH="dist/${DMG_NAME}" - -# ============================================ -# 步骤2: 清理旧文件 -# ============================================ -log_step "2" "清理旧文件" - -rm -rf build dist -mkdir -p dist -log_success "清理完成" - -# ============================================ -# 步骤2.5: 注入 Git Commit Hash -# ============================================ -log_step "2.5" "注入构建信息" - -# 从 Python 代码读取 Commit Hash(保证跨平台一致) -COMMIT_HASH=$(python3 -c " -try: - from core.build_info_local import COMMIT_HASH -except ImportError: - from core.build_info import COMMIT_HASH -print(COMMIT_HASH or 'unknown') -") -log_info "Commit Hash: ${COMMIT_HASH}" - -# 备份原始 build_info.py -BUILD_INFO_FILE="core/build_info.py" -BUILD_INFO_BACKUP="${BUILD_INFO_FILE}.backup" -cp "${BUILD_INFO_FILE}" "${BUILD_INFO_BACKUP}" - -# 注入 commit hash -sed -i.tmp "s/COMMIT_HASH = None/COMMIT_HASH = \"${COMMIT_HASH}\"/" "${BUILD_INFO_FILE}" -rm -f "${BUILD_INFO_FILE}.tmp" # macOS sed 的临时文件 - -log_success "构建信息已注入" - -# ============================================ -# 步骤3: PyInstaller 打包 -# ============================================ -log_step "3" "PyInstaller 打包" - -log_info "正在打包应用..." -${PYINSTALLER} SuperPicky.spec --clean --noconfirm - -# 恢复原始 build_info.py -if [ -f "${BUILD_INFO_BACKUP}" ]; then - mv "${BUILD_INFO_BACKUP}" "${BUILD_INFO_FILE}" - log_info "已恢复原始 build_info.py" -fi - -if [ ! -d "dist/${APP_NAME}.app" ]; then - log_error "打包失败!未找到 dist/${APP_NAME}.app" - exit 1 -fi -log_success "打包完成" - -# ============================================ -# 步骤4: 深度代码签名 -# ============================================ -log_step "4" "深度代码签名" - -# 签名所有嵌入的二进制文件和库 -log_info "签名嵌入的框架和库..." -find "dist/${APP_NAME}.app/Contents" -type f \( -name "*.dylib" -o -name "*.so" \) -print0 | while IFS= read -r -d '' file; do - codesign --force --sign "${DEVELOPER_ID}" --timestamp --options runtime "$file" 2>/dev/null || true -done - -# 签名可执行文件 -find "dist/${APP_NAME}.app/Contents/MacOS" -type f -perm +111 -print0 | while IFS= read -r -d '' file; do - codesign --force --sign "${DEVELOPER_ID}" --timestamp --options runtime "$file" 2>/dev/null || true -done - -# 签名主应用 -log_info "签名主应用..." -codesign --force --deep --sign "${DEVELOPER_ID}" \ - --timestamp \ - --options runtime \ - --entitlements entitlements.plist \ - "dist/${APP_NAME}.app" - -# 验证签名 -log_info "验证代码签名..." -codesign --verify --deep --strict --verbose=2 "dist/${APP_NAME}.app" -log_success "代码签名完成" - -# ============================================ -# 步骤5: 创建 DMG 安装包 -# ============================================ -log_step "5" "创建 DMG 安装包" - -# 创建临时 DMG 文件夹 -TEMP_DMG_DIR="dist/dmg_temp" -rm -rf "${TEMP_DMG_DIR}" -mkdir -p "${TEMP_DMG_DIR}" - -# 复制应用到临时文件夹 -cp -R "dist/${APP_NAME}.app" "${TEMP_DMG_DIR}/" - -# 创建 Applications 快捷方式 -ln -s /Applications "${TEMP_DMG_DIR}/Applications" - -# 创建 DMG -log_info "使用 hdiutil 创建 DMG..." -hdiutil create -volname "${APP_NAME}" -srcfolder "${TEMP_DMG_DIR}" -ov -format UDZO "${DMG_PATH}" - -# 清理临时文件夹 -rm -rf "${TEMP_DMG_DIR}" -log_success "DMG 创建完成: ${DMG_PATH}" - -# ============================================ -# 步骤6: 签名 DMG -# ============================================ -log_step "6" "签名 DMG" - -codesign --force --sign "${DEVELOPER_ID}" --timestamp "${DMG_PATH}" -codesign --verify --verbose=2 "${DMG_PATH}" -log_success "DMG 签名完成" - -# ============================================ -# 步骤7: 公证(仅 release 模式) -# ============================================ -if [ "$MODE" = "release" ]; then - log_step "7" "Apple 公证" - - # 从 Keychain 获取密码 - APP_PASSWORD=$(security find-generic-password -a "${APPLE_ID}" -s "${KEYCHAIN_ITEM}" -w) - - log_info "提交到 Apple 公证服务..." - log_info "(这可能需要几分钟时间)" - - NOTARIZE_OUTPUT=$(xcrun notarytool submit "${DMG_PATH}" \ - --apple-id "${APPLE_ID}" \ - --password "${APP_PASSWORD}" \ - --team-id "${TEAM_ID}" \ - --wait \ - --output-format json 2>&1) - - echo "${NOTARIZE_OUTPUT}" - - # 检查公证结果 - if echo "${NOTARIZE_OUTPUT}" | grep -Eq '"status"[[:space:]]*:[[:space:]]*"Accepted"'; then - log_success "公证成功!" - - # 步骤8: 装订公证票据 - log_step "8" "装订公证票据" - xcrun stapler staple "${DMG_PATH}" - xcrun stapler validate "${DMG_PATH}" - log_success "票据装订完成" - else - log_error "公证失败!" - - # 提取 RequestUUID 并获取详细日志 - REQUEST_UUID=$(echo "${NOTARIZE_OUTPUT}" | sed -n 's/.*"id"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p' | head -1) - if [ -n "${REQUEST_UUID}" ]; then - log_info "获取详细公证日志..." - xcrun notarytool log "${REQUEST_UUID}" \ - --apple-id "${APPLE_ID}" \ - --password "${APP_PASSWORD}" \ - --team-id "${TEAM_ID}" - fi - exit 1 - fi -else - log_info "测试模式:跳过公证步骤" -fi - -# ============================================ -# 完成报告 -# ============================================ -echo "" -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}构建完成!${NC}" -echo -e "${GREEN}========================================${NC}" -echo "" -echo -e "应用: ${CYAN}dist/${APP_NAME}.app${NC}" -echo -e "DMG: ${CYAN}${DMG_PATH}${NC}" -echo -e "架构: ${CYAN}${ARCH_SUFFIX}${NC}" -echo "" - -if [ "$MODE" = "release" ]; then - echo -e "状态: ${GREEN}已公证,可分发${NC}" - echo "" - echo "下一步:" - echo " 1. 测试 DMG 安装包" - echo " 2. 上传到 GitHub Releases" - echo "" - echo "注意: 如需构建其他架构版本,请在对应架构的 Mac 上重新运行此脚本" -else - echo -e "状态: ${YELLOW}已签名(未公证)${NC}" - echo "" - echo "注意: 测试模式下未进行公证,用户首次打开需要右键菜单" - echo "发布正式版本请使用: ./build_release.sh --release" -fi - -echo -e "${GREEN}========================================${NC}" +if [ "$#" -gt 0 ]; then + case "$1" in + --help|-h) + show_help + exit 0 + ;; + --release) + MODE_ARGS="--notarize" + shift + ;; + --test) + shift + ;; + esac +fi + +exec python3 "$SCRIPT_DIR/build_release_mac.py" --build-type full $MODE_ARGS "$@" \ No newline at end of file diff --git a/build_release_all.bat b/build_release_all.bat index c7e69713..36cf6ef0 100644 --- a/build_release_all.bat +++ b/build_release_all.bat @@ -1,3 +1,8 @@ +@echo off +setlocal EnableExtensions -call build_release_cpu.bat %1 -:: call build_release_cuda.bat %1 \ No newline at end of file +call "%~dp0build_release_cpu.bat" %* +if errorlevel 1 exit /b %ERRORLEVEL% + +call "%~dp0build_release_lite_win.bat" %* +exit /b %ERRORLEVEL% \ No newline at end of file diff --git a/build_release_cpu.bat b/build_release_cpu.bat index d60bec4f..16a9abb0 100644 --- a/build_release_cpu.bat +++ b/build_release_cpu.bat @@ -1,19 +1,5 @@ @echo off setlocal EnableExtensions -set "VERSION_INPUT=%~1" -if "%VERSION_INPUT%"=="" ( - set "VERSION_ARG=Win64_CPU" -) else ( - set "VERSION_ARG=%VERSION_INPUT%Win64_CPU" -) - -call "%~dp0.venv\Scripts\activate.bat" -if errorlevel 1 exit /b 1 - -set "OUT_DIST_DIR=dist_cpu" -call "%~dp0build_release.bat" "%VERSION_ARG%" "output" -set "RET=%ERRORLEVEL%" - -call "%~dp0.venv\Scripts\deactivate.bat" >nul 2>&1 -exit /b %RET% +call "%~dp0build_release.bat" "%~1" "output" +exit /b %ERRORLEVEL% diff --git a/build_release_cuda.bat b/build_release_cuda.bat index 47bb43a7..1e20d9fd 100644 --- a/build_release_cuda.bat +++ b/build_release_cuda.bat @@ -1,19 +1,12 @@ @echo off setlocal EnableExtensions -set "VERSION_INPUT=%~1" -if "%VERSION_INPUT%"=="" ( - set "VERSION_ARG=_Win64_CUDA" -) else ( - set "VERSION_ARG=%VERSION_INPUT%_Win64_CUDA" -) +set "SCRIPT_DIR=%~dp0" +set "PYTHON_EXE=%SCRIPT_DIR%.venv\Scripts\python.exe" +if not exist "%PYTHON_EXE%" set "PYTHON_EXE=python" -call "%~dp0.venv\Scripts\activate.bat" -if errorlevel 1 exit /b 1 +set "VERSION_ARG=" +if not "%~1"=="" set "VERSION_ARG=--version %~1" -set "OUT_DIST_DIR=dist_cuda" -call "%~dp0build_release.bat" "%VERSION_ARG%" "output\win64_cuda" -set "RET=%ERRORLEVEL%" - -call "%~dp0.venv\Scripts\deactivate.bat" >nul 2>&1 -exit /b %RET% +"%PYTHON_EXE%" "%SCRIPT_DIR%build_release_win.py" --build-type cuda %VERSION_ARG% --copy-dir output\win64_cuda +exit /b %ERRORLEVEL% diff --git a/build_release_full_mac.sh b/build_release_full_mac.sh new file mode 100755 index 00000000..4fe53940 --- /dev/null +++ b/build_release_full_mac.sh @@ -0,0 +1,5 @@ +#!/bin/sh +set -eu + +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +exec python3 "$SCRIPT_DIR/build_release_mac.py" --build-type full "$@" diff --git a/build_release_lite_mac.sh b/build_release_lite_mac.sh new file mode 100755 index 00000000..0434f40e --- /dev/null +++ b/build_release_lite_mac.sh @@ -0,0 +1,5 @@ +#!/bin/sh +set -eu + +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +exec python3 "$SCRIPT_DIR/build_release_mac.py" --build-type lite "$@" diff --git a/build_release_lite_win.bat b/build_release_lite_win.bat new file mode 100644 index 00000000..a4ce8b2f --- /dev/null +++ b/build_release_lite_win.bat @@ -0,0 +1,10 @@ +@echo off +setlocal EnableExtensions + +set "PYTHON_EXE=%~dp0.venv\Scripts\python.exe" +if not exist "%PYTHON_EXE%" ( + set "PYTHON_EXE=python" +) + +"%PYTHON_EXE%" "%~dp0build_release_win.py" --build-type lite %* +exit /b %ERRORLEVEL% diff --git a/build_release_mac.py b/build_release_mac.py new file mode 100644 index 00000000..95f4e81f --- /dev/null +++ b/build_release_mac.py @@ -0,0 +1,1018 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +SuperPicky macOS 构建脚本 / SuperPicky macOS build script. + +支持 full 与 lite 两种构建类型,并可选执行 Developer ID 签名。 +Supports both full and lite builds with optional Developer ID signing. +""" + +from __future__ import annotations + +import argparse +import ast +import importlib.metadata +import json +import logging +import os +import platform +import re +import secrets +import shutil +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Sequence + +from packaging.requirements import Requirement + + +ROOT_DIR = Path(__file__).resolve().parent +APP_NAME = "SuperPicky" +LITE_APP_NAME = "SuperPickyLite" +BUILD_INFO_FILE = ROOT_DIR / "core" / "build_info.py" +DOWNLOAD_MODELS_SCRIPT = ROOT_DIR / "scripts" / "download_models.py" +FULL_SPEC_FILE = ROOT_DIR / "SuperPicky_full.spec" +LITE_SPEC_FILE = ROOT_DIR / "SuperPicky_lite.spec" +REQUIREMENTS_MAC_FILE = ROOT_DIR / "requirements_mac.txt" +ENTITLEMENTS_FILE = ROOT_DIR / "entitlements.plist" +DMG_README_FILE = ROOT_DIR / "resources" / "DMG_README.txt" + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BuildPaths: + """ + 构建路径集合 / Build path collection. + """ + + label: str + work_dir: Path + dist_dir: Path + app_dir: Path + dmg_path: Path + + +@dataclass(frozen=True) +class BuildConfig: + """ + 构建配置 / Build configuration. + """ + + build_type: str + arch: str + copy_dir: Path | None + debug: bool + app_version: str + commit_hash: str + sign_p12: Path | None + sign_p12_password_env: str + sign_identity: str | None + release_channel: str + notarize: bool + apple_id: str | None + apple_password_env: str + team_id: str | None + notary_keychain_profile: str | None + + +@dataclass +class SigningContext: + """ + 签名上下文 / Signing context. + """ + + keychain_path: Path + keychain_password: str + imported_p12_path: Path + identity: str + + +def configure_logging(debug: bool) -> None: + """ + 配置 UTF-8 日志输出 / Configure UTF-8 logging output. + """ + + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] + if hasattr(sys.stderr, "reconfigure"): + sys.stderr.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] + + logger.setLevel(logging.DEBUG if debug else logging.INFO) + logger.propagate = False + + formatter = logging.Formatter("[%(levelname)s] %(message)s") + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if debug else logging.INFO) + handler.setFormatter(formatter) + + logger.handlers.clear() + logger.addHandler(handler) + + +def log_step(title: str) -> None: + """ + 记录步骤标题 / Log a step title. + """ + + logger.info("[========================================]") + logger.info(title) + logger.info("[========================================]") + + +def log_verbose(message: str, *args) -> None: + """ + 仅在调试模式输出详细日志 / Emit verbose logs only in debug mode. + """ + + logger.debug(message, *args) + + +def detect_host_arch() -> str: + """ + 规范化当前主机架构 / Normalize the current host architecture. + """ + + machine = platform.machine().lower() + return {"amd64": "x86_64", "x86_64": "x86_64", "arm64": "arm64", "aarch64": "arm64"}.get(machine, machine) + + +def optional_text(value: str | None) -> str | None: + """ + 规范化可选字符串 / Normalize optional text values. + """ + + if value is None: + return None + normalized = value.strip() + return normalized or None + + +def parse_args() -> argparse.Namespace: + """ + 解析命令行参数 / Parse command-line arguments. + """ + + parser = argparse.ArgumentParser(description="SuperPicky macOS 构建脚本") + parser.add_argument("--build-type", choices=["full", "lite"], required=True, help="构建类型:full 或 lite") + parser.add_argument( + "--arch", + choices=["arm64", "x86_64"], + default=detect_host_arch(), + help="目标架构,默认使用当前主机架构", + ) + parser.add_argument("--version", help="覆盖构建版本号,例如 4.2.5") + parser.add_argument("--copy-dir", help="复制最终产物的目标目录") + parser.add_argument("--debug", action="store_true", help="输出调试日志") + parser.add_argument("--sign-p12", help="Developer ID 证书 .p12 文件路径") + parser.add_argument( + "--sign-p12-password-env", + default="MACOS_CERTIFICATE_PWD", + help="读取 .p12 密码的环境变量名(默认: MACOS_CERTIFICATE_PWD)", + ) + parser.add_argument("--sign-identity", help="可选,显式指定 Developer ID Application identity") + parser.add_argument("--notarize", action="store_true", help="提交 Apple 公证并自动 staple DMG") + parser.add_argument("--apple-id", help="Apple notarization 使用的 Apple ID") + parser.add_argument("--team-id", help="Apple notarization 使用的 Team ID") + parser.add_argument( + "--apple-password-env", + default="APPLE_APP_PASSWORD", + help="读取 notarization 密码的环境变量名(默认: APPLE_APP_PASSWORD)", + ) + parser.add_argument("--notary-keychain-profile", help="可选,使用 notarytool keychain profile 进行认证") + return parser.parse_args() + + +def run_command( + command: Sequence[str], + *, + cwd: Path = ROOT_DIR, + check: bool = True, + capture_output: bool = False, + env: dict[str, str] | None = None, + label: str | None = None, +) -> subprocess.CompletedProcess[str]: + """ + 运行外部命令 / Run an external command. + """ + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("执行命令: %s", " ".join(command)) + + result = subprocess.run( + list(command), + cwd=str(cwd), + text=True, + capture_output=capture_output, + env=env, + ) + + if check and result.returncode != 0: + if capture_output: + if result.stdout: + logger.error(result.stdout.strip()) + if result.stderr: + logger.error(result.stderr.strip()) + raise RuntimeError(f"{label or '命令执行'}失败,返回码: {result.returncode}") + + return result + + +def remove_path(path: Path) -> None: + """ + 删除文件或目录 / Remove a file or directory. + """ + + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path, ignore_errors=True) + elif path.exists() or path.is_symlink(): + path.unlink(missing_ok=True) + + +def copy_tree(src: Path, dst: Path) -> None: + """ + 复制目录 / Copy a directory tree. + """ + + if not src.exists(): + raise FileNotFoundError(f"复制源目录不存在: {src}") + remove_path(dst) + shutil.copytree(src, dst, symlinks=True) + + +def copy_file(src: Path, dst: Path) -> None: + """ + 复制文件 / Copy a file. + """ + + if not src.exists(): + raise FileNotFoundError(f"复制源文件不存在: {src}") + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dst) + + +def read_app_version() -> str: + """ + 从 constants.py 读取版本号 / Read version from constants.py. + """ + + content = (ROOT_DIR / "constants.py").read_text(encoding="utf-8") + match = re.search(r'APP_VERSION\s*=\s*["\']([0-9A-Za-z._-]+)["\']', content) + return match.group(1) if match else "0.0.0" + + +def get_commit_hash() -> str: + """ + 获取当前提交哈希 / Get the current commit hash. + """ + + try: + result = run_command( + ["git", "rev-parse", "--short=7", "HEAD"], + capture_output=True, + label="获取提交哈希", + ) + return result.stdout.strip() or "unknown" + except Exception: + content = BUILD_INFO_FILE.read_text(encoding="utf-8") + match = re.search(r'COMMIT_HASH\s*=\s*"([^"]*)"', content) + return match.group(1) if match else "unknown" + + +def parse_release_channel() -> str: + """ + 根据 RELEASE_TAG 判断发布渠道 / Infer release channel from RELEASE_TAG. + """ + + release_tag = os.environ.get("RELEASE_TAG", "") + if release_tag and "-rc" in release_tag.lower(): + return "nightly" + return "official" + + +def inject_build_info(commit_hash: str, release_channel: str) -> Path | None: + """ + 注入构建信息并返回备份路径 / Inject build info and return the backup path. + """ + + log_step("步骤 1: 注入构建元数据") + if not BUILD_INFO_FILE.exists(): + logger.warning("未找到 build_info.py,跳过注入") + return None + + backup_path = BUILD_INFO_FILE.with_suffix(".py.backup") + shutil.copy2(BUILD_INFO_FILE, backup_path) + + content = BUILD_INFO_FILE.read_text(encoding="utf-8-sig") + updated = re.sub( + r'COMMIT_HASH\s*=\s*".*"', + f'COMMIT_HASH = "{commit_hash}"', + content, + count=1, + ) + updated = re.sub( + r'RELEASE_CHANNEL\s*=\s*".*"', + f'RELEASE_CHANNEL = "{release_channel}"', + updated, + count=1, + ) + BUILD_INFO_FILE.write_text(updated, encoding="utf-8") + log_verbose("[成功] 已写入 COMMIT_HASH=%s RELEASE_CHANNEL=%s", commit_hash, release_channel) + return backup_path + + +def restore_build_info(backup_path: Path | None) -> None: + """ + 恢复构建信息文件 / Restore the build info file. + """ + + if backup_path and backup_path.exists(): + shutil.move(str(backup_path), str(BUILD_INFO_FILE)) + + +def spec_file_for(build_type: str) -> Path: + """ + 返回构建类型对应的 spec 文件 / Return the spec file for a build type. + """ + + if build_type == "lite": + return LITE_SPEC_FILE + return FULL_SPEC_FILE + + +def app_name_for(build_type: str) -> str: + """ + 返回构建类型对应的应用名 / Return the app name for a build type. + """ + + return LITE_APP_NAME if build_type == "lite" else APP_NAME + + +def artifact_name_for(build_type: str) -> str: + """ + 返回发布产物名称前缀 / Return the artifact name prefix for releases. + """ + + return "SuperPicky_Lite" if build_type == "lite" else APP_NAME + + +def display_name_for(build_type: str) -> str: + """ + 返回面向用户的展示名称 / Return the user-facing display name. + """ + + return "SuperPicky Lite" if build_type == "lite" else APP_NAME + + +def get_build_paths(build_type: str, arch: str, app_version: str, commit_hash: str) -> BuildPaths: + """ + 生成构建路径 / Build output paths. + """ + + label = f"{build_type}_{arch}" + app_name = app_name_for(build_type) + artifact_name = artifact_name_for(build_type) + dist_dir = ROOT_DIR / f"dist_{label}" + dmg_name = f"{artifact_name}_v{app_version}_{arch}_{commit_hash}.dmg" + return BuildPaths( + label=label, + work_dir=ROOT_DIR / f"build_dist_{label}", + dist_dir=dist_dir, + app_dir=dist_dir / f"{app_name}.app", + dmg_path=dist_dir / dmg_name, + ) + + +def ensure_macos_host() -> None: + """ + 确保当前系统为 macOS / Ensure the current host is macOS. + """ + + if sys.platform != "darwin": + raise RuntimeError("build_release_mac.py 只能在 macOS 上运行") + + +def ensure_arch_matches(target_arch: str) -> None: + """ + 确保目标架构与当前机器匹配 / Ensure target architecture matches the host. + """ + + normalized = detect_host_arch() + if normalized != target_arch: + raise RuntimeError( + f"当前机器架构为 {normalized},不能直接构建 {target_arch}。" + "请在对应架构的 macOS 环境中运行此脚本。" + ) + + +def _iter_requirement_lines(requirements_file: Path) -> Iterable[tuple[Path, str]]: + """ + 递归展开 requirements 文件 / Recursively expand requirements files. + """ + + for raw_line in requirements_file.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("-r "): + nested_path = (requirements_file.parent / line[3:].strip()).resolve() + yield from _iter_requirement_lines(nested_path) + continue + if line.startswith("--requirement "): + nested_path = (requirements_file.parent / line.split(None, 1)[1].strip()).resolve() + yield from _iter_requirement_lines(nested_path) + continue + if line.startswith("-"): + logger.debug("跳过未处理的 requirements 条目: %s", line) + continue + yield requirements_file, line + + +def validate_python_environment() -> None: + """ + 检查当前 Python 环境是否满足 requirements_mac.txt / Validate the current Python environment. + """ + + log_step("步骤 2: 检查 Python 构建环境") + + missing_packages: list[str] = [] + version_conflicts: list[str] = [] + + for source_file, requirement_text in _iter_requirement_lines(REQUIREMENTS_MAC_FILE): + requirement = Requirement(requirement_text) + try: + installed_version = importlib.metadata.version(requirement.name) + except importlib.metadata.PackageNotFoundError: + missing_packages.append(f"{requirement.name} ({source_file.name})") + continue + if requirement.specifier and installed_version not in requirement.specifier: + version_conflicts.append( + f"{requirement.name}=={installed_version} 不满足 {requirement.specifier} ({source_file.name})" + ) + + if missing_packages or version_conflicts: + details = "\n".join([*missing_packages, *version_conflicts]) + raise RuntimeError( + "当前 Python 环境未满足 requirements_mac.txt。\n" + "请先执行 `python -m pip install -r requirements_mac.txt`。\n" + f"{details}" + ) + + run_command([sys.executable, "-c", "import PyInstaller; print(PyInstaller.__version__)"], label="PyInstaller 检查") + log_verbose("[成功] 当前 Python 环境满足 macOS 构建要求") + + +def load_required_models() -> list[dict[str, str]]: + """ + 从 download_models.py 解析模型清单 / Parse the required model list from download_models.py. + """ + + fallback = [ + {"filename": "model20240824.pth", "dest_dir": "models"}, + {"filename": "superFlier_efficientnet.pth", "dest_dir": "models"}, + {"filename": "cub200_keypoint_resnet50_slim.pth", "dest_dir": "models"}, + {"filename": "avonet.db", "dest_dir": "birdid/data"}, + {"filename": "cfanet_iaa_ava_res50-3cd62bb3.pth", "dest_dir": "models"}, + {"filename": "yolo11l-seg.pt", "dest_dir": "models"}, + ] + + if not DOWNLOAD_MODELS_SCRIPT.exists(): + logger.warning("未找到 download_models.py,使用默认模型列表") + return fallback + + try: + module_ast = ast.parse(DOWNLOAD_MODELS_SCRIPT.read_text(encoding="utf-8")) + models = None + for node in module_ast.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "MODELS_TO_DOWNLOAD": + models = ast.literal_eval(node.value) + break + if models is not None: + break + if models is None: + raise RuntimeError("download_models.py 中未找到 MODELS_TO_DOWNLOAD") + return [{"filename": item["filename"], "dest_dir": item["dest_dir"]} for item in models] + except BaseException as exc: + if isinstance(exc, KeyboardInterrupt): + raise + logger.warning("无法解析 download_models.py,使用默认模型列表: %s", exc) + return fallback + + +REQUIRED_MODELS = load_required_models() + + +def ensure_models() -> None: + """ + 检查 full 构建所需模型并在缺失时下载 / Ensure required models for the full build. + """ + + log_step("步骤 3: 检查并下载模型文件") + missing_paths = [ + ROOT_DIR / model["dest_dir"] / model["filename"] + for model in REQUIRED_MODELS + if not (ROOT_DIR / model["dest_dir"] / model["filename"]).exists() + ] + + if not missing_paths: + log_verbose("[成功] 所有模型文件已就绪") + return + + logger.warning("检测到 %d 个缺失模型,开始下载", len(missing_paths)) + run_command([sys.executable, str(DOWNLOAD_MODELS_SCRIPT)], label="模型下载") + + remaining = [path for path in missing_paths if not path.exists()] + if remaining: + missing_text = "\n".join(str(path) for path in remaining) + raise RuntimeError(f"模型下载后仍有缺失:\n{missing_text}") + log_verbose("[成功] 所有模型文件已就绪") + + +def clean_build_outputs(paths: BuildPaths) -> None: + """ + 清理构建目录 / Clean build outputs. + """ + + log_step("步骤 4: 清理旧的构建目录") + remove_path(paths.work_dir) + remove_path(paths.dist_dir) + log_verbose("[成功] 已清理 %s 和 %s", paths.work_dir, paths.dist_dir) + + +def build_environment(config: BuildConfig) -> dict[str, str]: + """ + 生成 PyInstaller 环境变量 / Build PyInstaller environment variables. + """ + + env = os.environ.copy() + env["SUPERPICKY_TARGET_ARCH"] = config.arch + env["SUPERPICKY_APP_VERSION"] = config.app_version + env["SUPERPICKY_CODESIGN_IDENTITY"] = "" + env["SUPERPICKY_ENTITLEMENTS_FILE"] = "" + return env + + +def build_bundle(config: BuildConfig, paths: BuildPaths) -> None: + """ + 执行 PyInstaller 构建 / Run the PyInstaller build. + """ + + log_step("步骤 5: 执行 PyInstaller 构建") + spec_file = spec_file_for(config.build_type) + if config.build_type == "lite": + log_verbose("[信息] macOS Lite 当前采用内置 Torch/Torchvision/Timm 的单包运行时策略") + pyinstaller_command = [ + sys.executable, + "-m", + "PyInstaller", + str(spec_file), + "--clean", + "--noconfirm", + f"--workpath={paths.work_dir}", + f"--distpath={paths.dist_dir}", + ] + logger.info("启动 PyInstaller 构建:开始") + logger.info("PyInstaller 参数:%s", " ".join(str(item) for item in pyinstaller_command[2:])) + run_command( + pyinstaller_command, + capture_output=not logger.isEnabledFor(logging.DEBUG), + env=build_environment(config), + label="PyInstaller 构建", + ) + + if not paths.app_dir.exists(): + raise FileNotFoundError(f"构建完成后未找到 .app: {paths.app_dir}") + logger.info("PyInstaller 构建成功!") + logger.info("构建产物位置:%s", paths.app_dir) + log_verbose("[成功] 已完成 %s 构建: %s", config.build_type, paths.app_dir) + + +def organize_app_bundle_resources(app_dir: Path) -> None: + """ + 将资源移至 .app 的 Resources 目录 / Move resources into the app bundle Resources directory. + """ + + log_step("步骤 6: 整理 .app 资源目录") + macos_dir = app_dir / "Contents" / "MacOS" + resources_dir = app_dir / "Contents" / "Resources" + resources_dir.mkdir(parents=True, exist_ok=True) + + for resource_name in ("SuperBirdIDPlugin.lrplugin", "en.lproj", "zh-Hans.lproj"): + source_path = macos_dir / resource_name + destination_path = resources_dir / resource_name + if source_path.exists(): + remove_path(destination_path) + shutil.move(str(source_path), str(destination_path)) + log_verbose("[成功] 已移动资源到 Resources: %s", resource_name) + + +def create_dmg(config: BuildConfig, paths: BuildPaths) -> None: + """ + 生成 DMG 安装镜像 / Create a DMG installer image. + """ + + log_step("步骤 7: 生成 DMG") + staging_dir = paths.dist_dir / "dmg_staging" + remove_path(staging_dir) + staging_dir.mkdir(parents=True, exist_ok=True) + + staged_app = staging_dir / paths.app_dir.name + copy_tree(paths.app_dir, staged_app) + + plugin_dir = paths.app_dir / "Contents" / "Resources" / "SuperBirdIDPlugin.lrplugin" + if plugin_dir.exists(): + copy_tree(plugin_dir, staging_dir / plugin_dir.name) + + if DMG_README_FILE.exists(): + copy_file(DMG_README_FILE, staging_dir / "README.txt") + + applications_link = staging_dir / "Applications" + if not applications_link.exists(): + os.symlink("/Applications", applications_link) + + paths.dmg_path.parent.mkdir(parents=True, exist_ok=True) + remove_path(paths.dmg_path) + run_command( + [ + "hdiutil", + "create", + "-volname", + f"{display_name_for(config.build_type)} {config.app_version}", + "-srcfolder", + str(staging_dir), + "-ov", + "-format", + "UDZO", + str(paths.dmg_path), + ], + label="生成 DMG", + ) + + remove_path(staging_dir) + log_verbose("[成功] 已生成 DMG: %s", paths.dmg_path) + + +def prepare_signing(config: BuildConfig) -> SigningContext | None: + """ + 如果提供 .p12,则导入临时 keychain / Import a temporary keychain when a .p12 is provided. + """ + + if config.sign_p12 is None: + return None + + log_step("步骤 8: 导入签名证书") + if not config.sign_p12.exists(): + raise FileNotFoundError(f"未找到签名证书文件: {config.sign_p12}") + + password = os.environ.get(config.sign_p12_password_env, "") + if not password: + raise RuntimeError(f"环境变量 {config.sign_p12_password_env} 未设置,无法导入 .p12") + + temp_dir = Path(tempfile.mkdtemp(prefix="superpicky_sign_", dir=str(ROOT_DIR))) + imported_p12_path = temp_dir / config.sign_p12.name + shutil.copy2(config.sign_p12, imported_p12_path) + + keychain_path = temp_dir / "build.keychain-db" + keychain_password = secrets.token_hex(16) + + run_command(["security", "create-keychain", "-p", keychain_password, str(keychain_path)], label="创建临时 keychain") + run_command(["security", "set-keychain-settings", "-lut", "21600", str(keychain_path)], label="配置 keychain") + run_command(["security", "unlock-keychain", "-p", keychain_password, str(keychain_path)], label="解锁 keychain") + run_command( + [ + "security", + "import", + str(imported_p12_path), + "-k", + str(keychain_path), + "-P", + password, + "-T", + "/usr/bin/codesign", + ], + label="导入 .p12", + ) + run_command( + [ + "security", + "set-key-partition-list", + "-S", + "apple-tool:,apple:", + "-s", + "-k", + keychain_password, + str(keychain_path), + ], + label="配置 keychain 访问权限", + ) + + identity = config.sign_identity or discover_signing_identity(keychain_path) + log_verbose("[成功] 已加载签名 identity: %s", identity) + return SigningContext( + keychain_path=keychain_path, + keychain_password=keychain_password, + imported_p12_path=imported_p12_path, + identity=identity, + ) + + +def discover_signing_identity(keychain_path: Path) -> str: + """ + 从 keychain 中解析 Developer ID Application identity / Resolve Developer ID Application identity from keychain. + """ + + result = run_command( + ["security", "find-identity", "-v", "-p", "codesigning", str(keychain_path)], + capture_output=True, + label="解析签名 identity", + ) + pattern = re.compile(r'"(Developer ID Application:[^"]+)"') + for line in result.stdout.splitlines(): + match = pattern.search(line) + if match: + return match.group(1) + raise RuntimeError("未在 .p12 对应 keychain 中找到 Developer ID Application identity") + + +def iter_signable_files(contents_dir: Path) -> list[Path]: + """ + 枚举需要优先签名的文件 / Enumerate files that should be signed first. + """ + + signable: list[Path] = [] + for path in contents_dir.rglob("*"): + if not path.is_file(): + continue + if path.suffix in {".dylib", ".so"} or os.access(path, os.X_OK): + signable.append(path) + signable.sort(key=lambda item: len(item.parts), reverse=True) + return signable + + +def codesign_path( + path: Path, + identity: str, + *, + entitlements: Path | None = None, + keychain_path: Path | None = None, + use_runtime: bool = False, +) -> None: + """ + 对指定路径执行 codesign / Sign a path with codesign. + """ + + command = ["codesign", "--force", "--sign", identity] + if identity != "-": + command.append("--timestamp") + if use_runtime: + command.extend(["--options", "runtime"]) + if entitlements is not None: + command.extend(["--entitlements", str(entitlements)]) + if keychain_path is not None: + command.extend(["--keychain", str(keychain_path)]) + command.append(str(path)) + run_command(command, label=f"签名 {path.name}") + + +def sign_app_bundle(app_dir: Path, signing_context: SigningContext | None) -> None: + """ + 对 .app 执行签名并验证 / Sign and verify the app bundle. + """ + + log_step("步骤 9: 签名并验证 .app") + identity = signing_context.identity if signing_context else "-" + keychain_path = signing_context.keychain_path if signing_context else None + entitlements = ENTITLEMENTS_FILE if signing_context and ENTITLEMENTS_FILE.exists() else None + + for nested_path in iter_signable_files(app_dir / "Contents"): + codesign_path(nested_path, identity, keychain_path=keychain_path, use_runtime=True) + + codesign_path(app_dir, identity, entitlements=entitlements, keychain_path=keychain_path, use_runtime=True) + verify_command = ["codesign", "--verify", "--deep", "--strict", "--verbose=2", str(app_dir)] + run_command(verify_command, label="校验 .app 签名") + log_verbose("[成功] .app 签名校验通过") + + +def sign_dmg(dmg_path: Path, signing_context: SigningContext | None) -> None: + """ + 如有证书则签名 DMG / Sign the DMG when a certificate is provided. + """ + + if signing_context is None: + log_verbose("[信息] 未提供 .p12,跳过 DMG 签名") + return + + log_step("步骤 10: 签名 DMG") + codesign_path(dmg_path, signing_context.identity, keychain_path=signing_context.keychain_path, use_runtime=False) + run_command(["codesign", "--verify", "--verbose=2", str(dmg_path)], label="校验 DMG 签名") + log_verbose("[成功] DMG 签名校验通过") + + +def notary_auth_arguments(config: BuildConfig) -> list[str]: + """ + 构造 notarytool 认证参数 / Build notarytool authentication arguments. + """ + + if config.notary_keychain_profile: + return ["--keychain-profile", config.notary_keychain_profile] + + if not config.apple_id: + raise RuntimeError("启用 --notarize 时必须提供 Apple ID 或设置 APPLE_ID 环境变量") + if not config.team_id: + raise RuntimeError("启用 --notarize 时必须提供 Team ID 或设置 MACOS_TEAM_ID/TEAM_ID 环境变量") + + password = os.environ.get(config.apple_password_env, "").strip() + if not password: + raise RuntimeError(f"启用 --notarize 时必须设置环境变量 {config.apple_password_env}") + + return [ + "--apple-id", + config.apple_id, + "--password", + password, + "--team-id", + config.team_id, + ] + + +def notarize_dmg(dmg_path: Path, config: BuildConfig) -> None: + """ + 公证并装订 DMG / Notarize and staple the DMG. + """ + + if not config.notarize: + log_verbose("[信息] 未启用 --notarize,跳过 Apple 公证") + return + + log_step("步骤 11: Apple 公证并装订") + auth_args = notary_auth_arguments(config) + submit_command = [ + "xcrun", + "notarytool", + "submit", + str(dmg_path), + *auth_args, + "--wait", + "--output-format", + "json", + ] + result = run_command(submit_command, capture_output=True, label="Apple 公证") + output = result.stdout.strip() + if output: + logger.info(output) + + status = "" + request_id = "" + if output: + try: + payload = json.loads(output) + except json.JSONDecodeError: + payload = None + if isinstance(payload, dict): + status = str(payload.get("status", "")).strip().lower() + request_id = str(payload.get("id", "")).strip() + elif "Accepted" in output: + status = "accepted" + + if status != "accepted": + if request_id: + log_verbose("[信息] 公证失败,尝试读取详细日志: %s", request_id) + log_result = run_command( + ["xcrun", "notarytool", "log", request_id, *auth_args], + capture_output=True, + check=False, + label="读取公证日志", + ) + if log_result.stdout: + logger.error(log_result.stdout.strip()) + if log_result.stderr: + logger.error(log_result.stderr.strip()) + raise RuntimeError("Apple 公证失败") + + run_command(["xcrun", "stapler", "staple", str(dmg_path)], label="装订公证票据") + run_command(["xcrun", "stapler", "validate", str(dmg_path)], label="验证公证票据") + log_verbose("[成功] DMG 公证与装订完成") + + +def publish_artifacts(paths: BuildPaths, config: BuildConfig) -> tuple[Path, Path]: + """ + 输出最终产物位置 / Publish final artifact locations. + """ + + if config.copy_dir is None: + return paths.app_dir, paths.dmg_path + + config.copy_dir.mkdir(parents=True, exist_ok=True) + destination_app = config.copy_dir / paths.app_dir.name + destination_dmg = config.copy_dir / paths.dmg_path.name + copy_tree(paths.app_dir, destination_app) + copy_file(paths.dmg_path, destination_dmg) + log_verbose("[成功] 已复制最终产物到: %s", config.copy_dir) + return destination_app, destination_dmg + + +def cleanup_signing_context(signing_context: SigningContext | None) -> None: + """ + 清理临时 keychain 和证书文件 / Clean up the temporary keychain and certificate file. + """ + + if signing_context is None: + return + + parent_dir = signing_context.keychain_path.parent + run_command(["security", "delete-keychain", str(signing_context.keychain_path)], check=False) + remove_path(parent_dir) + + +def create_config(args: argparse.Namespace) -> BuildConfig: + """ + 根据参数创建构建配置 / Create build configuration from CLI args. + """ + + return BuildConfig( + build_type=args.build_type, + arch=args.arch, + copy_dir=Path(args.copy_dir).resolve() if args.copy_dir else None, + debug=args.debug, + app_version=args.version or read_app_version(), + commit_hash=get_commit_hash(), + sign_p12=Path(args.sign_p12).resolve() if args.sign_p12 else None, + sign_p12_password_env=args.sign_p12_password_env, + sign_identity=optional_text(args.sign_identity), + release_channel=parse_release_channel(), + notarize=args.notarize, + apple_id=optional_text(args.apple_id) or optional_text(os.environ.get("APPLE_ID")), + apple_password_env=args.apple_password_env, + team_id=( + optional_text(args.team_id) + or optional_text(os.environ.get("MACOS_TEAM_ID")) + or optional_text(os.environ.get("TEAM_ID")) + ), + notary_keychain_profile=( + optional_text(args.notary_keychain_profile) + or optional_text(os.environ.get("NOTARY_KEYCHAIN_PROFILE")) + ), + ) + + +def run_build(config: BuildConfig) -> None: + """ + 执行完整构建流程 / Run the complete build flow. + """ + + ensure_macos_host() + ensure_arch_matches(config.arch) + validate_python_environment() + + if config.build_type == "full": + ensure_models() + + paths = get_build_paths(config.build_type, config.arch, config.app_version, config.commit_hash) + clean_build_outputs(paths) + build_bundle(config, paths) + organize_app_bundle_resources(paths.app_dir) + + signing_context: SigningContext | None = None + try: + signing_context = prepare_signing(config) + if config.notarize and signing_context is None and not config.sign_identity: + raise RuntimeError("启用 --notarize 时必须提供 --sign-p12 或 --sign-identity 以完成正式签名") + sign_app_bundle(paths.app_dir, signing_context) + create_dmg(config, paths) + sign_dmg(paths.dmg_path, signing_context) + notarize_dmg(paths.dmg_path, config) + final_app, final_dmg = publish_artifacts(paths, config) + logger.info("[========================================]") + logger.info("构建完成") + logger.info("[========================================]") + logger.info("构建类型: %s", config.build_type) + logger.info("目标架构: %s", config.arch) + logger.info(".app: %s", final_app) + logger.info(".dmg: %s", final_dmg) + finally: + cleanup_signing_context(signing_context) + + +def main() -> None: + """ + 程序入口 / Program entrypoint. + """ + + args = parse_args() + configure_logging(args.debug) + config = create_config(args) + backup_path = inject_build_info(config.commit_hash, config.release_channel) + try: + run_build(config) + finally: + restore_build_info(backup_path) + + +if __name__ == "__main__": + main() diff --git a/build_release_win.py b/build_release_win.py index 1190a391..b3fa00a5 100644 --- a/build_release_win.py +++ b/build_release_win.py @@ -16,10 +16,12 @@ import ast import hashlib import logging +import os import re import shutil import subprocess import sys +import zipfile from dataclasses import dataclass from pathlib import Path from typing import Sequence @@ -31,10 +33,12 @@ BUILD_INFO_FILE = ROOT_DIR / "core" / "build_info.py" DOWNLOAD_MODELS_SCRIPT = ROOT_DIR / "scripts" / "download_models.py" SPEC_FILE = ROOT_DIR / "SuperPicky_win64.spec" +LITE_SPEC_FILE = ROOT_DIR / "SuperPicky_lite_win.spec" CPU_VENV_DIR = ROOT_DIR / ".venv" CUDA_VENV_DIR = ROOT_DIR / ".venv-cuda" DEFAULT_PATCH_OUTPUT_ROOT = ROOT_DIR / "output" STANDARD_INNO_TEMPLATE = INNO_DIR / "SuperPicky.iss" +LITE_INNO_TEMPLATE = INNO_DIR / "SuperPicky-lite.iss" PATCH_INNO_TEMPLATE = INNO_DIR / "SuperPicky_CUDA_Patch.iss" INNO_LANGUAGE_FILE = INNO_DIR / "ChineseSimplified.isl" CPU_REQUIREMENTS_FILE = ROOT_DIR / "requirements.txt" @@ -42,6 +46,7 @@ PATCH_MANIFEST_RELATIVE_PATH = Path("_internal") / "cuda_patch_manifest.txt" CPU_INSTALLER_STAGING_DIRNAME = "installer_cpu" CUDA_INSTALLER_STAGING_DIRNAME = "installer_cuda" +LITE_INSTALLER_STAGING_DIRNAME = "installer_lite" CUDA_PATCH_PORTABLE_DIRNAME = "cuda_patch" CUDA_PATCH_INSTALLER_STAGING_DIRNAME = "cuda_patch_installer" @@ -68,9 +73,9 @@ class BuildConfig: def configure_logging(debug: bool) -> None: if hasattr(sys.stdout, "reconfigure"): - sys.stdout.reconfigure(encoding="utf-8", errors="strict") + sys.stdout.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] if hasattr(sys.stderr, "reconfigure"): - sys.stderr.reconfigure(encoding="utf-8", errors="strict") + sys.stderr.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] logger.setLevel(logging.DEBUG if debug else logging.INFO) logger.propagate = False @@ -90,13 +95,17 @@ def log_step(title: str) -> None: logger.info("[========================================]") +def log_verbose(message: str, *args) -> None: + logger.debug(message, *args) + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="SuperPicky Windows 构建脚本") parser.add_argument( "--build-type", - choices=["cpu", "cuda", "cuda-patch"], - default="cpu", - help="构建类型:cpu, cuda, cuda-patch (默认: cpu)", + choices=["cpu", "cuda", "cuda-patch", "lite"], + default="lite", + help="构建类型:cpu, cuda, cuda-patch, lite (默认: lite)", ) parser.add_argument("--version", help="覆盖基础版本号,例如 4.2.0") parser.add_argument("--copy-dir", help="复制最终产物的目标目录") @@ -172,6 +181,7 @@ def load_required_models() -> list[dict[str, str]]: {"filename": "cub200_keypoint_resnet50_slim.pth", "dest_dir": "models"}, {"filename": "avonet.db", "dest_dir": "birdid/data"}, {"filename": "cfanet_iaa_ava_res50-3cd62bb3.pth", "dest_dir": "models"}, + {"filename": "yolo11l-seg.pt", "dest_dir": "models"}, ] if not DOWNLOAD_MODELS_SCRIPT.exists(): @@ -191,7 +201,7 @@ def load_required_models() -> list[dict[str, str]]: break if models is None: raise RuntimeError("download_models.py 中未找到 MODELS_TO_DOWNLOAD") - logger.info("[成功] 已从 download_models.py 加载模型列表") + log_verbose("[成功] 已从 download_models.py 加载模型列表") return [{"filename": item["filename"], "dest_dir": item["dest_dir"]} for item in models] except BaseException as exc: if isinstance(exc, KeyboardInterrupt): @@ -216,7 +226,7 @@ def ensure_models(python_exe: Path) -> None: log_step("步骤 0: 检查并下载模型文件") missing = find_missing_models() if not missing: - logger.info("[成功] 所有模型文件已就绪") + log_verbose("[成功] 所有模型文件已就绪") return logger.warning("缺失 %d 个模型文件,开始下载", len(missing)) @@ -231,7 +241,7 @@ def ensure_models(python_exe: Path) -> None: logger.error("仍然缺失: %s", path) raise RuntimeError("模型下载后仍有缺失") - logger.info("[成功] 所有模型文件已就绪") + log_verbose("[成功] 所有模型文件已就绪") def read_app_version() -> str: @@ -279,7 +289,7 @@ def inject_build_info(commit_hash: str, release_channel: str = "official") -> Pa count=1, ) BUILD_INFO_FILE.write_text(updated, encoding="utf-8") - logger.info("[成功] 已写入 COMMIT_HASH=%s RELEASE_CHANNEL=%s", commit_hash, release_channel) + log_verbose("[成功] 已写入 COMMIT_HASH=%s RELEASE_CHANNEL=%s", commit_hash, release_channel) return backup_path @@ -288,16 +298,24 @@ def restore_build_info(backup_path: Path | None) -> None: shutil.move(str(backup_path), str(BUILD_INFO_FILE)) -def ensure_spec_file() -> None: - if not SPEC_FILE.exists(): - raise FileNotFoundError(f"缺少 spec 文件: {SPEC_FILE}") +def spec_file_for(build_type: str) -> Path: + if build_type == "lite": + return LITE_SPEC_FILE + return SPEC_FILE + + +def ensure_spec_file(build_type: str) -> None: + spec_file = spec_file_for(build_type) + if not spec_file.exists(): + raise FileNotFoundError(f"缺少 spec 文件: {spec_file}") + log_verbose("[信息] %s 构建将使用 spec: %s", build_type, spec_file.name) def check_python_environment(python_exe: Path, label: str) -> None: - logger.info("[信息] 检查 Python 环境 (%s): %s", label, python_exe) + log_verbose("[信息] 检查 Python 环境 (%s): %s", label, python_exe) run_command([str(python_exe), "-c", "import sys; print(sys.executable)"], label=f"{label} Python 检查") run_command([str(python_exe), "-c", "import PyInstaller"], label=f"{label} PyInstaller 检查") - logger.info("[成功] %s 环境可用", label) + log_verbose("[成功] %s 环境可用", label) def python_in_venv(venv_dir: Path) -> Path: @@ -315,7 +333,7 @@ def ensure_virtual_environment( venv_python = python_in_venv(venv_dir) if not venv_python.exists(): - logger.info("[信息] 创建 %s 虚拟环境: %s", label, venv_dir) + logger.info("创建 %s 虚拟环境...", label) run_command([str(bootstrap_python), "-m", "venv", str(venv_dir)], label=f"创建 {label} 虚拟环境") run_command([str(venv_python), "-m", "pip", "install", "--upgrade", "pip"], label=f"升级 {label} 环境 pip") @@ -348,99 +366,126 @@ def ensure_cuda_environment(bootstrap_python: Path) -> Path: def clean_build_outputs() -> None: log_step("步骤 2: 清理旧的构建目录") - for label in ("cpu", "cuda", "cuda_patch"): + for label in ("cpu", "cuda", "cuda_patch", "lite"): paths = get_build_paths(label) remove_path(paths.work_dir) remove_path(paths.dist_dir) remove_path(ROOT_DIR / "build_dist") remove_path(ROOT_DIR / "dist") - logger.info("[成功] 已清理构建目录") + log_verbose("[成功] 已清理构建目录") -def build_bundle(python_exe: Path, build_paths: BuildPaths) -> None: +def build_bundle(python_exe: Path, build_paths: BuildPaths, spec_file: Path) -> None: log_step(f"步骤 3: 构建 {build_paths.label.upper()} 版本") remove_path(build_paths.work_dir) remove_path(build_paths.dist_dir) + pyinstaller_command = [ + str(python_exe), + "-m", + "PyInstaller", + str(spec_file), + "--clean", + "--noconfirm", + f"--workpath={build_paths.work_dir}", + f"--distpath={build_paths.dist_dir}", + ] + logger.info("启动 PyInstaller 构建:开始") + logger.info("PyInstaller 参数:%s", " ".join(str(item) for item in pyinstaller_command[2:])) run_command( - [ - str(python_exe), - "-m", - "PyInstaller", - str(SPEC_FILE), - "--clean", - "--noconfirm", - f"--workpath={build_paths.work_dir}", - f"--distpath={build_paths.dist_dir}", - ], + pyinstaller_command, + capture_output=not logger.isEnabledFor(logging.DEBUG), label=f"{build_paths.label} PyInstaller 构建", ) exe_path = build_paths.bundle_dir / f"{APP_NAME}.exe" if not exe_path.exists(): raise FileNotFoundError(f"构建完成后未找到可执行文件: {exe_path}") - logger.info("[成功] %s 构建完成", build_paths.label.upper()) - - -def find_7z_executable() -> str: - candidates = [ - shutil.which("7z"), - shutil.which("7zz"), - shutil.which("7za"), - str(Path("C:/Program Files/7-Zip/7z.exe")), - str(Path("C:/Program Files (x86)/7-Zip/7z.exe")), - ] - for candidate in candidates: - if not candidate: - continue - if Path(candidate).exists(): - return candidate - raise FileNotFoundError("未找到 7z 可执行文件,请先安装 7-Zip 或确保 7z/7zz/7za 已加入 PATH") + logger.info("PyInstaller 构建成功!") + logger.info("构建产物位置:%s", exe_path) + log_verbose("[成功] %s 构建完成", build_paths.label.upper()) def create_zip_archive(source_dir: Path, archive_path: Path) -> None: + """ + 使用标准库创建 ZIP 包 / Create ZIP archives with the Python standard library. + """ + archive_path.parent.mkdir(parents=True, exist_ok=True) archive_path.unlink(missing_ok=True) - seven_zip = find_7z_executable() - run_command( - [ - seven_zip, - "a", - "-tzip", - "-mx=9", - "-mm=LZMA", - "-md=256m", - "-mfb=128", - "-mmt=on", - str(archive_path), - source_dir.name, - ], - cwd=source_dir.parent, - label="ZIP 压缩", - ) + with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as archive: + for file_path in sorted(source_dir.rglob("*")): + if file_path.is_dir(): + continue + archive.write(file_path, arcname=str(Path(source_dir.name) / file_path.relative_to(source_dir))) def archive_name_for(label: str, app_version: str, commit_hash: str) -> str: return f"{APP_NAME}_Win64_{app_version}_{commit_hash}_{label}.zip" -def update_inno_content(content: str, *, app_version: str, commit_hash: str, patch: bool) -> str: - version_value = f"{app_version}-{commit_hash}" - if patch: - output_base = f"SuperPicky_CUDA_Patch_Win64_{app_version}_{commit_hash}" - else: - output_base = f"SuperPicky_Setup_Win64_{app_version}_{commit_hash}" +def normalize_version(version: str) -> str: + """确保版本号以 'v' 前缀开头。 + + Ensure version string starts with 'v' prefix. + + 参数 / Parameters: + version (str): 原始版本号,例如 "4.2.0" 或 "v4.2.0" + + 返回 / Return: + str: 带 'v' 前缀的版本号,例如 "v4.2.0" + """ + return version if version.startswith("v") else f"v{version}" + + +def update_inno_content(content: str, *, app_version: str, commit_hash: str) -> str: + """替换 ISS 模板中的 #define 预处理器变量,注入版本号和提交哈希。 - content = re.sub(r"(?m)^AppVersion=.*$", f"AppVersion={version_value}", content) - content = re.sub(r"(?m)^OutputBaseFilename=.*$", f"OutputBaseFilename={output_base}", content) + Replace #define preprocessor variables in ISS template with version and commit hash. + + 参数 / Parameters: + content (str): ISS 模板原始内容 + app_version (str): 应用版本号(将自动添加 'v' 前缀) + commit_hash (str): Git 提交哈希 + + 返回 / Return: + str: 替换后的 ISS 内容 + """ + versioned = normalize_version(app_version) + content = re.sub( + r'(?m)^(#define\s+MyAppVersion\s+").*?(")\s*$', + rf'\g<1>{versioned}\2', + content, + ) + content = re.sub( + r'(?m)^(#define\s+MyAppCommitHash\s+").*?(")\s*$', + rf'\g<1>{commit_hash}\2', + content, + ) return content -def write_inno_script(template_path: Path, destination_path: Path, *, app_version: str, commit_hash: str, patch: bool) -> None: +def write_inno_script( + template_path: Path, + destination_path: Path, + *, + app_version: str, + commit_hash: str, +) -> None: + """读取 ISS 模板,注入版本号和哈希后写入目标路径。 + + Read ISS template, inject version and hash, write to destination. + + 参数 / Parameters: + template_path (Path): ISS 模板文件路径 + destination_path (Path): 输出 ISS 文件路径 + app_version (str): 应用版本号 + commit_hash (str): Git 提交哈希 + """ content = template_path.read_text(encoding="utf-8") destination_path.parent.mkdir(parents=True, exist_ok=True) destination_path.write_text( - update_inno_content(content, app_version=app_version, commit_hash=commit_hash, patch=patch), + update_inno_content(content, app_version=app_version, commit_hash=commit_hash), encoding="utf-8", ) @@ -450,22 +495,55 @@ def installer_staging_dir_name(label: str) -> str: return CPU_INSTALLER_STAGING_DIRNAME if label == "cuda": return CUDA_INSTALLER_STAGING_DIRNAME + if label == "lite": + return LITE_INSTALLER_STAGING_DIRNAME raise ValueError(f"不支持的标准安装包标签: {label}") +def inno_template_for(label: str) -> Path: + """根据构建标签返回对应的 ISS 模板路径。 + + Return the ISS template path for the given build label. + + 参数 / Parameters: + label (str): 构建标签,"lite" 或其他(Full/CPU/CUDA) + + 返回 / Return: + Path: ISS 模板文件路径 + """ + if label == "lite": + return LITE_INNO_TEMPLATE + return STANDARD_INNO_TEMPLATE + + def prepare_standard_installer_staging(source_bundle_dir: Path, staging_root: Path, config: BuildConfig, *, label: str) -> Path: + """准备标准安装包的 staging 目录,包含构建产物、ISS 脚本和依赖资源。 + + Prepare standard installer staging directory with build artifacts, ISS script and dependencies. + + 参数 / Parameters: + source_bundle_dir (Path): PyInstaller 构建产物目录 + staging_root (Path): staging 根目录 + config (BuildConfig): 构建配置 + label (str): 构建标签("cpu", "cuda", "lite") + + 返回 / Return: + Path: 生成的 ISS 脚本路径 + """ staging_dir = staging_root / installer_staging_dir_name(label) copy_tree(source_bundle_dir, staging_dir) + template = inno_template_for(label) + iss_filename = template.name write_inno_script( - STANDARD_INNO_TEMPLATE, - staging_dir / "SuperPicky.iss", + template, + staging_dir / iss_filename, app_version=config.app_version, commit_hash=config.commit_hash, - patch=False, ) copy_file(INNO_LANGUAGE_FILE, staging_dir / INNO_LANGUAGE_FILE.name) - logger.info("[成功] 已准备标准安装包脚本目录: %s", staging_dir) - return staging_dir / "SuperPicky.iss" + copy_tree(ROOT_DIR / "img", staging_dir / "img") + log_verbose("[成功] 已准备标准安装包脚本目录: %s", staging_dir) + return staging_dir / iss_filename def publish_standard_build( @@ -489,12 +567,12 @@ def publish_standard_build( if not config.no_zip: zip_path = artifact_root / archive_name_for(label, config.app_version, config.commit_hash) create_zip_archive(zip_source_dir, zip_path) - logger.info("[成功] 已创建 ZIP 压缩包: %s", zip_path) + log_verbose("[成功] 已创建 ZIP 压缩包: %s", zip_path) else: zip_path = None - logger.info("[信息] 跳过 ZIP 压缩包创建 (--no-zip)") + log_verbose("[信息] 跳过 ZIP 压缩包创建 (--no-zip)") - logger.info("[成功] 已准备目录: %s", final_bundle_dir) + log_verbose("[成功] 已准备目录: %s", final_bundle_dir) return final_bundle_dir, zip_path, installer_script_path @@ -555,9 +633,9 @@ def prepare_patch_directory(cpu_bundle: Path, cuda_bundle: Path, config: BuildCo shutil.copy2(cuda_file, destination) manifest_path = write_patch_manifest(patch_dir, copied_patch_files) - logger.info("[成功] 已导出差异文件: %d 个不同文件, %d 个 CUDA 独有文件", different_count, cuda_only_count) - logger.info("[成功] 已写入补丁清单: %s", manifest_path) - logger.info("[成功] 补丁目录: %s", patch_dir) + log_verbose("[成功] 已导出差异文件: %d 个不同文件, %d 个 CUDA 独有文件", different_count, cuda_only_count) + log_verbose("[成功] 已写入补丁清单: %s", manifest_path) + log_verbose("[成功] 补丁目录: %s", patch_dir) return patch_dir @@ -572,14 +650,17 @@ def prepare_patch_installer_staging(portable_patch_dir: Path, config: BuildConfi staging_dir / PATCH_INNO_TEMPLATE.name, app_version=config.app_version, commit_hash=config.commit_hash, - patch=True, ) - logger.info("[成功] 已准备 CUDA 补丁安装包脚本目录: %s", staging_dir) + log_verbose("[成功] 已准备 CUDA 补丁安装包脚本目录: %s", staging_dir) return staging_dir / PATCH_INNO_TEMPLATE.name def ensure_inno_templates() -> None: - for path in (STANDARD_INNO_TEMPLATE, PATCH_INNO_TEMPLATE, INNO_LANGUAGE_FILE): + """检查所有 Inno Setup 模板和依赖文件是否存在。 + + Verify all Inno Setup templates and dependency files exist. + """ + for path in (STANDARD_INNO_TEMPLATE, LITE_INNO_TEMPLATE, PATCH_INNO_TEMPLATE, INNO_LANGUAGE_FILE): if not path.exists(): raise FileNotFoundError(f"缺少 Inno 相关文件: {path}") @@ -593,7 +674,7 @@ def resolve_final_root(build_type: str, copy_dir: Path | None) -> Path | None: def build_single_target(config: BuildConfig, label: str, python_exe: Path) -> tuple[BuildPaths, Path, Path | None, Path]: check_python_environment(python_exe, label.upper()) build_paths = get_build_paths(label) - build_bundle(python_exe, build_paths) + build_bundle(python_exe, build_paths, spec_file_for(label if label == "lite" else config.build_type)) final_root = resolve_final_root(config.build_type, config.copy_dir) final_bundle, zip_path, installer_script_path = publish_standard_build( label=label, @@ -623,6 +704,20 @@ def run_cpu_or_cuda_build(config: BuildConfig) -> None: logger.info("安装包脚本: %s", installer_script_path) +def run_lite_build(config: BuildConfig) -> None: + bootstrap_python = Path(sys.executable) + build_python = ensure_cpu_environment(bootstrap_python) + + clean_build_outputs() + _, final_bundle, zip_path, installer_script_path = build_single_target(config, "lite", build_python) + logger.info("[========================================]") + logger.info("Lite 构建完成") + logger.info("[========================================]") + logger.info("可执行文件: %s", final_bundle / f"{APP_NAME}.exe") + logger.info("压缩文件: %s", zip_path if zip_path else "(已跳过)") + logger.info("安装包脚本: %s", installer_script_path) + + def run_cuda_patch_build(config: BuildConfig) -> None: bootstrap_python = Path(sys.executable) cpu_python = ensure_cpu_environment(bootstrap_python) @@ -633,7 +728,7 @@ def run_cuda_patch_build(config: BuildConfig) -> None: cuda_python = ensure_cuda_environment(bootstrap_python) cuda_paths = get_build_paths("cuda") - build_bundle(cuda_python, cuda_paths) + build_bundle(cuda_python, cuda_paths, spec_file_for("cuda")) patch_dir = prepare_patch_directory(cpu_paths.bundle_dir, cuda_paths.bundle_dir, config) patch_installer_script = prepare_patch_installer_staging(patch_dir, config) @@ -644,15 +739,15 @@ def run_cuda_patch_build(config: BuildConfig) -> None: config.commit_hash, ) create_zip_archive(patch_dir, patch_zip) - logger.info("[成功] 已创建 CUDA 补丁 ZIP 压缩包: %s", patch_zip) + log_verbose("[成功] 已创建 CUDA 补丁 ZIP 压缩包: %s", patch_zip) else: patch_zip = None - logger.info("[信息] 跳过 CUDA 补丁 ZIP 压缩包创建 (--no-zip)") + log_verbose("[信息] 跳过 CUDA 补丁 ZIP 压缩包创建 (--no-zip)") log_step("步骤 7: 清理 CUDA 中间产物") remove_path(cuda_paths.work_dir) remove_path(cuda_paths.dist_dir) - logger.info("[成功] 已清理 CUDA 中间目录") + log_verbose("[成功] 已清理 CUDA 中间目录") logger.info("[========================================]") logger.info("CUDA Patch 构建完成") @@ -666,12 +761,23 @@ def run_cuda_patch_build(config: BuildConfig) -> None: def create_config(args: argparse.Namespace) -> BuildConfig: + """根据命令行参数创建构建配置。版本号自动添加 'v' 前缀。 + + Create build config from CLI arguments. Version is auto-prefixed with 'v'. + + 参数 / Parameters: + args (argparse.Namespace): 解析后的命令行参数 + + 返回 / Return: + BuildConfig: 构建配置对象 + """ + raw_version = args.version or read_app_version() return BuildConfig( build_type=args.build_type, copy_dir=Path(args.copy_dir).resolve() if args.copy_dir else None, no_zip=args.no_zip, debug=args.debug, - app_version=args.version or read_app_version(), + app_version=normalize_version(raw_version), commit_hash=get_commit_hash(), ) @@ -681,16 +787,17 @@ def main() -> None: configure_logging(args.debug) config = create_config(args) - ensure_spec_file() + ensure_spec_file(config.build_type) ensure_inno_templates() - import os as _os - _tag = _os.environ.get("RELEASE_TAG", "") - _channel = "nightly" if _tag and "-rc" in _tag.lower() else "official" - backup_path = inject_build_info(config.commit_hash, _channel) + release_tag = os.environ.get("RELEASE_TAG", "") + release_channel = "nightly" if release_tag and "-rc" in release_tag.lower() else "official" + backup_path = inject_build_info(config.commit_hash, release_channel) try: if config.build_type == "cuda-patch": run_cuda_patch_build(config) + elif config.build_type == "lite": + run_lite_build(config) else: run_cpu_or_cuda_build(config) finally: diff --git a/config.py b/config.py index e8c82add..3a932d2e 100644 --- a/config.py +++ b/config.py @@ -1,28 +1,18 @@ """ -SuperPicky 配置管理模块。 -SuperPicky configuration management module. +SuperPicky 配置管理模块 / SuperPicky configuration management module. 本文件负责静态常量、路径约定、轻量运行时覆盖入口与共享懒加载注册器。 This file owns static constants, path conventions, lightweight runtime overrides, and the shared lazy registry. -维护分层: -Maintenance layering: -- `config.py`:公共读取入口与基础配置。 - `config.py`: shared read entry points and foundational configuration. -- `advanced_config.py`:高级持久化配置的默认值、读写、迁移与 UI 对接。 - `advanced_config.py`: defaults, persistence, migration, and UI integration for advanced persistent settings. -- 用户配置文件默认位于 `get_app_config_dir() / "advanced_config.json"`。 - The user config file defaults to `get_app_config_dir() / "advanced_config.json"`. - -文档入口: -Documentation entry points: -- 维护指南:当前文件本身。 - Maintenance guide: this file itself. -- 中英对照配置指南:TODO - 待补正式路径。 - Bilingual configuration guide: TODO - add the final path later. +维护分层 / Maintenance layering: +- `config.py`:公共读取入口与基础配置 / shared read entry points and foundational configuration. +- `advanced_config.py`:高级持久化配置的默认值、读写、迁移与 UI 对接 / defaults, persistence, migration, and UI integration for advanced persistent settings. +- 用户配置文件默认位于 `get_app_config_dir() / "advanced_config.json"` / The user config file defaults to `get_app_config_dir() / "advanced_config.json"`. """ import json +import importlib +import logging import os import platform import sys @@ -31,35 +21,162 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional -import torch +logger = logging.getLogger(__name__) +# Torch is intentionally imported lazily. +# macOS Lite frozen builds bundle Torch inside the app, and importing it at +# module-load time makes it easier to bind to a wrong partial path before the +# frozen runtime is fully settled. +torch = None -# ========================= -# 基础路径工具 -# ========================= -# 这一层只定义路径约定,不负责真实配置读写。 -# This layer only defines path conventions and does not implement actual config read/write behavior. +class _FallbackDevice: + def __init__(self, device_type: str): + self.type = device_type + + def __str__(self) -> str: + return self.type + + +def _get_torch_module(): + """Lazily (re)load torch so lightweight init can install it at runtime.""" + global torch + if torch is not None: + return torch + try: + torch = importlib.import_module("torch") + except Exception: + torch = None + return torch + + +def get_app_install_dir() -> Path: + """ + 返回应用安装根目录 / Return the application install root. + + Windows Lite 打包场景下,运行时、模型和数据库必须固定落在该目录内。 + In Windows Lite builds, runtime files, models, and databases must stay under this directory. + """ + if getattr(sys, "frozen", False): + executable = Path(sys.executable).resolve() + if sys.platform == "darwin" and executable.parent.name == "MacOS": + return executable.parents[2] + return executable.parent + return Path(__file__).resolve().parent + + +def get_runtime_meipass() -> Optional[str]: + """ + 返回 PyInstaller 注入的 `_MEIPASS` 路径字符串。 + Return the `_MEIPASS` path string injected by PyInstaller. + + 这是运行时动态属性,静态类型检查器并不知道它一定存在, + 所以所有调用方都应通过此函数统一访问,而不是直接读取 `sys._MEIPASS`。 + This is a runtime-only dynamic attribute that static type checkers do not + know about, so callers should go through this helper instead of touching + `sys._MEIPASS` directly. + """ + meipass = getattr(sys, "_MEIPASS", None) + if isinstance(meipass, str) and meipass: + return meipass + return None + + +def get_runtime_app_root() -> Optional[str]: + """ + 返回补丁覆盖层记录的真实应用根目录字符串。 + Return the real application root string recorded for the patch overlay. + + 在线补丁覆盖层会优先导入用户目录中的模块,导致 `__file__` 可能指向 + `code_updates/`。这里统一读取主入口注入的真实根目录,避免各模块自行 + 读取 `sys._SUPERPICKY_APP_ROOT` 触发静态告警。 + The patch overlay may cause `__file__` to point at `code_updates/`, so this + helper reads the real app root injected by the main entrypoint and avoids + direct `sys._SUPERPICKY_APP_ROOT` access across modules. + """ + app_root = getattr(sys, "_SUPERPICKY_APP_ROOT", None) + if isinstance(app_root, str) and app_root: + return app_root + return None + + +def set_runtime_app_root(app_root: str) -> str: + """ + 写入补丁覆盖层共享的真实应用根目录。 + Persist the real application root shared by the patch overlay. + + 这里使用 `setattr` 写入运行时动态属性,既保留现有打包/补丁行为, + 也避免直接赋值 `sys._SUPERPICKY_APP_ROOT` 触发 Pylance 属性告警。 + This helper uses `setattr` to preserve the existing runtime contract while + avoiding direct `sys._SUPERPICKY_APP_ROOT` assignments that trip Pylance. + """ + setattr(sys, "_SUPERPICKY_APP_ROOT", app_root) + return app_root + + +def get_bundled_resource_dir() -> Path: + """返回静态打包资源根目录 / Return the root directory for bundled static resources.""" + if getattr(sys, "frozen", False): + if sys.platform == "darwin": + executable = Path(sys.executable).resolve() + if executable.parent.name == "MacOS": + return executable.parents[1] / "Resources" + meipass = get_runtime_meipass() + if meipass is not None: + return Path(meipass) + app_root = get_runtime_app_root() + if app_root is not None: + return Path(app_root) + return get_app_install_dir() + + +def get_app_internal_dir() -> Path: + """ + 返回应用内部运行目录 / Return the application internal runtime directory. + + Windows one-dir 打包产物使用安装目录下的 `_internal/`。 + Other environments fall back to the bundled resource directory. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + return get_app_install_dir() / "_internal" + return get_bundled_resource_dir() + + +def get_install_scoped_resource_path( + relative_path: str, *, packaged_relative_path: Optional[str] = None +) -> Path: + """ + 返回安装目录约束下的资源路径 / Return a resource path constrained to the install directory when required. + + Windows Lite 打包环境下,模型/数据库/运行时等可变资源必须位于安装目录。 + Other environments keep using the bundled resource layout. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + target_relative_path = packaged_relative_path or relative_path + return get_app_internal_dir() / target_relative_path + return get_bundled_resource_dir() / relative_path + + +def get_packaged_model_relative_path(relative_path: str) -> str: + """返回 Windows Lite 打包环境下模型的内部相对路径 / Return the packaged relative path for models in Windows Lite builds.""" + normalized = relative_path.replace("\\", "/") + if normalized.startswith("models/"): + return "models/" + normalized.split("/", 1)[1] + return normalized def resource_path(relative_path: str) -> str: """ - 返回打包资源路径,兼容开发环境与 PyInstaller。 - Return a packaged resource path compatible with development mode and PyInstaller. + 返回打包资源路径,兼容开发环境与 PyInstaller / Return a packaged resource path compatible with development mode and PyInstaller. `relative_path` 是资源相对路径,例如 `models/yolo11l-seg.pt`。 - `relative_path` is the resource-relative path, for example `models/yolo11l-seg.pt`. - 这里只用于内置资源定位,不能拿来定位用户配置或用户数据。 This is only for bundled resource lookup and must not be used for user config or user data paths. """ - meipass = getattr(sys, '_MEIPASS', None) - if isinstance(meipass, str): - return os.path.join(meipass, relative_path) - return os.path.join(os.path.abspath('.'), relative_path) + return str(get_bundled_resource_dir() / relative_path) -def get_app_config_dir(app_name: str = 'SuperPicky') -> Path: +def get_app_config_dir(app_name: str = "SuperPicky") -> Path: """ 返回跨平台应用配置目录(存放 advanced_config.json、补丁等程序配置)。 Return the cross-platform application config directory. @@ -71,57 +188,50 @@ def get_app_config_dir(app_name: str = 'SuperPicky') -> Path: 用途:advanced_config.json、code_updates/(补丁目录)等程序级配置。 """ - if sys.platform == 'darwin': - return Path.home() / 'Library' / 'Application Support' / app_name - if sys.platform == 'win32': - return Path.home() / 'AppData' / 'Local' / app_name - return Path.home() / '.config' / app_name + if sys.platform == "darwin": + return Path.home() / "Library" / "Application Support" / app_name + if sys.platform == "win32": + return Path.home() / "AppData" / "Local" / app_name + return Path.home() / ".config" / app_name -def get_app_data_dir(app_name: str = 'SuperPicky') -> Path: +def get_app_data_dir(app_name: str = "SuperPicky") -> Path: """ 返回跨平台用户数据目录(存放 birdid 设置等用户产物)。 Return the cross-platform user data directory. - ⚠️ 与 get_app_config_dir() 完全不同的路径,请勿混用: - 所有平台:~/Documents/SuperPicky_Data/ + ⚠️ 现已统一使用标准配置目录,与 get_app_config_dir() 返回相同路径: + macOS : ~/Library/Application Support/SuperPicky/ + Windows: ~/AppData/Local/SuperPicky/ + Linux : ~/.config/SuperPicky/ 用途:birdid_dock_settings.json 等用户可见的数据文件。 切勿用于存放补丁或程序内部配置(应使用 get_app_config_dir())。 """ - return Path.home() / 'Documents' / f'{app_name}_Data' + return get_app_config_dir(app_name) -def get_patch_dir(app_name: str = 'SuperPicky') -> Path: - """ - 返回在线补丁目录。 - Return the online patch directory. +def get_patch_dir(app_name: str = "SuperPicky") -> Path: + """返回在线补丁目录 / Return the online patch directory.""" + return get_app_config_dir(app_name) / "code_updates" - 补丁目录派生自配置目录。 - The patch directory is derived from the config directory. - """ - return get_app_config_dir(app_name) / 'code_updates' +def get_birdid_settings_path(app_name: str = "SuperPicky") -> Path: + """返回 BirdID Dock 设置文件路径 / Return the BirdID Dock settings file path.""" + return get_app_data_dir(app_name) / "birdid_dock_settings.json" -def get_birdid_settings_path(app_name: str = 'SuperPicky') -> Path: - """ - 返回 BirdID Dock 设置文件路径。 - Return the BirdID Dock settings file path. - 该文件属于用户数据,因此放在 app data 目录下。 - This file belongs to user data, so it lives under the app data directory. +def get_birdname_settings_path(app_name: str = "SuperPicky") -> Path: """ - return get_app_data_dir(app_name) / 'birdid_dock_settings.json' - + 返回 BirdName IOC 设置文件路径 / Return the BirdName IOC settings file path. -# ========================= -# 可覆盖配置(ENV + 配置文件) -# ========================= + 该文件属于全局用户配置,应统一收敛到标准配置目录下的 ioc/ 子目录。 + This file belongs to global user configuration and should live under the standard config directory's ioc/ subdirectory. + """ + settings_dir = get_app_config_dir(app_name) / "ioc" + settings_dir.mkdir(parents=True, exist_ok=True) + return settings_dir / "birdname_settings.ini" -# 这里只读取覆盖值,不定义高级配置 schema。 -# This layer only reads override values and does not define the advanced config schema. -# 优先级:ENV > advanced_config.json > 默认值。 -# Priority: ENV > advanced_config.json > default value. _override_cache: Optional[Dict[str, Any]] = None _override_lock = threading.RLock() @@ -140,13 +250,13 @@ def _load_override_file() -> Dict[str, Any]: if _override_cache is not None: return _override_cache - cfg_path = get_app_config_dir() / 'advanced_config.json' + cfg_path = get_app_config_dir() / "advanced_config.json" if not cfg_path.exists(): _override_cache = {} return _override_cache try: - _override_cache = json.loads(cfg_path.read_text(encoding='utf-8')) + _override_cache = json.loads(cfg_path.read_text(encoding="utf-8")) except Exception: _override_cache = {} return _override_cache if _override_cache is not None else {} @@ -154,8 +264,7 @@ def _load_override_file() -> Dict[str, Any]: def _parse_bool(value: Optional[str], default: bool) -> bool: """ - 把字符串值解析为布尔值。 - Parse a string-like value into a boolean. + 把字符串值解析为布尔值 / Parse a string-like value into a boolean. 支持 `1/true/yes/on` 和 `0/false/no/off`,否则返回默认值。 Supports `1/true/yes/on` and `0/false/no/off`, otherwise returns the default. @@ -163,23 +272,22 @@ def _parse_bool(value: Optional[str], default: bool) -> bool: if value is None: return default norm = str(value).strip().lower() - if norm in {'1', 'true', 'yes', 'on'}: + if norm in {"1", "true", "yes", "on"}: return True - if norm in {'0', 'false', 'no', 'off'}: + if norm in {"0", "false", "no", "off"}: return False return default def _env_or_override(name: str, override_key: Optional[str], default: Any) -> Any: """ - 按 ENV > JSON > 默认值 的优先级读取覆盖值。 - Read an override using the priority order ENV > JSON > default. + 按 ENV > JSON > 默认值 的优先级读取覆盖值 / Read an override using the priority order ENV > JSON > default. 这里不做类型转换,调用方自行转成 `int`、`float` 或 `str`。 No type conversion is done here; callers should convert to `int`, `float`, or `str` themselves. """ env_value = os.getenv(name) - if env_value is not None and str(env_value).strip() != '': + if env_value is not None and str(env_value).strip() != "": return env_value if override_key: @@ -190,217 +298,182 @@ def _env_or_override(name: str, override_key: Optional[str], default: Any) -> An return default -# ========================= -# 静态常量分层 -# ========================= - -# 这些 dataclass 用来按领域收拢常量。 -# These dataclasses group constants by domain. - - @dataclass class FileConfig: """ - 文件处理相关静态配置。 - Static configuration related to file handling. + 文件处理相关静态配置 / Static configuration related to file handling. 这些列表会被 RAW/JPG 分类逻辑直接消费。 These lists are consumed directly by RAW/JPG classification logic. """ - # RAW_EXTENSIONS:被视为 RAW 的扩展名列表。 - # RAW_EXTENSIONS: extensions treated as RAW files. - RAW_EXTENSIONS: List[str] = field(default_factory=lambda: [ - '.nef', '.cr2', '.cr3', '.arw', '.raf', - '.orf', '.rw2', '.pef', '.dng', '.3fr', '.iiq' - ]) - # JPG_EXTENSIONS:被视为 JPG/JPEG 的扩展名列表。 - # JPG_EXTENSIONS: extensions treated as JPG/JPEG files. - JPG_EXTENSIONS: List[str] = field(default_factory=lambda: ['.jpg', '.jpeg']) + RAW_EXTENSIONS: List[str] = field( + default_factory=lambda: [ + ".nef", + ".cr2", + ".cr3", + ".arw", + ".raf", + ".orf", + ".rw2", + ".pef", + ".dng", + ".3fr", + ".iiq", + ] + ) + JPG_EXTENSIONS: List[str] = field(default_factory=lambda: [".jpg", ".jpeg"]) @dataclass class DirectoryConfig: """ - 输出目录与报告文件命名配置。 - Naming configuration for output directories and report files. + 输出目录与报告文件命名配置 / Naming configuration for output directories and report files. 修改这些值会影响结果目录结构与报告文件名。 Changing these values affects result folder layout and report filenames. """ - # 高质量照片目录。 - # Directory for excellent photos. - EXCELLENT_DIR: str = '优秀' - # 普通保留照片目录。 - # Directory for standard keepers. - STANDARD_DIR: str = '标准' - # 无鸟或废片目录。 - # Directory for no-bird or rejected photos. - NO_BIRDS_DIR: str = '没鸟' - # 内部临时目录。 - # Internal temporary directory. - TEMP_DIR: str = '_temp' - # 特定工作流使用的 Redbox 目录。 - # Redbox directory for specific workflows. - REDBOX_DIR: str = 'Redbox' - # 裁切临时目录。 - # Temporary crop directory. - CROP_TEMP_DIR: str = '.crop_temp' - - # 旧算法优秀目录。 - # Old-algorithm excellent directory. - OLD_ALGORITHM_EXCELLENT: str = '老算法优秀' - # 新算法优秀目录。 - # New-algorithm excellent directory. - NEW_ALGORITHM_EXCELLENT: str = '新算法优秀' - # 双算法共同优秀目录。 - # Intersection directory for both algorithms. - BOTH_ALGORITHMS_EXCELLENT: str = '双算法优秀' - # 算法差异目录。 - # Directory for algorithm-difference samples. - ALGORITHM_DIFF_DIR: str = '算法差异' - - # 处理日志文件名。 - # Processing log filename. - LOG_FILE: str = '.process_log.txt' - # 主报告 SQLite 文件名。 - # Primary SQLite report filename. - REPORT_FILE: str = '.report.db' - # 算法对比 CSV 文件名。 - # Algorithm comparison CSV filename. - COMPARISON_REPORT_FILE: str = '.algorithm_comparison.csv' + EXCELLENT_DIR: str = "优秀" + STANDARD_DIR: str = "标准" + NO_BIRDS_DIR: str = "没鸟" + TEMP_DIR: str = "_temp" + REDBOX_DIR: str = "Redbox" + CROP_TEMP_DIR: str = ".crop_temp" + OLD_ALGORITHM_EXCELLENT: str = "老算法优秀" + NEW_ALGORITHM_EXCELLENT: str = "新算法优秀" + BOTH_ALGORITHMS_EXCELLENT: str = "双算法优秀" + ALGORITHM_DIFF_DIR: str = "算法差异" + LOG_FILE: str = ".process_log.txt" + REPORT_FILE: str = ".report.db" + COMPARISON_REPORT_FILE: str = ".algorithm_comparison.csv" @dataclass class AIConfig: """ - AI 模型与推理相关静态配置。 - Static configuration related to AI models and inference. + AI 模型与推理相关静态配置 / Static configuration related to AI models and inference. 这些值服务于模型定位与基础推理行为,不替代高级用户参数。 These values support model lookup and baseline inference behavior and do not replace advanced user-facing parameters. """ - # 主模型相对路径。 - # Relative path to the main model. - MODEL_FILE: str = 'models/yolo11l-seg.pt' - # “鸟”类别的 class id。 - # Class id for the "bird" category. + MODEL_FILE: str = "models/yolo11l-seg.pt" BIRD_CLASS_ID: int = 14 - # 推理目标尺寸。 - # Target inference image size. TARGET_IMAGE_SIZE: int = 1024 - # 主体居中判断默认阈值。 - # Default threshold for centered-subject checks. CENTER_THRESHOLD: float = 0.15 - # 锐度归一化策略标识,默认不指定。 - # Sharpness normalization strategy marker, unset by default. SHARPNESS_NORMALIZATION: Optional[str] = None def get_model_path(self) -> str: """ - 返回主模型的实际可访问路径。 - Return the actual accessible path to the main model. + 返回主模型的实际可访问路径 / Return the actual accessible path to the main model. 调用方不应自行拼 PyInstaller 临时目录。 Callers should not manually stitch together PyInstaller temporary paths. """ - return resource_path(self.MODEL_FILE) + return str( + get_install_scoped_resource_path( + self.MODEL_FILE, + packaged_relative_path=get_packaged_model_relative_path( + self.MODEL_FILE + ), + ) + ) @dataclass class UIConfig: """ - UI 展示层静态常量。 - Static constants for the UI presentation layer. + UI 展示层静态常量 / Static constants for the UI presentation layer. 这里只放显示缩放和进度边界,不放业务阈值。 This group is for display scaling and progress bounds, not business thresholds. """ - # 置信度百分比缩放。 - # Percentage scale for confidence display. CONFIDENCE_SCALE: float = 100.0 - # 面积显示缩放。 - # Display scale for area values. AREA_SCALE: float = 1000.0 - # 锐度显示缩放。 - # Display scale for sharpness values. SHARPNESS_SCALE: int = 20 - # 进度条最小值。 - # Minimum progress-bar value. PROGRESS_MIN: int = 0 - # 进度条最大值。 - # Maximum progress-bar value. PROGRESS_MAX: int = 100 - # 默认提示音次数。 - # Default completion beep count. BEEP_COUNT: int = 3 @dataclass class CSVConfig: """ - CSV 报告结构配置。 - CSV report structure configuration. + CSV 报告结构配置 / CSV report structure configuration. `HEADERS` 定义导出列顺序,兼容性要求较高。 `HEADERS` defines export-column order and carries relatively high compatibility requirements. """ - # 报告列名顺序定义。 - # Ordered definition of report header names. - HEADERS: List[str] = field(default_factory=lambda: [ - 'filename', 'found_bird', 'AI score', 'bird_centre_x', - 'bird_centre_y', 'bird_area', 's_bird_area', - 'laplacian_var', 'sobel_var', 'fft_high_freq', 'contrast', - 'edge_density', 'background_complexity', 'motion_blur', - 'normalized_new', 'composite_score', 'result_new', - 'dominant_bool', 'centred_bool', 'sharp_bool', 'class_id' - ]) + HEADERS: List[str] = field( + default_factory=lambda: [ + "filename", + "found_bird", + "AI score", + "bird_centre_x", + "bird_centre_y", + "bird_area", + "s_bird_area", + "laplacian_var", + "sobel_var", + "fft_high_freq", + "contrast", + "edge_density", + "background_complexity", + "motion_blur", + "normalized_new", + "composite_score", + "result_new", + "dominant_bool", + "centred_bool", + "sharp_bool", + "class_id", + ] + ) @dataclass class ServerConfig: """ - BirdID 服务默认配置。 - Default configuration for the BirdID service. + BirdID 服务默认配置 / Default configuration for the BirdID service. 这些值可被 ENV 覆盖,并影响绑定地址、启动等待和健康检查节奏。 These values can be overridden by ENV and affect bind address, startup wait, and health-check cadence. """ - # 默认监听地址。ENV: SUPERPICKY_SERVER_HOST - # Default bind host. ENV: SUPERPICKY_SERVER_HOST - HOST: str = '127.0.0.1' - # 默认端口。ENV: SUPERPICKY_SERVER_PORT - # Default port. ENV: SUPERPICKY_SERVER_PORT + HOST: str = "127.0.0.1" PORT: int = 5156 - # 单次健康检查超时。ENV: SUPERPICKY_SERVER_HEALTH_TIMEOUT - # Single health-check timeout. ENV: SUPERPICKY_SERVER_HEALTH_TIMEOUT HEALTH_TIMEOUT_SECONDS: float = 2.0 - # 启动就绪总等待时间。ENV: SUPERPICKY_SERVER_STARTUP_WAIT - # Total startup readiness wait time. ENV: SUPERPICKY_SERVER_STARTUP_WAIT STARTUP_WAIT_SECONDS: float = 10.0 - # 健康检查轮询间隔。ENV: SUPERPICKY_SERVER_POLL_INTERVAL - # Health-check polling interval. ENV: SUPERPICKY_SERVER_POLL_INTERVAL POLL_INTERVAL_SECONDS: float = 0.5 @classmethod - def load(cls) -> 'ServerConfig': + def load(cls) -> "ServerConfig": """ - 按覆盖优先级构造 ServerConfig。 - Build ServerConfig using the override priority rules. + 按覆盖优先级构造 ServerConfig / Build ServerConfig using the override priority rules. 类型转换统一在这里做,避免调用方重复解析 ENV。 Type coercion is centralized here so callers do not repeat ENV parsing. """ - host = str(_env_or_override('SUPERPICKY_SERVER_HOST', None, cls.HOST)) - port = int(_env_or_override('SUPERPICKY_SERVER_PORT', None, cls.PORT)) - health_timeout = float(_env_or_override('SUPERPICKY_SERVER_HEALTH_TIMEOUT', None, cls.HEALTH_TIMEOUT_SECONDS)) - startup_wait = float(_env_or_override('SUPERPICKY_SERVER_STARTUP_WAIT', None, cls.STARTUP_WAIT_SECONDS)) - poll = float(_env_or_override('SUPERPICKY_SERVER_POLL_INTERVAL', None, cls.POLL_INTERVAL_SECONDS)) + host = str(_env_or_override("SUPERPICKY_SERVER_HOST", None, cls.HOST)) + port = int(_env_or_override("SUPERPICKY_SERVER_PORT", None, cls.PORT)) + health_timeout = float( + _env_or_override( + "SUPERPICKY_SERVER_HEALTH_TIMEOUT", None, cls.HEALTH_TIMEOUT_SECONDS + ) + ) + startup_wait = float( + _env_or_override( + "SUPERPICKY_SERVER_STARTUP_WAIT", None, cls.STARTUP_WAIT_SECONDS + ) + ) + poll = float( + _env_or_override( + "SUPERPICKY_SERVER_POLL_INTERVAL", None, cls.POLL_INTERVAL_SECONDS + ) + ) return cls( HOST=host, PORT=port, @@ -413,47 +486,50 @@ def load(cls) -> 'ServerConfig': @dataclass class EndpointConfig: """ - 远程服务端点默认配置。 - Default configuration for remote service endpoints. + 远程服务端点默认配置 / Default configuration for remote service endpoints. 这些 URL 会影响下载页、eBird 查询与 Nominatim 反查等网络行为。 These URLs affect network behavior such as download pages, eBird queries, and Nominatim reverse lookups. """ - # 镜像或资源下载基础地址。 - # Base URL for mirrors or downloadable resources. - MIRROR_BASE_URL: str = 'http://1.119.150.179:59080/superpicky' - # 给用户打开的下载页面地址。 - # Download page URL opened for the user. - UPDATE_DOWNLOAD_PAGE: str = 'https://superpicky.jamesphotography.com.au/#download' - # eBird API 根地址。 - # Root URL for the eBird API. - EBIRD_API_BASE: str = 'https://api.ebird.org/v2' - # Nominatim 反向地理编码接口。 - # Reverse geocoding endpoint for Nominatim. - NOMINATIM_REVERSE_URL: str = 'https://nominatim.openstreetmap.org/reverse' + MIRROR_BASE_URL: str = "http://1.119.150.179:59080/superpicky" + UPDATE_DOWNLOAD_PAGE: str = "https://superpicky.jamesphotography.com.au/#download" + EBIRD_API_BASE: str = "https://api.ebird.org/v2" + NOMINATIM_REVERSE_URL: str = "https://nominatim.openstreetmap.org/reverse" @classmethod - def load(cls) -> 'EndpointConfig': + def load(cls) -> "EndpointConfig": """ - 按覆盖优先级构造 EndpointConfig。 - Build EndpointConfig using the override priority rules. + 按覆盖优先级构造 EndpointConfig / Build EndpointConfig using the override priority rules. 统一入口便于未来继续扩展 ENV 覆盖。 A unified entry point makes future ENV override expansion easier. """ return cls( - MIRROR_BASE_URL=str(_env_or_override('SUPERPICKY_MIRROR_BASE_URL', None, cls.MIRROR_BASE_URL)), - UPDATE_DOWNLOAD_PAGE=str(_env_or_override('SUPERPICKY_DOWNLOAD_PAGE', None, cls.UPDATE_DOWNLOAD_PAGE)), - EBIRD_API_BASE=str(_env_or_override('SUPERPICKY_EBIRD_API_BASE', None, cls.EBIRD_API_BASE)), - NOMINATIM_REVERSE_URL=str(_env_or_override('SUPERPICKY_NOMINATIM_REVERSE_URL', None, cls.NOMINATIM_REVERSE_URL)), + MIRROR_BASE_URL=str( + _env_or_override( + "SUPERPICKY_MIRROR_BASE_URL", None, cls.MIRROR_BASE_URL + ) + ), + UPDATE_DOWNLOAD_PAGE=str( + _env_or_override( + "SUPERPICKY_DOWNLOAD_PAGE", None, cls.UPDATE_DOWNLOAD_PAGE + ) + ), + EBIRD_API_BASE=str( + _env_or_override("SUPERPICKY_EBIRD_API_BASE", None, cls.EBIRD_API_BASE) + ), + NOMINATIM_REVERSE_URL=str( + _env_or_override( + "SUPERPICKY_NOMINATIM_REVERSE_URL", None, cls.NOMINATIM_REVERSE_URL + ) + ), ) class Config: """ - 主配置聚合类。 - Main configuration aggregation class. + 主配置聚合类 / Main configuration aggregation class. 这是项目中最常用的统一读取入口,用来整合不同层次的配置。 This is the most common unified read entry point used to aggregate different configuration layers. @@ -461,8 +537,7 @@ class Config: def __init__(self): """ - 构造统一配置对象。 - Construct the unified configuration object. + 构造统一配置对象 / Construct the unified configuration object. 初始化时建立静态常量分组,并加载服务与端点配置。 Initialization builds static config groups and loads service and endpoint configuration. @@ -477,25 +552,23 @@ def __init__(self): def get_directory_names(self) -> Dict[str, str]: """ - 返回常用输出目录名映射。 - Return a mapping of commonly used output directory names. + 返回常用输出目录名映射 / Return a mapping of commonly used output directory names. 适合 UI 展示和流程内统一引用目录名。 Useful for UI display and for consistent directory references inside processing flows. """ return { - 'excellent': self.directory.EXCELLENT_DIR, - 'standard': self.directory.STANDARD_DIR, - 'no_birds': self.directory.NO_BIRDS_DIR, - 'temp': self.directory.TEMP_DIR, - 'redbox': self.directory.REDBOX_DIR, - 'crop_temp': self.directory.CROP_TEMP_DIR, + "excellent": self.directory.EXCELLENT_DIR, + "standard": self.directory.STANDARD_DIR, + "no_birds": self.directory.NO_BIRDS_DIR, + "temp": self.directory.TEMP_DIR, + "redbox": self.directory.REDBOX_DIR, + "crop_temp": self.directory.CROP_TEMP_DIR, } def is_raw_file(self, filename: str) -> bool: """ - 判断文件名是否属于 RAW 扩展名集合。 - Check whether a filename belongs to the RAW extension set. + 判断文件名是否属于 RAW 扩展名集合 / Check whether a filename belongs to the RAW extension set. 这里只按扩展名判断,不检查内容或 MIME。 This only checks file extensions and does not inspect content or MIME type. @@ -505,8 +578,7 @@ def is_raw_file(self, filename: str) -> bool: def is_jpg_file(self, filename: str) -> bool: """ - 判断文件名是否属于 JPG/JPEG 扩展名集合。 - Check whether a filename belongs to the JPG/JPEG extension set. + 判断文件名是否属于 JPG/JPEG 扩展名集合 / Check whether a filename belongs to the JPG/JPEG extension set. 这是轻量判断入口,不做内容探测。 This is a lightweight classification entry and does not inspect file contents. @@ -515,36 +587,25 @@ def is_jpg_file(self, filename: str) -> bool: return ext.lower() in self.file.JPG_EXTENSIONS -# ========================= -# 懒加载资源注册器 -# ========================= - -# 这个注册器用于跨模块共享可缓存、可复用、构造成本高的对象。 -# This registry is for cacheable, reusable, high-construction-cost objects shared across modules. -# 不适合用来存放短生命周期业务状态。 -# It is not suitable for short-lived business state. - _MISSING = object() class LazyRegistry: """ - 线程安全懒加载注册器。 - Thread-safe lazy registry. + 线程安全区加载注册器 / Thread-safe lazy registry. 目标是避免重复初始化重量级对象,并提供统一的共享入口。 Its goal is to avoid repeated heavy initialization and provide a unified sharing entry point. """ def __init__(self): - """初始化内部存储和锁。 Initialize the internal storage and lock.""" + """初始化内部存储和锁 / Initialize the internal storage and lock.""" self._values: Dict[str, Any] = {} self._lock = threading.RLock() def get_or_create(self, key: str, factory: Callable[[], Any]) -> Any: """ - 读取缓存对象,不存在时在锁内创建并缓存。 - Read a cached object and create/cache it inside the lock if it is missing. + 读取缓存对象,不存在时在锁内创建并缓存 / Read a cached object and create/cache it inside the lock if it is missing. 采用无锁快速读取加锁内二次检查,避免并发重复创建。 This uses a fast unlocked read plus a locked second check to avoid duplicate concurrent construction. @@ -560,17 +621,13 @@ def get_or_create(self, key: str, factory: Callable[[], Any]) -> Any: return value def get(self, key: str, default: Any = None) -> Any: - """ - 读取缓存值,不触发创建。 - Read a cached value without triggering creation. - """ + """读取缓存值,不触发创建 / Read a cached value without triggering creation.""" with self._lock: return self._values.get(key, default) def set(self, key: str, value: Any) -> None: """ - 显式设置缓存值。 - Explicitly set a cached value. + 显式设置缓存值 / Explicitly set a cached value. 不要用它存放临时业务状态。 Do not use this to stash temporary business state. @@ -580,8 +637,7 @@ def set(self, key: str, value: Any) -> None: def clear(self, key: str) -> None: """ - 清除单个缓存项。 - Clear a single cached item. + 清除单个缓存项 / Clear a single cached item. 适合测试隔离或强制下次重建。 Useful for test isolation or forcing the next access to rebuild the object. @@ -591,8 +647,7 @@ def clear(self, key: str) -> None: def clear_all(self) -> None: """ - 清空所有缓存项。 - Clear all cached items. + 清空所有缓存项 / Clear all cached items. 若缓存对象持有外部资源,应先确保存在显式关闭流程。 If cached objects hold external resources, make sure an explicit shutdown flow exists first. @@ -606,8 +661,7 @@ def clear_all(self) -> None: def get_lazy_registry() -> LazyRegistry: """ - 返回全局懒加载注册器实例。 - Return the global lazy-registry instance. + 返回全局懒加载注册器实例 / Return the global lazy-registry instance. 调用方应通过此入口获取共享注册器,避免自行新建导致缓存割裂。 Callers should use this entry point to get the shared registry and avoid cache fragmentation caused by creating their own registries. @@ -615,20 +669,9 @@ def get_lazy_registry() -> LazyRegistry: return _lazy_registry -# ========================= -# 设备选择 -# ========================= - -# 设备选择逻辑必须集中,避免不同模块各自判断 CUDA/MPS/CPU。 -# Device selection must stay centralized so different modules do not each make their own CUDA/MPS/CPU decisions. -# 若打包版与源码行为不同,优先怀疑打包环境中的 Torch/CUDA 运行时差异。 -# If packaged behavior differs from source behavior, suspect Torch/CUDA runtime differences in the packaged environment first. - - def get_best_device(): """ - 返回当前环境下最合适的 Torch 设备对象。 - Return the most appropriate Torch device object for the current environment. + 返回当前环境下最合适的 Torch 设备对象 / Return the most appropriate Torch device object for the current environment. 顺序为:macOS 先 MPS 再 CPU,其他平台先 CUDA 再 CPU。 The order is: on macOS use MPS then CPU, on other platforms use CUDA then CPU. @@ -637,19 +680,157 @@ def get_best_device(): On any detection failure, conservatively fall back to CPU. """ try: + torch_module = _get_torch_module() + if torch_module is None: + return _FallbackDevice("cpu") system = platform.system() - if system == 'Darwin': - if torch.backends.mps.is_available(): - return torch.device('mps') - return torch.device('cpu') - - if torch.cuda.is_available(): - return torch.device('cuda') - return torch.device('cpu') + if system == "Darwin": + if torch_module.backends.mps.is_available(): + return torch_module.device("mps") + return torch_module.device("cpu") + + if torch_module.cuda.is_available(): + return torch_module.device("cuda") + return torch_module.device("cpu") except Exception: - return torch.device('cpu') + torch_module = _get_torch_module() + return ( + torch_module.device("cpu") + if torch_module is not None + else _FallbackDevice("cpu") + ) + + +def migrate_old_data() -> bool: + """ + 迁移旧路径数据到新路径 / Migrate old path data to new path. + + 检测 ~/Documents/SuperPicky_Data 目录是否存在数据, + 如果存在则迁移到 get_app_config_dir() 返回的标准配置目录。 + + Returns: + bool: 迁移是否成功(如果没有旧数据也返回 True) + """ + try: + old_data_dir = Path.home() / "Documents" / "SuperPicky_Data" + new_data_dir = get_app_config_dir() + + if not old_data_dir.exists() or not old_data_dir.is_dir(): + return True + + files = list(old_data_dir.iterdir()) + if not files: + return True + + logger.info(f"检测到旧数据目录: {old_data_dir}") + logger.info(f"开始迁移到新目录: {new_data_dir}") + + new_data_dir.mkdir(parents=True, exist_ok=True) + + copied_files = [] + for file_path in files: + try: + dest_path = new_data_dir / file_path.name + if file_path.is_file(): + import shutil + + shutil.copy2(file_path, dest_path) + copied_files.append(file_path.name) + elif file_path.is_dir(): + import shutil + + shutil.copytree(file_path, dest_path, dirs_exist_ok=True) + copied_files.append(file_path.name) + except Exception as e: + logger.error(f"复制文件失败 {file_path.name}: {e}") + return False + + logger.info(f"成功迁移 {len(copied_files)} 个文件/目录") + + for file_name in copied_files: + try: + old_path = old_data_dir / file_name + if old_path.exists(): + if old_path.is_file(): + old_path.unlink() + elif old_path.is_dir(): + import shutil + + shutil.rmtree(old_path) + except Exception as e: + logger.warning(f"删除旧文件失败 {file_name}: {e}") + + try: + if old_data_dir.exists() and old_data_dir.is_dir(): + import shutil + + shutil.rmtree(old_data_dir) + logger.info(f"已删除旧数据目录: {old_data_dir}") + except Exception as e: + logger.warning(f"删除旧目录失败: {e}") + + logger.info("数据迁移完成") + return True + + except Exception as e: + logger.error(f"数据迁移失败: {e}") + return False + + +def migrate_legacy_ioc_settings(app_name: str = "SuperPicky") -> bool: + """ + 迁移旧的用户主目录 IOC 设置到标准配置目录。 + Migrate legacy IOC settings from the user home directory to the standard config directory. + + 仅处理 ~/.superpicky/ioc/birdname_settings.ini 这类全局配置残留, + 不涉及照片目录中的 .superpicky 工作文件。 + """ + try: + import shutil + + old_settings_path = ( + Path.home() / ".superpicky" / "ioc" / "birdname_settings.ini" + ) + new_settings_path = get_birdname_settings_path(app_name) + + if not old_settings_path.exists() or not old_settings_path.is_file(): + return True + + if new_settings_path.exists(): + logger.info(f"检测到新的 IOC 配置已存在,保留新路径: {new_settings_path}") + return True + + shutil.copy2(old_settings_path, new_settings_path) + logger.info(f"已迁移 IOC 配置: {old_settings_path} -> {new_settings_path}") + + try: + old_settings_path.unlink() + except Exception as e: + logger.warning(f"删除旧 IOC 配置失败: {e}") + return True + + old_ioc_dir = old_settings_path.parent + old_superpicky_dir = old_ioc_dir.parent + try: + if ( + old_ioc_dir.exists() + and old_ioc_dir.is_dir() + and not any(old_ioc_dir.iterdir()) + ): + old_ioc_dir.rmdir() + if ( + old_superpicky_dir.exists() + and old_superpicky_dir.is_dir() + and not any(old_superpicky_dir.iterdir()) + ): + old_superpicky_dir.rmdir() + except Exception as e: + logger.warning(f"清理旧 IOC 目录失败: {e}") + + return True + except Exception as e: + logger.error(f"IOC 配置迁移失败: {e}") + return False -# 全局配置实例,供多数模块直接 import 使用。 -# Global configuration instance intended for direct import by most modules. config = Config() diff --git a/constants.py b/constants.py index 1e647543..48fd83e7 100644 --- a/constants.py +++ b/constants.py @@ -1,77 +1,77 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -SuperPicky 常量定义 -统一管理全局常量,避免重复定义 -""" - -# 应用版本号 -# 应用版本号 -APP_VERSION = "4.2.5" - - -# 评分对应的文件夹名称映射(向后兼容,默认中文) -RATING_FOLDER_NAMES = { - 3: "3星_优选", - 2: "2星_良好", - 1: "1星_普通", - 0: "0星_放弃", - -1: "0星_放弃" # 无鸟照片也放入0星目录 -} - -# 英文文件夹名称 -RATING_FOLDER_NAMES_EN = { - 3: "3star_excellent", - 2: "2star_good", - 1: "1star_average", - 0: "0star_reject", - -1: "0star_reject" -} - -def get_rating_folder_names(): - """ - 获取当前语言的评分文件夹名称映射 - - Returns: - dict: {评分: 文件夹名称} - """ - try: - from tools.i18n import get_i18n - i18n = get_i18n() - if i18n.current_lang.startswith('en'): - return RATING_FOLDER_NAMES_EN.copy() - except Exception: - pass - return RATING_FOLDER_NAMES.copy() - -def get_rating_folder_name(rating: int) -> str: - """ - 获取指定评分的文件夹名称(根据当前语言) - - Args: - rating: 评分 (-1 to 3) - - Returns: - str: 文件夹名称 - """ - folders = get_rating_folder_names() - return folders.get(rating, folders.get(0, "0star_reject")) - -# 支持的 RAW 文件扩展名(小写) -RAW_EXTENSIONS = ['.nef', '.cr2', '.cr3', '.arw', '.raf', '.orf', '.rw2', '.pef', '.dng', '.3fr', '.iiq'] - -# 支持的 HEIF 文件扩展名(小写)- Sony HIF / Apple HEIC 等 -HEIF_EXTENSIONS = ['.hif', '.heif', '.heic'] - -# 支持的 JPG 文件扩展名(小写) -JPG_EXTENSIONS = ['.jpg', '.jpeg'] - -# 所有支持的图片扩展名(用于文件查找,包含大小写) -IMAGE_EXTENSIONS = ( - [ext.lower() for ext in RAW_EXTENSIONS] + - [ext.upper() for ext in RAW_EXTENSIONS] + - [ext.lower() for ext in HEIF_EXTENSIONS] + - [ext.upper() for ext in HEIF_EXTENSIONS] + - [ext.lower() for ext in JPG_EXTENSIONS] + - [ext.upper() for ext in JPG_EXTENSIONS] -) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +SuperPicky 常量定义 +统一管理全局常量,避免重复定义 +""" + +# 应用版本号 +# 应用版本号 +APP_VERSION = "4.2.5" + + +# 评分对应的文件夹名称映射(向后兼容,默认中文) +RATING_FOLDER_NAMES = { + 3: "3星_优选", + 2: "2星_良好", + 1: "1星_普通", + 0: "0星_放弃", + -1: "0星_放弃" # 无鸟照片也放入0星目录 +} + +# 英文文件夹名称 +RATING_FOLDER_NAMES_EN = { + 3: "3star_excellent", + 2: "2star_good", + 1: "1star_average", + 0: "0star_reject", + -1: "0star_reject" +} + +def get_rating_folder_names(): + """ + 获取当前语言的评分文件夹名称映射 + + Returns: + dict: {评分: 文件夹名称} + """ + try: + from tools.i18n import get_i18n + i18n = get_i18n() + if i18n.current_lang.startswith('en'): + return RATING_FOLDER_NAMES_EN.copy() + except Exception: + pass + return RATING_FOLDER_NAMES.copy() + +def get_rating_folder_name(rating: int) -> str: + """ + 获取指定评分的文件夹名称(根据当前语言) + + Args: + rating: 评分 (-1 to 3) + + Returns: + str: 文件夹名称 + """ + folders = get_rating_folder_names() + return folders.get(rating, folders.get(0, "0star_reject")) + +# 支持的 RAW 文件扩展名(小写) +RAW_EXTENSIONS = ['.nef', '.cr2', '.cr3', '.arw', '.raf', '.orf', '.rw2', '.pef', '.dng', '.3fr', '.iiq'] + +# 支持的 HEIF 文件扩展名(小写)- Sony HIF / Apple HEIC 等 +HEIF_EXTENSIONS = ['.hif', '.heif', '.heic'] + +# 支持的 JPG 文件扩展名(小写) +JPG_EXTENSIONS = ['.jpg', '.jpeg'] + +# 所有支持的图片扩展名(用于文件查找,包含大小写) +IMAGE_EXTENSIONS = ( + [ext.lower() for ext in RAW_EXTENSIONS] + + [ext.upper() for ext in RAW_EXTENSIONS] + + [ext.lower() for ext in HEIF_EXTENSIONS] + + [ext.upper() for ext in HEIF_EXTENSIONS] + + [ext.lower() for ext in JPG_EXTENSIONS] + + [ext.upper() for ext in JPG_EXTENSIONS] +) diff --git a/core/batch_processor.py b/core/batch_processor.py index 880130c0..439e7123 100644 --- a/core/batch_processor.py +++ b/core/batch_processor.py @@ -9,10 +9,10 @@ import os import json import time -from typing import List, Dict, Optional, Callable +from typing import Callable, Dict, List, Optional, Sequence, Union from dataclasses import dataclass, field -from core.recursive_scanner import scan_recursive, is_processed, count_photos +from core.recursive_scanner import DEFAULT_SCAN_MAX_DEPTH, ScannedDirectory, count_photos, is_processed, scan_directories @dataclass @@ -39,7 +39,7 @@ def __init__( root_dir: str, settings, # ProcessingSettings skip_existing: bool = False, - max_depth: int = 10, + max_depth: int = DEFAULT_SCAN_MAX_DEPTH, log_fn: Optional[Callable[[str], None]] = None, ): self.root_dir = os.path.abspath(root_dir) @@ -48,13 +48,13 @@ def __init__( self.max_depth = max_depth self.log = log_fn or print - def scan(self) -> List[str]: - """扫描并返回待处理的原子目录列表""" - return scan_recursive(self.root_dir, self.max_depth) + def scan(self) -> List[ScannedDirectory]: + """扫描并返回待处理的原子目录摘要列表""" + return scan_directories(self.root_dir, self.max_depth) def process( self, - dirs: List[str], + dirs: Sequence[Union[str, ScannedDirectory]], organize_files: bool = True, cleanup_temp: bool = True, ) -> BatchResult: @@ -71,16 +71,30 @@ def process( """ from core.photo_processor import PhotoProcessor, ProcessingCallbacks - result = BatchResult(total_dirs=len(dirs)) + normalized_dirs: List[ScannedDirectory] = [] + for entry in dirs: + if isinstance(entry, ScannedDirectory): + normalized_dirs.append(entry) + continue + normalized_dirs.append( + ScannedDirectory( + path=entry, + depth=-1, + photo_count=count_photos(entry), + ) + ) + + result = BatchResult(total_dirs=len(normalized_dirs)) batch_start = time.time() - for i, dir_path in enumerate(dirs, 1): + for i, scanned_dir in enumerate(normalized_dirs, 1): + dir_path = scanned_dir.path dir_name = os.path.relpath(dir_path, self.root_dir) - photo_count = count_photos(dir_path) + photo_count = scanned_dir.photo_count # 增量跳过 if self.skip_existing and is_processed(dir_path): - self.log(f"\n⏭️ [{i}/{len(dirs)}] 跳过已处理: {dir_name} ({photo_count} 张)") + self.log(f"\n⏭️ [{i}/{len(normalized_dirs)}] 跳过已处理: {dir_name} ({photo_count} 张)") result.skipped_dirs += 1 result.dir_results.append({ 'dir': dir_name, @@ -90,7 +104,7 @@ def process( continue self.log(f"\n{'━' * 60}") - self.log(f"📂 [{i}/{len(dirs)}] 处理: {dir_name} ({photo_count} 张)") + self.log(f"📂 [{i}/{len(normalized_dirs)}] 处理: {dir_name} ({photo_count} 张)") self.log(f"{'━' * 60}") dir_start = time.time() diff --git a/core/build_info.py b/core/build_info.py index 7eb1dfea..e1fab1b2 100644 --- a/core/build_info.py +++ b/core/build_info.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- """ 构建信息 -此文件在发布构建时由 build_release.sh 自动修改,以注入 git commit hash +此文件在发布构建时由 Python 构建脚本自动修改,以注入 git commit hash 和 release channel """ # 在打包时会被替换为实际的 commit hash -COMMIT_HASH = "be2f41a3" +COMMIT_HASH = "6f2049e" # 发布渠道:CI 打包时自动注入("nightly" = RC 预发布,"official" = 正式版) # 本地开发默认 "dev",不触发更新检查 -RELEASE_CHANNEL = "dev" +RELEASE_CHANNEL = "official" diff --git a/core/flight_detector.py b/core/flight_detector.py index 1079e470..f3f17f4d 100644 --- a/core/flight_detector.py +++ b/core/flight_detector.py @@ -1,69 +1,77 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Flight Detector - 飞版检测模块 -使用 EfficientNet-B3 模型检测鸟类是否处于飞行状态 +Flight Detector - 飞版检测模块。 +Flight Detector module. -V3.4 新增功能 +使用 EfficientNet-B3 模型检测鸟类是否处于飞行。 +Uses an EfficientNet-B3 model to determine whether a bird is in flight. """ from pathlib import Path from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, cast import numpy as np import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image -from config import get_best_device +from config import ( + get_best_device, + get_install_scoped_resource_path, + get_packaged_model_relative_path, + get_runtime_app_root, + get_runtime_meipass, +) @dataclass class FlightResult: """飞版检测结果""" - is_flying: bool # 是否飞行 - confidence: float # 置信度 (0-1) + is_flying: bool + confidence: float class FlightDetector: """ 飞版检测器 - - 使用 EfficientNet-B3 二分类模型判断鸟类是否处于飞行状态。 - 模型训练自 superFlier 项目,使用 RMSprop + last_block freeze 策略。 + 使用 EfficientNet-B3 二分类模型判断鸟类是否处于飞行状态 """ - - # 模型配置 - IMAGE_SIZE = 384 # 训练时的输入尺寸 - THRESHOLD = 0.5 # 默认分类阈值 + + IMAGE_SIZE = 384 + THRESHOLD = 0.5 def __init__(self, model_path: Optional[str] = None): """ 初始化检测器 - + Args: model_path: 模型文件路径,如果为 None 则使用默认路径 """ - self.model = None - self.device = None + self.model: Optional[nn.Module] = None + self.device: Optional[torch.device] = None self.model_loaded = False - - # 确定模型路径(支持 PyInstaller 打包) + if model_path is None: import sys - if hasattr(sys, '_MEIPASS'): - # PyInstaller 打包后的路径 - self.model_path = Path(sys._MEIPASS) / "models" / "superFlier_efficientnet.pth" + if getattr(sys, 'frozen', False) and sys.platform == 'win32': + self.model_path = get_install_scoped_resource_path( + "models/superFlier_efficientnet.pth", + packaged_relative_path=get_packaged_model_relative_path("models/superFlier_efficientnet.pth"), + ) else: - # 开发环境:优先使用 main.py 注入的真实 app 根目录(补丁覆盖层兼容) - project_root = Path(getattr(sys, '_SUPERPICKY_APP_ROOT', - str(Path(__file__).parent.parent))) - self.model_path = project_root / "models" / "superFlier_efficientnet.pth" + meipass = get_runtime_meipass() + if meipass is not None: + self.model_path = Path(meipass) / "models" / "superFlier_efficientnet.pth" + else: + project_root = get_runtime_app_root() + if project_root is None: + project_root = str(Path(__file__).parent.parent) + self.model_path = Path(project_root) / "models" / "superFlier_efficientnet.pth" else: self.model_path = Path(model_path) - - # 图像预处理(与训练时一致) + self.transform = transforms.Compose([ transforms.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE)), transforms.ToTensor(), @@ -76,39 +84,38 @@ def __init__(self, model_path: Optional[str] = None): def _build_model(self) -> nn.Module: """ 构建 EfficientNet-B3 模型结构 - - 必须与训练时的结构完全一致: - - 使用 Dropout(0.2) - - 输出层为 Linear(in_features, 1) + Sigmoid """ - model = models.efficientnet_b3(weights=None) # 不需要预训练权重 - in_features = model.classifier[1].in_features - - # 替换分类头(与 grid_search.py 中的 DROPOUT=0.2 一致) - model.classifier = nn.Sequential( + model = cast(nn.Module, models.efficientnet_b3(weights=None)) + classifier = cast(nn.Sequential, getattr(model, "classifier")) + classifier_linear = cast(nn.Linear, classifier[1]) + in_features = classifier_linear.in_features + + setattr( + model, + "classifier", + nn.Sequential( nn.Dropout(0.2), nn.Linear(in_features, 1), nn.Sigmoid() + ), ) - + return model def load_model(self) -> None: """ 加载模型权重 - + Raises: FileNotFoundError: 模型文件不存在 RuntimeError: 模型加载失败 """ if not self.model_path.exists(): raise FileNotFoundError(f"飞版检测模型未找到: {self.model_path}") - - self.device = get_best_device() - - # 构建并加载模型 + + self.device = torch.device(str(get_best_device())) self.model = self._build_model() - + try: state_dict = torch.load( self.model_path, @@ -118,44 +125,38 @@ def load_model(self) -> None: self.model.load_state_dict(state_dict) except Exception as e: raise RuntimeError(f"加载飞版检测模型失败: {e}") - - self.model.to(self.device) + + self.model.to(device=self.device) self.model.eval() self.model_loaded = True def detect( - self, + self, image: Union[np.ndarray, Image.Image, str], - threshold: float = None + threshold: Optional[float] = None ) -> FlightResult: """ 检测图像中的鸟是否处于飞行状态 - + Args: - image: 输入图像,支持以下格式: - - numpy.ndarray (BGR 或 RGB,由 OpenCV 或其他库读取) - - PIL.Image - - str (图像文件路径) + image: 输入图像,支持 numpy.ndarray、PIL.Image 或文件路径 threshold: 分类阈值,默认使用 self.THRESHOLD (0.5) - + Returns: FlightResult: 包含 is_flying 和 confidence - + Raises: RuntimeError: 模型未加载 """ if not self.model_loaded: raise RuntimeError("飞版检测模型未加载,请先调用 load_model()") - + if threshold is None: threshold = self.THRESHOLD - - # 处理不同输入类型 + if isinstance(image, str): - # 文件路径 pil_image = Image.open(image).convert('RGB') elif isinstance(image, np.ndarray): - # numpy 数组(假设是 BGR,需要转换) import cv2 if len(image.shape) == 3 and image.shape[2] == 3: rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -166,14 +167,17 @@ def detect( pil_image = image.convert('RGB') else: raise ValueError(f"不支持的图像类型: {type(image)}") - - # 预处理 - image_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) - - # 推理 + + transformed_tensor = cast(torch.Tensor, self.transform(pil_image)) + image_tensor = transformed_tensor.unsqueeze(0).to(self.device) + + if self.model is None: + raise RuntimeError("飞版检测模型尚未初始化") + with torch.no_grad(): prob = self.model(image_tensor).item() - + del image_tensor + return FlightResult( is_flying=prob > threshold, confidence=prob @@ -182,33 +186,32 @@ def detect( def detect_batch( self, images: list, - threshold: float = None, + threshold: Optional[float] = None, batch_size: int = 8 ) -> list: """ 批量检测多张图像 - + Args: images: 图像列表(支持混合类型) threshold: 分类阈值 batch_size: 批处理大小 - + Returns: list[FlightResult]: 检测结果列表 """ if not self.model_loaded: raise RuntimeError("飞版检测模型未加载,请先调用 load_model()") - + if threshold is None: threshold = self.THRESHOLD - + results = [] - - # 分批处理 + for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] batch_tensors = [] - + for img in batch: if isinstance(img, str): pil_image = Image.open(img).convert('RGB') @@ -220,40 +223,41 @@ def detect_batch( pil_image = img.convert('RGB') else: continue - - batch_tensors.append(self.transform(pil_image)) - + + batch_tensors.append(cast(torch.Tensor, self.transform(pil_image))) + if not batch_tensors: continue - - # 组合为批次 + + if self.device is None: + raise RuntimeError("飞版检测设备尚未初始化") batch_tensor = torch.stack(batch_tensors).to(self.device) - - # 推理 + + if self.model is None: + raise RuntimeError("飞版检测模型尚未初始化") + model = self.model with torch.no_grad(): - probs = self.model(batch_tensor).squeeze().cpu().numpy() - - # 处理单个元素的情况 + probs = model(batch_tensor).squeeze().cpu().numpy() # type: ignore + if probs.ndim == 0: probs = [probs.item()] - + for prob in probs: results.append(FlightResult( is_flying=prob > threshold, confidence=float(prob) )) - + return results -# 全局单例(延迟初始化) _flight_detector_instance: Optional[FlightDetector] = None def get_flight_detector() -> FlightDetector: """ 获取全局飞版检测器实例(单例模式) - + Returns: FlightDetector: 全局检测器实例 """ diff --git a/core/initialization_manager.py b/core/initialization_manager.py new file mode 100644 index 00000000..a43e9937 --- /dev/null +++ b/core/initialization_manager.py @@ -0,0 +1,1436 @@ +# -*- coding: utf-8 -*- +""" +First-run initialization manager for lightweight builds. + +The old first-run onboarding path is intentionally preserved elsewhere for +full-package compatibility. This manager only takes over when runtime or +required resources are missing. + +轻量级构建的首次运行初始化管理器。 + +旧的首次运行引导路径在其他地方保留,以实现完整包兼容性。 +此管理器仅在运行时或所需资源缺失时接管。 +""" + +from __future__ import annotations + +import importlib +import importlib.util +import logging +import os +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import urllib.parse +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Optional + +from PySide6.QtCore import QObject, Signal + +from advanced_config import get_advanced_config +from config import ( + get_app_config_dir, + get_app_internal_dir, + get_bundled_resource_dir, +) +from core.initialization_progress import ( + InitializationProgressEvent, + PROGRESS_KIND_DOWNLOAD, + PROGRESS_KIND_RUNTIME, + STAGE_DOWNLOADING, + STAGE_PREPARING_RUNTIME, + parse_pip_raw_progress_line, +) +from core.runtime_requirements import RuntimeRequirements, get_runtime_requirements +from core.source_probe import pick_best_source, probe_sources +from scripts.download_models import ( + download_resource, + resolve_download_plan, + resolve_resource_destination_dir, +) + +logging.basicConfig(level=logging.INFO) + + +PIPY_SOURCES = [ + {"name": "cernet", "url": "https://mirrors.cernet.edu.cn/pypi/web/simple"}, + {"name": "official", "url": "https://pypi.org/simple"}, +] + +FULL_FEATURE_SET = ("core_detection", "quality", "keypoint", "flight", "birdid") + +STAGE_NOT_STARTED = "not_started" +STAGE_PROBING = "probing_sources" +STAGE_CHECKING_UPDATES = "checking_updates" +STAGE_PREPARING_RUNTIME = "preparing_runtime" +STAGE_DOWNLOADING = "downloading_resources" +STAGE_VERIFYING = "verifying" +STAGE_READY = "ready" +STAGE_FAILED = "failed" + + +class InitializationInterrupted(RuntimeError): + """用户主动中断初始化 / User-requested initialization interruption.""" + + +@dataclass +class RuntimeSelection: + variant: str + detected_cuda_capable: bool + reason: str + + +@dataclass(frozen=True) +class RuntimeInstallLocation: + key: str + runtime_dir: Path + free_bytes: Optional[int] + writable: bool + + +@dataclass +class ResourceProgressState: + """ + Aggregate state for one resource inside the download phase. + + 下载阶段中单个资源的聚合状态。 + """ + + ratio: float = 0.0 + bytes_done: int | None = None + bytes_total: int | None = None + is_terminal: bool = False + last_logged_bucket: int = -1 + last_logged_message: str | None = None + last_logged_source: str | None = None + + +class InitializationManager(QObject): + """ + Coordinate runtime repair, resource preparation, and structured progress events. + + 负责协调运行时修复、资源准备以及结构化进度事件的初始化管理器。 + """ + + stage_changed = Signal(str, str) + progress_event = Signal(object) + progress_changed = Signal(int, str, int, int) + item_status_changed = Signal(str, str, str) + finished = Signal(bool, object) + + def __init__(self, parent=None): + """ + 初始化初始化管理器。 + + Initialize the initialization manager. + + 参数 Parameters: + parent: 父 QObject 对象 + """ + super().__init__(parent) + self.config = get_advanced_config() + self._thread: Optional[threading.Thread] = None + self._last_options: Optional[dict] = None + self._last_mode: str = "init" + self._project_root = self._resolve_project_root() + self._runtime_dir = self.resolve_runtime_dir( + self.config.runtime_install_location_preference + ) + self._source_map: Dict[str, str] = {} + self._cancel_requested = threading.Event() + self._active_process: Optional[subprocess.Popen[str]] = None + self._resource_progress: dict[str, ResourceProgressState] = {} + self._resource_progress_item_count = 0 + + self._ensure_hf_endpoint_configured() + logging.info("初始化管理器已创建,项目根目录: %s", self._project_root) + + def _resolve_project_root(self) -> Path: + """ + 解析项目根目录。 + + Resolve project root directory. + + 返回 Returns: + Path: 项目根目录路径 + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + return get_app_internal_dir() + return Path(__file__).resolve().parent.parent + + def _ensure_hf_endpoint_configured(self) -> None: + """ + 确保 Hugging Face 端点环境变量已正确设置。 + + Ensure Hugging Face endpoint environment variables are properly configured. + """ + hf_mirror_endpoint = "https://hf-mirror.com" + + if ( + "HF_ENDPOINT" not in os.environ + or os.environ["HF_ENDPOINT"] != hf_mirror_endpoint + ): + os.environ["HF_ENDPOINT"] = hf_mirror_endpoint + logging.info("已设置 HF_ENDPOINT = %s", hf_mirror_endpoint) + + env_vars = { + "HF_HUB_DISABLE_TELEMETRY": "1", + "HF_HUB_DISABLE_XET": "1", + "DO_NOT_TRACK": "1", + } + + for key, value in env_vars.items(): + if key not in os.environ or os.environ[key] != value: + os.environ[key] = value + logging.debug("已设置 %s = %s", key, value) + + def _resolve_runtime_requirements_path(self, runtime_variant: str) -> Path: + """Resolve runtime requirements file path for backward compatibility.""" + requirements = get_runtime_requirements(runtime_variant) # pyright: ignore[reportArgumentType] + requirements_content = requirements.to_requirements_string( + include_indexes=False, + package_urls=self._selected_torch_package_urls(runtime_variant), + ) + + temp_file = tempfile.NamedTemporaryFile( + mode="w", + suffix=".txt", + prefix=f"requirements_{runtime_variant}_", + delete=False, + encoding="utf-8", + ) + try: + temp_file.write(requirements_content) + temp_file.close() + return Path(temp_file.name) + except Exception: + temp_file.close() + Path(temp_file.name).unlink(missing_ok=True) + raise + + def _runtime_requirements(self, runtime_variant: str) -> RuntimeRequirements: + """ + Return the unified runtime requirement definition for one variant. + + 返回指定运行时变体的统一依赖定义。 + """ + return get_runtime_requirements(runtime_variant) # pyright: ignore[reportArgumentType] + + def _selected_torch_package_urls(self, runtime_variant: str) -> dict[str, str]: + """ + Build direct wheel references for Torch packages on Windows runtime installs. + + 为 Windows 运行时安装构建 Torch 系列包的直链引用。 + """ + if runtime_variant not in ("cpu", "cuda"): + return {} + primary_source = self._source_map.get("torch_primary", "").strip() + if not primary_source or sys.platform != "win32": + return {} + + requirements = self._runtime_requirements(runtime_variant) + python_tag = f"cp{sys.version_info.major}{sys.version_info.minor}" + abi_tag = python_tag + platform_tag = "win_amd64" + source_base = primary_source.rstrip("/") + + package_versions = { + "torch": requirements.torch_version, + "torchvision": requirements.torchvision_version, + "torchaudio": requirements.torchaudio_version, + } + selected_urls: dict[str, str] = {} + for package_name, version in package_versions.items(): + normalized_version = (version or "").strip() + if not normalized_version: + continue + filename = ( + f"{package_name}-{normalized_version}-{python_tag}-{abi_tag}-{platform_tag}.whl" + ) + quoted_filename = urllib.parse.quote(filename) + selected_urls[package_name] = ( + f"{package_name} @ {source_base}/{quoted_filename}" + ) + return selected_urls + + @staticmethod + def _torch_source_candidates(runtime_variant: str) -> list[dict[str, str]]: + """ + Build Torch wheel source candidates from the shared runtime requirements. + + 基于统一运行时依赖定义构建 Torch wheel 源候选列表。 + """ + requirements = get_runtime_requirements(runtime_variant) # pyright: ignore[reportArgumentType] + candidates: list[dict[str, str]] = [] + for index, url in enumerate(requirements.extra_index_urls): + lowered = url.lower() + if "mirror" in lowered or "nju" in lowered: + name = f"mirror-{index}" + elif "download.pytorch.org" in lowered: + name = f"official-{index}" + else: + name = f"torch-{index}" + candidates.append({"name": name, "url": url}) + return candidates + + @staticmethod + def _normalize_features(selected_features: Optional[Iterable[str]]) -> list[str]: + features = list(selected_features or FULL_FEATURE_SET) + if "core_detection" not in features: + features.insert(0, "core_detection") + return features + + def _save_config(self, **updates) -> None: + setters = { + "initialization_completed": self.config.set_initialization_completed, + "initialization_in_progress": self.config.set_initialization_in_progress, + "selected_runtime_variant": self.config.set_selected_runtime_variant, + "detected_cuda_capable": self.config.set_detected_cuda_capable, + "runtime_install_location_preference": self.config.set_runtime_install_location_preference, + "resolved_runtime_dir": self.config.set_resolved_runtime_dir, + "enabled_feature_set": self.config.set_enabled_feature_set, + "downloaded_resources": self.config.set_downloaded_resources, + "resolved_source_map": self.config.set_resolved_source_map, + "last_init_error": self.config.set_last_init_error, + "last_init_exit_reason": self.config.set_last_init_exit_reason, + "last_init_mode": self.config.set_last_init_mode, + "is_first_run": self.config.set_is_first_run, + } + for key, value in updates.items(): + setter = setters.get(key) + if setter is not None: + setter(value) + self.config.save() + + def _emit_item_status(self, resource_id: str, status: str, detail: str) -> None: + self.item_status_changed.emit(resource_id, status, detail) + + def _emit_progress_event(self, event: InitializationProgressEvent) -> None: + """ + Emit the new structured progress event and the deprecated legacy signal. + + 发出新的结构化进度事件,并兼容发出旧版信号。 + """ + self.progress_event.emit(event) + ratio = event.normalized_ratio() + if ratio is None: + return + self.progress_changed.emit( + int(round(ratio * 100.0)), + event.message, + event.item_index or 0, + event.item_count or 0, + ) + + def _emit_phase_completion( + self, + progress_kind: str, + message: str, + *, + bytes_done: int | None = None, + bytes_total: int | None = None, + ) -> None: + """ + Emit an explicit terminal progress event for one visual phase. + + 为单个视觉阶段发出显式终态进度事件。 + """ + stage = ( + STAGE_PREPARING_RUNTIME + if progress_kind == PROGRESS_KIND_RUNTIME + else STAGE_DOWNLOADING + ) + self._emit_progress_event( + InitializationProgressEvent( + stage=stage, + progress_kind=progress_kind, + message=message, + ratio=1.0, + bytes_done=bytes_done, + bytes_total=bytes_total, + is_terminal=True, + ) + ) + + def _installation_root(self) -> Path: + if getattr(sys, "frozen", False): + executable = Path(sys.executable).resolve() + if sys.platform == "darwin" and executable.parent.name == "MacOS": + return executable.parents[2] + return executable.parent + return self._project_root + + def _runtime_install_locations(self) -> dict[str, Path]: + install_runtime_dir = self._installation_root() / "runtime_env" + if self._requires_install_local_runtime(): + install_runtime_dir = get_app_internal_dir() / "runtime_env" + return { + "default": get_app_config_dir() / "runtime_env", + "install": install_runtime_dir, + } + + def _requires_install_local_runtime(self) -> bool: + return getattr(sys, "frozen", False) and sys.platform == "win32" + + def _uses_bundled_runtime(self) -> bool: + return getattr(sys, "frozen", False) and sys.platform == "darwin" + + @staticmethod + def _existing_probe_path(path: Path) -> Path: + current = path + while not current.exists() and current != current.parent: + current = current.parent + return current + + def _free_bytes_for_path(self, path: Path) -> Optional[int]: + probe_path = self._existing_probe_path(path) + try: + return shutil.disk_usage(probe_path).free + except Exception: + return None + + def _writable_probe_dir(self, path: Path) -> Path: + probe_root = path if path.exists() else path.parent + probe_dir = self._existing_probe_path(probe_root) + return probe_dir if probe_dir.is_dir() else probe_dir.parent + + def _is_runtime_dir_writable(self, path: Path) -> bool: + try: + probe_dir = self._writable_probe_dir(path) + with tempfile.TemporaryDirectory(dir=probe_dir, prefix="sp_runtime_probe_"): + pass + return True + except Exception: + return False + + def get_runtime_install_location_options(self) -> list[RuntimeInstallLocation]: + if self._requires_install_local_runtime(): + install_dir = self._runtime_install_locations()["install"] + return [ + RuntimeInstallLocation( + key="install", + runtime_dir=install_dir, + free_bytes=self._free_bytes_for_path(install_dir), + writable=self._is_runtime_dir_writable(install_dir), + ) + ] + + options = [] + for key, runtime_dir in self._runtime_install_locations().items(): + options.append( + RuntimeInstallLocation( + key=key, + runtime_dir=runtime_dir, + free_bytes=self._free_bytes_for_path(runtime_dir), + writable=self._is_runtime_dir_writable(runtime_dir), + ) + ) + return options + + def choose_runtime_install_location( + self, preferred_key: Optional[str] = None + ) -> RuntimeInstallLocation: + if self._requires_install_local_runtime(): + install_dir = self._runtime_install_locations()["install"] + return RuntimeInstallLocation( + "install", + install_dir, + self._free_bytes_for_path(install_dir), + self._is_runtime_dir_writable(install_dir), + ) + + options = [ + item + for item in self.get_runtime_install_location_options() + if item.writable + ] + if not options: + default_dir = self._runtime_install_locations()["default"] + return RuntimeInstallLocation("default", default_dir, None, True) + + by_key = {item.key: item for item in options} + if preferred_key in by_key: + return by_key[preferred_key] + + comparable = [item for item in options if item.free_bytes is not None] + if comparable: + return max( + comparable, + key=lambda item: (item.free_bytes or -1, item.key == "default"), + ) + return by_key.get("default", options[0]) + + def resolve_runtime_dir(self, preferred_key: Optional[str] = None) -> Path: + return self.choose_runtime_install_location(preferred_key).runtime_dir + + def runtime_display_dir(self, preferred_key: Optional[str] = None) -> Path: + if self._uses_bundled_runtime(): + return get_bundled_resource_dir() + if self._requires_install_local_runtime(): + return get_app_internal_dir() / "runtime_env" + return self.resolve_runtime_dir(preferred_key) + + def start(self, options: dict, mode: str = "init") -> None: + normalized_options = dict(options) + normalized_options["features"] = self._normalize_features( + normalized_options.get("features") + ) + normalized_options["runtime_install_location"] = ( + self.choose_runtime_install_location( + normalized_options.get("runtime_install_location") + or self.config.runtime_install_location_preference + ).key + ) + self._last_options = normalized_options + self._last_mode = mode + if self._thread and self._thread.is_alive(): + return + self._cancel_requested.clear() + self._thread = threading.Thread( + target=self._run, args=(dict(normalized_options), mode), daemon=True + ) + self._thread.start() + + def start_initialization(self, options: dict) -> None: + self.start(options, mode="init") + + def start_repair(self, options: dict) -> None: + self.start(options, mode="repair") + + def retry_failed(self) -> None: + if self._last_options is not None: + self.start(self._last_options, mode=self._last_mode) + + def resume_pending(self) -> None: + if self._last_options is not None: + self.start(self._last_options, mode=self._last_mode) + + def cancel(self) -> None: + self._cancel_requested.set() + process = self._active_process + if process is not None and process.poll() is None: + try: + process.terminate() + except Exception: + pass + + def is_ready_for_main_ui( + self, selected_features: Optional[Iterable[str]] = None + ) -> bool: + return self._has_runtime_available() and self._resources_available( + selected_features + ) + + def needs_initialization( + self, selected_features: Optional[Iterable[str]] = None + ) -> bool: + return not self.is_ready_for_main_ui(selected_features) + + def check_runtime_health(self) -> bool: + """ + 检查运行时健康状态。 + + Check runtime health status. + + 返回 Returns: + bool: 运行时是否健康 + """ + runtime_available = self._has_runtime_available() + import_ok = self._runtime_import_ok() + + logging.info("运行时健康检查: 可用=%s, 导入=%s", runtime_available, import_ok) + + return runtime_available and import_ok + + def check_resource_health( + self, selected_features: Optional[Iterable[str]] + ) -> Dict[str, bool]: + """ + 检查资源健康状态。 + + Check resource health status. + + 参数 Parameters: + selected_features (Optional[Iterable[str]]): 选定的功能特性 + + 返回 Returns: + Dict[str, bool]: 资源 ID 到健康状态的映射 + """ + plan = resolve_download_plan(self._normalize_features(selected_features)) + health_status = { + item["resource_id"]: self._resource_item_available(item) for item in plan + } + + healthy_count = sum(1 for status in health_status.values() if status) + logging.info("资源健康检查: %d/%d 资源可用", healthy_count, len(health_status)) + + return health_status + + def repair_runtime_if_needed(self, runtime_variant: str) -> bool: + """ + 如果需要,修复运行时环境。 + + Repair runtime environment if needed. + + 参数 Parameters: + runtime_variant (str): 运行时变体(cpu/cuda/mac) + + 返回 Returns: + bool: 是否执行了修复 + """ + if self.check_runtime_health(): + self._emit_item_status("runtime", "done", "Runtime already healthy") + self._emit_phase_completion( + PROGRESS_KIND_RUNTIME, + f"{runtime_variant} runtime already available", + ) + logging.info("运行时环境健康,无需修复") + return False + + if self._uses_bundled_runtime(): + raise RuntimeError( + "Bundled macOS Lite Torch runtime is unavailable; rebuild the app bundle." + ) + + logging.info("运行时环境需要修复,开始准备 %s 运行时...", runtime_variant) + self._emit_stage( + STAGE_PREPARING_RUNTIME, f"Preparing {runtime_variant} runtime..." + ) + self._cleanup_partial_runtime() + self._purge_pip_cache_if_needed() + + start_time = time.perf_counter() + try: + self._prepare_runtime(runtime_variant) + self._emit_phase_completion( + PROGRESS_KIND_RUNTIME, + f"{runtime_variant} runtime ready", + ) + elapsed = time.perf_counter() - start_time + logging.info("运行时环境修复完成,耗时 %.2f 秒", elapsed) + return True + except Exception as exc: + elapsed = time.perf_counter() - start_time + logging.error("运行时环境修复失败,耗时 %.2f 秒: %s", elapsed, exc) + raise + + def repair_resources_if_needed( + self, selected_features: Optional[Iterable[str]] + ) -> bool: + """ + 如果需要,修复资源文件。 + + Repair resource files if needed. + + 参数 Parameters: + selected_features (Optional[Iterable[str]]): 选定的功能特性 + + 返回 Returns: + bool: 是否执行了修复 + """ + plan = resolve_download_plan(self._normalize_features(selected_features)) + pending = [item for item in plan if not self._resource_item_available(item)] + total_items = max(1, len(pending)) + + if not pending: + self._emit_item_status("resources", "done", "Resources already healthy") + logging.info("所有资源已就绪,无需修复") + return False + + self._resource_progress = { + item["resource_id"]: ResourceProgressState() for item in pending + } + self._resource_progress_item_count = total_items + logging.info("需要修复 %d 个资源文件", len(pending)) + self._emit_stage(STAGE_DOWNLOADING, "Downloading required resources...") + + start_time = time.perf_counter() + success_count = 0 + + for index, resource in enumerate(pending, start=1): + label = resource["filename"] + resource_id = resource["resource_id"] + + logging.info("开始下载资源 [%d/%d]: %s", index, total_items, label) + + self._emit_item_status(resource_id, "running", f"Preparing {label}") + + try: + download_resource( + resource, + project_root=self._project_root, + progress_cb=self._resource_progress_cb(index, total_items, resource_id), + ) + success_count += 1 + self._emit_item_status(resource_id, "done", f"{label} ready") + logging.info("资源 [%s] 下载成功", resource_id) + except Exception as exc: + logging.error("资源 [%s] 下载失败: %s", resource_id, exc) + self._emit_item_status(resource_id, "error", f"{label} failed: {exc}") + raise + + elapsed = time.perf_counter() - start_time + logging.info( + "资源修复完成: %d/%d 成功,总耗时 %.2f 秒", + success_count, + total_items, + elapsed, + ) + self._emit_phase_completion( + PROGRESS_KIND_DOWNLOAD, + "All required resources ready", + ) + + return True + + def detect_runtime_selection( + self, preferred_variant: str = "auto" + ) -> RuntimeSelection: + if sys.platform == "darwin": + if preferred_variant in ("cpu", "mac"): + return RuntimeSelection("mac", False, "macOS runtime") + return RuntimeSelection("mac", False, "macOS runtime") + + detected_cuda = self._detect_cuda_capable() + if preferred_variant == "cuda" and detected_cuda: + return RuntimeSelection("cuda", True, "user requested CUDA") + if preferred_variant == "cuda" and not detected_cuda: + return RuntimeSelection( + "cpu", False, "CUDA unavailable, falling back to CPU" + ) + if preferred_variant == "cpu": + return RuntimeSelection("cpu", detected_cuda, "user requested CPU") + if detected_cuda: + return RuntimeSelection("cuda", True, "detected NVIDIA/CUDA support") + return RuntimeSelection("cpu", False, "default CPU runtime") + + def _run(self, options: dict, mode: str) -> None: + try: + self._raise_if_cancelled() + selected_features = self._normalize_features(options.get("features")) + self._last_mode = mode + runtime_location = self.choose_runtime_install_location( + options.get("runtime_install_location") + or self.config.runtime_install_location_preference + ) + self._runtime_dir = runtime_location.runtime_dir + self._save_config( + initialization_in_progress=(mode == "init"), + last_init_error=None, + last_init_exit_reason="none", + last_init_mode=mode, + runtime_install_location_preference=runtime_location.key, + resolved_runtime_dir=str(runtime_location.runtime_dir), + ) + + runtime_choice = self.detect_runtime_selection( + options.get("runtime_variant", "auto") + ) + if mode == "init": + self._save_config( + selected_runtime_variant=runtime_choice.variant, + detected_cuda_capable=runtime_choice.detected_cuda_capable, + runtime_install_location_preference=runtime_location.key, + resolved_runtime_dir=str(runtime_location.runtime_dir), + enabled_feature_set=selected_features, + ) + + self._raise_if_cancelled() + self._emit_stage(STAGE_PROBING, "Probing download sources...") + self._source_map = self._resolve_best_sources(runtime_choice.variant) + self._emit_item_status( + "source_probe", "done", f"PyPI -> {self._source_map['pypi_primary']}" + ) + if self._source_map.get("torch_primary"): + self._emit_item_status( + "source_probe", "done", f"Torch -> {self._source_map['torch_primary']}" + ) + else: + self._emit_item_status( + "source_probe", "done", "Torch -> bundled runtime" + ) + self._save_config(resolved_source_map=self._source_map) + + if options.get("auto_update_enabled", True): + self._emit_stage(STAGE_CHECKING_UPDATES, "Checking updates...") + self._check_updates_if_enabled() + else: + try: + from tools.patch_manager import safe_clear_patch + + cleared, clear_message = safe_clear_patch() + clear_status = "done" if cleared else "warning" + self._emit_item_status("updates", clear_status, clear_message) + except Exception as exc: + self._emit_item_status( + "updates", "warning", f"Patch cleanup skipped: {exc}" + ) + self._emit_item_status( + "updates", "skipped", "Automatic updates disabled by user" + ) + + self._raise_if_cancelled() + self.repair_runtime_if_needed(runtime_choice.variant) + self._raise_if_cancelled() + self.repair_resources_if_needed(selected_features) + + self._emit_stage(STAGE_VERIFYING, "Verifying resources...") + if not self.is_ready_for_main_ui(selected_features): + raise RuntimeError( + "Initialization completed with missing runtime or resources" + ) + + success_updates: dict[str, object] = { + "initialization_in_progress": False, + "last_init_mode": "none", + } + if mode == "init": + success_updates.update( + initialization_completed=True, + last_init_exit_reason="none", + is_first_run=False, + downloaded_resources={ + item["resource_id"]: True + for item in resolve_download_plan(selected_features) + }, + ) + self._save_config(**success_updates) + final_message = ( + "Initialization completed" + if mode == "init" + else "Environment repair completed" + ) + self._emit_stage(STAGE_READY, final_message) + self.finished.emit( + True, + { + "runtime_variant": runtime_choice.variant, + "source_map": self._source_map, + "mode": mode, + }, + ) + except InitializationInterrupted: + self._cleanup_partial_runtime() + self._purge_pip_cache_if_needed() + self._save_config( + initialization_in_progress=False, + last_init_error=None, + last_init_exit_reason="interrupted", + last_init_mode=mode, + ) + self.finished.emit(False, {"interrupted": True, "mode": mode}) + except Exception as exc: + self._save_config( + initialization_in_progress=False, + last_init_error=str(exc), + last_init_exit_reason="failed", + last_init_mode=mode, + ) + self._emit_stage(STAGE_FAILED, str(exc)) + self.finished.emit(False, {"error": str(exc), "mode": mode}) + + def _resource_progress_cb( + self, + item_index: int, + total_items: int, + fallback_resource_id: str, + ): + """ + Adapt per-resource download events into one aggregated download stream. + + 将单个资源的下载事件适配为统一的聚合下载进度流。 + """ + + def _callback(event: InitializationProgressEvent) -> None: + resource_id = event.resource_id or fallback_resource_id + enriched_event = InitializationProgressEvent( + stage=event.stage, + progress_kind=event.progress_kind, + message=event.message, + ratio=event.ratio, + bytes_done=event.bytes_done, + bytes_total=event.bytes_total, + item_index=item_index - 1, + item_count=total_items, + resource_id=resource_id, + source=event.source, + is_terminal=event.is_terminal, + ) + aggregate_event = self._update_resource_aggregate(enriched_event) + if self._should_emit_resource_log(enriched_event): + self._emit_item_status(resource_id, "progress", enriched_event.message) + self._emit_progress_event(aggregate_event) + + return _callback + + def _update_resource_aggregate( + self, + event: InitializationProgressEvent, + ) -> InitializationProgressEvent: + """ + Fold one resource event into the cross-resource aggregate download ratio. + + 将单个资源事件折叠到跨资源的聚合下载进度中。 + """ + resource_id = event.resource_id or f"resource-{len(self._resource_progress)}" + ratio = event.normalized_ratio() + bytes_total = event.bytes_total if event.bytes_total and event.bytes_total > 0 else None + + state = self._resource_progress.get(resource_id) + if state is None: + state = ResourceProgressState() + self._resource_progress[resource_id] = state + + if ratio is not None: + state.ratio = max(state.ratio, ratio) + elif event.is_terminal: + state.ratio = 1.0 + + if event.bytes_done is not None: + state.bytes_done = max(state.bytes_done or 0, event.bytes_done) + if bytes_total is not None: + state.bytes_total = max(state.bytes_total or 0, bytes_total) + state.is_terminal = state.is_terminal or event.is_terminal or state.ratio >= 1.0 + + total_items = max(1, self._resource_progress_item_count or len(self._resource_progress)) + completed_ratio_sum = sum(item.ratio for item in self._resource_progress.values()) + overall_ratio = completed_ratio_sum / total_items + + known_total = 0 + known_done = 0 + all_have_bytes = bool(self._resource_progress) + for item in self._resource_progress.values(): + if item.bytes_total is None: + all_have_bytes = False + continue + known_total += item.bytes_total + known_done += min(item.bytes_done or 0, item.bytes_total) + + aggregate_terminal = ( + len(self._resource_progress) >= self._resource_progress_item_count + and all(item.is_terminal for item in self._resource_progress.values()) + ) + return InitializationProgressEvent( + stage=STAGE_DOWNLOADING, + progress_kind=event.progress_kind, + message=event.message, + ratio=min(1.0, max(0.0, overall_ratio)), + bytes_done=known_done if all_have_bytes else None, + bytes_total=known_total if all_have_bytes else None, + item_index=event.item_index, + item_count=event.item_count, + resource_id=event.resource_id, + source=event.source, + is_terminal=aggregate_terminal, + ) + + def _should_emit_resource_log(self, event: InitializationProgressEvent) -> bool: + """ + Throttle noisy per-byte download events into human-readable milestone logs. + + 将高频字节级下载事件节流为更适合阅读的里程碑日志。 + """ + resource_id = event.resource_id + if not resource_id: + return False + + state = self._resource_progress.get(resource_id) + if state is None: + return True + + ratio = event.normalized_ratio() + message = event.message + source_changed = bool(event.source and event.source != state.last_logged_source) + terminal_message = event.is_terminal or any( + token in message.lower() + for token in ("failed", "validated", "downloaded", "already present", "copied from local fallback") + ) + + should_log = False + if state.last_logged_message is None: + should_log = True + elif terminal_message or source_changed: + should_log = True + elif ratio is not None: + bucket = min(10, int(ratio * 10.0)) + if bucket > state.last_logged_bucket: + should_log = True + elif message != state.last_logged_message: + should_log = True + + if should_log: + if ratio is not None: + state.last_logged_bucket = max( + state.last_logged_bucket, + min(10, int(ratio * 10.0)), + ) + state.last_logged_message = message + state.last_logged_source = event.source or state.last_logged_source + + return should_log + + def _emit_stage(self, stage: str, message: str) -> None: + self.stage_changed.emit(stage, message) + + def _check_updates_if_enabled(self) -> None: + try: + from tools.update_checker import UpdateChecker + + self._raise_if_cancelled() + checker = UpdateChecker() + checker.check_for_updates() + self._emit_item_status("updates", "done", "Update probe finished") + except Exception as exc: + self._emit_item_status("updates", "warning", f"Update probe skipped: {exc}") + + def _resolve_best_sources(self, runtime_variant: str) -> Dict[str, str]: + pypi_results = probe_sources("pypi", PIPY_SOURCES) + best_pypi = self._pick_preferred_source(pypi_results) + + pypi_primary = best_pypi.url if best_pypi else PIPY_SOURCES[0]["url"] + pypi_fallback = self._resolve_fallback_url(pypi_results, pypi_primary) + + torch_primary = "" + torch_fallback = "" + torch_sources = self._torch_source_candidates(runtime_variant) + if torch_sources: + torch_results = probe_sources(f"torch-{runtime_variant}", torch_sources) + best_torch = self._pick_preferred_source(torch_results) + torch_primary = best_torch.url if best_torch else torch_sources[0]["url"] + torch_fallback = self._resolve_fallback_url(torch_results, torch_primary) + + selected = { + "pypi_primary": pypi_primary, + "pypi_fallback": pypi_fallback, + "torch_primary": torch_primary, + "torch_fallback": torch_fallback, + } + return selected + + @staticmethod + def _pick_preferred_source(results): + successful = [item for item in results if item.ok] + if not successful: + return None + + non_official = [ + item for item in successful if "official" not in item.name.lower() + ] + if non_official: + return pick_best_source(non_official) + return pick_best_source(successful) + + @staticmethod + def _resolve_fallback_url(results, primary_url: str) -> str: + successful = [item for item in results if item.ok and item.url != primary_url] + if not successful: + return primary_url + + non_official = [ + item for item in successful if "official" not in item.name.lower() + ] + if non_official: + fallback = pick_best_source(non_official) + return fallback.url if fallback else primary_url + + fallback = pick_best_source(successful) + return fallback.url if fallback else primary_url + + def _prepare_runtime(self, runtime_variant: str) -> None: + """ + Create or repair the runtime environment for the selected variant. + + 为当前选择的运行时变体创建或修复运行环境。 + """ + if self._uses_bundled_runtime(): + raise RuntimeError( + "Bundled macOS Lite Torch runtime is unavailable; runtime installation is disabled." + ) + + if self._use_packaged_runtime_bootstrap(): + self._prepare_runtime_with_packaged_bootstrap(runtime_variant) + return + + python_cmd = self._resolve_python_command() + if not self._runtime_dir.exists(): + self._run_subprocess( + [*python_cmd, "-m", "venv", str(self._runtime_dir)], + "Create runtime venv", + ) + + pip_executable = ( + self._runtime_dir + / ("Scripts" if os.name == "nt" else "bin") + / ("pip.exe" if os.name == "nt" else "pip") + ) + requirements_file = self._resolve_runtime_requirements_path(runtime_variant) + runtime_requirements = self._runtime_requirements(runtime_variant) + install_cmd = [ + str(pip_executable), + "install", + "--no-cache-dir", + "--progress-bar", + "raw", + "-r", + str(requirements_file), + "-i", + self._source_map["pypi_primary"], + "--extra-index-url", + self._source_map["pypi_fallback"], + ] + if runtime_requirements.extra_index_urls: + install_cmd.extend(["--extra-index-url", self._source_map["torch_primary"]]) + if ( + self._source_map["torch_fallback"] + and self._source_map["torch_fallback"] != self._source_map["torch_primary"] + ): + install_cmd.extend( + ["--extra-index-url", self._source_map["torch_fallback"]] + ) + self._run_subprocess( + install_cmd, + f"Install {runtime_variant} runtime", + progress_stage=STAGE_PREPARING_RUNTIME, + progress_kind=PROGRESS_KIND_RUNTIME, + ) + self._inject_runtime_site_packages() + self._verify_runtime_import(runtime_variant) + + def _use_packaged_runtime_bootstrap(self) -> bool: + return getattr(sys, "frozen", False) and sys.platform == "win32" + + def _runtime_site_packages_candidates(self) -> list[Path]: + candidates: list[Path] = [ + self._runtime_dir / "site-packages", + self._runtime_dir / "Lib" / "site-packages", + ] + lib_dir = self._runtime_dir / "lib" + if lib_dir.exists(): + candidates.extend(sorted(lib_dir.glob("python*/site-packages"))) + version_tag = f"python{sys.version_info.major}.{sys.version_info.minor}" + candidates.append(self._runtime_dir / "lib" / version_tag / "site-packages") + + unique_candidates: list[Path] = [] + seen_paths: set[Path] = set() + for candidate in candidates: + if candidate in seen_paths: + continue + seen_paths.add(candidate) + unique_candidates.append(candidate) + return unique_candidates + + def _runtime_python_executable(self) -> Path: + if os.name == "nt": + return self._runtime_dir / "Scripts" / "python.exe" + return self._runtime_dir / "bin" / "python3" + + def _prepare_runtime_with_packaged_bootstrap(self, runtime_variant: str) -> None: + requirements_file = self._resolve_runtime_requirements_path(runtime_variant) + runtime_site_packages = self._runtime_dir / "site-packages" + runtime_site_packages.mkdir(parents=True, exist_ok=True) + runtime_requirements = self._runtime_requirements(runtime_variant) + + command = [ + str(Path(sys.executable).resolve()), + "--runtime-bootstrap", + "--runtime-dir", + str(self._runtime_dir), + "--requirements", + str(requirements_file), + "--index-url", + self._source_map["pypi_primary"], + "--extra-index-url", + self._source_map["pypi_fallback"], + ] + if runtime_requirements.extra_index_urls: + command.extend(["--extra-index-url", self._source_map["torch_primary"]]) + if ( + self._source_map["torch_fallback"] + and self._source_map["torch_fallback"] != self._source_map["torch_primary"] + ): + command.extend( + ["--extra-index-url", self._source_map["torch_fallback"]] + ) + + self._run_subprocess( + command, + f"Install {runtime_variant} runtime", + progress_stage=STAGE_PREPARING_RUNTIME, + progress_kind=PROGRESS_KIND_RUNTIME, + ) + self._inject_runtime_site_packages() + self._verify_runtime_import(runtime_variant) + + def _run_subprocess( + self, + command: list[str], + label: str, + *, + progress_stage: str | None = None, + progress_kind: str | None = None, + ) -> None: + """ + Run a subprocess while forwarding structured progress updates when possible. + + 运行子进程,并在可能时转发结构化进度更新。 + """ + popen_kwargs = { + "stdout": subprocess.PIPE, + "stderr": subprocess.STDOUT, + "text": True, + "encoding": "utf-8", + "errors": "replace", + } + if os.name == "nt": + popen_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0) + + process = subprocess.Popen(command, **popen_kwargs) + self._active_process = process + latest_progress_bytes: tuple[int, int | None] | None = None + try: + assert process.stdout is not None + for line in process.stdout: + self._raise_if_cancelled() + text = line.strip() + if text: + parsed = parse_pip_raw_progress_line(text) + if ( + parsed is not None + and progress_stage is not None + and progress_kind is not None + ): + current, total = parsed + latest_progress_bytes = (current, total if total > 0 else None) + self._emit_progress_event( + InitializationProgressEvent( + stage=progress_stage, + progress_kind=progress_kind, + message=f"{label}: raw progress {current}/{total}", + ratio=(current / total) if total > 0 else None, + bytes_done=current, + bytes_total=total if total > 0 else None, + ) + ) + continue + self.item_status_changed.emit("runtime", "progress", f"{label}: {text}") + return_code = process.wait() + if self._cancel_requested.is_set(): + raise InitializationInterrupted("Initialization interrupted by user") + if return_code != 0: + raise RuntimeError(f"{label} failed with exit code {return_code}") + if progress_stage is not None and progress_kind is not None: + done_bytes = latest_progress_bytes[0] if latest_progress_bytes else None + total_bytes = latest_progress_bytes[1] if latest_progress_bytes else None + self._emit_progress_event( + InitializationProgressEvent( + stage=progress_stage, + progress_kind=progress_kind, + message=f"{label} completed", + ratio=1.0, + bytes_done=total_bytes or done_bytes, + bytes_total=total_bytes, + is_terminal=True, + ) + ) + finally: + self._active_process = None + + def _resolve_python_command(self) -> list[str]: + if os.environ.get("VIRTUAL_ENV") and shutil.which("python"): + return [shutil.which("python") or "python"] + + candidates = [ + ( + [sys.executable] + if sys.executable and not getattr(sys, "frozen", False) + else None + ), + [shutil.which("python3")] if shutil.which("python3") else None, + [shutil.which("python")] if shutil.which("python") else None, + ["py", "-3"] if shutil.which("py") else None, + ] + for candidate in candidates: + if not candidate: + continue + try: + subprocess.run( + [*candidate, "-c", "import sys; print(sys.executable)"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + text=True, + ) + return candidate + except Exception: + continue + raise RuntimeError("Unable to find a Python interpreter for runtime bootstrap") + + def _inject_runtime_site_packages(self) -> None: + if self._uses_bundled_runtime(): + return + importlib.invalidate_caches() + for candidate in self._runtime_site_packages_candidates(): + if candidate.exists(): + path = str(candidate) + if path not in sys.path: + sys.path.insert(0, path) + + def _verify_runtime_import(self, runtime_variant: str) -> None: + try: + if self._uses_bundled_runtime(): + torch_module = importlib.import_module("torch") + torch_version = getattr(torch_module, "__version__", "unknown") + self._emit_item_status( + "runtime", + "done", + f"Bundled Torch import OK: {torch_version} ({runtime_variant})", + ) + return + importlib.invalidate_caches() + self._inject_runtime_site_packages() + runtime_python = self._runtime_python_executable() + if runtime_python.exists(): + result = subprocess.run( + [ + str(runtime_python), + "-c", + ( + "import torch, sys; " + "print(torch.__version__); " + "print(sys.executable)" + ), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8", + errors="replace", + check=True, + ) + runtime_lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] + torch_version = runtime_lines[0] if runtime_lines else "unknown" + runtime_executable = runtime_lines[1] if len(runtime_lines) > 1 else str(runtime_python) + self._emit_item_status( + "runtime", + "done", + f"Torch import OK: {torch_version} ({runtime_variant}) via {runtime_executable}", + ) + return + torch_module = importlib.import_module("torch") + torch_version = getattr(torch_module, "__version__", "unknown") + self._emit_item_status( + "runtime", + "done", + f"Torch import OK: {torch_version} ({runtime_variant})", + ) + except Exception as exc: + raise RuntimeError( + f"Runtime installed but Torch import failed: {exc}" + ) from exc + + def _runtime_import_ok(self) -> bool: + try: + if self._uses_bundled_runtime(): + importlib.invalidate_caches() + importlib.import_module("torch") + return True + self._inject_runtime_site_packages() + importlib.invalidate_caches() + importlib.import_module("torch") + return True + except Exception: + return False + + def _raise_if_cancelled(self) -> None: + if self._cancel_requested.is_set(): + raise InitializationInterrupted("Initialization interrupted by user") + + def _cleanup_partial_runtime(self) -> None: + if self._uses_bundled_runtime(): + return + runtime_dir = self._runtime_dir + if not runtime_dir.exists(): + return + + removable_paths = [ + runtime_dir / "site-packages", + runtime_dir / "Lib" / "site-packages", + runtime_dir / "runtime_install_manifest.json", + ] + for candidate in removable_paths: + try: + if candidate.is_dir(): + shutil.rmtree(candidate, ignore_errors=True) + else: + candidate.unlink(missing_ok=True) + except Exception as exc: + logging.warning("清理运行时残留失败: %s (%s)", candidate, exc) + + def _pip_cache_roots(self) -> list[Path]: + roots: list[Path] = [] + if sys.platform == "win32": + local_app_data = os.environ.get("LOCALAPPDATA") + if local_app_data: + roots.append(Path(local_app_data) / "pip" / "Cache") + else: + roots.append(Path.home() / ".cache" / "pip") + return roots + + def _purge_pip_cache_if_needed(self) -> None: + for cache_root in self._pip_cache_roots(): + if not cache_root.exists(): + continue + for relative_name in ("http-v2", "http", "wheels", "selfcheck"): + candidate = cache_root / relative_name + try: + if candidate.is_dir(): + shutil.rmtree(candidate, ignore_errors=True) + else: + candidate.unlink(missing_ok=True) + except Exception as exc: + logging.warning("清理 pip 缓存失败: %s (%s)", candidate, exc) + + def _has_runtime_available(self) -> bool: + if self._uses_bundled_runtime(): + importlib.invalidate_caches() + return importlib.util.find_spec("torch") is not None + if importlib.util.find_spec("torch") is not None: + return True + self._inject_runtime_site_packages() + return importlib.util.find_spec("torch") is not None + + def _resources_available(self, selected_features: Optional[Iterable[str]]) -> bool: + features = self._normalize_features(selected_features) + plan = resolve_download_plan(features) + return all( + self._resource_item_available(item) + for item in plan + if item.get("required") or selected_features + ) + + def _resource_item_available(self, item: dict) -> bool: + path = ( + resolve_resource_destination_dir(self._project_root, item) + / item["filename"] + ) + return path.exists() + + def _detect_cuda_capable(self) -> bool: + if sys.platform != "win32": + return False + try: + result = subprocess.run( + ["nvidia-smi", "-L"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=4, + creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), + ) + return result.returncode == 0 and bool(result.stdout.strip()) + except Exception: + return False diff --git a/core/initialization_progress.py b/core/initialization_progress.py new file mode 100644 index 00000000..25ceaf7d --- /dev/null +++ b/core/initialization_progress.py @@ -0,0 +1,504 @@ +# -*- coding: utf-8 -*- +""" +Initialization progress models for lightweight onboarding flows. + +This module centralizes structured progress events, stage-to-phase mapping, +and the deterministic progress animation policy shared by the welcome dialog +and the repair dialog. The model keeps UI timing logic out of Qt widgets so it +can be verified with small, repeatable tests. + +轻量化初始化流程的进度模型。 + +此模块集中定义结构化进度事件、阶段到动画阶段的映射关系,以及欢迎向导与修复对话框共享的确定性进度动画策略。 +这样可以把 UI 的时间推进逻辑从 Qt 控件中抽离出来,便于使用小型、可重复的测试进行验证。 +""" + +from __future__ import annotations + +import math +import re +from dataclasses import dataclass + + +PROGRESS_KIND_RUNTIME = "runtime_install" +PROGRESS_KIND_DOWNLOAD = "resource_download" + +STAGE_PROBING = "probing_sources" +STAGE_CHECKING_UPDATES = "checking_updates" +STAGE_PREPARING_RUNTIME = "preparing_runtime" +STAGE_DOWNLOADING = "downloading_resources" +STAGE_VERIFYING = "verifying" +STAGE_READY = "ready" +STAGE_FAILED = "failed" + +PIP_RAW_PROGRESS_PATTERN = re.compile(r"^Progress\s+(?P\d+)\s+of\s+(?P\d+)$") + + +@dataclass(frozen=True) +class InitializationProgressEvent: + """ + Structured progress payload emitted by initialization subsystems. + + 初始化子系统发出的结构化进度负载。 + + Attributes: + stage: Pipeline stage that owns this event. + progress_kind: Logical progress stream, such as runtime install or resource download. + message: Human-readable status text for logs or labels. + ratio: Normalized completion ratio in [0.0, 1.0] when known. + bytes_done: Downloaded or processed bytes when known. + bytes_total: Expected total bytes when known. + item_index: Zero-based item index in a multi-item batch. + item_count: Total item count in a multi-item batch. + resource_id: Resource identifier for per-item tracking. + source: Active mirror or backend source name. + is_terminal: Whether the current item or phase has completed. + """ + + stage: str + progress_kind: str + message: str + ratio: float | None = None + bytes_done: int | None = None + bytes_total: int | None = None + item_index: int | None = None + item_count: int | None = None + resource_id: str | None = None + source: str | None = None + is_terminal: bool = False + + def normalized_ratio(self) -> float | None: + """ + Return a clamped completion ratio when one can be inferred. + + 当可以推导出完成比时,返回一个夹紧后的进度比例。 + """ + if self.ratio is not None: + return max(0.0, min(1.0, float(self.ratio))) + if self.bytes_total and self.bytes_total > 0 and self.bytes_done is not None: + return max(0.0, min(1.0, float(self.bytes_done) / float(self.bytes_total))) + return None + + +@dataclass(frozen=True) +class ProgressSnapshot: + """ + Render-ready progress state returned by the animation model. + + 动画模型返回的可直接用于渲染的进度状态。 + """ + + display_percent: int + actual_percent: int + target_percent: int + display_value: float + actual_value: float + target_value: float + active_phase: str | None + is_finishing: bool + is_settled: bool + suggested_interval_ms: int + + +@dataclass(frozen=True) +class ProgressPhaseProfile: + """ + Visual progress allocation and timing policy for one logical phase. + + 单个逻辑阶段的视觉进度分配与时间策略。 + """ + + key: str + start_percent: float + end_percent: float + min_duration_seconds: float + max_duration_seconds: float + + @property + def span(self) -> float: + """Return the percentage span owned by this phase.""" + return self.end_percent - self.start_percent + + +def phase_from_stage(stage: str) -> str | None: + """ + Map an initialization stage to the owning visual phase. + + 将初始化阶段映射到对应的视觉动画阶段。 + """ + if stage == STAGE_PREPARING_RUNTIME: + return PROGRESS_KIND_RUNTIME + if stage == STAGE_DOWNLOADING: + return PROGRESS_KIND_DOWNLOAD + return None + + +def parse_pip_raw_progress_line(line: str) -> tuple[int, int] | None: + """ + Parse `pip --progress-bar raw` lines into byte counters. + + 解析 `pip --progress-bar raw` 输出行,提取字节级进度。 + """ + match = PIP_RAW_PROGRESS_PATTERN.match(line.strip()) + if not match: + return None + return int(match.group("current")), int(match.group("total")) + + +class InitializationProgressModel: + """ + Deterministic progress animation model for long-running initialization tasks. + + 长耗时初始化任务的确定性进度动画模型。 + + The model uses fixed visual budgets per phase and lets time drive the bar. + Real progress is still tracked for logging and phase completion, but it no + longer directly drives the visible percentage. + + 该模型为每个阶段分配固定视觉预算,并以时间作为进度条主驱动。 + 真实进度仍用于日志与阶段完成判定,但不再直接驱动可见百分比。 + """ + + PHASES: dict[str, ProgressPhaseProfile] = { + PROGRESS_KIND_RUNTIME: ProgressPhaseProfile( + key=PROGRESS_KIND_RUNTIME, + start_percent=0.0, + end_percent=30.0, + min_duration_seconds=2.0, + max_duration_seconds=420.0, + ), + PROGRESS_KIND_DOWNLOAD: ProgressPhaseProfile( + key=PROGRESS_KIND_DOWNLOAD, + start_percent=30.0, + end_percent=99.0, + min_duration_seconds=250.0, + max_duration_seconds=420.0, + ), + } + + def __init__( + self, + *, + seed: int = 17, + min_visible_progress: int = 4, + suggested_interval_ms: int = 80, + ) -> None: + """ + Initialize the animation model with deterministic timing parameters. + + 使用确定性时间参数初始化动画模型。 + """ + self._seed = float(seed) + self._min_visible_progress = min_visible_progress + self._suggested_interval_ms = suggested_interval_ms + self.reset(0.0) + + def reset(self, now: float = 0.0) -> ProgressSnapshot: + """ + Reset the model to its initial idle state. + + 将模型重置为初始空闲状态。 + """ + self._display_percent = 0.0 + self._actual_percent = 0.0 + self._target_percent = 0.0 + self._active_phase: str | None = None + self._phase_started_at = now + self._phase_target_seconds = 0.0 + self._phase_complete_requested = False + self._phase_completion_started_at: float | None = None + self._phase_completion_from = 0.0 + self._phase_completion_duration = 0.0 + self._phase_completion_target = 0.0 + self._phase_was_observed = False + self._finish_started_at: float | None = None + self._finish_duration = 0.0 + self._finish_from = 0.0 + self._settled = False + return self.snapshot() + + def on_stage_changed(self, stage: str, now: float) -> ProgressSnapshot: + """ + React to a stage transition emitted by the initialization manager. + + 响应初始化管理器发出的阶段切换事件。 + """ + phase_key = phase_from_stage(stage) + if phase_key is not None: + self._activate_phase(phase_key, now) + elif stage == STAGE_VERIFYING: + self._actual_percent = max(self._actual_percent, 99.0) + self._target_percent = max(self._target_percent, 99.0) + return self.advance(now) + + def on_progress_event( + self, + event: InitializationProgressEvent, + now: float, + ) -> ProgressSnapshot: + """ + Update actual progress from a structured subsystem event. + + 使用结构化子系统事件更新真实进度。 + """ + phase_key = event.progress_kind if event.progress_kind in self.PHASES else phase_from_stage(event.stage) + if phase_key is not None: + self._activate_phase( + phase_key, + now, + bytes_total=event.bytes_total, + item_count=event.item_count, + event=event, + ) + phase = self.PHASES[phase_key] + ratio = event.normalized_ratio() + if ratio is not None: + actual = phase.start_percent + (phase.span * ratio) + self._actual_percent = max(self._actual_percent, actual) + if event.is_terminal: + self._actual_percent = max(self._actual_percent, phase.end_percent) + return self.advance(now) + + def on_finished(self, success: bool, now: float) -> ProgressSnapshot: + """ + Enter the success settle animation or stop immediately on failure. + + 成功时进入收尾动画,失败时立即停止动画推进。 + """ + if not success: + self._settled = True + return self.snapshot() + + self._actual_percent = max(self._actual_percent, 100.0) + self._target_percent = max(self._target_percent, 100.0) + self._finish_started_at = now + self._finish_from = max(self._display_percent, min(self._target_percent, 99.5)) + remaining = max(0.0, 100.0 - self._finish_from) + self._finish_duration = min(2.0, max(1.0, 1.0 + (remaining / 60.0))) + self._settled = False + return self.advance(now) + + def advance(self, now: float) -> ProgressSnapshot: + """ + Advance the animation to the specified monotonic time. + + 将动画推进到指定的单调时间点。 + """ + if self._finish_started_at is not None: + elapsed = max(0.0, now - self._finish_started_at) + ratio = min(1.0, elapsed / max(0.001, self._finish_duration)) + eased = 1.0 - (1.0 - ratio) ** 3 + self._display_percent = self._finish_from + ((100.0 - self._finish_from) * eased) + if ratio >= 1.0: + self._display_percent = 100.0 + self._settled = True + return self.snapshot() + + if self._active_phase is None: + return self.snapshot() + + if self._phase_completion_started_at is not None: + elapsed = max(0.0, now - self._phase_completion_started_at) + ratio = min(1.0, elapsed / max(0.001, self._phase_completion_duration)) + eased = 1.0 - (1.0 - ratio) ** 3 + target = self._phase_completion_target + self._display_percent = self._phase_completion_from + ( + (target - self._phase_completion_from) * eased + ) + self._target_percent = max(self._target_percent, self._display_percent) + if ratio >= 1.0: + self._display_percent = target + self._target_percent = max(self._target_percent, target) + self._phase_completion_started_at = None + return self.snapshot() + + time_target = self._compute_time_target(now) + desired = time_target + self._target_percent = max(self._target_percent, desired) + self._display_percent = max(self._display_percent, desired) + if self._phase_complete_requested and self._should_begin_phase_completion(now): + self._start_phase_completion(now) + return self.snapshot() + + def snapshot(self) -> ProgressSnapshot: + """ + Return an immutable render snapshot for the current model state. + + 返回当前模型状态的不可变渲染快照。 + """ + display_percent = int(round(self._display_percent)) + actual_percent = int(round(self._actual_percent)) + target_percent = int(round(self._target_percent)) + + if 0 < display_percent < self._min_visible_progress: + display_percent = self._min_visible_progress + + return ProgressSnapshot( + display_percent=max(0, min(100, display_percent)), + actual_percent=max(0, min(100, actual_percent)), + target_percent=max(0, min(100, target_percent)), + display_value=max(0.0, min(100.0, self._display_percent)), + actual_value=max(0.0, min(100.0, self._actual_percent)), + target_value=max(0.0, min(100.0, self._target_percent)), + active_phase=self._active_phase, + is_finishing=self._finish_started_at is not None and not self._settled, + is_settled=self._settled, + suggested_interval_ms=self._suggested_interval_ms, + ) + + def _activate_phase( + self, + phase_key: str, + now: float, + *, + bytes_total: int | None = None, + item_count: int | None = None, + event: InitializationProgressEvent | None = None, + ) -> None: + """ + Enter or retune a visual phase when new information arrives. + + 在新的信息到达时进入或重新调整视觉阶段。 + """ + profile = self.PHASES[phase_key] + is_new_phase = self._active_phase != phase_key + if is_new_phase: + self._active_phase = phase_key + self._phase_started_at = now + self._phase_complete_requested = False + self._phase_completion_started_at = None + self._phase_completion_from = max(self._display_percent, profile.start_percent) + self._phase_completion_duration = 0.0 + self._phase_completion_target = profile.end_percent + self._phase_was_observed = False + self._display_percent = max(self._display_percent, profile.start_percent) + self._actual_percent = max(self._actual_percent, profile.start_percent) + + self._phase_target_seconds = self._choose_target_duration( + profile, + bytes_total=bytes_total, + item_count=item_count, + event=event, + ) + if event is not None: + self._phase_was_observed = self._phase_was_observed or ( + event.bytes_total is not None + or event.bytes_done is not None + or event.normalized_ratio() is not None + ) + if event.is_terminal: + self._phase_complete_requested = True + + def _choose_target_duration( + self, + profile: ProgressPhaseProfile, + *, + bytes_total: int | None, + item_count: int | None, + event: InitializationProgressEvent | None, + ) -> float: + """ + Choose a phase duration inside the configured long-task window. + + 在配置好的长任务窗口内选择当前阶段的目标时长。 + """ + if profile.key == PROGRESS_KIND_RUNTIME: + if event is not None and event.is_terminal and not self._phase_was_observed: + return 4.0 + self._small_duration_jitter(profile.key) + base = 180.0 + self._large_duration_jitter(profile.key) + if bytes_total and bytes_total > 0: + base += min(35.0, bytes_total / float(1024 ** 3) * 20.0) + return min(profile.max_duration_seconds, max(profile.min_duration_seconds, base)) + + base = 300.0 + self._large_duration_jitter(profile.key) + if bytes_total and bytes_total > 0: + size_gib = bytes_total / float(1024 ** 3) + base += min(55.0, size_gib * 45.0) + if item_count and item_count > 1: + base += min(35.0, float(item_count - 1) * 8.0) + return min(profile.max_duration_seconds, max(profile.min_duration_seconds, base)) + + def _compute_time_target(self, now: float) -> float: + """ + Compute the monotonic time-driven target for the active phase. + + 计算当前活动阶段的单调时间驱动目标值。 + """ + assert self._active_phase is not None + profile = self.PHASES[self._active_phase] + elapsed = max(0.0, now - self._phase_started_at) + duration = max(1.0, self._phase_target_seconds) + progress_ratio = min(0.985, elapsed / duration) + jitter = self._bounded_jitter(profile, elapsed, progress_ratio) + raw_ratio = min(0.985, max(0.0, progress_ratio + jitter)) + target = profile.start_percent + (profile.span * raw_ratio) + return max(self._display_percent, target) + + def _bounded_jitter( + self, + profile: ProgressPhaseProfile, + elapsed: float, + progress_ratio: float, + ) -> float: + """ + Return a small deterministic oscillation for natural-looking speed changes. + + 返回一个小幅、确定性的振荡量,用于制造更自然的速度变化。 + """ + damping = max(0.18, 1.0 - progress_ratio) + amplitude = 0.014 * damping + if profile.key == PROGRESS_KIND_DOWNLOAD: + amplitude += 0.004 + + wave = ( + math.sin((elapsed * 0.10) + self._seed) + + 0.45 * math.sin((elapsed * 0.27) + (self._seed * 1.7)) + + 0.2 * math.sin((elapsed * 0.51) + (self._seed * 2.3)) + ) + jitter = (wave / 1.65) * amplitude + return min(0.018, max(-0.014, jitter)) + + def _small_duration_jitter(self, phase_key: str) -> float: + """ + Return a deterministic short-duration jitter in seconds. + + 返回一个确定性的短时长抖动(秒)。 + """ + offset = self._seed + (0.37 if phase_key == PROGRESS_KIND_RUNTIME else 0.91) + return math.sin(offset) * 1.8 + + def _large_duration_jitter(self, phase_key: str) -> float: + """ + Return a deterministic long-duration jitter in seconds. + + 返回一个确定性的长时长抖动(秒)。 + """ + offset = self._seed + (1.13 if phase_key == PROGRESS_KIND_RUNTIME else 2.41) + return math.sin(offset) * 18.0 + + def _should_begin_phase_completion(self, now: float) -> bool: + """ + Decide whether the phase may start its completion settle animation. + + 判断阶段是否可以开始进入完成收敛动画。 + """ + if self._active_phase is None: + return False + elapsed = max(0.0, now - self._phase_started_at) + return elapsed >= self._phase_target_seconds * 0.92 + + def _start_phase_completion(self, now: float) -> None: + """ + Start a short settle animation into the phase end percentage. + + 启动一个短暂的收敛动画,把进度推进到当前阶段的终点百分比。 + """ + if self._active_phase is None: + return + profile = self.PHASES[self._active_phase] + self._phase_completion_started_at = now + self._phase_completion_from = max(self._display_percent, profile.start_percent) + remaining = max(0.0, profile.end_percent - self._phase_completion_from) + self._phase_completion_duration = min(1.8, max(0.8, 0.6 + (remaining / 45.0))) + self._phase_completion_target = profile.end_percent + self._phase_complete_requested = False diff --git a/core/keypoint_detector.py b/core/keypoint_detector.py index ef84065a..d1f4eb9f 100644 --- a/core/keypoint_detector.py +++ b/core/keypoint_detector.py @@ -1,6 +1,10 @@ """ -关键点检测器模块 -使用 CUB-200 训练的 ResNet50 模型检测鸟类关键点(左眼、右眼、喙) +关键点检测器模块。 +Keypoint detector module. + +使用 CUB-200 训练的 ResNet50 模型检测鸟类关键点(左眼、右眼、喙)。 +Uses a CUB-200-trained ResNet50 model to detect bird keypoints (left eye, +right eye, and beak). """ import os @@ -13,26 +17,30 @@ import numpy as np import cv2 from dataclasses import dataclass -from typing import Optional, Tuple -from config import get_best_device +from typing import Any, Optional, Tuple, cast +from config import ( + get_best_device, + get_install_scoped_resource_path, + get_packaged_model_relative_path, + get_runtime_app_root, + get_runtime_meipass, +) @dataclass class KeypointResult: """关键点检测结果""" - left_eye: Tuple[float, float] # (x, y) 归一化坐标 + left_eye: Tuple[float, float] right_eye: Tuple[float, float] beak: Tuple[float, float] - left_eye_vis: float # 可见性概率 0-1 + left_eye_vis: float right_eye_vis: float beak_vis: float - - # 派生属性 - both_eyes_hidden: bool # 双眼是否都不可见(保留兼容) - all_keypoints_hidden: bool # 所有关键点(双眼+鸟喙)都不可见 - best_eye_visibility: float # 双眼中较高的置信度 max(左眼, 右眼) - visible_eye: Optional[str] # 'left', 'right', 'both', None - head_sharpness: float # 头部区域锐度 + both_eyes_hidden: bool + all_keypoints_hidden: bool + best_eye_visibility: float + visible_eye: Optional[str] + head_sharpness: float class PartLocalizer(nn.Module): @@ -40,9 +48,10 @@ class PartLocalizer(nn.Module): def __init__(self, num_parts=3, hidden_dim=512, dropout=0.2): super().__init__() self.num_parts = num_parts - self.backbone = models.resnet50(weights=None) - in_features = self.backbone.fc.in_features - self.backbone.fc = nn.Identity() + self.backbone = cast(nn.Module, models.resnet50(weights=None)) + backbone_fc = cast(nn.Linear, getattr(self.backbone, "fc")) + in_features = backbone_fc.in_features + setattr(self.backbone, "fc", nn.Identity()) self.head = nn.Sequential( nn.Linear(in_features, hidden_dim), @@ -67,26 +76,35 @@ def forward(self, x): class KeypointDetector: """鸟类关键点检测器""" - # 默认配置 IMG_SIZE = 416 - VISIBILITY_THRESHOLD = 0.3 # 至少一个关键点可见性需≥0.3才不算"全部不可见" - RADIUS_MULTIPLIER = 1.2 # 有喙时的半径系数 - NO_BEAK_RADIUS_RATIO = 0.15 # 无喙时用检测框的15% + VISIBILITY_THRESHOLD = 0.3 + RADIUS_MULTIPLIER = 1.2 + NO_BEAK_RADIUS_RATIO = 0.15 @staticmethod def _get_default_model_path() -> str: - """获取默认模型路径(支持 PyInstaller 打包)""" + """ + 获取默认模型路径并兼容冻结环境。 + Resolve the default model path while remaining compatible with frozen builds. + """ import sys - if hasattr(sys, '_MEIPASS'): - # PyInstaller 打包后的路径 - return os.path.join(sys._MEIPASS, 'models', 'cub200_keypoint_resnet50_slim.pth') - else: - # 开发环境:优先使用 main.py 注入的真实 app 根目录(补丁覆盖层兼容) - project_root = getattr(sys, '_SUPERPICKY_APP_ROOT', - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - return os.path.join(project_root, 'models', 'cub200_keypoint_resnet50_slim.pth') + if getattr(sys, 'frozen', False) and sys.platform == 'win32': + return str( + get_install_scoped_resource_path( + 'models/cub200_keypoint_resnet50_slim.pth', + packaged_relative_path=get_packaged_model_relative_path('models/cub200_keypoint_resnet50_slim.pth'), + ) + ) + meipass = get_runtime_meipass() + if meipass is not None: + return os.path.join(meipass, 'models', 'cub200_keypoint_resnet50_slim.pth') + + project_root = get_runtime_app_root() + if project_root is None: + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(project_root, 'models', 'cub200_keypoint_resnet50_slim.pth') - def __init__(self, model_path: str = None): + def __init__(self, model_path: Optional[str] = None): """ 初始化关键点检测器 @@ -94,9 +112,8 @@ def __init__(self, model_path: str = None): model_path: 模型文件路径,默认使用自动检测的路径 """ self.model_path = model_path or self._get_default_model_path() - # 使用统一的设备检测逻辑 - self.device = get_best_device() - self.model = None + self.device = torch.device(str(get_best_device())) + self.model: Optional[PartLocalizer] = None self.transform = transforms.Compose([ transforms.Resize((self.IMG_SIZE, self.IMG_SIZE)), transforms.ToTensor(), @@ -112,27 +129,33 @@ def load_model(self): raise FileNotFoundError(f"关键点模型不存在: {self.model_path}") self.model = PartLocalizer() - checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=True) + checkpoint = torch.load( + self.model_path, + map_location=self.device, + weights_only=True, + ) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model.load_state_dict(checkpoint) - self.model.to(self.device) - - # V4.0.5: 启用 FP16 半精度推理,提速约 30-50% - # MPS 和 CUDA 都支持 FP16 + self.model.to(device=self.device) + if self.device.type in ('mps', 'cuda'): self.model = self.model.half() self._use_fp16 = True else: self._use_fp16 = False - + self.model.eval() - def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, - seg_mask: np.ndarray = None) -> Optional[KeypointResult]: + def detect( + self, + bird_crop: np.ndarray, + box: Optional[Tuple[int, int, int, int]] = None, + seg_mask: Optional[np.ndarray] = None, + ) -> Optional[KeypointResult]: """ 检测鸟类关键点并计算头部锐度 @@ -148,22 +171,24 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, if bird_crop is None or bird_crop.size == 0: return None - - # 转换为PIL并进行推理 + pil_crop = Image.fromarray(bird_crop) - tensor = self.transform(pil_crop).unsqueeze(0).to(self.device) - - # V4.0.5: 使用 FP16 和 inference_mode 优化推理 + transformed_tensor = cast(torch.Tensor, self.transform(pil_crop)) + tensor = transformed_tensor.unsqueeze(0).to(self.device) + if hasattr(self, '_use_fp16') and self._use_fp16: tensor = tensor.half() - + + if self.model is None: + raise RuntimeError("关键点检测模型尚未初始化") + with torch.inference_mode(): coords, vis = self.model(tensor) - + coords = coords[0].cpu().numpy() vis = vis[0].cpu().numpy() - - # 解析结果 + del tensor + left_eye = (float(coords[0, 0]), float(coords[0, 1])) right_eye = (float(coords[1, 0]), float(coords[1, 1])) beak = (float(coords[2, 0]), float(coords[2, 1])) @@ -171,15 +196,12 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, left_eye_vis = float(vis[0]) right_eye_vis = float(vis[1]) beak_vis = float(vis[2]) - - # 判断可见性 + left_visible = left_eye_vis >= self.VISIBILITY_THRESHOLD right_visible = right_eye_vis >= self.VISIBILITY_THRESHOLD beak_visible = beak_vis >= self.VISIBILITY_THRESHOLD - - # 保留旧属性(兼容) + both_eyes_hidden = not left_visible and not right_visible - # 新逻辑:只有当双眼和鸟喙都不可见时才算"全部不可见" all_keypoints_hidden = not left_visible and not right_visible and not beak_visible if left_visible and right_visible: @@ -190,8 +212,7 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, visible_eye = 'right' else: visible_eye = None - - # 计算头部锐度 + head_sharpness = 0.0 if visible_eye is not None: head_sharpness = self._calculate_head_sharpness( @@ -199,10 +220,9 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, left_eye_vis, right_eye_vis, beak_visible, box, seg_mask ) - - # V3.8: 计算双眼中较高的置信度,用于评分封顶逻辑 + best_eye_visibility = max(left_eye_vis, right_eye_vis) - + return KeypointResult( left_eye=left_eye, right_eye=right_eye, @@ -226,25 +246,19 @@ def _calculate_head_sharpness( left_eye_vis: float, right_eye_vis: float, beak_visible: bool, - box: Tuple[int, int, int, int] = None, - seg_mask: np.ndarray = None + box: Optional[Tuple[int, int, int, int]] = None, + seg_mask: Optional[np.ndarray] = None ) -> float: """ 计算头部区域锐度 - - 使用眼睛为圆心,眼喙距离×1.2为半径,与seg掩码取交集 """ h, w = bird_crop.shape[:2] - # 如果双眼都不可见(如鸟侧面、头部转向等): - # 模型坐标仍然大致指向头部位置,用置信度较高的那只眼做 fallback - # 沿用与正常流程完全相同的"圆形区域 Sobel"算法,结果 ×0.8 作为惩罚 - # 这样与正常眼睛检测的锐度值在同一量级,不会因用全身 Sobel 而虚高 if left_eye_vis < self.VISIBILITY_THRESHOLD and right_eye_vis < self.VISIBILITY_THRESHOLD: eye = left_eye if left_eye_vis >= right_eye_vis else right_eye eye_px = (int(eye[0] * w), int(eye[1] * h)) beak_px = (int(beak[0] * w), int(beak[1] * h)) - if beak_vis >= self.VISIBILITY_THRESHOLD: + if beak_visible: radius = int(self._distance(eye_px, beak_px) * self.RADIUS_MULTIPLIER) elif box is not None: box_size = max(box[2], box[3]) @@ -258,111 +272,86 @@ def _calculate_head_sharpness( head_mask = cv2.bitwise_and(circle_mask, seg_mask) else: head_mask = circle_mask - LOW_VIS_PENALTY = 0.8 # 眼睛不可见时降分但不误杀 + LOW_VIS_PENALTY = 0.8 return self._calculate_sharpness(bird_crop, head_mask) * LOW_VIS_PENALTY - - # 选择眼睛:用更远离喙的那只眼 if left_eye_vis >= self.VISIBILITY_THRESHOLD and right_eye_vis >= self.VISIBILITY_THRESHOLD: - # 两眼都可见,选更远离喙的 left_dist = self._distance(left_eye, beak) right_dist = self._distance(right_eye, beak) eye = left_eye if left_dist >= right_dist else right_eye elif left_eye_vis >= self.VISIBILITY_THRESHOLD: eye = left_eye else: - # 只有一只眼可见(右眼) eye = right_eye - - # 转换为像素坐标 + eye_px = (int(eye[0] * w), int(eye[1] * h)) beak_px = (int(beak[0] * w), int(beak[1] * h)) - - # 计算半径 + if beak_visible: radius = int(self._distance(eye_px, beak_px) * self.RADIUS_MULTIPLIER) elif box is not None: - # 无喙时用检测框的15% - # box 格式是 (x, y, w, h),所以 box[2]=width, box[3]=height box_size = max(box[2], box[3]) radius = int(box_size * self.NO_BEAK_RADIUS_RATIO) else: - # 最后fallback:用裁剪区域的15% radius = int(max(w, h) * self.NO_BEAK_RADIUS_RATIO) - - # 确保半径合理 + radius = max(10, min(radius, min(w, h) // 2)) - - # 创建圆形掩码 + circle_mask = np.zeros((h, w), dtype=np.uint8) cv2.circle(circle_mask, eye_px, radius, 255, -1) - - # 如果有seg掩码,取交集 + if seg_mask is not None: - # seg_mask可能需要裁剪到bird_crop区域 - # 这里假设bird_crop已经是裁剪后的,seg_mask也已相应处理 if seg_mask.shape[:2] == (h, w): head_mask = cv2.bitwise_and(circle_mask, seg_mask) else: head_mask = circle_mask else: head_mask = circle_mask - - # 计算锐度 + return self._calculate_sharpness(bird_crop, head_mask) def _calculate_sharpness(self, image: np.ndarray, mask: np.ndarray) -> float: """ 计算掩码区域的锐度(Tenengrad + 对数归一化) - - V3.7 改动: 使用 Tenengrad (Sobel梯度) 替代 Laplacian以减少噪点干扰 - 并使用对数归一化将结果映射到 0-1000 范围 """ if mask.sum() == 0: return 0.0 - - # 转灰度 + if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: gray = image - - # Tenengrad 算子 (Sobel梯度平方和) + gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) gradient_magnitude = gx ** 2 + gy ** 2 - - # 只取掩码区域的平均值 + mask_pixels = mask > 0 if mask_pixels.sum() == 0: return 0.0 - + raw_sharpness = float(gradient_magnitude[mask_pixels].mean()) - - # 对数归一化到 0-1000 - # V4.0 修复: 降低 MIN_VAL,之前 1460 太高导致锐利照片也返回 0 - # 测试显示: 锐利照片梯度平均值约 800-2000 - MIN_VAL = 100.0 # 降低阈值,保留更多低锐度信息 + + MIN_VAL = 100.0 MAX_VAL = 154016.0 - + if raw_sharpness <= MIN_VAL: return 0.0 if raw_sharpness >= MAX_VAL: return 1000.0 - + log_val = math.log(raw_sharpness) - math.log(MIN_VAL) log_max = math.log(MAX_VAL) - math.log(MIN_VAL) - + return (log_val / log_max) * 1000.0 - + @staticmethod def _distance(p1: Tuple[float, float], p2: Tuple[float, float]) -> float: """计算两点距离""" return math.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2) -# 便捷函数 -_detector_instance = None +_detector_instance: Optional[KeypointDetector] = None def get_keypoint_detector() -> KeypointDetector: """获取全局关键点检测器实例""" diff --git a/core/photo_processor.py b/core/photo_processor.py index da96adca..aa670dc9 100644 --- a/core/photo_processor.py +++ b/core/photo_processor.py @@ -1458,6 +1458,21 @@ def exif_prefetch_worker(): elif self._perf_enabled: self._log(" ⚙️ EXIF prefetch: off") + # 周期性 GPU 显存清理间隔(MPS 每 50 张,CUDA 每 200 张) + # 提前计算避免在循环内 import torch 引发 UnboundLocalError + try: + import torch as _torch_module + import gc as _gc_module + _use_mps = hasattr(_torch_module, 'backends') and _torch_module.backends.mps.is_available() + _use_cuda = not _use_mps and _torch_module.cuda.is_available() + _cache_interval = 50 if _use_mps else 200 + except Exception: + _torch_module = None + _gc_module = None + _use_mps = False + _use_cuda = False + _cache_interval = 200 + for local_index in range(1, len(files_tbr) + 1): cancel_processing() i = display_start + local_index - 1 @@ -1577,20 +1592,17 @@ def add_photo_stage(stage: str, ms: float): progress = int((i / total_files) * 100) self._progress(progress) - # 周期性 GPU 显存清理(每 200 张) - # MPS 不像 CUDA 会自动回收,长批次(如 13000 张)会导致显存耗尽 - if i % 200 == 0: + if i % _cache_interval == 0 and _torch_module is not None: try: - import torch, gc - if torch.backends.mps.is_available(): - torch.mps.empty_cache() + if _use_mps: + _torch_module.mps.empty_cache() self._log(self.i18n.t("logs.mps_cache_cleared", index=i), "info") - elif torch.cuda.is_available(): - torch.cuda.empty_cache() + elif _use_cuda: + _torch_module.cuda.empty_cache() self._log(self.i18n.t("logs.cuda_cache_cleared", index=i), "info") else: self._log(f" 🧹 [第{i}张] GC 已执行", "info") - gc.collect() + _gc_module.collect() except Exception: pass diff --git a/core/recursive_scanner.py b/core/recursive_scanner.py index 4f5307ea..f557e8e8 100644 --- a/core/recursive_scanner.py +++ b/core/recursive_scanner.py @@ -7,7 +7,10 @@ """ import os -from typing import List, Set +import platform +from dataclasses import dataclass +from pathlib import PurePosixPath, PureWindowsPath +from typing import List, Optional, Set, Tuple from constants import RAW_EXTENSIONS, JPG_EXTENSIONS, HEIF_EXTENSIONS, RATING_FOLDER_NAMES, RATING_FOLDER_NAMES_EN @@ -17,6 +20,17 @@ # 星级目录名(中 + 英) _RATING_DIR_NAMES: Set[str] = set(RATING_FOLDER_NAMES.values()) | set(RATING_FOLDER_NAMES_EN.values()) +DEFAULT_SCAN_MAX_DEPTH = 16 + + +@dataclass(frozen=True) +class ScannedDirectory: + """扫描结果条目""" + + path: str + depth: int + photo_count: int + def is_excluded(dirname: str) -> bool: """判断目录是否应被排除(非用户照片目录)""" @@ -31,17 +45,137 @@ def is_excluded(dirname: str) -> bool: return False +def _scan_directory_once(dir_path: str) -> Tuple[int, List[str]]: + """单次扫描目录,返回直接照片数量与可继续扫描的子目录。""" + photo_count = 0 + child_dirs: List[str] = [] + + try: + with os.scandir(dir_path) as entries: + for entry in entries: + if entry.is_file(follow_symlinks=False): + ext = os.path.splitext(entry.name)[1].lower() + if ext in _PHOTO_EXTENSIONS: + photo_count += 1 + continue + + if not entry.is_dir(follow_symlinks=False): + continue + if is_excluded(entry.name): + continue + + child_dirs.append(entry.path) + except (FileNotFoundError, NotADirectoryError, PermissionError): + return 0, [] + + child_dirs.sort(key=lambda value: os.path.basename(value).casefold()) + return photo_count, child_dirs + + +def _scan_directories_dfs(root: str, max_depth: int) -> List[ScannedDirectory]: + root = os.path.abspath(root) + if max_depth < 0: + return [] + + result: List[ScannedDirectory] = [] + stack: List[Tuple[str, int]] = [(root, 0)] + while stack: + dir_path, depth = stack.pop() + photo_count, child_dirs = _scan_directory_once(dir_path) + if photo_count > 0: + result.append(ScannedDirectory(path=dir_path, depth=depth, photo_count=photo_count)) + if depth >= max_depth: + continue + for child_dir in reversed(child_dirs): + stack.append((child_dir, depth + 1)) + result.sort(key=lambda item: item.path.casefold()) + return result + + +def _is_windows_path(path: str) -> bool: + drive, _ = os.path.splitdrive(path) + return bool(drive) or "\\" in path + + +def _is_subpath(candidate_parts: Tuple[str, ...], protected_parts: Tuple[str, ...]) -> bool: + if len(candidate_parts) < len(protected_parts): + return False + return candidate_parts[:len(protected_parts)] == protected_parts + + +def is_dangerous_root( + root: str, + platform_name: Optional[str] = None, + home_dir: Optional[str] = None, +) -> Tuple[bool, str]: + """判断根目录是否属于危险目录。""" + platform_name = (platform_name or platform.system()).lower() + home_dir = os.path.expanduser(home_dir or "~") + + if platform_name.startswith("win") or _is_windows_path(root): + normalized = str(PureWindowsPath(os.path.realpath(os.path.abspath(root)))) + root_path = PureWindowsPath(normalized) + anchor = root_path.anchor.rstrip("\\/") + current = normalized.rstrip("\\/") + if anchor and current.lower() == anchor.lower(): + return True, "磁盘根目录 / Drive root" + + protected_paths = [ + PureWindowsPath(os.path.realpath(os.environ.get("SystemRoot", "C:\\Windows"))), + PureWindowsPath(os.path.realpath("C:\\Program Files")), + PureWindowsPath(os.path.realpath("C:\\Program Files (x86)")), + PureWindowsPath(os.path.realpath(os.path.join(home_dir, "AppData"))), + ] + root_parts = tuple(part.casefold() for part in root_path.parts) + for protected in protected_paths: + protected_parts = tuple(part.casefold() for part in protected.parts) + if _is_subpath(root_parts, protected_parts): + return True, f"受保护的系统或设置目录 / Protected path: {protected}" + return False, "" + + normalized = str(PurePosixPath(os.path.realpath(os.path.abspath(root)))) + root_path = PurePosixPath(normalized) + root_parts = tuple(root_path.parts) + + # 严格相等判断:仅当路径恰好是文件系统根目录时才阻断。 + # 不能将 "/" 放入 protected_paths 用 _is_subpath() 做前缀匹配, + # 因为所有绝对路径的第一个 part 均为 "/",会导致误判一切目录。 + # Strict equality check: only block when the path is exactly the filesystem root. + # Do NOT put "/" into protected_paths for _is_subpath() prefix matching, + # because every absolute path starts with "/" as its first part, + # which would cause all directories to be falsely flagged. + if normalized == "/": + return True, "文件系统根目录 / Filesystem root" + + # 对受保护路径同样执行 realpath 解析,以处理符号链接。 + # 例如 macOS 上 /etc -> /private/etc,/var -> /private/var, + # 若不解析则与已 realpath 处理的 normalized 无法匹配。 + # Resolve protected paths with realpath too, to handle symlinks. + # e.g. on macOS: /etc -> /private/etc, /var -> /private/var. + _raw_protected = [ + "/usr", + "/etc", + "/var", + "/System", + "/Library", + os.path.join(home_dir, "Library"), + ] + protected_paths = [ + PurePosixPath(os.path.realpath(p)) for p in _raw_protected + ] + for protected in protected_paths: + protected_parts = tuple(protected.parts) + if _is_subpath(root_parts, protected_parts): + return True, f"受保护的系统或设置目录 / Protected path: {protected}" + if normalized in ("/home", os.path.realpath("/home")): + return True, "系统用户根目录 / System user root" + return False, "" + + def has_photos(dir_path: str) -> bool: """判断目录是否直接包含至少 1 个照片文件""" - try: - for entry in os.scandir(dir_path): - if entry.is_file(follow_symlinks=False): - ext = os.path.splitext(entry.name)[1].lower() - if ext in _PHOTO_EXTENSIONS: - return True - except PermissionError: - pass - return False + photo_count, _ = _scan_directory_once(dir_path) + return photo_count > 0 def is_processed(dir_path: str) -> bool: @@ -49,56 +183,34 @@ def is_processed(dir_path: str) -> bool: return os.path.exists(os.path.join(dir_path, '.superpicky', 'report.db')) -def scan_recursive(root: str, max_depth: int = 10) -> List[str]: +def scan_directories( + root: str, + max_depth: int = DEFAULT_SCAN_MAX_DEPTH, +) -> List[ScannedDirectory]: + """扫描根目录,返回包含照片的目录摘要列表。""" + return _scan_directories_dfs(root, max_depth) + + +def scan_dfs(root: str, max_depth: int = DEFAULT_SCAN_MAX_DEPTH) -> List[ScannedDirectory]: + """使用 DFS 扫描根目录。""" + return scan_directories(root, max_depth=max_depth) + + +def scan_recursive(root: str, max_depth: int = DEFAULT_SCAN_MAX_DEPTH) -> List[str]: """ 递归扫描根目录,返回所有原子目录(包含照片的非排除目录)的绝对路径列表。 Args: root: 根目录路径 - max_depth: 最大递归深度(默认 10) + max_depth: 最大递归深度(默认 16) Returns: 原子目录绝对路径列表,按字母排序 """ - result: List[str] = [] - - # 根目录本身如果包含照片,也加入列表 - if has_photos(root): - result.append(root) - - def _scan(dir_path: str, depth: int): - if depth > max_depth: - return - try: - entries = sorted(os.scandir(dir_path), key=lambda e: e.name) - except PermissionError: - return - - for entry in entries: - if not entry.is_dir(follow_symlinks=False): - continue - if is_excluded(entry.name): - continue - - if has_photos(entry.path): - result.append(entry.path) - - # 即使当前目录有照片,也继续扫描子目录 - _scan(entry.path, depth + 1) - - _scan(root, 0) - return result + return [item.path for item in scan_dfs(root, max_depth=max_depth)] def count_photos(dir_path: str) -> int: """统计目录中直接包含的照片文件数量""" - count = 0 - try: - for entry in os.scandir(dir_path): - if entry.is_file(follow_symlinks=False): - ext = os.path.splitext(entry.name)[1].lower() - if ext in _PHOTO_EXTENSIONS: - count += 1 - except PermissionError: - pass + count, _ = _scan_directory_once(dir_path) return count diff --git a/core/runtime_bootstrap.py b/core/runtime_bootstrap.py new file mode 100644 index 00000000..74b175ca --- /dev/null +++ b/core/runtime_bootstrap.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +""" +Packaged runtime bootstrap helper. + +This entrypoint is designed for frozen Windows lightweight builds. It runs in a +separate hidden process mode of the packaged executable and installs runtime +dependencies into an app-local site-packages directory without relying on any +system Python interpreter. +""" + +from __future__ import annotations + +import argparse +import os +import json +import sys +import importlib +from datetime import datetime, timezone +from pathlib import Path + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--runtime-bootstrap", action="store_true") + parser.add_argument("--runtime-dir", required=True) + parser.add_argument("--requirements", required=True) + parser.add_argument("--index-url", default=None) + parser.add_argument("--extra-index-url", action="append", default=[]) + return parser.parse_args(argv) + + +def _ensure_utf8_stdio() -> None: + for stream_name in ("stdout", "stderr"): + stream = getattr(sys, stream_name, None) + reconfigure = getattr(stream, "reconfigure", None) + if callable(reconfigure): + reconfigure(encoding="utf-8", errors="replace") + + +def _build_pip_args(args: argparse.Namespace, site_packages_dir: Path) -> list[str]: + """ + Build the pip command line for the packaged runtime bootstrap. + + 为打包运行时引导流程构建 pip 命令行参数。 + """ + pip_args = [ + "install", + "--disable-pip-version-check", + "--no-warn-script-location", + "--no-cache-dir", + "--progress-bar", + "raw", + "--use-deprecated=legacy-certs", + "--upgrade", + "--target", + str(site_packages_dir), + "-r", + str(Path(args.requirements).resolve()), + ] + if args.index_url: + pip_args.extend(["-i", args.index_url]) + for extra_index_url in args.extra_index_url: + pip_args.extend(["--extra-index-url", extra_index_url]) + return pip_args + + +def _write_manifest(runtime_dir: Path, site_packages_dir: Path, args: argparse.Namespace) -> None: + manifest = { + "generated_at": datetime.now(timezone.utc).isoformat(), + "runtime_dir": str(runtime_dir), + "site_packages_dir": str(site_packages_dir), + "requirements": str(Path(args.requirements).resolve()), + "index_url": args.index_url, + "extra_index_urls": list(args.extra_index_url), + "python_version": sys.version, + "bootstrap_executable": sys.executable, + } + manifest_path = runtime_dir / "runtime_install_manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8") + + +def _bundled_root() -> Path: + meipass = getattr(sys, "_MEIPASS", None) + if isinstance(meipass, str) and meipass: + return Path(meipass) + if getattr(sys, "frozen", False) and sys.platform == "win32": + return Path(sys.executable).resolve().parent / "_internal" + return Path(__file__).resolve().parent.parent + + +def _configure_ca_bundle() -> Path | None: + cert_path = _bundled_root() / "certifi" / "cacert.pem" + if not cert_path.exists(): + return None + os.environ.setdefault("PIP_CERT", str(cert_path)) + os.environ.setdefault("SSL_CERT_FILE", str(cert_path)) + os.environ.setdefault("REQUESTS_CA_BUNDLE", str(cert_path)) + return cert_path + + +def _patch_pip_for_frozen_bootstrap() -> None: + """ + Patch pip vendored distlib so it can run from a PyInstaller-frozen process. + + distlib expects a standard loader/resource finder pair and Windows launcher + executables inside the distlib package. In a frozen app, the PyInstaller + loader type is unknown to distlib, and launcher stubs are not needed for + our app-local runtime bootstrap. We register a fallback finder and disable + launcher generation on Windows. + """ + os.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1") + os.environ.setdefault("PIP_USE_DEPRECATED", "legacy-certs") + cert_path = _configure_ca_bundle() + + from pip._vendor.distlib import resources as distlib_resources + + distlib_pkg = importlib.import_module("pip._vendor.distlib") + loader = getattr(distlib_pkg, "__loader__", None) + if loader is not None: + finder_registry = getattr(distlib_resources, "_finder_registry", {}) + if type(loader) not in finder_registry: + distlib_resources.register_finder(loader, distlib_resources.ResourceFinder) + + import pip._vendor.distlib.scripts as distlib_scripts + import pip._vendor.certifi as pip_certifi + import pip._internal.cli.index_command as pip_index_command + + if not getattr(distlib_scripts.ScriptMaker.__init__, "_superpicky_patched", False): + original_init = distlib_scripts.ScriptMaker.__init__ + + def _patched_init(self, source_dir, target_dir, add_launchers=True, dry_run=False, fileop=None): + # We do not need .exe launcher stubs in the app-local target dir. + return original_init(self, source_dir, target_dir, add_launchers=False, dry_run=dry_run, fileop=fileop) + + _patched_init._superpicky_patched = True # type: ignore[attr-defined] + distlib_scripts.ScriptMaker.__init__ = _patched_init + + # Frozen bootstrap can resolve vendored certifi resources incorrectly under + # truststore. Force pip to skip the truststore SSL context path and fall + # back to its legacy certificate behavior. + pip_index_command._create_truststore_ssl_context = lambda: None + if cert_path is not None: + pip_certifi.where = lambda: str(cert_path) + + +def run_runtime_bootstrap(argv: list[str]) -> int: + _ensure_utf8_stdio() + args = _parse_args(argv) + runtime_dir = Path(args.runtime_dir).resolve() + site_packages_dir = runtime_dir / "site-packages" + site_packages_dir.mkdir(parents=True, exist_ok=True) + + _patch_pip_for_frozen_bootstrap() + + from pip._internal.cli.main import main as pip_main + + pip_args = _build_pip_args(args, site_packages_dir) + print(f"[runtime-bootstrap] target={site_packages_dir}") + exit_code = int(pip_main(pip_args)) + if exit_code != 0: + return exit_code + + if str(site_packages_dir) not in sys.path: + sys.path.insert(0, str(site_packages_dir)) + + import torch # noqa: F401 + + _write_manifest(runtime_dir, site_packages_dir, args) + print("[runtime-bootstrap] torch import verified") + return 0 diff --git a/core/runtime_requirements.py b/core/runtime_requirements.py new file mode 100644 index 00000000..5873f5de --- /dev/null +++ b/core/runtime_requirements.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +""" +Runtime requirements manager for lightweight builds. + +This module provides a unified interface for managing platform-specific runtime +dependencies across CPU, CUDA, and macOS builds. It consolidates the previously +separate requirements_runtime_*.txt files into a single Python module with +type-safe configuration access. + +轻量化构建的运行时依赖管理模块。 + +该模块为 CPU、CUDA 和 macOS 构建统一描述运行时依赖, +避免把平台差异散落在多个 requirements 文本文件与调用点之间。 +""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass +from typing import Literal + + +PlatformType = Literal["cpu", "cuda", "mac"] + + +@dataclass(frozen=True) +class RuntimeRequirements: + """ + Runtime dependency configuration for a specific platform. + + 单个平台对应的运行时依赖配置。 + """ + + torch_version: str + torchvision_version: str + torchaudio_version: str + timm_version: str + extra_index_urls: list[str] + index_url: str | None = None + + @staticmethod + def _format_pinned_requirement(package_name: str, version: str) -> str: + """ + Return a pinned requirement only when a version is provided. + + 仅在存在版本号时返回带固定版本的依赖声明。 + """ + + normalized_version = version.strip() + if not normalized_version: + return package_name + return f"{package_name}=={normalized_version}" + + def to_requirements_string( + self, + *, + include_indexes: bool = True, + package_urls: dict[str, str] | None = None, + ) -> str: + """ + Convert configuration to pip requirements file format. + + 将配置转换为 pip requirements 文件格式。 + """ + lines = [] + package_urls = package_urls or {} + if include_indexes and self.index_url: + lines.append(f"--index-url {self.index_url}") + if include_indexes: + for url in self.extra_index_urls: + lines.append(f"--extra-index-url {url}") + lines.append( + package_urls.get( + "torch", + self._format_pinned_requirement("torch", self.torch_version), + ) + ) + lines.append( + package_urls.get( + "torchvision", + self._format_pinned_requirement("torchvision", self.torchvision_version), + ) + ) + lines.append( + package_urls.get( + "torchaudio", + self._format_pinned_requirement("torchaudio", self.torchaudio_version), + ) + ) + lines.append(f"timm{self.timm_version}") + return "\n".join(lines) + + +def get_cpu_requirements() -> RuntimeRequirements: + """Get runtime requirements for CPU builds. / 获取 CPU 构建的运行时依赖。""" + return RuntimeRequirements( + torch_version="2.7.1+cpu", + torchvision_version="0.22.1+cpu", + torchaudio_version="2.7.1+cpu", + timm_version=">=0.9.0", + extra_index_urls=[ + "https://mirror.nju.edu.cn/pytorch/whl/cpu/", + "https://download.pytorch.org/whl/cpu", + ], + ) + + +def get_cuda_requirements() -> RuntimeRequirements: + """Get runtime requirements for CUDA builds. / 获取 CUDA 构建的运行时依赖。""" + return RuntimeRequirements( + torch_version="2.7.1+cu118", + torchvision_version="0.22.1+cu118", + torchaudio_version="2.7.1+cu118", + timm_version=">=0.9.0", + extra_index_urls=[ + "https://mirror.nju.edu.cn/pytorch/whl/cu118/", + "https://download.pytorch.org/whl/cu118", + ], + ) + + +def get_mac_requirements() -> RuntimeRequirements: + """Get runtime requirements for macOS builds. / 获取 macOS 构建的运行时依赖。""" + return RuntimeRequirements( + torch_version="2.8.0", + torchvision_version="", + torchaudio_version="", + timm_version=">=0.9.0", + extra_index_urls=[], + ) + + +def detect_platform() -> PlatformType: + """Detect the current platform type. / 检测当前平台类型。""" + if sys.platform == "darwin": + return "mac" + if sys.platform == "win32": + return "cuda" + return "cpu" + + +def get_runtime_requirements(platform: PlatformType | None = None) -> RuntimeRequirements: + """ + Get runtime requirements for the specified or detected platform. + + Args: + platform: Platform type ('cpu', 'cuda', 'mac'). If None, auto-detects. + 平台类型;若为 None,则自动检测。 + + Returns: + RuntimeRequirements: Platform-specific dependency configuration. + 对应平台的依赖配置。 + + Raises: + ValueError: If platform type is invalid. + 当平台类型非法时抛出。 + """ + if platform is None: + platform = detect_platform() + + requirements_getters = { + "cpu": get_cpu_requirements, + "cuda": get_cuda_requirements, + "mac": get_mac_requirements, + } + + getter = requirements_getters.get(platform) + if getter is None: + raise ValueError(f"Unsupported platform: {platform}") + + return getter() diff --git a/core/source_probe.py b/core/source_probe.py new file mode 100644 index 00000000..6a24e6bd --- /dev/null +++ b/core/source_probe.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +""" +HTTP source probe helpers for initialization. + +Notes: +- We intentionally do not use ICMP ping as the primary selection mechanism. +- Some networks block ping while HTTPS still works normally. +- Selection is based on real HTTP responsiveness and cached for the current run. + +HTTP 源探测辅助工具,用于初始化。 + +注意事项: +- 我们有意不使用 ICMP ping 作为主要选择机制。 +- 某些网络阻止 ping,但 HTTPS 仍然正常工作。 +- 选择基于真实的 HTTP 响应能力,并在当前运行中缓存。 +""" + +from __future__ import annotations + +import logging +import time +import urllib.request +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional + +logging.basicConfig(level=logging.INFO) + + +DEFAULT_TIMEOUT_SECONDS = 4.0 +LARGE_FILE_TIMEOUT_SECONDS = 10.0 + + +@dataclass +class ProbeResult: + """ + 源探测结果数据类。 + + Source probe result dataclass. + + 属性 Attributes: + name (str): 源名称 + url (str): 源 URL + ok (bool): 探测是否成功 + total_ms (float): 总响应时间(毫秒) + first_byte_ms (float): 首字节响应时间(毫秒) + error (Optional[str]): 错误信息(如果失败) + status_code (Optional[int]): HTTP 状态码 + response_headers (Optional[Dict[str, str]]): 响应头 + """ + + name: str + url: str + ok: bool + total_ms: float + first_byte_ms: float + error: Optional[str] = None + status_code: Optional[int] = None + response_headers: Optional[Dict[str, str]] = None + + +_PROBE_CACHE: Dict[str, List[ProbeResult]] = {} + + +def _normalize_probe_url(url: str) -> str: + """ + 标准化化探测 URL。 + + Normalize probe URL. + + 参数 Parameters: + url (str): 原始 URL + + 返回 Returns: + str: 标准化后的 URL + """ + if url.endswith("/simple"): + return url.rstrip("/") + "/pip/" + return url + + +def probe_url( + name: str, url: str, timeout: float = DEFAULT_TIMEOUT_SECONDS +) -> ProbeResult: + """ + 探测单个 URL 的响应能力。 + + Probe the responsiveness of a single URL. + + 参数 Parameters: + name (str): 源名称 + url (str): 要探测的 URL + timeout (float): 超时时间(秒) + + 返回 Returns: + ProbeResult: 探测结果 + """ + start = time.perf_counter() + request = urllib.request.Request( + _normalize_probe_url(url), + headers={"User-Agent": "SuperPicky-InitProbe/1.0"}, + method="GET", + ) + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + first_byte_start = time.perf_counter() + response.read(256) + first_byte_ms = (time.perf_counter() - first_byte_start) * 1000.0 + + status_code = response.getcode() + response_headers = dict(response.headers.items()) + + total_ms = (time.perf_counter() - start) * 1000.0 + + logging.info( + "源探测成功: %s (%s) - 状态码: %d, 总耗时: %.2f ms, 首字节: %.2f ms", + name, + url, + status_code, + total_ms, + first_byte_ms, + ) + + return ProbeResult( + name=name, + url=url, + ok=True, + total_ms=total_ms, + first_byte_ms=first_byte_ms, + status_code=status_code, + response_headers=response_headers, + ) + except Exception as exc: + total_ms = (time.perf_counter() - start) * 1000.0 + error_msg = f"{type(exc).__name__}: {exc}" + + logging.warning( + "源探测失败: %s (%s) - 错误: %s, 耗时: %.2f ms", + name, + url, + error_msg, + total_ms, + ) + + return ProbeResult( + name=name, + url=url, + ok=False, + total_ms=total_ms, + first_byte_ms=0.0, + error=error_msg, + status_code=None, + response_headers=None, + ) + + +def probe_sources( + group_name: str, sources: Iterable[dict], timeout: float = DEFAULT_TIMEOUT_SECONDS +) -> List[ProbeResult]: + """ + 探测一组源并返回结果。 + + Probe a group of sources and return results. + + 参数 Parameters: + group_name (str): 源组名称(用于缓存) + sources (Iterable[dict]): 源列表,每个源包含 name 和 url + timeout (float): 超时时间(秒) + + 返回 Returns: + List[ProbeResult]: 探测结果列表 + """ + if group_name in _PROBE_CACHE: + logging.info("使用缓存的探测结果: %s", group_name) + return list(_PROBE_CACHE[group_name]) + + sources_list = list(sources) + logging.info("开始探测源组: %s,共 %d 个源", group_name, len(sources_list)) + results: List[ProbeResult] = [] + for source in sources_list: + results.append(probe_url(source["name"], source["url"], timeout=timeout)) + + _PROBE_CACHE[group_name] = list(results) + + successful_count = sum(1 for item in results if item.ok) + logging.info( + "源组 %s 探测完成: %d/%d 成功", group_name, successful_count, len(results) + ) + + return results + + +def pick_best_source(results: Iterable[ProbeResult]) -> Optional[ProbeResult]: + """ + 从探测结果中选择最佳源。 + + Select the best source from probe results. + + 参数 Parameters: + results (Iterable[ProbeResult]): 探测结果列表 + + 返回 Returns: + Optional[ProbeResult]: 最佳源,如果没有成功的源则返回 None + """ + successful = [item for item in results if item.ok] + if not successful: + logging.warning("没有可用的源") + return None + + best = min(successful, key=lambda item: (item.total_ms, item.first_byte_ms)) + logging.info( + "选择最佳源: %s (%s) - 总耗时: %.2f ms, 首字节: %.2f ms", + best.name, + best.url, + best.total_ms, + best.first_byte_ms, + ) + return best + + +def clear_probe_cache() -> None: + _PROBE_CACHE.clear() diff --git a/docs/downloads.html b/docs/downloads.html index 0063e8d2..4ae49efe 100644 --- a/docs/downloads.html +++ b/docs/downloads.html @@ -153,7 +153,7 @@

2026-03-25