diff --git a/.github/workflows/build-release-lite.yml b/.github/workflows/build-release-lite.yml new file mode 100644 index 0000000..ca083e0 --- /dev/null +++ b/.github/workflows/build-release-lite.yml @@ -0,0 +1,106 @@ +name: Build and Release SuperPicky Lite + +on: + workflow_dispatch: + inputs: + version: + description: Version number for the lite release + required: true + +jobs: + build-windows-lite: + runs-on: windows-latest + permissions: + contents: write + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + cache: 'pip' + cache-dependency-path: | + requirements.txt + requirements_cuda.txt + + - name: Resolve release metadata + id: release_meta + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + INPUT_VERSION: ${{ github.event.inputs.version }} + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata + + - name: Build Windows Lite release + env: + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + run: python build_release_win.py --build-type lite --copy-dir output/lite-win --no-zip + + - name: Create Windows Lite installer + uses: Minionguyjpro/Inno-Setup-Action@v1.2.7 + with: + path: ./output/lite-win/installer_lite/SuperPicky-lite.iss + + - name: Collect Windows Lite assets + run: | + python scripts/ci_release.py collect-assets --output-dir release_assets/lite-win --pattern output/lite-win/installer_lite/Output/SuperPicky_Setup_Lite_Win64_*.exe + + - name: Upload Windows Lite assets to GitHub Release + uses: softprops/action-gh-release@v3 + with: + tag_name: ${{ steps.release_meta.outputs.tag }} + name: SuperPicky Lite ${{ steps.release_meta.outputs.tag }} + body_path: ChangeLog.md + fail_on_unmatched_files: true + prerelease: ${{ contains(steps.release_meta.outputs.tag, '-rc') }} + files: | + release_assets/lite-win/* + + build-mac-lite: + runs-on: macos-14 + permissions: + contents: write + + steps: + - name: Checkout repository + uses: actions/checkout@v6 + + - name: Set up Python 3.12 + uses: actions/setup-python@v6 + with: + python-version: '3.12' + cache: 'pip' + cache-dependency-path: requirements_mac.txt + + - name: Resolve release metadata + id: release_meta + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + INPUT_VERSION: ${{ github.event.inputs.version }} + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata + + - name: Install macOS Lite build dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements_mac.txt + + - name: Build unsigned macOS Lite release + env: + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + run: | + python build_release_mac.py --build-type lite --arch arm64 --copy-dir output/lite-mac + + - name: Upload macOS Lite assets to GitHub Release + uses: softprops/action-gh-release@v3 + with: + tag_name: ${{ steps.release_meta.outputs.tag }} + name: SuperPicky Lite ${{ steps.release_meta.outputs.tag }} + body_path: ChangeLog.md + fail_on_unmatched_files: true + prerelease: ${{ contains(steps.release_meta.outputs.tag, '-rc') }} + files: | + output/lite-mac/*.dmg \ No newline at end of file diff --git a/.github/workflows/build-release.yml b/.github/workflows/build-release.yml index b6b2940..a9d1613 100644 --- a/.github/workflows/build-release.yml +++ b/.github/workflows/build-release.yml @@ -18,10 +18,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' cache: 'pip' @@ -30,7 +30,6 @@ jobs: requirements_cuda.txt - name: Prepare telemetry build override - shell: pwsh env: COUNTLY_APP_KEY: ${{ secrets.COUNTLY_APP_KEY }} COUNTLY_SERVER_URL: ${{ secrets.COUNTLY_SERVER_URL }} @@ -38,170 +37,53 @@ jobs: - name: Resolve release metadata id: release_meta - shell: pwsh env: - EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_NAME: ${{ github.event_name }} INPUT_VERSION: ${{ github.event.inputs.version }} - REF_NAME: ${{ github.ref_name }} - run: | - if ($env:EVENT_NAME -eq "workflow_dispatch") { - $rawVersion = $env:INPUT_VERSION - } else { - $rawVersion = $env:REF_NAME - } - - if (-not $rawVersion) { - throw "Release version is required." - } - - $tag = if ($rawVersion.StartsWith("v")) { $rawVersion } else { "v$rawVersion" } - "tag=$tag" >> $env:GITHUB_OUTPUT - "name=SuperPicky $tag" >> $env:GITHUB_OUTPUT + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata - - name: Create build virtual environment - shell: pwsh - run: python -m venv .venv - - - name: Install CPU build dependencies - shell: pwsh - run: | - .\.venv\Scripts\python.exe -m pip install --upgrade pip - .\.venv\Scripts\python.exe -m pip install -r requirements.txt - - - name: Build CPU and CUDA patch release payloads - shell: pwsh + - name: Build Full release env: RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} - run: | - .\.venv\Scripts\python.exe scripts\download_models.py - .\.venv\Scripts\python.exe build_release_win.py --build-type cuda-patch --copy-dir output --debug --no-zip + run: python build_release_win.py --build-type cpu --copy-dir output --no-zip - - name: Create CPU installer with Inno Setup + - name: Create Full installer with Inno Setup uses: Minionguyjpro/Inno-Setup-Action@v1.2.7 with: path: ./output/installer_cpu/SuperPicky.iss - - name: Create CUDA patch installer with Inno Setup + - name: Build Lite release + env: + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + run: python build_release_win.py --build-type lite --copy-dir output --no-zip + + - name: Create Lite installer with Inno Setup uses: Minionguyjpro/Inno-Setup-Action@v1.2.7 with: - path: ./output/cuda_patch_installer/SuperPicky_CUDA_Patch.iss + path: ./output/installer_lite/SuperPicky-lite.iss - name: Prepare release assets - shell: pwsh run: | - New-Item -ItemType Directory -Path release_assets -Force | Out-Null - - $assetPatterns = @( - "output/cuda_patch_installer/Output/SuperPicky_CUDA_Patch_Win64_*.exe", - "output/installer_cpu/Output/SuperPicky_Setup_Win64_*.exe" - ) - - foreach ($pattern in $assetPatterns) { - $foundFiles = Get-ChildItem -Path $pattern -File -ErrorAction Stop - if ($foundFiles.Count -ne 1) { - throw "Expected exactly one asset for pattern '$pattern', found $($foundFiles.Count)." - } - Copy-Item -Path $foundFiles[0].FullName -Destination release_assets/ - } + python scripts/ci_release.py collect-assets --output-dir release_assets --pattern output/installer_cpu/Output/SuperPicky_Setup_Full_Win64_*.exe --pattern output/installer_lite/Output/SuperPicky_Setup_Lite_Win64_*.exe - name: Build code patch zip - shell: pwsh - run: | - $version = .\\.venv\\Scripts\\python.exe -c "from constants import APP_VERSION; print(APP_VERSION)" - $patchVersion = "${{ steps.release_meta.outputs.tag }}" - $zipName = "code_patch_${patchVersion}.zip" - $zipPath = "release_assets/$zipName" - - # 7-Zip 在 windows-latest runner 上内置 - # 包含所有运行时 Python 文件;排除: - # main.py — PyInstaller 入口,patch 后不会生效 - # build_release_win.py — 仅构建用,不需随 patch 分发 - # _telemetry_build.py — 构建时临时覆盖文件 - # pyi_rth_*.py — PyInstaller runtime hook,非运行时逻辑 - # test_*.py — 测试脚本 - $items = @( - "constants.py", "advanced_config.py", "ai_model.py", - "birdid_server.py", "birdid_cli.py", "iqa_scorer.py", - "post_adjustment_engine.py", "server_manager.py", - "superpicky_cli.py", "topiq_model.py", - "tools", "core", "ui", "birdid", "locales" - ) | Where-Object { Test-Path $_ } - & "C:\Program Files\7-Zip\7z.exe" a -tzip $zipPath @items ` - -xr!"__pycache__" -xr!"*.pyc" -xr!"*.pyo" ` - -xr!"main.py" | Out-Null - - $meta = @{ - patch_version = $patchVersion - base_version = $version - applied_at = (Get-Date -Format "o") - } | ConvertTo-Json - $meta | Out-File -Encoding utf8 "release_assets/patch_meta.json" + run: python scripts/ci_release.py build-patch --output-dir release_assets --patch-version ${{ steps.release_meta.outputs.tag }} - name: Clean telemetry build override if: always() - shell: pwsh - run: Remove-Item -Path _telemetry_build.py -Force -ErrorAction SilentlyContinue + run: python scripts/ci_release.py cleanup-paths --path _telemetry_build.py - name: Create GitHub Release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 with: tag_name: ${{ steps.release_meta.outputs.tag }} name: ${{ steps.release_meta.outputs.name }} body_path: ChangeLog.md + fail_on_unmatched_files: true files: | release_assets/* - # - name: Create GitCode Release - # continue-on-error: true - # shell: pwsh - # env: - # GITCODE_TOKEN: ${{ secrets.GITCODE_TOKEN }} - # TAG: ${{ steps.release_meta.outputs.tag }} - # run: | - # $baseUrl = "https://gitcode.com/api/v4/projects/Jamesphotography%2FSuperPicky" - # $tag = $env:TAG - # $headers = @{ "PRIVATE-TOKEN" = $env:GITCODE_TOKEN; "Content-Type" = "application/json" } - # $body = @{ tag_name = $tag; name = "SuperPicky $tag"; description = "See GitHub release for details." } | ConvertTo-Json - # try { - # Invoke-RestMethod -Uri "$baseUrl/releases" -Method Post -Body $body -Headers $headers -ErrorAction Stop - # } catch { - # Write-Host "Release already exists or creation failed, continuing..." - # } - - # - name: Update downloads_github.json and push - # shell: pwsh - # env: - # TAG: ${{ steps.release_meta.outputs.tag }} - # GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # REF: ${{ github.ref_name }} - # run: | - # $tag = $env:TAG - # $assets = gh release view $tag --json assets --jq '.assets[].name' 2>$null - # $macArm64 = ($assets | Where-Object { $_ -match '(?i)arm64.*\.dmg$' } | Select-Object -First 1) - # $winCpu = ($assets | Where-Object { $_ -match '(?i)Setup.*Win.*\.exe$' } | Select-Object -First 1) - # $winCuda = ($assets | Where-Object { $_ -match '(?i)CUDA.*Win.*\.exe$' } | Select-Object -First 1) - # $json = [ordered]@{ - # beta = [ordered]@{ - # tag = $tag - # updated_at = (Get-Date -Format "o") - # files = [ordered]@{ - # mac_arm64 = if ($macArm64) { $macArm64 } else { $null } - # win_cpu = if ($winCpu) { $winCpu } else { $null } - # win_cuda = if ($winCuda) { $winCuda } else { $null } - # } - # } - # } - # $jsonContent = $json | ConvertTo-Json -Depth 5 - # git config user.name "github-actions[bot]" - # git config user.email "github-actions[bot]@users.noreply.github.com" - # git fetch origin nightly - # git checkout -B nightly origin/nightly - # $jsonContent | Out-File -Encoding utf8NoBOM "docs/downloads_github.json" - # git add docs/downloads_github.json - # git diff --cached --quiet && exit 0 - # git commit -m "chore(web): 自动更新 downloads_github.json → $tag [skip ci]" - # git push origin HEAD:nightly - build-mac: runs-on: macos-14 permissions: @@ -209,10 +91,10 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Set up Python 3.12 - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.12' cache: 'pip' @@ -227,299 +109,44 @@ jobs: - name: Resolve release metadata id: release_meta env: - EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_NAME: ${{ github.event_name }} INPUT_VERSION: ${{ github.event.inputs.version }} - REF_NAME: ${{ github.ref_name }} - run: | - if [ "$EVENT_NAME" = "workflow_dispatch" ]; then - RAW_VERSION="$INPUT_VERSION" - else - RAW_VERSION="$REF_NAME" - fi - - if [ -z "$RAW_VERSION" ]; then - echo "Release version is required." >&2 - exit 1 - fi - - if [[ "$RAW_VERSION" == v* ]]; then - TAG="$RAW_VERSION" - else - TAG="v$RAW_VERSION" - fi - - echo "tag=$TAG" >> "$GITHUB_OUTPUT" - echo "name=SuperPicky $TAG" >> "$GITHUB_OUTPUT" + GITHUB_REF_NAME: ${{ github.ref_name }} + run: python scripts/ci_release.py resolve-metadata - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements_mac.txt + python -m pip install -r requirements_mac.txt - - name: Import signing certificate + - name: Materialize signing certificate env: MACOS_CERTIFICATE: ${{ secrets.MACOS_CERTIFICATE }} - MACOS_CERTIFICATE_PWD: ${{ secrets.MACOS_CERTIFICATE_PWD }} - run: | - KEYCHAIN_PATH="$RUNNER_TEMP/build.keychain" - KEYCHAIN_PWD=$(openssl rand -hex 16) - - echo "$MACOS_CERTIFICATE" | base64 --decode > "$RUNNER_TEMP/certificate.p12" - - security create-keychain -p "$KEYCHAIN_PWD" "$KEYCHAIN_PATH" - security set-keychain-settings -lut 21600 "$KEYCHAIN_PATH" - security unlock-keychain -p "$KEYCHAIN_PWD" "$KEYCHAIN_PATH" - security import "$RUNNER_TEMP/certificate.p12" \ - -k "$KEYCHAIN_PATH" \ - -P "$MACOS_CERTIFICATE_PWD" \ - -T /usr/bin/codesign - security list-keychain -d user -s "$KEYCHAIN_PATH" - security set-key-partition-list -S apple-tool:,apple: -s -k "$KEYCHAIN_PWD" "$KEYCHAIN_PATH" - - echo "KEYCHAIN_PWD=$KEYCHAIN_PWD" >> "$GITHUB_ENV" - echo "KEYCHAIN_PATH=$KEYCHAIN_PATH" >> "$GITHUB_ENV" - - - name: Download models - run: python scripts/download_models.py + run: python scripts/ci_release.py materialize-secret-file --env-name MACOS_CERTIFICATE --output $RUNNER_TEMP/certificate.p12 --decode-base64 - - name: Inject commit hash and run PyInstaller + - name: Build notarized macOS release env: - TAG: ${{ steps.release_meta.outputs.tag }} - run: | - COMMIT_HASH=$(git rev-parse --short HEAD) - BUILD_INFO_FILE="core/build_info.py" - - # 判断渠道:纯 vX.Y.Z 为 official,含 -RC 为 nightly - if echo "$TAG" | grep -iqE '\-rc'; then - RELEASE_CHANNEL="nightly" - else - RELEASE_CHANNEL="official" - fi - - cp "$BUILD_INFO_FILE" "${BUILD_INFO_FILE}.backup" - sed -i.tmp "s/COMMIT_HASH = .*/COMMIT_HASH = \"${COMMIT_HASH}\"/" "$BUILD_INFO_FILE" - sed -i.tmp "s/RELEASE_CHANNEL = .*/RELEASE_CHANNEL = \"${RELEASE_CHANNEL}\"/" "$BUILD_INFO_FILE" - rm -f "${BUILD_INFO_FILE}.tmp" - - pyinstaller SuperPicky.spec --clean --noconfirm - - mv "${BUILD_INFO_FILE}.backup" "$BUILD_INFO_FILE" - - echo "COMMIT_HASH=$COMMIT_HASH" >> "$GITHUB_ENV" - - - name: Organize .app bundle resources - run: | - APP_PATH="dist/SuperPicky.app" - mkdir -p "${APP_PATH}/Contents/MacOS" "${APP_PATH}/Contents/Resources" - - if [ -d "dist/SuperPicky" ] && [ ! -f "${APP_PATH}/Contents/MacOS/SuperPicky" ]; then - mv dist/SuperPicky/* "${APP_PATH}/Contents/MacOS/" - fi - - for res in SuperBirdIDPlugin.lrplugin en.lproj zh-Hans.lproj; do - if [ -d "${APP_PATH}/Contents/MacOS/$res" ]; then - mv "${APP_PATH}/Contents/MacOS/$res" "${APP_PATH}/Contents/Resources/" - fi - done - - echo "APP_PATH=$APP_PATH" >> "$GITHUB_ENV" - - - name: Sign application - env: - DEVELOPER_ID: ${{ secrets.MACOS_DEVELOPER_ID }} - run: | - find "${APP_PATH}/Contents" -type f \( -name "*.dylib" -o -name "*.so" -o -perm +111 \) -print0 | \ - xargs -0 -P 8 -I {} codesign --force --sign "$DEVELOPER_ID" --timestamp --options runtime {} 2>/dev/null || true - - codesign --force --deep --sign "$DEVELOPER_ID" \ - --timestamp \ - --options runtime \ - --entitlements entitlements.plist \ - "${APP_PATH}" - - codesign --verify --deep --strict --verbose=2 "${APP_PATH}" - - - name: Install create-dmg - run: brew install create-dmg - - - name: Create DMG - env: - DEVELOPER_ID: ${{ secrets.MACOS_DEVELOPER_ID }} - run: | - VERSION=$(python -c "from constants import APP_VERSION; print(APP_VERSION)") - DMG_NAME="SuperPicky_v${VERSION}_arm64_${COMMIT_HASH}.dmg" - DMG_TEMP="dist/dmg_temp" - - rm -rf "$DMG_TEMP" - mkdir -p "$DMG_TEMP" - - cp -R "${APP_PATH}" "${DMG_TEMP}/" - - if [ -d "SuperBirdIDPlugin.lrplugin" ]; then - cp -R "SuperBirdIDPlugin.lrplugin" "${DMG_TEMP}/" - fi - - # README 安装说明 - if [ -f "resources/DMG_README.txt" ]; then - cp "resources/DMG_README.txt" "${DMG_TEMP}/README.txt" - fi - - # 使用 create-dmg 生成带 Applications 图标的标准安装 DMG - create-dmg \ - --volname "SuperPicky ${VERSION}" \ - --window-pos 200 120 \ - --window-size 580 380 \ - --icon-size 100 \ - --icon "SuperPicky.app" 140 180 \ - --hide-extension "SuperPicky.app" \ - --app-drop-link 440 180 \ - --no-internet-enable \ - "dist/${DMG_NAME}" \ - "$DMG_TEMP" - - rm -rf "$DMG_TEMP" - - codesign --force --sign "$DEVELOPER_ID" --timestamp "dist/${DMG_NAME}" - - echo "DMG_PATH=dist/${DMG_NAME}" >> "$GITHUB_ENV" - - - name: Notarize and staple DMG - env: - APPLE_ID: ${{ secrets.APPLE_ID }} + RELEASE_TAG: ${{ steps.release_meta.outputs.tag }} + MACOS_CERTIFICATE_PWD: ${{ secrets.MACOS_CERTIFICATE_PWD }} APPLE_APP_PASSWORD: ${{ secrets.APPLE_APP_PASSWORD }} - TEAM_ID: ${{ secrets.MACOS_TEAM_ID }} run: | - NOTARIZE_OUTPUT=$(xcrun notarytool submit "${DMG_PATH}" \ - --apple-id "$APPLE_ID" \ - --password "$APPLE_APP_PASSWORD" \ - --team-id "$TEAM_ID" \ - --wait \ - --output-format json 2>&1) - - echo "$NOTARIZE_OUTPUT" - - if ! echo "$NOTARIZE_OUTPUT" | grep -Eq '"status"[[:space:]]*:[[:space:]]*"Accepted"'; then - echo "Notarization failed." >&2 - exit 1 - fi - - xcrun stapler staple "${DMG_PATH}" - xcrun stapler validate "${DMG_PATH}" + python build_release_mac.py --build-type full --arch arm64 --copy-dir output/mac --sign-p12 $RUNNER_TEMP/certificate.p12 --sign-p12-password-env MACOS_CERTIFICATE_PWD --notarize --apple-id ${{ secrets.APPLE_ID }} --team-id ${{ secrets.MACOS_TEAM_ID }} - name: Build code patch zip - env: - TAG: ${{ steps.release_meta.outputs.tag }} - run: | - VERSION=$(python3 -c "from constants import APP_VERSION; print(APP_VERSION)") - PATCH_VERSION="${TAG}" - ZIP_NAME="code_patch_${PATCH_VERSION}.zip" - - # 与 Windows CI 保持一致:包含所有运行时 Python 文件 - # 排除 main.py(PyInstaller 入口,patch 无效) - # build_info.py 必须包含:CI 已注入 RELEASE_CHANNEL,patch 覆盖后不影响正确性 - # 若排除则 code_updates/core/ 遮蔽冻结包导致 ModuleNotFoundError - PATCH_ITEMS=( - constants.py advanced_config.py ai_model.py - birdid_server.py birdid_cli.py iqa_scorer.py - post_adjustment_engine.py server_manager.py - superpicky_cli.py topiq_model.py - tools/ core/ ui/ birdid/ locales/ - ) - EXISTING_ITEMS=() - for item in "${PATCH_ITEMS[@]}"; do - [ -e "$item" ] && EXISTING_ITEMS+=("$item") - done - - zip -r "dist/${ZIP_NAME}" "${EXISTING_ITEMS[@]}" \ - --exclude "**/__pycache__/*" --exclude "**/*.pyc" --exclude "**/*.pyo" - - APPLIED_AT=$(python3 -c "import datetime; print(datetime.datetime.now(datetime.timezone.utc).isoformat())") - printf '{"patch_version":"%s","base_version":"%s","applied_at":"%s"}\n' \ - "$PATCH_VERSION" "$VERSION" "$APPLIED_AT" > dist/patch_meta.json - - echo "PATCH_ZIP=dist/${ZIP_NAME}" >> "$GITHUB_ENV" - echo "PATCH_META=dist/patch_meta.json" >> "$GITHUB_ENV" - - # - name: Push patch files to mirror server - # continue-on-error: true - # env: - # MIRROR_SSH_KEY: ${{ secrets.MIRROR_SSH_KEY }} - # TAG: ${{ steps.release_meta.outputs.tag }} - # run: | - # mkdir -p ~/.ssh - # echo "$MIRROR_SSH_KEY" > ~/.ssh/mirror_key - # chmod 600 ~/.ssh/mirror_key - # VERSION=$(python3 -c "from constants import APP_VERSION; print(APP_VERSION)") - # PATCH_VERSION="${TAG}" - # MIRROR_BASE="http://1.119.150.179:59080/superpicky" - # python3 -c "import json,datetime,sys; d={'tag_name':sys.argv[1],'version':sys.argv[2],'published_at':datetime.datetime.now(datetime.timezone.utc).isoformat(),'patch_meta_url':sys.argv[3]+'/patch_meta.json','patch_zip_url':sys.argv[3]+'/code_patch_'+sys.argv[4]+'.zip'}; json.dump(d,open('dist/latest.json','w'),indent=2)" "$TAG" "$VERSION" "$MIRROR_BASE" "$PATCH_VERSION" - # ssh -i ~/.ssh/mirror_key -p 22 \ - # -o StrictHostKeyChecking=no \ - # -o ConnectTimeout=10 \ - # jordan@1.119.150.179 \ - # "mkdir -p /opt/1panel/www/sites/github/index/superpicky" - # scp -i ~/.ssh/mirror_key -P 22 \ - # -o StrictHostKeyChecking=no \ - # -o ConnectTimeout=10 \ - # dist/latest.json \ - # "${PATCH_META}" \ - # "${PATCH_ZIP}" \ - # jordan@1.119.150.179:/opt/1panel/www/sites/github/index/superpicky/ - # rm -f ~/.ssh/mirror_key + run: python scripts/ci_release.py build-patch --output-dir output/mac --patch-version ${{ steps.release_meta.outputs.tag }} - name: Clean up if: always() - run: | - rm -f _telemetry_build.py - security delete-keychain "$KEYCHAIN_PATH" 2>/dev/null || true + run: python scripts/ci_release.py cleanup-paths --path _telemetry_build.py --path $RUNNER_TEMP/certificate.p12 - name: Upload to GitHub Release - uses: softprops/action-gh-release@v2 + uses: softprops/action-gh-release@v3 with: tag_name: ${{ steps.release_meta.outputs.tag }} name: ${{ steps.release_meta.outputs.name }} body_path: ChangeLog.md + fail_on_unmatched_files: true files: | - ${{ env.DMG_PATH }} - ${{ env.PATCH_ZIP }} - ${{ env.PATCH_META }} - - # - name: Upload patch files to GitCode - # continue-on-error: true - # env: - # GITCODE_TOKEN: ${{ secrets.GITCODE_TOKEN }} - # TAG: ${{ steps.release_meta.outputs.tag }} - # run: | - # BASE_URL="https://gitcode.com/api/v4/projects/Jamesphotography%2FSuperPicky" - # FILE_BASE="https://gitcode.com/Jamesphotography/SuperPicky/-/package_files/generic/release" - # PATCH_ZIP_NAME=$(basename "${PATCH_ZIP}") - # for FILE in dist/latest.json "${PATCH_META}" "${PATCH_ZIP}"; do - # FILENAME=$(basename "$FILE") - # curl -s --fail -T "$FILE" \ - # --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - # "$BASE_URL/packages/generic/release/$TAG/$FILENAME" || true - # done - # curl -s \ - # --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - # --header "Content-Type: application/json" \ - # --request POST \ - # --data "{\"tag_name\":\"$TAG\",\"name\":\"SuperPicky $TAG\",\"description\":\"See GitHub release for details.\"}" \ - # "$BASE_URL/releases" || true - # curl -s \ - # --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - # --header "Content-Type: application/json" \ - # --request POST \ - # --data "{\"name\":\"patch_meta.json\",\"url\":\"$FILE_BASE/$TAG/patch_meta.json\",\"link_type\":\"package\"}" \ - # "$BASE_URL/releases/$TAG/assets/links" || true - # curl -s \ - # --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - # --header "Content-Type: application/json" \ - # --request POST \ - # --data "{\"name\":\"$PATCH_ZIP_NAME\",\"url\":\"$FILE_BASE/$TAG/$PATCH_ZIP_NAME\",\"link_type\":\"package\"}" \ - # "$BASE_URL/releases/$TAG/assets/links" || true - # curl -s \ - # --header "PRIVATE-TOKEN: $GITCODE_TOKEN" \ - # --header "Content-Type: application/json" \ - # --request POST \ - # --data "{\"name\":\"latest.json\",\"url\":\"$FILE_BASE/$TAG/latest.json\",\"link_type\":\"package\"}" \ - # "$BASE_URL/releases/$TAG/assets/links" || true + output/mac/*.dmg + output/mac/code_patch_*.zip + output/mac/patch_meta.json diff --git a/.gitignore b/.gitignore index a7aea6c..66c858b 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,7 @@ build_dist*/* dist*/* !dist/*.dmg build/ +.python-version # IDE .idea/copilot.* diff --git a/AGENTS.md b/AGENTS.md index f957deb..04064bf 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -1,22 +1,204 @@ -# AGENTS.md (Codex / OpenAI Coding Agents) +## 第一性原理 / First Principles -Follow `scripts_dev/AI_CODING_RULES.md` as the project baseline. +请使用第一性原理思考。你不能总是假设我非常清楚自己想要什么和该怎么得到。请保持审慎,从原始需求和问题出发,如果动机和目标不清晰,停下来和我讨论。 +Please use first principles thinking. You should not assume that I always know exactly what I want or how to achieve it. Be cautious and start from the original needs and problems. If the motivation and goals are unclear, stop and discuss with me. -## Mandatory Project Constraints +## 技术方案规范 / Technical Solution Specifications -- Keep files in UTF-8; avoid introducing mojibake. -- For ExifTool non-ASCII metadata writes, prefer UTF-8 temp-file redirection (`-Tag<=file`) over inline command args. -- Preserve Windows/macOS compatibility for paths and subprocess behavior. -- For SQLite in threaded code: either serialize shared-connection access with a lock or use per-thread connections; never assume `check_same_thread=False` is enough. -- Do not directly access private DB connection internals from business code (e.g., `report_db._conn.*`); add thread-safe wrapper methods instead. -- Ensure transaction handling is consistent and defensive (avoid unsynchronized mixed transaction styles; commit only when valid). -- Ensure persistent external processes (like `exiftool -stay_open`) have explicit shutdown and are closed on exit. -- For packaged-only CUDA issues, first suspect packaging/runtime differences. -- In Windows PyInstaller spec for Torch/CUDA, keep `upx=False` unless explicitly re-validated. +当需要你给出修改或者重构方案时必须符合以下规范: +The following specifications must be followed when giving modification or refactoring plans: -## Validation Minimum +* 你是技术专家,所以设计方案时要使用各种工具查询网络资料,确定基本事实,不要给出虚假观点。 + You are a technical expert, so when designing solutions, use various tools to check online resources and ensure the basic facts are correct. Do not provide false opinions. +* 除非我很确定,不然不能随意迁就我的观点,因为我的观点很可能是错的,需要基于基本事实有理有据的说服我同意你的新观点。 + Unless I am very sure, do not easily accommodate my opinions because they may be wrong. You need to convince me to agree with your new views based on facts. +* 给出兼容性或者补丁性的方案时需要给出确定性的理由与我讨论。 + When proposing compatibility or patch solutions, provide definitive reasons for discussion. +* 必须确保方案的逻辑正确,必须经过全链路的逻辑验证。 + Ensure that the solution is logically correct and has been verified across the entire system. -- Run `py -3 -m py_compile` on changed Python files. +## 编码规范 / Coding Specifications + +所有文件读写均需要满足如下规范: +All file reading and writing must meet the following specifications: + +* 使用UTF-8编码,强制所有的中文输出,均为UTF-8。 + Use UTF-8 encoding, and enforce all Chinese output to be UTF-8. +* 在PowerShell中读取含有中文的文件时,限制性** **`chcp 65001`并设置UTF-8输出。 + When reading Chinese files in PowerShell, use** **`chcp 65001` and set UTF-8 output. +* 读取时用** **`open(file, 'r', encoding='utf-8')`方式读取。 + Use** **`open(file, 'r', encoding='utf-8')` to read files. +* 不要使用shell脚本(如sed/awk)处理含中文的文件,优先使用Python(Python 3.x),如果Python环境无法满足需求,再考虑其他语言,最后才考虑PowerShell。 + Do not use shell scripts (like sed/awk) to handle files with Chinese characters. Prefer Python (Python 3.x), and if Python environment cannot meet the requirements, consider other languages, and only as a last resort consider PowerShell. + +## 代码规范 / Code Specifications + +所有代码增删查改均需要满足如下规范: +All code changes (addition, deletion, modification) must meet the following specifications: + +* 先阅读相关代码段落,预先评估代码修改量,如果发现改动文件过多,或者改动量很大,提前分成几个小部分进行修补,避免系统拒绝修补。 + First, read the relevant code sections, assess the extent of the changes, and if too many files are affected or the changes are too large, break them down into smaller parts to avoid rejection by the system. +* 代码按照逻辑顺序进行修补,避免改完之后又回头改。 + Make code changes in logical order to avoid having to go back and modify things again. +* 代码改动完毕后要重新整体阅读全链路,避免出现变量函数未定义未声明导致编译不通过。 + After code changes, review the entire system to ensure there are no undefined or undeclared variables or functions that could cause compilation errors. +* 代码优化精简的时候需要按照逻辑顺序对变量函数进行重排,方便维护者从上到下进行阅读。 + When optimizing and simplifying the code, rearrange variables and functions in logical order to make it easier for maintainers to read from top to bottom. +* 跨文件代码边界维护要清晰分明,高内聚低耦合。 + Maintain clear boundaries for cross-file code, ensuring high cohesion and low coupling. +* 在Python中,避免使用全局变量。优先选择函数或类封装,保持数据和功能分离。 + In Python, avoid using global variables. Prefer encapsulation in functions or classes to separate data and functionality. + +## 注释规范 / Commenting Specifications + +所有注释增删查改均需要满足如下规范: +All comment changes (addition, deletion, modification) must meet the following specifications: + +* 如果没有额外指定,请使用UTF-8编码的中文注释 + 相同格式的英文注释。 + If not otherwise specified, use UTF-8 encoded Chinese comments + corresponding English comments in the same format. +* 需要给出详细且必要的功能说明,增加可维护性,让不熟悉相关类型代码的人也能看懂。 + Provide detailed and necessary functional descriptions to increase maintainability, so that those unfamiliar with the relevant code can understand it. +* 使用docstring格式进行函数、类注释,确保清晰描述函数的功能、参数、返回值及可能的异常。 + Use docstring format for function and class comments, ensuring clear descriptions of the function's functionality, parameters, return values, and possible exceptions. + +```python +def example_function(param: int) -> str: + """ + 这是一个示例函数,接受一个整数作为输入,返回字符串。 + + 参数: + param (int): 输入的整数 + + 返回: + str: 返回一个简单的字符串,表示输入的平方值 + + This is a sample function that takes an integer as input and returns a string. + + Parameters: + param (int): The integer to input + + Return: + str: Returns a simple string representing the square of the input. + """ + + return f"The square is {param ** 2}" +``` + +## 总结汇报规范 / Summary Reporting Specifications + +所有的总结汇报均需要满足如下规范: +All summary reports must meet the following specifications: + +* 改动部分请加上具体文件的行号,如果涉及多个跨行的改动,给出相关段落,方便进行查找。 + Specify the line numbers of the changed parts, and provide relevant sections for easy search if multiple lines are involved. +* 对于Python项目,考虑到代码可能涉及模块导入、功能封装等,需要明确指出哪些模块或类的修改或新增影响了其他模块的功能。 + For Python projects, since the code may involve module imports and function encapsulation, clearly indicate which module or class changes or additions affect the functionality of other modules. + +## Python使用规范 / Python Usage Specifications + +在使用Python语言时均需要满足如下规范: +The following specifications must be met when using Python: + +* **类型注解 / Type Annotations** :尽量使用类型注解(Python 3.x),以增强代码可读性和静态检查工具的支持。例如,函数的输入和输出应该明确标注类型。 + **Type annotations** : Try to use type annotations (Python 3.x) to enhance code readability and static analysis tool support. For example, the input and output of functions should clearly annotate their types. + +```python + def add_numbers(a: int, b: int) -> int: + return a + b +``` + +* **避免使用过于宽泛的类型标注 / Avoid overly broad type annotations** :Python中不存在** **`any`类型,但要尽量避免过于宽泛的类型标注。 + Python does not have an** **`any` type, but avoid overly broad type annotations whenever possible. +* **操作用户文件规范 / User File Operations** :当使用代码操作用户系统中的文件时,要使用安全的方法,并注意权限。对于配置文件的存放位置应该局限在一个文件夹内,不要在用户的文件夹中到处存放零星文件。 + When manipulating user files, use secure methods and be mindful of permissions. The storage location for configuration files should be limited to a single folder, and avoid scattering files across the user's directories. +* **遵循PEP8规范 / Follow PEP8** :始终遵循Python的官方代码风格PEP8,并且使用自动化工具(如** **`black`)进行格式化。 + Always follow the official Python coding style PEP8 and use automation tools (like** **`black`) for formatting. +* **严格使用UTF-8 / Strict Use of UTF-8** :始终遵循Python的官方代码标准PEP686,始终使用 UTF-8 作为文件、标准输入输出和管道的默认编码。 + Always follow Python's official code standard PEP686, and use UTF-8 as the default encoding for files, standard input/output, and pipes. +* **注重安全性 / Focus on Security** :避免直接执行来自不可信来源的代码,如避免使用** **`eval()`或** **`exec()`等函数。使用适当的输入验证和参数化查询,避免SQL注入、XSS等安全漏洞。 + Avoid executing code from untrusted sources, such as using** **`eval()` or** **`exec()`. Use proper input validation and parameterized queries to avoid SQL injection, XSS, and other security vulnerabilities. + +```python + import sqlite3 + connection = sqlite3.connect('database.db') + cursor = connection.cursor() + + # 避免 SQL 注入,使用参数化查询 + cursor.execute("SELECT * FROM users WHERE username = ?", (username,)) +``` + +* **异常处理 / Exception Handling** :要优雅地处理可能的错误和异常,避免程序崩溃。优先使用Python标准库提供的异常机制。 + Handle potential errors and exceptions gracefully to avoid crashes. Use Python's standard exception mechanisms first. + +```python + try: + result = 10 / 0 + except ZeroDivisionError as e: + print(f"Error occurred: {e}") +``` + +## Python 3 环境配置与工具使用规范 / Python 3 Environment Setup and Tool Usage Specifications + +为了避免Python 3工具默认使用系统中的Python环境(可能导致许多不可预料的问题),请务必采用以下规范进行配置: + +* **使用虚拟环境 / Virtual Environment** :优先使用 `venv`或 `conda`等工具创建独立的Python环境,避免使用系统全局环境。 + Prefer using** **`venv` or** **`conda` to create isolated Python environments, avoiding the use of the system's global environment. +* **确保包管理一致性 / Ensure Package Management Consistency** :在项目中使用 `pip`来管理依赖,确保依赖版本的一致性,避免版本冲突和意外问题。 + Use** **`pip` to manage dependencies in the project, ensuring version consistency and avoiding conflicts and unexpected issues. +* **工具使用推荐 / Recommended Tool Usage** :为了避免依赖于系统环境的Python,建议使用虚拟环境中的解释器进行构建和运行。 + To avoid relying on the system environment's Python, it is recommended to use the interpreter in the virtual environment for builds and executions. + +## 多系统规范 / Multi-System Specifications + +### 1. 避免多系统之间的差异导致程序出现无法运行甚至安全漏洞 / Avoid System-Specific Differences Leading to Errors or Security Vulnerabilities + +- 在开发跨平台应用时,需避免代码中因操作系统差异(如Windows与Linux、macOS之间的差异)导致程序无法运行或出现安全漏洞。 + When developing cross-platform applications, avoid code differences that cause errors or security vulnerabilities due to differences between operating systems (e.g., Windows vs. Linux or macOS). +- **路径问题**:文件路径的格式在不同操作系统间有所不同。确保使用跨平台兼容的路径分隔符,推荐使用Python的 `os.path`模块,或 `pathlib`模块来自动处理路径分隔符。 + **Path Issues**: File path formats differ across operating systems. Ensure the use of cross-platform compatible path separators. It is recommended to use Python's `os.path` or `pathlib` modules to automatically handle path separators. + + ``` + from pathlib import Path + + file_path = Path("some_folder") / "file.txt" # This works across all OS + ``` +- **换行符问题**:Windows和类Unix系统的换行符不同。 + **Line Endings**: Line endings differ between Windows and Unix-based systems. + +### 2. 不同系统的文件存储策略和文件夹权限管理不同,需要提前预防 / Different Systems Have Different File Storage and Folder Permissions + +- 在设计涉及文件存储和访问的应用时,需注意不同操作系统对文件权限和路径访问的管理差异。Windows、Linux和macOS在文件权限、符号链接和隐藏文件的处理上有所不同。 + When designing applications that involve file storage and access, be aware of the differences in file permission and path access management across operating systems. Windows, Linux, and macOS handle file permissions, symlinks, and hidden files differently. +- **权限问题**:Linux和macOS有严格的文件权限控制,而Windows则使用ACL(访问控制列表)来管理权限。确保文件的读写权限适合所使用的操作系统,并且文件夹权限应在应用设计时进行适当配置。 + **Permission Issues**: Linux and macOS have strict file permission controls, while Windows uses ACLs (Access Control Lists) for permission management. Ensure that file read/write permissions are suitable for the operating system in use, and folder permissions should be appropriately configured during application design. + +### 3. 避免大量使用PowerShell代码 / Avoid Excessive Use of PowerShell Code + +- PowerShell主要是Windows环境下使用的脚本语言,避免在跨平台项目中广泛使用PowerShell。为了确保程序的兼容性,尽量使用Python脚本或其他语言。 + PowerShell is primarily used in Windows environments. Avoid using PowerShell extensively in cross-platform projects. To ensure compatibility, try to use Python scripts or other languages instead. +- 如果必须使用PowerShell,请确保通过条件语句检查操作系统类型,并仅在Windows系统中执行相关命令。 + If PowerShell must be used, ensure that conditional statements are used to check the operating system and only execute related commands on Windows systems. + + ``` + import platform + + if platform.system() == "Windows": + # Execute PowerShell command + pass + ``` + +## Always Enforce + +- UTF-8 safety first; do not introduce Chinese text corruption. +- ExifTool Chinese metadata writes must use UTF-8 temp files (`-XMP:Title<=tmp.txt`) instead of inline CLI values. +- Keep changes cross-platform (Windows + macOS). +- Any persistent external process must have deterministic cleanup on task/app exit. +- Packaged CUDA failures: prioritize packaging/runtime diagnosis before algorithm refactors. +- Keep Windows Torch/CUDA packaging with `upx=False` unless explicitly requested and validated. + +## Minimum Verification + +- Run `.venv*/bin/python -m py_compile` on changed Python files. - For metadata changes: write + read-back verification with Chinese sample values. - For `.spec` changes: packaged startup smoke test. - For DB/threading changes: run a small multi-thread write/read stress check and confirm no transaction-state errors. diff --git a/CLAUDE.md b/CLAUDE.md index fd3176d..896179a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -13,7 +13,199 @@ Use `scripts_dev/AI_CODING_RULES.md` as the single source of truth for this repo ## Minimum Verification -- `py -3 -m py_compile` for changed Python modules. -- Metadata write/read-back check for non-ASCII fields. -- Packaged app smoke test when `.spec` or runtime packaging behavior changes. +- Run `.venv*/bin/python -m py_compile` on changed Python files. +- For metadata changes: write + read-back verification with Chinese sample values. +- For `.spec` changes: packaged startup smoke test. +- For DB/threading changes: run a small multi-thread write/read stress check and confirm no transaction-state errors. +## 第一性原理 / First Principles + +请使用第一性原理思考。你不能总是假设我非常清楚自己想要什么和该怎么得到。请保持审慎,从原始需求和问题出发,如果动机和目标不清晰,停下来和我讨论。 +Please use first principles thinking. You should not assume that I always know exactly what I want or how to achieve it. Be cautious and start from the original needs and problems. If the motivation and goals are unclear, stop and discuss with me. + +## 技术方案规范 / Technical Solution Specifications + +当需要你给出修改或者重构方案时必须符合以下规范: +The following specifications must be followed when giving modification or refactoring plans: + +* 你是技术专家,所以设计方案时要使用各种工具查询网络资料,确定基本事实,不要给出虚假观点。 + You are a technical expert, so when designing solutions, use various tools to check online resources and ensure the basic facts are correct. Do not provide false opinions. +* 除非我很确定,不然不能随意迁就我的观点,因为我的观点很可能是错的,需要基于基本事实有理有据的说服我同意你的新观点。 + Unless I am very sure, do not easily accommodate my opinions because they may be wrong. You need to convince me to agree with your new views based on facts. +* 给出兼容性或者补丁性的方案时需要给出确定性的理由与我讨论。 + When proposing compatibility or patch solutions, provide definitive reasons for discussion. +* 必须确保方案的逻辑正确,必须经过全链路的逻辑验证。 + Ensure that the solution is logically correct and has been verified across the entire system. + +## 编码规范 / Coding Specifications + +所有文件读写均需要满足如下规范: +All file reading and writing must meet the following specifications: + +* 使用UTF-8编码,强制所有的中文输出,均为UTF-8。 + Use UTF-8 encoding, and enforce all Chinese output to be UTF-8. +* 在PowerShell中读取含有中文的文件时,限制性** **`chcp 65001`并设置UTF-8输出。 + When reading Chinese files in PowerShell, use** **`chcp 65001` and set UTF-8 output. +* 读取时用** **`open(file, 'r', encoding='utf-8')`方式读取。 + Use** **`open(file, 'r', encoding='utf-8')` to read files. +* 不要使用shell脚本(如sed/awk)处理含中文的文件,优先使用Python(Python 3.x),如果Python环境无法满足需求,再考虑其他语言,最后才考虑PowerShell。 + Do not use shell scripts (like sed/awk) to handle files with Chinese characters. Prefer Python (Python 3.x), and if Python environment cannot meet the requirements, consider other languages, and only as a last resort consider PowerShell. + +## 代码规范 / Code Specifications + +所有代码增删查改均需要满足如下规范: +All code changes (addition, deletion, modification) must meet the following specifications: + +* 先阅读相关代码段落,预先评估代码修改量,如果发现改动文件过多,或者改动量很大,提前分成几个小部分进行修补,避免系统拒绝修补。 + First, read the relevant code sections, assess the extent of the changes, and if too many files are affected or the changes are too large, break them down into smaller parts to avoid rejection by the system. +* 代码按照逻辑顺序进行修补,避免改完之后又回头改。 + Make code changes in logical order to avoid having to go back and modify things again. +* 代码改动完毕后要重新整体阅读全链路,避免出现变量函数未定义未声明导致编译不通过。 + After code changes, review the entire system to ensure there are no undefined or undeclared variables or functions that could cause compilation errors. +* 代码优化精简的时候需要按照逻辑顺序对变量函数进行重排,方便维护者从上到下进行阅读。 + When optimizing and simplifying the code, rearrange variables and functions in logical order to make it easier for maintainers to read from top to bottom. +* 跨文件代码边界维护要清晰分明,高内聚低耦合。 + Maintain clear boundaries for cross-file code, ensuring high cohesion and low coupling. +* 在Python中,避免使用全局变量。优先选择函数或类封装,保持数据和功能分离。 + In Python, avoid using global variables. Prefer encapsulation in functions or classes to separate data and functionality. + +## 注释规范 / Commenting Specifications + +所有注释增删查改均需要满足如下规范: +All comment changes (addition, deletion, modification) must meet the following specifications: + +* 如果没有额外指定,请使用UTF-8编码的中文注释 + 相同格式的英文注释。 + If not otherwise specified, use UTF-8 encoded Chinese comments + corresponding English comments in the same format. +* 需要给出详细且必要的功能说明,增加可维护性,让不熟悉相关类型代码的人也能看懂。 + Provide detailed and necessary functional descriptions to increase maintainability, so that those unfamiliar with the relevant code can understand it. +* 使用docstring格式进行函数、类注释,确保清晰描述函数的功能、参数、返回值及可能的异常。 + Use docstring format for function and class comments, ensuring clear descriptions of the function's functionality, parameters, return values, and possible exceptions. + +```python +def example_function(param: int) -> str: + """ + 这是一个示例函数,接受一个整数作为输入,返回字符串。 + + 参数: + param (int): 输入的整数 + + 返回: + str: 返回一个简单的字符串,表示输入的平方值 + + This is a sample function that takes an integer as input and returns a string. + + Parameters: + param (int): The integer to input + + Return: + str: Returns a simple string representing the square of the input. + """ + + return f"The square is {param ** 2}" +``` + +## 总结汇报规范 / Summary Reporting Specifications + +所有的总结汇报均需要满足如下规范: +All summary reports must meet the following specifications: + +* 改动部分请加上具体文件的行号,如果涉及多个跨行的改动,给出相关段落,方便进行查找。 + Specify the line numbers of the changed parts, and provide relevant sections for easy search if multiple lines are involved. +* 对于Python项目,考虑到代码可能涉及模块导入、功能封装等,需要明确指出哪些模块或类的修改或新增影响了其他模块的功能。 + For Python projects, since the code may involve module imports and function encapsulation, clearly indicate which module or class changes or additions affect the functionality of other modules. + +## Python使用规范 / Python Usage Specifications + +在使用Python语言时均需要满足如下规范: +The following specifications must be met when using Python: + +* **类型注解 / Type Annotations** :尽量使用类型注解(Python 3.x),以增强代码可读性和静态检查工具的支持。例如,函数的输入和输出应该明确标注类型。 + **Type annotations** : Try to use type annotations (Python 3.x) to enhance code readability and static analysis tool support. For example, the input and output of functions should clearly annotate their types. + +```python + def add_numbers(a: int, b: int) -> int: + return a + b +``` + +* **避免使用过于宽泛的类型标注 / Avoid overly broad type annotations** :Python中不存在** **`any`类型,但要尽量避免过于宽泛的类型标注。 + Python does not have an** **`any` type, but avoid overly broad type annotations whenever possible. +* **操作用户文件规范 / User File Operations** :当使用代码操作用户系统中的文件时,要使用安全的方法,并注意权限。对于配置文件的存放位置应该局限在一个文件夹内,不要在用户的文件夹中到处存放零星文件。 + When manipulating user files, use secure methods and be mindful of permissions. The storage location for configuration files should be limited to a single folder, and avoid scattering files across the user's directories. +* **遵循PEP8规范 / Follow PEP8** :始终遵循Python的官方代码风格PEP8,并且使用自动化工具(如** **`black`)进行格式化。 + Always follow the official Python coding style PEP8 and use automation tools (like** **`black`) for formatting. +* **严格使用UTF-8 / Strict Use of UTF-8** :始终遵循Python的官方代码标准PEP686,始终使用 UTF-8 作为文件、标准输入输出和管道的默认编码。 + Always follow Python's official code standard PEP686, and use UTF-8 as the default encoding for files, standard input/output, and pipes. +* **注重安全性 / Focus on Security** :避免直接执行来自不可信来源的代码,如避免使用** **`eval()`或** **`exec()`等函数。使用适当的输入验证和参数化查询,避免SQL注入、XSS等安全漏洞。 + Avoid executing code from untrusted sources, such as using** **`eval()` or** **`exec()`. Use proper input validation and parameterized queries to avoid SQL injection, XSS, and other security vulnerabilities. + +```python + import sqlite3 + connection = sqlite3.connect('database.db') + cursor = connection.cursor() + + # 避免 SQL 注入,使用参数化查询 + cursor.execute("SELECT * FROM users WHERE username = ?", (username,)) +``` + +* **异常处理 / Exception Handling** :要优雅地处理可能的错误和异常,避免程序崩溃。优先使用Python标准库提供的异常机制。 + Handle potential errors and exceptions gracefully to avoid crashes. Use Python's standard exception mechanisms first. + +```python + try: + result = 10 / 0 + except ZeroDivisionError as e: + print(f"Error occurred: {e}") +``` + +## Python 3 环境配置与工具使用规范 / Python 3 Environment Setup and Tool Usage Specifications + +为了避免Python 3工具默认使用系统中的Python环境(可能导致许多不可预料的问题),请务必采用以下规范进行配置: + +* **使用虚拟环境 / Virtual Environment** :优先使用 `venv`或 `conda`等工具创建独立的Python环境,避免使用系统全局环境。 + Prefer using** **`venv` or** **`conda` to create isolated Python environments, avoiding the use of the system's global environment. +* **确保包管理一致性 / Ensure Package Management Consistency** :在项目中使用 `pip`来管理依赖,确保依赖版本的一致性,避免版本冲突和意外问题。 + Use** **`pip` to manage dependencies in the project, ensuring version consistency and avoiding conflicts and unexpected issues. +* **工具使用推荐 / Recommended Tool Usage** :为了避免依赖于系统环境的Python,建议使用虚拟环境中的解释器进行构建和运行。 + To avoid relying on the system environment's Python, it is recommended to use the interpreter in the virtual environment for builds and executions. + +## 多系统规范 / Multi-System Specifications + +### 1. 避免多系统之间的差异导致程序出现无法运行甚至安全漏洞 / Avoid System-Specific Differences Leading to Errors or Security Vulnerabilities + +- 在开发跨平台应用时,需避免代码中因操作系统差异(如Windows与Linux、macOS之间的差异)导致程序无法运行或出现安全漏洞。 + When developing cross-platform applications, avoid code differences that cause errors or security vulnerabilities due to differences between operating systems (e.g., Windows vs. Linux or macOS). + +- **路径问题**:文件路径的格式在不同操作系统间有所不同。确保使用跨平台兼容的路径分隔符,推荐使用Python的 `os.path`模块,或 `pathlib`模块来自动处理路径分隔符。 + **Path Issues**: File path formats differ across operating systems. Ensure the use of cross-platform compatible path separators. It is recommended to use Python's `os.path` or `pathlib` modules to automatically handle path separators. + + ``` + from pathlib import Path + + file_path = Path("some_folder") / "file.txt" # This works across all OS + ``` + +- **换行符问题**:Windows和类Unix系统的换行符不同。 + **Line Endings**: Line endings differ between Windows and Unix-based systems. + +### 2. 不同系统的文件存储策略和文件夹权限管理不同,需要提前预防 / Different Systems Have Different File Storage and Folder Permissions + +- 在设计涉及文件存储和访问的应用时,需注意不同操作系统对文件权限和路径访问的管理差异。Windows、Linux和macOS在文件权限、符号链接和隐藏文件的处理上有所不同。 + When designing applications that involve file storage and access, be aware of the differences in file permission and path access management across operating systems. Windows, Linux, and macOS handle file permissions, symlinks, and hidden files differently. +- **权限问题**:Linux和macOS有严格的文件权限控制,而Windows则使用ACL(访问控制列表)来管理权限。确保文件的读写权限适合所使用的操作系统,并且文件夹权限应在应用设计时进行适当配置。 + **Permission Issues**: Linux and macOS have strict file permission controls, while Windows uses ACLs (Access Control Lists) for permission management. Ensure that file read/write permissions are suitable for the operating system in use, and folder permissions should be appropriately configured during application design. + +### 3. 避免大量使用PowerShell代码 / Avoid Excessive Use of PowerShell Code + +- PowerShell主要是Windows环境下使用的脚本语言,避免在跨平台项目中广泛使用PowerShell。为了确保程序的兼容性,尽量使用Python脚本或其他语言。 + PowerShell is primarily used in Windows environments. Avoid using PowerShell extensively in cross-platform projects. To ensure compatibility, try to use Python scripts or other languages instead. + +- 如果必须使用PowerShell,请确保通过条件语句检查操作系统类型,并仅在Windows系统中执行相关命令。 + If PowerShell must be used, ensure that conditional statements are used to check the operating system and only execute related commands on Windows systems. + + ``` + import platform + + if platform.system() == "Windows": + # Execute PowerShell command + pass + ``` diff --git a/SuperPicky.spec b/SuperPicky.spec index 2c362f2..0f24fbe 100644 --- a/SuperPicky.spec +++ b/SuperPicky.spec @@ -5,6 +5,14 @@ import sys sys.path.append(os.path.abspath('.')) from constants import APP_VERSION + +def _env_or_none(name): + value = os.environ.get(name, "").strip() + return value or None + + +APP_VERSION = os.environ.get("SUPERPICKY_APP_VERSION", APP_VERSION) + # 获取当前工作目录 base_path = os.path.abspath('.') @@ -161,9 +169,9 @@ exe = EXE( console=False, disable_windowed_traceback=False, argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, + target_arch=_env_or_none("SUPERPICKY_TARGET_ARCH"), + codesign_identity=_env_or_none("SUPERPICKY_CODESIGN_IDENTITY"), + entitlements_file=_env_or_none("SUPERPICKY_ENTITLEMENTS_FILE"), icon=os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns') if os.path.exists(os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns')) else None, ) diff --git a/SuperPicky_full.spec b/SuperPicky_full.spec new file mode 100644 index 0000000..5f67072 --- /dev/null +++ b/SuperPicky_full.spec @@ -0,0 +1,5 @@ +# Full-package compatibility wrapper. +# We intentionally keep the legacy full build spec unchanged and expose a new +# explicit entrypoint so release automation can choose between full/lite builds. +exec(open("SuperPicky.spec", "r", encoding="utf-8").read()) + diff --git a/SuperPicky_lite.spec b/SuperPicky_lite.spec new file mode 100644 index 0000000..05f08ca --- /dev/null +++ b/SuperPicky_lite.spec @@ -0,0 +1,194 @@ +import os +import site +from PyInstaller.utils.hooks import collect_data_files, copy_metadata +import sys +sys.path.append(os.path.abspath('.')) +from constants import APP_VERSION + + +def _env_or_none(name): + value = os.environ.get(name, "").strip() + return value or None + + +def _optional_copy_metadata(package_name): + try: + return copy_metadata(package_name) + except Exception: + return [] + + +APP_VERSION = os.environ.get("SUPERPICKY_APP_VERSION", APP_VERSION) + +base_path = os.path.abspath('.') +sp = [p for p in site.getsitepackages() if os.path.isdir(p)] +site_packages = sp[0] if sp else site.getusersitepackages() + +ultralytics_base = site_packages +if not os.path.exists(os.path.join(ultralytics_base, 'ultralytics')): + try: + import ultralytics + ultralytics_base = os.path.dirname(os.path.dirname(ultralytics.__file__)) + except ImportError: + pass + +ultralytics_datas = collect_data_files('ultralytics') +imageio_datas = collect_data_files('imageio') +rawpy_datas = collect_data_files('rawpy') +pillow_heif_datas = collect_data_files('pillow_heif') + +all_datas = [ + # Lite keeps the AI runtime bundled for startup stability while still + # allowing the first-run flow to fetch missing resources on demand. + (os.path.join(base_path, 'exiftools_mac'), 'exiftools_mac'), + (os.path.join(base_path, 'img'), 'img'), + (os.path.join(base_path, 'locales'), 'locales'), + (os.path.join(base_path, 'locales', 'en.lproj'), 'en.lproj'), + (os.path.join(base_path, 'locales', 'zh-Hans.lproj'), 'zh-Hans.lproj'), + (os.path.join(base_path, 'models', 'yolo11l-seg.pt'), 'models'), + (os.path.join(base_path, 'birdid', 'data', 'bird_reference.sqlite'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_classid_mapping.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_regions.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'offline_ebird_data'), 'birdid/data/offline_ebird_data'), + (os.path.join(ultralytics_base, 'ultralytics/cfg'), 'ultralytics/cfg'), + (os.path.join(base_path, 'SuperBirdIDPlugin.lrplugin'), 'SuperBirdIDPlugin.lrplugin'), + (os.path.join(base_path, 'ioc'), 'ioc'), +] + +all_datas.extend(ultralytics_datas) +all_datas.extend(imageio_datas) +all_datas.extend(rawpy_datas) +all_datas.extend(pillow_heif_datas) +all_datas.extend(_optional_copy_metadata('imageio')) +all_datas.extend(_optional_copy_metadata('rawpy')) +all_datas.extend(_optional_copy_metadata('ultralytics')) +all_datas.extend(_optional_copy_metadata('pillow_heif')) +all_datas.extend(_optional_copy_metadata('pi_heif')) + +app_hiddenimports = [ + 'ultralytics', + 'torch', + 'torchvision', + 'torchvision.models', + 'torchvision.transforms', + 'torchvision.transforms.functional', + 'torchaudio', + 'PIL', + 'cv2', + 'numpy', + 'yaml', + 'matplotlib', + 'matplotlib.pyplot', + 'matplotlib.backends.backend_agg', + 'PySide6', + 'PySide6.QtCore', + 'PySide6.QtGui', + 'PySide6.QtWidgets', + 'timm', + 'timm.models', + 'timm.models.resnet', + 'imageio', + 'rawpy', + 'imagehash', + 'pywt', + 'pillow_heif', + 'pi_heif', + 'core', + 'core.burst_detector', + 'core.config_manager', + 'core.exposure_detector', + 'core.file_manager', + 'core.flight_detector', + 'core.focus_point_detector', + 'core.initialization_manager', + 'core.keypoint_detector', + 'core.photo_processor', + 'core.rating_engine', + 'core.source_probe', + 'core.stats_formatter', + 'multiprocessing', + 'multiprocessing.spawn', + 'tools.update_checker', + 'packaging', + 'packaging.version', + 'birdid', + 'birdid.bird_identifier', + 'birdid.ebird_country_filter', + 'birdid_server', + 'server_manager', + 'flask', + 'flask.json', + 'cryptography', + 'cryptography.fernet', + '_telemetry_build', + 'app_user_stat._telemetry_build', + 'app_user_stat', + 'app_user_stat.telemetry', + 'app_user_stat.consent_texts', + 'app_user_stat.consent_texts.en_US', + 'app_user_stat.consent_texts.zh_CN', +] + +a = Analysis( + ['main.py'], + pathex=[base_path], + binaries=[], + datas=all_datas, + hiddenimports=app_hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=['pyi_rth_cv2.py'] if os.path.exists('pyi_rth_cv2.py') else [], + excludes=[ + 'PyQt5', 'PyQt6', 'tkinter', + 'polars', 'numba', 'llvmlite', 'pyarrow', 'facexlib', 'datasets', + ], + noarchive=False, + optimize=0, +) + +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='SuperPickyLite', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=_env_or_none("SUPERPICKY_TARGET_ARCH"), + codesign_identity=_env_or_none("SUPERPICKY_CODESIGN_IDENTITY"), + entitlements_file=_env_or_none("SUPERPICKY_ENTITLEMENTS_FILE"), + icon=os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns') if os.path.exists(os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns')) else None, +) + +coll = COLLECT( + exe, + a.binaries, + a.datas, + strip=False, + upx=False, + upx_exclude=[], + name='SuperPickyLite', +) + +app = BUNDLE( + coll, + name='SuperPickyLite.app', + icon=os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns') if os.path.exists(os.path.join(base_path, 'img', 'SuperPicky-V0.02.icns')) else None, + bundle_identifier='com.jamesphotography.superpicky.lite', + info_plist={ + 'CFBundleName': 'SuperPickyLite', + 'CFBundleDisplayName': 'SuperPickyLite', + 'CFBundleVersion': APP_VERSION, + 'CFBundleShortVersionString': APP_VERSION, + 'NSHighResolutionCapable': True, + 'NSAppleEventsUsageDescription': '慧眼选鸟需要发送 AppleEvents 与其他应用通信。', + 'NSAppleScriptEnabled': False, + }, +) diff --git a/SuperPicky_lite_win.spec b/SuperPicky_lite_win.spec new file mode 100644 index 0000000..3cdfe74 --- /dev/null +++ b/SuperPicky_lite_win.spec @@ -0,0 +1,223 @@ +import os +import site +import sys +from PyInstaller.utils.hooks import collect_data_files, copy_metadata + +sys.path.append(os.path.abspath('.')) + +base_path = os.path.abspath('.') + + +def _optional_copy_metadata(package_name): + try: + return copy_metadata(package_name) + except Exception: + return [] + + +sp = [p for p in site.getsitepackages() if os.path.isdir(p)] +site_packages = sp[0] if sp else site.getusersitepackages() + +ultralytics_base = site_packages +if not os.path.exists(os.path.join(ultralytics_base, 'ultralytics')): + try: + import ultralytics + ultralytics_base = os.path.dirname(os.path.dirname(ultralytics.__file__)) + except ImportError: + pass + +ultralytics_datas = collect_data_files('ultralytics') +imageio_datas = collect_data_files('imageio') +rawpy_datas = collect_data_files('rawpy') +pillow_heif_datas = collect_data_files('pillow_heif') + +all_datas = [ + (os.path.join(base_path, 'exiftools_win'), 'exiftools_win'), + (os.path.join(base_path, 'img'), 'img'), + (os.path.join(base_path, 'locales'), 'locales'), + (os.path.join(base_path, 'ioc'), 'ioc'), + (os.path.join(base_path, 'models', 'yolo11l-seg.pt'), 'models'), + (os.path.join(base_path, 'birdid', 'data', 'bird_reference.sqlite'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_classid_mapping.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'ebird_regions.json'), 'birdid/data'), + (os.path.join(base_path, 'birdid', 'data', 'offline_ebird_data'), 'birdid/data/offline_ebird_data'), + (os.path.join(base_path, 'SuperBirdIDPlugin.lrplugin'), 'SuperBirdIDPlugin.lrplugin'), + (os.path.join(base_path, 'requirements_base.txt'), '.'), + (os.path.join(base_path, 'core', 'runtime_requirements.py'), 'core'), + (os.path.join(ultralytics_base, 'ultralytics', 'cfg'), 'ultralytics/cfg'), +] + +all_datas.extend(ultralytics_datas) +all_datas.extend(imageio_datas) +all_datas.extend(rawpy_datas) +all_datas.extend(pillow_heif_datas) +all_datas.extend(_optional_copy_metadata('imageio')) +all_datas.extend(_optional_copy_metadata('rawpy')) +all_datas.extend(_optional_copy_metadata('ultralytics')) +all_datas.extend(_optional_copy_metadata('pillow_heif')) +all_datas.extend(_optional_copy_metadata('pi_heif')) + +# Windows Lite 冻结包会在主程序最早期进入 `--runtime-bootstrap` 路径, +# 由打包后的可执行文件自身安装并校验 Torch 运行时。 +# PyInstaller 对这条链路里的标准库模块并不总能静态识别, +# 所以需要在这里集中声明,避免后续再零散追加到主 hiddenimports 列表。 +# The Windows Lite frozen build enters `--runtime-bootstrap` very early and +# installs/verifies the Torch runtime from the packaged executable itself. +# PyInstaller does not always detect the stdlib modules used along that path, +# so keep them centralized here instead of appending ad hoc entries later. +runtime_bootstrap_stdlib_hiddenimports = [ + 'argparse', + 'ast', + 'base64', + 'bisect', + 'cProfile', + 'concurrent', + 'copy', + 'csv', + 'ctypes', + 'dataclasses', + 'datetime', + 'difflib', + 'dis', + 'enum', + 'faulthandler', + 'fnmatch', + 'gc', + 'getpass', + 'glob', + 'gzip', + 'hashlib', + 'heapq', + 'inspect', + 'ipaddress', + 'linecache', + 'locale', + 'modulefinder', + 'numbers', + 'pickletools', + 'profile', + 'pprint', + 'pstats', + 'queue', + 'resource', + 'runpy', + 'shlex', + 'signal', + 'sqlite3', + 'statistics', + 'sysconfig', + 'tarfile', + 'timeit', + 'tokenize', + 'traceback', + 'unittest', + 'uuid', + 'weakref', + 'xml', + 'zipfile', +] + +app_hiddenimports = [ + 'ultralytics', + 'PIL', + 'cv2', + 'numpy', + 'yaml', + 'PySide6', + 'PySide6.QtCore', + 'PySide6.QtGui', + 'PySide6.QtWidgets', + 'imageio', + 'rawpy', + 'imagehash', + 'pywt', + 'pillow_heif', + 'core', + 'core.burst_detector', + 'core.config_manager', + 'core.exposure_detector', + 'core.file_manager', + 'core.flight_detector', + 'core.focus_point_detector', + 'core.initialization_manager', + 'core.keypoint_detector', + 'core.photo_processor', + 'core.rating_engine', + 'core.runtime_bootstrap', + 'core.source_probe', + 'core.stats_formatter', + 'multiprocessing', + 'multiprocessing.spawn', + 'tools.update_checker', + 'packaging', + 'packaging.version', + 'birdid', + 'birdid.bird_identifier', + 'birdid.ebird_country_filter', + 'birdid_server', + 'server_manager', + 'flask', + 'flask.json', + 'cryptography', + 'cryptography.fernet', + '_telemetry_build', + 'app_user_stat._telemetry_build', + 'app_user_stat', + 'app_user_stat.telemetry', + 'app_user_stat.consent_texts', + 'app_user_stat.consent_texts.en_US', + 'app_user_stat.consent_texts.zh_CN', +] + +a = Analysis( + ['main.py'], + pathex=[base_path], + binaries=[], + datas=all_datas, + hiddenimports=app_hiddenimports + runtime_bootstrap_stdlib_hiddenimports, + hookspath=[], + hooksconfig={}, + runtime_hooks=['pyi_rth_cv2.py'] if os.path.exists('pyi_rth_cv2.py') else [], + excludes=[ + 'torch', 'torchvision', 'torchaudio', 'timm', + 'PyQt5', 'PyQt6', 'tkinter', + 'polars', 'numba', 'llvmlite', 'pyarrow', 'facexlib', 'datasets', + ], + noarchive=False, + optimize=0, +) + +pyz = PYZ(a.pure) + +icon_path = os.path.join(base_path, 'img', 'icon.ico') +if not os.path.exists(icon_path): + icon_path = None + +exe = EXE( + pyz, + a.scripts, + [], + exclude_binaries=True, + name='SuperPicky', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=False, + console=False, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, + icon=icon_path, +) + +coll = COLLECT( + exe, + a.binaries, + a.datas, + strip=False, + upx=False, + upx_exclude=[], + name='SuperPicky', +) diff --git a/advanced_config.py b/advanced_config.py index d8223ae..a953537 100644 --- a/advanced_config.py +++ b/advanced_config.py @@ -88,7 +88,32 @@ class AdvancedConfig: # 更新提醒控制 "ignored_update_version": None, # 跳过提醒的版本号,如 "4.3.0" "include_prerelease": False, # 是否接收 Beta/RC 更新提醒 - "auto_check_updates": True, # 启动时自动检查更新(含补丁) + "auto_check_updates": False, # 启动时自动检查更新(含补丁) + + # V4.3+: 轻量底包首启初始化状态 + "initialization_completed": False, + "initialization_manifest_version": "v1", + "initialization_in_progress": False, + "last_init_exit_reason": "none", + "last_init_mode": "none", + + # V4.3+: 运行时选择与能力探测 + "selected_runtime_variant": "auto", # auto | cpu | cuda | mac + "detected_cuda_capable": False, + "runtime_install_location_preference": None, # None | default | install + "resolved_runtime_dir": None, + + # V4.3+: 首启启用的功能集与资源记录 + "enabled_feature_set": [ + "core_detection", + "quality", + "keypoint", + "flight", + "birdid", + ], + "downloaded_resources": {}, + "resolved_source_map": {}, + "last_init_error": None, # 主界面复选框状态 "flight_check": False, # 飞鸟检测默认关闭(开启后速度较慢,用户可手动开启) @@ -368,12 +393,115 @@ def set_include_prerelease(self, value: bool): @property def auto_check_updates(self) -> bool: - return self.config.get("auto_check_updates", True) + return self.config.get("auto_check_updates", False) def set_auto_check_updates(self, value: bool): """设置启动时是否自动检查更新。""" self.config["auto_check_updates"] = bool(value) + # V4.3+: 首启初始化状态 getter/setter + def _set_init_config(self, key: str, value): + self.config[key] = value + + @property + def initialization_completed(self) -> bool: + return self.config.get("initialization_completed", False) + + def set_initialization_completed(self, value: bool): + self._set_init_config("initialization_completed", bool(value)) + + @property + def initialization_manifest_version(self) -> str: + return str(self.config.get("initialization_manifest_version", "v1")) + + def set_initialization_manifest_version(self, value: str): + self._set_init_config("initialization_manifest_version", str(value or "v1")) + + @property + def initialization_in_progress(self) -> bool: + return self.config.get("initialization_in_progress", False) + + def set_initialization_in_progress(self, value: bool): + self._set_init_config("initialization_in_progress", bool(value)) + + @property + def last_init_exit_reason(self) -> str: + value = str(self.config.get("last_init_exit_reason", "none") or "none") + return value if value in ("none", "interrupted", "failed") else "none" + + def set_last_init_exit_reason(self, value: str): + normalized = value if value in ("none", "interrupted", "failed") else "none" + self._set_init_config("last_init_exit_reason", normalized) + + @property + def last_init_mode(self) -> str: + value = str(self.config.get("last_init_mode", "none") or "none") + return value if value in ("none", "init", "repair") else "none" + + def set_last_init_mode(self, value: str): + normalized = value if value in ("none", "init", "repair") else "none" + self._set_init_config("last_init_mode", normalized) + + @property + def selected_runtime_variant(self) -> str: + return str(self.config.get("selected_runtime_variant", "auto")) + + def set_selected_runtime_variant(self, value: str): + if value in ("auto", "cpu", "cuda", "mac"): + self._set_init_config("selected_runtime_variant", value) + + @property + def detected_cuda_capable(self) -> bool: + return self.config.get("detected_cuda_capable", False) + + def set_detected_cuda_capable(self, value: bool): + self._set_init_config("detected_cuda_capable", bool(value)) + + @property + def runtime_install_location_preference(self): + value = self.config.get("runtime_install_location_preference", None) + return value if value in ("default", "install", None) else None + + def set_runtime_install_location_preference(self, value): + normalized = value if value in ("default", "install") else None + self._set_init_config("runtime_install_location_preference", normalized) + + @property + def resolved_runtime_dir(self): + value = self.config.get("resolved_runtime_dir", None) + return None if value in (None, "") else str(value) + + def set_resolved_runtime_dir(self, value): + self._set_init_config("resolved_runtime_dir", None if not value else str(value)) + + @property + def enabled_feature_set(self) -> list: + return list(self.config.get("enabled_feature_set", [])) + + def set_enabled_feature_set(self, features: list): + self._set_init_config("enabled_feature_set", list(features)) + + @property + def downloaded_resources(self) -> dict: + return dict(self.config.get("downloaded_resources", {})) + + def set_downloaded_resources(self, resources: dict): + self._set_init_config("downloaded_resources", dict(resources)) + + @property + def resolved_source_map(self) -> dict: + return dict(self.config.get("resolved_source_map", {})) + + def set_resolved_source_map(self, source_map: dict): + self._set_init_config("resolved_source_map", dict(source_map)) + + @property + def last_init_error(self): + return self.config.get("last_init_error", None) + + def set_last_init_error(self, value): + self._set_init_config("last_init_error", value if value is None else str(value)) + # 主界面复选框状态 getter/setter @property def flight_check(self): diff --git a/ai_model.py b/ai_model.py index b48fb43..94d75c8 100644 --- a/ai_model.py +++ b/ai_model.py @@ -17,7 +17,9 @@ def load_yolo_model(log_callback=None): """加载 YOLO 模型(使用最佳计算设备)""" - model_path = config.ai.get_model_path() + model_path = os.path.abspath(config.ai.get_model_path()) + if not os.path.exists(model_path): + raise FileNotFoundError(f"YOLO model file not found: {model_path}") model = YOLO(str(model_path)) # 使用统一的设备检测逻辑 @@ -431,4 +433,4 @@ def detect_and_draw_birds(image_path, model, output_path, dir, ui_settings, i18n # Mask processing failed, ignore pass - return found_bird, bird_result, bird_confidence, bird_sharpness, nima_score, bird_bbox, img_dims, bird_mask, bird_count \ No newline at end of file + return found_bird, bird_result, bird_confidence, bird_sharpness, nima_score, bird_bbox, img_dims, bird_mask, bird_count diff --git a/birdid/avonet_filter.py b/birdid/avonet_filter.py index 8ed383a..0e8d41c 100644 --- a/birdid/avonet_filter.py +++ b/birdid/avonet_filter.py @@ -14,148 +14,155 @@ import os import sqlite3 from typing import Set, List, Optional, Tuple +from config import get_install_scoped_resource_path from tools.i18n import t as _t -# 区域边界定义 (south, north, west, east) -# 格式: REGION_CODE: (南纬界, 北纬界, 西经界, 东经界) REGION_BOUNDS = { - # 全球 "GLOBAL": (-90, 90, -180, 180), - - # 六大洲 (宽泛定义,用于大范围检索) - "AF": (-35, 37, -17, 51), # 非洲 (Africa) - "AS": (-10, 81, 26, 170), # 亚洲 (Asia) - "EU": (34, 71, -25, 45), # 欧洲 (Europe) - "NA": (14, 83, -168, -52), # 北美洲 (North America) - "SA": (-56, 13, -81, -34), # 南美洲 (South America) - "OC": (-47, -10, 110, 180), # 大洋洲 (Oceania) - - # 亚太地区 - 国家 - "AU": (-44, -10, 112, 155), # 澳大利亚 - "NZ": (-47.5, -34, 166, 179), # 新西兰 - "CN": (18, 54, 73, 135), # 中国 - "JP": (24, 46, 122, 154), # 日本 - "KR": (33, 43, 124, 132), # 韩国 - "TW": (21.5, 25.5, 119, 122.5), # 台湾 - "HK": (22.1, 22.6, 113.8, 114.5), # 香港 - "TH": (5.5, 20.5, 97.5, 105.5), # 泰国 - "MY": (0.5, 7.5, 99.5, 119.5), # 马来西亚 - "SG": (1.1, 1.5, 103.6, 104.1), # 新加坡 - "ID": (-11, 6, 95, 141), # 印度尼西亚 - "PH": (4.5, 21, 116, 127), # 菲律宾 - "VN": (8, 23.5, 102, 110), # 越南 - "IN": (6, 36, 68, 98), # 印度 - "LK": (5, 10, 79, 82), # 斯里兰卡 - "NP": (26, 31, 80, 88), # 尼泊尔 - "MN": (41, 52, 87, 120), # 蒙古 - "RU": (41, 82, 19, 180), # 俄罗斯 - - # 美洲 - "US": (24, 49, -125, -66), # 美国本土 - "CA": (42, 83, -141, -52), # 加拿大 - "MX": (14, 33, -118, -86), # 墨西哥 - "BR": (-34, 5.5, -74, -34), # 巴西 - "AR": (-55, -21, -73, -53), # 阿根廷 - "CL": (-56, -17, -76, -66), # 智利 - "CO": (-4.5, 13, -79, -66), # 哥伦比亚 - "PE": (-18.5, 0, -81, -68), # 秘鲁 - "EC": (-5, 2, -81, -75), # 厄瓜多尔 - "CR": (8, 11.5, -86, -82.5), # 哥斯达黎加 - - # 欧洲 - "GB": (49, 61, -8, 2), # 英国 - "FR": (41, 51.5, -5, 10), # 法国 - "DE": (47, 55.5, 5.5, 15.5), # 德国 - "ES": (35.5, 44, -10, 4.5), # 西班牙 - "IT": (36, 47.5, 6.5, 18.5), # 意大利 - "NO": (57.5, 71.5, 4.5, 31.5), # 挪威 - "SE": (55, 69.5, 10.5, 24.5), # 瑞典 - "FI": (59.5, 70.5, 19.5, 31.5), # 芬兰 - "PL": (49, 55, 14, 24.5), # 波兰 - "TR": (35.5, 42.5, 25.5, 45), # 土耳其 - "PT": (36, 42, -10, -6), # 葡萄牙 - "NL": (50, 54, 3, 8), # 荷兰 - "CH": (45, 48, 5, 11), # 瑞士 - "GR": (34, 42, 19, 29), # 希腊 - "UA": (44, 53, 22, 41), # 乌克兰 - - # 非洲 - "MG": (-26, -11, 43, 51), # 马达加斯加 - "ZA": (-35, -22, 16.5, 33), # 南非 - "KE": (-5, 5, 33.5, 42), # 肯尼亚 - "TZ": (-12, -1, 29, 41), # 坦桑尼亚 - "EG": (22, 32, 24.5, 37), # 埃及 - "MA": (27, 36, -13, -1), # 摩洛哥 - - # 澳大利亚各州 - "AU-QLD": (-29, -10, 138, 154), # Queensland - "AU-NSW": (-37.5, -28, 141, 154), # New South Wales - "AU-VIC": (-39.2, -34, 141, 150), # Victoria - "AU-TAS": (-43.7, -39.5, 143.5, 148.5), # Tasmania - "AU-SA": (-38, -26, 129, 141), # South Australia - "AU-WA": (-35, -13.5, 112.5, 129), # Western Australia - "AU-NT": (-26, -10.5, 129, 138), # Northern Territory - "AU-ACT": (-35.95,-35.1, 148.75,149.4), # Australian Capital Territory - - # 美国各州 (south, north, west, east) - "US-AL": (30, 35, -88.5, -84.9), "US-AK": (51, 72, -168, -130), - "US-AZ": (31.3, 37, -114.8,-109), "US-AR": (33, 36.5, -94.6, -89.6), - "US-CA": (32.5, 42, -124.5,-114), "US-CO": (37, 41, -109, -102), - "US-CT": (40.9, 42.1, -73.7, -71.8),"US-DE": (38.4, 39.8, -75.8, -75), - "US-FL": (24.4, 31, -87.7, -80), "US-GA": (30.4, 35, -85.6, -80.8), - "US-HI": (18.9, 22.2, -160.3,-154.8),"US-ID": (42, 49, -117.2,-111), - "US-IL": (36.9, 42.5, -91.5, -87.5),"US-IN": (37.8, 41.8, -88.1, -84.8), - "US-IA": (40.4, 43.5, -96.6, -90.1),"US-KS": (37, 40, -102.1,-94.6), - "US-KY": (36.5, 39.2, -89.6, -81.9),"US-LA": (28.9, 33.1, -94.1, -88.8), - "US-ME": (43.1, 47.5, -71.1, -66.9),"US-MD": (37.9, 39.7, -79.5, -75), - "US-MA": (41.2, 42.9, -73.5, -69.9),"US-MI": (41.7, 48.3, -90.4, -82.4), - "US-MN": (43.5, 49.4, -97.2, -89.5),"US-MS": (30, 35, -91.7, -88.1), - "US-MO": (36, 40.6, -95.8, -89.1),"US-MT": (44.4, 49, -116.1,-104), - "US-NE": (40, 43, -104.1,-95.3),"US-NV": (35, 42, -120, -114), - "US-NH": (42.7, 45.3, -72.6, -70.7),"US-NJ": (38.9, 41.4, -75.6, -73.9), - "US-NM": (31.3, 37, -109.1,-103), "US-NY": (40.5, 45.1, -79.8, -71.9), - "US-NC": (33.8, 36.6, -84.3, -75.5),"US-ND": (45.9, 49, -104.1,-96.6), - "US-OH": (38.4, 42, -84.8, -80.5),"US-OK": (33.6, 37, -103, -94.4), - "US-OR": (41.9, 46.3, -124.6,-116.5),"US-PA": (39.7, 42.3, -80.5, -74.7), - "US-RI": (41.1, 42.1, -71.9, -71.1),"US-SC": (32, 35.2, -83.4, -78.5), - "US-SD": (42.5, 45.9, -104.1,-96.4),"US-TN": (35, 36.7, -90.3, -81.6), - "US-TX": (25.8, 36.5, -106.6,-93.5),"US-UT": (37, 42, -114.1,-109), - "US-VT": (42.7, 45.1, -73.4, -71.5),"US-VA": (36.5, 39.5, -83.7, -75.2), - "US-WA": (45.5, 49, -124.8,-116.9),"US-WV": (37.2, 40.6, -82.7, -77.7), - "US-WI": (42.5, 47.1, -92.9, -86.8),"US-WY": (41, 45, -111.1,-104), - - # 中国各省(south, north, west, east) - "CN-11": (39.4, 41.1, 115.4, 117.7), # 北京 - "CN-12": (38.6, 40.3, 116.7, 118.1), # 天津 - "CN-13": (36, 42.7, 113.5, 119.8), # 河北 - "CN-14": (34.6, 40.7, 110.2, 114.6), # 山西 - "CN-15": (37.5, 53.3, 97.2, 126.1), # 内蒙古 - "CN-21": (38.7, 43.5, 118.8, 125.7), # 辽宁 - "CN-22": (41.2, 46, 121.6, 131.3), # 吉林 - "CN-23": (43.4, 53.6, 121.1, 135.1), # 黑龙江 - "CN-31": (30.7, 31.9, 120.8, 122), # 上海 - "CN-32": (30.8, 35.1, 116.4, 121.9), # 江苏 - "CN-33": (27.1, 31.2, 118.1, 122.9), # 浙江 - "CN-34": (29.4, 34.7, 114.9, 119.9), # 安徽 - "CN-35": (23.5, 28.3, 115.8, 120.7), # 福建 - "CN-36": (24.5, 30.1, 113.6, 118.5), # 江西 - "CN-37": (34.4, 38.3, 114.8, 122.7), # 山东 - "CN-41": (31.4, 36.4, 110.4, 116.7), # 河南 - "CN-42": (29.1, 33.2, 108.4, 116.1), # 湖北 - "CN-43": (24.6, 30.1, 108.8, 114.3), # 湖南 - "CN-44": (20.2, 25.5, 109.7, 117.3), # 广东 - "CN-45": (20.9, 26.4, 104.5, 112.1), # 广西 - "CN-46": (18.1, 20.2, 108.4, 111.2), # 海南 - "CN-50": (28.2, 32.2, 105.3, 110.2), # 重庆 - "CN-51": (26, 34.3, 97.4, 108.5), # 四川 - "CN-52": (24.6, 29.2, 103.6, 109.6), # 贵州 - "CN-53": (21.1, 29.3, 97.5, 106.2), # 云南 - "CN-54": (26.8, 36.5, 78.4, 99.1), # 西藏 - "CN-61": (31.7, 39.6, 105.5, 111.3), # 陕西 - "CN-62": (32.6, 42.8, 92.4, 108.7), # 甘肃 - "CN-63": (31.6, 39.2, 89.4, 103.1), # 青海 - "CN-64": (35.2, 39.4, 104.3, 107.7), # 宁夏 - "CN-65": (34.3, 49.2, 73.5, 96.4), # 新疆 + "AF": (-35, 37, -17, 51), + "AS": (-10, 81, 26, 170), + "EU": (34, 71, -25, 45), + "NA": (14, 83, -168, -52), + "SA": (-56, 13, -81, -34), + "OC": (-47, -10, 110, 180), + "AU": (-44, -10, 112, 155), + "NZ": (-47.5, -34, 166, 179), + "CN": (18, 54, 73, 135), + "JP": (24, 46, 122, 154), + "KR": (33, 43, 124, 132), + "TW": (21.5, 25.5, 119, 122.5), + "HK": (22.1, 22.6, 113.8, 114.5), + "TH": (5.5, 20.5, 97.5, 105.5), + "MY": (0.5, 7.5, 99.5, 119.5), + "SG": (1.1, 1.5, 103.6, 104.1), + "ID": (-11, 6, 95, 141), + "PH": (4.5, 21, 116, 127), + "VN": (8, 23.5, 102, 110), + "IN": (6, 36, 68, 98), + "LK": (5, 10, 79, 82), + "NP": (26, 31, 80, 88), + "MN": (41, 52, 87, 120), + "RU": (41, 82, 19, 180), + "US": (24, 49, -125, -66), + "CA": (42, 83, -141, -52), + "MX": (14, 33, -118, -86), + "BR": (-34, 5.5, -74, -34), + "AR": (-55, -21, -73, -53), + "CL": (-56, -17, -76, -66), + "CO": (-4.5, 13, -79, -66), + "PE": (-18.5, 0, -81, -68), + "EC": (-5, 2, -81, -75), + "CR": (8, 11.5, -86, -82.5), + "GB": (49, 61, -8, 2), + "FR": (41, 51.5, -5, 10), + "DE": (47, 55.5, 5.5, 15.5), + "ES": (35.5, 44, -10, 4.5), + "IT": (36, 47.5, 6.5, 18.5), + "NO": (57.5, 71.5, 4.5, 31.5), + "SE": (55, 69.5, 10.5, 24.5), + "FI": (59.5, 70.5, 19.5, 31.5), + "PL": (49, 55, 14, 24.5), + "TR": (35.5, 42.5, 25.5, 45), + "PT": (36, 42, -10, -6), + "NL": (50, 54, 3, 8), + "CH": (45, 48, 5, 11), + "GR": (34, 42, 19, 29), + "UA": (44, 53, 22, 41), + "MG": (-26, -11, 43, 51), + "ZA": (-35, -22, 16.5, 33), + "KE": (-5, 5, 33.5, 42), + "TZ": (-12, -1, 29, 41), + "EG": (22, 32, 24.5, 37), + "MA": (27, 36, -13, -1), + "AU-QLD": (-29, -10, 138, 154), + "AU-NSW": (-37.5, -28, 141, 154), + "AU-VIC": (-39.2, -34, 141, 150), + "AU-TAS": (-43.7, -39.5, 143.5, 148.5), + "AU-SA": (-38, -26, 129, 141), + "AU-WA": (-35, -13.5, 112.5, 129), + "AU-NT": (-26, -10.5, 129, 138), + "AU-ACT": (-35.95, -35.1, 148.75, 149.4), + "US-AL": (30, 35, -88.5, -84.9), + "US-AK": (51, 72, -168, -130), + "US-AZ": (31.3, 37, -114.8, -109), + "US-AR": (33, 36.5, -94.6, -89.6), + "US-CA": (32.5, 42, -124.5, -114), + "US-CO": (37, 41, -109, -102), + "US-CT": (40.9, 42.1, -73.7, -71.8), + "US-DE": (38.4, 39.8, -75.8, -75), + "US-FL": (24.4, 31, -87.7, -80), + "US-GA": (30.4, 35, -85.6, -80.8), + "US-HI": (18.9, 22.2, -160.3, -154.8), + "US-ID": (42, 49, -117.2, -111), + "US-IL": (36.9, 42.5, -91.5, -87.5), + "US-IN": (37.8, 41.8, -88.1, -84.8), + "US-IA": (40.4, 43.5, -96.6, -90.1), + "US-KS": (37, 40, -102.1, -94.6), + "US-KY": (36.5, 39.2, -89.6, -81.9), + "US-LA": (28.9, 33.1, -94.1, -88.8), + "US-ME": (43.1, 47.5, -71.1, -66.9), + "US-MD": (37.9, 39.7, -79.5, -75), + "US-MA": (41.2, 42.9, -73.5, -69.9), + "US-MI": (41.7, 48.3, -90.4, -82.4), + "US-MN": (43.5, 49.4, -97.2, -89.5), + "US-MS": (30, 35, -91.7, -88.1), + "US-MO": (36, 40.6, -95.8, -89.1), + "US-MT": (44.4, 49, -116.1, -104), + "US-NE": (40, 43, -104.1, -95.3), + "US-NV": (35, 42, -120, -114), + "US-NH": (42.7, 45.3, -72.6, -70.7), + "US-NJ": (38.9, 41.4, -75.6, -73.9), + "US-NM": (31.3, 37, -109.1, -103), + "US-NY": (40.5, 45.1, -79.8, -71.9), + "US-NC": (33.8, 36.6, -84.3, -75.5), + "US-ND": (45.9, 49, -104.1, -96.6), + "US-OH": (38.4, 42, -84.8, -80.5), + "US-OK": (33.6, 37, -103, -94.4), + "US-OR": (41.9, 46.3, -124.6, -116.5), + "US-PA": (39.7, 42.3, -80.5, -74.7), + "US-RI": (41.1, 42.1, -71.9, -71.1), + "US-SC": (32, 35.2, -83.4, -78.5), + "US-SD": (42.5, 45.9, -104.1, -96.4), + "US-TN": (35, 36.7, -90.3, -81.6), + "US-TX": (25.8, 36.5, -106.6, -93.5), + "US-UT": (37, 42, -114.1, -109), + "US-VT": (42.7, 45.1, -73.4, -71.5), + "US-VA": (36.5, 39.5, -83.7, -75.2), + "US-WA": (45.5, 49, -124.8, -116.9), + "US-WV": (37.2, 40.6, -82.7, -77.7), + "US-WI": (42.5, 47.1, -92.9, -86.8), + "US-WY": (41, 45, -111.1, -104), + "CN-11": (39.4, 41.1, 115.4, 117.7), + "CN-12": (38.6, 40.3, 116.7, 118.1), + "CN-13": (36, 42.7, 113.5, 119.8), + "CN-14": (34.6, 40.7, 110.2, 114.6), + "CN-15": (37.5, 53.3, 97.2, 126.1), + "CN-21": (38.7, 43.5, 118.8, 125.7), + "CN-22": (41.2, 46, 121.6, 131.3), + "CN-23": (43.4, 53.6, 121.1, 135.1), + "CN-31": (30.7, 31.9, 120.8, 122), + "CN-32": (30.8, 35.1, 116.4, 121.9), + "CN-33": (27.1, 31.2, 118.1, 122.9), + "CN-34": (29.4, 34.7, 114.9, 119.9), + "CN-35": (23.5, 28.3, 115.8, 120.7), + "CN-36": (24.5, 30.1, 113.6, 118.5), + "CN-37": (34.4, 38.3, 114.8, 122.7), + "CN-41": (31.4, 36.4, 110.4, 116.7), + "CN-42": (29.1, 33.2, 108.4, 116.1), + "CN-43": (24.6, 30.1, 108.8, 114.3), + "CN-44": (20.2, 25.5, 109.7, 117.3), + "CN-45": (20.9, 26.4, 104.5, 112.1), + "CN-46": (18.1, 20.2, 108.4, 111.2), + "CN-50": (28.2, 32.2, 105.3, 110.2), + "CN-51": (26, 34.3, 97.4, 108.5), + "CN-52": (24.6, 29.2, 103.6, 109.6), + "CN-53": (21.1, 29.3, 97.5, 106.2), + "CN-54": (26.8, 36.5, 78.4, 99.1), + "CN-61": (31.7, 39.6, 105.5, 111.3), + "CN-62": (32.6, 42.8, 92.4, 108.7), + "CN-63": (31.6, 39.2, 89.4, 103.1), + "CN-64": (35.2, 39.4, 104.3, 107.7), + "CN-65": (34.3, 49.2, 73.5, 96.4), } @@ -176,14 +183,12 @@ def __init__(self, db_path: Optional[str] = None): db_path: avonet.db 的路径,如果为 None 则自动定位 """ if db_path is None: - # 自动定位数据库文件 db_path = self._find_database() self.db_path = db_path self._conn: Optional[sqlite3.Connection] = None - self._ebird_cls_map: Optional[dict] = None # eBird code -> class_id(懒加载) + self._ebird_cls_map: Optional[dict] = None - # 尝试连接数据库 if self.db_path and os.path.exists(self.db_path): try: self._conn = sqlite3.connect(self.db_path, check_same_thread=False) @@ -197,17 +202,12 @@ def _find_database(self) -> Optional[str]: 自动查找 avonet.db 文件 查找顺序: - 1. birdid/data/avonet.db (相对于当前文件) - 2. data/avonet.db (相对于当前工作目录) - 3. 常见安装位置 + 1. 统一安装目录资源路径 + 2. 兼容旧开发目录 """ - # 相对于当前模块的位置 - module_dir = os.path.dirname(os.path.abspath(__file__)) possible_paths = [ - os.path.join(module_dir, "data", "avonet.db"), - os.path.join(module_dir, "..", "data", "avonet.db"), - os.path.join(os.getcwd(), "birdid", "data", "avonet.db"), - os.path.join(os.getcwd(), "data", "avonet.db"), + str(get_install_scoped_resource_path(os.path.join("birdid", "data", "avonet.db"))), + os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "avonet.db"), ] for path in possible_paths: @@ -251,7 +251,6 @@ def get_species_by_gps(self, lat: float, lon: float) -> Set[int]: return set() try: - # 查询包含该GPS点的网格中的所有物种 query = """ SELECT DISTINCT sm.cls FROM distributions d @@ -306,7 +305,6 @@ def _get_species_by_bounds( return set() try: - # 查询与边界框重叠的所有网格中的物种 query = """ SELECT DISTINCT sm.cls FROM distributions d @@ -365,8 +363,6 @@ def __del__(self): """析构时关闭连接""" self.close() - # ==================== eBird 国家级回退 ==================== - def _load_ebird_cls_map(self) -> dict: """懒加载 ebird_classid_mapping.json,返回 ebird_code -> class_id 的反向映射""" if self._ebird_cls_map is not None: @@ -381,7 +377,7 @@ def _load_ebird_cls_map(self) -> dict: try: import json with open(map_path, "r", encoding="utf-8") as f: - raw = json.load(f) # {str(class_id): ebird_code} + raw = json.load(f) self._ebird_cls_map = {v: int(k) for k, v in raw.items()} except Exception as e: print(_t("logs.avonet_classid_failed", e=e)) @@ -394,10 +390,8 @@ def _detect_country_from_gps(self, lat: float, lon: float) -> Optional[str]: 根据 GPS 坐标离线判定国家代码(仅返回国家级,不含州级)。 优先匹配面积最小的边界框,避免大国遮蔽小国。 """ - # 大陆级/全球代码,跳过 _SKIP = {"GLOBAL", "AF", "AS", "EU", "NA", "SA", "OC"} - # 收集匹配的国家及其面积 candidates = [] for code, bounds in REGION_BOUNDS.items(): if code in _SKIP: @@ -410,7 +404,6 @@ def _detect_country_from_gps(self, lat: float, lon: float) -> Optional[str]: if not candidates: return None - # 返回面积最小的匹配(最具体的) candidates.sort() return candidates[0][1] @@ -431,7 +424,6 @@ def get_species_by_country_ebird( if not country_code: return set(), None - # 加载对应国家的 eBird 物种列表 module_dir = os.path.dirname(os.path.abspath(__file__)) species_file = os.path.join( module_dir, "data", "offline_ebird_data", @@ -450,7 +442,6 @@ def get_species_by_country_ebird( print(_t("logs.avonet_read_ebird_failed", code=country_code, e=e)) return set(), None - # 转换 eBird 代码 -> class_id cls_map = self._load_ebird_cls_map() class_ids: Set[int] = set() for code in ebird_codes: @@ -481,7 +472,6 @@ def _load_ebird_file(code: str) -> Optional[List[str]]: try: with open(path, "r", encoding="utf-8") as f: data = json.load(f) - # 支持两种格式:纯 list 或 {"species": [...]} if isinstance(data, list): return data return data.get("species", []) @@ -490,11 +480,9 @@ def _load_ebird_file(code: str) -> Optional[List[str]]: return None region_code = region_code.upper() - # 尝试加载州级数据 species_codes = _load_ebird_file(region_code) actual_region = region_code - # 州级无数据则回退到国家级 if not species_codes and "-" in region_code: country = region_code.split("-")[0] species_codes = _load_ebird_file(country) @@ -511,85 +499,3 @@ def _load_ebird_file(code: str) -> Optional[List[str]]: class_ids.add(cls_id) return class_ids, actual_region - - -if __name__ == "__main__": - print("=" * 60) - print("AvonetFilter 测试") - print("=" * 60) - - # 创建过滤器实例 - af = AvonetFilter() - - # 检查数据库是否可用 - print(f"\n数据库路径: {af.db_path}") - print(f"数据库可用: {af.is_available()}") - - if not af.is_available(): - print("错误: 数据库不可用,无法继续测试") - exit(1) - - # 测试 GPS 查询 - print("\n" + "-" * 40) - print("GPS 坐标查询测试") - print("-" * 40) - - test_locations = [ - ("吉隆坡 (马来西亚)", 3.0, 101.7), - ("悉尼 (澳大利亚)", -33.9, 151.2), - ("东京 (日本)", 35.7, 139.7), - ("伦敦 (英国)", 51.5, -0.1), - ] - - for name, lat, lon in test_locations: - species = af.get_species_by_gps(lat, lon) - print(f" {name}: {len(species)} 个物种") - if species: - sample = sorted(list(species))[:5] - print(f" 样例 class_ids: {sample}") - - # 测试区域查询 - print("\n" + "-" * 40) - print("区域代码查询测试") - print("-" * 40) - - test_regions = ["AU", "AU-SA", "CN", "JP"] - - for region in test_regions: - species = af.get_species_by_region(region) - bounds = af.get_region_bounds(region) - print(f" {region}: {len(species)} 个物种") - print(f" 边界: {bounds}") - if species: - sample = sorted(list(species))[:5] - print(f" 样例 class_ids: {sample}") - - # 显示支持的区域列表 - print("\n" + "-" * 40) - print("支持的区域代码") - print("-" * 40) - - regions = af.get_supported_regions() - print(f" 共 {len(regions)} 个区域:") - - # 按类别分组显示 - global_regions = [r for r in regions if r == "GLOBAL"] - au_states = [r for r in regions if r.startswith("AU-")] - au_country = [r for r in regions if r == "AU"] - asia = [r for r in regions if r in ["CN", "JP", "KR", "TW", "TH", "MY", "SG", "ID", "PH", "VN", "IN", "NZ"]] - americas = [r for r in regions if r in ["US", "CA", "MX", "BR", "AR", "CL", "CO", "PE", "EC", "CR"]] - europe = [r for r in regions if r in ["GB", "FR", "DE", "ES", "IT", "NO", "SE", "FI", "PL", "TR"]] - africa = [r for r in regions if r in ["ZA", "KE", "TZ", "EG", "MA"]] - - print(f" 全球: {global_regions}") - print(f" 澳大利亚: {au_country + au_states}") - print(f" 亚太: {asia}") - print(f" 美洲: {americas}") - print(f" 欧洲: {europe}") - print(f" 非洲: {africa}") - - # 关闭连接 - af.close() - print("\n" + "=" * 60) - print("测试完成") - print("=" * 60) diff --git a/birdid/bird_identifier.py b/birdid/bird_identifier.py index 1a8c2e3..95c9d52 100644 --- a/birdid/bird_identifier.py +++ b/birdid/bird_identifier.py @@ -1,7 +1,11 @@ #!/usr/bin/env python3 """ -鸟类识别核心模块 -从 SuperBirdID 移植,提供鸟类检测与分类识别功能 +鸟类识别核心模块。 +Core bird-identification module. + +从 SuperBirdID 移植,负责鸟类检测、分类与离线资源路径兼容。 +Ported from SuperBirdID and responsible for bird detection, classification, +and compatibility with offline resource paths. """ __version__ = "1.0.0" @@ -15,113 +19,141 @@ import io import os import sys -from typing import Optional, List, Dict, Tuple, Set +from typing import Any, Optional, List, Dict, Tuple, Set, cast from tools.i18n import t as _t -from config import get_best_device, get_lazy_registry - -# ==================== 设备配置 ==================== +from config import ( + get_best_device, + get_lazy_registry, + get_app_config_dir, + get_install_scoped_resource_path, + get_packaged_model_relative_path, + get_runtime_meipass, +) -CLASSIFIER_DEVICE = get_best_device() +CLASSIFIER_DEVICE = torch.device(str(get_best_device())) -# ==================== 可选依赖检测 ==================== +RESAMPLING_LANCZOS = Image.Resampling.LANCZOS -# RAW格式支持 try: import rawpy import imageio + RAW_SUPPORT = True except ImportError: + rawpy = cast(Any, None) + imageio = cast(Any, None) RAW_SUPPORT = False -# YOLO检测支持 try: from ultralytics import YOLO + YOLO_AVAILABLE = True except ImportError: + YOLO = cast(Any, None) YOLO_AVAILABLE = False -# ==================== 路径配置 ==================== - -# birdid 模块目录 BIRDID_DIR = os.path.dirname(os.path.abspath(__file__)) -# 项目根目录(code_updates overlay 场景下 __file__ 指向 code_updates/birdid/,需通过 sys.path 找真实根) + + def _find_project_root() -> str: candidate = os.path.dirname(BIRDID_DIR) - if os.path.exists(os.path.join(candidate, 'models', 'model20240824.pth')): + if os.path.exists(os.path.join(candidate, "models", "model20240824.pth")): return candidate for p in sys.path: - if p and os.path.isdir(p) and os.path.exists(os.path.join(p, 'models', 'model20240824.pth')): + if ( + p + and os.path.isdir(p) + and os.path.exists(os.path.join(p, "models", "model20240824.pth")) + ): return p - return candidate # 兜底 + return candidate + def _find_birdid_dir() -> str: - if os.path.exists(os.path.join(BIRDID_DIR, 'data', 'bird_reference.sqlite')): + if os.path.exists(os.path.join(BIRDID_DIR, "data", "bird_reference.sqlite")): return BIRDID_DIR for p in sys.path: if p and os.path.isdir(p): - candidate = os.path.join(p, 'birdid') - if os.path.exists(os.path.join(candidate, 'data', 'bird_reference.sqlite')): + candidate = os.path.join(p, "birdid") + if os.path.exists(os.path.join(candidate, "data", "bird_reference.sqlite")): return candidate - return BIRDID_DIR # 兜底 + return BIRDID_DIR + PROJECT_ROOT = _find_project_root() BIRDID_DIR = _find_birdid_dir() def get_birdid_path(relative_path: str) -> str: - """获取 birdid 模块内的资源路径""" - if getattr(sys, 'frozen', False): - # PyInstaller 打包环境 - return os.path.join(sys._MEIPASS, 'birdid', relative_path) + """ + 返回 `birdid/` 目录下的资源路径。 + Return a resource path under the `birdid/` directory. + + Windows Lite 构建需要从安装目录 `_internal` 读取资源,其余冻结环境仍跟随 + PyInstaller bundle 目录;源码环境则回退到仓库内的 `birdid/` 目录。 + Windows Lite builds read from the install-scoped `_internal` tree, other + frozen builds still follow the PyInstaller bundle, and source runs fall back + to the repository `birdid/` directory. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + return str( + get_install_scoped_resource_path(os.path.join("birdid", relative_path)) + ) + if getattr(sys, "frozen", False): + meipass = get_runtime_meipass() + if meipass is not None: + return os.path.join(meipass, "birdid", relative_path) return os.path.join(BIRDID_DIR, relative_path) def get_project_path(relative_path: str) -> str: - """获取项目根目录下的资源路径""" - if getattr(sys, 'frozen', False): - return os.path.join(sys._MEIPASS, relative_path) + """ + 返回项目级资源路径。 + Return a project-level resource path. + + 这里统一兼容 Windows Lite 安装目录、普通 PyInstaller bundle 与源码目录, + 避免各调用方再自行拼接 `_MEIPASS` 路径。 + This helper centralizes path selection for Windows Lite installs, regular + PyInstaller bundles, and source checkouts so callers do not rebuild + `_MEIPASS`-based paths themselves. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + packaged_relative_path = None + if relative_path.startswith("models/"): + packaged_relative_path = get_packaged_model_relative_path(relative_path) + return str( + get_install_scoped_resource_path( + relative_path, packaged_relative_path=packaged_relative_path + ) + ) + if getattr(sys, "frozen", False): + meipass = get_runtime_meipass() + if meipass is not None: + return os.path.join(meipass, relative_path) return os.path.join(PROJECT_ROOT, relative_path) def get_user_data_dir() -> str: - """获取用户数据目录""" - if sys.platform == 'darwin': - user_data_dir = os.path.expanduser('~/Documents/SuperPicky_Data') - elif sys.platform == 'win32': - user_data_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') - else: - user_data_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') + user_data_dir = str(get_app_config_dir()) os.makedirs(user_data_dir, exist_ok=True) return user_data_dir -# ==================== 模型路径 ==================== -# 鸟类识别专用模型和数据(在 birdid/ 目录下) -# OSEA ResNet34 模型(替代旧 birdid2024) -MODEL_PATH = get_project_path('models/model20240824.pth') -# 旧模型路径(保留作为回退) -MODEL_PATH_LEGACY = get_birdid_path('models/birdid2024.pt') -MODEL_PATH_ENC = get_birdid_path('models/birdid2024.pt.enc') -# OSEA 模型类别数 +MODEL_PATH = get_project_path("models/model20240824.pth") +MODEL_PATH_LEGACY = get_birdid_path("models/birdid2024.pt") +MODEL_PATH_ENC = get_birdid_path("models/birdid2024.pt.enc") OSEA_NUM_CLASSES = 11000 -DATABASE_PATH = get_birdid_path('data/bird_reference.sqlite') - -# YOLO 模型(共用项目根目录的模型) -YOLO_MODEL_PATH = get_project_path('models/yolo11l-seg.pt') +DATABASE_PATH = get_birdid_path("data/bird_reference.sqlite") +YOLO_MODEL_PATH = get_project_path("models/yolo11l-seg.pt") -# ==================== 全局变量(懒加载)==================== -# 已迁移至 config.get_lazy_registry() 统一管理 - -# ==================== 模型加密解密 ==================== def decrypt_model(encrypted_path: str, password: str) -> bytes: - """解密模型文件""" from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC - with open(encrypted_path, 'rb') as f: + with open(encrypted_path, "rb") as f: encrypted_data = f.read() salt = encrypted_data[:16] @@ -133,15 +165,11 @@ def decrypt_model(encrypted_path: str, password: str) -> bytes: length=32, salt=salt, iterations=100000, - backend=default_backend() + backend=default_backend(), ) key = kdf.derive(password.encode()) - cipher = Cipher( - algorithms.AES(key), - modes.CBC(iv), - backend=default_backend() - ) + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend()) decryptor = cipher.decryptor() plaintext_padded = decryptor.update(ciphertext) + decryptor.finalize() @@ -150,15 +178,11 @@ def decrypt_model(encrypted_path: str, password: str) -> bytes: def _load_torchscript_from_bytes(model_data: bytes): - """Load TorchScript from bytes to avoid Windows non-ASCII temp path issues.""" buffer = io.BytesIO(model_data) - return torch.jit.load(buffer, map_location='cpu') + return torch.jit.load(buffer, map_location="cpu") -# ==================== 懒加载函数 ==================== - def get_classifier(): - """懒加载分类模型(OSEA ResNet34)""" registry = get_lazy_registry() def _factory(): @@ -166,11 +190,10 @@ def _factory(): if os.path.exists(MODEL_PATH): model = models.resnet34(num_classes=OSEA_NUM_CLASSES) - state_dict = torch.load(MODEL_PATH, map_location='cpu', weights_only=True) + state_dict = torch.load(MODEL_PATH, map_location="cpu", weights_only=True) model.load_state_dict(state_dict) - model = model.to(CLASSIFIER_DEVICE) + model = model.to(device=CLASSIFIER_DEVICE) model.eval() - print(f"[BirdID] OSEA ResNet34 model loaded, device: {CLASSIFIER_DEVICE}") return model SECRET_PASSWORD = "SuperBirdID_2024_AI_Model_Encryption_Key_v1" @@ -179,40 +202,38 @@ def _factory(): model = _load_torchscript_from_bytes(model_data) elif os.path.exists(MODEL_PATH_LEGACY): try: - model = torch.jit.load(MODEL_PATH_LEGACY, map_location='cpu') + model = torch.jit.load(MODEL_PATH_LEGACY, map_location="cpu") except RuntimeError as e: - if 'open file failed' not in str(e) or 'fopen' not in str(e): + if "open file failed" not in str(e) or "fopen" not in str(e): raise - with open(MODEL_PATH_LEGACY, 'rb') as f: + with open(MODEL_PATH_LEGACY, "rb") as f: model_data = f.read() model = _load_torchscript_from_bytes(model_data) else: raise RuntimeError(f"未找到分类模型: {MODEL_PATH} 或 {MODEL_PATH_LEGACY}") - model = model.to(CLASSIFIER_DEVICE) - model.eval() # noqa: model.eval() is a PyTorch API call, not Python eval() - print(_t("logs.birdid_fallback_model")) + model = model.to(device=CLASSIFIER_DEVICE) + model.eval() return model return registry.get_or_create("birdid.classifier", _factory) def get_bird_model(): - """获取识鸟模型(get_classifier 的别名,用于模型预加载)""" return get_classifier() def get_database_manager(): - """懒加载数据库管理器""" registry = get_lazy_registry() def _factory(): try: from birdid.bird_database_manager import BirdDatabaseManager + if os.path.exists(DATABASE_PATH): return BirdDatabaseManager(DATABASE_PATH) except Exception as e: - print(_t("logs.db_load_failed", e=e)) + pass return False result = registry.get_or_create("birdid.database_manager", _factory) @@ -220,40 +241,38 @@ def _factory(): def get_yolo_detector(): - """懒加载YOLO检测器""" if not YOLO_AVAILABLE: return None registry = get_lazy_registry() return registry.get_or_create( "birdid.yolo_detector", - lambda: YOLOBirdDetector(YOLO_MODEL_PATH) if os.path.exists(YOLO_MODEL_PATH) else None, + lambda: ( + YOLOBirdDetector(YOLO_MODEL_PATH) + if os.path.exists(YOLO_MODEL_PATH) + else None + ), ) def get_species_filter(): - """懒加载 AvonetFilter(单例模式)""" registry = get_lazy_registry() def _factory(): try: from birdid.avonet_filter import AvonetFilter + filt = AvonetFilter() if filt.is_available(): - print(_t("logs.avonet_loaded")) return filt except Exception as e: - print(_t("logs.avonet_init_failed", e=e)) + pass return None return registry.get_or_create("birdid.avonet_filter", _factory) -# ==================== YOLO 鸟类检测器 ==================== - class YOLOBirdDetector: - """YOLO 鸟类检测器""" - - def __init__(self, model_path: str = None): + def __init__(self, model_path: Optional[str] = None): if not YOLO_AVAILABLE: self.model = None return @@ -261,10 +280,14 @@ def __init__(self, model_path: str = None): if model_path is None: model_path = YOLO_MODEL_PATH + model_path = os.path.abspath(model_path) + if not os.path.exists(model_path): + self.model = None + return + try: self.model = YOLO(model_path) except Exception as e: - print(_t("logs.yolo_load_failed", e=e)) self.model = None def detect_and_crop_bird( @@ -272,26 +295,8 @@ def detect_and_crop_bird( image_input, confidence_threshold: float = 0.25, padding_ratio: float = 0.15, - fill_color: Tuple[int, int, int] = (0, 0, 0) + fill_color: Tuple[int, int, int] = (0, 0, 0), ) -> Tuple[Optional[Image.Image], str]: - """ - 检测并裁剪鸟类区域(智能正方形裁剪 + Letterboxing) - - 处理流程: - 1. YOLO 检测获取 bounding box - 2. 智能正方形扩展: max_side * (1 + padding_ratio) - 3. 边界限制: 裁剪区域不超出图片范围 - 4. Letterboxing: 如果裁剪后非正方形,用 fill_color 填充成正方形 - - Args: - image_input: 文件路径或 PIL Image - confidence_threshold: 置信度阈值 - padding_ratio: padding 比例(基于 bbox 最大边长),默认 0.15 (15%) - fill_color: Letterboxing 填充颜色,默认黑色 (0, 0, 0) - - Returns: - (裁剪后的正方形图像, 检测信息) 或 (None, 错误信息) - """ if self.model is None: return None, "YOLO模型未可用" @@ -315,29 +320,27 @@ def detect_and_crop_bird( confidence = box.conf[0].cpu().numpy() class_id = int(box.cls[0].cpu().numpy()) - # COCO 数据集中鸟类的 class_id 是 14 if class_id == 14: - detections.append({ - 'bbox': [int(x1), int(y1), int(x2), int(y2)], - 'confidence': float(confidence) - }) + detections.append( + { + "bbox": [int(x1), int(y1), int(x2), int(y2)], + "confidence": float(confidence), + } + ) if not detections: - return None, _t("logs.no_bird_detected") + return None, "未检测到鸟类" - best = max(detections, key=lambda x: x['confidence']) + best = max(detections, key=lambda x: x["confidence"]) img_width, img_height = image.size - # Phase 1: 获取 bbox - x1, y1, x2, y2 = best['bbox'] + x1, y1, x2, y2 = best["bbox"] bbox_width = x2 - x1 bbox_height = y2 - y1 - # Phase 2: 智能正方形扩展 (基于最大边长 + padding_ratio) max_side = max(bbox_width, bbox_height) target_side = int(max_side * (1 + padding_ratio)) - # 以 bbox 中心为基准扩展 cx = (x1 + x2) // 2 cy = (y1 + y2) // 2 half = target_side // 2 @@ -347,7 +350,6 @@ def detect_and_crop_bird( sq_x2 = cx + half sq_y2 = cy + half - # Phase 3: 边界限制 crop_x1 = max(0, sq_x1) crop_y1 = max(0, sq_y1) crop_x2 = min(img_width, sq_x2) @@ -356,10 +358,9 @@ def detect_and_crop_bird( cropped = image.crop((crop_x1, crop_y1, crop_x2, crop_y2)) crop_w, crop_h = cropped.size - # Phase 4: Letterboxing (如果裁剪后非正方形) if crop_w != crop_h: sq_size = max(crop_w, crop_h) - square = Image.new('RGB', (sq_size, sq_size), fill_color) + square = Image.new("RGB", (sq_size, sq_size), fill_color) paste_x = (sq_size - crop_w) // 2 paste_y = (sq_size - crop_h) // 2 square.paste(cropped, (paste_x, paste_y)) @@ -373,68 +374,79 @@ def detect_and_crop_bird( return None, f"检测失败: {e}" -# ==================== 图像加载 ==================== - def load_image(image_path: str) -> Image.Image: - """ - 加载图像,支持标准格式和 RAW 格式 - 对 RAW 文件优先提取内嵌 JPEG 预览图(更适合 YOLO 检测) - """ if not os.path.exists(image_path): raise FileNotFoundError(f"文件不存在: {image_path}") ext = os.path.splitext(image_path)[1].lower() raw_extensions = [ - '.cr2', '.cr3', '.nef', '.nrw', '.arw', '.srf', '.dng', - '.raf', '.orf', '.rw2', '.pef', '.srw', '.raw', '.rwl', - '.3fr', '.fff', '.erf', '.mef', '.mos', '.mrw', '.x3f', - '.hif', '.heif', '.heic', # Sony HIF / HEIF + ".cr2", + ".cr3", + ".nef", + ".nrw", + ".arw", + ".srf", + ".dng", + ".raf", + ".orf", + ".rw2", + ".pef", + ".srw", + ".raw", + ".rwl", + ".3fr", + ".fff", + ".erf", + ".mef", + ".mos", + ".mrw", + ".x3f", + ".hif", + ".heif", + ".heic", ] - # HEIF 格式(rawpy 不支持):直接补 pillow-heif 路径 - heif_extensions = {'.hif', '.heif', '.heic'} + heif_extensions = {".hif", ".heif", ".heic"} if ext in raw_extensions: if ext in heif_extensions: return _load_heif(image_path) if RAW_SUPPORT: + thumb_format_enum = getattr(rawpy, "ThumbFormat", None) + jpeg_thumb_format = getattr(thumb_format_enum, "JPEG", None) + bitmap_thumb_format = getattr(thumb_format_enum, "BITMAP", None) + rawpy_internal = getattr(rawpy, "_rawpy", None) + unsupported_error = getattr( + rawpy_internal, "LibRawFileUnsupportedError", None + ) try: with rawpy.imread(image_path) as raw: - # 优先尝试提取内嵌的 JPEG 预览图 try: thumb = raw.extract_thumb() - if thumb.format == rawpy.ThumbFormat.JPEG: - # 直接使用内嵌的 JPEG + if thumb.format == jpeg_thumb_format: from io import BytesIO + img = Image.open(BytesIO(thumb.data)).convert("RGB") - print(_t("logs.raw_embedded_jpeg", w=img.size[0], h=img.size[1])) return img - elif thumb.format == rawpy.ThumbFormat.BITMAP: - # 位图格式 + elif thumb.format == bitmap_thumb_format: img = Image.fromarray(thumb.data).convert("RGB") - print(_t("logs.raw_embedded_bitmap", w=img.size[0], h=img.size[1])) return img except Exception as e: - print(_t("logs.raw_preview_failed", e=e)) - - # 如果无法提取预览,使用半尺寸后处理 + pass + rgb = raw.postprocess( use_camera_wb=True, output_bps=8, no_auto_bright=False, auto_bright_thr=0.01, - half_size=True # 使用半尺寸,加快处理 + half_size=True, ) img = Image.fromarray(rgb) - print(_t("logs.raw_half_size", w=img.size[0], h=img.size[1])) return img - except rawpy._rawpy.LibRawFileUnsupportedError: - # LibRaw 不支持的格式(如 Sony A7M5 NeXt/Compressed RAW 2) - # 回退:使用 exiftool -b -JpgFromRaw 提取相机内嵌 JPEG - print(f"[RAW] rawpy 不支持此 RAW 格式,尝试 ExifTool JpgFromRaw 回退...") - return _load_raw_via_exiftool(image_path) except Exception as e: + if unsupported_error is not None and isinstance(e, unsupported_error): + return _load_raw_via_exiftool(image_path) raise Exception(f"RAW处理失败: {e}") else: raise ImportError("需要安装 rawpy 来处理 RAW 格式") @@ -442,19 +454,19 @@ def load_image(image_path: str) -> Image.Image: return Image.open(image_path).convert("RGB") -def _load_raw_via_exiftool(image_path: str) -> "Image.Image": +def _load_raw_via_exiftool(image_path: str) -> Image.Image: """ - 使用 ExifTool 从 RAW 文件提取内嵌 JPEG。 - 用于 LibRaw 不支持的格式(如 Sony A7M5 NeXt/Compressed RAW 2)。 - 按优先级依次尝试:JpgFromRaw → PreviewImage → ThumbnailImage + 使用 ExifTool 从 RAW 文件提取可解码预览图。 + Extract a decodable preview image from a RAW file via ExifTool. """ import subprocess from io import BytesIO - # 查找 exiftool(优先使用打包内的版本) possible_paths = [] if getattr(sys, "frozen", False): - possible_paths.append(os.path.join(sys._MEIPASS, "exiftools_mac", "exiftool")) + meipass = get_runtime_meipass() + if meipass is not None: + possible_paths.append(os.path.join(meipass, "exiftools_mac", "exiftool")) possible_paths += [ os.path.join(PROJECT_ROOT, "exiftools_mac", "exiftool"), "/opt/homebrew/bin/exiftool", @@ -463,165 +475,157 @@ def _load_raw_via_exiftool(image_path: str) -> "Image.Image": ] exiftool = next((p for p in possible_paths if os.path.isfile(p)), "exiftool") - # 依次尝试各种嵌入图像标签 for tag in ["-JpgFromRaw", "-PreviewImage", "-ThumbnailImage"]: try: result = subprocess.run( - [exiftool, "-b", tag, image_path], - capture_output=True, timeout=15 + [exiftool, "-b", tag, image_path], capture_output=True, timeout=15 ) if result.returncode == 0 and result.stdout and len(result.stdout) > 1000: img = Image.open(BytesIO(result.stdout)).convert("RGB") - print(f"[RAW] ExifTool {tag} 提取成功: {img.size[0]}x{img.size[1]}") return img except Exception as e: - print(f"[RAW] ExifTool {tag} 失败: {e}") continue raise Exception( f"\u6682\u4e0d\u652f\u6301\u6b64 RAW \u683c\u5f0f\uff08{os.path.basename(image_path)}\uff09\u3002" "Sony A7M5 \u7b49\u76f8\u673a\u7684 NeXt/Compressed RAW 2 \u683c\u5f0f\u76ee\u524d\u7b2c\u4e09\u65b9\u5e93\u5c1a\u672a\u5b8c\u6574\u652f\u6301\uff0c" - "\u5c06\u5728\u540e\u7eed\u7248\u672c\u4e2d\u4fee\u590d\u3002\u5efa\u8bae\u4e34\u65f6\u4f7f\u7528\u65e0\u538b\u7f29 RAW \u6216 JPEG \u683c\u5f0f\u62cd\u6444\u3002" + "\u5c06\u0627\u5728\u540e\u7eed\u7248\u672c\u4e2d\u4fee\u590d\u3002\u5efa\u8bae\u4e34\u65f6\u4f7f\u7528\u65e0\u538b\u7f29 RAW \u6216 JPEG \u683c\u5f0f\u62cd\u6444\u3002" ) -def _load_heif(image_path: str) -> "Image.Image": - """ - \u4f7f\u7528 pillow-heif \u89e3\u7801 HEIF/HIF \u6587\u4ef6\uff08Sony HIF \u3001\u82f9\u679c HEIC \u7b49\uff09\u4e3a PIL Image\u3002 - """ +def _load_heif(image_path: str) -> Image.Image: try: import pillow_heif + heif_file = pillow_heif.read_heif(image_path) + if heif_file.data is None: + raise ValueError("HEIF 解码结果缺少像素数据") img = Image.frombytes( heif_file.mode, heif_file.size, heif_file.data, "raw", ).convert("RGB") - print(f"[HEIF] pillow-heif \u89e3\u7801\u6210\u529f: {img.size[0]}x{img.size[1]}") return img except ImportError: raise Exception( - "\u8bf7\u5b89\u88c5 pillow-heif \u6765\u652f\u6301 HIF/HEIC \u683c\u5f0f\uff1a pip install pillow-heif" + "请安装 pillow-heif 来支持 HIF/HEIC 格式: pip install pillow-heif" ) except Exception as e: - raise Exception(f"HEIF \u89e3\u7801\u5931\u8d25 ({os.path.basename(image_path)}): {e}") + raise Exception(f"HEIF 解码失败 ({os.path.basename(image_path)}): {e}") -# ==================== GPS 提取 ==================== -def extract_gps_from_exif(image_path: str) -> Tuple[Optional[float], Optional[float], str]: - """ - 从图像 EXIF 提取 GPS 坐标 - 支持 RAW 文件(使用 exiftool) - - Returns: - (纬度, 经度, 信息) 或 (None, None, 错误信息) - """ +def extract_gps_from_exif( + image_path: str, +) -> Tuple[Optional[float], Optional[float], str]: import subprocess import json as json_module - - # 首先尝试使用 exiftool(支持 RAW 格式) + try: - # 查找 exiftool exiftool_paths = [ - '/usr/local/bin/exiftool', - '/opt/homebrew/bin/exiftool', - 'exiftool', # 在 PATH 中查找 + "/usr/local/bin/exiftool", + "/opt/homebrew/bin/exiftool", + "exiftool", ] - + exiftool_path = None for path in exiftool_paths: try: - result = subprocess.run([path, '-ver'], capture_output=True, text=False, timeout=5) + result = subprocess.run( + [path, "-ver"], capture_output=True, text=False, timeout=5 + ) if result.returncode == 0: - # 解码输出 stdout_bytes = result.stdout - # 尝试多种编码解码 decoded_output = None - for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: + for encoding in ["utf-8", "gbk", "gb2312", "latin-1"]: try: decoded_output = stdout_bytes.decode(encoding) break except UnicodeDecodeError: continue - + if decoded_output is None: - # 如果所有编码都失败,使用 latin-1 作为最后手段(不会失败) - decoded_output = stdout_bytes.decode('latin-1') - - # 检查是否成功获取版本 + decoded_output = stdout_bytes.decode("latin-1") + if decoded_output.strip(): exiftool_path = path break except: continue - + if exiftool_path: - # 使用 exiftool 提取 GPS 信息 result = subprocess.run( - [exiftool_path, '-j', '-GPSLatitude', '-GPSLongitude', '-GPSLatitudeRef', '-GPSLongitudeRef', image_path], + [ + exiftool_path, + "-j", + "-GPSLatitude", + "-GPSLongitude", + "-GPSLatitudeRef", + "-GPSLongitudeRef", + image_path, + ], capture_output=True, - text=False, # 使用 bytes 模式,避免自动解码 - timeout=10 + text=False, + timeout=10, ) - + if result.returncode == 0 and result.stdout: stdout_bytes = result.stdout - # 尝试多种编码解码 decoded_output = None - for encoding in ['utf-8', 'gbk', 'gb2312', 'latin-1']: + for encoding in ["utf-8", "gbk", "gb2312", "latin-1"]: try: decoded_output = stdout_bytes.decode(encoding) break except UnicodeDecodeError: continue - + if decoded_output is None: - # 如果所有编码都失败,使用 latin-1 作为最后手段(不会失败) - decoded_output = stdout_bytes.decode('latin-1') - + decoded_output = stdout_bytes.decode("latin-1") + data = json_module.loads(decoded_output) if data and len(data) > 0: gps_data = data[0] - - lat_str = gps_data.get('GPSLatitude', '') - lon_str = gps_data.get('GPSLongitude', '') - lat_ref = gps_data.get('GPSLatitudeRef', 'N') - lon_ref = gps_data.get('GPSLongitudeRef', 'E') - + + lat_str = gps_data.get("GPSLatitude", "") + lon_str = gps_data.get("GPSLongitude", "") + lat_ref = gps_data.get("GPSLatitudeRef", "N") + lon_ref = gps_data.get("GPSLongitudeRef", "E") + if lat_str and lon_str: - # 解析度分秒格式,如 "27 deg 25' 0.53\" S" + def parse_dms(dms_str): import re - match = re.search(r'(\d+)\s*deg\s*(\d+)\'\s*([\d.]+)"?', str(dms_str)) + + match = re.search( + r'(\d+)\s*deg\s*(\d+)\'\s*([\d.]+)"?', str(dms_str) + ) if match: - d, m, s = float(match.group(1)), float(match.group(2)), float(match.group(3)) - return d + m/60 + s/3600 - # 尝试直接作为数字解析 + d, m, s = ( + float(match.group(1)), + float(match.group(2)), + float(match.group(3)), + ) + return d + m / 60 + s / 3600 try: return float(dms_str) except: return None - + lat = parse_dms(lat_str) lon = parse_dms(lon_str) - + if lat is not None and lon is not None: - # 处理南纬 (S 或 South) - if lat_ref and lat_ref.upper().startswith('S'): + if lat_ref and lat_ref.upper().startswith("S"): lat = -lat - # 处理西经 (W 或 West) - if lon_ref and lon_ref.upper().startswith('W'): + if lon_ref and lon_ref.upper().startswith("W"): lon = -lon - print(_t("logs.gps_extracted", lat=f"{lat:.6f}", lon=f"{lon:.6f}")) return lat, lon, f"GPS: {lat:.6f}, {lon:.6f}" except Exception as e: - print(_t("logs.gps_failed", e=e)) - - # 回退到 PIL(仅支持 JPEG 等常规格式) + pass + try: image = Image.open(image_path) - exif_data = image._getexif() + exif_data = image.getexif() if not exif_data: return None, None, "无EXIF数据" @@ -641,18 +645,22 @@ def parse_dms(dms_str): def convert_to_degrees(coord, ref): d, m, s = coord decimal = d + (m / 60.0) + (s / 3600.0) - if ref in ['S', 'W']: + if ref in ["S", "W"]: decimal = -decimal return decimal lat = None lon = None - if 'GPSLatitude' in gps_info and 'GPSLatitudeRef' in gps_info: - lat = convert_to_degrees(gps_info['GPSLatitude'], gps_info['GPSLatitudeRef']) + if "GPSLatitude" in gps_info and "GPSLatitudeRef" in gps_info: + lat = convert_to_degrees( + gps_info["GPSLatitude"], gps_info["GPSLatitudeRef"] + ) - if 'GPSLongitude' in gps_info and 'GPSLongitudeRef' in gps_info: - lon = convert_to_degrees(gps_info['GPSLongitude'], gps_info['GPSLongitudeRef']) + if "GPSLongitude" in gps_info and "GPSLongitudeRef" in gps_info: + lon = convert_to_degrees( + gps_info["GPSLongitude"], gps_info["GPSLongitudeRef"] + ) if lat is not None and lon is not None: return lat, lon, f"GPS: {lat:.6f}, {lon:.6f}" @@ -663,24 +671,20 @@ def convert_to_degrees(coord, ref): return None, None, f"GPS解析失败: {e}" -# ==================== 图像预处理 ==================== - def smart_resize(image: Image.Image, target_size: int = 224) -> Image.Image: - """智能图像尺寸调整""" width, height = image.size max_dim = max(width, height) if max_dim < 1000: - return image.resize((target_size, target_size), Image.LANCZOS) + return image.resize((target_size, target_size), RESAMPLING_LANCZOS) - resized = image.resize((256, 256), Image.LANCZOS) + resized = image.resize((256, 256), RESAMPLING_LANCZOS) left = (256 - target_size) // 2 top = (256 - target_size) // 2 return resized.crop((left, top, left + target_size, top + target_size)) def apply_enhancement(image: Image.Image, method: str = "unsharp_mask") -> Image.Image: - """应用图像增强""" if method == "unsharp_mask": return image.filter(ImageFilter.UnsharpMask()) elif method == "edge_enhance_more": @@ -694,72 +698,51 @@ def apply_enhancement(image: Image.Image, method: str = "unsharp_mask") -> Image return image -# ==================== OSEA 预处理 ==================== - -# CenterCrop 预处理: Resize(256) + CenterCrop(224) + ImageNet Normalize -# 用于原始大图(未经 YOLO 裁剪) -OSEA_TRANSFORM = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) - -# 直接缩放预处理: Resize(224, 224) with Lanczos + ImageNet Normalize -# 用于 YOLO 裁剪后的正方形图片(已经过 Letterboxing 处理) -# 使用 Lanczos 插值保证高质量缩放 -OSEA_TRANSFORM_DIRECT = transforms.Compose([ - transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) - +OSEA_TRANSFORM = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) + +OSEA_TRANSFORM_DIRECT = transforms.Compose( + [ + transforms.Resize( + (224, 224), interpolation=transforms.InterpolationMode.LANCZOS + ), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) -# ==================== 核心识别函数 ==================== def predict_bird( image: Image.Image, top_k: int = 5, species_class_ids: Optional[Set[int]] = None, is_yolo_cropped: bool = False, - name_format: str = None + name_format: Optional[str] = None, ) -> List[Dict]: - """ - 识别鸟类(OSEA ResNet34) - - Args: - image: PIL Image 对象 - top_k: 返回前 K 个结果 - species_class_ids: 区域物种的 class_id 集合(用于过滤) - is_yolo_cropped: 图片是否经过 YOLO 裁剪(用于选择预处理方式) - - Returns: - 识别结果列表 [{cn_name, en_name, confidence, ebird_code, ...}, ...] - """ model = get_classifier() db_manager = get_database_manager() - # 根据是否经过 YOLO 裁剪选择预处理方式 - # - YOLO 裁剪后: 直接 Resize(224,224),避免 CenterCrop 丢失特征 - # - 原始大图: Resize(256) + CenterCrop(224),鸟在中心时效果更好 - if image.mode != 'RGB': - image = image.convert('RGB') + if image.mode != "RGB": + image = image.convert("RGB") transform = OSEA_TRANSFORM_DIRECT if is_yolo_cropped else OSEA_TRANSFORM - input_tensor = transform(image).unsqueeze(0).to(CLASSIFIER_DEVICE) + transformed_tensor = cast(torch.Tensor, transform(image)) + input_tensor = transformed_tensor.unsqueeze(0).to(CLASSIFIER_DEVICE) - # 推理 with torch.no_grad(): output = model(input_tensor)[0] - # 截取有效类别数(模型输出可能多于实际物种数) num_classes = min(10964, output.shape[0]) output = output[:num_classes] - # Softmax(温度=0.9 更平滑:降低过高置信度,避免 99%+ 输出) TEMPERATURE = 0.9 best_probs = torch.nn.functional.softmax(output / TEMPERATURE, dim=0) - # 获取 top-k 结果 k = min(100 if species_class_ids else top_k, len(best_probs)) top_probs, top_indices = torch.topk(best_probs, k) @@ -767,7 +750,6 @@ def predict_bird( for i in range(len(top_indices)): class_id = top_indices[i].item() confidence = top_probs[i].item() * 100 - # 置信度阈值:使用区域过滤时降低阈值以保留更多候选 min_confidence = 0.3 if species_class_ids else 1.0 if confidence < min_confidence: continue @@ -778,54 +760,57 @@ def predict_bird( ebird_code = None description = None - # 优先从数据库获取 if db_manager: info = db_manager.get_bird_by_class_id(class_id) if info: - cn_name = info.get('chinese_simplified') - en_name = info.get('english_name') - scientific_name = info.get('scientific_name') - ebird_code = info.get('ebird_code') - description = info.get('short_description_zh') + cn_name = info.get("chinese_simplified") + en_name = info.get("english_name") + scientific_name = info.get("scientific_name") + ebird_code = info.get("ebird_code") + description = info.get("short_description_zh") if not cn_name: cn_name = f"Unknown (ID: {class_id})" en_name = f"Unknown (ID: {class_id})" - # AviList name format override if name_format and name_format != "default" and db_manager: avilist_info = db_manager.get_avilist_names_by_class_id(class_id) - if avilist_info and avilist_info.get('match_type') != 'no_match': + if avilist_info and avilist_info.get("match_type") != "no_match": if name_format == "scientific": - en_name = avilist_info.get('scientific_name_avilist') or scientific_name or en_name + en_name = ( + avilist_info.get("scientific_name_avilist") + or scientific_name + or en_name + ) else: - # Map format to column: avilist/clements/birdlife col = f"en_name_{name_format}" alt_name = avilist_info.get(col) - # Fallback chain: selected -> avilist -> keep default if alt_name: en_name = alt_name - elif name_format != "avilist" and avilist_info.get('en_name_avilist'): - en_name = avilist_info['en_name_avilist'] + elif name_format != "avilist" and avilist_info.get( + "en_name_avilist" + ): + en_name = avilist_info["en_name_avilist"] - # Avonet 地理过滤 region_match = False if species_class_ids: if class_id in species_class_ids: region_match = True else: - continue # 不在区域物种列表中,跳过 - - results.append({ - 'class_id': class_id, - 'cn_name': cn_name, - 'en_name': en_name, - 'scientific_name': scientific_name, - 'confidence': confidence, - 'ebird_code': ebird_code, - 'region_match': region_match, - 'description': description or '' - }) + continue + + results.append( + { + "class_id": class_id, + "cn_name": cn_name, + "en_name": en_name, + "scientific_name": scientific_name, + "confidence": confidence, + "ebird_code": ebird_code, + "region_match": region_match, + "description": description or "", + } + ) if len(results) >= top_k: break @@ -838,205 +823,157 @@ def identify_bird( use_yolo: bool = True, use_gps: bool = True, use_ebird: bool = True, - country_code: str = None, - region_code: str = None, + country_code: Optional[str] = None, + region_code: Optional[str] = None, top_k: int = 5, - name_format: str = None, - preloaded_crop=None, # PIL Image,主流水线已裁剪好时传入,跳过重复 YOLO + name_format: Optional[str] = None, + preloaded_crop: Optional[Image.Image] = None, ) -> Dict: - """ - 端到端鸟类识别 - - Args: - image_path: 图像路径(仍用于 GPS 提取) - use_yolo: 是否使用 YOLO 裁剪(preloaded_crop 存在时忽略) - use_gps: 是否使用 GPS 自动检测区域 - use_ebird: 是否启用 eBird 区域过滤 - country_code: 手动指定国家代码(如 "AU") - region_code: 手动指定区域代码(如 "AU-SA") - top_k: 返回前 K 个结果 - preloaded_crop: 预裁剪的鸟类区域 PIL Image,由调用方传入时跳过 YOLO - - Returns: - 识别结果字典 - """ result = { - 'success': False, - 'image_path': image_path, - 'results': [], - 'yolo_info': None, - 'gps_info': None, - 'ebird_info': None, - 'error': None + "success": False, + "image_path": image_path, + "results": [], + "yolo_info": None, + "gps_info": None, + "ebird_info": None, + "error": None, } try: - # 若调用方已提供裁剪好的鸟类区域,直接使用,跳过图像加载和 YOLO + is_yolo_cropped = False if preloaded_crop is not None: image = preloaded_crop is_yolo_cropped = True - result['yolo_info'] = {'preloaded': True} + result["yolo_info"] = {"preloaded": True} else: - # 加载图像 image = load_image(image_path) - # YOLO 裁剪(preloaded_crop 存在时已跳过) - if preloaded_crop is None: - is_yolo_cropped = False - print(f"[YOLO] use_yolo={use_yolo}, YOLO_AVAILABLE={YOLO_AVAILABLE}") if preloaded_crop is None and use_yolo and YOLO_AVAILABLE: width, height = image.size - print(f"[YOLO] image size: {width}x{height}") if max(width, height) > 640: detector = get_yolo_detector() - print(f"[YOLO] detector={detector is not None}") if detector: cropped, info = detector.detect_and_crop_bird(image) - print(f"[YOLO] detect result: cropped={cropped is not None}, info={info}") if cropped: image = cropped - result['yolo_info'] = info - result['cropped_image'] = cropped # square-cropped PIL Image + result["yolo_info"] = info + result["cropped_image"] = cropped is_yolo_cropped = True - print(f"[YOLO] ✅ Bird region cropped") else: - print(f"[YOLO] ⚠️ No bird detected") - # strict mode: no bird found, short-circuit - result['success'] = True - result['results'] = [] - result['yolo_info'] = {'bird_count': 0} + result["success"] = True + result["results"] = [] + result["yolo_info"] = {"bird_count": 0} return result - else: - print(f"[YOLO] Image too small, skipping crop") - else: - print(f"[YOLO] YOLO not enabled or unavailable") - # Avonet 地理过滤 species_class_ids = None - lat = lon = None # GPS 坐标(供后续回退使用) - species_filter = None # 物种过滤器(供后续回退使用) + lat = lon = None + species_filter = None - if use_ebird: # 参数名保持兼容,实际使用 Avonet + if use_ebird: try: species_filter = get_species_filter() - if not species_filter: - print(_t("logs.avonet_unavailable")) - else: - # 优先使用 GPS 坐标 + if species_filter: if use_gps: lat, lon, gps_msg = extract_gps_from_exif(image_path) if lat and lon: - result['gps_info'] = { - 'latitude': lat, - 'longitude': lon, - 'info': gps_msg + result["gps_info"] = { + "latitude": lat, + "longitude": lon, + "info": gps_msg, } - species_class_ids = species_filter.get_species_by_gps(lat, lon) - if species_class_ids: - print(f"[Avonet] GPS ({lat:.2f}, {lon:.2f}): {len(species_class_ids)} species") + species_class_ids = species_filter.get_species_by_gps( + lat, lon + ) - # 回退到区域代码(优先 eBird 离线物种列表,其次 Avonet 边界) if species_class_ids is None and (region_code or country_code): effective_region = region_code or country_code - # 优先使用 eBird 离线物种 JSON(精确到州/省) try: - ebird_ids, actual_region = species_filter.get_species_by_region_ebird(effective_region) + ebird_ids, actual_region = ( + species_filter.get_species_by_region_ebird( + effective_region + ) + ) if ebird_ids: species_class_ids = ebird_ids - print(f"[eBird] Region {effective_region}: {len(species_class_ids)} species (offline JSON)") except Exception as _e: - print(f"[eBird] State filter failed: {_e}") - # 如果 eBird 数据不可用,回退到 Avonet 边界框查询 + pass if not species_class_ids: - species_class_ids = species_filter.get_species_by_region(effective_region) - if species_class_ids: - print(f"[Avonet] Region {effective_region}: {len(species_class_ids)} species (bounds)") + species_class_ids = species_filter.get_species_by_region( + effective_region + ) - # 记录过滤信息 if species_class_ids: - result['ebird_info'] = { # 保持键名兼容 - 'enabled': True, - 'species_count': len(species_class_ids), - 'data_source': 'avonet.db (offline)', - 'region_code': region_code or country_code if not result.get('gps_info') else None + result["ebird_info"] = { + "enabled": True, + "species_count": len(species_class_ids), + "data_source": "avonet.db (offline)", + "region_code": ( + region_code or country_code + if not result.get("gps_info") + else None + ), } except Exception as e: - print(f"[Avonet] Filter init failed: {e}") + pass - # 执行识别 results = predict_bird( image, top_k=top_k, species_class_ids=species_class_ids, is_yolo_cropped=is_yolo_cropped, - name_format=name_format + name_format=name_format, ) - # GPS 过滤无匹配时,先尝试 eBird 国家级回退,再全局 if not results and species_class_ids: - print(f"[Avonet] ⚠️ No match after GPS filter ({len(species_class_ids)} species), trying eBird country fallback") - - # 第一步:eBird 国家级回退 country_cls_ids = None country_cc = None if lat is not None and lon is not None and species_filter is not None: try: - country_cls_ids, country_cc = species_filter.get_species_by_country_ebird(lat, lon) + country_cls_ids, country_cc = ( + species_filter.get_species_by_country_ebird(lat, lon) + ) except Exception as _e: - print(f"[eBird] Country fallback failed: {_e}") + pass if country_cls_ids: - print(f"[eBird] Trying country fallback: {country_cc} ({len(country_cls_ids)} species)") results = predict_bird( image, top_k=top_k, species_class_ids=country_cls_ids, is_yolo_cropped=is_yolo_cropped, - name_format=name_format + name_format=name_format, ) if results: - if not result.get('ebird_info'): - result['ebird_info'] = {} - result['ebird_info']['country_fallback'] = True - result['ebird_info']['country_code'] = country_cc + if not result.get("ebird_info"): + result["ebird_info"] = {} + result["ebird_info"]["country_fallback"] = True + result["ebird_info"]["country_code"] = country_cc - # 第二步:仍无结果 → 全局模式 if not results: - print(f"[Avonet] ⚠️ Country fallback still no match, switching to global mode") results = predict_bird( image, top_k=top_k, species_class_ids=None, is_yolo_cropped=is_yolo_cropped, - name_format=name_format + name_format=name_format, ) - if results and result.get('ebird_info'): - result['ebird_info']['gps_fallback'] = True + if results and result.get("ebird_info"): + result["ebird_info"]["gps_fallback"] = True - result['success'] = True - result['results'] = results + result["success"] = True + result["results"] = results except Exception as e: - result['error'] = str(e) + result["error"] = str(e) return result -# ==================== 便捷函数 ==================== - def quick_identify(image_path: str, top_k: int = 3) -> List[Dict]: - """ - 快速识别(简化接口) - - Returns: - 识别结果列表 - """ result = identify_bird(image_path, top_k=top_k) - return result.get('results', []) - + return result.get("results", []) -# ==================== 测试 ==================== if __name__ == "__main__": print("BirdIdentifier 模块测试") diff --git a/birdid/osea_classifier.py b/birdid/osea_classifier.py index f5b6b5e..7171619 100644 --- a/birdid/osea_classifier.py +++ b/birdid/osea_classifier.py @@ -14,14 +14,17 @@ import os import sqlite3 -import sys from pathlib import Path from typing import Dict, List, Optional, Set import torch from PIL import Image from torchvision import models, transforms -from config import get_best_device +from config import ( + get_best_device, + get_install_scoped_resource_path, + get_packaged_model_relative_path, +) def _torch_load_compat(path: str, *, map_location: str, weights_only: bool): @@ -73,11 +76,14 @@ def _load_osea_checkpoint(model_path: str): return _torch_load_compat(model_path, map_location="cpu", weights_only=True) except Exception as e: if _should_retry_without_weights_only(e): - print("[OSEA] weights_only=True 加载失败,回退 weights_only=False(仅限可信模型)") - return _torch_load_compat(model_path, map_location="cpu", weights_only=False) + print( + "[OSEA] weights_only=True 加载失败,回退 weights_only=False(仅限可信模型)" + ) + return _torch_load_compat( + model_path, map_location="cpu", weights_only=False + ) raise -# ==================== 路径配置 ==================== def _get_birdid_dir() -> Path: """获取 birdid 模块目录""" @@ -90,38 +96,36 @@ def _get_project_root() -> Path: def _get_resource_path(relative_path: str) -> Path: - """获取资源路径 (支持 PyInstaller 打包)""" - if getattr(sys, 'frozen', False): - base = Path(sys._MEIPASS) - else: - base = _get_project_root() - return base / relative_path - - -# ==================== 设备配置 ==================== + """获取资源路径 (支持安装目录约束的打包场景)""" + packaged_relative_path = None + if relative_path.startswith("models/"): + packaged_relative_path = get_packaged_model_relative_path(relative_path) + return get_install_scoped_resource_path( + relative_path, packaged_relative_path=packaged_relative_path + ) DEVICE = get_best_device() -# ==================== 预处理 transforms ==================== - -CENTER_CROP_TRANSFORM = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) +CENTER_CROP_TRANSFORM = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) -BASELINE_TRANSFORM = transforms.Compose([ - transforms.Resize((224, 224)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), -]) +BASELINE_TRANSFORM = transforms.Compose( + [ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] +) -# ==================== OSEA 分类器 ==================== - class OSEAClassifier: """ OSEA ResNet34 鸟类分类器 @@ -152,7 +156,9 @@ def __init__( """ self.device = device or DEVICE self.use_center_crop = use_center_crop - self.transform = CENTER_CROP_TRANSFORM if use_center_crop else BASELINE_TRANSFORM + self.transform = ( + CENTER_CROP_TRANSFORM if use_center_crop else BASELINE_TRANSFORM + ) self.model_path = model_path or str(_get_resource_path(self.DEFAULT_MODEL_PATH)) self.model = self._load_model() @@ -199,13 +205,15 @@ def _load_bird_info(self) -> List[List[str]]: conn.close() num_classes = 10964 - bird_info: List[List[str]] = [['Unknown', 'Unknown', ''] for _ in range(num_classes)] + bird_info: List[List[str]] = [ + ["Unknown", "Unknown", ""] for _ in range(num_classes) + ] for class_id, cn_name, en_name, scientific_name in rows: if 0 <= class_id < num_classes: bird_info[class_id] = [ - cn_name or 'Unknown', - en_name or 'Unknown', - scientific_name or '', + cn_name or "Unknown", + en_name or "Unknown", + scientific_name or "", ] return bird_info @@ -228,8 +236,8 @@ def predict( Returns: 识别结果列表 [{cn_name, en_name, scientific_name, confidence, class_id}, ...] """ - if image.mode != 'RGB': - image = image.convert('RGB') + if image.mode != "RGB": + image = image.convert("RGB") input_tensor = self.transform(image).unsqueeze(0).to(self.device) @@ -237,7 +245,7 @@ def predict( with torch.no_grad(): output = self.model(input_tensor)[0] - output = output[:self.num_classes] + output = output[: self.num_classes] probs = torch.nn.functional.softmax(output / temperature, dim=0) k = min(100 if ebird_species_set else top_k, self.num_classes) @@ -257,16 +265,16 @@ def predict( en_name = info[1] scientific_name = info[2] if len(info) > 2 else None - ebird_match = False - - results.append({ - 'class_id': class_id, - 'cn_name': cn_name, - 'en_name': en_name, - 'scientific_name': scientific_name, - 'confidence': confidence, - 'ebird_match': ebird_match, - }) + results.append( + { + "class_id": class_id, + "cn_name": cn_name, + "en_name": en_name, + "scientific_name": scientific_name, + "confidence": confidence, + "ebird_match": False, + } + ) if len(results) >= top_k: break @@ -286,8 +294,8 @@ def predict_with_tta( TTA 策略: 原图 + 水平翻转取平均 推理时间翻倍,但可能提高准确率 """ - if image.mode != 'RGB': - image = image.convert('RGB') + if image.mode != "RGB": + image = image.convert("RGB") input1 = self.transform(image).unsqueeze(0).to(self.device) @@ -296,8 +304,8 @@ def predict_with_tta( self.model.eval() with torch.no_grad(): - output1 = self.model(input1)[0][:self.num_classes] - output2 = self.model(input2)[0][:self.num_classes] + output1 = self.model(input1)[0][: self.num_classes] + output2 = self.model(input2)[0][: self.num_classes] avg_output = (output1 + output2) / 2 probs = torch.nn.functional.softmax(avg_output / temperature, dim=0) @@ -319,14 +327,16 @@ def predict_with_tta( en_name = info[1] scientific_name = info[2] if len(info) > 2 else None - results.append({ - 'class_id': class_id, - 'cn_name': cn_name, - 'en_name': en_name, - 'scientific_name': scientific_name, - 'confidence': confidence, - 'ebird_match': False, - }) + results.append( + { + "class_id": class_id, + "cn_name": cn_name, + "en_name": en_name, + "scientific_name": scientific_name, + "confidence": confidence, + "ebird_match": False, + } + ) if len(results) >= top_k: break @@ -334,8 +344,6 @@ def predict_with_tta( return results -# ==================== 全局单例 ==================== - _osea_classifier: Optional[OSEAClassifier] = None @@ -347,8 +355,6 @@ def get_osea_classifier() -> OSEAClassifier: return _osea_classifier -# ==================== 便捷函数 ==================== - def osea_predict(image: Image.Image, top_k: int = 5) -> List[Dict]: """快速 OSEA 预测""" classifier = get_osea_classifier() @@ -358,12 +364,11 @@ def osea_predict(image: Image.Image, top_k: int = 5) -> List[Dict]: def osea_predict_file(image_path: str, top_k: int = 5) -> List[Dict]: """OSEA 预测 (从文件路径)""" from birdid.bird_identifier import load_image + image = load_image(image_path) return osea_predict(image, top_k=top_k) -# ==================== 测试 ==================== - if __name__ == "__main__": import argparse @@ -374,6 +379,7 @@ def osea_predict_file(image_path: str, top_k: int = 5) -> List[Dict]: args = parser.parse_args() from birdid.bird_identifier import load_image + image = load_image(args.image) classifier = OSEAClassifier() diff --git a/build_release.bat b/build_release.bat index e358fe9..1ac6092 100644 --- a/build_release.bat +++ b/build_release.bat @@ -1,404 +1,46 @@ @echo off -chcp 65001 >nul -setlocal EnableExtensions EnableDelayedExpansion +setlocal EnableExtensions -set "APP_NAME=SuperPicky" -set "SPEC_FILE=SuperPicky_win64.spec" -set "ROOT_DIR=%~dp0" -set "ROOT_DIR=%ROOT_DIR:~0,-1%" -cd /d "%ROOT_DIR%" +set "SCRIPT_DIR=%~dp0" +set "PYTHON_EXE=%SCRIPT_DIR%.venv\Scripts\python.exe" +if not exist "%PYTHON_EXE%" set "PYTHON_EXE=python" -set "VERSION_ARG=" -set "ZIP_COPY_DIR=" - -if "!OUT_DIST_DIR!"=="" set "OUT_DIST_DIR=dist" - -set "BUILD_ZIP=1" - -call :parse_args %* -if errorlevel 1 exit /b 1 -if defined SHOW_HELP goto :show_help - -goto :start - -:show_help -echo SuperPicky Windows build script -echo. -echo Usage: -echo %~nx0 [version] [zip_copy_dir] -echo. -echo version 版本号 ^(如 4.0.6^),用于 ZIP 文件名;缺省则从 ui/about_dialog.py 读取 -echo zip_copy_dir 目标目录 ^(如 E:\_SuperPickyVersions^);若指定则复制 SuperPicky 为 SuperPicky_版本号 并打 zip -echo. -exit /b 0 - -:parse_args -:parse_args_loop -if "%~1"=="" exit /b 0 - -if /i "%~1"=="--help" ( - set "SHOW_HELP=1" - exit /b 0 -) -if /i "%~1"=="-h" ( - set "SHOW_HELP=1" - exit /b 0 -) -if "%VERSION_ARG%"=="" ( - set "VERSION_ARG=%~1" -) else if "!ZIP_COPY_DIR!"=="" ( - set "ZIP_COPY_DIR=%~1" -) else ( - echo [WARNING] Ignored extra argument: %~1 -) -shift -goto :parse_args_loop - -:start -echo. -echo [========================================] -echo Step 0: Clean old build files -echo [========================================] - -rem Set Inno Setup directory -set "INNO_DIR=%ROOT_DIR%\inno" - -rem Clean old build directories -if exist "%ROOT_DIR%\build_dist" rd /s /q "%ROOT_DIR%\build_dist" >nul 2>&1 -if exist "%ROOT_DIR%\build_dist_cpu" rd /s /q "%ROOT_DIR%\build_dist_cpu" >nul 2>&1 -if exist "%ROOT_DIR%\build_dist_cuda" rd /s /q "%ROOT_DIR%\build_dist_cuda" >nul 2>&1 -if exist "%ROOT_DIR%\dist" rd /s /q "%ROOT_DIR%\dist" >nul 2>&1 -if exist "%ROOT_DIR%\dist_cpu" rd /s /q "%ROOT_DIR%\dist_cpu" >nul 2>&1 -if exist "%ROOT_DIR%\dist_cuda" rd /s /q "%ROOT_DIR%\dist_cuda" >nul 2>&1 -if exist "%ROOT_DIR%\output" rd /s /q "%ROOT_DIR%\output" >nul 2>&1 - -echo [SUCCESS] Cleaned old build files - -echo. -echo [========================================] -echo Step 1: Environment check -echo [========================================] - -if not exist "%SPEC_FILE%" ( - echo [ERROR] Missing spec file: %SPEC_FILE% - exit /b 1 -) - -echo [SUCCESS] Spec file found: %SPEC_FILE% - -if "!PYTHON_EXE!"=="" set "PYTHON_EXE=python" -rem Prefer current env Python (e.g. activated venv): use first "python" in PATH -if "!PYTHON_EXE!"=="python" ( - where python >nul 2>nul && for /f "tokens=*" %%i in ('python -c "import sys; print(sys.executable)" 2^>nul') do set "PYTHON_EXE=%%i" -) -if "!PYTHON_EXE!"=="" set "PYTHON_EXE=python" -call :check_python "!PYTHON_EXE!" "default" -if errorlevel 1 exit /b 1 - -echo. -echo [========================================] -echo Step 1: Resolve version -echo [========================================] - -set "VERSION=4.0.5_sp3" -if not "!VERSION_ARG!"=="" ( - set "VERSION=!VERSION_ARG!" - echo [SUCCESS] Use version from args: !VERSION! -) else ( - for /f "usebackq delims=" %%i in (`powershell -NoProfile -Command "$c=Get-Content -Path 'ui/about_dialog.py' -Raw -Encoding UTF8; if($c -match 'v([0-9A-Za-z._-]+)'){ $matches[1] }"`) do ( - set "VERSION=%%i" - ) - if "!VERSION!"=="" set "VERSION=0.0.0" - echo [SUCCESS] Detected version: v!VERSION! -) - -echo. -echo [========================================] -echo Step 1.5: Inject build metadata -echo [========================================] - -set "COMMIT_HASH=unknown" -rem 优先从 Python 代码读取 COMMIT_HASH(保证跨平台一致) -for /f "tokens=*" %%i in ('"%PYTHON_EXE%" -c "exec('try:\n from core.build_info_local import COMMIT_HASH\nexcept ImportError:\n from core.build_info import COMMIT_HASH\nprint(COMMIT_HASH or chr(0))')" 2^>nul') do set "COMMIT_HASH=%%i" -if "%COMMIT_HASH%"=="" for /f "tokens=*" %%i in ('git rev-parse --short HEAD 2^>nul') do set "COMMIT_HASH=%%i" -if "%COMMIT_HASH%"=="" set "COMMIT_HASH=unknown" -echo [INFO] Commit hash: %COMMIT_HASH% - -set "BUILD_INFO_FILE=core\build_info.py" -set "BUILD_INFO_BACKUP=core\build_info.py.backup" -if exist "%BUILD_INFO_FILE%" copy /y "%BUILD_INFO_FILE%" "%BUILD_INFO_BACKUP%" >nul - -powershell -NoProfile -Command "(Get-Content -Path '%BUILD_INFO_FILE%' -Raw -Encoding UTF8) -replace 'COMMIT_HASH\s*=\s*.*', 'COMMIT_HASH = \"%COMMIT_HASH%\"' | Set-Content -Path '%BUILD_INFO_FILE%' -Encoding UTF8" -if errorlevel 1 ( - echo [ERROR] Failed to inject build info - call :restore_build_info >nul - exit /b 1 -) +if /I "%~1"=="--help" goto :show_help +if /I "%~1"=="-h" goto :show_help -echo [SUCCESS] Build info injected +set "FIRST_ARG=%~1" -call :build_single -set "RET=%ERRORLEVEL%" -call :restore_build_info >nul -exit /b %RET% +if "%~1"=="" goto :run_default +if "%FIRST_ARG:~0,2%"=="--" goto :run_passthrough +if not "%~3"=="" goto :show_positional_error -:check_python -set "CHECK_PY=%~1" -set "CHECK_LABEL=%~2" - -echo [INFO] Checking Python (%CHECK_LABEL%): %CHECK_PY% -"%CHECK_PY%" -c "import sys; print(sys.executable)" >nul 2>nul -if errorlevel 1 ( - echo [ERROR] Python not available: %CHECK_PY% - exit /b 1 -) -for /f "tokens=*" %%i in ('"%CHECK_PY%" -c "import sys; print(sys.executable)" 2^>nul') do set "_PY_RESOLVED=%%i" -echo [SUCCESS] Python (%CHECK_LABEL%): !_PY_RESOLVED! - -echo [INFO] Checking PyInstaller (%CHECK_LABEL%)... -"%CHECK_PY%" -c "import PyInstaller" >nul 2>nul -if errorlevel 1 ( - echo [ERROR] PyInstaller missing in %CHECK_LABEL% environment - exit /b 1 -) -echo [SUCCESS] PyInstaller is available (%CHECK_LABEL%) -exit /b 0 - -:build_with_python -set "B_PY=%~1" -set "B_WORK=%~2" -set "B_DIST=%~3" -set "B_LABEL=%~4" - -echo. -echo [========================================] -echo Build: %B_LABEL% -echo [========================================] - -if exist "%B_WORK%" rd /s /q "%B_WORK%" -if exist "%B_DIST%" rd /s /q "%B_DIST%" - -"%B_PY%" -m PyInstaller "%SPEC_FILE%" --clean --noconfirm --workpath "%B_WORK%" --distpath "%B_DIST%" -set "PYI_RC=%ERRORLEVEL%" -echo [INFO] PyInstaller process rc (%B_LABEL%): %PYI_RC% -if not "%PYI_RC%"=="0" ( - echo [WARNING] PyInstaller returned non-zero [%B_LABEL%]: %PYI_RC% -) - -if not exist "%B_DIST%\%APP_NAME%\SuperPicky.exe" ( - echo [ERROR] Missing output exe: %B_DIST%\%APP_NAME%\SuperPicky.exe - exit /b 1 -) - -echo [SUCCESS] Build completed (%B_LABEL%) -exit /b 0 - -rem Copy folder C_SRC into C_DST using robocopy (more reliable than xcopy for deep trees). -:copy_dir -set "C_SRC=%~1" -set "C_DST=%~2" - -if not exist "%C_SRC%" ( - echo [ERROR] Copy source not found: %C_SRC% - exit /b 1 -) - -if not exist "%C_DST%" mkdir "%C_DST%" -if errorlevel 1 ( - echo [ERROR] Failed to create target dir: %C_DST% - exit /b 1 -) - -robocopy "%C_SRC%" "%C_DST%" /E /R:2 /W:1 /NFL /NDL /NJH /NJS /NP >nul -set "COPY_RC=%ERRORLEVEL%" -if !COPY_RC! GEQ 8 ( - echo [ERROR] Failed to copy to %C_DST% ^(robocopy exit code !COPY_RC!^) - exit /b 1 -) - -echo [SUCCESS] Copied directory: %C_SRC% -^> %C_DST% -exit /b 0 - -rem Zip folder Z_SRC into Z_OUT. Archive contains one top-level folder (e.g. SuperPicky\) so unzip gives one dir. -:zip_dir -set "Z_SRC=%~1" -set "Z_OUT=%~2" - -if not exist "%Z_SRC%" ( - echo [ERROR] Zip source not found: %Z_SRC% - exit /b 1 -) - -if exist "%Z_OUT%" del /q "%Z_OUT%" >nul 2>&1 - -where 7z >nul 2>&1 -if not errorlevel 1 ( - 7z a -tzip "%Z_OUT%" "%Z_SRC%" -r >nul - if errorlevel 1 ( - echo [ERROR] Failed to create zip with 7z: %Z_OUT% - exit /b 1 - ) -) else ( - powershell -NoProfile -Command "Compress-Archive -Path '%Z_SRC%' -DestinationPath '%Z_OUT%' -Force" - if errorlevel 1 ( - echo [ERROR] Failed to create zip with Compress-Archive: %Z_OUT% - exit /b 1 - ) -) - -echo [SUCCESS] Created zip: %Z_OUT% -exit /b 0 +set "VERSION_ARG=" +set "COPY_DIR_ARG=" -:build_single -set "WORK_DIR=%ROOT_DIR%\build_!OUT_DIST_DIR!" -set "DIST_DIR=%ROOT_DIR%\!OUT_DIST_DIR!" +if not "%~1"=="" set "VERSION_ARG=--version %~1" +if not "%~2"=="" set "COPY_DIR_ARG=--copy-dir %~2" -call :build_with_python "%PYTHON_EXE%" "!WORK_DIR!" "!DIST_DIR!" "release" -set "BUILD_RC=%ERRORLEVEL%" -echo [INFO] build_with_python rc: !BUILD_RC! -if !BUILD_RC! NEQ 0 exit /b !BUILD_RC! +:run_default +"%PYTHON_EXE%" "%SCRIPT_DIR%build_release_win.py" --build-type cpu %VERSION_ARG% %COPY_DIR_ARG% +exit /b %ERRORLEVEL% -rem Default: always create one release ZIP -if "%BUILD_ZIP%"=="1" ( - set "ZIP_NAME=!APP_NAME!_v!VERSION!_Win64.zip" - - rem Remove Inno Setup files before creating zip - if exist "!DIST_DIR!\!APP_NAME!\SuperPicky.iss" del /q "!DIST_DIR!\!APP_NAME!\SuperPicky.iss" >nul 2>&1 - if exist "!DIST_DIR!\!APP_NAME!\ChineseSimplified.isl" del /q "!DIST_DIR!\!APP_NAME!\ChineseSimplified.isl" >nul 2>&1 - - call :zip_dir "!DIST_DIR!\!APP_NAME!" "!DIST_DIR!\!ZIP_NAME!" - if errorlevel 1 exit /b 1 - - rem Restore Inno Setup files after creating zip - if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "!DIST_DIR!\!APP_NAME!\SuperPicky.iss" >nul - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '!DIST_DIR!\!APP_NAME!\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=!VERSION!' | Set-Content -Path '!DIST_DIR!\!APP_NAME!\SuperPicky.iss' -Encoding UTF8" - ) - if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "!DIST_DIR!\!APP_NAME!\ChineseSimplified.isl" >nul - ) - - if defined ZIP_COPY_DIR ( - set "TARGET_SUBDIR=%APP_NAME%_!VERSION!" - set "TARGET_DIR=!ZIP_COPY_DIR!\!TARGET_SUBDIR!" - if not exist "!ZIP_COPY_DIR!" mkdir "!ZIP_COPY_DIR!" - if errorlevel 1 ( - echo [ERROR] Failed to create copy root dir: !ZIP_COPY_DIR! - exit /b 1 - ) - if exist "!TARGET_DIR!" rd /s /q "!TARGET_DIR!" - if exist "!TARGET_DIR!" ( - echo [ERROR] Failed to clean old target dir: !TARGET_DIR! - exit /b 1 - ) - call :copy_dir "%DIST_DIR%\%APP_NAME%" "!TARGET_DIR!" - if errorlevel 1 exit /b 1 - - rem Copy Inno Setup files to target directory - if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "!TARGET_DIR!\SuperPicky.iss" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy SuperPicky.iss to target directory - exit /b 1 - ) - echo [SUCCESS] Copied SuperPicky.iss to !TARGET_DIR! - - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=!VERSION!' | Set-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Encoding UTF8" - if errorlevel 1 ( - echo [ERROR] Failed to update version in SuperPicky.iss in target directory - exit /b 1 - ) - echo [SUCCESS] Updated version in SuperPicky.iss to !VERSION! in target directory - ) - - if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "!TARGET_DIR!\ChineseSimplified.isl" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy ChineseSimplified.isl to target directory - exit /b 1 - ) - echo [SUCCESS] Copied ChineseSimplified.isl to !TARGET_DIR! - ) - - rem Remove Inno Setup files before creating zip - if exist "!TARGET_DIR!\SuperPicky.iss" del /q "!TARGET_DIR!\SuperPicky.iss" >nul 2>&1 - if exist "!TARGET_DIR!\ChineseSimplified.isl" del /q "!TARGET_DIR!\ChineseSimplified.isl" >nul 2>&1 - - call :zip_dir "!TARGET_DIR!" "!ZIP_COPY_DIR!\!TARGET_SUBDIR!.zip" - if errorlevel 1 exit /b 1 - - rem Restore Inno Setup files after creating zip - if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "!TARGET_DIR!\SuperPicky.iss" >nul - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=!VERSION!' | Set-Content -Path '!TARGET_DIR!\SuperPicky.iss' -Encoding UTF8" - ) - if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "!TARGET_DIR!\ChineseSimplified.isl" >nul - ) - - echo [SUCCESS] Copied !TARGET_SUBDIR! + created !ZIP_COPY_DIR!\!TARGET_SUBDIR!.zip - ) -) else ( - set "ZIP_NAME=" - echo [INFO] ZIP creation skipped ^(--no-zip^) -) +:run_passthrough +"%PYTHON_EXE%" "%SCRIPT_DIR%build_release_win.py" --build-type cpu %* +exit /b %ERRORLEVEL% +:show_help +echo SuperPicky Windows compatibility wrapper echo. -echo [========================================] -echo Step 4: Copy Inno Setup files -echo [========================================] - -set "INNO_DIR=%ROOT_DIR%\inno" -set "OUTPUT_EXE_DIR=%DIST_DIR%\%APP_NAME%" - -if exist "%INNO_DIR%\SuperPicky.iss" ( - copy /y "%INNO_DIR%\SuperPicky.iss" "%OUTPUT_EXE_DIR%\SuperPicky.iss" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy SuperPicky.iss - exit /b 1 - ) - echo [SUCCESS] Copied SuperPicky.iss to %OUTPUT_EXE_DIR% - - rem Update version in iss file - powershell -NoProfile -Command "(Get-Content -Path '%OUTPUT_EXE_DIR%\SuperPicky.iss' -Raw -Encoding UTF8) -replace 'VersionInfoVersion=.*', 'VersionInfoVersion=%VERSION%' | Set-Content -Path '%OUTPUT_EXE_DIR%\SuperPicky.iss' -Encoding UTF8" - if errorlevel 1 ( - echo [ERROR] Failed to update version in SuperPicky.iss - exit /b 1 - ) - echo [SUCCESS] Updated version in SuperPicky.iss to %VERSION% -) else ( - echo [WARNING] SuperPicky.iss not found in %INNO_DIR% -) - -if exist "%INNO_DIR%\ChineseSimplified.isl" ( - copy /y "%INNO_DIR%\ChineseSimplified.isl" "%OUTPUT_EXE_DIR%\ChineseSimplified.isl" >nul - if errorlevel 1 ( - echo [ERROR] Failed to copy ChineseSimplified.isl - exit /b 1 - ) - echo [SUCCESS] Copied ChineseSimplified.isl to %OUTPUT_EXE_DIR% -) else ( - echo [WARNING] ChineseSimplified.isl not found in %INNO_DIR% -) - +echo Usage: +echo %~nx0 [version] [copy_dir] +echo %~nx0 [build_release_win.py options] echo. -echo [========================================] -echo Build finished -echo [========================================] -echo EXE: %DIST_DIR%\%APP_NAME%\SuperPicky.exe -if defined ZIP_NAME ( -echo ZIP: %DIST_DIR%\%ZIP_NAME% -if defined ZIP_COPY_DIR echo Copy: %ZIP_COPY_DIR%\%APP_NAME%_%VERSION% + .zip -) else ( -echo ZIP: ^(skipped^) -) +echo This wrapper forwards to build_release_win.py --build-type cpu. +echo If the first argument starts with --, all arguments are passed through directly. exit /b 0 -:restore_build_info -if exist "%BUILD_INFO_BACKUP%" ( - move /y "%BUILD_INFO_BACKUP%" "%BUILD_INFO_FILE%" >nul -) -exit /b 0 +:show_positional_error +echo [ERROR] Positional compatibility mode only accepts [version] [copy_dir]. +echo [ERROR] If you need extra options such as --debug or --help, use explicit option mode. +echo [ERROR] Example: build_release.bat --debug --help +exit /b 1 \ No newline at end of file diff --git a/build_release.sh b/build_release.sh index 7d004a3..bd62425 100755 --- a/build_release.sh +++ b/build_release.sh @@ -1,378 +1,34 @@ -#!/bin/bash -# SuperPicky - 打包、签名和公证脚本 -# 作者: James Zhen Yu -# 版本: 1.0 -# -# 用法: -# ./build_release.sh --test # 仅打包和签名(跳过公证) -# ./build_release.sh --release # 完整流程:打包、签名、公证 -# ./build_release.sh --help # 显示帮助 +#!/bin/sh +set -eu -set -e # 遇到错误立即退出 - -# ============================================ -# 配置参数 -# ============================================ -APP_NAME="SuperPicky" -BUNDLE_ID="com.jamesphotography.superpicky" -DEVELOPER_ID="Developer ID Application: James Zhen Yu (JWR6FDB52H)" -APPLE_ID="james@jamesphotography.com.au" -TEAM_ID="JWR6FDB52H" -KEYCHAIN_ITEM="SuperPicky-Notarize" - -# 颜色输出 -RED='\033[0;31m' -GREEN='\033[0;32m' -YELLOW='\033[1;33m' -BLUE='\033[0;34m' -CYAN='\033[0;36m' -NC='\033[0m' # No Color - -# ============================================ -# 辅助函数 -# ============================================ -log_info() { - echo -e "${BLUE}[INFO]${NC} $1" -} - -log_success() { - echo -e "${GREEN}[SUCCESS]${NC} $1" -} - -log_warning() { - echo -e "${YELLOW}[WARNING]${NC} $1" -} - -log_error() { - echo -e "${RED}[ERROR]${NC} $1" -} - -log_step() { - echo -e "\n${CYAN}========================================${NC}" - echo -e "${CYAN}步骤$1: $2${NC}" - echo -e "${CYAN}========================================${NC}" -} +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +MODE_ARGS="" show_help() { - echo "SuperPicky 构建脚本" - echo "" - echo "用法: $0 [选项]" + echo "SuperPicky macOS compatibility wrapper" echo "" - echo "选项:" - echo " --test 仅打包和签名,跳过公证(用于快速测试)" - echo " --release 完整流程:打包、签名、公证、装订" - echo " --help 显示此帮助信息" - echo "" - echo "首次使用前,需要配置 Keychain:" - echo " security add-generic-password -a \"${APPLE_ID}\" \\" - echo " -s \"${KEYCHAIN_ITEM}\" -w \"你的App-Specific-Password\"" + echo "Usage:" + echo " ./build_release.sh --test [extra build_release_mac.py args]" + echo " ./build_release.sh --release [extra build_release_mac.py args]" echo "" + echo "This wrapper forwards to build_release_mac.py --build-type full." + echo "Use --release to append --notarize." } -# ============================================ -# 参数解析 -# ============================================ -MODE="" -if [ $# -eq 0 ]; then - show_help - exit 0 -fi - -case "$1" in - --test) - MODE="test" - ;; - --release) - MODE="release" - ;; - --help|-h) - show_help - exit 0 - ;; - *) - log_error "未知选项: $1" - show_help - exit 1 - ;; -esac - -# ============================================ -# 步骤0: 环境检查 -# ============================================ -log_step "0" "环境检查" - -# 检查开发者证书 -log_info "检查开发者证书..." -if ! security find-identity -v -p codesigning | grep -q "${DEVELOPER_ID}"; then - log_error "未找到开发者证书: ${DEVELOPER_ID}" - log_info "请确保已在 Keychain 中安装有效的开发者证书" - exit 1 -fi -log_success "开发者证书已就绪" - -# 检查 Keychain 密码(仅 release 模式) -if [ "$MODE" = "release" ]; then - log_info "检查 Keychain 中的 App-Specific Password..." - if ! security find-generic-password -a "${APPLE_ID}" -s "${KEYCHAIN_ITEM}" -w &>/dev/null; then - log_error "未在 Keychain 中找到 App-Specific Password" - log_info "请运行以下命令添加密码:" - echo "" - echo " security add-generic-password -a \"${APPLE_ID}\" \\" - echo " -s \"${KEYCHAIN_ITEM}\" -w \"你的密码\"" - echo "" - exit 1 - fi - log_success "Keychain 密码已配置" -fi - -# 检查 PyInstaller -log_info "检查 PyInstaller..." -if ! command -v pyinstaller &>/dev/null; then - # 尝试从虚拟环境 - if [ -f ".venv/bin/pyinstaller" ]; then - PYINSTALLER=".venv/bin/pyinstaller" - else - log_error "未找到 PyInstaller,请先安装: pip install pyinstaller" - exit 1 - fi -else - PYINSTALLER="pyinstaller" -fi -log_success "PyInstaller 已就绪" - -# 检查 entitlements.plist -if [ ! -f "entitlements.plist" ]; then - log_error "未找到 entitlements.plist 文件" - exit 1 -fi - -# ============================================ -# 步骤1: 提取版本号 -# ============================================ -log_step "1" "提取版本号" - -VERSION=$(grep 'APP_VERSION' constants.py | grep -oE '"[0-9]+\.[0-9]+\.[0-9]+"' | tr -d '"' | head -1) -if [ -z "$VERSION" ]; then - log_error "无法从 constants.py 提取版本号" - exit 1 -fi -log_success "检测到版本: v${VERSION}" - -# ============================================ -# 步骤1.5: 检测 CPU 架构 -# ============================================ -log_info "检测 CPU 架构..." -ARCH=$(uname -m) -if [ "$ARCH" = "arm64" ]; then - ARCH_SUFFIX="arm64" - log_success "检测到 Apple Silicon (arm64)" -elif [ "$ARCH" = "x86_64" ]; then - ARCH_SUFFIX="intel" - log_success "检测到 Intel (x86_64)" -else - ARCH_SUFFIX="$ARCH" - log_warning "未知架构: $ARCH" -fi - -# 设置输出文件名(包含架构信息) -if [ "$MODE" = "test" ]; then - DMG_NAME="${APP_NAME}_v${VERSION}_${ARCH_SUFFIX}_test.dmg" -else - DMG_NAME="${APP_NAME}_v${VERSION}_${ARCH_SUFFIX}.dmg" -fi -DMG_PATH="dist/${DMG_NAME}" - -# ============================================ -# 步骤2: 清理旧文件 -# ============================================ -log_step "2" "清理旧文件" - -rm -rf build dist -mkdir -p dist -log_success "清理完成" - -# ============================================ -# 步骤2.5: 注入 Git Commit Hash -# ============================================ -log_step "2.5" "注入构建信息" - -# 从 Python 代码读取 Commit Hash(保证跨平台一致) -COMMIT_HASH=$(python3 -c " -try: - from core.build_info_local import COMMIT_HASH -except ImportError: - from core.build_info import COMMIT_HASH -print(COMMIT_HASH or 'unknown') -") -log_info "Commit Hash: ${COMMIT_HASH}" - -# 备份原始 build_info.py -BUILD_INFO_FILE="core/build_info.py" -BUILD_INFO_BACKUP="${BUILD_INFO_FILE}.backup" -cp "${BUILD_INFO_FILE}" "${BUILD_INFO_BACKUP}" - -# 注入 commit hash -sed -i.tmp "s/COMMIT_HASH = None/COMMIT_HASH = \"${COMMIT_HASH}\"/" "${BUILD_INFO_FILE}" -rm -f "${BUILD_INFO_FILE}.tmp" # macOS sed 的临时文件 - -log_success "构建信息已注入" - -# ============================================ -# 步骤3: PyInstaller 打包 -# ============================================ -log_step "3" "PyInstaller 打包" - -log_info "正在打包应用..." -${PYINSTALLER} SuperPicky.spec --clean --noconfirm - -# 恢复原始 build_info.py -if [ -f "${BUILD_INFO_BACKUP}" ]; then - mv "${BUILD_INFO_BACKUP}" "${BUILD_INFO_FILE}" - log_info "已恢复原始 build_info.py" -fi - -if [ ! -d "dist/${APP_NAME}.app" ]; then - log_error "打包失败!未找到 dist/${APP_NAME}.app" - exit 1 -fi -log_success "打包完成" - -# ============================================ -# 步骤4: 深度代码签名 -# ============================================ -log_step "4" "深度代码签名" - -# 签名所有嵌入的二进制文件和库 -log_info "签名嵌入的框架和库..." -find "dist/${APP_NAME}.app/Contents" -type f \( -name "*.dylib" -o -name "*.so" \) -print0 | while IFS= read -r -d '' file; do - codesign --force --sign "${DEVELOPER_ID}" --timestamp --options runtime "$file" 2>/dev/null || true -done - -# 签名可执行文件 -find "dist/${APP_NAME}.app/Contents/MacOS" -type f -perm +111 -print0 | while IFS= read -r -d '' file; do - codesign --force --sign "${DEVELOPER_ID}" --timestamp --options runtime "$file" 2>/dev/null || true -done - -# 签名主应用 -log_info "签名主应用..." -codesign --force --deep --sign "${DEVELOPER_ID}" \ - --timestamp \ - --options runtime \ - --entitlements entitlements.plist \ - "dist/${APP_NAME}.app" - -# 验证签名 -log_info "验证代码签名..." -codesign --verify --deep --strict --verbose=2 "dist/${APP_NAME}.app" -log_success "代码签名完成" - -# ============================================ -# 步骤5: 创建 DMG 安装包 -# ============================================ -log_step "5" "创建 DMG 安装包" - -# 创建临时 DMG 文件夹 -TEMP_DMG_DIR="dist/dmg_temp" -rm -rf "${TEMP_DMG_DIR}" -mkdir -p "${TEMP_DMG_DIR}" - -# 复制应用到临时文件夹 -cp -R "dist/${APP_NAME}.app" "${TEMP_DMG_DIR}/" - -# 创建 Applications 快捷方式 -ln -s /Applications "${TEMP_DMG_DIR}/Applications" - -# 创建 DMG -log_info "使用 hdiutil 创建 DMG..." -hdiutil create -volname "${APP_NAME}" -srcfolder "${TEMP_DMG_DIR}" -ov -format UDZO "${DMG_PATH}" - -# 清理临时文件夹 -rm -rf "${TEMP_DMG_DIR}" -log_success "DMG 创建完成: ${DMG_PATH}" - -# ============================================ -# 步骤6: 签名 DMG -# ============================================ -log_step "6" "签名 DMG" - -codesign --force --sign "${DEVELOPER_ID}" --timestamp "${DMG_PATH}" -codesign --verify --verbose=2 "${DMG_PATH}" -log_success "DMG 签名完成" - -# ============================================ -# 步骤7: 公证(仅 release 模式) -# ============================================ -if [ "$MODE" = "release" ]; then - log_step "7" "Apple 公证" - - # 从 Keychain 获取密码 - APP_PASSWORD=$(security find-generic-password -a "${APPLE_ID}" -s "${KEYCHAIN_ITEM}" -w) - - log_info "提交到 Apple 公证服务..." - log_info "(这可能需要几分钟时间)" - - NOTARIZE_OUTPUT=$(xcrun notarytool submit "${DMG_PATH}" \ - --apple-id "${APPLE_ID}" \ - --password "${APP_PASSWORD}" \ - --team-id "${TEAM_ID}" \ - --wait \ - --output-format json 2>&1) - - echo "${NOTARIZE_OUTPUT}" - - # 检查公证结果 - if echo "${NOTARIZE_OUTPUT}" | grep -Eq '"status"[[:space:]]*:[[:space:]]*"Accepted"'; then - log_success "公证成功!" - - # 步骤8: 装订公证票据 - log_step "8" "装订公证票据" - xcrun stapler staple "${DMG_PATH}" - xcrun stapler validate "${DMG_PATH}" - log_success "票据装订完成" - else - log_error "公证失败!" - - # 提取 RequestUUID 并获取详细日志 - REQUEST_UUID=$(echo "${NOTARIZE_OUTPUT}" | sed -n 's/.*"id"[[:space:]]*:[[:space:]]*"\([^"]*\)".*/\1/p' | head -1) - if [ -n "${REQUEST_UUID}" ]; then - log_info "获取详细公证日志..." - xcrun notarytool log "${REQUEST_UUID}" \ - --apple-id "${APPLE_ID}" \ - --password "${APP_PASSWORD}" \ - --team-id "${TEAM_ID}" - fi - exit 1 - fi -else - log_info "测试模式:跳过公证步骤" -fi - -# ============================================ -# 完成报告 -# ============================================ -echo "" -echo -e "${GREEN}========================================${NC}" -echo -e "${GREEN}构建完成!${NC}" -echo -e "${GREEN}========================================${NC}" -echo "" -echo -e "应用: ${CYAN}dist/${APP_NAME}.app${NC}" -echo -e "DMG: ${CYAN}${DMG_PATH}${NC}" -echo -e "架构: ${CYAN}${ARCH_SUFFIX}${NC}" -echo "" - -if [ "$MODE" = "release" ]; then - echo -e "状态: ${GREEN}已公证,可分发${NC}" - echo "" - echo "下一步:" - echo " 1. 测试 DMG 安装包" - echo " 2. 上传到 GitHub Releases" - echo "" - echo "注意: 如需构建其他架构版本,请在对应架构的 Mac 上重新运行此脚本" -else - echo -e "状态: ${YELLOW}已签名(未公证)${NC}" - echo "" - echo "注意: 测试模式下未进行公证,用户首次打开需要右键菜单" - echo "发布正式版本请使用: ./build_release.sh --release" -fi - -echo -e "${GREEN}========================================${NC}" +if [ "$#" -gt 0 ]; then + case "$1" in + --help|-h) + show_help + exit 0 + ;; + --release) + MODE_ARGS="--notarize" + shift + ;; + --test) + shift + ;; + esac +fi + +exec python3 "$SCRIPT_DIR/build_release_mac.py" --build-type full $MODE_ARGS "$@" \ No newline at end of file diff --git a/build_release_all.bat b/build_release_all.bat index c7e6971..36cf6ef 100644 --- a/build_release_all.bat +++ b/build_release_all.bat @@ -1,3 +1,8 @@ +@echo off +setlocal EnableExtensions -call build_release_cpu.bat %1 -:: call build_release_cuda.bat %1 \ No newline at end of file +call "%~dp0build_release_cpu.bat" %* +if errorlevel 1 exit /b %ERRORLEVEL% + +call "%~dp0build_release_lite_win.bat" %* +exit /b %ERRORLEVEL% \ No newline at end of file diff --git a/build_release_cpu.bat b/build_release_cpu.bat index d60bec4..16a9abb 100644 --- a/build_release_cpu.bat +++ b/build_release_cpu.bat @@ -1,19 +1,5 @@ @echo off setlocal EnableExtensions -set "VERSION_INPUT=%~1" -if "%VERSION_INPUT%"=="" ( - set "VERSION_ARG=Win64_CPU" -) else ( - set "VERSION_ARG=%VERSION_INPUT%Win64_CPU" -) - -call "%~dp0.venv\Scripts\activate.bat" -if errorlevel 1 exit /b 1 - -set "OUT_DIST_DIR=dist_cpu" -call "%~dp0build_release.bat" "%VERSION_ARG%" "output" -set "RET=%ERRORLEVEL%" - -call "%~dp0.venv\Scripts\deactivate.bat" >nul 2>&1 -exit /b %RET% +call "%~dp0build_release.bat" "%~1" "output" +exit /b %ERRORLEVEL% diff --git a/build_release_cuda.bat b/build_release_cuda.bat index 47bb43a..1e20d9f 100644 --- a/build_release_cuda.bat +++ b/build_release_cuda.bat @@ -1,19 +1,12 @@ @echo off setlocal EnableExtensions -set "VERSION_INPUT=%~1" -if "%VERSION_INPUT%"=="" ( - set "VERSION_ARG=_Win64_CUDA" -) else ( - set "VERSION_ARG=%VERSION_INPUT%_Win64_CUDA" -) +set "SCRIPT_DIR=%~dp0" +set "PYTHON_EXE=%SCRIPT_DIR%.venv\Scripts\python.exe" +if not exist "%PYTHON_EXE%" set "PYTHON_EXE=python" -call "%~dp0.venv\Scripts\activate.bat" -if errorlevel 1 exit /b 1 +set "VERSION_ARG=" +if not "%~1"=="" set "VERSION_ARG=--version %~1" -set "OUT_DIST_DIR=dist_cuda" -call "%~dp0build_release.bat" "%VERSION_ARG%" "output\win64_cuda" -set "RET=%ERRORLEVEL%" - -call "%~dp0.venv\Scripts\deactivate.bat" >nul 2>&1 -exit /b %RET% +"%PYTHON_EXE%" "%SCRIPT_DIR%build_release_win.py" --build-type cuda %VERSION_ARG% --copy-dir output\win64_cuda +exit /b %ERRORLEVEL% diff --git a/build_release_full_mac.sh b/build_release_full_mac.sh new file mode 100755 index 0000000..4fe5394 --- /dev/null +++ b/build_release_full_mac.sh @@ -0,0 +1,5 @@ +#!/bin/sh +set -eu + +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +exec python3 "$SCRIPT_DIR/build_release_mac.py" --build-type full "$@" diff --git a/build_release_lite_mac.sh b/build_release_lite_mac.sh new file mode 100755 index 0000000..0434f40 --- /dev/null +++ b/build_release_lite_mac.sh @@ -0,0 +1,5 @@ +#!/bin/sh +set -eu + +SCRIPT_DIR=$(CDPATH= cd -- "$(dirname -- "$0")" && pwd) +exec python3 "$SCRIPT_DIR/build_release_mac.py" --build-type lite "$@" diff --git a/build_release_lite_win.bat b/build_release_lite_win.bat new file mode 100644 index 0000000..a4ce8b2 --- /dev/null +++ b/build_release_lite_win.bat @@ -0,0 +1,10 @@ +@echo off +setlocal EnableExtensions + +set "PYTHON_EXE=%~dp0.venv\Scripts\python.exe" +if not exist "%PYTHON_EXE%" ( + set "PYTHON_EXE=python" +) + +"%PYTHON_EXE%" "%~dp0build_release_win.py" --build-type lite %* +exit /b %ERRORLEVEL% diff --git a/build_release_mac.py b/build_release_mac.py new file mode 100644 index 0000000..95f4e81 --- /dev/null +++ b/build_release_mac.py @@ -0,0 +1,1018 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +SuperPicky macOS 构建脚本 / SuperPicky macOS build script. + +支持 full 与 lite 两种构建类型,并可选执行 Developer ID 签名。 +Supports both full and lite builds with optional Developer ID signing. +""" + +from __future__ import annotations + +import argparse +import ast +import importlib.metadata +import json +import logging +import os +import platform +import re +import secrets +import shutil +import subprocess +import sys +import tempfile +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, Sequence + +from packaging.requirements import Requirement + + +ROOT_DIR = Path(__file__).resolve().parent +APP_NAME = "SuperPicky" +LITE_APP_NAME = "SuperPickyLite" +BUILD_INFO_FILE = ROOT_DIR / "core" / "build_info.py" +DOWNLOAD_MODELS_SCRIPT = ROOT_DIR / "scripts" / "download_models.py" +FULL_SPEC_FILE = ROOT_DIR / "SuperPicky_full.spec" +LITE_SPEC_FILE = ROOT_DIR / "SuperPicky_lite.spec" +REQUIREMENTS_MAC_FILE = ROOT_DIR / "requirements_mac.txt" +ENTITLEMENTS_FILE = ROOT_DIR / "entitlements.plist" +DMG_README_FILE = ROOT_DIR / "resources" / "DMG_README.txt" + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BuildPaths: + """ + 构建路径集合 / Build path collection. + """ + + label: str + work_dir: Path + dist_dir: Path + app_dir: Path + dmg_path: Path + + +@dataclass(frozen=True) +class BuildConfig: + """ + 构建配置 / Build configuration. + """ + + build_type: str + arch: str + copy_dir: Path | None + debug: bool + app_version: str + commit_hash: str + sign_p12: Path | None + sign_p12_password_env: str + sign_identity: str | None + release_channel: str + notarize: bool + apple_id: str | None + apple_password_env: str + team_id: str | None + notary_keychain_profile: str | None + + +@dataclass +class SigningContext: + """ + 签名上下文 / Signing context. + """ + + keychain_path: Path + keychain_password: str + imported_p12_path: Path + identity: str + + +def configure_logging(debug: bool) -> None: + """ + 配置 UTF-8 日志输出 / Configure UTF-8 logging output. + """ + + if hasattr(sys.stdout, "reconfigure"): + sys.stdout.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] + if hasattr(sys.stderr, "reconfigure"): + sys.stderr.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] + + logger.setLevel(logging.DEBUG if debug else logging.INFO) + logger.propagate = False + + formatter = logging.Formatter("[%(levelname)s] %(message)s") + handler = logging.StreamHandler(sys.stdout) + handler.setLevel(logging.DEBUG if debug else logging.INFO) + handler.setFormatter(formatter) + + logger.handlers.clear() + logger.addHandler(handler) + + +def log_step(title: str) -> None: + """ + 记录步骤标题 / Log a step title. + """ + + logger.info("[========================================]") + logger.info(title) + logger.info("[========================================]") + + +def log_verbose(message: str, *args) -> None: + """ + 仅在调试模式输出详细日志 / Emit verbose logs only in debug mode. + """ + + logger.debug(message, *args) + + +def detect_host_arch() -> str: + """ + 规范化当前主机架构 / Normalize the current host architecture. + """ + + machine = platform.machine().lower() + return {"amd64": "x86_64", "x86_64": "x86_64", "arm64": "arm64", "aarch64": "arm64"}.get(machine, machine) + + +def optional_text(value: str | None) -> str | None: + """ + 规范化可选字符串 / Normalize optional text values. + """ + + if value is None: + return None + normalized = value.strip() + return normalized or None + + +def parse_args() -> argparse.Namespace: + """ + 解析命令行参数 / Parse command-line arguments. + """ + + parser = argparse.ArgumentParser(description="SuperPicky macOS 构建脚本") + parser.add_argument("--build-type", choices=["full", "lite"], required=True, help="构建类型:full 或 lite") + parser.add_argument( + "--arch", + choices=["arm64", "x86_64"], + default=detect_host_arch(), + help="目标架构,默认使用当前主机架构", + ) + parser.add_argument("--version", help="覆盖构建版本号,例如 4.2.5") + parser.add_argument("--copy-dir", help="复制最终产物的目标目录") + parser.add_argument("--debug", action="store_true", help="输出调试日志") + parser.add_argument("--sign-p12", help="Developer ID 证书 .p12 文件路径") + parser.add_argument( + "--sign-p12-password-env", + default="MACOS_CERTIFICATE_PWD", + help="读取 .p12 密码的环境变量名(默认: MACOS_CERTIFICATE_PWD)", + ) + parser.add_argument("--sign-identity", help="可选,显式指定 Developer ID Application identity") + parser.add_argument("--notarize", action="store_true", help="提交 Apple 公证并自动 staple DMG") + parser.add_argument("--apple-id", help="Apple notarization 使用的 Apple ID") + parser.add_argument("--team-id", help="Apple notarization 使用的 Team ID") + parser.add_argument( + "--apple-password-env", + default="APPLE_APP_PASSWORD", + help="读取 notarization 密码的环境变量名(默认: APPLE_APP_PASSWORD)", + ) + parser.add_argument("--notary-keychain-profile", help="可选,使用 notarytool keychain profile 进行认证") + return parser.parse_args() + + +def run_command( + command: Sequence[str], + *, + cwd: Path = ROOT_DIR, + check: bool = True, + capture_output: bool = False, + env: dict[str, str] | None = None, + label: str | None = None, +) -> subprocess.CompletedProcess[str]: + """ + 运行外部命令 / Run an external command. + """ + + if logger.isEnabledFor(logging.DEBUG): + logger.debug("执行命令: %s", " ".join(command)) + + result = subprocess.run( + list(command), + cwd=str(cwd), + text=True, + capture_output=capture_output, + env=env, + ) + + if check and result.returncode != 0: + if capture_output: + if result.stdout: + logger.error(result.stdout.strip()) + if result.stderr: + logger.error(result.stderr.strip()) + raise RuntimeError(f"{label or '命令执行'}失败,返回码: {result.returncode}") + + return result + + +def remove_path(path: Path) -> None: + """ + 删除文件或目录 / Remove a file or directory. + """ + + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path, ignore_errors=True) + elif path.exists() or path.is_symlink(): + path.unlink(missing_ok=True) + + +def copy_tree(src: Path, dst: Path) -> None: + """ + 复制目录 / Copy a directory tree. + """ + + if not src.exists(): + raise FileNotFoundError(f"复制源目录不存在: {src}") + remove_path(dst) + shutil.copytree(src, dst, symlinks=True) + + +def copy_file(src: Path, dst: Path) -> None: + """ + 复制文件 / Copy a file. + """ + + if not src.exists(): + raise FileNotFoundError(f"复制源文件不存在: {src}") + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dst) + + +def read_app_version() -> str: + """ + 从 constants.py 读取版本号 / Read version from constants.py. + """ + + content = (ROOT_DIR / "constants.py").read_text(encoding="utf-8") + match = re.search(r'APP_VERSION\s*=\s*["\']([0-9A-Za-z._-]+)["\']', content) + return match.group(1) if match else "0.0.0" + + +def get_commit_hash() -> str: + """ + 获取当前提交哈希 / Get the current commit hash. + """ + + try: + result = run_command( + ["git", "rev-parse", "--short=7", "HEAD"], + capture_output=True, + label="获取提交哈希", + ) + return result.stdout.strip() or "unknown" + except Exception: + content = BUILD_INFO_FILE.read_text(encoding="utf-8") + match = re.search(r'COMMIT_HASH\s*=\s*"([^"]*)"', content) + return match.group(1) if match else "unknown" + + +def parse_release_channel() -> str: + """ + 根据 RELEASE_TAG 判断发布渠道 / Infer release channel from RELEASE_TAG. + """ + + release_tag = os.environ.get("RELEASE_TAG", "") + if release_tag and "-rc" in release_tag.lower(): + return "nightly" + return "official" + + +def inject_build_info(commit_hash: str, release_channel: str) -> Path | None: + """ + 注入构建信息并返回备份路径 / Inject build info and return the backup path. + """ + + log_step("步骤 1: 注入构建元数据") + if not BUILD_INFO_FILE.exists(): + logger.warning("未找到 build_info.py,跳过注入") + return None + + backup_path = BUILD_INFO_FILE.with_suffix(".py.backup") + shutil.copy2(BUILD_INFO_FILE, backup_path) + + content = BUILD_INFO_FILE.read_text(encoding="utf-8-sig") + updated = re.sub( + r'COMMIT_HASH\s*=\s*".*"', + f'COMMIT_HASH = "{commit_hash}"', + content, + count=1, + ) + updated = re.sub( + r'RELEASE_CHANNEL\s*=\s*".*"', + f'RELEASE_CHANNEL = "{release_channel}"', + updated, + count=1, + ) + BUILD_INFO_FILE.write_text(updated, encoding="utf-8") + log_verbose("[成功] 已写入 COMMIT_HASH=%s RELEASE_CHANNEL=%s", commit_hash, release_channel) + return backup_path + + +def restore_build_info(backup_path: Path | None) -> None: + """ + 恢复构建信息文件 / Restore the build info file. + """ + + if backup_path and backup_path.exists(): + shutil.move(str(backup_path), str(BUILD_INFO_FILE)) + + +def spec_file_for(build_type: str) -> Path: + """ + 返回构建类型对应的 spec 文件 / Return the spec file for a build type. + """ + + if build_type == "lite": + return LITE_SPEC_FILE + return FULL_SPEC_FILE + + +def app_name_for(build_type: str) -> str: + """ + 返回构建类型对应的应用名 / Return the app name for a build type. + """ + + return LITE_APP_NAME if build_type == "lite" else APP_NAME + + +def artifact_name_for(build_type: str) -> str: + """ + 返回发布产物名称前缀 / Return the artifact name prefix for releases. + """ + + return "SuperPicky_Lite" if build_type == "lite" else APP_NAME + + +def display_name_for(build_type: str) -> str: + """ + 返回面向用户的展示名称 / Return the user-facing display name. + """ + + return "SuperPicky Lite" if build_type == "lite" else APP_NAME + + +def get_build_paths(build_type: str, arch: str, app_version: str, commit_hash: str) -> BuildPaths: + """ + 生成构建路径 / Build output paths. + """ + + label = f"{build_type}_{arch}" + app_name = app_name_for(build_type) + artifact_name = artifact_name_for(build_type) + dist_dir = ROOT_DIR / f"dist_{label}" + dmg_name = f"{artifact_name}_v{app_version}_{arch}_{commit_hash}.dmg" + return BuildPaths( + label=label, + work_dir=ROOT_DIR / f"build_dist_{label}", + dist_dir=dist_dir, + app_dir=dist_dir / f"{app_name}.app", + dmg_path=dist_dir / dmg_name, + ) + + +def ensure_macos_host() -> None: + """ + 确保当前系统为 macOS / Ensure the current host is macOS. + """ + + if sys.platform != "darwin": + raise RuntimeError("build_release_mac.py 只能在 macOS 上运行") + + +def ensure_arch_matches(target_arch: str) -> None: + """ + 确保目标架构与当前机器匹配 / Ensure target architecture matches the host. + """ + + normalized = detect_host_arch() + if normalized != target_arch: + raise RuntimeError( + f"当前机器架构为 {normalized},不能直接构建 {target_arch}。" + "请在对应架构的 macOS 环境中运行此脚本。" + ) + + +def _iter_requirement_lines(requirements_file: Path) -> Iterable[tuple[Path, str]]: + """ + 递归展开 requirements 文件 / Recursively expand requirements files. + """ + + for raw_line in requirements_file.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("-r "): + nested_path = (requirements_file.parent / line[3:].strip()).resolve() + yield from _iter_requirement_lines(nested_path) + continue + if line.startswith("--requirement "): + nested_path = (requirements_file.parent / line.split(None, 1)[1].strip()).resolve() + yield from _iter_requirement_lines(nested_path) + continue + if line.startswith("-"): + logger.debug("跳过未处理的 requirements 条目: %s", line) + continue + yield requirements_file, line + + +def validate_python_environment() -> None: + """ + 检查当前 Python 环境是否满足 requirements_mac.txt / Validate the current Python environment. + """ + + log_step("步骤 2: 检查 Python 构建环境") + + missing_packages: list[str] = [] + version_conflicts: list[str] = [] + + for source_file, requirement_text in _iter_requirement_lines(REQUIREMENTS_MAC_FILE): + requirement = Requirement(requirement_text) + try: + installed_version = importlib.metadata.version(requirement.name) + except importlib.metadata.PackageNotFoundError: + missing_packages.append(f"{requirement.name} ({source_file.name})") + continue + if requirement.specifier and installed_version not in requirement.specifier: + version_conflicts.append( + f"{requirement.name}=={installed_version} 不满足 {requirement.specifier} ({source_file.name})" + ) + + if missing_packages or version_conflicts: + details = "\n".join([*missing_packages, *version_conflicts]) + raise RuntimeError( + "当前 Python 环境未满足 requirements_mac.txt。\n" + "请先执行 `python -m pip install -r requirements_mac.txt`。\n" + f"{details}" + ) + + run_command([sys.executable, "-c", "import PyInstaller; print(PyInstaller.__version__)"], label="PyInstaller 检查") + log_verbose("[成功] 当前 Python 环境满足 macOS 构建要求") + + +def load_required_models() -> list[dict[str, str]]: + """ + 从 download_models.py 解析模型清单 / Parse the required model list from download_models.py. + """ + + fallback = [ + {"filename": "model20240824.pth", "dest_dir": "models"}, + {"filename": "superFlier_efficientnet.pth", "dest_dir": "models"}, + {"filename": "cub200_keypoint_resnet50_slim.pth", "dest_dir": "models"}, + {"filename": "avonet.db", "dest_dir": "birdid/data"}, + {"filename": "cfanet_iaa_ava_res50-3cd62bb3.pth", "dest_dir": "models"}, + {"filename": "yolo11l-seg.pt", "dest_dir": "models"}, + ] + + if not DOWNLOAD_MODELS_SCRIPT.exists(): + logger.warning("未找到 download_models.py,使用默认模型列表") + return fallback + + try: + module_ast = ast.parse(DOWNLOAD_MODELS_SCRIPT.read_text(encoding="utf-8")) + models = None + for node in module_ast.body: + if isinstance(node, ast.Assign): + for target in node.targets: + if isinstance(target, ast.Name) and target.id == "MODELS_TO_DOWNLOAD": + models = ast.literal_eval(node.value) + break + if models is not None: + break + if models is None: + raise RuntimeError("download_models.py 中未找到 MODELS_TO_DOWNLOAD") + return [{"filename": item["filename"], "dest_dir": item["dest_dir"]} for item in models] + except BaseException as exc: + if isinstance(exc, KeyboardInterrupt): + raise + logger.warning("无法解析 download_models.py,使用默认模型列表: %s", exc) + return fallback + + +REQUIRED_MODELS = load_required_models() + + +def ensure_models() -> None: + """ + 检查 full 构建所需模型并在缺失时下载 / Ensure required models for the full build. + """ + + log_step("步骤 3: 检查并下载模型文件") + missing_paths = [ + ROOT_DIR / model["dest_dir"] / model["filename"] + for model in REQUIRED_MODELS + if not (ROOT_DIR / model["dest_dir"] / model["filename"]).exists() + ] + + if not missing_paths: + log_verbose("[成功] 所有模型文件已就绪") + return + + logger.warning("检测到 %d 个缺失模型,开始下载", len(missing_paths)) + run_command([sys.executable, str(DOWNLOAD_MODELS_SCRIPT)], label="模型下载") + + remaining = [path for path in missing_paths if not path.exists()] + if remaining: + missing_text = "\n".join(str(path) for path in remaining) + raise RuntimeError(f"模型下载后仍有缺失:\n{missing_text}") + log_verbose("[成功] 所有模型文件已就绪") + + +def clean_build_outputs(paths: BuildPaths) -> None: + """ + 清理构建目录 / Clean build outputs. + """ + + log_step("步骤 4: 清理旧的构建目录") + remove_path(paths.work_dir) + remove_path(paths.dist_dir) + log_verbose("[成功] 已清理 %s 和 %s", paths.work_dir, paths.dist_dir) + + +def build_environment(config: BuildConfig) -> dict[str, str]: + """ + 生成 PyInstaller 环境变量 / Build PyInstaller environment variables. + """ + + env = os.environ.copy() + env["SUPERPICKY_TARGET_ARCH"] = config.arch + env["SUPERPICKY_APP_VERSION"] = config.app_version + env["SUPERPICKY_CODESIGN_IDENTITY"] = "" + env["SUPERPICKY_ENTITLEMENTS_FILE"] = "" + return env + + +def build_bundle(config: BuildConfig, paths: BuildPaths) -> None: + """ + 执行 PyInstaller 构建 / Run the PyInstaller build. + """ + + log_step("步骤 5: 执行 PyInstaller 构建") + spec_file = spec_file_for(config.build_type) + if config.build_type == "lite": + log_verbose("[信息] macOS Lite 当前采用内置 Torch/Torchvision/Timm 的单包运行时策略") + pyinstaller_command = [ + sys.executable, + "-m", + "PyInstaller", + str(spec_file), + "--clean", + "--noconfirm", + f"--workpath={paths.work_dir}", + f"--distpath={paths.dist_dir}", + ] + logger.info("启动 PyInstaller 构建:开始") + logger.info("PyInstaller 参数:%s", " ".join(str(item) for item in pyinstaller_command[2:])) + run_command( + pyinstaller_command, + capture_output=not logger.isEnabledFor(logging.DEBUG), + env=build_environment(config), + label="PyInstaller 构建", + ) + + if not paths.app_dir.exists(): + raise FileNotFoundError(f"构建完成后未找到 .app: {paths.app_dir}") + logger.info("PyInstaller 构建成功!") + logger.info("构建产物位置:%s", paths.app_dir) + log_verbose("[成功] 已完成 %s 构建: %s", config.build_type, paths.app_dir) + + +def organize_app_bundle_resources(app_dir: Path) -> None: + """ + 将资源移至 .app 的 Resources 目录 / Move resources into the app bundle Resources directory. + """ + + log_step("步骤 6: 整理 .app 资源目录") + macos_dir = app_dir / "Contents" / "MacOS" + resources_dir = app_dir / "Contents" / "Resources" + resources_dir.mkdir(parents=True, exist_ok=True) + + for resource_name in ("SuperBirdIDPlugin.lrplugin", "en.lproj", "zh-Hans.lproj"): + source_path = macos_dir / resource_name + destination_path = resources_dir / resource_name + if source_path.exists(): + remove_path(destination_path) + shutil.move(str(source_path), str(destination_path)) + log_verbose("[成功] 已移动资源到 Resources: %s", resource_name) + + +def create_dmg(config: BuildConfig, paths: BuildPaths) -> None: + """ + 生成 DMG 安装镜像 / Create a DMG installer image. + """ + + log_step("步骤 7: 生成 DMG") + staging_dir = paths.dist_dir / "dmg_staging" + remove_path(staging_dir) + staging_dir.mkdir(parents=True, exist_ok=True) + + staged_app = staging_dir / paths.app_dir.name + copy_tree(paths.app_dir, staged_app) + + plugin_dir = paths.app_dir / "Contents" / "Resources" / "SuperBirdIDPlugin.lrplugin" + if plugin_dir.exists(): + copy_tree(plugin_dir, staging_dir / plugin_dir.name) + + if DMG_README_FILE.exists(): + copy_file(DMG_README_FILE, staging_dir / "README.txt") + + applications_link = staging_dir / "Applications" + if not applications_link.exists(): + os.symlink("/Applications", applications_link) + + paths.dmg_path.parent.mkdir(parents=True, exist_ok=True) + remove_path(paths.dmg_path) + run_command( + [ + "hdiutil", + "create", + "-volname", + f"{display_name_for(config.build_type)} {config.app_version}", + "-srcfolder", + str(staging_dir), + "-ov", + "-format", + "UDZO", + str(paths.dmg_path), + ], + label="生成 DMG", + ) + + remove_path(staging_dir) + log_verbose("[成功] 已生成 DMG: %s", paths.dmg_path) + + +def prepare_signing(config: BuildConfig) -> SigningContext | None: + """ + 如果提供 .p12,则导入临时 keychain / Import a temporary keychain when a .p12 is provided. + """ + + if config.sign_p12 is None: + return None + + log_step("步骤 8: 导入签名证书") + if not config.sign_p12.exists(): + raise FileNotFoundError(f"未找到签名证书文件: {config.sign_p12}") + + password = os.environ.get(config.sign_p12_password_env, "") + if not password: + raise RuntimeError(f"环境变量 {config.sign_p12_password_env} 未设置,无法导入 .p12") + + temp_dir = Path(tempfile.mkdtemp(prefix="superpicky_sign_", dir=str(ROOT_DIR))) + imported_p12_path = temp_dir / config.sign_p12.name + shutil.copy2(config.sign_p12, imported_p12_path) + + keychain_path = temp_dir / "build.keychain-db" + keychain_password = secrets.token_hex(16) + + run_command(["security", "create-keychain", "-p", keychain_password, str(keychain_path)], label="创建临时 keychain") + run_command(["security", "set-keychain-settings", "-lut", "21600", str(keychain_path)], label="配置 keychain") + run_command(["security", "unlock-keychain", "-p", keychain_password, str(keychain_path)], label="解锁 keychain") + run_command( + [ + "security", + "import", + str(imported_p12_path), + "-k", + str(keychain_path), + "-P", + password, + "-T", + "/usr/bin/codesign", + ], + label="导入 .p12", + ) + run_command( + [ + "security", + "set-key-partition-list", + "-S", + "apple-tool:,apple:", + "-s", + "-k", + keychain_password, + str(keychain_path), + ], + label="配置 keychain 访问权限", + ) + + identity = config.sign_identity or discover_signing_identity(keychain_path) + log_verbose("[成功] 已加载签名 identity: %s", identity) + return SigningContext( + keychain_path=keychain_path, + keychain_password=keychain_password, + imported_p12_path=imported_p12_path, + identity=identity, + ) + + +def discover_signing_identity(keychain_path: Path) -> str: + """ + 从 keychain 中解析 Developer ID Application identity / Resolve Developer ID Application identity from keychain. + """ + + result = run_command( + ["security", "find-identity", "-v", "-p", "codesigning", str(keychain_path)], + capture_output=True, + label="解析签名 identity", + ) + pattern = re.compile(r'"(Developer ID Application:[^"]+)"') + for line in result.stdout.splitlines(): + match = pattern.search(line) + if match: + return match.group(1) + raise RuntimeError("未在 .p12 对应 keychain 中找到 Developer ID Application identity") + + +def iter_signable_files(contents_dir: Path) -> list[Path]: + """ + 枚举需要优先签名的文件 / Enumerate files that should be signed first. + """ + + signable: list[Path] = [] + for path in contents_dir.rglob("*"): + if not path.is_file(): + continue + if path.suffix in {".dylib", ".so"} or os.access(path, os.X_OK): + signable.append(path) + signable.sort(key=lambda item: len(item.parts), reverse=True) + return signable + + +def codesign_path( + path: Path, + identity: str, + *, + entitlements: Path | None = None, + keychain_path: Path | None = None, + use_runtime: bool = False, +) -> None: + """ + 对指定路径执行 codesign / Sign a path with codesign. + """ + + command = ["codesign", "--force", "--sign", identity] + if identity != "-": + command.append("--timestamp") + if use_runtime: + command.extend(["--options", "runtime"]) + if entitlements is not None: + command.extend(["--entitlements", str(entitlements)]) + if keychain_path is not None: + command.extend(["--keychain", str(keychain_path)]) + command.append(str(path)) + run_command(command, label=f"签名 {path.name}") + + +def sign_app_bundle(app_dir: Path, signing_context: SigningContext | None) -> None: + """ + 对 .app 执行签名并验证 / Sign and verify the app bundle. + """ + + log_step("步骤 9: 签名并验证 .app") + identity = signing_context.identity if signing_context else "-" + keychain_path = signing_context.keychain_path if signing_context else None + entitlements = ENTITLEMENTS_FILE if signing_context and ENTITLEMENTS_FILE.exists() else None + + for nested_path in iter_signable_files(app_dir / "Contents"): + codesign_path(nested_path, identity, keychain_path=keychain_path, use_runtime=True) + + codesign_path(app_dir, identity, entitlements=entitlements, keychain_path=keychain_path, use_runtime=True) + verify_command = ["codesign", "--verify", "--deep", "--strict", "--verbose=2", str(app_dir)] + run_command(verify_command, label="校验 .app 签名") + log_verbose("[成功] .app 签名校验通过") + + +def sign_dmg(dmg_path: Path, signing_context: SigningContext | None) -> None: + """ + 如有证书则签名 DMG / Sign the DMG when a certificate is provided. + """ + + if signing_context is None: + log_verbose("[信息] 未提供 .p12,跳过 DMG 签名") + return + + log_step("步骤 10: 签名 DMG") + codesign_path(dmg_path, signing_context.identity, keychain_path=signing_context.keychain_path, use_runtime=False) + run_command(["codesign", "--verify", "--verbose=2", str(dmg_path)], label="校验 DMG 签名") + log_verbose("[成功] DMG 签名校验通过") + + +def notary_auth_arguments(config: BuildConfig) -> list[str]: + """ + 构造 notarytool 认证参数 / Build notarytool authentication arguments. + """ + + if config.notary_keychain_profile: + return ["--keychain-profile", config.notary_keychain_profile] + + if not config.apple_id: + raise RuntimeError("启用 --notarize 时必须提供 Apple ID 或设置 APPLE_ID 环境变量") + if not config.team_id: + raise RuntimeError("启用 --notarize 时必须提供 Team ID 或设置 MACOS_TEAM_ID/TEAM_ID 环境变量") + + password = os.environ.get(config.apple_password_env, "").strip() + if not password: + raise RuntimeError(f"启用 --notarize 时必须设置环境变量 {config.apple_password_env}") + + return [ + "--apple-id", + config.apple_id, + "--password", + password, + "--team-id", + config.team_id, + ] + + +def notarize_dmg(dmg_path: Path, config: BuildConfig) -> None: + """ + 公证并装订 DMG / Notarize and staple the DMG. + """ + + if not config.notarize: + log_verbose("[信息] 未启用 --notarize,跳过 Apple 公证") + return + + log_step("步骤 11: Apple 公证并装订") + auth_args = notary_auth_arguments(config) + submit_command = [ + "xcrun", + "notarytool", + "submit", + str(dmg_path), + *auth_args, + "--wait", + "--output-format", + "json", + ] + result = run_command(submit_command, capture_output=True, label="Apple 公证") + output = result.stdout.strip() + if output: + logger.info(output) + + status = "" + request_id = "" + if output: + try: + payload = json.loads(output) + except json.JSONDecodeError: + payload = None + if isinstance(payload, dict): + status = str(payload.get("status", "")).strip().lower() + request_id = str(payload.get("id", "")).strip() + elif "Accepted" in output: + status = "accepted" + + if status != "accepted": + if request_id: + log_verbose("[信息] 公证失败,尝试读取详细日志: %s", request_id) + log_result = run_command( + ["xcrun", "notarytool", "log", request_id, *auth_args], + capture_output=True, + check=False, + label="读取公证日志", + ) + if log_result.stdout: + logger.error(log_result.stdout.strip()) + if log_result.stderr: + logger.error(log_result.stderr.strip()) + raise RuntimeError("Apple 公证失败") + + run_command(["xcrun", "stapler", "staple", str(dmg_path)], label="装订公证票据") + run_command(["xcrun", "stapler", "validate", str(dmg_path)], label="验证公证票据") + log_verbose("[成功] DMG 公证与装订完成") + + +def publish_artifacts(paths: BuildPaths, config: BuildConfig) -> tuple[Path, Path]: + """ + 输出最终产物位置 / Publish final artifact locations. + """ + + if config.copy_dir is None: + return paths.app_dir, paths.dmg_path + + config.copy_dir.mkdir(parents=True, exist_ok=True) + destination_app = config.copy_dir / paths.app_dir.name + destination_dmg = config.copy_dir / paths.dmg_path.name + copy_tree(paths.app_dir, destination_app) + copy_file(paths.dmg_path, destination_dmg) + log_verbose("[成功] 已复制最终产物到: %s", config.copy_dir) + return destination_app, destination_dmg + + +def cleanup_signing_context(signing_context: SigningContext | None) -> None: + """ + 清理临时 keychain 和证书文件 / Clean up the temporary keychain and certificate file. + """ + + if signing_context is None: + return + + parent_dir = signing_context.keychain_path.parent + run_command(["security", "delete-keychain", str(signing_context.keychain_path)], check=False) + remove_path(parent_dir) + + +def create_config(args: argparse.Namespace) -> BuildConfig: + """ + 根据参数创建构建配置 / Create build configuration from CLI args. + """ + + return BuildConfig( + build_type=args.build_type, + arch=args.arch, + copy_dir=Path(args.copy_dir).resolve() if args.copy_dir else None, + debug=args.debug, + app_version=args.version or read_app_version(), + commit_hash=get_commit_hash(), + sign_p12=Path(args.sign_p12).resolve() if args.sign_p12 else None, + sign_p12_password_env=args.sign_p12_password_env, + sign_identity=optional_text(args.sign_identity), + release_channel=parse_release_channel(), + notarize=args.notarize, + apple_id=optional_text(args.apple_id) or optional_text(os.environ.get("APPLE_ID")), + apple_password_env=args.apple_password_env, + team_id=( + optional_text(args.team_id) + or optional_text(os.environ.get("MACOS_TEAM_ID")) + or optional_text(os.environ.get("TEAM_ID")) + ), + notary_keychain_profile=( + optional_text(args.notary_keychain_profile) + or optional_text(os.environ.get("NOTARY_KEYCHAIN_PROFILE")) + ), + ) + + +def run_build(config: BuildConfig) -> None: + """ + 执行完整构建流程 / Run the complete build flow. + """ + + ensure_macos_host() + ensure_arch_matches(config.arch) + validate_python_environment() + + if config.build_type == "full": + ensure_models() + + paths = get_build_paths(config.build_type, config.arch, config.app_version, config.commit_hash) + clean_build_outputs(paths) + build_bundle(config, paths) + organize_app_bundle_resources(paths.app_dir) + + signing_context: SigningContext | None = None + try: + signing_context = prepare_signing(config) + if config.notarize and signing_context is None and not config.sign_identity: + raise RuntimeError("启用 --notarize 时必须提供 --sign-p12 或 --sign-identity 以完成正式签名") + sign_app_bundle(paths.app_dir, signing_context) + create_dmg(config, paths) + sign_dmg(paths.dmg_path, signing_context) + notarize_dmg(paths.dmg_path, config) + final_app, final_dmg = publish_artifacts(paths, config) + logger.info("[========================================]") + logger.info("构建完成") + logger.info("[========================================]") + logger.info("构建类型: %s", config.build_type) + logger.info("目标架构: %s", config.arch) + logger.info(".app: %s", final_app) + logger.info(".dmg: %s", final_dmg) + finally: + cleanup_signing_context(signing_context) + + +def main() -> None: + """ + 程序入口 / Program entrypoint. + """ + + args = parse_args() + configure_logging(args.debug) + config = create_config(args) + backup_path = inject_build_info(config.commit_hash, config.release_channel) + try: + run_build(config) + finally: + restore_build_info(backup_path) + + +if __name__ == "__main__": + main() diff --git a/build_release_win.py b/build_release_win.py index 1190a39..b3fa00a 100644 --- a/build_release_win.py +++ b/build_release_win.py @@ -16,10 +16,12 @@ import ast import hashlib import logging +import os import re import shutil import subprocess import sys +import zipfile from dataclasses import dataclass from pathlib import Path from typing import Sequence @@ -31,10 +33,12 @@ BUILD_INFO_FILE = ROOT_DIR / "core" / "build_info.py" DOWNLOAD_MODELS_SCRIPT = ROOT_DIR / "scripts" / "download_models.py" SPEC_FILE = ROOT_DIR / "SuperPicky_win64.spec" +LITE_SPEC_FILE = ROOT_DIR / "SuperPicky_lite_win.spec" CPU_VENV_DIR = ROOT_DIR / ".venv" CUDA_VENV_DIR = ROOT_DIR / ".venv-cuda" DEFAULT_PATCH_OUTPUT_ROOT = ROOT_DIR / "output" STANDARD_INNO_TEMPLATE = INNO_DIR / "SuperPicky.iss" +LITE_INNO_TEMPLATE = INNO_DIR / "SuperPicky-lite.iss" PATCH_INNO_TEMPLATE = INNO_DIR / "SuperPicky_CUDA_Patch.iss" INNO_LANGUAGE_FILE = INNO_DIR / "ChineseSimplified.isl" CPU_REQUIREMENTS_FILE = ROOT_DIR / "requirements.txt" @@ -42,6 +46,7 @@ PATCH_MANIFEST_RELATIVE_PATH = Path("_internal") / "cuda_patch_manifest.txt" CPU_INSTALLER_STAGING_DIRNAME = "installer_cpu" CUDA_INSTALLER_STAGING_DIRNAME = "installer_cuda" +LITE_INSTALLER_STAGING_DIRNAME = "installer_lite" CUDA_PATCH_PORTABLE_DIRNAME = "cuda_patch" CUDA_PATCH_INSTALLER_STAGING_DIRNAME = "cuda_patch_installer" @@ -68,9 +73,9 @@ class BuildConfig: def configure_logging(debug: bool) -> None: if hasattr(sys.stdout, "reconfigure"): - sys.stdout.reconfigure(encoding="utf-8", errors="strict") + sys.stdout.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] if hasattr(sys.stderr, "reconfigure"): - sys.stderr.reconfigure(encoding="utf-8", errors="strict") + sys.stderr.reconfigure(encoding="utf-8", errors="strict") # pyright: ignore[reportAttributeAccessIssue] logger.setLevel(logging.DEBUG if debug else logging.INFO) logger.propagate = False @@ -90,13 +95,17 @@ def log_step(title: str) -> None: logger.info("[========================================]") +def log_verbose(message: str, *args) -> None: + logger.debug(message, *args) + + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="SuperPicky Windows 构建脚本") parser.add_argument( "--build-type", - choices=["cpu", "cuda", "cuda-patch"], - default="cpu", - help="构建类型:cpu, cuda, cuda-patch (默认: cpu)", + choices=["cpu", "cuda", "cuda-patch", "lite"], + default="lite", + help="构建类型:cpu, cuda, cuda-patch, lite (默认: lite)", ) parser.add_argument("--version", help="覆盖基础版本号,例如 4.2.0") parser.add_argument("--copy-dir", help="复制最终产物的目标目录") @@ -172,6 +181,7 @@ def load_required_models() -> list[dict[str, str]]: {"filename": "cub200_keypoint_resnet50_slim.pth", "dest_dir": "models"}, {"filename": "avonet.db", "dest_dir": "birdid/data"}, {"filename": "cfanet_iaa_ava_res50-3cd62bb3.pth", "dest_dir": "models"}, + {"filename": "yolo11l-seg.pt", "dest_dir": "models"}, ] if not DOWNLOAD_MODELS_SCRIPT.exists(): @@ -191,7 +201,7 @@ def load_required_models() -> list[dict[str, str]]: break if models is None: raise RuntimeError("download_models.py 中未找到 MODELS_TO_DOWNLOAD") - logger.info("[成功] 已从 download_models.py 加载模型列表") + log_verbose("[成功] 已从 download_models.py 加载模型列表") return [{"filename": item["filename"], "dest_dir": item["dest_dir"]} for item in models] except BaseException as exc: if isinstance(exc, KeyboardInterrupt): @@ -216,7 +226,7 @@ def ensure_models(python_exe: Path) -> None: log_step("步骤 0: 检查并下载模型文件") missing = find_missing_models() if not missing: - logger.info("[成功] 所有模型文件已就绪") + log_verbose("[成功] 所有模型文件已就绪") return logger.warning("缺失 %d 个模型文件,开始下载", len(missing)) @@ -231,7 +241,7 @@ def ensure_models(python_exe: Path) -> None: logger.error("仍然缺失: %s", path) raise RuntimeError("模型下载后仍有缺失") - logger.info("[成功] 所有模型文件已就绪") + log_verbose("[成功] 所有模型文件已就绪") def read_app_version() -> str: @@ -279,7 +289,7 @@ def inject_build_info(commit_hash: str, release_channel: str = "official") -> Pa count=1, ) BUILD_INFO_FILE.write_text(updated, encoding="utf-8") - logger.info("[成功] 已写入 COMMIT_HASH=%s RELEASE_CHANNEL=%s", commit_hash, release_channel) + log_verbose("[成功] 已写入 COMMIT_HASH=%s RELEASE_CHANNEL=%s", commit_hash, release_channel) return backup_path @@ -288,16 +298,24 @@ def restore_build_info(backup_path: Path | None) -> None: shutil.move(str(backup_path), str(BUILD_INFO_FILE)) -def ensure_spec_file() -> None: - if not SPEC_FILE.exists(): - raise FileNotFoundError(f"缺少 spec 文件: {SPEC_FILE}") +def spec_file_for(build_type: str) -> Path: + if build_type == "lite": + return LITE_SPEC_FILE + return SPEC_FILE + + +def ensure_spec_file(build_type: str) -> None: + spec_file = spec_file_for(build_type) + if not spec_file.exists(): + raise FileNotFoundError(f"缺少 spec 文件: {spec_file}") + log_verbose("[信息] %s 构建将使用 spec: %s", build_type, spec_file.name) def check_python_environment(python_exe: Path, label: str) -> None: - logger.info("[信息] 检查 Python 环境 (%s): %s", label, python_exe) + log_verbose("[信息] 检查 Python 环境 (%s): %s", label, python_exe) run_command([str(python_exe), "-c", "import sys; print(sys.executable)"], label=f"{label} Python 检查") run_command([str(python_exe), "-c", "import PyInstaller"], label=f"{label} PyInstaller 检查") - logger.info("[成功] %s 环境可用", label) + log_verbose("[成功] %s 环境可用", label) def python_in_venv(venv_dir: Path) -> Path: @@ -315,7 +333,7 @@ def ensure_virtual_environment( venv_python = python_in_venv(venv_dir) if not venv_python.exists(): - logger.info("[信息] 创建 %s 虚拟环境: %s", label, venv_dir) + logger.info("创建 %s 虚拟环境...", label) run_command([str(bootstrap_python), "-m", "venv", str(venv_dir)], label=f"创建 {label} 虚拟环境") run_command([str(venv_python), "-m", "pip", "install", "--upgrade", "pip"], label=f"升级 {label} 环境 pip") @@ -348,99 +366,126 @@ def ensure_cuda_environment(bootstrap_python: Path) -> Path: def clean_build_outputs() -> None: log_step("步骤 2: 清理旧的构建目录") - for label in ("cpu", "cuda", "cuda_patch"): + for label in ("cpu", "cuda", "cuda_patch", "lite"): paths = get_build_paths(label) remove_path(paths.work_dir) remove_path(paths.dist_dir) remove_path(ROOT_DIR / "build_dist") remove_path(ROOT_DIR / "dist") - logger.info("[成功] 已清理构建目录") + log_verbose("[成功] 已清理构建目录") -def build_bundle(python_exe: Path, build_paths: BuildPaths) -> None: +def build_bundle(python_exe: Path, build_paths: BuildPaths, spec_file: Path) -> None: log_step(f"步骤 3: 构建 {build_paths.label.upper()} 版本") remove_path(build_paths.work_dir) remove_path(build_paths.dist_dir) + pyinstaller_command = [ + str(python_exe), + "-m", + "PyInstaller", + str(spec_file), + "--clean", + "--noconfirm", + f"--workpath={build_paths.work_dir}", + f"--distpath={build_paths.dist_dir}", + ] + logger.info("启动 PyInstaller 构建:开始") + logger.info("PyInstaller 参数:%s", " ".join(str(item) for item in pyinstaller_command[2:])) run_command( - [ - str(python_exe), - "-m", - "PyInstaller", - str(SPEC_FILE), - "--clean", - "--noconfirm", - f"--workpath={build_paths.work_dir}", - f"--distpath={build_paths.dist_dir}", - ], + pyinstaller_command, + capture_output=not logger.isEnabledFor(logging.DEBUG), label=f"{build_paths.label} PyInstaller 构建", ) exe_path = build_paths.bundle_dir / f"{APP_NAME}.exe" if not exe_path.exists(): raise FileNotFoundError(f"构建完成后未找到可执行文件: {exe_path}") - logger.info("[成功] %s 构建完成", build_paths.label.upper()) - - -def find_7z_executable() -> str: - candidates = [ - shutil.which("7z"), - shutil.which("7zz"), - shutil.which("7za"), - str(Path("C:/Program Files/7-Zip/7z.exe")), - str(Path("C:/Program Files (x86)/7-Zip/7z.exe")), - ] - for candidate in candidates: - if not candidate: - continue - if Path(candidate).exists(): - return candidate - raise FileNotFoundError("未找到 7z 可执行文件,请先安装 7-Zip 或确保 7z/7zz/7za 已加入 PATH") + logger.info("PyInstaller 构建成功!") + logger.info("构建产物位置:%s", exe_path) + log_verbose("[成功] %s 构建完成", build_paths.label.upper()) def create_zip_archive(source_dir: Path, archive_path: Path) -> None: + """ + 使用标准库创建 ZIP 包 / Create ZIP archives with the Python standard library. + """ + archive_path.parent.mkdir(parents=True, exist_ok=True) archive_path.unlink(missing_ok=True) - seven_zip = find_7z_executable() - run_command( - [ - seven_zip, - "a", - "-tzip", - "-mx=9", - "-mm=LZMA", - "-md=256m", - "-mfb=128", - "-mmt=on", - str(archive_path), - source_dir.name, - ], - cwd=source_dir.parent, - label="ZIP 压缩", - ) + with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as archive: + for file_path in sorted(source_dir.rglob("*")): + if file_path.is_dir(): + continue + archive.write(file_path, arcname=str(Path(source_dir.name) / file_path.relative_to(source_dir))) def archive_name_for(label: str, app_version: str, commit_hash: str) -> str: return f"{APP_NAME}_Win64_{app_version}_{commit_hash}_{label}.zip" -def update_inno_content(content: str, *, app_version: str, commit_hash: str, patch: bool) -> str: - version_value = f"{app_version}-{commit_hash}" - if patch: - output_base = f"SuperPicky_CUDA_Patch_Win64_{app_version}_{commit_hash}" - else: - output_base = f"SuperPicky_Setup_Win64_{app_version}_{commit_hash}" +def normalize_version(version: str) -> str: + """确保版本号以 'v' 前缀开头。 + + Ensure version string starts with 'v' prefix. + + 参数 / Parameters: + version (str): 原始版本号,例如 "4.2.0" 或 "v4.2.0" + + 返回 / Return: + str: 带 'v' 前缀的版本号,例如 "v4.2.0" + """ + return version if version.startswith("v") else f"v{version}" + + +def update_inno_content(content: str, *, app_version: str, commit_hash: str) -> str: + """替换 ISS 模板中的 #define 预处理器变量,注入版本号和提交哈希。 - content = re.sub(r"(?m)^AppVersion=.*$", f"AppVersion={version_value}", content) - content = re.sub(r"(?m)^OutputBaseFilename=.*$", f"OutputBaseFilename={output_base}", content) + Replace #define preprocessor variables in ISS template with version and commit hash. + + 参数 / Parameters: + content (str): ISS 模板原始内容 + app_version (str): 应用版本号(将自动添加 'v' 前缀) + commit_hash (str): Git 提交哈希 + + 返回 / Return: + str: 替换后的 ISS 内容 + """ + versioned = normalize_version(app_version) + content = re.sub( + r'(?m)^(#define\s+MyAppVersion\s+").*?(")\s*$', + rf'\g<1>{versioned}\2', + content, + ) + content = re.sub( + r'(?m)^(#define\s+MyAppCommitHash\s+").*?(")\s*$', + rf'\g<1>{commit_hash}\2', + content, + ) return content -def write_inno_script(template_path: Path, destination_path: Path, *, app_version: str, commit_hash: str, patch: bool) -> None: +def write_inno_script( + template_path: Path, + destination_path: Path, + *, + app_version: str, + commit_hash: str, +) -> None: + """读取 ISS 模板,注入版本号和哈希后写入目标路径。 + + Read ISS template, inject version and hash, write to destination. + + 参数 / Parameters: + template_path (Path): ISS 模板文件路径 + destination_path (Path): 输出 ISS 文件路径 + app_version (str): 应用版本号 + commit_hash (str): Git 提交哈希 + """ content = template_path.read_text(encoding="utf-8") destination_path.parent.mkdir(parents=True, exist_ok=True) destination_path.write_text( - update_inno_content(content, app_version=app_version, commit_hash=commit_hash, patch=patch), + update_inno_content(content, app_version=app_version, commit_hash=commit_hash), encoding="utf-8", ) @@ -450,22 +495,55 @@ def installer_staging_dir_name(label: str) -> str: return CPU_INSTALLER_STAGING_DIRNAME if label == "cuda": return CUDA_INSTALLER_STAGING_DIRNAME + if label == "lite": + return LITE_INSTALLER_STAGING_DIRNAME raise ValueError(f"不支持的标准安装包标签: {label}") +def inno_template_for(label: str) -> Path: + """根据构建标签返回对应的 ISS 模板路径。 + + Return the ISS template path for the given build label. + + 参数 / Parameters: + label (str): 构建标签,"lite" 或其他(Full/CPU/CUDA) + + 返回 / Return: + Path: ISS 模板文件路径 + """ + if label == "lite": + return LITE_INNO_TEMPLATE + return STANDARD_INNO_TEMPLATE + + def prepare_standard_installer_staging(source_bundle_dir: Path, staging_root: Path, config: BuildConfig, *, label: str) -> Path: + """准备标准安装包的 staging 目录,包含构建产物、ISS 脚本和依赖资源。 + + Prepare standard installer staging directory with build artifacts, ISS script and dependencies. + + 参数 / Parameters: + source_bundle_dir (Path): PyInstaller 构建产物目录 + staging_root (Path): staging 根目录 + config (BuildConfig): 构建配置 + label (str): 构建标签("cpu", "cuda", "lite") + + 返回 / Return: + Path: 生成的 ISS 脚本路径 + """ staging_dir = staging_root / installer_staging_dir_name(label) copy_tree(source_bundle_dir, staging_dir) + template = inno_template_for(label) + iss_filename = template.name write_inno_script( - STANDARD_INNO_TEMPLATE, - staging_dir / "SuperPicky.iss", + template, + staging_dir / iss_filename, app_version=config.app_version, commit_hash=config.commit_hash, - patch=False, ) copy_file(INNO_LANGUAGE_FILE, staging_dir / INNO_LANGUAGE_FILE.name) - logger.info("[成功] 已准备标准安装包脚本目录: %s", staging_dir) - return staging_dir / "SuperPicky.iss" + copy_tree(ROOT_DIR / "img", staging_dir / "img") + log_verbose("[成功] 已准备标准安装包脚本目录: %s", staging_dir) + return staging_dir / iss_filename def publish_standard_build( @@ -489,12 +567,12 @@ def publish_standard_build( if not config.no_zip: zip_path = artifact_root / archive_name_for(label, config.app_version, config.commit_hash) create_zip_archive(zip_source_dir, zip_path) - logger.info("[成功] 已创建 ZIP 压缩包: %s", zip_path) + log_verbose("[成功] 已创建 ZIP 压缩包: %s", zip_path) else: zip_path = None - logger.info("[信息] 跳过 ZIP 压缩包创建 (--no-zip)") + log_verbose("[信息] 跳过 ZIP 压缩包创建 (--no-zip)") - logger.info("[成功] 已准备目录: %s", final_bundle_dir) + log_verbose("[成功] 已准备目录: %s", final_bundle_dir) return final_bundle_dir, zip_path, installer_script_path @@ -555,9 +633,9 @@ def prepare_patch_directory(cpu_bundle: Path, cuda_bundle: Path, config: BuildCo shutil.copy2(cuda_file, destination) manifest_path = write_patch_manifest(patch_dir, copied_patch_files) - logger.info("[成功] 已导出差异文件: %d 个不同文件, %d 个 CUDA 独有文件", different_count, cuda_only_count) - logger.info("[成功] 已写入补丁清单: %s", manifest_path) - logger.info("[成功] 补丁目录: %s", patch_dir) + log_verbose("[成功] 已导出差异文件: %d 个不同文件, %d 个 CUDA 独有文件", different_count, cuda_only_count) + log_verbose("[成功] 已写入补丁清单: %s", manifest_path) + log_verbose("[成功] 补丁目录: %s", patch_dir) return patch_dir @@ -572,14 +650,17 @@ def prepare_patch_installer_staging(portable_patch_dir: Path, config: BuildConfi staging_dir / PATCH_INNO_TEMPLATE.name, app_version=config.app_version, commit_hash=config.commit_hash, - patch=True, ) - logger.info("[成功] 已准备 CUDA 补丁安装包脚本目录: %s", staging_dir) + log_verbose("[成功] 已准备 CUDA 补丁安装包脚本目录: %s", staging_dir) return staging_dir / PATCH_INNO_TEMPLATE.name def ensure_inno_templates() -> None: - for path in (STANDARD_INNO_TEMPLATE, PATCH_INNO_TEMPLATE, INNO_LANGUAGE_FILE): + """检查所有 Inno Setup 模板和依赖文件是否存在。 + + Verify all Inno Setup templates and dependency files exist. + """ + for path in (STANDARD_INNO_TEMPLATE, LITE_INNO_TEMPLATE, PATCH_INNO_TEMPLATE, INNO_LANGUAGE_FILE): if not path.exists(): raise FileNotFoundError(f"缺少 Inno 相关文件: {path}") @@ -593,7 +674,7 @@ def resolve_final_root(build_type: str, copy_dir: Path | None) -> Path | None: def build_single_target(config: BuildConfig, label: str, python_exe: Path) -> tuple[BuildPaths, Path, Path | None, Path]: check_python_environment(python_exe, label.upper()) build_paths = get_build_paths(label) - build_bundle(python_exe, build_paths) + build_bundle(python_exe, build_paths, spec_file_for(label if label == "lite" else config.build_type)) final_root = resolve_final_root(config.build_type, config.copy_dir) final_bundle, zip_path, installer_script_path = publish_standard_build( label=label, @@ -623,6 +704,20 @@ def run_cpu_or_cuda_build(config: BuildConfig) -> None: logger.info("安装包脚本: %s", installer_script_path) +def run_lite_build(config: BuildConfig) -> None: + bootstrap_python = Path(sys.executable) + build_python = ensure_cpu_environment(bootstrap_python) + + clean_build_outputs() + _, final_bundle, zip_path, installer_script_path = build_single_target(config, "lite", build_python) + logger.info("[========================================]") + logger.info("Lite 构建完成") + logger.info("[========================================]") + logger.info("可执行文件: %s", final_bundle / f"{APP_NAME}.exe") + logger.info("压缩文件: %s", zip_path if zip_path else "(已跳过)") + logger.info("安装包脚本: %s", installer_script_path) + + def run_cuda_patch_build(config: BuildConfig) -> None: bootstrap_python = Path(sys.executable) cpu_python = ensure_cpu_environment(bootstrap_python) @@ -633,7 +728,7 @@ def run_cuda_patch_build(config: BuildConfig) -> None: cuda_python = ensure_cuda_environment(bootstrap_python) cuda_paths = get_build_paths("cuda") - build_bundle(cuda_python, cuda_paths) + build_bundle(cuda_python, cuda_paths, spec_file_for("cuda")) patch_dir = prepare_patch_directory(cpu_paths.bundle_dir, cuda_paths.bundle_dir, config) patch_installer_script = prepare_patch_installer_staging(patch_dir, config) @@ -644,15 +739,15 @@ def run_cuda_patch_build(config: BuildConfig) -> None: config.commit_hash, ) create_zip_archive(patch_dir, patch_zip) - logger.info("[成功] 已创建 CUDA 补丁 ZIP 压缩包: %s", patch_zip) + log_verbose("[成功] 已创建 CUDA 补丁 ZIP 压缩包: %s", patch_zip) else: patch_zip = None - logger.info("[信息] 跳过 CUDA 补丁 ZIP 压缩包创建 (--no-zip)") + log_verbose("[信息] 跳过 CUDA 补丁 ZIP 压缩包创建 (--no-zip)") log_step("步骤 7: 清理 CUDA 中间产物") remove_path(cuda_paths.work_dir) remove_path(cuda_paths.dist_dir) - logger.info("[成功] 已清理 CUDA 中间目录") + log_verbose("[成功] 已清理 CUDA 中间目录") logger.info("[========================================]") logger.info("CUDA Patch 构建完成") @@ -666,12 +761,23 @@ def run_cuda_patch_build(config: BuildConfig) -> None: def create_config(args: argparse.Namespace) -> BuildConfig: + """根据命令行参数创建构建配置。版本号自动添加 'v' 前缀。 + + Create build config from CLI arguments. Version is auto-prefixed with 'v'. + + 参数 / Parameters: + args (argparse.Namespace): 解析后的命令行参数 + + 返回 / Return: + BuildConfig: 构建配置对象 + """ + raw_version = args.version or read_app_version() return BuildConfig( build_type=args.build_type, copy_dir=Path(args.copy_dir).resolve() if args.copy_dir else None, no_zip=args.no_zip, debug=args.debug, - app_version=args.version or read_app_version(), + app_version=normalize_version(raw_version), commit_hash=get_commit_hash(), ) @@ -681,16 +787,17 @@ def main() -> None: configure_logging(args.debug) config = create_config(args) - ensure_spec_file() + ensure_spec_file(config.build_type) ensure_inno_templates() - import os as _os - _tag = _os.environ.get("RELEASE_TAG", "") - _channel = "nightly" if _tag and "-rc" in _tag.lower() else "official" - backup_path = inject_build_info(config.commit_hash, _channel) + release_tag = os.environ.get("RELEASE_TAG", "") + release_channel = "nightly" if release_tag and "-rc" in release_tag.lower() else "official" + backup_path = inject_build_info(config.commit_hash, release_channel) try: if config.build_type == "cuda-patch": run_cuda_patch_build(config) + elif config.build_type == "lite": + run_lite_build(config) else: run_cpu_or_cuda_build(config) finally: diff --git a/config.py b/config.py index e8c82ad..3a932d2 100644 --- a/config.py +++ b/config.py @@ -1,28 +1,18 @@ """ -SuperPicky 配置管理模块。 -SuperPicky configuration management module. +SuperPicky 配置管理模块 / SuperPicky configuration management module. 本文件负责静态常量、路径约定、轻量运行时覆盖入口与共享懒加载注册器。 This file owns static constants, path conventions, lightweight runtime overrides, and the shared lazy registry. -维护分层: -Maintenance layering: -- `config.py`:公共读取入口与基础配置。 - `config.py`: shared read entry points and foundational configuration. -- `advanced_config.py`:高级持久化配置的默认值、读写、迁移与 UI 对接。 - `advanced_config.py`: defaults, persistence, migration, and UI integration for advanced persistent settings. -- 用户配置文件默认位于 `get_app_config_dir() / "advanced_config.json"`。 - The user config file defaults to `get_app_config_dir() / "advanced_config.json"`. - -文档入口: -Documentation entry points: -- 维护指南:当前文件本身。 - Maintenance guide: this file itself. -- 中英对照配置指南:TODO - 待补正式路径。 - Bilingual configuration guide: TODO - add the final path later. +维护分层 / Maintenance layering: +- `config.py`:公共读取入口与基础配置 / shared read entry points and foundational configuration. +- `advanced_config.py`:高级持久化配置的默认值、读写、迁移与 UI 对接 / defaults, persistence, migration, and UI integration for advanced persistent settings. +- 用户配置文件默认位于 `get_app_config_dir() / "advanced_config.json"` / The user config file defaults to `get_app_config_dir() / "advanced_config.json"`. """ import json +import importlib +import logging import os import platform import sys @@ -31,35 +21,162 @@ from pathlib import Path from typing import Any, Callable, Dict, List, Optional -import torch +logger = logging.getLogger(__name__) +# Torch is intentionally imported lazily. +# macOS Lite frozen builds bundle Torch inside the app, and importing it at +# module-load time makes it easier to bind to a wrong partial path before the +# frozen runtime is fully settled. +torch = None -# ========================= -# 基础路径工具 -# ========================= -# 这一层只定义路径约定,不负责真实配置读写。 -# This layer only defines path conventions and does not implement actual config read/write behavior. +class _FallbackDevice: + def __init__(self, device_type: str): + self.type = device_type + + def __str__(self) -> str: + return self.type + + +def _get_torch_module(): + """Lazily (re)load torch so lightweight init can install it at runtime.""" + global torch + if torch is not None: + return torch + try: + torch = importlib.import_module("torch") + except Exception: + torch = None + return torch + + +def get_app_install_dir() -> Path: + """ + 返回应用安装根目录 / Return the application install root. + + Windows Lite 打包场景下,运行时、模型和数据库必须固定落在该目录内。 + In Windows Lite builds, runtime files, models, and databases must stay under this directory. + """ + if getattr(sys, "frozen", False): + executable = Path(sys.executable).resolve() + if sys.platform == "darwin" and executable.parent.name == "MacOS": + return executable.parents[2] + return executable.parent + return Path(__file__).resolve().parent + + +def get_runtime_meipass() -> Optional[str]: + """ + 返回 PyInstaller 注入的 `_MEIPASS` 路径字符串。 + Return the `_MEIPASS` path string injected by PyInstaller. + + 这是运行时动态属性,静态类型检查器并不知道它一定存在, + 所以所有调用方都应通过此函数统一访问,而不是直接读取 `sys._MEIPASS`。 + This is a runtime-only dynamic attribute that static type checkers do not + know about, so callers should go through this helper instead of touching + `sys._MEIPASS` directly. + """ + meipass = getattr(sys, "_MEIPASS", None) + if isinstance(meipass, str) and meipass: + return meipass + return None + + +def get_runtime_app_root() -> Optional[str]: + """ + 返回补丁覆盖层记录的真实应用根目录字符串。 + Return the real application root string recorded for the patch overlay. + + 在线补丁覆盖层会优先导入用户目录中的模块,导致 `__file__` 可能指向 + `code_updates/`。这里统一读取主入口注入的真实根目录,避免各模块自行 + 读取 `sys._SUPERPICKY_APP_ROOT` 触发静态告警。 + The patch overlay may cause `__file__` to point at `code_updates/`, so this + helper reads the real app root injected by the main entrypoint and avoids + direct `sys._SUPERPICKY_APP_ROOT` access across modules. + """ + app_root = getattr(sys, "_SUPERPICKY_APP_ROOT", None) + if isinstance(app_root, str) and app_root: + return app_root + return None + + +def set_runtime_app_root(app_root: str) -> str: + """ + 写入补丁覆盖层共享的真实应用根目录。 + Persist the real application root shared by the patch overlay. + + 这里使用 `setattr` 写入运行时动态属性,既保留现有打包/补丁行为, + 也避免直接赋值 `sys._SUPERPICKY_APP_ROOT` 触发 Pylance 属性告警。 + This helper uses `setattr` to preserve the existing runtime contract while + avoiding direct `sys._SUPERPICKY_APP_ROOT` assignments that trip Pylance. + """ + setattr(sys, "_SUPERPICKY_APP_ROOT", app_root) + return app_root + + +def get_bundled_resource_dir() -> Path: + """返回静态打包资源根目录 / Return the root directory for bundled static resources.""" + if getattr(sys, "frozen", False): + if sys.platform == "darwin": + executable = Path(sys.executable).resolve() + if executable.parent.name == "MacOS": + return executable.parents[1] / "Resources" + meipass = get_runtime_meipass() + if meipass is not None: + return Path(meipass) + app_root = get_runtime_app_root() + if app_root is not None: + return Path(app_root) + return get_app_install_dir() + + +def get_app_internal_dir() -> Path: + """ + 返回应用内部运行目录 / Return the application internal runtime directory. + + Windows one-dir 打包产物使用安装目录下的 `_internal/`。 + Other environments fall back to the bundled resource directory. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + return get_app_install_dir() / "_internal" + return get_bundled_resource_dir() + + +def get_install_scoped_resource_path( + relative_path: str, *, packaged_relative_path: Optional[str] = None +) -> Path: + """ + 返回安装目录约束下的资源路径 / Return a resource path constrained to the install directory when required. + + Windows Lite 打包环境下,模型/数据库/运行时等可变资源必须位于安装目录。 + Other environments keep using the bundled resource layout. + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + target_relative_path = packaged_relative_path or relative_path + return get_app_internal_dir() / target_relative_path + return get_bundled_resource_dir() / relative_path + + +def get_packaged_model_relative_path(relative_path: str) -> str: + """返回 Windows Lite 打包环境下模型的内部相对路径 / Return the packaged relative path for models in Windows Lite builds.""" + normalized = relative_path.replace("\\", "/") + if normalized.startswith("models/"): + return "models/" + normalized.split("/", 1)[1] + return normalized def resource_path(relative_path: str) -> str: """ - 返回打包资源路径,兼容开发环境与 PyInstaller。 - Return a packaged resource path compatible with development mode and PyInstaller. + 返回打包资源路径,兼容开发环境与 PyInstaller / Return a packaged resource path compatible with development mode and PyInstaller. `relative_path` 是资源相对路径,例如 `models/yolo11l-seg.pt`。 - `relative_path` is the resource-relative path, for example `models/yolo11l-seg.pt`. - 这里只用于内置资源定位,不能拿来定位用户配置或用户数据。 This is only for bundled resource lookup and must not be used for user config or user data paths. """ - meipass = getattr(sys, '_MEIPASS', None) - if isinstance(meipass, str): - return os.path.join(meipass, relative_path) - return os.path.join(os.path.abspath('.'), relative_path) + return str(get_bundled_resource_dir() / relative_path) -def get_app_config_dir(app_name: str = 'SuperPicky') -> Path: +def get_app_config_dir(app_name: str = "SuperPicky") -> Path: """ 返回跨平台应用配置目录(存放 advanced_config.json、补丁等程序配置)。 Return the cross-platform application config directory. @@ -71,57 +188,50 @@ def get_app_config_dir(app_name: str = 'SuperPicky') -> Path: 用途:advanced_config.json、code_updates/(补丁目录)等程序级配置。 """ - if sys.platform == 'darwin': - return Path.home() / 'Library' / 'Application Support' / app_name - if sys.platform == 'win32': - return Path.home() / 'AppData' / 'Local' / app_name - return Path.home() / '.config' / app_name + if sys.platform == "darwin": + return Path.home() / "Library" / "Application Support" / app_name + if sys.platform == "win32": + return Path.home() / "AppData" / "Local" / app_name + return Path.home() / ".config" / app_name -def get_app_data_dir(app_name: str = 'SuperPicky') -> Path: +def get_app_data_dir(app_name: str = "SuperPicky") -> Path: """ 返回跨平台用户数据目录(存放 birdid 设置等用户产物)。 Return the cross-platform user data directory. - ⚠️ 与 get_app_config_dir() 完全不同的路径,请勿混用: - 所有平台:~/Documents/SuperPicky_Data/ + ⚠️ 现已统一使用标准配置目录,与 get_app_config_dir() 返回相同路径: + macOS : ~/Library/Application Support/SuperPicky/ + Windows: ~/AppData/Local/SuperPicky/ + Linux : ~/.config/SuperPicky/ 用途:birdid_dock_settings.json 等用户可见的数据文件。 切勿用于存放补丁或程序内部配置(应使用 get_app_config_dir())。 """ - return Path.home() / 'Documents' / f'{app_name}_Data' + return get_app_config_dir(app_name) -def get_patch_dir(app_name: str = 'SuperPicky') -> Path: - """ - 返回在线补丁目录。 - Return the online patch directory. +def get_patch_dir(app_name: str = "SuperPicky") -> Path: + """返回在线补丁目录 / Return the online patch directory.""" + return get_app_config_dir(app_name) / "code_updates" - 补丁目录派生自配置目录。 - The patch directory is derived from the config directory. - """ - return get_app_config_dir(app_name) / 'code_updates' +def get_birdid_settings_path(app_name: str = "SuperPicky") -> Path: + """返回 BirdID Dock 设置文件路径 / Return the BirdID Dock settings file path.""" + return get_app_data_dir(app_name) / "birdid_dock_settings.json" -def get_birdid_settings_path(app_name: str = 'SuperPicky') -> Path: - """ - 返回 BirdID Dock 设置文件路径。 - Return the BirdID Dock settings file path. - 该文件属于用户数据,因此放在 app data 目录下。 - This file belongs to user data, so it lives under the app data directory. +def get_birdname_settings_path(app_name: str = "SuperPicky") -> Path: """ - return get_app_data_dir(app_name) / 'birdid_dock_settings.json' - + 返回 BirdName IOC 设置文件路径 / Return the BirdName IOC settings file path. -# ========================= -# 可覆盖配置(ENV + 配置文件) -# ========================= + 该文件属于全局用户配置,应统一收敛到标准配置目录下的 ioc/ 子目录。 + This file belongs to global user configuration and should live under the standard config directory's ioc/ subdirectory. + """ + settings_dir = get_app_config_dir(app_name) / "ioc" + settings_dir.mkdir(parents=True, exist_ok=True) + return settings_dir / "birdname_settings.ini" -# 这里只读取覆盖值,不定义高级配置 schema。 -# This layer only reads override values and does not define the advanced config schema. -# 优先级:ENV > advanced_config.json > 默认值。 -# Priority: ENV > advanced_config.json > default value. _override_cache: Optional[Dict[str, Any]] = None _override_lock = threading.RLock() @@ -140,13 +250,13 @@ def _load_override_file() -> Dict[str, Any]: if _override_cache is not None: return _override_cache - cfg_path = get_app_config_dir() / 'advanced_config.json' + cfg_path = get_app_config_dir() / "advanced_config.json" if not cfg_path.exists(): _override_cache = {} return _override_cache try: - _override_cache = json.loads(cfg_path.read_text(encoding='utf-8')) + _override_cache = json.loads(cfg_path.read_text(encoding="utf-8")) except Exception: _override_cache = {} return _override_cache if _override_cache is not None else {} @@ -154,8 +264,7 @@ def _load_override_file() -> Dict[str, Any]: def _parse_bool(value: Optional[str], default: bool) -> bool: """ - 把字符串值解析为布尔值。 - Parse a string-like value into a boolean. + 把字符串值解析为布尔值 / Parse a string-like value into a boolean. 支持 `1/true/yes/on` 和 `0/false/no/off`,否则返回默认值。 Supports `1/true/yes/on` and `0/false/no/off`, otherwise returns the default. @@ -163,23 +272,22 @@ def _parse_bool(value: Optional[str], default: bool) -> bool: if value is None: return default norm = str(value).strip().lower() - if norm in {'1', 'true', 'yes', 'on'}: + if norm in {"1", "true", "yes", "on"}: return True - if norm in {'0', 'false', 'no', 'off'}: + if norm in {"0", "false", "no", "off"}: return False return default def _env_or_override(name: str, override_key: Optional[str], default: Any) -> Any: """ - 按 ENV > JSON > 默认值 的优先级读取覆盖值。 - Read an override using the priority order ENV > JSON > default. + 按 ENV > JSON > 默认值 的优先级读取覆盖值 / Read an override using the priority order ENV > JSON > default. 这里不做类型转换,调用方自行转成 `int`、`float` 或 `str`。 No type conversion is done here; callers should convert to `int`, `float`, or `str` themselves. """ env_value = os.getenv(name) - if env_value is not None and str(env_value).strip() != '': + if env_value is not None and str(env_value).strip() != "": return env_value if override_key: @@ -190,217 +298,182 @@ def _env_or_override(name: str, override_key: Optional[str], default: Any) -> An return default -# ========================= -# 静态常量分层 -# ========================= - -# 这些 dataclass 用来按领域收拢常量。 -# These dataclasses group constants by domain. - - @dataclass class FileConfig: """ - 文件处理相关静态配置。 - Static configuration related to file handling. + 文件处理相关静态配置 / Static configuration related to file handling. 这些列表会被 RAW/JPG 分类逻辑直接消费。 These lists are consumed directly by RAW/JPG classification logic. """ - # RAW_EXTENSIONS:被视为 RAW 的扩展名列表。 - # RAW_EXTENSIONS: extensions treated as RAW files. - RAW_EXTENSIONS: List[str] = field(default_factory=lambda: [ - '.nef', '.cr2', '.cr3', '.arw', '.raf', - '.orf', '.rw2', '.pef', '.dng', '.3fr', '.iiq' - ]) - # JPG_EXTENSIONS:被视为 JPG/JPEG 的扩展名列表。 - # JPG_EXTENSIONS: extensions treated as JPG/JPEG files. - JPG_EXTENSIONS: List[str] = field(default_factory=lambda: ['.jpg', '.jpeg']) + RAW_EXTENSIONS: List[str] = field( + default_factory=lambda: [ + ".nef", + ".cr2", + ".cr3", + ".arw", + ".raf", + ".orf", + ".rw2", + ".pef", + ".dng", + ".3fr", + ".iiq", + ] + ) + JPG_EXTENSIONS: List[str] = field(default_factory=lambda: [".jpg", ".jpeg"]) @dataclass class DirectoryConfig: """ - 输出目录与报告文件命名配置。 - Naming configuration for output directories and report files. + 输出目录与报告文件命名配置 / Naming configuration for output directories and report files. 修改这些值会影响结果目录结构与报告文件名。 Changing these values affects result folder layout and report filenames. """ - # 高质量照片目录。 - # Directory for excellent photos. - EXCELLENT_DIR: str = '优秀' - # 普通保留照片目录。 - # Directory for standard keepers. - STANDARD_DIR: str = '标准' - # 无鸟或废片目录。 - # Directory for no-bird or rejected photos. - NO_BIRDS_DIR: str = '没鸟' - # 内部临时目录。 - # Internal temporary directory. - TEMP_DIR: str = '_temp' - # 特定工作流使用的 Redbox 目录。 - # Redbox directory for specific workflows. - REDBOX_DIR: str = 'Redbox' - # 裁切临时目录。 - # Temporary crop directory. - CROP_TEMP_DIR: str = '.crop_temp' - - # 旧算法优秀目录。 - # Old-algorithm excellent directory. - OLD_ALGORITHM_EXCELLENT: str = '老算法优秀' - # 新算法优秀目录。 - # New-algorithm excellent directory. - NEW_ALGORITHM_EXCELLENT: str = '新算法优秀' - # 双算法共同优秀目录。 - # Intersection directory for both algorithms. - BOTH_ALGORITHMS_EXCELLENT: str = '双算法优秀' - # 算法差异目录。 - # Directory for algorithm-difference samples. - ALGORITHM_DIFF_DIR: str = '算法差异' - - # 处理日志文件名。 - # Processing log filename. - LOG_FILE: str = '.process_log.txt' - # 主报告 SQLite 文件名。 - # Primary SQLite report filename. - REPORT_FILE: str = '.report.db' - # 算法对比 CSV 文件名。 - # Algorithm comparison CSV filename. - COMPARISON_REPORT_FILE: str = '.algorithm_comparison.csv' + EXCELLENT_DIR: str = "优秀" + STANDARD_DIR: str = "标准" + NO_BIRDS_DIR: str = "没鸟" + TEMP_DIR: str = "_temp" + REDBOX_DIR: str = "Redbox" + CROP_TEMP_DIR: str = ".crop_temp" + OLD_ALGORITHM_EXCELLENT: str = "老算法优秀" + NEW_ALGORITHM_EXCELLENT: str = "新算法优秀" + BOTH_ALGORITHMS_EXCELLENT: str = "双算法优秀" + ALGORITHM_DIFF_DIR: str = "算法差异" + LOG_FILE: str = ".process_log.txt" + REPORT_FILE: str = ".report.db" + COMPARISON_REPORT_FILE: str = ".algorithm_comparison.csv" @dataclass class AIConfig: """ - AI 模型与推理相关静态配置。 - Static configuration related to AI models and inference. + AI 模型与推理相关静态配置 / Static configuration related to AI models and inference. 这些值服务于模型定位与基础推理行为,不替代高级用户参数。 These values support model lookup and baseline inference behavior and do not replace advanced user-facing parameters. """ - # 主模型相对路径。 - # Relative path to the main model. - MODEL_FILE: str = 'models/yolo11l-seg.pt' - # “鸟”类别的 class id。 - # Class id for the "bird" category. + MODEL_FILE: str = "models/yolo11l-seg.pt" BIRD_CLASS_ID: int = 14 - # 推理目标尺寸。 - # Target inference image size. TARGET_IMAGE_SIZE: int = 1024 - # 主体居中判断默认阈值。 - # Default threshold for centered-subject checks. CENTER_THRESHOLD: float = 0.15 - # 锐度归一化策略标识,默认不指定。 - # Sharpness normalization strategy marker, unset by default. SHARPNESS_NORMALIZATION: Optional[str] = None def get_model_path(self) -> str: """ - 返回主模型的实际可访问路径。 - Return the actual accessible path to the main model. + 返回主模型的实际可访问路径 / Return the actual accessible path to the main model. 调用方不应自行拼 PyInstaller 临时目录。 Callers should not manually stitch together PyInstaller temporary paths. """ - return resource_path(self.MODEL_FILE) + return str( + get_install_scoped_resource_path( + self.MODEL_FILE, + packaged_relative_path=get_packaged_model_relative_path( + self.MODEL_FILE + ), + ) + ) @dataclass class UIConfig: """ - UI 展示层静态常量。 - Static constants for the UI presentation layer. + UI 展示层静态常量 / Static constants for the UI presentation layer. 这里只放显示缩放和进度边界,不放业务阈值。 This group is for display scaling and progress bounds, not business thresholds. """ - # 置信度百分比缩放。 - # Percentage scale for confidence display. CONFIDENCE_SCALE: float = 100.0 - # 面积显示缩放。 - # Display scale for area values. AREA_SCALE: float = 1000.0 - # 锐度显示缩放。 - # Display scale for sharpness values. SHARPNESS_SCALE: int = 20 - # 进度条最小值。 - # Minimum progress-bar value. PROGRESS_MIN: int = 0 - # 进度条最大值。 - # Maximum progress-bar value. PROGRESS_MAX: int = 100 - # 默认提示音次数。 - # Default completion beep count. BEEP_COUNT: int = 3 @dataclass class CSVConfig: """ - CSV 报告结构配置。 - CSV report structure configuration. + CSV 报告结构配置 / CSV report structure configuration. `HEADERS` 定义导出列顺序,兼容性要求较高。 `HEADERS` defines export-column order and carries relatively high compatibility requirements. """ - # 报告列名顺序定义。 - # Ordered definition of report header names. - HEADERS: List[str] = field(default_factory=lambda: [ - 'filename', 'found_bird', 'AI score', 'bird_centre_x', - 'bird_centre_y', 'bird_area', 's_bird_area', - 'laplacian_var', 'sobel_var', 'fft_high_freq', 'contrast', - 'edge_density', 'background_complexity', 'motion_blur', - 'normalized_new', 'composite_score', 'result_new', - 'dominant_bool', 'centred_bool', 'sharp_bool', 'class_id' - ]) + HEADERS: List[str] = field( + default_factory=lambda: [ + "filename", + "found_bird", + "AI score", + "bird_centre_x", + "bird_centre_y", + "bird_area", + "s_bird_area", + "laplacian_var", + "sobel_var", + "fft_high_freq", + "contrast", + "edge_density", + "background_complexity", + "motion_blur", + "normalized_new", + "composite_score", + "result_new", + "dominant_bool", + "centred_bool", + "sharp_bool", + "class_id", + ] + ) @dataclass class ServerConfig: """ - BirdID 服务默认配置。 - Default configuration for the BirdID service. + BirdID 服务默认配置 / Default configuration for the BirdID service. 这些值可被 ENV 覆盖,并影响绑定地址、启动等待和健康检查节奏。 These values can be overridden by ENV and affect bind address, startup wait, and health-check cadence. """ - # 默认监听地址。ENV: SUPERPICKY_SERVER_HOST - # Default bind host. ENV: SUPERPICKY_SERVER_HOST - HOST: str = '127.0.0.1' - # 默认端口。ENV: SUPERPICKY_SERVER_PORT - # Default port. ENV: SUPERPICKY_SERVER_PORT + HOST: str = "127.0.0.1" PORT: int = 5156 - # 单次健康检查超时。ENV: SUPERPICKY_SERVER_HEALTH_TIMEOUT - # Single health-check timeout. ENV: SUPERPICKY_SERVER_HEALTH_TIMEOUT HEALTH_TIMEOUT_SECONDS: float = 2.0 - # 启动就绪总等待时间。ENV: SUPERPICKY_SERVER_STARTUP_WAIT - # Total startup readiness wait time. ENV: SUPERPICKY_SERVER_STARTUP_WAIT STARTUP_WAIT_SECONDS: float = 10.0 - # 健康检查轮询间隔。ENV: SUPERPICKY_SERVER_POLL_INTERVAL - # Health-check polling interval. ENV: SUPERPICKY_SERVER_POLL_INTERVAL POLL_INTERVAL_SECONDS: float = 0.5 @classmethod - def load(cls) -> 'ServerConfig': + def load(cls) -> "ServerConfig": """ - 按覆盖优先级构造 ServerConfig。 - Build ServerConfig using the override priority rules. + 按覆盖优先级构造 ServerConfig / Build ServerConfig using the override priority rules. 类型转换统一在这里做,避免调用方重复解析 ENV。 Type coercion is centralized here so callers do not repeat ENV parsing. """ - host = str(_env_or_override('SUPERPICKY_SERVER_HOST', None, cls.HOST)) - port = int(_env_or_override('SUPERPICKY_SERVER_PORT', None, cls.PORT)) - health_timeout = float(_env_or_override('SUPERPICKY_SERVER_HEALTH_TIMEOUT', None, cls.HEALTH_TIMEOUT_SECONDS)) - startup_wait = float(_env_or_override('SUPERPICKY_SERVER_STARTUP_WAIT', None, cls.STARTUP_WAIT_SECONDS)) - poll = float(_env_or_override('SUPERPICKY_SERVER_POLL_INTERVAL', None, cls.POLL_INTERVAL_SECONDS)) + host = str(_env_or_override("SUPERPICKY_SERVER_HOST", None, cls.HOST)) + port = int(_env_or_override("SUPERPICKY_SERVER_PORT", None, cls.PORT)) + health_timeout = float( + _env_or_override( + "SUPERPICKY_SERVER_HEALTH_TIMEOUT", None, cls.HEALTH_TIMEOUT_SECONDS + ) + ) + startup_wait = float( + _env_or_override( + "SUPERPICKY_SERVER_STARTUP_WAIT", None, cls.STARTUP_WAIT_SECONDS + ) + ) + poll = float( + _env_or_override( + "SUPERPICKY_SERVER_POLL_INTERVAL", None, cls.POLL_INTERVAL_SECONDS + ) + ) return cls( HOST=host, PORT=port, @@ -413,47 +486,50 @@ def load(cls) -> 'ServerConfig': @dataclass class EndpointConfig: """ - 远程服务端点默认配置。 - Default configuration for remote service endpoints. + 远程服务端点默认配置 / Default configuration for remote service endpoints. 这些 URL 会影响下载页、eBird 查询与 Nominatim 反查等网络行为。 These URLs affect network behavior such as download pages, eBird queries, and Nominatim reverse lookups. """ - # 镜像或资源下载基础地址。 - # Base URL for mirrors or downloadable resources. - MIRROR_BASE_URL: str = 'http://1.119.150.179:59080/superpicky' - # 给用户打开的下载页面地址。 - # Download page URL opened for the user. - UPDATE_DOWNLOAD_PAGE: str = 'https://superpicky.jamesphotography.com.au/#download' - # eBird API 根地址。 - # Root URL for the eBird API. - EBIRD_API_BASE: str = 'https://api.ebird.org/v2' - # Nominatim 反向地理编码接口。 - # Reverse geocoding endpoint for Nominatim. - NOMINATIM_REVERSE_URL: str = 'https://nominatim.openstreetmap.org/reverse' + MIRROR_BASE_URL: str = "http://1.119.150.179:59080/superpicky" + UPDATE_DOWNLOAD_PAGE: str = "https://superpicky.jamesphotography.com.au/#download" + EBIRD_API_BASE: str = "https://api.ebird.org/v2" + NOMINATIM_REVERSE_URL: str = "https://nominatim.openstreetmap.org/reverse" @classmethod - def load(cls) -> 'EndpointConfig': + def load(cls) -> "EndpointConfig": """ - 按覆盖优先级构造 EndpointConfig。 - Build EndpointConfig using the override priority rules. + 按覆盖优先级构造 EndpointConfig / Build EndpointConfig using the override priority rules. 统一入口便于未来继续扩展 ENV 覆盖。 A unified entry point makes future ENV override expansion easier. """ return cls( - MIRROR_BASE_URL=str(_env_or_override('SUPERPICKY_MIRROR_BASE_URL', None, cls.MIRROR_BASE_URL)), - UPDATE_DOWNLOAD_PAGE=str(_env_or_override('SUPERPICKY_DOWNLOAD_PAGE', None, cls.UPDATE_DOWNLOAD_PAGE)), - EBIRD_API_BASE=str(_env_or_override('SUPERPICKY_EBIRD_API_BASE', None, cls.EBIRD_API_BASE)), - NOMINATIM_REVERSE_URL=str(_env_or_override('SUPERPICKY_NOMINATIM_REVERSE_URL', None, cls.NOMINATIM_REVERSE_URL)), + MIRROR_BASE_URL=str( + _env_or_override( + "SUPERPICKY_MIRROR_BASE_URL", None, cls.MIRROR_BASE_URL + ) + ), + UPDATE_DOWNLOAD_PAGE=str( + _env_or_override( + "SUPERPICKY_DOWNLOAD_PAGE", None, cls.UPDATE_DOWNLOAD_PAGE + ) + ), + EBIRD_API_BASE=str( + _env_or_override("SUPERPICKY_EBIRD_API_BASE", None, cls.EBIRD_API_BASE) + ), + NOMINATIM_REVERSE_URL=str( + _env_or_override( + "SUPERPICKY_NOMINATIM_REVERSE_URL", None, cls.NOMINATIM_REVERSE_URL + ) + ), ) class Config: """ - 主配置聚合类。 - Main configuration aggregation class. + 主配置聚合类 / Main configuration aggregation class. 这是项目中最常用的统一读取入口,用来整合不同层次的配置。 This is the most common unified read entry point used to aggregate different configuration layers. @@ -461,8 +537,7 @@ class Config: def __init__(self): """ - 构造统一配置对象。 - Construct the unified configuration object. + 构造统一配置对象 / Construct the unified configuration object. 初始化时建立静态常量分组,并加载服务与端点配置。 Initialization builds static config groups and loads service and endpoint configuration. @@ -477,25 +552,23 @@ def __init__(self): def get_directory_names(self) -> Dict[str, str]: """ - 返回常用输出目录名映射。 - Return a mapping of commonly used output directory names. + 返回常用输出目录名映射 / Return a mapping of commonly used output directory names. 适合 UI 展示和流程内统一引用目录名。 Useful for UI display and for consistent directory references inside processing flows. """ return { - 'excellent': self.directory.EXCELLENT_DIR, - 'standard': self.directory.STANDARD_DIR, - 'no_birds': self.directory.NO_BIRDS_DIR, - 'temp': self.directory.TEMP_DIR, - 'redbox': self.directory.REDBOX_DIR, - 'crop_temp': self.directory.CROP_TEMP_DIR, + "excellent": self.directory.EXCELLENT_DIR, + "standard": self.directory.STANDARD_DIR, + "no_birds": self.directory.NO_BIRDS_DIR, + "temp": self.directory.TEMP_DIR, + "redbox": self.directory.REDBOX_DIR, + "crop_temp": self.directory.CROP_TEMP_DIR, } def is_raw_file(self, filename: str) -> bool: """ - 判断文件名是否属于 RAW 扩展名集合。 - Check whether a filename belongs to the RAW extension set. + 判断文件名是否属于 RAW 扩展名集合 / Check whether a filename belongs to the RAW extension set. 这里只按扩展名判断,不检查内容或 MIME。 This only checks file extensions and does not inspect content or MIME type. @@ -505,8 +578,7 @@ def is_raw_file(self, filename: str) -> bool: def is_jpg_file(self, filename: str) -> bool: """ - 判断文件名是否属于 JPG/JPEG 扩展名集合。 - Check whether a filename belongs to the JPG/JPEG extension set. + 判断文件名是否属于 JPG/JPEG 扩展名集合 / Check whether a filename belongs to the JPG/JPEG extension set. 这是轻量判断入口,不做内容探测。 This is a lightweight classification entry and does not inspect file contents. @@ -515,36 +587,25 @@ def is_jpg_file(self, filename: str) -> bool: return ext.lower() in self.file.JPG_EXTENSIONS -# ========================= -# 懒加载资源注册器 -# ========================= - -# 这个注册器用于跨模块共享可缓存、可复用、构造成本高的对象。 -# This registry is for cacheable, reusable, high-construction-cost objects shared across modules. -# 不适合用来存放短生命周期业务状态。 -# It is not suitable for short-lived business state. - _MISSING = object() class LazyRegistry: """ - 线程安全懒加载注册器。 - Thread-safe lazy registry. + 线程安全区加载注册器 / Thread-safe lazy registry. 目标是避免重复初始化重量级对象,并提供统一的共享入口。 Its goal is to avoid repeated heavy initialization and provide a unified sharing entry point. """ def __init__(self): - """初始化内部存储和锁。 Initialize the internal storage and lock.""" + """初始化内部存储和锁 / Initialize the internal storage and lock.""" self._values: Dict[str, Any] = {} self._lock = threading.RLock() def get_or_create(self, key: str, factory: Callable[[], Any]) -> Any: """ - 读取缓存对象,不存在时在锁内创建并缓存。 - Read a cached object and create/cache it inside the lock if it is missing. + 读取缓存对象,不存在时在锁内创建并缓存 / Read a cached object and create/cache it inside the lock if it is missing. 采用无锁快速读取加锁内二次检查,避免并发重复创建。 This uses a fast unlocked read plus a locked second check to avoid duplicate concurrent construction. @@ -560,17 +621,13 @@ def get_or_create(self, key: str, factory: Callable[[], Any]) -> Any: return value def get(self, key: str, default: Any = None) -> Any: - """ - 读取缓存值,不触发创建。 - Read a cached value without triggering creation. - """ + """读取缓存值,不触发创建 / Read a cached value without triggering creation.""" with self._lock: return self._values.get(key, default) def set(self, key: str, value: Any) -> None: """ - 显式设置缓存值。 - Explicitly set a cached value. + 显式设置缓存值 / Explicitly set a cached value. 不要用它存放临时业务状态。 Do not use this to stash temporary business state. @@ -580,8 +637,7 @@ def set(self, key: str, value: Any) -> None: def clear(self, key: str) -> None: """ - 清除单个缓存项。 - Clear a single cached item. + 清除单个缓存项 / Clear a single cached item. 适合测试隔离或强制下次重建。 Useful for test isolation or forcing the next access to rebuild the object. @@ -591,8 +647,7 @@ def clear(self, key: str) -> None: def clear_all(self) -> None: """ - 清空所有缓存项。 - Clear all cached items. + 清空所有缓存项 / Clear all cached items. 若缓存对象持有外部资源,应先确保存在显式关闭流程。 If cached objects hold external resources, make sure an explicit shutdown flow exists first. @@ -606,8 +661,7 @@ def clear_all(self) -> None: def get_lazy_registry() -> LazyRegistry: """ - 返回全局懒加载注册器实例。 - Return the global lazy-registry instance. + 返回全局懒加载注册器实例 / Return the global lazy-registry instance. 调用方应通过此入口获取共享注册器,避免自行新建导致缓存割裂。 Callers should use this entry point to get the shared registry and avoid cache fragmentation caused by creating their own registries. @@ -615,20 +669,9 @@ def get_lazy_registry() -> LazyRegistry: return _lazy_registry -# ========================= -# 设备选择 -# ========================= - -# 设备选择逻辑必须集中,避免不同模块各自判断 CUDA/MPS/CPU。 -# Device selection must stay centralized so different modules do not each make their own CUDA/MPS/CPU decisions. -# 若打包版与源码行为不同,优先怀疑打包环境中的 Torch/CUDA 运行时差异。 -# If packaged behavior differs from source behavior, suspect Torch/CUDA runtime differences in the packaged environment first. - - def get_best_device(): """ - 返回当前环境下最合适的 Torch 设备对象。 - Return the most appropriate Torch device object for the current environment. + 返回当前环境下最合适的 Torch 设备对象 / Return the most appropriate Torch device object for the current environment. 顺序为:macOS 先 MPS 再 CPU,其他平台先 CUDA 再 CPU。 The order is: on macOS use MPS then CPU, on other platforms use CUDA then CPU. @@ -637,19 +680,157 @@ def get_best_device(): On any detection failure, conservatively fall back to CPU. """ try: + torch_module = _get_torch_module() + if torch_module is None: + return _FallbackDevice("cpu") system = platform.system() - if system == 'Darwin': - if torch.backends.mps.is_available(): - return torch.device('mps') - return torch.device('cpu') - - if torch.cuda.is_available(): - return torch.device('cuda') - return torch.device('cpu') + if system == "Darwin": + if torch_module.backends.mps.is_available(): + return torch_module.device("mps") + return torch_module.device("cpu") + + if torch_module.cuda.is_available(): + return torch_module.device("cuda") + return torch_module.device("cpu") except Exception: - return torch.device('cpu') + torch_module = _get_torch_module() + return ( + torch_module.device("cpu") + if torch_module is not None + else _FallbackDevice("cpu") + ) + + +def migrate_old_data() -> bool: + """ + 迁移旧路径数据到新路径 / Migrate old path data to new path. + + 检测 ~/Documents/SuperPicky_Data 目录是否存在数据, + 如果存在则迁移到 get_app_config_dir() 返回的标准配置目录。 + + Returns: + bool: 迁移是否成功(如果没有旧数据也返回 True) + """ + try: + old_data_dir = Path.home() / "Documents" / "SuperPicky_Data" + new_data_dir = get_app_config_dir() + + if not old_data_dir.exists() or not old_data_dir.is_dir(): + return True + + files = list(old_data_dir.iterdir()) + if not files: + return True + + logger.info(f"检测到旧数据目录: {old_data_dir}") + logger.info(f"开始迁移到新目录: {new_data_dir}") + + new_data_dir.mkdir(parents=True, exist_ok=True) + + copied_files = [] + for file_path in files: + try: + dest_path = new_data_dir / file_path.name + if file_path.is_file(): + import shutil + + shutil.copy2(file_path, dest_path) + copied_files.append(file_path.name) + elif file_path.is_dir(): + import shutil + + shutil.copytree(file_path, dest_path, dirs_exist_ok=True) + copied_files.append(file_path.name) + except Exception as e: + logger.error(f"复制文件失败 {file_path.name}: {e}") + return False + + logger.info(f"成功迁移 {len(copied_files)} 个文件/目录") + + for file_name in copied_files: + try: + old_path = old_data_dir / file_name + if old_path.exists(): + if old_path.is_file(): + old_path.unlink() + elif old_path.is_dir(): + import shutil + + shutil.rmtree(old_path) + except Exception as e: + logger.warning(f"删除旧文件失败 {file_name}: {e}") + + try: + if old_data_dir.exists() and old_data_dir.is_dir(): + import shutil + + shutil.rmtree(old_data_dir) + logger.info(f"已删除旧数据目录: {old_data_dir}") + except Exception as e: + logger.warning(f"删除旧目录失败: {e}") + + logger.info("数据迁移完成") + return True + + except Exception as e: + logger.error(f"数据迁移失败: {e}") + return False + + +def migrate_legacy_ioc_settings(app_name: str = "SuperPicky") -> bool: + """ + 迁移旧的用户主目录 IOC 设置到标准配置目录。 + Migrate legacy IOC settings from the user home directory to the standard config directory. + + 仅处理 ~/.superpicky/ioc/birdname_settings.ini 这类全局配置残留, + 不涉及照片目录中的 .superpicky 工作文件。 + """ + try: + import shutil + + old_settings_path = ( + Path.home() / ".superpicky" / "ioc" / "birdname_settings.ini" + ) + new_settings_path = get_birdname_settings_path(app_name) + + if not old_settings_path.exists() or not old_settings_path.is_file(): + return True + + if new_settings_path.exists(): + logger.info(f"检测到新的 IOC 配置已存在,保留新路径: {new_settings_path}") + return True + + shutil.copy2(old_settings_path, new_settings_path) + logger.info(f"已迁移 IOC 配置: {old_settings_path} -> {new_settings_path}") + + try: + old_settings_path.unlink() + except Exception as e: + logger.warning(f"删除旧 IOC 配置失败: {e}") + return True + + old_ioc_dir = old_settings_path.parent + old_superpicky_dir = old_ioc_dir.parent + try: + if ( + old_ioc_dir.exists() + and old_ioc_dir.is_dir() + and not any(old_ioc_dir.iterdir()) + ): + old_ioc_dir.rmdir() + if ( + old_superpicky_dir.exists() + and old_superpicky_dir.is_dir() + and not any(old_superpicky_dir.iterdir()) + ): + old_superpicky_dir.rmdir() + except Exception as e: + logger.warning(f"清理旧 IOC 目录失败: {e}") + + return True + except Exception as e: + logger.error(f"IOC 配置迁移失败: {e}") + return False -# 全局配置实例,供多数模块直接 import 使用。 -# Global configuration instance intended for direct import by most modules. config = Config() diff --git a/core/batch_processor.py b/core/batch_processor.py index 880130c..439e712 100644 --- a/core/batch_processor.py +++ b/core/batch_processor.py @@ -9,10 +9,10 @@ import os import json import time -from typing import List, Dict, Optional, Callable +from typing import Callable, Dict, List, Optional, Sequence, Union from dataclasses import dataclass, field -from core.recursive_scanner import scan_recursive, is_processed, count_photos +from core.recursive_scanner import DEFAULT_SCAN_MAX_DEPTH, ScannedDirectory, count_photos, is_processed, scan_directories @dataclass @@ -39,7 +39,7 @@ def __init__( root_dir: str, settings, # ProcessingSettings skip_existing: bool = False, - max_depth: int = 10, + max_depth: int = DEFAULT_SCAN_MAX_DEPTH, log_fn: Optional[Callable[[str], None]] = None, ): self.root_dir = os.path.abspath(root_dir) @@ -48,13 +48,13 @@ def __init__( self.max_depth = max_depth self.log = log_fn or print - def scan(self) -> List[str]: - """扫描并返回待处理的原子目录列表""" - return scan_recursive(self.root_dir, self.max_depth) + def scan(self) -> List[ScannedDirectory]: + """扫描并返回待处理的原子目录摘要列表""" + return scan_directories(self.root_dir, self.max_depth) def process( self, - dirs: List[str], + dirs: Sequence[Union[str, ScannedDirectory]], organize_files: bool = True, cleanup_temp: bool = True, ) -> BatchResult: @@ -71,16 +71,30 @@ def process( """ from core.photo_processor import PhotoProcessor, ProcessingCallbacks - result = BatchResult(total_dirs=len(dirs)) + normalized_dirs: List[ScannedDirectory] = [] + for entry in dirs: + if isinstance(entry, ScannedDirectory): + normalized_dirs.append(entry) + continue + normalized_dirs.append( + ScannedDirectory( + path=entry, + depth=-1, + photo_count=count_photos(entry), + ) + ) + + result = BatchResult(total_dirs=len(normalized_dirs)) batch_start = time.time() - for i, dir_path in enumerate(dirs, 1): + for i, scanned_dir in enumerate(normalized_dirs, 1): + dir_path = scanned_dir.path dir_name = os.path.relpath(dir_path, self.root_dir) - photo_count = count_photos(dir_path) + photo_count = scanned_dir.photo_count # 增量跳过 if self.skip_existing and is_processed(dir_path): - self.log(f"\n⏭️ [{i}/{len(dirs)}] 跳过已处理: {dir_name} ({photo_count} 张)") + self.log(f"\n⏭️ [{i}/{len(normalized_dirs)}] 跳过已处理: {dir_name} ({photo_count} 张)") result.skipped_dirs += 1 result.dir_results.append({ 'dir': dir_name, @@ -90,7 +104,7 @@ def process( continue self.log(f"\n{'━' * 60}") - self.log(f"📂 [{i}/{len(dirs)}] 处理: {dir_name} ({photo_count} 张)") + self.log(f"📂 [{i}/{len(normalized_dirs)}] 处理: {dir_name} ({photo_count} 张)") self.log(f"{'━' * 60}") dir_start = time.time() diff --git a/core/build_info.py b/core/build_info.py index 7eb1dfe..e1fab1b 100644 --- a/core/build_info.py +++ b/core/build_info.py @@ -1,15 +1,15 @@ # -*- coding: utf-8 -*- """ 构建信息 -此文件在发布构建时由 build_release.sh 自动修改,以注入 git commit hash +此文件在发布构建时由 Python 构建脚本自动修改,以注入 git commit hash 和 release channel """ # 在打包时会被替换为实际的 commit hash -COMMIT_HASH = "be2f41a3" +COMMIT_HASH = "6f2049e" # 发布渠道:CI 打包时自动注入("nightly" = RC 预发布,"official" = 正式版) # 本地开发默认 "dev",不触发更新检查 -RELEASE_CHANNEL = "dev" +RELEASE_CHANNEL = "official" diff --git a/core/flight_detector.py b/core/flight_detector.py index 433057d..f3f17f4 100644 --- a/core/flight_detector.py +++ b/core/flight_detector.py @@ -1,69 +1,77 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -Flight Detector - 飞版检测模块 -使用 EfficientNet-B3 模型检测鸟类是否处于飞行状态 +Flight Detector - 飞版检测模块。 +Flight Detector module. -V3.4 新增功能 +使用 EfficientNet-B3 模型检测鸟类是否处于飞行。 +Uses an EfficientNet-B3 model to determine whether a bird is in flight. """ from pathlib import Path from dataclasses import dataclass -from typing import Optional, Union +from typing import Optional, Union, cast import numpy as np import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image -from config import get_best_device +from config import ( + get_best_device, + get_install_scoped_resource_path, + get_packaged_model_relative_path, + get_runtime_app_root, + get_runtime_meipass, +) @dataclass class FlightResult: """飞版检测结果""" - is_flying: bool # 是否飞行 - confidence: float # 置信度 (0-1) + is_flying: bool + confidence: float class FlightDetector: """ 飞版检测器 - - 使用 EfficientNet-B3 二分类模型判断鸟类是否处于飞行状态。 - 模型训练自 superFlier 项目,使用 RMSprop + last_block freeze 策略。 + 使用 EfficientNet-B3 二分类模型判断鸟类是否处于飞行状态 """ - - # 模型配置 - IMAGE_SIZE = 384 # 训练时的输入尺寸 - THRESHOLD = 0.5 # 默认分类阈值 + + IMAGE_SIZE = 384 + THRESHOLD = 0.5 def __init__(self, model_path: Optional[str] = None): """ 初始化检测器 - + Args: model_path: 模型文件路径,如果为 None 则使用默认路径 """ - self.model = None - self.device = None + self.model: Optional[nn.Module] = None + self.device: Optional[torch.device] = None self.model_loaded = False - - # 确定模型路径(支持 PyInstaller 打包) + if model_path is None: import sys - if hasattr(sys, '_MEIPASS'): - # PyInstaller 打包后的路径 - self.model_path = Path(sys._MEIPASS) / "models" / "superFlier_efficientnet.pth" + if getattr(sys, 'frozen', False) and sys.platform == 'win32': + self.model_path = get_install_scoped_resource_path( + "models/superFlier_efficientnet.pth", + packaged_relative_path=get_packaged_model_relative_path("models/superFlier_efficientnet.pth"), + ) else: - # 开发环境:优先使用 main.py 注入的真实 app 根目录(补丁覆盖层兼容) - project_root = Path(getattr(sys, '_SUPERPICKY_APP_ROOT', - str(Path(__file__).parent.parent))) - self.model_path = project_root / "models" / "superFlier_efficientnet.pth" + meipass = get_runtime_meipass() + if meipass is not None: + self.model_path = Path(meipass) / "models" / "superFlier_efficientnet.pth" + else: + project_root = get_runtime_app_root() + if project_root is None: + project_root = str(Path(__file__).parent.parent) + self.model_path = Path(project_root) / "models" / "superFlier_efficientnet.pth" else: self.model_path = Path(model_path) - - # 图像预处理(与训练时一致) + self.transform = transforms.Compose([ transforms.Resize((self.IMAGE_SIZE, self.IMAGE_SIZE)), transforms.ToTensor(), @@ -76,39 +84,38 @@ def __init__(self, model_path: Optional[str] = None): def _build_model(self) -> nn.Module: """ 构建 EfficientNet-B3 模型结构 - - 必须与训练时的结构完全一致: - - 使用 Dropout(0.2) - - 输出层为 Linear(in_features, 1) + Sigmoid """ - model = models.efficientnet_b3(weights=None) # 不需要预训练权重 - in_features = model.classifier[1].in_features - - # 替换分类头(与 grid_search.py 中的 DROPOUT=0.2 一致) - model.classifier = nn.Sequential( + model = cast(nn.Module, models.efficientnet_b3(weights=None)) + classifier = cast(nn.Sequential, getattr(model, "classifier")) + classifier_linear = cast(nn.Linear, classifier[1]) + in_features = classifier_linear.in_features + + setattr( + model, + "classifier", + nn.Sequential( nn.Dropout(0.2), nn.Linear(in_features, 1), nn.Sigmoid() + ), ) - + return model def load_model(self) -> None: """ 加载模型权重 - + Raises: FileNotFoundError: 模型文件不存在 RuntimeError: 模型加载失败 """ if not self.model_path.exists(): raise FileNotFoundError(f"飞版检测模型未找到: {self.model_path}") - - self.device = get_best_device() - - # 构建并加载模型 + + self.device = torch.device(str(get_best_device())) self.model = self._build_model() - + try: state_dict = torch.load( self.model_path, @@ -118,44 +125,38 @@ def load_model(self) -> None: self.model.load_state_dict(state_dict) except Exception as e: raise RuntimeError(f"加载飞版检测模型失败: {e}") - - self.model.to(self.device) + + self.model.to(device=self.device) self.model.eval() self.model_loaded = True def detect( - self, + self, image: Union[np.ndarray, Image.Image, str], - threshold: float = None + threshold: Optional[float] = None ) -> FlightResult: """ 检测图像中的鸟是否处于飞行状态 - + Args: - image: 输入图像,支持以下格式: - - numpy.ndarray (BGR 或 RGB,由 OpenCV 或其他库读取) - - PIL.Image - - str (图像文件路径) + image: 输入图像,支持 numpy.ndarray、PIL.Image 或文件路径 threshold: 分类阈值,默认使用 self.THRESHOLD (0.5) - + Returns: FlightResult: 包含 is_flying 和 confidence - + Raises: RuntimeError: 模型未加载 """ if not self.model_loaded: raise RuntimeError("飞版检测模型未加载,请先调用 load_model()") - + if threshold is None: threshold = self.THRESHOLD - - # 处理不同输入类型 + if isinstance(image, str): - # 文件路径 pil_image = Image.open(image).convert('RGB') elif isinstance(image, np.ndarray): - # numpy 数组(假设是 BGR,需要转换) import cv2 if len(image.shape) == 3 and image.shape[2] == 3: rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) @@ -166,14 +167,16 @@ def detect( pil_image = image.convert('RGB') else: raise ValueError(f"不支持的图像类型: {type(image)}") - - # 预处理 - image_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) - # 推理 + transformed_tensor = cast(torch.Tensor, self.transform(pil_image)) + image_tensor = transformed_tensor.unsqueeze(0).to(self.device) + + if self.model is None: + raise RuntimeError("飞版检测模型尚未初始化") + with torch.no_grad(): prob = self.model(image_tensor).item() - del image_tensor # 立即释放 MPS/CUDA 显存,避免长批次累积 + del image_tensor return FlightResult( is_flying=prob > threshold, @@ -183,33 +186,32 @@ def detect( def detect_batch( self, images: list, - threshold: float = None, + threshold: Optional[float] = None, batch_size: int = 8 ) -> list: """ 批量检测多张图像 - + Args: images: 图像列表(支持混合类型) threshold: 分类阈值 batch_size: 批处理大小 - + Returns: list[FlightResult]: 检测结果列表 """ if not self.model_loaded: raise RuntimeError("飞版检测模型未加载,请先调用 load_model()") - + if threshold is None: threshold = self.THRESHOLD - + results = [] - - # 分批处理 + for i in range(0, len(images), batch_size): batch = images[i:i + batch_size] batch_tensors = [] - + for img in batch: if isinstance(img, str): pil_image = Image.open(img).convert('RGB') @@ -221,40 +223,41 @@ def detect_batch( pil_image = img.convert('RGB') else: continue - - batch_tensors.append(self.transform(pil_image)) - + + batch_tensors.append(cast(torch.Tensor, self.transform(pil_image))) + if not batch_tensors: continue - - # 组合为批次 + + if self.device is None: + raise RuntimeError("飞版检测设备尚未初始化") batch_tensor = torch.stack(batch_tensors).to(self.device) - - # 推理 + + if self.model is None: + raise RuntimeError("飞版检测模型尚未初始化") + model = self.model with torch.no_grad(): - probs = self.model(batch_tensor).squeeze().cpu().numpy() - - # 处理单个元素的情况 + probs = model(batch_tensor).squeeze().cpu().numpy() # type: ignore + if probs.ndim == 0: probs = [probs.item()] - + for prob in probs: results.append(FlightResult( is_flying=prob > threshold, confidence=float(prob) )) - + return results -# 全局单例(延迟初始化) _flight_detector_instance: Optional[FlightDetector] = None def get_flight_detector() -> FlightDetector: """ 获取全局飞版检测器实例(单例模式) - + Returns: FlightDetector: 全局检测器实例 """ diff --git a/core/initialization_manager.py b/core/initialization_manager.py new file mode 100644 index 0000000..a43e993 --- /dev/null +++ b/core/initialization_manager.py @@ -0,0 +1,1436 @@ +# -*- coding: utf-8 -*- +""" +First-run initialization manager for lightweight builds. + +The old first-run onboarding path is intentionally preserved elsewhere for +full-package compatibility. This manager only takes over when runtime or +required resources are missing. + +轻量级构建的首次运行初始化管理器。 + +旧的首次运行引导路径在其他地方保留,以实现完整包兼容性。 +此管理器仅在运行时或所需资源缺失时接管。 +""" + +from __future__ import annotations + +import importlib +import importlib.util +import logging +import os +import shutil +import subprocess +import sys +import tempfile +import threading +import time +import urllib.parse +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, Optional + +from PySide6.QtCore import QObject, Signal + +from advanced_config import get_advanced_config +from config import ( + get_app_config_dir, + get_app_internal_dir, + get_bundled_resource_dir, +) +from core.initialization_progress import ( + InitializationProgressEvent, + PROGRESS_KIND_DOWNLOAD, + PROGRESS_KIND_RUNTIME, + STAGE_DOWNLOADING, + STAGE_PREPARING_RUNTIME, + parse_pip_raw_progress_line, +) +from core.runtime_requirements import RuntimeRequirements, get_runtime_requirements +from core.source_probe import pick_best_source, probe_sources +from scripts.download_models import ( + download_resource, + resolve_download_plan, + resolve_resource_destination_dir, +) + +logging.basicConfig(level=logging.INFO) + + +PIPY_SOURCES = [ + {"name": "cernet", "url": "https://mirrors.cernet.edu.cn/pypi/web/simple"}, + {"name": "official", "url": "https://pypi.org/simple"}, +] + +FULL_FEATURE_SET = ("core_detection", "quality", "keypoint", "flight", "birdid") + +STAGE_NOT_STARTED = "not_started" +STAGE_PROBING = "probing_sources" +STAGE_CHECKING_UPDATES = "checking_updates" +STAGE_PREPARING_RUNTIME = "preparing_runtime" +STAGE_DOWNLOADING = "downloading_resources" +STAGE_VERIFYING = "verifying" +STAGE_READY = "ready" +STAGE_FAILED = "failed" + + +class InitializationInterrupted(RuntimeError): + """用户主动中断初始化 / User-requested initialization interruption.""" + + +@dataclass +class RuntimeSelection: + variant: str + detected_cuda_capable: bool + reason: str + + +@dataclass(frozen=True) +class RuntimeInstallLocation: + key: str + runtime_dir: Path + free_bytes: Optional[int] + writable: bool + + +@dataclass +class ResourceProgressState: + """ + Aggregate state for one resource inside the download phase. + + 下载阶段中单个资源的聚合状态。 + """ + + ratio: float = 0.0 + bytes_done: int | None = None + bytes_total: int | None = None + is_terminal: bool = False + last_logged_bucket: int = -1 + last_logged_message: str | None = None + last_logged_source: str | None = None + + +class InitializationManager(QObject): + """ + Coordinate runtime repair, resource preparation, and structured progress events. + + 负责协调运行时修复、资源准备以及结构化进度事件的初始化管理器。 + """ + + stage_changed = Signal(str, str) + progress_event = Signal(object) + progress_changed = Signal(int, str, int, int) + item_status_changed = Signal(str, str, str) + finished = Signal(bool, object) + + def __init__(self, parent=None): + """ + 初始化初始化管理器。 + + Initialize the initialization manager. + + 参数 Parameters: + parent: 父 QObject 对象 + """ + super().__init__(parent) + self.config = get_advanced_config() + self._thread: Optional[threading.Thread] = None + self._last_options: Optional[dict] = None + self._last_mode: str = "init" + self._project_root = self._resolve_project_root() + self._runtime_dir = self.resolve_runtime_dir( + self.config.runtime_install_location_preference + ) + self._source_map: Dict[str, str] = {} + self._cancel_requested = threading.Event() + self._active_process: Optional[subprocess.Popen[str]] = None + self._resource_progress: dict[str, ResourceProgressState] = {} + self._resource_progress_item_count = 0 + + self._ensure_hf_endpoint_configured() + logging.info("初始化管理器已创建,项目根目录: %s", self._project_root) + + def _resolve_project_root(self) -> Path: + """ + 解析项目根目录。 + + Resolve project root directory. + + 返回 Returns: + Path: 项目根目录路径 + """ + if getattr(sys, "frozen", False) and sys.platform == "win32": + return get_app_internal_dir() + return Path(__file__).resolve().parent.parent + + def _ensure_hf_endpoint_configured(self) -> None: + """ + 确保 Hugging Face 端点环境变量已正确设置。 + + Ensure Hugging Face endpoint environment variables are properly configured. + """ + hf_mirror_endpoint = "https://hf-mirror.com" + + if ( + "HF_ENDPOINT" not in os.environ + or os.environ["HF_ENDPOINT"] != hf_mirror_endpoint + ): + os.environ["HF_ENDPOINT"] = hf_mirror_endpoint + logging.info("已设置 HF_ENDPOINT = %s", hf_mirror_endpoint) + + env_vars = { + "HF_HUB_DISABLE_TELEMETRY": "1", + "HF_HUB_DISABLE_XET": "1", + "DO_NOT_TRACK": "1", + } + + for key, value in env_vars.items(): + if key not in os.environ or os.environ[key] != value: + os.environ[key] = value + logging.debug("已设置 %s = %s", key, value) + + def _resolve_runtime_requirements_path(self, runtime_variant: str) -> Path: + """Resolve runtime requirements file path for backward compatibility.""" + requirements = get_runtime_requirements(runtime_variant) # pyright: ignore[reportArgumentType] + requirements_content = requirements.to_requirements_string( + include_indexes=False, + package_urls=self._selected_torch_package_urls(runtime_variant), + ) + + temp_file = tempfile.NamedTemporaryFile( + mode="w", + suffix=".txt", + prefix=f"requirements_{runtime_variant}_", + delete=False, + encoding="utf-8", + ) + try: + temp_file.write(requirements_content) + temp_file.close() + return Path(temp_file.name) + except Exception: + temp_file.close() + Path(temp_file.name).unlink(missing_ok=True) + raise + + def _runtime_requirements(self, runtime_variant: str) -> RuntimeRequirements: + """ + Return the unified runtime requirement definition for one variant. + + 返回指定运行时变体的统一依赖定义。 + """ + return get_runtime_requirements(runtime_variant) # pyright: ignore[reportArgumentType] + + def _selected_torch_package_urls(self, runtime_variant: str) -> dict[str, str]: + """ + Build direct wheel references for Torch packages on Windows runtime installs. + + 为 Windows 运行时安装构建 Torch 系列包的直链引用。 + """ + if runtime_variant not in ("cpu", "cuda"): + return {} + primary_source = self._source_map.get("torch_primary", "").strip() + if not primary_source or sys.platform != "win32": + return {} + + requirements = self._runtime_requirements(runtime_variant) + python_tag = f"cp{sys.version_info.major}{sys.version_info.minor}" + abi_tag = python_tag + platform_tag = "win_amd64" + source_base = primary_source.rstrip("/") + + package_versions = { + "torch": requirements.torch_version, + "torchvision": requirements.torchvision_version, + "torchaudio": requirements.torchaudio_version, + } + selected_urls: dict[str, str] = {} + for package_name, version in package_versions.items(): + normalized_version = (version or "").strip() + if not normalized_version: + continue + filename = ( + f"{package_name}-{normalized_version}-{python_tag}-{abi_tag}-{platform_tag}.whl" + ) + quoted_filename = urllib.parse.quote(filename) + selected_urls[package_name] = ( + f"{package_name} @ {source_base}/{quoted_filename}" + ) + return selected_urls + + @staticmethod + def _torch_source_candidates(runtime_variant: str) -> list[dict[str, str]]: + """ + Build Torch wheel source candidates from the shared runtime requirements. + + 基于统一运行时依赖定义构建 Torch wheel 源候选列表。 + """ + requirements = get_runtime_requirements(runtime_variant) # pyright: ignore[reportArgumentType] + candidates: list[dict[str, str]] = [] + for index, url in enumerate(requirements.extra_index_urls): + lowered = url.lower() + if "mirror" in lowered or "nju" in lowered: + name = f"mirror-{index}" + elif "download.pytorch.org" in lowered: + name = f"official-{index}" + else: + name = f"torch-{index}" + candidates.append({"name": name, "url": url}) + return candidates + + @staticmethod + def _normalize_features(selected_features: Optional[Iterable[str]]) -> list[str]: + features = list(selected_features or FULL_FEATURE_SET) + if "core_detection" not in features: + features.insert(0, "core_detection") + return features + + def _save_config(self, **updates) -> None: + setters = { + "initialization_completed": self.config.set_initialization_completed, + "initialization_in_progress": self.config.set_initialization_in_progress, + "selected_runtime_variant": self.config.set_selected_runtime_variant, + "detected_cuda_capable": self.config.set_detected_cuda_capable, + "runtime_install_location_preference": self.config.set_runtime_install_location_preference, + "resolved_runtime_dir": self.config.set_resolved_runtime_dir, + "enabled_feature_set": self.config.set_enabled_feature_set, + "downloaded_resources": self.config.set_downloaded_resources, + "resolved_source_map": self.config.set_resolved_source_map, + "last_init_error": self.config.set_last_init_error, + "last_init_exit_reason": self.config.set_last_init_exit_reason, + "last_init_mode": self.config.set_last_init_mode, + "is_first_run": self.config.set_is_first_run, + } + for key, value in updates.items(): + setter = setters.get(key) + if setter is not None: + setter(value) + self.config.save() + + def _emit_item_status(self, resource_id: str, status: str, detail: str) -> None: + self.item_status_changed.emit(resource_id, status, detail) + + def _emit_progress_event(self, event: InitializationProgressEvent) -> None: + """ + Emit the new structured progress event and the deprecated legacy signal. + + 发出新的结构化进度事件,并兼容发出旧版信号。 + """ + self.progress_event.emit(event) + ratio = event.normalized_ratio() + if ratio is None: + return + self.progress_changed.emit( + int(round(ratio * 100.0)), + event.message, + event.item_index or 0, + event.item_count or 0, + ) + + def _emit_phase_completion( + self, + progress_kind: str, + message: str, + *, + bytes_done: int | None = None, + bytes_total: int | None = None, + ) -> None: + """ + Emit an explicit terminal progress event for one visual phase. + + 为单个视觉阶段发出显式终态进度事件。 + """ + stage = ( + STAGE_PREPARING_RUNTIME + if progress_kind == PROGRESS_KIND_RUNTIME + else STAGE_DOWNLOADING + ) + self._emit_progress_event( + InitializationProgressEvent( + stage=stage, + progress_kind=progress_kind, + message=message, + ratio=1.0, + bytes_done=bytes_done, + bytes_total=bytes_total, + is_terminal=True, + ) + ) + + def _installation_root(self) -> Path: + if getattr(sys, "frozen", False): + executable = Path(sys.executable).resolve() + if sys.platform == "darwin" and executable.parent.name == "MacOS": + return executable.parents[2] + return executable.parent + return self._project_root + + def _runtime_install_locations(self) -> dict[str, Path]: + install_runtime_dir = self._installation_root() / "runtime_env" + if self._requires_install_local_runtime(): + install_runtime_dir = get_app_internal_dir() / "runtime_env" + return { + "default": get_app_config_dir() / "runtime_env", + "install": install_runtime_dir, + } + + def _requires_install_local_runtime(self) -> bool: + return getattr(sys, "frozen", False) and sys.platform == "win32" + + def _uses_bundled_runtime(self) -> bool: + return getattr(sys, "frozen", False) and sys.platform == "darwin" + + @staticmethod + def _existing_probe_path(path: Path) -> Path: + current = path + while not current.exists() and current != current.parent: + current = current.parent + return current + + def _free_bytes_for_path(self, path: Path) -> Optional[int]: + probe_path = self._existing_probe_path(path) + try: + return shutil.disk_usage(probe_path).free + except Exception: + return None + + def _writable_probe_dir(self, path: Path) -> Path: + probe_root = path if path.exists() else path.parent + probe_dir = self._existing_probe_path(probe_root) + return probe_dir if probe_dir.is_dir() else probe_dir.parent + + def _is_runtime_dir_writable(self, path: Path) -> bool: + try: + probe_dir = self._writable_probe_dir(path) + with tempfile.TemporaryDirectory(dir=probe_dir, prefix="sp_runtime_probe_"): + pass + return True + except Exception: + return False + + def get_runtime_install_location_options(self) -> list[RuntimeInstallLocation]: + if self._requires_install_local_runtime(): + install_dir = self._runtime_install_locations()["install"] + return [ + RuntimeInstallLocation( + key="install", + runtime_dir=install_dir, + free_bytes=self._free_bytes_for_path(install_dir), + writable=self._is_runtime_dir_writable(install_dir), + ) + ] + + options = [] + for key, runtime_dir in self._runtime_install_locations().items(): + options.append( + RuntimeInstallLocation( + key=key, + runtime_dir=runtime_dir, + free_bytes=self._free_bytes_for_path(runtime_dir), + writable=self._is_runtime_dir_writable(runtime_dir), + ) + ) + return options + + def choose_runtime_install_location( + self, preferred_key: Optional[str] = None + ) -> RuntimeInstallLocation: + if self._requires_install_local_runtime(): + install_dir = self._runtime_install_locations()["install"] + return RuntimeInstallLocation( + "install", + install_dir, + self._free_bytes_for_path(install_dir), + self._is_runtime_dir_writable(install_dir), + ) + + options = [ + item + for item in self.get_runtime_install_location_options() + if item.writable + ] + if not options: + default_dir = self._runtime_install_locations()["default"] + return RuntimeInstallLocation("default", default_dir, None, True) + + by_key = {item.key: item for item in options} + if preferred_key in by_key: + return by_key[preferred_key] + + comparable = [item for item in options if item.free_bytes is not None] + if comparable: + return max( + comparable, + key=lambda item: (item.free_bytes or -1, item.key == "default"), + ) + return by_key.get("default", options[0]) + + def resolve_runtime_dir(self, preferred_key: Optional[str] = None) -> Path: + return self.choose_runtime_install_location(preferred_key).runtime_dir + + def runtime_display_dir(self, preferred_key: Optional[str] = None) -> Path: + if self._uses_bundled_runtime(): + return get_bundled_resource_dir() + if self._requires_install_local_runtime(): + return get_app_internal_dir() / "runtime_env" + return self.resolve_runtime_dir(preferred_key) + + def start(self, options: dict, mode: str = "init") -> None: + normalized_options = dict(options) + normalized_options["features"] = self._normalize_features( + normalized_options.get("features") + ) + normalized_options["runtime_install_location"] = ( + self.choose_runtime_install_location( + normalized_options.get("runtime_install_location") + or self.config.runtime_install_location_preference + ).key + ) + self._last_options = normalized_options + self._last_mode = mode + if self._thread and self._thread.is_alive(): + return + self._cancel_requested.clear() + self._thread = threading.Thread( + target=self._run, args=(dict(normalized_options), mode), daemon=True + ) + self._thread.start() + + def start_initialization(self, options: dict) -> None: + self.start(options, mode="init") + + def start_repair(self, options: dict) -> None: + self.start(options, mode="repair") + + def retry_failed(self) -> None: + if self._last_options is not None: + self.start(self._last_options, mode=self._last_mode) + + def resume_pending(self) -> None: + if self._last_options is not None: + self.start(self._last_options, mode=self._last_mode) + + def cancel(self) -> None: + self._cancel_requested.set() + process = self._active_process + if process is not None and process.poll() is None: + try: + process.terminate() + except Exception: + pass + + def is_ready_for_main_ui( + self, selected_features: Optional[Iterable[str]] = None + ) -> bool: + return self._has_runtime_available() and self._resources_available( + selected_features + ) + + def needs_initialization( + self, selected_features: Optional[Iterable[str]] = None + ) -> bool: + return not self.is_ready_for_main_ui(selected_features) + + def check_runtime_health(self) -> bool: + """ + 检查运行时健康状态。 + + Check runtime health status. + + 返回 Returns: + bool: 运行时是否健康 + """ + runtime_available = self._has_runtime_available() + import_ok = self._runtime_import_ok() + + logging.info("运行时健康检查: 可用=%s, 导入=%s", runtime_available, import_ok) + + return runtime_available and import_ok + + def check_resource_health( + self, selected_features: Optional[Iterable[str]] + ) -> Dict[str, bool]: + """ + 检查资源健康状态。 + + Check resource health status. + + 参数 Parameters: + selected_features (Optional[Iterable[str]]): 选定的功能特性 + + 返回 Returns: + Dict[str, bool]: 资源 ID 到健康状态的映射 + """ + plan = resolve_download_plan(self._normalize_features(selected_features)) + health_status = { + item["resource_id"]: self._resource_item_available(item) for item in plan + } + + healthy_count = sum(1 for status in health_status.values() if status) + logging.info("资源健康检查: %d/%d 资源可用", healthy_count, len(health_status)) + + return health_status + + def repair_runtime_if_needed(self, runtime_variant: str) -> bool: + """ + 如果需要,修复运行时环境。 + + Repair runtime environment if needed. + + 参数 Parameters: + runtime_variant (str): 运行时变体(cpu/cuda/mac) + + 返回 Returns: + bool: 是否执行了修复 + """ + if self.check_runtime_health(): + self._emit_item_status("runtime", "done", "Runtime already healthy") + self._emit_phase_completion( + PROGRESS_KIND_RUNTIME, + f"{runtime_variant} runtime already available", + ) + logging.info("运行时环境健康,无需修复") + return False + + if self._uses_bundled_runtime(): + raise RuntimeError( + "Bundled macOS Lite Torch runtime is unavailable; rebuild the app bundle." + ) + + logging.info("运行时环境需要修复,开始准备 %s 运行时...", runtime_variant) + self._emit_stage( + STAGE_PREPARING_RUNTIME, f"Preparing {runtime_variant} runtime..." + ) + self._cleanup_partial_runtime() + self._purge_pip_cache_if_needed() + + start_time = time.perf_counter() + try: + self._prepare_runtime(runtime_variant) + self._emit_phase_completion( + PROGRESS_KIND_RUNTIME, + f"{runtime_variant} runtime ready", + ) + elapsed = time.perf_counter() - start_time + logging.info("运行时环境修复完成,耗时 %.2f 秒", elapsed) + return True + except Exception as exc: + elapsed = time.perf_counter() - start_time + logging.error("运行时环境修复失败,耗时 %.2f 秒: %s", elapsed, exc) + raise + + def repair_resources_if_needed( + self, selected_features: Optional[Iterable[str]] + ) -> bool: + """ + 如果需要,修复资源文件。 + + Repair resource files if needed. + + 参数 Parameters: + selected_features (Optional[Iterable[str]]): 选定的功能特性 + + 返回 Returns: + bool: 是否执行了修复 + """ + plan = resolve_download_plan(self._normalize_features(selected_features)) + pending = [item for item in plan if not self._resource_item_available(item)] + total_items = max(1, len(pending)) + + if not pending: + self._emit_item_status("resources", "done", "Resources already healthy") + logging.info("所有资源已就绪,无需修复") + return False + + self._resource_progress = { + item["resource_id"]: ResourceProgressState() for item in pending + } + self._resource_progress_item_count = total_items + logging.info("需要修复 %d 个资源文件", len(pending)) + self._emit_stage(STAGE_DOWNLOADING, "Downloading required resources...") + + start_time = time.perf_counter() + success_count = 0 + + for index, resource in enumerate(pending, start=1): + label = resource["filename"] + resource_id = resource["resource_id"] + + logging.info("开始下载资源 [%d/%d]: %s", index, total_items, label) + + self._emit_item_status(resource_id, "running", f"Preparing {label}") + + try: + download_resource( + resource, + project_root=self._project_root, + progress_cb=self._resource_progress_cb(index, total_items, resource_id), + ) + success_count += 1 + self._emit_item_status(resource_id, "done", f"{label} ready") + logging.info("资源 [%s] 下载成功", resource_id) + except Exception as exc: + logging.error("资源 [%s] 下载失败: %s", resource_id, exc) + self._emit_item_status(resource_id, "error", f"{label} failed: {exc}") + raise + + elapsed = time.perf_counter() - start_time + logging.info( + "资源修复完成: %d/%d 成功,总耗时 %.2f 秒", + success_count, + total_items, + elapsed, + ) + self._emit_phase_completion( + PROGRESS_KIND_DOWNLOAD, + "All required resources ready", + ) + + return True + + def detect_runtime_selection( + self, preferred_variant: str = "auto" + ) -> RuntimeSelection: + if sys.platform == "darwin": + if preferred_variant in ("cpu", "mac"): + return RuntimeSelection("mac", False, "macOS runtime") + return RuntimeSelection("mac", False, "macOS runtime") + + detected_cuda = self._detect_cuda_capable() + if preferred_variant == "cuda" and detected_cuda: + return RuntimeSelection("cuda", True, "user requested CUDA") + if preferred_variant == "cuda" and not detected_cuda: + return RuntimeSelection( + "cpu", False, "CUDA unavailable, falling back to CPU" + ) + if preferred_variant == "cpu": + return RuntimeSelection("cpu", detected_cuda, "user requested CPU") + if detected_cuda: + return RuntimeSelection("cuda", True, "detected NVIDIA/CUDA support") + return RuntimeSelection("cpu", False, "default CPU runtime") + + def _run(self, options: dict, mode: str) -> None: + try: + self._raise_if_cancelled() + selected_features = self._normalize_features(options.get("features")) + self._last_mode = mode + runtime_location = self.choose_runtime_install_location( + options.get("runtime_install_location") + or self.config.runtime_install_location_preference + ) + self._runtime_dir = runtime_location.runtime_dir + self._save_config( + initialization_in_progress=(mode == "init"), + last_init_error=None, + last_init_exit_reason="none", + last_init_mode=mode, + runtime_install_location_preference=runtime_location.key, + resolved_runtime_dir=str(runtime_location.runtime_dir), + ) + + runtime_choice = self.detect_runtime_selection( + options.get("runtime_variant", "auto") + ) + if mode == "init": + self._save_config( + selected_runtime_variant=runtime_choice.variant, + detected_cuda_capable=runtime_choice.detected_cuda_capable, + runtime_install_location_preference=runtime_location.key, + resolved_runtime_dir=str(runtime_location.runtime_dir), + enabled_feature_set=selected_features, + ) + + self._raise_if_cancelled() + self._emit_stage(STAGE_PROBING, "Probing download sources...") + self._source_map = self._resolve_best_sources(runtime_choice.variant) + self._emit_item_status( + "source_probe", "done", f"PyPI -> {self._source_map['pypi_primary']}" + ) + if self._source_map.get("torch_primary"): + self._emit_item_status( + "source_probe", "done", f"Torch -> {self._source_map['torch_primary']}" + ) + else: + self._emit_item_status( + "source_probe", "done", "Torch -> bundled runtime" + ) + self._save_config(resolved_source_map=self._source_map) + + if options.get("auto_update_enabled", True): + self._emit_stage(STAGE_CHECKING_UPDATES, "Checking updates...") + self._check_updates_if_enabled() + else: + try: + from tools.patch_manager import safe_clear_patch + + cleared, clear_message = safe_clear_patch() + clear_status = "done" if cleared else "warning" + self._emit_item_status("updates", clear_status, clear_message) + except Exception as exc: + self._emit_item_status( + "updates", "warning", f"Patch cleanup skipped: {exc}" + ) + self._emit_item_status( + "updates", "skipped", "Automatic updates disabled by user" + ) + + self._raise_if_cancelled() + self.repair_runtime_if_needed(runtime_choice.variant) + self._raise_if_cancelled() + self.repair_resources_if_needed(selected_features) + + self._emit_stage(STAGE_VERIFYING, "Verifying resources...") + if not self.is_ready_for_main_ui(selected_features): + raise RuntimeError( + "Initialization completed with missing runtime or resources" + ) + + success_updates: dict[str, object] = { + "initialization_in_progress": False, + "last_init_mode": "none", + } + if mode == "init": + success_updates.update( + initialization_completed=True, + last_init_exit_reason="none", + is_first_run=False, + downloaded_resources={ + item["resource_id"]: True + for item in resolve_download_plan(selected_features) + }, + ) + self._save_config(**success_updates) + final_message = ( + "Initialization completed" + if mode == "init" + else "Environment repair completed" + ) + self._emit_stage(STAGE_READY, final_message) + self.finished.emit( + True, + { + "runtime_variant": runtime_choice.variant, + "source_map": self._source_map, + "mode": mode, + }, + ) + except InitializationInterrupted: + self._cleanup_partial_runtime() + self._purge_pip_cache_if_needed() + self._save_config( + initialization_in_progress=False, + last_init_error=None, + last_init_exit_reason="interrupted", + last_init_mode=mode, + ) + self.finished.emit(False, {"interrupted": True, "mode": mode}) + except Exception as exc: + self._save_config( + initialization_in_progress=False, + last_init_error=str(exc), + last_init_exit_reason="failed", + last_init_mode=mode, + ) + self._emit_stage(STAGE_FAILED, str(exc)) + self.finished.emit(False, {"error": str(exc), "mode": mode}) + + def _resource_progress_cb( + self, + item_index: int, + total_items: int, + fallback_resource_id: str, + ): + """ + Adapt per-resource download events into one aggregated download stream. + + 将单个资源的下载事件适配为统一的聚合下载进度流。 + """ + + def _callback(event: InitializationProgressEvent) -> None: + resource_id = event.resource_id or fallback_resource_id + enriched_event = InitializationProgressEvent( + stage=event.stage, + progress_kind=event.progress_kind, + message=event.message, + ratio=event.ratio, + bytes_done=event.bytes_done, + bytes_total=event.bytes_total, + item_index=item_index - 1, + item_count=total_items, + resource_id=resource_id, + source=event.source, + is_terminal=event.is_terminal, + ) + aggregate_event = self._update_resource_aggregate(enriched_event) + if self._should_emit_resource_log(enriched_event): + self._emit_item_status(resource_id, "progress", enriched_event.message) + self._emit_progress_event(aggregate_event) + + return _callback + + def _update_resource_aggregate( + self, + event: InitializationProgressEvent, + ) -> InitializationProgressEvent: + """ + Fold one resource event into the cross-resource aggregate download ratio. + + 将单个资源事件折叠到跨资源的聚合下载进度中。 + """ + resource_id = event.resource_id or f"resource-{len(self._resource_progress)}" + ratio = event.normalized_ratio() + bytes_total = event.bytes_total if event.bytes_total and event.bytes_total > 0 else None + + state = self._resource_progress.get(resource_id) + if state is None: + state = ResourceProgressState() + self._resource_progress[resource_id] = state + + if ratio is not None: + state.ratio = max(state.ratio, ratio) + elif event.is_terminal: + state.ratio = 1.0 + + if event.bytes_done is not None: + state.bytes_done = max(state.bytes_done or 0, event.bytes_done) + if bytes_total is not None: + state.bytes_total = max(state.bytes_total or 0, bytes_total) + state.is_terminal = state.is_terminal or event.is_terminal or state.ratio >= 1.0 + + total_items = max(1, self._resource_progress_item_count or len(self._resource_progress)) + completed_ratio_sum = sum(item.ratio for item in self._resource_progress.values()) + overall_ratio = completed_ratio_sum / total_items + + known_total = 0 + known_done = 0 + all_have_bytes = bool(self._resource_progress) + for item in self._resource_progress.values(): + if item.bytes_total is None: + all_have_bytes = False + continue + known_total += item.bytes_total + known_done += min(item.bytes_done or 0, item.bytes_total) + + aggregate_terminal = ( + len(self._resource_progress) >= self._resource_progress_item_count + and all(item.is_terminal for item in self._resource_progress.values()) + ) + return InitializationProgressEvent( + stage=STAGE_DOWNLOADING, + progress_kind=event.progress_kind, + message=event.message, + ratio=min(1.0, max(0.0, overall_ratio)), + bytes_done=known_done if all_have_bytes else None, + bytes_total=known_total if all_have_bytes else None, + item_index=event.item_index, + item_count=event.item_count, + resource_id=event.resource_id, + source=event.source, + is_terminal=aggregate_terminal, + ) + + def _should_emit_resource_log(self, event: InitializationProgressEvent) -> bool: + """ + Throttle noisy per-byte download events into human-readable milestone logs. + + 将高频字节级下载事件节流为更适合阅读的里程碑日志。 + """ + resource_id = event.resource_id + if not resource_id: + return False + + state = self._resource_progress.get(resource_id) + if state is None: + return True + + ratio = event.normalized_ratio() + message = event.message + source_changed = bool(event.source and event.source != state.last_logged_source) + terminal_message = event.is_terminal or any( + token in message.lower() + for token in ("failed", "validated", "downloaded", "already present", "copied from local fallback") + ) + + should_log = False + if state.last_logged_message is None: + should_log = True + elif terminal_message or source_changed: + should_log = True + elif ratio is not None: + bucket = min(10, int(ratio * 10.0)) + if bucket > state.last_logged_bucket: + should_log = True + elif message != state.last_logged_message: + should_log = True + + if should_log: + if ratio is not None: + state.last_logged_bucket = max( + state.last_logged_bucket, + min(10, int(ratio * 10.0)), + ) + state.last_logged_message = message + state.last_logged_source = event.source or state.last_logged_source + + return should_log + + def _emit_stage(self, stage: str, message: str) -> None: + self.stage_changed.emit(stage, message) + + def _check_updates_if_enabled(self) -> None: + try: + from tools.update_checker import UpdateChecker + + self._raise_if_cancelled() + checker = UpdateChecker() + checker.check_for_updates() + self._emit_item_status("updates", "done", "Update probe finished") + except Exception as exc: + self._emit_item_status("updates", "warning", f"Update probe skipped: {exc}") + + def _resolve_best_sources(self, runtime_variant: str) -> Dict[str, str]: + pypi_results = probe_sources("pypi", PIPY_SOURCES) + best_pypi = self._pick_preferred_source(pypi_results) + + pypi_primary = best_pypi.url if best_pypi else PIPY_SOURCES[0]["url"] + pypi_fallback = self._resolve_fallback_url(pypi_results, pypi_primary) + + torch_primary = "" + torch_fallback = "" + torch_sources = self._torch_source_candidates(runtime_variant) + if torch_sources: + torch_results = probe_sources(f"torch-{runtime_variant}", torch_sources) + best_torch = self._pick_preferred_source(torch_results) + torch_primary = best_torch.url if best_torch else torch_sources[0]["url"] + torch_fallback = self._resolve_fallback_url(torch_results, torch_primary) + + selected = { + "pypi_primary": pypi_primary, + "pypi_fallback": pypi_fallback, + "torch_primary": torch_primary, + "torch_fallback": torch_fallback, + } + return selected + + @staticmethod + def _pick_preferred_source(results): + successful = [item for item in results if item.ok] + if not successful: + return None + + non_official = [ + item for item in successful if "official" not in item.name.lower() + ] + if non_official: + return pick_best_source(non_official) + return pick_best_source(successful) + + @staticmethod + def _resolve_fallback_url(results, primary_url: str) -> str: + successful = [item for item in results if item.ok and item.url != primary_url] + if not successful: + return primary_url + + non_official = [ + item for item in successful if "official" not in item.name.lower() + ] + if non_official: + fallback = pick_best_source(non_official) + return fallback.url if fallback else primary_url + + fallback = pick_best_source(successful) + return fallback.url if fallback else primary_url + + def _prepare_runtime(self, runtime_variant: str) -> None: + """ + Create or repair the runtime environment for the selected variant. + + 为当前选择的运行时变体创建或修复运行环境。 + """ + if self._uses_bundled_runtime(): + raise RuntimeError( + "Bundled macOS Lite Torch runtime is unavailable; runtime installation is disabled." + ) + + if self._use_packaged_runtime_bootstrap(): + self._prepare_runtime_with_packaged_bootstrap(runtime_variant) + return + + python_cmd = self._resolve_python_command() + if not self._runtime_dir.exists(): + self._run_subprocess( + [*python_cmd, "-m", "venv", str(self._runtime_dir)], + "Create runtime venv", + ) + + pip_executable = ( + self._runtime_dir + / ("Scripts" if os.name == "nt" else "bin") + / ("pip.exe" if os.name == "nt" else "pip") + ) + requirements_file = self._resolve_runtime_requirements_path(runtime_variant) + runtime_requirements = self._runtime_requirements(runtime_variant) + install_cmd = [ + str(pip_executable), + "install", + "--no-cache-dir", + "--progress-bar", + "raw", + "-r", + str(requirements_file), + "-i", + self._source_map["pypi_primary"], + "--extra-index-url", + self._source_map["pypi_fallback"], + ] + if runtime_requirements.extra_index_urls: + install_cmd.extend(["--extra-index-url", self._source_map["torch_primary"]]) + if ( + self._source_map["torch_fallback"] + and self._source_map["torch_fallback"] != self._source_map["torch_primary"] + ): + install_cmd.extend( + ["--extra-index-url", self._source_map["torch_fallback"]] + ) + self._run_subprocess( + install_cmd, + f"Install {runtime_variant} runtime", + progress_stage=STAGE_PREPARING_RUNTIME, + progress_kind=PROGRESS_KIND_RUNTIME, + ) + self._inject_runtime_site_packages() + self._verify_runtime_import(runtime_variant) + + def _use_packaged_runtime_bootstrap(self) -> bool: + return getattr(sys, "frozen", False) and sys.platform == "win32" + + def _runtime_site_packages_candidates(self) -> list[Path]: + candidates: list[Path] = [ + self._runtime_dir / "site-packages", + self._runtime_dir / "Lib" / "site-packages", + ] + lib_dir = self._runtime_dir / "lib" + if lib_dir.exists(): + candidates.extend(sorted(lib_dir.glob("python*/site-packages"))) + version_tag = f"python{sys.version_info.major}.{sys.version_info.minor}" + candidates.append(self._runtime_dir / "lib" / version_tag / "site-packages") + + unique_candidates: list[Path] = [] + seen_paths: set[Path] = set() + for candidate in candidates: + if candidate in seen_paths: + continue + seen_paths.add(candidate) + unique_candidates.append(candidate) + return unique_candidates + + def _runtime_python_executable(self) -> Path: + if os.name == "nt": + return self._runtime_dir / "Scripts" / "python.exe" + return self._runtime_dir / "bin" / "python3" + + def _prepare_runtime_with_packaged_bootstrap(self, runtime_variant: str) -> None: + requirements_file = self._resolve_runtime_requirements_path(runtime_variant) + runtime_site_packages = self._runtime_dir / "site-packages" + runtime_site_packages.mkdir(parents=True, exist_ok=True) + runtime_requirements = self._runtime_requirements(runtime_variant) + + command = [ + str(Path(sys.executable).resolve()), + "--runtime-bootstrap", + "--runtime-dir", + str(self._runtime_dir), + "--requirements", + str(requirements_file), + "--index-url", + self._source_map["pypi_primary"], + "--extra-index-url", + self._source_map["pypi_fallback"], + ] + if runtime_requirements.extra_index_urls: + command.extend(["--extra-index-url", self._source_map["torch_primary"]]) + if ( + self._source_map["torch_fallback"] + and self._source_map["torch_fallback"] != self._source_map["torch_primary"] + ): + command.extend( + ["--extra-index-url", self._source_map["torch_fallback"]] + ) + + self._run_subprocess( + command, + f"Install {runtime_variant} runtime", + progress_stage=STAGE_PREPARING_RUNTIME, + progress_kind=PROGRESS_KIND_RUNTIME, + ) + self._inject_runtime_site_packages() + self._verify_runtime_import(runtime_variant) + + def _run_subprocess( + self, + command: list[str], + label: str, + *, + progress_stage: str | None = None, + progress_kind: str | None = None, + ) -> None: + """ + Run a subprocess while forwarding structured progress updates when possible. + + 运行子进程,并在可能时转发结构化进度更新。 + """ + popen_kwargs = { + "stdout": subprocess.PIPE, + "stderr": subprocess.STDOUT, + "text": True, + "encoding": "utf-8", + "errors": "replace", + } + if os.name == "nt": + popen_kwargs["creationflags"] = getattr(subprocess, "CREATE_NO_WINDOW", 0) + + process = subprocess.Popen(command, **popen_kwargs) + self._active_process = process + latest_progress_bytes: tuple[int, int | None] | None = None + try: + assert process.stdout is not None + for line in process.stdout: + self._raise_if_cancelled() + text = line.strip() + if text: + parsed = parse_pip_raw_progress_line(text) + if ( + parsed is not None + and progress_stage is not None + and progress_kind is not None + ): + current, total = parsed + latest_progress_bytes = (current, total if total > 0 else None) + self._emit_progress_event( + InitializationProgressEvent( + stage=progress_stage, + progress_kind=progress_kind, + message=f"{label}: raw progress {current}/{total}", + ratio=(current / total) if total > 0 else None, + bytes_done=current, + bytes_total=total if total > 0 else None, + ) + ) + continue + self.item_status_changed.emit("runtime", "progress", f"{label}: {text}") + return_code = process.wait() + if self._cancel_requested.is_set(): + raise InitializationInterrupted("Initialization interrupted by user") + if return_code != 0: + raise RuntimeError(f"{label} failed with exit code {return_code}") + if progress_stage is not None and progress_kind is not None: + done_bytes = latest_progress_bytes[0] if latest_progress_bytes else None + total_bytes = latest_progress_bytes[1] if latest_progress_bytes else None + self._emit_progress_event( + InitializationProgressEvent( + stage=progress_stage, + progress_kind=progress_kind, + message=f"{label} completed", + ratio=1.0, + bytes_done=total_bytes or done_bytes, + bytes_total=total_bytes, + is_terminal=True, + ) + ) + finally: + self._active_process = None + + def _resolve_python_command(self) -> list[str]: + if os.environ.get("VIRTUAL_ENV") and shutil.which("python"): + return [shutil.which("python") or "python"] + + candidates = [ + ( + [sys.executable] + if sys.executable and not getattr(sys, "frozen", False) + else None + ), + [shutil.which("python3")] if shutil.which("python3") else None, + [shutil.which("python")] if shutil.which("python") else None, + ["py", "-3"] if shutil.which("py") else None, + ] + for candidate in candidates: + if not candidate: + continue + try: + subprocess.run( + [*candidate, "-c", "import sys; print(sys.executable)"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + text=True, + ) + return candidate + except Exception: + continue + raise RuntimeError("Unable to find a Python interpreter for runtime bootstrap") + + def _inject_runtime_site_packages(self) -> None: + if self._uses_bundled_runtime(): + return + importlib.invalidate_caches() + for candidate in self._runtime_site_packages_candidates(): + if candidate.exists(): + path = str(candidate) + if path not in sys.path: + sys.path.insert(0, path) + + def _verify_runtime_import(self, runtime_variant: str) -> None: + try: + if self._uses_bundled_runtime(): + torch_module = importlib.import_module("torch") + torch_version = getattr(torch_module, "__version__", "unknown") + self._emit_item_status( + "runtime", + "done", + f"Bundled Torch import OK: {torch_version} ({runtime_variant})", + ) + return + importlib.invalidate_caches() + self._inject_runtime_site_packages() + runtime_python = self._runtime_python_executable() + if runtime_python.exists(): + result = subprocess.run( + [ + str(runtime_python), + "-c", + ( + "import torch, sys; " + "print(torch.__version__); " + "print(sys.executable)" + ), + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + encoding="utf-8", + errors="replace", + check=True, + ) + runtime_lines = [line.strip() for line in result.stdout.splitlines() if line.strip()] + torch_version = runtime_lines[0] if runtime_lines else "unknown" + runtime_executable = runtime_lines[1] if len(runtime_lines) > 1 else str(runtime_python) + self._emit_item_status( + "runtime", + "done", + f"Torch import OK: {torch_version} ({runtime_variant}) via {runtime_executable}", + ) + return + torch_module = importlib.import_module("torch") + torch_version = getattr(torch_module, "__version__", "unknown") + self._emit_item_status( + "runtime", + "done", + f"Torch import OK: {torch_version} ({runtime_variant})", + ) + except Exception as exc: + raise RuntimeError( + f"Runtime installed but Torch import failed: {exc}" + ) from exc + + def _runtime_import_ok(self) -> bool: + try: + if self._uses_bundled_runtime(): + importlib.invalidate_caches() + importlib.import_module("torch") + return True + self._inject_runtime_site_packages() + importlib.invalidate_caches() + importlib.import_module("torch") + return True + except Exception: + return False + + def _raise_if_cancelled(self) -> None: + if self._cancel_requested.is_set(): + raise InitializationInterrupted("Initialization interrupted by user") + + def _cleanup_partial_runtime(self) -> None: + if self._uses_bundled_runtime(): + return + runtime_dir = self._runtime_dir + if not runtime_dir.exists(): + return + + removable_paths = [ + runtime_dir / "site-packages", + runtime_dir / "Lib" / "site-packages", + runtime_dir / "runtime_install_manifest.json", + ] + for candidate in removable_paths: + try: + if candidate.is_dir(): + shutil.rmtree(candidate, ignore_errors=True) + else: + candidate.unlink(missing_ok=True) + except Exception as exc: + logging.warning("清理运行时残留失败: %s (%s)", candidate, exc) + + def _pip_cache_roots(self) -> list[Path]: + roots: list[Path] = [] + if sys.platform == "win32": + local_app_data = os.environ.get("LOCALAPPDATA") + if local_app_data: + roots.append(Path(local_app_data) / "pip" / "Cache") + else: + roots.append(Path.home() / ".cache" / "pip") + return roots + + def _purge_pip_cache_if_needed(self) -> None: + for cache_root in self._pip_cache_roots(): + if not cache_root.exists(): + continue + for relative_name in ("http-v2", "http", "wheels", "selfcheck"): + candidate = cache_root / relative_name + try: + if candidate.is_dir(): + shutil.rmtree(candidate, ignore_errors=True) + else: + candidate.unlink(missing_ok=True) + except Exception as exc: + logging.warning("清理 pip 缓存失败: %s (%s)", candidate, exc) + + def _has_runtime_available(self) -> bool: + if self._uses_bundled_runtime(): + importlib.invalidate_caches() + return importlib.util.find_spec("torch") is not None + if importlib.util.find_spec("torch") is not None: + return True + self._inject_runtime_site_packages() + return importlib.util.find_spec("torch") is not None + + def _resources_available(self, selected_features: Optional[Iterable[str]]) -> bool: + features = self._normalize_features(selected_features) + plan = resolve_download_plan(features) + return all( + self._resource_item_available(item) + for item in plan + if item.get("required") or selected_features + ) + + def _resource_item_available(self, item: dict) -> bool: + path = ( + resolve_resource_destination_dir(self._project_root, item) + / item["filename"] + ) + return path.exists() + + def _detect_cuda_capable(self) -> bool: + if sys.platform != "win32": + return False + try: + result = subprocess.run( + ["nvidia-smi", "-L"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + timeout=4, + creationflags=getattr(subprocess, "CREATE_NO_WINDOW", 0), + ) + return result.returncode == 0 and bool(result.stdout.strip()) + except Exception: + return False diff --git a/core/initialization_progress.py b/core/initialization_progress.py new file mode 100644 index 0000000..8f24777 --- /dev/null +++ b/core/initialization_progress.py @@ -0,0 +1,504 @@ +# -*- coding: utf-8 -*- +""" +Initialization progress models for lightweight onboarding flows. + +This module centralizes structured progress events, stage-to-phase mapping, +and the deterministic progress animation policy shared by the welcome dialog +and the repair dialog. The model keeps UI timing logic out of Qt widgets so it +can be verified with small, repeatable tests. + +轻量化初始化流程的进度模型。 + +此模块集中定义结构化进度事件、阶段到动画阶段的映射关系,以及欢迎向导与修复对话框共享的确定性进度动画策略。 +这样可以把 UI 的时间推进逻辑从 Qt 控件中抽离出来,便于使用小型、可重复的测试进行验证。 +""" + +from __future__ import annotations + +import math +import re +from dataclasses import dataclass + + +PROGRESS_KIND_RUNTIME = "runtime_install" +PROGRESS_KIND_DOWNLOAD = "resource_download" + +STAGE_PROBING = "probing_sources" +STAGE_CHECKING_UPDATES = "checking_updates" +STAGE_PREPARING_RUNTIME = "preparing_runtime" +STAGE_DOWNLOADING = "downloading_resources" +STAGE_VERIFYING = "verifying" +STAGE_READY = "ready" +STAGE_FAILED = "failed" + +PIP_RAW_PROGRESS_PATTERN = re.compile(r"^Progress\s+(?P\d+)\s+of\s+(?P\d+)$") + + +@dataclass(frozen=True) +class InitializationProgressEvent: + """ + Structured progress payload emitted by initialization subsystems. + + 初始化子系统发出的结构化进度负载。 + + Attributes: + stage: Pipeline stage that owns this event. + progress_kind: Logical progress stream, such as runtime install or resource download. + message: Human-readable status text for logs or labels. + ratio: Normalized completion ratio in [0.0, 1.0] when known. + bytes_done: Downloaded or processed bytes when known. + bytes_total: Expected total bytes when known. + item_index: Zero-based item index in a multi-item batch. + item_count: Total item count in a multi-item batch. + resource_id: Resource identifier for per-item tracking. + source: Active mirror or backend source name. + is_terminal: Whether the current item or phase has completed. + """ + + stage: str + progress_kind: str + message: str + ratio: float | None = None + bytes_done: int | None = None + bytes_total: int | None = None + item_index: int | None = None + item_count: int | None = None + resource_id: str | None = None + source: str | None = None + is_terminal: bool = False + + def normalized_ratio(self) -> float | None: + """ + Return a clamped completion ratio when one can be inferred. + + 当可以推导出完成比时,返回一个夹紧后的进度比例。 + """ + if self.ratio is not None: + return max(0.0, min(1.0, float(self.ratio))) + if self.bytes_total and self.bytes_total > 0 and self.bytes_done is not None: + return max(0.0, min(1.0, float(self.bytes_done) / float(self.bytes_total))) + return None + + +@dataclass(frozen=True) +class ProgressSnapshot: + """ + Render-ready progress state returned by the animation model. + + 动画模型返回的可直接用于渲染的进度状态。 + """ + + display_percent: int + actual_percent: int + target_percent: int + display_value: float + actual_value: float + target_value: float + active_phase: str | None + is_finishing: bool + is_settled: bool + suggested_interval_ms: int + + +@dataclass(frozen=True) +class ProgressPhaseProfile: + """ + Visual progress allocation and timing policy for one logical phase. + + 单个逻辑阶段的视觉进度分配与时间策略。 + """ + + key: str + start_percent: float + end_percent: float + min_duration_seconds: float + max_duration_seconds: float + + @property + def span(self) -> float: + """Return the percentage span owned by this phase.""" + return self.end_percent - self.start_percent + + +def phase_from_stage(stage: str) -> str | None: + """ + Map an initialization stage to the owning visual phase. + + 将初始化阶段映射到对应的视觉动画阶段。 + """ + if stage == STAGE_PREPARING_RUNTIME: + return PROGRESS_KIND_RUNTIME + if stage == STAGE_DOWNLOADING: + return PROGRESS_KIND_DOWNLOAD + return None + + +def parse_pip_raw_progress_line(line: str) -> tuple[int, int] | None: + """ + Parse `pip --progress-bar raw` lines into byte counters. + + 解析 `pip --progress-bar raw` 输出行,提取字节级进度。 + """ + match = PIP_RAW_PROGRESS_PATTERN.match(line.strip()) + if not match: + return None + return int(match.group("current")), int(match.group("total")) + + +class InitializationProgressModel: + """ + Deterministic progress animation model for long-running initialization tasks. + + 长耗时初始化任务的确定性进度动画模型。 + + The model uses fixed visual budgets per phase and lets time drive the bar. + Real progress is still tracked for logging and phase completion, but it no + longer directly drives the visible percentage. + + 该模型为每个阶段分配固定视觉预算,并以时间作为进度条主驱动。 + 真实进度仍用于日志与阶段完成判定,但不再直接驱动可见百分比。 + """ + + PHASES: dict[str, ProgressPhaseProfile] = { + PROGRESS_KIND_RUNTIME: ProgressPhaseProfile( + key=PROGRESS_KIND_RUNTIME, + start_percent=0.0, + end_percent=30.0, + min_duration_seconds=2.0, + max_duration_seconds=420.0, + ), + PROGRESS_KIND_DOWNLOAD: ProgressPhaseProfile( + key=PROGRESS_KIND_DOWNLOAD, + start_percent=30.0, + end_percent=99.0, + min_duration_seconds=250.0, + max_duration_seconds=420.0, + ), + } + + def __init__( + self, + *, + seed: int = 17, + min_visible_progress: int = 4, + suggested_interval_ms: int = 80, + ) -> None: + """ + Initialize the animation model with deterministic timing parameters. + + 使用确定性时间参数初始化动画模型。 + """ + self._seed = float(seed) + self._min_visible_progress = min_visible_progress + self._suggested_interval_ms = suggested_interval_ms + self.reset(0.0) + + def reset(self, now: float = 0.0) -> ProgressSnapshot: + """ + Reset the model to its initial idle state. + + 将模型重置为初始空闲状态。 + """ + self._display_percent = 0.0 + self._actual_percent = 0.0 + self._target_percent = 0.0 + self._active_phase: str | None = None + self._phase_started_at = now + self._phase_target_seconds = 0.0 + self._phase_complete_requested = False + self._phase_completion_started_at: float | None = None + self._phase_completion_from = 0.0 + self._phase_completion_duration = 0.0 + self._phase_completion_target = 0.0 + self._phase_was_observed = False + self._finish_started_at: float | None = None + self._finish_duration = 0.0 + self._finish_from = 0.0 + self._settled = False + return self.snapshot() + + def on_stage_changed(self, stage: str, now: float) -> ProgressSnapshot: + """ + React to a stage transition emitted by the initialization manager. + + 响应初始化管理器发出的阶段切换事件。 + """ + phase_key = phase_from_stage(stage) + if phase_key is not None: + self._activate_phase(phase_key, now) + elif stage == STAGE_VERIFYING: + self._actual_percent = max(self._actual_percent, 99.0) + self._target_percent = max(self._target_percent, 99.0) + return self.advance(now) + + def on_progress_event( + self, + event: InitializationProgressEvent, + now: float, + ) -> ProgressSnapshot: + """ + Update actual progress from a structured subsystem event. + + 使用结构化子系统事件更新真实进度。 + """ + phase_key = event.progress_kind if event.progress_kind in self.PHASES else phase_from_stage(event.stage) + if phase_key is not None: + self._activate_phase( + phase_key, + now, + bytes_total=event.bytes_total, + item_count=event.item_count, + event=event, + ) + phase = self.PHASES[phase_key] + ratio = event.normalized_ratio() + if ratio is not None: + actual = phase.start_percent + (phase.span * ratio) + self._actual_percent = max(self._actual_percent, actual) + if event.is_terminal: + self._actual_percent = max(self._actual_percent, phase.end_percent) + return self.advance(now) + + def on_finished(self, success: bool, now: float) -> ProgressSnapshot: + """ + Enter the success settle animation or stop immediately on failure. + + 成功时进入收尾动画,失败时立即停止动画推进。 + """ + if not success: + self._settled = True + return self.snapshot() + + self._actual_percent = max(self._actual_percent, 100.0) + self._target_percent = max(self._target_percent, 100.0) + self._finish_started_at = now + self._finish_from = max(self._display_percent, min(self._target_percent, 99.5)) + remaining = max(0.0, 100.0 - self._finish_from) + self._finish_duration = min(2.0, max(1.0, 1.0 + (remaining / 60.0))) + self._settled = False + return self.advance(now) + + def advance(self, now: float) -> ProgressSnapshot: + """ + Advance the animation to the specified monotonic time. + + 将动画推进到指定的单调时间点。 + """ + if self._finish_started_at is not None: + elapsed = max(0.0, now - self._finish_started_at) + ratio = min(1.0, elapsed / max(0.001, self._finish_duration)) + eased = 1.0 - (1.0 - ratio) ** 3 + self._display_percent = self._finish_from + ((100.0 - self._finish_from) * eased) + if ratio >= 1.0: + self._display_percent = 100.0 + self._settled = True + return self.snapshot() + + if self._active_phase is None: + return self.snapshot() + + if self._phase_completion_started_at is not None: + elapsed = max(0.0, now - self._phase_completion_started_at) + ratio = min(1.0, elapsed / max(0.001, self._phase_completion_duration)) + eased = 1.0 - (1.0 - ratio) ** 3 + target = self._phase_completion_target + self._display_percent = self._phase_completion_from + ( + (target - self._phase_completion_from) * eased + ) + self._target_percent = max(self._target_percent, self._display_percent) + if ratio >= 1.0: + self._display_percent = target + self._target_percent = max(self._target_percent, target) + self._phase_completion_started_at = None + return self.snapshot() + + time_target = self._compute_time_target(now) + desired = time_target + self._target_percent = max(self._target_percent, desired) + self._display_percent = max(self._display_percent, desired) + if self._phase_complete_requested and self._should_begin_phase_completion(now): + self._start_phase_completion(now) + return self.snapshot() + + def snapshot(self) -> ProgressSnapshot: + """ + Return an immutable render snapshot for the current model state. + + 返回当前模型状态的不可变渲染快照。 + """ + display_percent = int(round(self._display_percent)) + actual_percent = int(round(self._actual_percent)) + target_percent = int(round(self._target_percent)) + + if 0 < display_percent < self._min_visible_progress: + display_percent = self._min_visible_progress + + return ProgressSnapshot( + display_percent=max(0, min(100, display_percent)), + actual_percent=max(0, min(100, actual_percent)), + target_percent=max(0, min(100, target_percent)), + display_value=max(0.0, min(100.0, self._display_percent)), + actual_value=max(0.0, min(100.0, self._actual_percent)), + target_value=max(0.0, min(100.0, self._target_percent)), + active_phase=self._active_phase, + is_finishing=self._finish_started_at is not None and not self._settled, + is_settled=self._settled, + suggested_interval_ms=self._suggested_interval_ms, + ) + + def _activate_phase( + self, + phase_key: str, + now: float, + *, + bytes_total: int | None = None, + item_count: int | None = None, + event: InitializationProgressEvent | None = None, + ) -> None: + """ + Enter or retune a visual phase when new information arrives. + + 在新的信息到达时进入或重新调整视觉阶段。 + """ + profile = self.PHASES[phase_key] + is_new_phase = self._active_phase != phase_key + if is_new_phase: + self._active_phase = phase_key + self._phase_started_at = now + self._phase_complete_requested = False + self._phase_completion_started_at = None + self._phase_completion_from = max(self._display_percent, profile.start_percent) + self._phase_completion_duration = 0.0 + self._phase_completion_target = profile.end_percent + self._phase_was_observed = False + self._display_percent = max(self._display_percent, profile.start_percent) + self._actual_percent = max(self._actual_percent, profile.start_percent) + + self._phase_target_seconds = self._choose_target_duration( + profile, + bytes_total=bytes_total, + item_count=item_count, + event=event, + ) + if event is not None: + self._phase_was_observed = self._phase_was_observed or ( + event.bytes_total is not None + or event.bytes_done is not None + or event.normalized_ratio() is not None + ) + if event.is_terminal: + self._phase_complete_requested = True + + def _choose_target_duration( + self, + profile: ProgressPhaseProfile, + *, + bytes_total: int | None, + item_count: int | None, + event: InitializationProgressEvent | None, + ) -> float: + """ + Choose a phase duration inside the configured long-task window. + + 在配置好的长任务窗口内选择当前阶段的目标时长。 + """ + if profile.key == PROGRESS_KIND_RUNTIME: + if event is not None and event.is_terminal and not self._phase_was_observed: + return 4.0 + self._small_duration_jitter(profile.key) + base = 400.0 + self._large_duration_jitter(profile.key) + if bytes_total and bytes_total > 0: + base += min(18.0, bytes_total / float(1024 ** 3) * 6.0) + return min(profile.max_duration_seconds, max(profile.min_duration_seconds, base)) + + base = 400.0 + self._large_duration_jitter(profile.key) + if bytes_total and bytes_total > 0: + size_gib = bytes_total / float(1024 ** 3) + base += min(22.0, size_gib * 8.0) + if item_count and item_count > 1: + base += min(18.0, float(item_count - 1) * 3.0) + return min(profile.max_duration_seconds, max(profile.min_duration_seconds, base)) + + def _compute_time_target(self, now: float) -> float: + """ + Compute the monotonic time-driven target for the active phase. + + 计算当前活动阶段的单调时间驱动目标值。 + """ + assert self._active_phase is not None + profile = self.PHASES[self._active_phase] + elapsed = max(0.0, now - self._phase_started_at) + duration = max(1.0, self._phase_target_seconds) + progress_ratio = min(0.985, elapsed / duration) + jitter = self._bounded_jitter(profile, elapsed, progress_ratio) + raw_ratio = min(0.985, max(0.0, progress_ratio + jitter)) + target = profile.start_percent + (profile.span * raw_ratio) + return max(self._display_percent, target) + + def _bounded_jitter( + self, + profile: ProgressPhaseProfile, + elapsed: float, + progress_ratio: float, + ) -> float: + """ + Return a small deterministic oscillation for natural-looking speed changes. + + 返回一个小幅、确定性的振荡量,用于制造更自然的速度变化。 + """ + damping = max(0.18, 1.0 - progress_ratio) + amplitude = 0.014 * damping + if profile.key == PROGRESS_KIND_DOWNLOAD: + amplitude += 0.004 + + wave = ( + math.sin((elapsed * 0.10) + self._seed) + + 0.45 * math.sin((elapsed * 0.27) + (self._seed * 1.7)) + + 0.2 * math.sin((elapsed * 0.51) + (self._seed * 2.3)) + ) + jitter = (wave / 1.65) * amplitude + return min(0.018, max(-0.014, jitter)) + + def _small_duration_jitter(self, phase_key: str) -> float: + """ + Return a deterministic short-duration jitter in seconds. + + 返回一个确定性的短时长抖动(秒)。 + """ + offset = self._seed + (0.37 if phase_key == PROGRESS_KIND_RUNTIME else 0.91) + return math.sin(offset) * 1.8 + + def _large_duration_jitter(self, phase_key: str) -> float: + """ + Return a deterministic long-duration jitter in seconds. + + 返回一个确定性的长时长抖动(秒)。 + """ + offset = self._seed + (1.13 if phase_key == PROGRESS_KIND_RUNTIME else 2.41) + return math.sin(offset) * 18.0 + + def _should_begin_phase_completion(self, now: float) -> bool: + """ + Decide whether the phase may start its completion settle animation. + + 判断阶段是否可以开始进入完成收敛动画。 + """ + if self._active_phase is None: + return False + elapsed = max(0.0, now - self._phase_started_at) + return elapsed >= self._phase_target_seconds * 0.92 + + def _start_phase_completion(self, now: float) -> None: + """ + Start a short settle animation into the phase end percentage. + + 启动一个短暂的收敛动画,把进度推进到当前阶段的终点百分比。 + """ + if self._active_phase is None: + return + profile = self.PHASES[self._active_phase] + self._phase_completion_started_at = now + self._phase_completion_from = max(self._display_percent, profile.start_percent) + remaining = max(0.0, profile.end_percent - self._phase_completion_from) + self._phase_completion_duration = min(1.8, max(0.8, 0.6 + (remaining / 45.0))) + self._phase_completion_target = profile.end_percent + self._phase_complete_requested = False diff --git a/core/keypoint_detector.py b/core/keypoint_detector.py index 6be315d..d1f4eb9 100644 --- a/core/keypoint_detector.py +++ b/core/keypoint_detector.py @@ -1,6 +1,10 @@ """ -关键点检测器模块 -使用 CUB-200 训练的 ResNet50 模型检测鸟类关键点(左眼、右眼、喙) +关键点检测器模块。 +Keypoint detector module. + +使用 CUB-200 训练的 ResNet50 模型检测鸟类关键点(左眼、右眼、喙)。 +Uses a CUB-200-trained ResNet50 model to detect bird keypoints (left eye, +right eye, and beak). """ import os @@ -13,26 +17,30 @@ import numpy as np import cv2 from dataclasses import dataclass -from typing import Optional, Tuple -from config import get_best_device +from typing import Any, Optional, Tuple, cast +from config import ( + get_best_device, + get_install_scoped_resource_path, + get_packaged_model_relative_path, + get_runtime_app_root, + get_runtime_meipass, +) @dataclass class KeypointResult: """关键点检测结果""" - left_eye: Tuple[float, float] # (x, y) 归一化坐标 + left_eye: Tuple[float, float] right_eye: Tuple[float, float] beak: Tuple[float, float] - left_eye_vis: float # 可见性概率 0-1 + left_eye_vis: float right_eye_vis: float beak_vis: float - - # 派生属性 - both_eyes_hidden: bool # 双眼是否都不可见(保留兼容) - all_keypoints_hidden: bool # 所有关键点(双眼+鸟喙)都不可见 - best_eye_visibility: float # 双眼中较高的置信度 max(左眼, 右眼) - visible_eye: Optional[str] # 'left', 'right', 'both', None - head_sharpness: float # 头部区域锐度 + both_eyes_hidden: bool + all_keypoints_hidden: bool + best_eye_visibility: float + visible_eye: Optional[str] + head_sharpness: float class PartLocalizer(nn.Module): @@ -40,9 +48,10 @@ class PartLocalizer(nn.Module): def __init__(self, num_parts=3, hidden_dim=512, dropout=0.2): super().__init__() self.num_parts = num_parts - self.backbone = models.resnet50(weights=None) - in_features = self.backbone.fc.in_features - self.backbone.fc = nn.Identity() + self.backbone = cast(nn.Module, models.resnet50(weights=None)) + backbone_fc = cast(nn.Linear, getattr(self.backbone, "fc")) + in_features = backbone_fc.in_features + setattr(self.backbone, "fc", nn.Identity()) self.head = nn.Sequential( nn.Linear(in_features, hidden_dim), @@ -67,26 +76,35 @@ def forward(self, x): class KeypointDetector: """鸟类关键点检测器""" - # 默认配置 IMG_SIZE = 416 - VISIBILITY_THRESHOLD = 0.3 # 至少一个关键点可见性需≥0.3才不算"全部不可见" - RADIUS_MULTIPLIER = 1.2 # 有喙时的半径系数 - NO_BEAK_RADIUS_RATIO = 0.15 # 无喙时用检测框的15% + VISIBILITY_THRESHOLD = 0.3 + RADIUS_MULTIPLIER = 1.2 + NO_BEAK_RADIUS_RATIO = 0.15 @staticmethod def _get_default_model_path() -> str: - """获取默认模型路径(支持 PyInstaller 打包)""" + """ + 获取默认模型路径并兼容冻结环境。 + Resolve the default model path while remaining compatible with frozen builds. + """ import sys - if hasattr(sys, '_MEIPASS'): - # PyInstaller 打包后的路径 - return os.path.join(sys._MEIPASS, 'models', 'cub200_keypoint_resnet50_slim.pth') - else: - # 开发环境:优先使用 main.py 注入的真实 app 根目录(补丁覆盖层兼容) - project_root = getattr(sys, '_SUPERPICKY_APP_ROOT', - os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - return os.path.join(project_root, 'models', 'cub200_keypoint_resnet50_slim.pth') + if getattr(sys, 'frozen', False) and sys.platform == 'win32': + return str( + get_install_scoped_resource_path( + 'models/cub200_keypoint_resnet50_slim.pth', + packaged_relative_path=get_packaged_model_relative_path('models/cub200_keypoint_resnet50_slim.pth'), + ) + ) + meipass = get_runtime_meipass() + if meipass is not None: + return os.path.join(meipass, 'models', 'cub200_keypoint_resnet50_slim.pth') + + project_root = get_runtime_app_root() + if project_root is None: + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + return os.path.join(project_root, 'models', 'cub200_keypoint_resnet50_slim.pth') - def __init__(self, model_path: str = None): + def __init__(self, model_path: Optional[str] = None): """ 初始化关键点检测器 @@ -94,9 +112,8 @@ def __init__(self, model_path: str = None): model_path: 模型文件路径,默认使用自动检测的路径 """ self.model_path = model_path or self._get_default_model_path() - # 使用统一的设备检测逻辑 - self.device = get_best_device() - self.model = None + self.device = torch.device(str(get_best_device())) + self.model: Optional[PartLocalizer] = None self.transform = transforms.Compose([ transforms.Resize((self.IMG_SIZE, self.IMG_SIZE)), transforms.ToTensor(), @@ -112,27 +129,33 @@ def load_model(self): raise FileNotFoundError(f"关键点模型不存在: {self.model_path}") self.model = PartLocalizer() - checkpoint = torch.load(self.model_path, map_location=self.device, weights_only=True) + checkpoint = torch.load( + self.model_path, + map_location=self.device, + weights_only=True, + ) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model.load_state_dict(checkpoint) - self.model.to(self.device) - - # V4.0.5: 启用 FP16 半精度推理,提速约 30-50% - # MPS 和 CUDA 都支持 FP16 + self.model.to(device=self.device) + if self.device.type in ('mps', 'cuda'): self.model = self.model.half() self._use_fp16 = True else: self._use_fp16 = False - + self.model.eval() - def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, - seg_mask: np.ndarray = None) -> Optional[KeypointResult]: + def detect( + self, + bird_crop: np.ndarray, + box: Optional[Tuple[int, int, int, int]] = None, + seg_mask: Optional[np.ndarray] = None, + ) -> Optional[KeypointResult]: """ 检测鸟类关键点并计算头部锐度 @@ -148,23 +171,24 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, if bird_crop is None or bird_crop.size == 0: return None - - # 转换为PIL并进行推理 + pil_crop = Image.fromarray(bird_crop) - tensor = self.transform(pil_crop).unsqueeze(0).to(self.device) - - # V4.0.5: 使用 FP16 和 inference_mode 优化推理 + transformed_tensor = cast(torch.Tensor, self.transform(pil_crop)) + tensor = transformed_tensor.unsqueeze(0).to(self.device) + if hasattr(self, '_use_fp16') and self._use_fp16: tensor = tensor.half() - + + if self.model is None: + raise RuntimeError("关键点检测模型尚未初始化") + with torch.inference_mode(): coords, vis = self.model(tensor) coords = coords[0].cpu().numpy() vis = vis[0].cpu().numpy() - del tensor # 立即释放 MPS/CUDA 显存,避免长批次累积 - - # 解析结果 + del tensor + left_eye = (float(coords[0, 0]), float(coords[0, 1])) right_eye = (float(coords[1, 0]), float(coords[1, 1])) beak = (float(coords[2, 0]), float(coords[2, 1])) @@ -172,15 +196,12 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, left_eye_vis = float(vis[0]) right_eye_vis = float(vis[1]) beak_vis = float(vis[2]) - - # 判断可见性 + left_visible = left_eye_vis >= self.VISIBILITY_THRESHOLD right_visible = right_eye_vis >= self.VISIBILITY_THRESHOLD beak_visible = beak_vis >= self.VISIBILITY_THRESHOLD - - # 保留旧属性(兼容) + both_eyes_hidden = not left_visible and not right_visible - # 新逻辑:只有当双眼和鸟喙都不可见时才算"全部不可见" all_keypoints_hidden = not left_visible and not right_visible and not beak_visible if left_visible and right_visible: @@ -191,8 +212,7 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, visible_eye = 'right' else: visible_eye = None - - # 计算头部锐度 + head_sharpness = 0.0 if visible_eye is not None: head_sharpness = self._calculate_head_sharpness( @@ -200,10 +220,9 @@ def detect(self, bird_crop: np.ndarray, box: Tuple[int, int, int, int] = None, left_eye_vis, right_eye_vis, beak_visible, box, seg_mask ) - - # V3.8: 计算双眼中较高的置信度,用于评分封顶逻辑 + best_eye_visibility = max(left_eye_vis, right_eye_vis) - + return KeypointResult( left_eye=left_eye, right_eye=right_eye, @@ -227,25 +246,19 @@ def _calculate_head_sharpness( left_eye_vis: float, right_eye_vis: float, beak_visible: bool, - box: Tuple[int, int, int, int] = None, - seg_mask: np.ndarray = None + box: Optional[Tuple[int, int, int, int]] = None, + seg_mask: Optional[np.ndarray] = None ) -> float: """ 计算头部区域锐度 - - 使用眼睛为圆心,眼喙距离×1.2为半径,与seg掩码取交集 """ h, w = bird_crop.shape[:2] - # 如果双眼都不可见(如鸟侧面、头部转向等): - # 模型坐标仍然大致指向头部位置,用置信度较高的那只眼做 fallback - # 沿用与正常流程完全相同的"圆形区域 Sobel"算法,结果 ×0.8 作为惩罚 - # 这样与正常眼睛检测的锐度值在同一量级,不会因用全身 Sobel 而虚高 if left_eye_vis < self.VISIBILITY_THRESHOLD and right_eye_vis < self.VISIBILITY_THRESHOLD: eye = left_eye if left_eye_vis >= right_eye_vis else right_eye eye_px = (int(eye[0] * w), int(eye[1] * h)) beak_px = (int(beak[0] * w), int(beak[1] * h)) - if beak_vis >= self.VISIBILITY_THRESHOLD: + if beak_visible: radius = int(self._distance(eye_px, beak_px) * self.RADIUS_MULTIPLIER) elif box is not None: box_size = max(box[2], box[3]) @@ -259,111 +272,86 @@ def _calculate_head_sharpness( head_mask = cv2.bitwise_and(circle_mask, seg_mask) else: head_mask = circle_mask - LOW_VIS_PENALTY = 0.8 # 眼睛不可见时降分但不误杀 + LOW_VIS_PENALTY = 0.8 return self._calculate_sharpness(bird_crop, head_mask) * LOW_VIS_PENALTY - - # 选择眼睛:用更远离喙的那只眼 if left_eye_vis >= self.VISIBILITY_THRESHOLD and right_eye_vis >= self.VISIBILITY_THRESHOLD: - # 两眼都可见,选更远离喙的 left_dist = self._distance(left_eye, beak) right_dist = self._distance(right_eye, beak) eye = left_eye if left_dist >= right_dist else right_eye elif left_eye_vis >= self.VISIBILITY_THRESHOLD: eye = left_eye else: - # 只有一只眼可见(右眼) eye = right_eye - - # 转换为像素坐标 + eye_px = (int(eye[0] * w), int(eye[1] * h)) beak_px = (int(beak[0] * w), int(beak[1] * h)) - - # 计算半径 + if beak_visible: radius = int(self._distance(eye_px, beak_px) * self.RADIUS_MULTIPLIER) elif box is not None: - # 无喙时用检测框的15% - # box 格式是 (x, y, w, h),所以 box[2]=width, box[3]=height box_size = max(box[2], box[3]) radius = int(box_size * self.NO_BEAK_RADIUS_RATIO) else: - # 最后fallback:用裁剪区域的15% radius = int(max(w, h) * self.NO_BEAK_RADIUS_RATIO) - - # 确保半径合理 + radius = max(10, min(radius, min(w, h) // 2)) - - # 创建圆形掩码 + circle_mask = np.zeros((h, w), dtype=np.uint8) cv2.circle(circle_mask, eye_px, radius, 255, -1) - - # 如果有seg掩码,取交集 + if seg_mask is not None: - # seg_mask可能需要裁剪到bird_crop区域 - # 这里假设bird_crop已经是裁剪后的,seg_mask也已相应处理 if seg_mask.shape[:2] == (h, w): head_mask = cv2.bitwise_and(circle_mask, seg_mask) else: head_mask = circle_mask else: head_mask = circle_mask - - # 计算锐度 + return self._calculate_sharpness(bird_crop, head_mask) def _calculate_sharpness(self, image: np.ndarray, mask: np.ndarray) -> float: """ 计算掩码区域的锐度(Tenengrad + 对数归一化) - - V3.7 改动: 使用 Tenengrad (Sobel梯度) 替代 Laplacian以减少噪点干扰 - 并使用对数归一化将结果映射到 0-1000 范围 """ if mask.sum() == 0: return 0.0 - - # 转灰度 + if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) else: gray = image - - # Tenengrad 算子 (Sobel梯度平方和) + gx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3) gy = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3) gradient_magnitude = gx ** 2 + gy ** 2 - - # 只取掩码区域的平均值 + mask_pixels = mask > 0 if mask_pixels.sum() == 0: return 0.0 - + raw_sharpness = float(gradient_magnitude[mask_pixels].mean()) - - # 对数归一化到 0-1000 - # V4.0 修复: 降低 MIN_VAL,之前 1460 太高导致锐利照片也返回 0 - # 测试显示: 锐利照片梯度平均值约 800-2000 - MIN_VAL = 100.0 # 降低阈值,保留更多低锐度信息 + + MIN_VAL = 100.0 MAX_VAL = 154016.0 - + if raw_sharpness <= MIN_VAL: return 0.0 if raw_sharpness >= MAX_VAL: return 1000.0 - + log_val = math.log(raw_sharpness) - math.log(MIN_VAL) log_max = math.log(MAX_VAL) - math.log(MIN_VAL) - + return (log_val / log_max) * 1000.0 - + @staticmethod def _distance(p1: Tuple[float, float], p2: Tuple[float, float]) -> float: """计算两点距离""" return math.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2) -# 便捷函数 -_detector_instance = None +_detector_instance: Optional[KeypointDetector] = None def get_keypoint_detector() -> KeypointDetector: """获取全局关键点检测器实例""" diff --git a/core/recursive_scanner.py b/core/recursive_scanner.py index 4f5307e..f557e8e 100644 --- a/core/recursive_scanner.py +++ b/core/recursive_scanner.py @@ -7,7 +7,10 @@ """ import os -from typing import List, Set +import platform +from dataclasses import dataclass +from pathlib import PurePosixPath, PureWindowsPath +from typing import List, Optional, Set, Tuple from constants import RAW_EXTENSIONS, JPG_EXTENSIONS, HEIF_EXTENSIONS, RATING_FOLDER_NAMES, RATING_FOLDER_NAMES_EN @@ -17,6 +20,17 @@ # 星级目录名(中 + 英) _RATING_DIR_NAMES: Set[str] = set(RATING_FOLDER_NAMES.values()) | set(RATING_FOLDER_NAMES_EN.values()) +DEFAULT_SCAN_MAX_DEPTH = 16 + + +@dataclass(frozen=True) +class ScannedDirectory: + """扫描结果条目""" + + path: str + depth: int + photo_count: int + def is_excluded(dirname: str) -> bool: """判断目录是否应被排除(非用户照片目录)""" @@ -31,17 +45,137 @@ def is_excluded(dirname: str) -> bool: return False +def _scan_directory_once(dir_path: str) -> Tuple[int, List[str]]: + """单次扫描目录,返回直接照片数量与可继续扫描的子目录。""" + photo_count = 0 + child_dirs: List[str] = [] + + try: + with os.scandir(dir_path) as entries: + for entry in entries: + if entry.is_file(follow_symlinks=False): + ext = os.path.splitext(entry.name)[1].lower() + if ext in _PHOTO_EXTENSIONS: + photo_count += 1 + continue + + if not entry.is_dir(follow_symlinks=False): + continue + if is_excluded(entry.name): + continue + + child_dirs.append(entry.path) + except (FileNotFoundError, NotADirectoryError, PermissionError): + return 0, [] + + child_dirs.sort(key=lambda value: os.path.basename(value).casefold()) + return photo_count, child_dirs + + +def _scan_directories_dfs(root: str, max_depth: int) -> List[ScannedDirectory]: + root = os.path.abspath(root) + if max_depth < 0: + return [] + + result: List[ScannedDirectory] = [] + stack: List[Tuple[str, int]] = [(root, 0)] + while stack: + dir_path, depth = stack.pop() + photo_count, child_dirs = _scan_directory_once(dir_path) + if photo_count > 0: + result.append(ScannedDirectory(path=dir_path, depth=depth, photo_count=photo_count)) + if depth >= max_depth: + continue + for child_dir in reversed(child_dirs): + stack.append((child_dir, depth + 1)) + result.sort(key=lambda item: item.path.casefold()) + return result + + +def _is_windows_path(path: str) -> bool: + drive, _ = os.path.splitdrive(path) + return bool(drive) or "\\" in path + + +def _is_subpath(candidate_parts: Tuple[str, ...], protected_parts: Tuple[str, ...]) -> bool: + if len(candidate_parts) < len(protected_parts): + return False + return candidate_parts[:len(protected_parts)] == protected_parts + + +def is_dangerous_root( + root: str, + platform_name: Optional[str] = None, + home_dir: Optional[str] = None, +) -> Tuple[bool, str]: + """判断根目录是否属于危险目录。""" + platform_name = (platform_name or platform.system()).lower() + home_dir = os.path.expanduser(home_dir or "~") + + if platform_name.startswith("win") or _is_windows_path(root): + normalized = str(PureWindowsPath(os.path.realpath(os.path.abspath(root)))) + root_path = PureWindowsPath(normalized) + anchor = root_path.anchor.rstrip("\\/") + current = normalized.rstrip("\\/") + if anchor and current.lower() == anchor.lower(): + return True, "磁盘根目录 / Drive root" + + protected_paths = [ + PureWindowsPath(os.path.realpath(os.environ.get("SystemRoot", "C:\\Windows"))), + PureWindowsPath(os.path.realpath("C:\\Program Files")), + PureWindowsPath(os.path.realpath("C:\\Program Files (x86)")), + PureWindowsPath(os.path.realpath(os.path.join(home_dir, "AppData"))), + ] + root_parts = tuple(part.casefold() for part in root_path.parts) + for protected in protected_paths: + protected_parts = tuple(part.casefold() for part in protected.parts) + if _is_subpath(root_parts, protected_parts): + return True, f"受保护的系统或设置目录 / Protected path: {protected}" + return False, "" + + normalized = str(PurePosixPath(os.path.realpath(os.path.abspath(root)))) + root_path = PurePosixPath(normalized) + root_parts = tuple(root_path.parts) + + # 严格相等判断:仅当路径恰好是文件系统根目录时才阻断。 + # 不能将 "/" 放入 protected_paths 用 _is_subpath() 做前缀匹配, + # 因为所有绝对路径的第一个 part 均为 "/",会导致误判一切目录。 + # Strict equality check: only block when the path is exactly the filesystem root. + # Do NOT put "/" into protected_paths for _is_subpath() prefix matching, + # because every absolute path starts with "/" as its first part, + # which would cause all directories to be falsely flagged. + if normalized == "/": + return True, "文件系统根目录 / Filesystem root" + + # 对受保护路径同样执行 realpath 解析,以处理符号链接。 + # 例如 macOS 上 /etc -> /private/etc,/var -> /private/var, + # 若不解析则与已 realpath 处理的 normalized 无法匹配。 + # Resolve protected paths with realpath too, to handle symlinks. + # e.g. on macOS: /etc -> /private/etc, /var -> /private/var. + _raw_protected = [ + "/usr", + "/etc", + "/var", + "/System", + "/Library", + os.path.join(home_dir, "Library"), + ] + protected_paths = [ + PurePosixPath(os.path.realpath(p)) for p in _raw_protected + ] + for protected in protected_paths: + protected_parts = tuple(protected.parts) + if _is_subpath(root_parts, protected_parts): + return True, f"受保护的系统或设置目录 / Protected path: {protected}" + if normalized in ("/home", os.path.realpath("/home")): + return True, "系统用户根目录 / System user root" + return False, "" + + def has_photos(dir_path: str) -> bool: """判断目录是否直接包含至少 1 个照片文件""" - try: - for entry in os.scandir(dir_path): - if entry.is_file(follow_symlinks=False): - ext = os.path.splitext(entry.name)[1].lower() - if ext in _PHOTO_EXTENSIONS: - return True - except PermissionError: - pass - return False + photo_count, _ = _scan_directory_once(dir_path) + return photo_count > 0 def is_processed(dir_path: str) -> bool: @@ -49,56 +183,34 @@ def is_processed(dir_path: str) -> bool: return os.path.exists(os.path.join(dir_path, '.superpicky', 'report.db')) -def scan_recursive(root: str, max_depth: int = 10) -> List[str]: +def scan_directories( + root: str, + max_depth: int = DEFAULT_SCAN_MAX_DEPTH, +) -> List[ScannedDirectory]: + """扫描根目录,返回包含照片的目录摘要列表。""" + return _scan_directories_dfs(root, max_depth) + + +def scan_dfs(root: str, max_depth: int = DEFAULT_SCAN_MAX_DEPTH) -> List[ScannedDirectory]: + """使用 DFS 扫描根目录。""" + return scan_directories(root, max_depth=max_depth) + + +def scan_recursive(root: str, max_depth: int = DEFAULT_SCAN_MAX_DEPTH) -> List[str]: """ 递归扫描根目录,返回所有原子目录(包含照片的非排除目录)的绝对路径列表。 Args: root: 根目录路径 - max_depth: 最大递归深度(默认 10) + max_depth: 最大递归深度(默认 16) Returns: 原子目录绝对路径列表,按字母排序 """ - result: List[str] = [] - - # 根目录本身如果包含照片,也加入列表 - if has_photos(root): - result.append(root) - - def _scan(dir_path: str, depth: int): - if depth > max_depth: - return - try: - entries = sorted(os.scandir(dir_path), key=lambda e: e.name) - except PermissionError: - return - - for entry in entries: - if not entry.is_dir(follow_symlinks=False): - continue - if is_excluded(entry.name): - continue - - if has_photos(entry.path): - result.append(entry.path) - - # 即使当前目录有照片,也继续扫描子目录 - _scan(entry.path, depth + 1) - - _scan(root, 0) - return result + return [item.path for item in scan_dfs(root, max_depth=max_depth)] def count_photos(dir_path: str) -> int: """统计目录中直接包含的照片文件数量""" - count = 0 - try: - for entry in os.scandir(dir_path): - if entry.is_file(follow_symlinks=False): - ext = os.path.splitext(entry.name)[1].lower() - if ext in _PHOTO_EXTENSIONS: - count += 1 - except PermissionError: - pass + count, _ = _scan_directory_once(dir_path) return count diff --git a/core/runtime_bootstrap.py b/core/runtime_bootstrap.py new file mode 100644 index 0000000..74b175c --- /dev/null +++ b/core/runtime_bootstrap.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +""" +Packaged runtime bootstrap helper. + +This entrypoint is designed for frozen Windows lightweight builds. It runs in a +separate hidden process mode of the packaged executable and installs runtime +dependencies into an app-local site-packages directory without relying on any +system Python interpreter. +""" + +from __future__ import annotations + +import argparse +import os +import json +import sys +import importlib +from datetime import datetime, timezone +from pathlib import Path + + +def _parse_args(argv: list[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--runtime-bootstrap", action="store_true") + parser.add_argument("--runtime-dir", required=True) + parser.add_argument("--requirements", required=True) + parser.add_argument("--index-url", default=None) + parser.add_argument("--extra-index-url", action="append", default=[]) + return parser.parse_args(argv) + + +def _ensure_utf8_stdio() -> None: + for stream_name in ("stdout", "stderr"): + stream = getattr(sys, stream_name, None) + reconfigure = getattr(stream, "reconfigure", None) + if callable(reconfigure): + reconfigure(encoding="utf-8", errors="replace") + + +def _build_pip_args(args: argparse.Namespace, site_packages_dir: Path) -> list[str]: + """ + Build the pip command line for the packaged runtime bootstrap. + + 为打包运行时引导流程构建 pip 命令行参数。 + """ + pip_args = [ + "install", + "--disable-pip-version-check", + "--no-warn-script-location", + "--no-cache-dir", + "--progress-bar", + "raw", + "--use-deprecated=legacy-certs", + "--upgrade", + "--target", + str(site_packages_dir), + "-r", + str(Path(args.requirements).resolve()), + ] + if args.index_url: + pip_args.extend(["-i", args.index_url]) + for extra_index_url in args.extra_index_url: + pip_args.extend(["--extra-index-url", extra_index_url]) + return pip_args + + +def _write_manifest(runtime_dir: Path, site_packages_dir: Path, args: argparse.Namespace) -> None: + manifest = { + "generated_at": datetime.now(timezone.utc).isoformat(), + "runtime_dir": str(runtime_dir), + "site_packages_dir": str(site_packages_dir), + "requirements": str(Path(args.requirements).resolve()), + "index_url": args.index_url, + "extra_index_urls": list(args.extra_index_url), + "python_version": sys.version, + "bootstrap_executable": sys.executable, + } + manifest_path = runtime_dir / "runtime_install_manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2, ensure_ascii=False), encoding="utf-8") + + +def _bundled_root() -> Path: + meipass = getattr(sys, "_MEIPASS", None) + if isinstance(meipass, str) and meipass: + return Path(meipass) + if getattr(sys, "frozen", False) and sys.platform == "win32": + return Path(sys.executable).resolve().parent / "_internal" + return Path(__file__).resolve().parent.parent + + +def _configure_ca_bundle() -> Path | None: + cert_path = _bundled_root() / "certifi" / "cacert.pem" + if not cert_path.exists(): + return None + os.environ.setdefault("PIP_CERT", str(cert_path)) + os.environ.setdefault("SSL_CERT_FILE", str(cert_path)) + os.environ.setdefault("REQUESTS_CA_BUNDLE", str(cert_path)) + return cert_path + + +def _patch_pip_for_frozen_bootstrap() -> None: + """ + Patch pip vendored distlib so it can run from a PyInstaller-frozen process. + + distlib expects a standard loader/resource finder pair and Windows launcher + executables inside the distlib package. In a frozen app, the PyInstaller + loader type is unknown to distlib, and launcher stubs are not needed for + our app-local runtime bootstrap. We register a fallback finder and disable + launcher generation on Windows. + """ + os.environ.setdefault("PIP_DISABLE_PIP_VERSION_CHECK", "1") + os.environ.setdefault("PIP_USE_DEPRECATED", "legacy-certs") + cert_path = _configure_ca_bundle() + + from pip._vendor.distlib import resources as distlib_resources + + distlib_pkg = importlib.import_module("pip._vendor.distlib") + loader = getattr(distlib_pkg, "__loader__", None) + if loader is not None: + finder_registry = getattr(distlib_resources, "_finder_registry", {}) + if type(loader) not in finder_registry: + distlib_resources.register_finder(loader, distlib_resources.ResourceFinder) + + import pip._vendor.distlib.scripts as distlib_scripts + import pip._vendor.certifi as pip_certifi + import pip._internal.cli.index_command as pip_index_command + + if not getattr(distlib_scripts.ScriptMaker.__init__, "_superpicky_patched", False): + original_init = distlib_scripts.ScriptMaker.__init__ + + def _patched_init(self, source_dir, target_dir, add_launchers=True, dry_run=False, fileop=None): + # We do not need .exe launcher stubs in the app-local target dir. + return original_init(self, source_dir, target_dir, add_launchers=False, dry_run=dry_run, fileop=fileop) + + _patched_init._superpicky_patched = True # type: ignore[attr-defined] + distlib_scripts.ScriptMaker.__init__ = _patched_init + + # Frozen bootstrap can resolve vendored certifi resources incorrectly under + # truststore. Force pip to skip the truststore SSL context path and fall + # back to its legacy certificate behavior. + pip_index_command._create_truststore_ssl_context = lambda: None + if cert_path is not None: + pip_certifi.where = lambda: str(cert_path) + + +def run_runtime_bootstrap(argv: list[str]) -> int: + _ensure_utf8_stdio() + args = _parse_args(argv) + runtime_dir = Path(args.runtime_dir).resolve() + site_packages_dir = runtime_dir / "site-packages" + site_packages_dir.mkdir(parents=True, exist_ok=True) + + _patch_pip_for_frozen_bootstrap() + + from pip._internal.cli.main import main as pip_main + + pip_args = _build_pip_args(args, site_packages_dir) + print(f"[runtime-bootstrap] target={site_packages_dir}") + exit_code = int(pip_main(pip_args)) + if exit_code != 0: + return exit_code + + if str(site_packages_dir) not in sys.path: + sys.path.insert(0, str(site_packages_dir)) + + import torch # noqa: F401 + + _write_manifest(runtime_dir, site_packages_dir, args) + print("[runtime-bootstrap] torch import verified") + return 0 diff --git a/core/runtime_requirements.py b/core/runtime_requirements.py new file mode 100644 index 0000000..5873f5d --- /dev/null +++ b/core/runtime_requirements.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +""" +Runtime requirements manager for lightweight builds. + +This module provides a unified interface for managing platform-specific runtime +dependencies across CPU, CUDA, and macOS builds. It consolidates the previously +separate requirements_runtime_*.txt files into a single Python module with +type-safe configuration access. + +轻量化构建的运行时依赖管理模块。 + +该模块为 CPU、CUDA 和 macOS 构建统一描述运行时依赖, +避免把平台差异散落在多个 requirements 文本文件与调用点之间。 +""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass +from typing import Literal + + +PlatformType = Literal["cpu", "cuda", "mac"] + + +@dataclass(frozen=True) +class RuntimeRequirements: + """ + Runtime dependency configuration for a specific platform. + + 单个平台对应的运行时依赖配置。 + """ + + torch_version: str + torchvision_version: str + torchaudio_version: str + timm_version: str + extra_index_urls: list[str] + index_url: str | None = None + + @staticmethod + def _format_pinned_requirement(package_name: str, version: str) -> str: + """ + Return a pinned requirement only when a version is provided. + + 仅在存在版本号时返回带固定版本的依赖声明。 + """ + + normalized_version = version.strip() + if not normalized_version: + return package_name + return f"{package_name}=={normalized_version}" + + def to_requirements_string( + self, + *, + include_indexes: bool = True, + package_urls: dict[str, str] | None = None, + ) -> str: + """ + Convert configuration to pip requirements file format. + + 将配置转换为 pip requirements 文件格式。 + """ + lines = [] + package_urls = package_urls or {} + if include_indexes and self.index_url: + lines.append(f"--index-url {self.index_url}") + if include_indexes: + for url in self.extra_index_urls: + lines.append(f"--extra-index-url {url}") + lines.append( + package_urls.get( + "torch", + self._format_pinned_requirement("torch", self.torch_version), + ) + ) + lines.append( + package_urls.get( + "torchvision", + self._format_pinned_requirement("torchvision", self.torchvision_version), + ) + ) + lines.append( + package_urls.get( + "torchaudio", + self._format_pinned_requirement("torchaudio", self.torchaudio_version), + ) + ) + lines.append(f"timm{self.timm_version}") + return "\n".join(lines) + + +def get_cpu_requirements() -> RuntimeRequirements: + """Get runtime requirements for CPU builds. / 获取 CPU 构建的运行时依赖。""" + return RuntimeRequirements( + torch_version="2.7.1+cpu", + torchvision_version="0.22.1+cpu", + torchaudio_version="2.7.1+cpu", + timm_version=">=0.9.0", + extra_index_urls=[ + "https://mirror.nju.edu.cn/pytorch/whl/cpu/", + "https://download.pytorch.org/whl/cpu", + ], + ) + + +def get_cuda_requirements() -> RuntimeRequirements: + """Get runtime requirements for CUDA builds. / 获取 CUDA 构建的运行时依赖。""" + return RuntimeRequirements( + torch_version="2.7.1+cu118", + torchvision_version="0.22.1+cu118", + torchaudio_version="2.7.1+cu118", + timm_version=">=0.9.0", + extra_index_urls=[ + "https://mirror.nju.edu.cn/pytorch/whl/cu118/", + "https://download.pytorch.org/whl/cu118", + ], + ) + + +def get_mac_requirements() -> RuntimeRequirements: + """Get runtime requirements for macOS builds. / 获取 macOS 构建的运行时依赖。""" + return RuntimeRequirements( + torch_version="2.8.0", + torchvision_version="", + torchaudio_version="", + timm_version=">=0.9.0", + extra_index_urls=[], + ) + + +def detect_platform() -> PlatformType: + """Detect the current platform type. / 检测当前平台类型。""" + if sys.platform == "darwin": + return "mac" + if sys.platform == "win32": + return "cuda" + return "cpu" + + +def get_runtime_requirements(platform: PlatformType | None = None) -> RuntimeRequirements: + """ + Get runtime requirements for the specified or detected platform. + + Args: + platform: Platform type ('cpu', 'cuda', 'mac'). If None, auto-detects. + 平台类型;若为 None,则自动检测。 + + Returns: + RuntimeRequirements: Platform-specific dependency configuration. + 对应平台的依赖配置。 + + Raises: + ValueError: If platform type is invalid. + 当平台类型非法时抛出。 + """ + if platform is None: + platform = detect_platform() + + requirements_getters = { + "cpu": get_cpu_requirements, + "cuda": get_cuda_requirements, + "mac": get_mac_requirements, + } + + getter = requirements_getters.get(platform) + if getter is None: + raise ValueError(f"Unsupported platform: {platform}") + + return getter() diff --git a/core/source_probe.py b/core/source_probe.py new file mode 100644 index 0000000..6a24e6b --- /dev/null +++ b/core/source_probe.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +""" +HTTP source probe helpers for initialization. + +Notes: +- We intentionally do not use ICMP ping as the primary selection mechanism. +- Some networks block ping while HTTPS still works normally. +- Selection is based on real HTTP responsiveness and cached for the current run. + +HTTP 源探测辅助工具,用于初始化。 + +注意事项: +- 我们有意不使用 ICMP ping 作为主要选择机制。 +- 某些网络阻止 ping,但 HTTPS 仍然正常工作。 +- 选择基于真实的 HTTP 响应能力,并在当前运行中缓存。 +""" + +from __future__ import annotations + +import logging +import time +import urllib.request +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional + +logging.basicConfig(level=logging.INFO) + + +DEFAULT_TIMEOUT_SECONDS = 4.0 +LARGE_FILE_TIMEOUT_SECONDS = 10.0 + + +@dataclass +class ProbeResult: + """ + 源探测结果数据类。 + + Source probe result dataclass. + + 属性 Attributes: + name (str): 源名称 + url (str): 源 URL + ok (bool): 探测是否成功 + total_ms (float): 总响应时间(毫秒) + first_byte_ms (float): 首字节响应时间(毫秒) + error (Optional[str]): 错误信息(如果失败) + status_code (Optional[int]): HTTP 状态码 + response_headers (Optional[Dict[str, str]]): 响应头 + """ + + name: str + url: str + ok: bool + total_ms: float + first_byte_ms: float + error: Optional[str] = None + status_code: Optional[int] = None + response_headers: Optional[Dict[str, str]] = None + + +_PROBE_CACHE: Dict[str, List[ProbeResult]] = {} + + +def _normalize_probe_url(url: str) -> str: + """ + 标准化化探测 URL。 + + Normalize probe URL. + + 参数 Parameters: + url (str): 原始 URL + + 返回 Returns: + str: 标准化后的 URL + """ + if url.endswith("/simple"): + return url.rstrip("/") + "/pip/" + return url + + +def probe_url( + name: str, url: str, timeout: float = DEFAULT_TIMEOUT_SECONDS +) -> ProbeResult: + """ + 探测单个 URL 的响应能力。 + + Probe the responsiveness of a single URL. + + 参数 Parameters: + name (str): 源名称 + url (str): 要探测的 URL + timeout (float): 超时时间(秒) + + 返回 Returns: + ProbeResult: 探测结果 + """ + start = time.perf_counter() + request = urllib.request.Request( + _normalize_probe_url(url), + headers={"User-Agent": "SuperPicky-InitProbe/1.0"}, + method="GET", + ) + try: + with urllib.request.urlopen(request, timeout=timeout) as response: + first_byte_start = time.perf_counter() + response.read(256) + first_byte_ms = (time.perf_counter() - first_byte_start) * 1000.0 + + status_code = response.getcode() + response_headers = dict(response.headers.items()) + + total_ms = (time.perf_counter() - start) * 1000.0 + + logging.info( + "源探测成功: %s (%s) - 状态码: %d, 总耗时: %.2f ms, 首字节: %.2f ms", + name, + url, + status_code, + total_ms, + first_byte_ms, + ) + + return ProbeResult( + name=name, + url=url, + ok=True, + total_ms=total_ms, + first_byte_ms=first_byte_ms, + status_code=status_code, + response_headers=response_headers, + ) + except Exception as exc: + total_ms = (time.perf_counter() - start) * 1000.0 + error_msg = f"{type(exc).__name__}: {exc}" + + logging.warning( + "源探测失败: %s (%s) - 错误: %s, 耗时: %.2f ms", + name, + url, + error_msg, + total_ms, + ) + + return ProbeResult( + name=name, + url=url, + ok=False, + total_ms=total_ms, + first_byte_ms=0.0, + error=error_msg, + status_code=None, + response_headers=None, + ) + + +def probe_sources( + group_name: str, sources: Iterable[dict], timeout: float = DEFAULT_TIMEOUT_SECONDS +) -> List[ProbeResult]: + """ + 探测一组源并返回结果。 + + Probe a group of sources and return results. + + 参数 Parameters: + group_name (str): 源组名称(用于缓存) + sources (Iterable[dict]): 源列表,每个源包含 name 和 url + timeout (float): 超时时间(秒) + + 返回 Returns: + List[ProbeResult]: 探测结果列表 + """ + if group_name in _PROBE_CACHE: + logging.info("使用缓存的探测结果: %s", group_name) + return list(_PROBE_CACHE[group_name]) + + sources_list = list(sources) + logging.info("开始探测源组: %s,共 %d 个源", group_name, len(sources_list)) + results: List[ProbeResult] = [] + for source in sources_list: + results.append(probe_url(source["name"], source["url"], timeout=timeout)) + + _PROBE_CACHE[group_name] = list(results) + + successful_count = sum(1 for item in results if item.ok) + logging.info( + "源组 %s 探测完成: %d/%d 成功", group_name, successful_count, len(results) + ) + + return results + + +def pick_best_source(results: Iterable[ProbeResult]) -> Optional[ProbeResult]: + """ + 从探测结果中选择最佳源。 + + Select the best source from probe results. + + 参数 Parameters: + results (Iterable[ProbeResult]): 探测结果列表 + + 返回 Returns: + Optional[ProbeResult]: 最佳源,如果没有成功的源则返回 None + """ + successful = [item for item in results if item.ok] + if not successful: + logging.warning("没有可用的源") + return None + + best = min(successful, key=lambda item: (item.total_ms, item.first_byte_ms)) + logging.info( + "选择最佳源: %s (%s) - 总耗时: %.2f ms, 首字节: %.2f ms", + best.name, + best.url, + best.total_ms, + best.first_byte_ms, + ) + return best + + +def clear_probe_cache() -> None: + _PROBE_CACHE.clear() diff --git a/inno/SuperPicky-lite.iss b/inno/SuperPicky-lite.iss new file mode 100644 index 0000000..fedb530 --- /dev/null +++ b/inno/SuperPicky-lite.iss @@ -0,0 +1,64 @@ +; Script generated by the Inno Setup Script Wizard. +; SEE THE DOCUMENTATION FOR DETAILS ON CREATING INNO SETUP SCRIPT FILES! +; Non-commercial use only + +#define MyAppName "SuperPicky" +#define MyAppVersion "unknown" +#define MyAppPublisher "JamesPhotography" +#define MyAppURL "superpicky.app" +#define MyAppExeName "SuperPicky.exe" +#define MyAppCommitHash "unknown" +#define OutputBaseFilename "SuperPicky_Setup_Lite_Win64_" + MyAppVersion + "_" + MyAppCommitHash + +[Setup] +; NOTE: The value of AppId uniquely identifies this application. Do not use the same AppId value in installers for other applications. +; (To generate a new GUID, click Tools | Generate GUID inside the IDE.) +AppId={{14FA9904-CE97-4FAC-84F3-3A7A705590FB} +AppName={#MyAppName} +AppVersion={#MyAppVersion} +AppVerName={#MyAppName} {#MyAppVersion} +AppPublisher={#MyAppPublisher} +AppPublisherURL={#MyAppURL} +AppSupportURL={#MyAppURL} +AppUpdatesURL={#MyAppURL} +DefaultDirName={autopf}\SuperPicky +UninstallDisplayIcon={app}\{#MyAppExeName} +; "ArchitecturesAllowed=x64compatible" specifies that Setup cannot run +; on anything but x64 and Windows 11 on Arm. +ArchitecturesAllowed=x64compatible +; "ArchitecturesInstallIn64BitMode=x64compatible" requests that the +; install be done in "64-bit mode" on x64 or Windows 11 on Arm, +; meaning it should use the native 64-bit Program Files directory and +; the 64-bit view of the registry. +ArchitecturesInstallIn64BitMode=x64compatible +DisableProgramGroupPage=yes +; Remove the following line to run in administrative install mode (install for all users). +PrivilegesRequired=lowest +OutputDir=output +OutputBaseFilename={#OutputBaseFilename} +SetupIconFile=img\icon.ico +SolidCompression=yes +WizardStyle=modern dynamic + +[Languages] +Name: "chinesesimplified"; MessagesFile: "ChineseSimplified.isl" +Name: "english"; MessagesFile: "compiler:Default.isl" + +[Tasks] +Name: "desktopicon"; Description: "{cm:CreateDesktopIcon}"; GroupDescription: "{cm:AdditionalIcons}"; Flags: checkablealone + +[Files] +Source: "{#MyAppExeName}"; DestDir: "{app}"; Flags: ignoreversion +Source: "_internal\*"; DestDir: "{app}\_internal"; Flags: ignoreversion recursesubdirs createallsubdirs +; NOTE: Don't use "Flags: ignoreversion" on any shared system files + +[Icons] +Name: "{autoprograms}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}" +Name: "{autodesktop}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; Tasks: desktopicon + +[Run] +Filename: "{app}\{#MyAppExeName}"; Description: "{cm:LaunchProgram,{#StringChange(MyAppName, '&', '&&')}}"; Flags: nowait postinstall skipifsilent +Filename: "https://superpicky.app/"; Description: "访问项目网站"; Flags: postinstall skipifsilent shellexec + +[UninstallDelete] +Type: filesandordirs; Name: "{app}\_internal" \ No newline at end of file diff --git a/inno/SuperPicky.iss b/inno/SuperPicky.iss index 7489414..a3d00a5 100644 --- a/inno/SuperPicky.iss +++ b/inno/SuperPicky.iss @@ -1,407 +1,62 @@ +; SuperPicky Full 安装脚本 +; SuperPicky Full installer script +; Non-commercial use only + +#define MyAppName "SuperPicky" +#define MyAppVersion "unknown" +#define MyAppPublisher "JamesPhotography" +#define MyAppURL "superpicky.app" +#define MyAppExeName "SuperPicky.exe" +#define MyAppCommitHash "unknown" +#define OutputBaseFilename "SuperPicky_Setup_Full_Win64_" + MyAppVersion + "_" + MyAppCommitHash + [Setup] -AppId=SuperPicky -AppName=SuperPicky -AppVersion=4.2.0-113b079 +AppId={{B7E3F2A1-8D4C-4F5A-9E6B-1C2D3E4F5A6B} +AppName={#MyAppName} +AppVersion={#MyAppVersion} +AppVerName={#MyAppName} {#MyAppVersion} +AppPublisher={#MyAppPublisher} +AppPublisherURL={#MyAppURL} +AppSupportURL={#MyAppURL} +AppUpdatesURL={#MyAppURL} DefaultDirName={autopf}\SuperPicky -DefaultGroupName=SuperPicky -AppPublisherURL=https://superpicky.app/ -OutputBaseFilename=SuperPicky_Setup_Win64_4.2.0_113b079 +UninstallDisplayIcon={app}\{#MyAppExeName} +ArchitecturesAllowed=x64compatible +ArchitecturesInstallIn64BitMode=x64compatible +DisableProgramGroupPage=yes +PrivilegesRequired=lowest +OutputDir=output +OutputBaseFilename={#OutputBaseFilename} +SetupIconFile=img\icon.ico Compression=lzma2/ultra64 LZMAUseSeparateProcess=yes LZMADictionarySize=1048576 LZMANumFastBytes=273 SolidCompression=yes -CreateAppDir=yes -DirExistsWarning=no -UninstallDisplayIcon={app}\SuperPicky.exe -SetupIconFile=_internal\img\icon.ico WizardStyle=modern -DisableProgramGroupPage=yes -DisableDirPage=no -DisableReadyPage=no -DisableFinishedPage=no -VersionInfoCompany=https://superpicky.app/ -WizardImageFile=_internal\img\icon.png -WizardSmallImageFile=_internal\img\icon.png -AlwaysShowComponentsList=no -AlwaysShowGroupOnReadyPage=no -ArchitecturesAllowed=x64compatible -ArchitecturesInstallIn64BitMode=x64compatible -PrivilegesRequired=admin +WizardImageFile=img\icon.png +WizardSmallImageFile=img\icon.png CloseApplications=yes RestartApplications=no -UsePreviousAppDir=no - -[Registry] -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "InstallDir"; ValueData: "{app}"; Flags: uninsdeletevalue -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "Version"; ValueData: "{#SetupSetting('AppVersion')}"; Flags: uninsdeletevalue -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "UninstallString"; ValueData: """{uninstallexe}"""; Flags: uninsdeletevalue -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "CUDA_Patch_Installed"; ValueData: "0"; Flags: uninsdeletevalue -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "CUDA_Patch_Version"; ValueData: ""; Flags: uninsdeletevalue -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "CUDA_Patch_TargetDir"; ValueData: ""; Flags: uninsdeletevalue -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "CUDA_Patch_FileList"; ValueData: ""; Flags: uninsdeletevalue -Root: HKLM64; Subkey: "SOFTWARE\SuperPicky"; ValueType: string; ValueName: "CUDA_Patch_InstalledAt"; ValueData: ""; Flags: uninsdeletevalue - -[Code] -const - AppRegistryKey = 'SOFTWARE\SuperPicky'; - UninstallKeyAppId = 'Software\Microsoft\Windows\CurrentVersion\Uninstall\SuperPicky'; - UninstallKeyLegacy = 'Software\Microsoft\Windows\CurrentVersion\Uninstall\SuperPicky_is1'; - PatchManifestRelativePath = '_internal\cuda_patch_manifest.txt'; - -var - PreviousInstallDir: string; - PreviousUninstallString: string; - PatchCleanupWarnings: string; - -function QueryStringValue(const RootKey: Integer; const SubKey, ValueName: string; var Value: string): Boolean; -begin - Result := RegQueryStringValue(RootKey, SubKey, ValueName, Value) and (Trim(Value) <> ''); -end; - -function QueryInstallDir(var Value: string): Boolean; -begin - Result := - QueryStringValue(HKLM64, AppRegistryKey, 'InstallDir', Value) or - QueryStringValue(HKLM64, UninstallKeyAppId, 'Inno Setup: App Path', Value) or - QueryStringValue(HKLM64, UninstallKeyLegacy, 'Inno Setup: App Path', Value) or - QueryStringValue(HKLM, AppRegistryKey, 'InstallDir', Value) or - QueryStringValue(HKLM, UninstallKeyAppId, 'Inno Setup: App Path', Value) or - QueryStringValue(HKLM, UninstallKeyLegacy, 'Inno Setup: App Path', Value) or - QueryStringValue(HKCU, AppRegistryKey, 'InstallDir', Value) or - QueryStringValue(HKCU, UninstallKeyAppId, 'Inno Setup: App Path', Value) or - QueryStringValue(HKCU, UninstallKeyLegacy, 'Inno Setup: App Path', Value); -end; - -function QueryUninstallString(var Value: string): Boolean; -begin - Result := - QueryStringValue(HKLM64, AppRegistryKey, 'UninstallString', Value) or - QueryStringValue(HKLM64, UninstallKeyAppId, 'UninstallString', Value) or - QueryStringValue(HKLM64, UninstallKeyLegacy, 'UninstallString', Value) or - QueryStringValue(HKLM, AppRegistryKey, 'UninstallString', Value) or - QueryStringValue(HKLM, UninstallKeyAppId, 'UninstallString', Value) or - QueryStringValue(HKLM, UninstallKeyLegacy, 'UninstallString', Value) or - QueryStringValue(HKCU, AppRegistryKey, 'UninstallString', Value) or - QueryStringValue(HKCU, UninstallKeyAppId, 'UninstallString', Value) or - QueryStringValue(HKCU, UninstallKeyLegacy, 'UninstallString', Value); -end; - -function QueryPatchValue(const ValueName: string; var Value: string): Boolean; -begin - Result := - QueryStringValue(HKLM64, AppRegistryKey, ValueName, Value) or - QueryStringValue(HKLM, AppRegistryKey, ValueName, Value) or - QueryStringValue(HKCU, AppRegistryKey, ValueName, Value); -end; - -procedure LoadPreviousInstallState; -begin - PreviousInstallDir := ''; - PreviousUninstallString := ''; - QueryInstallDir(PreviousInstallDir); - QueryUninstallString(PreviousUninstallString); -end; - -function NormalizePath(const Value: string): string; -begin - Result := Trim(Value); - StringChangeEx(Result, '/', '\', True); - while (Length(Result) > 3) and (Result[Length(Result)] = '\') do - Delete(Result, Length(Result), 1); - Result := Uppercase(Result); -end; - -function PathsEqual(const A, B: string): Boolean; -begin - Result := NormalizePath(A) = NormalizePath(B); -end; - -function IsPathUnderBase(const BaseDir, CandidatePath: string): Boolean; -var - NormalizedBase: string; - NormalizedCandidate: string; -begin - NormalizedBase := NormalizePath(BaseDir); - NormalizedCandidate := NormalizePath(CandidatePath); - if (NormalizedBase = '') or (NormalizedCandidate = '') then - begin - Result := False; - exit; - end; - - Result := Pos(AddBackslash(NormalizedBase), AddBackslash(NormalizedCandidate)) = 1; -end; - -function ExtractCommand(const CommandLine: string; var Executable, Parameters: string): Boolean; -var - TrimmedLine: string; - QuoteEnd: Integer; - SpacePos: Integer; -begin - Result := False; - Executable := ''; - Parameters := ''; - TrimmedLine := Trim(CommandLine); - if TrimmedLine = '' then - exit; - - if TrimmedLine[1] = '"' then - begin - Delete(TrimmedLine, 1, 1); - QuoteEnd := Pos('"', TrimmedLine); - if QuoteEnd = 0 then - exit; - Executable := Copy(TrimmedLine, 1, QuoteEnd - 1); - Parameters := Trim(Copy(TrimmedLine, QuoteEnd + 1, MaxInt)); - end - else - begin - SpacePos := Pos(' ', TrimmedLine); - if SpacePos = 0 then - Executable := TrimmedLine - else - begin - Executable := Copy(TrimmedLine, 1, SpacePos - 1); - Parameters := Trim(Copy(TrimmedLine, SpacePos + 1, MaxInt)); - end; - end; - - Result := Executable <> ''; -end; - -function EnsureSilentUninstallParams(const ExistingParams: string): string; -var - UpperParams: string; -begin - Result := Trim(ExistingParams); - UpperParams := Uppercase(Result); - if Pos('/VERYSILENT', UpperParams) = 0 then - Result := Trim(Result + ' /VERYSILENT'); - if Pos('/SUPPRESSMSGBOXES', UpperParams) = 0 then - Result := Trim(Result + ' /SUPPRESSMSGBOXES'); - if Pos('/NORESTART', UpperParams) = 0 then - Result := Trim(Result + ' /NORESTART'); - if Pos('/SP-', UpperParams) = 0 then - Result := Trim(Result + ' /SP-'); -end; - -function RunPreviousUninstaller(): Boolean; -var - UninstallExe: string; - UninstallParams: string; - ResultCode: Integer; -begin - Result := True; - if PreviousUninstallString = '' then - exit; - - if not ExtractCommand(PreviousUninstallString, UninstallExe, UninstallParams) then - begin - Result := False; - exit; - end; - - if not FileExists(UninstallExe) then - begin - Result := False; - exit; - end; - - UninstallParams := EnsureSilentUninstallParams(UninstallParams); - if not Exec(UninstallExe, UninstallParams, ExtractFileDir(UninstallExe), SW_SHOWNORMAL, ewWaitUntilTerminated, ResultCode) then - begin - Result := False; - exit; - end; - - Result := ResultCode = 0; -end; -procedure AppendPatchCleanupWarning(const MessageText: string); -begin - if Trim(MessageText) = '' then - exit; - - if PatchCleanupWarnings <> '' then - PatchCleanupWarnings := PatchCleanupWarnings + #13#10; - PatchCleanupWarnings := PatchCleanupWarnings + MessageText; -end; - -function IsSafeRelativePatchPath(const RelativePath: string): Boolean; -var - NormalizedPath: string; -begin - NormalizedPath := Trim(RelativePath); - StringChangeEx(NormalizedPath, '/', '\', True); - if NormalizedPath = '' then - begin - Result := False; - exit; - end; - - Result := - (Pos('..', NormalizedPath) = 0) and - (Pos(':', NormalizedPath) = 0) and - (NormalizedPath[1] <> '\') and - (NormalizedPath[1] <> '/'); -end; - -procedure RemoveEmptyParentDirs(StartingDir, AppDir: string); -begin - StartingDir := Trim(StartingDir); - AppDir := Trim(AppDir); - - while IsPathUnderBase(AppDir, StartingDir) and (not PathsEqual(StartingDir, AppDir)) do - begin - if not RemoveDir(StartingDir) then - exit; - StartingDir := ExtractFileDir(StartingDir); - end; -end; - -procedure ClearPatchRegistryValues; -begin - RegDeleteValue(HKLM64, AppRegistryKey, 'CUDA_Patch_Installed'); - RegDeleteValue(HKLM64, AppRegistryKey, 'CUDA_Patch_Version'); - RegDeleteValue(HKLM64, AppRegistryKey, 'CUDA_Patch_TargetDir'); - RegDeleteValue(HKLM64, AppRegistryKey, 'CUDA_Patch_FileList'); - RegDeleteValue(HKLM64, AppRegistryKey, 'CUDA_Patch_InstalledAt'); - RegDeleteValue(HKLM, AppRegistryKey, 'CUDA_Patch_Installed'); - RegDeleteValue(HKLM, AppRegistryKey, 'CUDA_Patch_Version'); - RegDeleteValue(HKLM, AppRegistryKey, 'CUDA_Patch_TargetDir'); - RegDeleteValue(HKLM, AppRegistryKey, 'CUDA_Patch_FileList'); - RegDeleteValue(HKLM, AppRegistryKey, 'CUDA_Patch_InstalledAt'); - RegDeleteValue(HKCU, AppRegistryKey, 'CUDA_Patch_Installed'); - RegDeleteValue(HKCU, AppRegistryKey, 'CUDA_Patch_Version'); - RegDeleteValue(HKCU, AppRegistryKey, 'CUDA_Patch_TargetDir'); - RegDeleteValue(HKCU, AppRegistryKey, 'CUDA_Patch_FileList'); - RegDeleteValue(HKCU, AppRegistryKey, 'CUDA_Patch_InstalledAt'); -end; - -function ResolvePatchManifestPath(const AppDir: string): string; -begin - Result := AddBackslash(AppDir) + PatchManifestRelativePath; - QueryPatchValue('CUDA_Patch_FileList', Result); - if not IsPathUnderBase(AppDir, Result) then - Result := AddBackslash(AppDir) + PatchManifestRelativePath; -end; - -procedure CleanupCudaPatchArtifacts; -var - AppDir: string; - ManifestPath: string; - TargetDir: string; - PatchInstalledFlag: string; - Lines: TArrayOfString; - RelativePath: string; - FullPath: string; - I: Integer; -begin - AppDir := ExpandConstant('{app}'); - ManifestPath := ResolvePatchManifestPath(AppDir); - TargetDir := ''; - PatchInstalledFlag := ''; - QueryPatchValue('CUDA_Patch_TargetDir', TargetDir); - QueryPatchValue('CUDA_Patch_Installed', PatchInstalledFlag); - - if (TargetDir <> '') and (not PathsEqual(TargetDir, AppDir)) then - AppendPatchCleanupWarning('检测到旧补丁记录的目标目录与当前卸载目录不一致,已仅清理当前安装目录内的补丁痕迹。'); - - if FileExists(ManifestPath) then - begin - if LoadStringsFromFile(ManifestPath, Lines) then - begin - for I := 0 to GetArrayLength(Lines) - 1 do - begin - RelativePath := Trim(Lines[I]); - if RelativePath <> '' then - begin - if not IsSafeRelativePatchPath(RelativePath) then - AppendPatchCleanupWarning('已跳过异常的补丁清单项: ' + RelativePath) - else - begin - StringChangeEx(RelativePath, '/', '\', True); - FullPath := AddBackslash(AppDir) + RelativePath; - if not IsPathUnderBase(AppDir, FullPath) then - AppendPatchCleanupWarning('已跳过目录外路径: ' + RelativePath) - else if FileExists(FullPath) then - begin - if DeleteFile(FullPath) then - RemoveEmptyParentDirs(ExtractFileDir(FullPath), AppDir) - else - AppendPatchCleanupWarning('无法删除 CUDA 补丁文件: ' + RelativePath); - end; - end; - end; - end; - end - else - AppendPatchCleanupWarning('无法读取 CUDA 补丁清单,部分补丁文件可能需要手动删除。'); - - if DeleteFile(ManifestPath) then - RemoveEmptyParentDirs(ExtractFileDir(ManifestPath), AppDir) - else if FileExists(ManifestPath) then - AppendPatchCleanupWarning('无法删除 CUDA 补丁清单文件。'); - end - else if Trim(PatchInstalledFlag) = '1' then - AppendPatchCleanupWarning('未找到 CUDA 补丁清单,部分补丁文件可能需要手动清理。'); - - ClearPatchRegistryValues; -end; - -function InitializeSetup(): Boolean; -begin - LoadPreviousInstallState; - Result := True; -end; - -procedure InitializeWizard; -begin - if PreviousInstallDir <> '' then - WizardForm.DirEdit.Text := PreviousInstallDir; -end; - -function PrepareToInstall(var NeedsRestart: Boolean): String; -begin - Result := ''; - if PreviousUninstallString = '' then - exit; - - WizardForm.StatusLabel.Caption := '正在卸载旧版本,请稍候...'; - - if not RunPreviousUninstaller() then - Result := '无法自动卸载旧版本,请关闭程序后手动卸载再重试。'; -end; +[Languages] +Name: "chinesesimplified"; MessagesFile: "ChineseSimplified.isl" +Name: "english"; MessagesFile: "compiler:Default.isl" -procedure CurUninstallStepChanged(CurUninstallStep: TUninstallStep); -begin - if CurUninstallStep = usUninstall then - CleanupCudaPatchArtifacts - else if (CurUninstallStep = usPostUninstall) and (PatchCleanupWarnings <> '') then - SuppressibleMsgBox( - 'SuperPicky 卸载时已尽力清理 CUDA 补丁文件,但仍有部分内容可能需要手动删除:' + #13#10#13#10 + PatchCleanupWarnings, - mbInformation, - MB_OK, - IDOK - ); -end; +[Tasks] +Name: "desktopicon"; Description: "{cm:CreateDesktopIcon}"; GroupDescription: "{cm:AdditionalIcons}"; Flags: checkablealone [Files] -Source: "SuperPicky.exe"; DestDir: "{app}"; Flags: ignoreversion +Source: "{#MyAppExeName}"; DestDir: "{app}"; Flags: ignoreversion Source: "_internal\*"; DestDir: "{app}\_internal"; Flags: ignoreversion recursesubdirs createallsubdirs [Icons] -Name: "{group}\SuperPicky"; Filename: "{app}\SuperPicky.exe" -Name: "{commondesktop}\SuperPicky"; Filename: "{app}\SuperPicky.exe"; Tasks: desktopicon - -[Tasks] -Name: "desktopicon"; Description: "{cm:CreateDesktopIcon}"; GroupDescription: "{cm:AdditionalIcons}"; Flags: checkablealone +Name: "{autoprograms}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}" +Name: "{autodesktop}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; Tasks: desktopicon [Run] -Filename: "{app}\SuperPicky.exe"; Description: "{cm:LaunchProgram,SuperPicky}"; Flags: nowait postinstall skipifsilent +Filename: "{app}\{#MyAppExeName}"; Description: "{cm:LaunchProgram,{#StringChange(MyAppName, '&', '&&')}}"; Flags: nowait postinstall skipifsilent Filename: "https://superpicky.app/"; Description: "访问项目网站"; Flags: postinstall skipifsilent shellexec -[Languages] -Name: "chinesesimplified"; MessagesFile: "ChineseSimplified.isl" -Name: "english"; MessagesFile: "compiler:Default.isl" +[UninstallDelete] +Type: filesandordirs; Name: "{app}\_internal" diff --git a/inno/SuperPicky_CUDA_Patch.iss b/inno/SuperPicky_CUDA_Patch.iss index d67e9b7..8f2c539 100644 --- a/inno/SuperPicky_CUDA_Patch.iss +++ b/inno/SuperPicky_CUDA_Patch.iss @@ -1,11 +1,26 @@ +; SuperPicky CUDA 补丁安装脚本 +; SuperPicky CUDA Patch installer script +; Non-commercial use only + +#define MyAppName "SuperPicky" +#define MyAppVersion "unknown" +#define MyAppPublisher "JamesPhotography" +#define MyAppURL "superpicky.app" +#define MyAppExeName "SuperPicky.exe" +#define MyAppCommitHash "unknown" +#define OutputBaseFilename "SuperPicky_CUDA_Patch_Win64_{#MyAppVersion}_{#MyAppCommitHash}" + [Setup] AppId=SuperPicky.CUDAPatch -AppName=SuperPicky CUDA Patch -AppVersion=4.2.0-113b079 +AppName={#MyAppName} CUDA Patch +AppVersion={#MyAppVersion} +AppVerName={#MyAppName} CUDA Patch {#MyAppVersion} +AppPublisher={#MyAppPublisher} +AppPublisherURL={#MyAppURL} DefaultDirName={autopf}\SuperPicky DefaultGroupName=SuperPicky -AppPublisherURL=https://superpicky.app/ -OutputBaseFilename=SuperPicky_CUDA_Patch_Win64_4.2.0_113b079 +OutputDir=output +OutputBaseFilename={#OutputBaseFilename} Compression=lzma2/ultra64 LZMAUseSeparateProcess=yes LZMADictionarySize=1048576 @@ -15,16 +30,13 @@ CreateAppDir=yes Uninstallable=no SetupIconFile=img\icon.ico WizardStyle=modern +WizardImageFile=img\icon.png +WizardSmallImageFile=img\icon.png DisableProgramGroupPage=yes DisableDirPage=no DisableReadyPage=no DisableFinishedPage=no DirExistsWarning=no -VersionInfoCompany=https://superpicky.app/ -WizardImageFile=img\icon.png -WizardSmallImageFile=img\icon.png -AlwaysShowComponentsList=no -AlwaysShowGroupOnReadyPage=no ArchitecturesAllowed=x64compatible ArchitecturesInstallIn64BitMode=x64compatible PrivilegesRequired=admin @@ -151,11 +163,11 @@ begin end; [Files] -Source: "SuperPicky.exe"; DestDir: "{app}"; Flags: ignoreversion +Source: "{#MyAppExeName}"; DestDir: "{app}"; Flags: ignoreversion Source: "_internal\*"; DestDir: "{app}\_internal"; Flags: ignoreversion recursesubdirs createallsubdirs [Run] -Filename: "{app}\SuperPicky.exe"; Description: "{cm:LaunchProgram,SuperPicky}"; Flags: nowait postinstall skipifsilent +Filename: "{app}\{#MyAppExeName}"; Description: "{cm:LaunchProgram,{#StringChange(MyAppName, '&', '&&')}}"; Flags: nowait postinstall skipifsilent [Languages] Name: "chinesesimplified"; MessagesFile: "ChineseSimplified.isl" diff --git a/locales/en_US.json b/locales/en_US.json index 4060f4d..7d9adf9 100644 --- a/locales/en_US.json +++ b/locales/en_US.json @@ -25,7 +25,8 @@ "language": "Language", "lang_zh": "简体中文", "lang_en": "English", - "check_update": "Online Update...", + "check_update": "Auto Update...", + "environment_repair": "Environment Repair...", "background_mode": "Run in Background", "background_mode_title": "Background Mode", "background_mode_msg": "App will enter background mode\n\n• Bird ID service continues running\n• Lightroom plugin works normally\n• Reopen app to restore interface\n\nNote: Server uses ~250MB memory", @@ -103,6 +104,7 @@ "note_burst": "📸 Smart Burst: Merged to highest-rated photo's folder (Min 4 photos)" }, "messages": { + "initialization_required": "Initialization is not finished yet. Complete first-run resource setup before processing.", "select_dir_first": "Please select a source folder first", "processing": "Culling in progress, please wait...", "reset_confirm": "⚠️ Reset will clear all EXIF tags and temporary files. Continue?", @@ -438,8 +440,10 @@ "ram_low": "⚠️ Low RAM: {free} GB — model loading may be slow", "exiftool_error_title": "ExifTool Unavailable", "exiftool_error_msg": "ExifTool could not be started. Star ratings will not be written to EXIF metadata.\n\nError: {error}\n\nPlease reinstall the app or check for file corruption.", + "dangerous_dir_title": "Protected Directory", + "dangerous_dir_msg": "For safety, this directory cannot be scanned:\n{directory}\n\nReason: {reason}\n\nPlease choose a specific photo directory instead of a system directory, settings directory, or drive root.", "no_photos_title": "No Photos Found", - "no_photos_msg": "No supported photo files (JPG / RAW / HEIF) found in:\n{directory}", + "no_photos_msg": "No supported photo files (JPG / RAW / HEIF) were found in this directory or its subdirectories:\n{directory}", "models_still_loading": "⚠️ Models are still loading in background — first photo may process slowly" }, "stats": { @@ -842,6 +846,16 @@ "update_center_btn_clear_patch": "Clear Patch", "update_center_checking": "Checking..." }, + "repair": { + "window_title": "Environment Repair", + "summary": "Check and repair the private Python runtime, Torch, models, and database resources required by the current installation. This does not rerun first-launch preferences.", + "start": "Checking current environment...", + "running": "Repairing environment...", + "success": "Environment repair completed", + "failed": "Environment repair failed", + "retry": "Retry", + "log_retry": "Starting environment repair..." + }, "about": { "window_title": "About James", "subtitle": "AI Bird Photo Culling Tool", @@ -1021,6 +1035,74 @@ "custom": "Custom", "current_label": "Current: {level}" }, + "onboarding": { + "window_title": "Welcome to SuperPicky", + "previous": "Previous", + "next": "Next", + "finish": "Finish", + "start_initialization": "Start Initialization", + "welcome_badge": "WELCOME", + "welcome_title": "Welcome to SuperPicky", + "welcome_subtitle": "Finish setup in three quick steps and start culling.", + "welcome_feature_title": "A smoother default workflow for bird photographers", + "welcome_feature_desc": "We'll set your update preference and photography level now, and you can change both later in Settings.", + "lite_welcome_title": "Welcome to SuperPicky", + "lite_welcome_subtitle": "The lightweight build will prepare its runtime, models, and databases on first launch.", + "lite_welcome_hint": "You only need to confirm your preferences. The rest of the setup will run automatically.", + "update_title": "Enable Auto Update", + "update_subtitle": "If you turn this off, initialization will also skip automatic patch and component update checks.", + "update_enabled_title": "Allow auto update", + "update_enabled_desc": "Initialization can check patches now, and later startup can check automatically too", + "update_disabled_title": "Do not auto update", + "update_disabled_desc": "Neither initialization nor normal startup will check automatically", + "update_hint": "You can change this later in the Update Center.", + "skill_title": "Choose Your Photography Level", + "skill_subtitle": "This applies a recommended threshold preset first. You can still adjust it later in Settings.", + "skill_hint": "For the first run, the default recommendation is usually the safest choice.", + "features_title": "AI Models and Resources", + "features_subtitle": "Initialization will automatically install the following AI models and database resources. No manual selection is required.", + "features_hint": "Windows Lite keeps its runtime under the app installation directory. macOS Lite uses the bundled runtime.", + "feature_core_detection_label": "Core culling - Install the main detection model and base classification capability", + "feature_quality_label": "Aesthetic scoring - Download the TOPIQ quality model", + "feature_keypoint_label": "Keypoint detection - Download the keypoint model for finer detail checks", + "feature_flight_label": "Flight detection - Download the flight recognition model", + "feature_birdid_label": "Bird identification - Download Bird ID databases and related resources", + "model_core_detection_title": "Core Culling Model", + "model_core_detection_desc": "Installs the main detection model and base classification capability used to find birds and run the core culling flow.", + "model_quality_title": "Aesthetic Scoring Model", + "model_quality_desc": "Installs the TOPIQ quality model used to assess the overall visual quality and aesthetic appeal of photos.", + "model_keypoint_title": "Keypoint Detection Model", + "model_keypoint_desc": "Installs the keypoint model used to analyze bird details and important body landmarks for finer judgments.", + "model_flight_title": "Flight Detection Model", + "model_flight_desc": "Installs the flight recognition model used to detect flying poses and support flight-specific culling behavior.", + "model_birdid_title": "Bird Identification Resources", + "model_birdid_desc": "Installs Bird ID databases and related resources used for species identification and metadata assistance later.", + "runtime_status_title": "Runtime Check Status", + "runtime_status_hint": "If the runtime check does not pass, initialization will repair the missing runtime automatically.", + "runtime_status_path": "Runtime directory: {path}", + "runtime_status_item_ready": "✅ Runtime check passed", + "runtime_status_item_pending": "✅ Runtime will be completed during initialization", + "runtime_status_item_variant": "✅ Target runtime: {variant}", + "runtime_status_item_source": "✅ Detection result: {detail}", + "runtime_status_result_ready": "Complete", + "runtime_status_result_pending": "Pending repair", + "runtime_status_policy_windows": "✅ Windows Lite installs its runtime into _internal/runtime_env under the app installation directory", + "runtime_status_policy_mac": "✅ macOS Lite uses the runtime bundled with the app and does not download it separately", + "runtime_hint_cpu": "CPU runtime will be installed automatically.", + "runtime_hint_cuda": "CUDA support was detected. Initialization will prefer the CUDA runtime.", + "runtime_hint_mac": "This is macOS. The Lite build will use the bundled runtime directly.", + "runtime_check_passed": "Runtime check passed.", + "initialization_title": "Initializing", + "initialization_waiting": "Waiting to start...", + "initialization_failed": "Initialization failed", + "initialization_interrupted": "Initialization was interrupted. The next launch will try to repair the environment automatically.", + "close_confirm_title": "Initialization Is Not Finished", + "close_confirm_message": "Exiting now will interrupt initialization. On the next launch, SuperPicky will clean up leftovers and try to repair the environment automatically.", + "close_confirm_exit": "Exit and Repair Later", + "close_confirm_continue": "Continue Initialization", + "log_start": "Starting initialization...", + "log_retry": "Retrying initialization..." + }, "browser": { "title": "Selection Results Browser", "no_db": "Report database not found", diff --git a/locales/zh_CN.json b/locales/zh_CN.json index 9fdcbf8..fcb313d 100644 --- a/locales/zh_CN.json +++ b/locales/zh_CN.json @@ -25,7 +25,8 @@ "language": "界面语言", "lang_zh": "简体中文", "lang_en": "English", - "check_update": "在线更新...", + "check_update": "自动更新...", + "environment_repair": "环境修复...", "background_mode": "后台运行 (保持识鸟服务)", "background_mode_title": "后台模式", "background_mode_msg": "应用将进入后台模式\n\n• 识鸟服务继续在后台运行\n• Lightroom 插件可以正常使用\n• 再次打开应用可恢复界面\n\n提示:服务器内存占用约 250MB", @@ -103,6 +104,7 @@ "note_burst": "📸 连拍策略:自动合并至最高分照片所在目录(少于4张不分组)" }, "messages": { + "initialization_required": "初始化尚未完成。请先完成首次资源准备后再开始处理。", "select_dir_first": "请先选择照片目录", "processing": "正在处理中,请稍候...", "reset_confirm": "⚠️ 重置将清除所有EXIF标记和临时文件,是否继续?", @@ -438,8 +440,10 @@ "ram_low": "⚠️ 可用内存仅 {free} GB,模型加载可能较慢", "exiftool_error_title": "ExifTool 不可用", "exiftool_error_msg": "无法启动 ExifTool,照片评分将无法写入 EXIF 元数据。\n\n错误信息:{error}\n\n请重新安装应用或检查软件完整性。", + "dangerous_dir_title": "目录受保护", + "dangerous_dir_msg": "出于安全考虑,不能扫描这个目录:\n{directory}\n\n原因:{reason}\n\n请选择更具体的照片目录,而不是系统目录、设置目录或磁盘根目录。", "no_photos_title": "未找到照片", - "no_photos_msg": "目录中没有支持的照片文件(JPG / RAW / HEIF):\n{directory}", + "no_photos_msg": "目录及其子目录中没有支持的照片文件(JPG / RAW / HEIF):\n{directory}", "models_still_loading": "⚠️ 模型仍在后台预加载,处理时首张照片可能稍慢" }, "stats": { @@ -816,6 +820,16 @@ "update_center_btn_clear_patch": "清除补丁", "update_center_checking": "检查中..." }, + "repair": { + "window_title": "环境修复", + "summary": "将检查并补齐当前环境所需的 Python 运行时、Torch、模型和数据库资源。该过程不会重跑首次启动偏好设置。", + "start": "准备检查当前环境...", + "running": "环境修复中...", + "success": "环境修复完成", + "failed": "环境修复失败", + "retry": "重试", + "log_retry": "开始环境修复..." + }, "about": { "window_title": "关于", "subtitle": "AI 智能选片助手", @@ -1020,6 +1034,74 @@ "custom": "自选", "current_label": "当前: {level}" }, + "onboarding": { + "window_title": "欢迎使用慧眼选鸟", + "previous": "上一页", + "next": "下一页", + "finish": "完成", + "start_initialization": "开始初始化", + "welcome_badge": "WELCOME", + "welcome_title": "欢迎使用“慧眼选鸟”", + "welcome_subtitle": "与君初相识,犹如故人归。请坐和放宽,好东西就要来了。", + "welcome_feature_title": "为鸟类摄影准备的默认流程", + "welcome_feature_desc": "我们先帮您设好更新偏好和摄影等级,后续也能随时在设置中修改。", + "lite_welcome_title": "欢迎使用 SuperPicky", + "lite_welcome_subtitle": "与君初相识,犹如故人归。请坐和放宽,好东西就要来了。", + "lite_welcome_hint": "轻量版会在首次启动时帮你完成运行时、模型和数据库准备。你只需要确认偏好,其余步骤会自动完成。", + "update_title": "是否启用自动更新", + "update_subtitle": "如果关闭自动更新,初始化阶段也不会自动检查补丁或组件更新。", + "update_enabled_title": "允许自动更新", + "update_enabled_desc": "初始化阶段可检查补丁,日后也可启动时自动检查", + "update_disabled_title": "不自动更新", + "update_disabled_desc": "初始化和日常启动都不会自动检查更新", + "update_hint": "这个选项之后可以在“在线更新”里随时更改。", + "skill_title": "选择摄影等级", + "skill_subtitle": "这会先设置一套推荐阈值,后续依然可以在设置中调整。", + "skill_hint": "建议第一次先用默认推荐值。", + "features_title": "AI 模型与资源", + "features_subtitle": "初始化会自动安装以下 AI 模型与数据库组件,无需手动选择。", + "features_hint": "Windows Lite 会将运行时固定补全到软件安装目录;macOS Lite 使用随包提供的运行时。", + "feature_core_detection_label": "基础筛选 - 安装主检测模型和基础分类能力", + "feature_quality_label": "美学评分 - 下载 TOPIQ 质量评分模型", + "feature_keypoint_label": "关键点检测 - 下载关键点模型,提升细节判断", + "feature_flight_label": "飞鸟检测 - 下载飞鸟识别模型", + "feature_birdid_label": "鸟种识别 - 下载识鸟数据库和相关资源", + "model_core_detection_title": "核心筛选模型", + "model_core_detection_desc": "安装主检测模型和基础分类能力,用于识别画面中的鸟类主体并完成基础筛选。", + "model_quality_title": "美学评分模型", + "model_quality_desc": "安装 TOPIQ 质量评分模型,用于评估照片的整体观感和美学质量。", + "model_keypoint_title": "关键点检测模型", + "model_keypoint_desc": "安装关键点模型,用于分析鸟体细节和关键部位,提升精细判断能力。", + "model_flight_title": "飞鸟检测模型", + "model_flight_desc": "安装飞鸟识别模型,用于识别飞行姿态并支持飞鸟场景的筛选逻辑。", + "model_birdid_title": "鸟种识别资源", + "model_birdid_desc": "安装识鸟数据库和相关资源,用于后续鸟种识别和元数据辅助能力。", + "runtime_status_title": "运行时检测状态", + "runtime_status_hint": "如果检测未通过,初始化会自动补全缺失的运行时环境。", + "runtime_status_path": "运行时目录:{path}", + "runtime_status_item_ready": "✅ 运行时检测通过", + "runtime_status_item_pending": "✅ 运行时将在初始化过程中自动补全", + "runtime_status_item_variant": "✅ 当前目标运行时:{variant}", + "runtime_status_item_source": "✅ 检测结论:{detail}", + "runtime_status_result_ready": "完整", + "runtime_status_result_pending": "待补全", + "runtime_status_policy_windows": "✅ Windows Lite 会把运行时固定安装到软件安装目录下的 _internal/runtime_env", + "runtime_status_policy_mac": "✅ macOS Lite 使用随应用提供的运行时,不再单独下载安装", + "runtime_hint_cpu": "将自动安装 CPU 运行时。", + "runtime_hint_cuda": "检测到可用 CUDA,初始化会优先安装 CUDA 运行时。", + "runtime_hint_mac": "当前系统为 macOS,Lite 版本会直接使用随包提供的运行时。", + "runtime_check_passed": "运行时检测已通过。", + "initialization_title": "初始化中", + "initialization_waiting": "等待开始...", + "initialization_failed": "初始化失败", + "initialization_interrupted": "初始化已中断,下次启动会自动尝试修复环境。", + "close_confirm_title": "初始化尚未完成", + "close_confirm_message": "现在退出会中断当前初始化。软件下次启动时会自动清理残留并尝试修复环境。", + "close_confirm_exit": "退出并稍后修复", + "close_confirm_continue": "继续初始化", + "log_start": "开始初始化...", + "log_retry": "重试初始化..." + }, "browser": { "title": "选鸟结果浏览器", "no_db": "未找到报告数据库", diff --git a/main.py b/main.py index 9399959..ed83cf7 100644 --- a/main.py +++ b/main.py @@ -1,52 +1,115 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- """ -SuperPicky - PySide6 版本入口点 -Version: 4.0.6 - Country Selection Simplification +SuperPicky - PySide6 版本入口点。 +SuperPicky - PySide6 application entrypoint. + +本模块负责最早期的进程初始化、补丁覆盖层注入、运行时自举分流与 Qt 应用启动。 +This module owns early process initialization, patch overlay injection, runtime +bootstrap dispatch, and Qt application startup. """ import sys import os - -# V3.9.3: 修复 macOS PyInstaller 打包后的多进程问题 -# 必须在所有其他导入之前设置 import multiprocessing -if sys.platform == 'darwin': - multiprocessing.set_start_method('spawn', force=True) -# V3.9.4: 防止 PyInstaller 打包后 spawn 模式创建重复进程/窗口 -# 这是 macOS PyInstaller 的标准做法 +from config import ( + get_runtime_app_root, + get_runtime_meipass, + migrate_legacy_ioc_settings, + migrate_old_data, + set_runtime_app_root, +) + +# macOS 的 PyInstaller GUI 进程必须在其他重量级导入前强制使用 `spawn`。 +# macOS PyInstaller GUI processes must force `spawn` before any heavy imports. +if sys.platform == "darwin": + multiprocessing.set_start_method("spawn", force=True) + +# 冻结环境下提前启用 `freeze_support()`,避免子进程重复拉起完整 GUI。 +# Enable `freeze_support()` early so frozen subprocesses do not re-launch the GUI. multiprocessing.freeze_support() -# 确保模块路径正确 sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -# 在线补丁层:优先加载用户数据目录下的 code_updates/(覆盖内置模块) + +def _should_enable_patch_overlay() -> bool: + """ + 仅允许打包环境启用在线补丁覆盖层。 + Only allow the online patch overlay in packaged environments. + """ + return bool(getattr(sys, "frozen", False)) + + def _inject_patch_path(): + """ + 注入在线补丁目录并记录真实应用根目录。 + Inject the online patch directory and record the real application root. + + 补丁覆盖层会把用户配置目录下的 `code_updates/` 放到 `sys.path` 最前面, + 同时保存真实应用根目录,供被覆盖模块继续定位模型、图标和 exiftool。 + The patch overlay prepends `code_updates/` to `sys.path` and stores the real + app root so overridden modules can still resolve models, icons, and exiftool. + """ if sys.platform == "darwin": - _patch_dir = os.path.join(os.path.expanduser("~"), "Library", "Application Support", "SuperPicky", "code_updates") + _patch_dir = os.path.join( + os.path.expanduser("~"), + "Library", + "Application Support", + "SuperPicky", + "code_updates", + ) elif sys.platform == "win32": - _patch_dir = os.path.join(os.path.expanduser("~"), "AppData", "Local", "SuperPicky", "code_updates") + _patch_dir = os.path.join( + os.path.expanduser("~"), "AppData", "Local", "SuperPicky", "code_updates" + ) else: - _patch_dir = os.path.join(os.path.expanduser("~"), ".config", "SuperPicky", "code_updates") - if os.path.isdir(_patch_dir) and _patch_dir not in sys.path: + _patch_dir = os.path.join( + os.path.expanduser("~"), ".config", "SuperPicky", "code_updates" + ) + if _should_enable_patch_overlay() and os.path.isdir(_patch_dir) and _patch_dir not in sys.path: sys.path.insert(0, _patch_dir) - # 记录真实 app 根目录,供补丁中的模块查找资源文件(模型、exiftool 等) - if not hasattr(sys, '_SUPERPICKY_APP_ROOT'): - if hasattr(sys, '_MEIPASS'): - sys._SUPERPICKY_APP_ROOT = sys._MEIPASS + if get_runtime_app_root() is None: + if getattr(sys, "frozen", False) and sys.platform == "win32": + set_runtime_app_root(os.path.dirname(os.path.abspath(sys.executable))) else: - sys._SUPERPICKY_APP_ROOT = os.path.dirname(os.path.abspath(__file__)) + meipass = get_runtime_meipass() + if meipass is not None: + set_runtime_app_root(meipass) + else: + set_runtime_app_root(os.path.dirname(os.path.abspath(__file__))) + + _inject_patch_path() -# Fix Windows console encoding: default cp1252 cannot render emoji/CJK characters, -# causing UnicodeEncodeError crashes on print(). Reconfigure to UTF-8 with replacement -# fallback so all log output survives regardless of the console codepage. + +def _run_runtime_bootstrap_if_requested(): + """ + 在请求时进入运行时自举流程并立即退出当前主入口。 + Enter the runtime bootstrap flow when requested and exit this main entrypoint. + """ + if "--runtime-bootstrap" not in sys.argv[1:]: + return + from core.runtime_bootstrap import run_runtime_bootstrap + + raise SystemExit(run_runtime_bootstrap(sys.argv[1:])) + + +_run_runtime_bootstrap_if_requested() + if sys.platform == "win32": import io def _ensure_utf8_stream(stream): - # PyInstaller windowed mode (`console=False`) may set stdout/stderr to None. + """ + 为 Windows 控制台流兜底成 UTF-8 文本输出。 + Ensure a Windows console stream falls back to UTF-8 text output. + + PyInstaller 的无控制台模式可能把 `stdout/stderr` 设为 `None`, + 而普通控制台也可能仍是非 UTF-8 编码,这里统一兜底避免日志写崩。 + PyInstaller windowed mode may set `stdout/stderr` to `None`, and regular + consoles may still use a non-UTF-8 code page, so normalize both cases here. + """ if stream is None: return open(os.devnull, "w", encoding="utf-8", errors="replace") @@ -77,99 +140,99 @@ def _ensure_utf8_stream(stream): from ui.styles import APP_TOOLTIP_STYLE from tools.system_logger import setup_error_logging -# 尽早捕获未处理异常,写入 superpicky.log(或 config dir fallback) +# 尽早接管未处理异常,确保源码和冻结包都能留下可诊断日志。 +# Install logging early so both source runs and frozen builds preserve diagnostics. setup_error_logging() -# 内存监视器(开发调试用):设置环境变量 SP_MEMORY_MONITOR=1 启用 -# 例:SP_MEMORY_MONITOR=1 python main.py -# 日志写入 /memory_monitor.log +# 启动阶段先完成遗留数据迁移,避免后续模块读到旧路径状态。 +# Finish legacy data migration before later modules observe stale paths. +migrate_old_data() +migrate_legacy_ioc_settings() + _memory_monitor = None if os.environ.get("SP_MEMORY_MONITOR") == "1": from tools.memory_monitor import MemoryMonitor + _memory_monitor = MemoryMonitor(interval=30) -# V3.9.3: 全局窗口引用,防止重复创建 _main_window = None def main(): - """主函数""" + """ + 启动 Qt 应用并创建主窗口。 + Start the Qt application and create the main window. + """ global _main_window - # Fix: macOS GUI launch (double-click / Dock) sets CWD to read-only '/'. - # YOLO attempts to create a 'runs/' dir relative to CWD, which fails with - # [Errno 30] Read-only file system on Intel Macs (CPU inference path). - # Switch to the user home dir so any YOLO cache writes succeed. - if sys.platform == 'darwin': - safe_cwd = os.path.expanduser('~') + # macOS 双击启动 GUI 时 cwd 可能是只读根目录 `/`,需要切回用户目录。 + # macOS GUI launches may start from the read-only `/`, so switch to the home dir. + if sys.platform == "darwin": + safe_cwd = os.path.expanduser("~") os.chdir(safe_cwd) - - # V3.9.3: 检查是否已有 QApplication 实例 + app = QApplication.instance() if app is None: app = QApplication(sys.argv) - else: - print("⚠️ 检测到已存在的 QApplication 实例") - - # 设置应用属性 - # V4.0.5: 动态设置应用名称 + elif not isinstance(app, QApplication): + raise RuntimeError("检测到非 QApplication 的 Qt 应用实例,无法继续启动 GUI。") + from constants import APP_VERSION from core.build_info import COMMIT_HASH - + commit_hash = COMMIT_HASH - if commit_hash == "154984fd": # 默认占位符 - try: - import subprocess - hash_short = subprocess.check_output(['git', 'rev-parse', '--short', 'HEAD']).strip().decode('utf-8') - commit_hash = hash_short - except: - pass + if commit_hash == "154984fd": + try: + import subprocess + + hash_short = ( + subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) + .strip() + .decode("utf-8") + ) + commit_hash = hash_short + except: + pass app.setApplicationName("SuperPicky") app.setApplicationDisplayName(f"慧眼选鸟v{APP_VERSION} ({commit_hash})") app.setOrganizationName("JamesPhotography") app.setOrganizationDomain("jamesphotography.com.au") - # 防止隐藏主窗口(切到结果浏览器时)触发 Qt 自动退出 - # 退出由托盘菜单"退出"或 _quit_app() 显式控制,统一走 aboutToQuit 清理 + # 主窗口会在托盘与子窗口之间显隐切换,不能依赖“最后一个窗口关闭即退出”。 + # The main window may hide while tray or child windows remain active, so do not + # couple process lifetime to the last visible top-level window. app.setQuitOnLastWindowClosed(False) - - # 设置应用图标 + icon_path = os.path.join(os.path.dirname(__file__), "img", "icon.png") if os.path.exists(icon_path): app.setWindowIcon(QIcon(icon_path)) - - # V4.1: Windows 高 DPI 缩放策略 - # Qt6/PySide6 已默认启用 HiDPI,但 RoundingPolicy 默认为 RoundPreferFloor, - # 在 Windows 125%/150% 等非整数缩放下会导致文字/边框轻微模糊。 - # PassThrough 允许使用精确的小数缩放因子,避免像素取整问题。 + + # Windows 非整数 DPI 缩放下使用 PassThrough,避免字体和边框被提前取整。 + # Use PassThrough on fractional Windows DPI scales to avoid premature rounding. if sys.platform == "win32": from PySide6.QtCore import Qt + app.setHighDpiScaleFactorRoundingPolicy( Qt.HighDpiScaleFactorRoundingPolicy.PassThrough ) - - # V4.1: 在 QApplication 级别设置 QToolTip 样式 - # macOS 上 QToolTip 是顶层窗口,不继承 QMainWindow 的样式, - # 在系统浅色模式下会被系统接管为毛玻璃浅色背景 → 文字不可见 + + # QToolTip 属于顶层窗口,需要在 QApplication 级别统一覆盖样式。 + # QToolTip is a top-level window, so its style must be applied at QApplication level. app.setStyleSheet(APP_TOOLTIP_STYLE) - - # V3.9.3: 防止重复创建窗口 + if _main_window is None: _main_window = SuperPickyMainWindow() _main_window.show() bootstrap_telemetry(_main_window, on_ready=_main_window.run_startup_prompts) - # 统一退出清理:无论通过 X / 托盘 / Cmd+Q 退出,都会经由 aboutToQuit 信号 app.aboutToQuit.connect(_main_window._cleanup_on_quit) if _memory_monitor is not None: _memory_monitor.start() app.aboutToQuit.connect(_memory_monitor.stop) else: - print("⚠️ 检测到已存在的主窗口实例") _main_window.raise_() _main_window.activateWindow() - - # 运行事件循环 + sys.exit(app.exec()) diff --git a/requirements_cuda.txt b/requirements_cuda.txt index 1f8b08f..6872cdf 100644 --- a/requirements_cuda.txt +++ b/requirements_cuda.txt @@ -4,6 +4,7 @@ -r requirements_base.txt # PyTorch CUDA 11.8 wheels +--extra-index-url https://mirror.nju.edu.cn/pytorch/whl/cu118/ --extra-index-url https://download.pytorch.org/whl/cu118 torch==2.7.1+cu118 torchvision==0.22.1+cu118 diff --git a/scripts/ci_release.py b/scripts/ci_release.py new file mode 100644 index 0000000..3962dcb --- /dev/null +++ b/scripts/ci_release.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +CI 发布辅助脚本 / CI release helper utilities. + +集中处理 GitHub Actions 中的发布元数据、资产整理、补丁打包和临时文件操作, +避免在 workflow 中堆积大量平台相关 shell 逻辑。 + +This module centralizes release metadata handling, asset collection, patch +packaging, and temporary file operations for GitHub Actions workflows so the +workflow can stay small and shell-agnostic. +""" + +from __future__ import annotations + +import argparse +import base64 +import glob +import json +import os +import shutil +import sys +import zipfile +from datetime import datetime, timezone +from pathlib import Path +from typing import Iterable + + +ROOT_DIR = Path(__file__).resolve().parent.parent +PATCH_ITEMS = ( + "constants.py", + "advanced_config.py", + "ai_model.py", + "birdid_server.py", + "birdid_cli.py", + "iqa_scorer.py", + "post_adjustment_engine.py", + "server_manager.py", + "superpicky_cli.py", + "topiq_model.py", + "tools", + "core", + "ui", + "birdid", + "locales", +) +PATCH_EXCLUDED_DIRS = {"__pycache__"} +PATCH_EXCLUDED_SUFFIXES = {".pyc", ".pyo"} + + +def configure_stdio() -> None: + """ + 强制标准输出为 UTF-8 / Force UTF-8 stdio when possible. + """ + + for stream in (sys.stdout, sys.stderr): + reconfigure = getattr(stream, "reconfigure", None) + if callable(reconfigure): + reconfigure(encoding="utf-8", errors="strict") + + +def optional_text(value: str | None) -> str | None: + """ + 规范化可选字符串 / Normalize optional strings. + """ + + if value is None: + return None + normalized = value.strip() + return normalized or None + + +def repo_path(raw_path: str | Path) -> Path: + """ + 将相对仓库路径解析为绝对路径 / Resolve repository-relative paths. + """ + + path = Path(raw_path) + if path.is_absolute(): + return path + return ROOT_DIR / path + + +def write_github_outputs(values: dict[str, str], output_path: str | None = None) -> None: + """ + 写入 GitHub Actions step outputs / Write GitHub Actions step outputs. + """ + + target = optional_text(output_path) or optional_text(os.environ.get("GITHUB_OUTPUT")) + if not target: + return + + path = Path(target) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("a", encoding="utf-8") as handle: + for key, value in values.items(): + handle.write(f"{key}={value}\n") + + +def infer_release_tag(event_name: str | None, input_version: str | None, ref_name: str | None) -> str: + """ + 计算 release tag / Resolve the release tag. + """ + + raw_version = input_version if optional_text(event_name) == "workflow_dispatch" else ref_name + normalized = optional_text(raw_version) + if not normalized: + raise RuntimeError("Release version is required.") + return normalized if normalized.startswith("v") else f"v{normalized}" + + +def cmd_resolve_metadata(args: argparse.Namespace) -> int: + """ + 解析 release 元数据 / Resolve release metadata. + """ + + tag = infer_release_tag(args.event_name, args.input_version, args.ref_name) + values = {"tag": tag, "name": f"SuperPicky {tag}"} + write_github_outputs(values, args.github_output) + print(json.dumps(values, ensure_ascii=False)) + return 0 + + +def ensure_single_match(pattern: str) -> Path: + """ + 确保 glob 模式只匹配一个文件 / Ensure a glob matches exactly one file. + """ + + matches = [Path(item) for item in glob.glob(pattern) if Path(item).is_file()] + if len(matches) != 1: + raise RuntimeError(f"Expected exactly one asset for pattern '{pattern}', found {len(matches)}.") + return matches[0] + + +def cmd_collect_assets(args: argparse.Namespace) -> int: + """ + 收集 release 资产 / Collect release assets. + """ + + output_dir = repo_path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + copied_files: list[str] = [] + for pattern in args.pattern: + source_file = ensure_single_match(str(repo_path(pattern))) + destination = output_dir / source_file.name + shutil.copy2(source_file, destination) + copied_files.append(destination.name) + + print(json.dumps({"output_dir": str(output_dir), "files": copied_files}, ensure_ascii=False)) + return 0 + + +def read_app_version() -> str: + """ + 读取应用版本号 / Read the application version. + """ + + constants_path = ROOT_DIR / "constants.py" + content = constants_path.read_text(encoding="utf-8") + marker = 'APP_VERSION = ' + for line in content.splitlines(): + if marker not in line: + continue + _, raw_value = line.split(marker, 1) + return raw_value.strip().strip('"').strip("'") + raise RuntimeError("Unable to read APP_VERSION from constants.py") + + +def infer_release_channel(tag: str) -> str: + """ + 根据 tag 判断渠道 / Infer release channel from tag. + """ + + return "nightly" if "-rc" in tag.lower() else "official" + + +def iter_patch_files() -> Iterable[tuple[Path, Path]]: + """ + 枚举补丁文件 / Enumerate patch files. + """ + + for item in PATCH_ITEMS: + source_path = ROOT_DIR / item + if not source_path.exists(): + continue + if source_path.is_file(): + if source_path.name == "main.py" or source_path.suffix in PATCH_EXCLUDED_SUFFIXES: + continue + yield source_path, source_path.relative_to(ROOT_DIR) + continue + + for file_path in sorted(source_path.rglob("*")): + if not file_path.is_file(): + continue + if any(part in PATCH_EXCLUDED_DIRS for part in file_path.parts): + continue + if file_path.suffix in PATCH_EXCLUDED_SUFFIXES: + continue + yield file_path, file_path.relative_to(ROOT_DIR) + + +def write_patch_zip(zip_path: Path) -> None: + """ + 创建补丁 ZIP / Create the patch ZIP. + """ + + zip_path.parent.mkdir(parents=True, exist_ok=True) + zip_path.unlink(missing_ok=True) + + with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as archive: + for file_path, relative_path in iter_patch_files(): + archive.write(file_path, arcname=str(relative_path).replace("\\", "/")) + + +def write_patch_meta(meta_path: Path, patch_version: str, base_version: str, release_channel: str) -> None: + """ + 写入补丁元数据 / Write patch metadata. + """ + + payload = { + "patch_version": patch_version, + "base_version": base_version, + "release_channel": release_channel, + "target_channels": [release_channel], + "applied_at": datetime.now(timezone.utc).isoformat(), + } + meta_path.parent.mkdir(parents=True, exist_ok=True) + meta_path.write_text(json.dumps(payload, ensure_ascii=False, indent=2) + "\n", encoding="utf-8") + + +def cmd_build_patch(args: argparse.Namespace) -> int: + """ + 生成补丁 ZIP 和 patch_meta.json / Generate patch ZIP and patch_meta.json. + """ + + patch_version = infer_release_tag("workflow_dispatch", args.patch_version, args.patch_version) + release_channel = infer_release_channel(patch_version) + base_version = args.base_version or read_app_version() + output_dir = repo_path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + zip_path = output_dir / f"code_patch_{patch_version}.zip" + meta_path = output_dir / "patch_meta.json" + write_patch_zip(zip_path) + write_patch_meta(meta_path, patch_version, base_version, release_channel) + + values = {"patch_zip": str(zip_path), "patch_meta": str(meta_path)} + write_github_outputs(values, args.github_output) + print(json.dumps(values, ensure_ascii=False)) + return 0 + + +def decode_secret_value(env_name: str) -> str: + """ + 读取环境变量中的 secret / Read a secret value from the environment. + """ + + value = os.environ.get(env_name, "") + if not value: + raise RuntimeError(f"Environment variable {env_name} is required.") + return value + + +def cmd_materialize_secret_file(args: argparse.Namespace) -> int: + """ + 将 secret 落盘为文件 / Materialize a secret into a file. + """ + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + raw_value = decode_secret_value(args.env_name) + data = base64.b64decode(raw_value) if args.decode_base64 else raw_value.encode("utf-8") + output_path.write_bytes(data) + + values = {"materialized_path": str(output_path)} + write_github_outputs(values, args.github_output) + print(json.dumps(values, ensure_ascii=False)) + return 0 + + +def remove_path(path: Path) -> None: + """ + 删除文件或目录 / Remove a file or directory. + """ + + if path.is_dir() and not path.is_symlink(): + shutil.rmtree(path, ignore_errors=True) + elif path.exists() or path.is_symlink(): + path.unlink(missing_ok=True) + + +def cmd_cleanup_paths(args: argparse.Namespace) -> int: + """ + 清理路径 / Clean up files or directories. + """ + + for raw_path in args.path: + remove_path(Path(raw_path)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + """ + 创建命令行解析器 / Build the command-line parser. + """ + + parser = argparse.ArgumentParser(description="SuperPicky CI 发布辅助脚本") + subparsers = parser.add_subparsers(dest="command", required=True) + + resolve_parser = subparsers.add_parser("resolve-metadata", help="解析 release tag 和 name") + resolve_parser.add_argument("--event-name", default=os.environ.get("GITHUB_EVENT_NAME")) + resolve_parser.add_argument("--input-version", default=os.environ.get("INPUT_VERSION")) + resolve_parser.add_argument("--ref-name", default=os.environ.get("GITHUB_REF_NAME")) + resolve_parser.add_argument("--github-output", help="可选,显式指定 GITHUB_OUTPUT 文件路径") + resolve_parser.set_defaults(func=cmd_resolve_metadata) + + collect_parser = subparsers.add_parser("collect-assets", help="按 glob 收集 release 资产") + collect_parser.add_argument("--output-dir", required=True, help="资产输出目录") + collect_parser.add_argument("--pattern", action="append", required=True, help="需要匹配的文件 glob,可重复指定") + collect_parser.set_defaults(func=cmd_collect_assets) + + patch_parser = subparsers.add_parser("build-patch", help="生成 code patch ZIP 与 patch_meta.json") + patch_parser.add_argument("--output-dir", required=True, help="补丁输出目录") + patch_parser.add_argument("--patch-version", required=True, help="补丁版本号,例如 v4.2.0 或 4.2.0") + patch_parser.add_argument("--base-version", help="可选,显式指定 base version") + patch_parser.add_argument("--github-output", help="可选,显式指定 GITHUB_OUTPUT 文件路径") + patch_parser.set_defaults(func=cmd_build_patch) + + secret_parser = subparsers.add_parser("materialize-secret-file", help="将环境变量写入文件") + secret_parser.add_argument("--env-name", required=True, help="secret 所在环境变量名") + secret_parser.add_argument("--output", required=True, help="输出文件路径") + secret_parser.add_argument("--decode-base64", action="store_true", help="按 Base64 解码后写入") + secret_parser.add_argument("--github-output", help="可选,显式指定 GITHUB_OUTPUT 文件路径") + secret_parser.set_defaults(func=cmd_materialize_secret_file) + + cleanup_parser = subparsers.add_parser("cleanup-paths", help="删除文件或目录") + cleanup_parser.add_argument("--path", action="append", required=True, help="待删除的路径,可重复指定") + cleanup_parser.set_defaults(func=cmd_cleanup_paths) + + return parser + + +def main() -> int: + """ + 脚本入口 / Script entrypoint. + """ + + configure_stdio() + parser = build_parser() + args = parser.parse_args() + return int(args.func(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) \ No newline at end of file diff --git a/scripts/download_models.py b/scripts/download_models.py index 224bb86..f990126 100644 --- a/scripts/download_models.py +++ b/scripts/download_models.py @@ -1,113 +1,807 @@ +""" +Model and resource download helpers for lightweight initialization. + +This module prepares model files and local fallback resources needed by the +welcome onboarding flow. It emits structured progress events so callers can +aggregate real byte progress, item-level progress, and source retry state +without scraping ad-hoc log text. + +轻量化初始化所需的模型与资源下载辅助模块。 + +此模块负责准备欢迎引导流程所需的模型文件与本地回退资源,并发出结构化进度事件, +以便调用方能够聚合真实字节进度、条目级进度以及镜像重试状态,而不必再解析零散日志文本。 +""" + +import hashlib +import importlib +import logging import os +import random import sys -import logging +import time +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, cast + + +def _reconfigure_text_stream(stream: object) -> None: + """Use UTF-8 output when the active stream implementation supports it.""" + reconfigure = getattr(stream, "reconfigure", None) + if callable(reconfigure): + reconfigure(encoding="utf-8", errors="strict") + + +_reconfigure_text_stream(sys.stdout) +_reconfigure_text_stream(sys.stderr) -if hasattr(sys.stdout, "reconfigure"): - sys.stdout.reconfigure(encoding="utf-8", errors="strict") -if hasattr(sys.stderr, "reconfigure"): - sys.stderr.reconfigure(encoding="utf-8", errors="strict") +HF_MIRROR_ENDPOINT = "https://hf-mirror.com" +HF_OFFICIAL_ENDPOINT = "https://huggingface.co" +os.environ["HF_ENDPOINT"] = HF_MIRROR_ENDPOINT +os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" +os.environ["HF_HUB_DISABLE_XET"] = "1" +os.environ["DO_NOT_TRACK"] = "1" try: from huggingface_hub import hf_hub_download except ImportError: - print("Error: huggingface_hub is not installed. Please run `pip install huggingface_hub tqdm` first.") - sys.exit(1) + hf_hub_download = None + +try: + from tqdm.auto import tqdm as tqdm_base +except ImportError: + tqdm_base = None + +try: + from core.source_probe import pick_best_source, probe_sources +except Exception: + pick_best_source = None + probe_sources = None + +from core.initialization_progress import ( + InitializationProgressEvent, + PROGRESS_KIND_DOWNLOAD, + STAGE_DOWNLOADING, +) logging.basicConfig( level=logging.INFO, - format='%(asctime)s - %(levelname)s - %(message)s', + format="%(asctime)s - %(levelname)s - %(message)s", stream=sys.stdout, ) -# Define the models and their destination directories relative to the project root +DOWNLOAD_ENDPOINTS = [ + ("hf-mirror", HF_MIRROR_ENDPOINT), + ("official", HF_OFFICIAL_ENDPOINT), +] + MODELS_TO_DOWNLOAD = [ { + "resource_id": "classification_model", "category": "Classification", "repo_id": "jamesphotography/SuperPicky-models", "filename": "model20240824.pth", "dest_dir": "models", + "packaged_dest_dir": "models", + "feature_tags": ["core_detection", "birdid"], + "required": True, + "sha256": None, }, { + "resource_id": "flight_model", "category": "Flight Detection", "repo_id": "jamesphotography/SuperPicky-models", "filename": "superFlier_efficientnet.pth", "dest_dir": "models", + "packaged_dest_dir": "models", + "feature_tags": ["flight"], + "required": False, + "sha256": None, }, { + "resource_id": "keypoint_model", "category": "Keypoint Detection", "repo_id": "jamesphotography/SuperPicky-models", "filename": "cub200_keypoint_resnet50_slim.pth", "dest_dir": "models", + "packaged_dest_dir": "models", + "feature_tags": ["keypoint"], + "required": False, + "sha256": None, }, { + "resource_id": "avonet_database", "category": "Database", "repo_id": "jamesphotography/SuperPicky-models", "filename": "avonet.db", "dest_dir": "birdid/data", + "feature_tags": ["birdid"], + "required": False, + "sha256": None, }, { + "resource_id": "quality_model", "category": "Quality Assessment", "repo_id": "chaofengc/IQA-PyTorch-Weights", "filename": "cfanet_iaa_ava_res50-3cd62bb3.pth", "dest_dir": "models", - } + "packaged_dest_dir": "models", + "feature_tags": ["quality"], + "required": False, + "sha256": None, + }, + { + "resource_id": "yolo_segmentation", + "category": "Segmentation", + "repo_id": "jamesphotography/SuperPicky-models", + "filename": "yolo11l-seg.pt", + "dest_dir": "models", + "packaged_dest_dir": "models", + "feature_tags": ["core_detection"], + "required": True, + "sha256": None, + }, ] +OPTIONAL_LOCAL_RESOURCES = [ + { + "resource_id": "bird_reference_sqlite", + "filename": "bird_reference.sqlite", + "dest_dir": "birdid/data", + "feature_tags": ["birdid"], + "required": False, + "sha256": None, + "copy_only": True, + }, + { + "resource_id": "birdname_db", + "filename": "birdname.db", + "dest_dir": "ioc", + "feature_tags": ["birdid"], + "required": False, + "sha256": None, + "copy_only": True, + }, +] + + +def get_project_root() -> Path: + script_dir = os.path.dirname(os.path.abspath(__file__)) + return Path(os.path.abspath(os.path.join(script_dir, ".."))) + + +def _format_download_error(exc: Exception) -> str: + message = str(exc).strip() + if not message: + message = repr(exc) + return f"{type(exc).__name__}: {message}" + + +def _sha256_file(file_path: Path, chunk_size: int = 1024 * 1024) -> str: + digest = hashlib.sha256() + with file_path.open("rb") as handle: + while True: + chunk = handle.read(chunk_size) + if not chunk: + break + digest.update(chunk) + return digest.hexdigest() + + +def verify_resource(resource: Dict[str, Any], file_path: Path) -> bool: + expected_sha256 = resource.get("sha256") + if not expected_sha256: + return file_path.exists() + return file_path.exists() and _sha256_file(file_path) == expected_sha256.lower() + + +def _resolve_hf_endpoints() -> List[Tuple[str, str]]: + if probe_sources is None or pick_best_source is None: + return list(DOWNLOAD_ENDPOINTS) + + probe_input = [{"name": name, "url": endpoint} for name, endpoint in DOWNLOAD_ENDPOINTS] + results = probe_sources("huggingface-models", probe_input) + successful = [item for item in results if item.ok] + if not successful: + return list(DOWNLOAD_ENDPOINTS) + + non_official = [item for item in successful if "official" not in item.name.lower()] + preferred = non_official or successful + ordered_results = sorted(preferred, key=lambda item: (item.total_ms, item.first_byte_ms)) + if non_official: + return [(item.name, item.url) for item in ordered_results] + + return [(item.name, item.url) for item in ordered_results] + + +def _resource_matches_selection(resource: Dict[str, Any], selected: set[str]) -> bool: + if resource.get("required"): + return True + feature_tags = set(resource.get("feature_tags", [])) + return not selected or bool(feature_tags & selected) + + +def _iter_selected_resources( + resources: Iterable[Dict[str, Any]], + selected_features: Optional[Iterable[str]], +) -> Iterator[Dict[str, Any]]: + selected = set(selected_features or []) + for item in resources: + if _resource_matches_selection(item, selected): + yield dict(item) + + +def resolve_download_plan( + selected_features: Optional[Iterable[str]] = None, + *, + include_optional_local: bool = True, +) -> List[Dict[str, Any]]: + plan = list(_iter_selected_resources(MODELS_TO_DOWNLOAD, selected_features)) + if include_optional_local: + plan.extend(_iter_selected_resources(OPTIONAL_LOCAL_RESOURCES, selected_features)) + return plan + + +def _emit_resource_progress( + progress_cb: Optional[Callable[[InitializationProgressEvent], None]], + event: InitializationProgressEvent, +) -> None: + if progress_cb: + progress_cb(event) + + +def _build_resource_progress_event( + resource: Dict[str, Any], + message: str, + *, + ratio: float | None = None, + bytes_done: int | None = None, + bytes_total: int | None = None, + source: str | None = None, + is_terminal: bool = False, +) -> InitializationProgressEvent: + """ + Create a structured progress payload for one resource update. + + 为单个资源更新创建结构化进度负载。 + """ + return InitializationProgressEvent( + stage=STAGE_DOWNLOADING, + progress_kind=PROGRESS_KIND_DOWNLOAD, + message=message, + ratio=ratio, + bytes_done=bytes_done, + bytes_total=bytes_total, + resource_id=resource.get("resource_id"), + source=source, + is_terminal=is_terminal, + ) + + +def resolve_resource_destination_dir(project_root: Path, resource: Dict[str, Any]) -> Path: + dest_dir = resource["dest_dir"] + if getattr(sys, "frozen", False) and sys.platform == "win32": + dest_dir = resource.get("packaged_dest_dir", dest_dir) + return project_root / dest_dir + + +def _copy_local_resource( + resource: Dict[str, Any], + project_root: Path, + progress_cb: Optional[Callable[[InitializationProgressEvent], None]] = None, +) -> Optional[Path]: + """ + Copy a packaged local fallback resource into the expected destination. + + 将打包时附带的本地回退资源复制到目标目录。 + """ + filename = resource["filename"] + dest_dir = resolve_resource_destination_dir(project_root, resource) + dest_dir.mkdir(parents=True, exist_ok=True) + destination = dest_dir / filename + + if destination.exists(): + existing_size = destination.stat().st_size + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{filename} already present", + ratio=1.0, + bytes_done=existing_size, + bytes_total=existing_size, + is_terminal=True, + ), + ) + return destination + + local_candidates = [ + resolve_resource_destination_dir(project_root, resource) / filename, + project_root / "resources" / resource["dest_dir"] / filename, + ] + for candidate in local_candidates: + if candidate.exists(): + if candidate.resolve() != destination.resolve(): + destination.write_bytes(candidate.read_bytes()) + copied_size = destination.stat().st_size + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{filename} copied from local fallback", + ratio=1.0, + bytes_done=copied_size, + bytes_total=copied_size, + is_terminal=True, + ), + ) + return destination + return None + + +def _estimate_remote_file_size(repo_id: str, filename: str) -> int | None: + """ + Estimate remote file size with `hf_hub_download(..., dry_run=True)`. + + 通过 `hf_hub_download(..., dry_run=True)` 估算远端文件大小。 + """ + global hf_hub_download + if hf_hub_download is None: + return None + + for _source_name, endpoint in _resolve_hf_endpoints(): + try: + _configure_hf_client_for_endpoint(endpoint) + dry_run_info = cast( + Any, + hf_hub_download( + repo_id=repo_id, + filename=filename, + endpoint=endpoint, + dry_run=True, + ), + ) + file_size = getattr(dry_run_info, "file_size", None) + if isinstance(file_size, int) and file_size > 0: + return file_size + except Exception: + continue + return None + + +def _build_download_tqdm_class( + resource: Dict[str, Any], + source_name: str, + expected_bytes: int | None, + progress_cb: Optional[Callable[[InitializationProgressEvent], None]], +): + """ + Create a tqdm subclass that forwards byte-level download updates. + + 创建一个把字节级下载更新转发为结构化事件的 tqdm 子类。 + """ + if progress_cb is None or tqdm_base is None: + return None + + class ResourceDownloadTqdm(tqdm_base): + """ + Progress tracker used internally by `hf_hub_download`. + + `hf_hub_download` 内部使用的进度跟踪器。 + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + total = getattr(self, "total", None) or expected_bytes + if isinstance(total, (int, float)) and total > 0: + total = int(total) + else: + total = expected_bytes + self._superpicky_total = total + self._superpicky_last_n = 0 + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{resource['filename']}: downloading from {source_name}", + ratio=0.0 if total else None, + bytes_done=0, + bytes_total=total, + source=source_name, + ), + ) + + def update(self, n=1): + result = super().update(n) + current = int(getattr(self, "n", self._superpicky_last_n)) + total = getattr(self, "total", None) or self._superpicky_total + if isinstance(total, (int, float)) and total > 0: + total = int(total) + ratio = current / total + else: + total = None + ratio = None + self._superpicky_last_n = current + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{resource['filename']}: downloading from {source_name}", + ratio=ratio, + bytes_done=current, + bytes_total=total, + source=source_name, + is_terminal=bool(total and current >= total), + ), + ) + return result + + def close(self): + current = int(getattr(self, "n", self._superpicky_last_n)) + total = getattr(self, "total", None) or self._superpicky_total + if isinstance(total, (int, float)) and total > 0: + total = int(total) + ratio = min(1.0, current / total) + else: + total = None + ratio = None + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{resource['filename']}: download stream closed", + ratio=ratio, + bytes_done=current, + bytes_total=total, + source=source_name, + is_terminal=bool(total and current >= total), + ), + ) + return super().close() + + return ResourceDownloadTqdm + + +def _download_with_fallback( + resource: Dict[str, Any], + repo_id: str, + filename: str, + full_dest_dir: str, + *, + expected_bytes: int | None = None, + progress_cb: Optional[Callable[[InitializationProgressEvent], None]] = None, +) -> Optional[str]: + """ + 使用回退机制下载文件,支持重试和源切换。 + + Download file with fallback mechanism, supporting retry and source switching. + + 参数 Parameters: + repo_id (str): Hugging Face 仓库 ID + filename (str): 要下载的文件名 + full_dest_dir (str): 目标目录路径 + expected_bytes (int | None): 预估文件大小 + progress_cb (Optional[Callable[[InitializationProgressEvent], None]]): 进度回调函数 + + 返回 Returns: + Optional[str]: 下载的文件路径,失败时返回 None + """ + global hf_hub_download + if hf_hub_download is None: + try: + from huggingface_hub import hf_hub_download as _hf_hub_download + + hf_hub_download = _hf_hub_download + except Exception as exc: + raise RuntimeError(f"huggingface_hub is not installed yet: {exc}") from exc + + errors = [] + endpoints = _resolve_hf_endpoints() + max_retries = 3 # 每个源的最大重试次数 + + for index, (source_name, endpoint) in enumerate(endpoints): + logging.info("尝试从 %s (%s) 下载 %s", source_name, endpoint, filename) + + for retry_count in range(max_retries): + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{filename}: connecting {source_name} ({retry_count + 1}/{max_retries})", + ratio=0.0 if expected_bytes else None, + bytes_done=0, + bytes_total=expected_bytes, + source=source_name, + ), + ) + + start_time = time.perf_counter() + try: + _configure_hf_client_for_endpoint(endpoint) + download_kwargs: Dict[str, Any] = { + "repo_id": repo_id, + "filename": filename, + "local_dir": full_dest_dir, + "local_dir_use_symlinks": False, + "endpoint": endpoint, + } + tqdm_class = _build_download_tqdm_class( + resource, + source_name, + expected_bytes, + progress_cb, + ) + if tqdm_class is not None: + download_kwargs["tqdm_class"] = tqdm_class + try: + download_kwargs["resume_download"] = True + except Exception: + pass + + downloaded_path = cast(Any, hf_hub_download)(**download_kwargs) + elapsed_time = time.perf_counter() - start_time + + path_obj = Path(downloaded_path) + file_size = path_obj.stat().st_size if path_obj.exists() else expected_bytes + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{filename}: downloaded via {source_name}", + ratio=1.0, + bytes_done=file_size, + bytes_total=file_size, + source=source_name, + is_terminal=True, + ), + ) + + logging.info( + "%s 已通过 %s 下载完成,耗时 %.2f 秒", + filename, + source_name, + elapsed_time + ) + return downloaded_path + + except Exception as exc: + elapsed_time = time.perf_counter() - start_time + error_text = _format_download_error(exc) + errors.append(f"{source_name} (尝试 {retry_count + 1}): {error_text}") + + logging.warning( + "%s 通过 %s 下载失败 (尝试 %d/%d): %s (耗时 %.2f 秒)", + filename, + source_name, + retry_count + 1, + max_retries, + error_text, + elapsed_time + ) + + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"{filename}: {source_name} failed ({retry_count + 1}/{max_retries})", + ratio=0.0 if expected_bytes else None, + bytes_done=0, + bytes_total=expected_bytes, + source=source_name, + ), + ) + + if retry_count < max_retries - 1: + base_delay = 2 ** retry_count + jitter = base_delay * 0.25 * (random.random() * 2 - 1) + delay = max(0.5, base_delay + jitter) + logging.info("等待 %.2f 秒后重试...", delay) + time.sleep(delay) + else: + if index < len(endpoints) - 1: + next_source_name = endpoints[index + 1][0] + logging.info("(" + next_source_name + ") 切换到下一个源下载 %s...", filename) + + logging.error( + "所有下载源均失败: %s 来自 %s。详细信息: %s", + filename, + repo_id, + " | ".join(errors), + ) + return None + + +def _configure_hf_client_for_endpoint(endpoint: str) -> None: + """ + 强制 huggingface_hub 在当前尝试中保持选定的端点。 + + 官方文档说明 `HF_ENDPOINT` 会在导入时读取,因此这里同时设置环境变量与运行期常量, + 避免中国网络下已经导入过的客户端偷偷回退到默认官方端点。 + + Force huggingface_hub to stay on the selected endpoint for the current attempt. + + The official documentation states that `HF_ENDPOINT` is read during import, + so we set both environment variables and runtime constants here to prevent + the already-imported client from silently falling back to the default official endpoint + under Chinese network conditions. + + 参数 Parameters: + endpoint (str): 要使用的 Hugging Face 端点 URL + The Hugging Face endpoint URL to use + """ + os.environ["HF_ENDPOINT"] = endpoint + os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" + os.environ["HF_HUB_DISABLE_XET"] = "1" + os.environ["DO_NOT_TRACK"] = "1" + + try: + constants_module = importlib.import_module("huggingface_hub.constants") + if hasattr(constants_module, "ENDPOINT"): + constants_module.ENDPOINT = endpoint + logging.debug("已设置 huggingface_hub.constants.ENDPOINT = %s", endpoint) + except Exception as exc: + logging.debug("设置 huggingface_hub.constants.ENDPOINT 失败: %s", exc) + + try: + file_download_module = importlib.import_module("huggingface_hub.file_download") + if hasattr(file_download_module, "ENDPOINT"): + file_download_module.ENDPOINT = endpoint + logging.debug("已设置 huggingface_hub.file_download.ENDPOINT = %s", endpoint) + except Exception as exc: + logging.debug("设置 huggingface_hub.file_download.ENDPOINT 失败: %s", exc) + + +def download_resource( + resource: Dict[str, Any], + *, + project_root: Optional[Path] = None, + progress_cb: Optional[Callable[[InitializationProgressEvent], None]] = None, +) -> Path: + """ + 下载并验证资源文件。 + + Download and verify resource file. + + 参数 Parameters: + resource (Dict[str, Any]): 资源元数据字典 + project_root (Optional[Path]): 项目根目录 + progress_cb (Optional[Callable[[InitializationProgressEvent], None]]): 进度回调函数 + + 返回 Returns: + Path: 下载的文件路径 + + 异常 Raises: + FileNotFoundError: 本地回退资源未找到 + RuntimeError: 下载失败或完整性验证失败 + """ + project_root = project_root or get_project_root() + + if resource.get("copy_only"): + copied = _copy_local_resource(resource, project_root, progress_cb=progress_cb) + if copied is None: + raise FileNotFoundError(f"Local fallback resource not found: {resource['filename']}") + return copied + + repo_id = resource["repo_id"] + filename = resource["filename"] + resource_id = resource.get("resource_id", "unknown") + full_dest_dir = resolve_resource_destination_dir(project_root, resource) + full_dest_dir.mkdir(parents=True, exist_ok=True) + + logging.info( + "开始下载资源 [%s]: %s 来自仓库 %s", + resource_id, + filename, + repo_id + ) + expected_bytes = _estimate_remote_file_size(repo_id, filename) + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"Preparing download for {filename}", + ratio=0.0 if expected_bytes else None, + bytes_done=0, + bytes_total=expected_bytes, + ), + ) + + download_start_time = time.perf_counter() + downloaded_path = _download_with_fallback( + resource=resource, + repo_id=repo_id, + filename=filename, + full_dest_dir=str(full_dest_dir), + expected_bytes=expected_bytes, + progress_cb=progress_cb, + ) + download_elapsed = time.perf_counter() - download_start_time + + if not downloaded_path: + logging.error( + "资源 [%s] 下载失败: %s 来自 %s,总耗时 %.2f 秒", + resource_id, + filename, + repo_id, + download_elapsed + ) + raise RuntimeError(f"Failed to download {filename} from {repo_id}") + + path_obj = Path(downloaded_path) + file_size = path_obj.stat().st_size if path_obj.exists() else 0 + + logging.info( + "资源 [%s] 下载文件大小: %d 字节 (%.2f MB)", + resource_id, + file_size, + file_size / (1024 * 1024) + ) + + if not verify_resource(resource, path_obj): + path_obj.unlink(missing_ok=True) + logging.error( + "资源 [%s] 完整性验证失败: %s", + resource_id, + filename + ) + raise RuntimeError(f"Integrity verification failed for {filename}") + + logging.info( + "资源 [%s] 下载并验证成功: %s,总耗时 %.2f 秒,文件大小 %.2f MB", + resource_id, + filename, + download_elapsed, + file_size / (1024 * 1024) + ) + _emit_resource_progress( + progress_cb, + _build_resource_progress_event( + resource, + f"Validated {filename}", + ratio=1.0, + bytes_done=file_size, + bytes_total=file_size, + is_terminal=True, + ), + ) + return path_obj + + def main(): """ Downloads required models and database files from Hugging Face Hub. Ensures files are placed in the correct directories for the application to function. """ logging.info("Starting model download process...") - - # Ensure we're running from the project root (where this script is located in an expected directory) - script_dir = os.path.dirname(os.path.abspath(__file__)) - project_root = os.path.abspath(os.path.join(script_dir, "..")) - - # Change to project root to simplify path handling if run from elsewhere + if hf_hub_download is None: + print("Error: huggingface_hub is not installed. Please run `pip install huggingface_hub tqdm` first.") + sys.exit(1) + + project_root = get_project_root() os.chdir(project_root) - logging.info(f"Working directory set to: {project_root}") + logging.info("Working directory set to: %s", project_root) + plan = resolve_download_plan( + {"core_detection", "quality", "keypoint", "flight", "birdid"}, + include_optional_local=False, + ) success_count = 0 - total_models = len(MODELS_TO_DOWNLOAD) - - for item in MODELS_TO_DOWNLOAD: - repo_id = item["repo_id"] - filename = item["filename"] - dest_dir = item["dest_dir"] - category = item["category"] - - full_dest_dir = os.path.join(project_root, dest_dir) - full_dest_path = os.path.join(full_dest_dir, filename) - - logging.info(f"[{category}] Retrieving {filename}...") - - # Ensure destination directory exists - os.makedirs(full_dest_dir, exist_ok=True) - + + for item in plan: + logging.info("[%s] Retrieving %s...", item.get("category", "Resource"), item["filename"]) try: - # Download file using huggingface_hub. It handles caching automatically. - # We use local_dir to bypass symlink behaviors and put it right where we want it. - # If the file already exists and is the correct size/hash, it won't redownload. - downloaded_path = hf_hub_download( - repo_id=repo_id, - filename=filename, - local_dir=full_dest_dir, - local_dir_use_symlinks=False - ) - logging.info(f"✓ Successfully downloaded/verified: {os.path.basename(downloaded_path)}") + downloaded_path = download_resource(item, project_root=project_root) + logging.info("✓ Successfully downloaded/verified: %s", os.path.basename(downloaded_path)) success_count += 1 - except Exception as e: - logging.error(f"✗ Failed to download {filename} from {repo_id}: {str(e)}") + except Exception as exc: + logging.error("✗ Failed to prepare %s: %s", item["filename"], _format_download_error(exc)) - if success_count == total_models: - logging.info(f"All {total_models} files are ready.") - logging.info("Application is ready to run.") + if success_count == len(plan): + logging.info("All %s files are ready.", len(plan)) + logging.info("Application resources are ready to run.") sys.exit(0) - else: - logging.error(f"Only {success_count}/{total_models} files were successfully downloaded.") - logging.error("Please check your internet connection and verify the files exist in the specified Hugging Face repositories.") - sys.exit(1) + + logging.error("Only %s/%s files were successfully prepared.", success_count, len(plan)) + sys.exit(1) + if __name__ == "__main__": main() diff --git a/scripts/verify_patch_cleanup_regression.py b/scripts/verify_patch_cleanup_regression.py new file mode 100644 index 0000000..dd33fb6 --- /dev/null +++ b/scripts/verify_patch_cleanup_regression.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Manual regression checker for residual patch cleanup. + +用于手工验证“残留 code_updates + 初始化关闭自动更新”场景的回归检查脚本。 +""" + +from __future__ import annotations + +import json +import sys +import tempfile +from pathlib import Path +from unittest.mock import patch + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + +from core.initialization_manager import ( + InitializationManager, + RuntimeInstallLocation, + RuntimeSelection, +) +from constants import APP_VERSION +from tools import patch_manager + + +def _create_residual_patch_environment(config_dir: Path) -> tuple[Path, Path]: + patch_dir = config_dir / "code_updates" + meta_path = config_dir / "patch_meta.json" + (patch_dir / "core").mkdir(parents=True, exist_ok=True) + (patch_dir / "core" / "legacy_override.py").write_text( + "PATCH_MARKER = 'stale-overlay'\n", + encoding="utf-8", + ) + meta_path.write_text( + json.dumps( + { + "patch_version": "v-stale", + "base_version": APP_VERSION, + "release_channel": "official", + "target_channels": ["official"], + }, + ensure_ascii=False, + indent=2, + ), + encoding="utf-8", + ) + return patch_dir, meta_path + + +def run_manual_regression_check() -> int: + with tempfile.TemporaryDirectory() as temp_dir: + temp_root = Path(temp_dir) + config_dir = temp_root / "config" + runtime_dir = temp_root / "runtime" + config_dir.mkdir(parents=True, exist_ok=True) + runtime_dir.mkdir(parents=True, exist_ok=True) + patch_dir, meta_path = _create_residual_patch_environment(config_dir) + + print(f"[setup] residual patch dir: {patch_dir}") + print(f"[setup] residual patch meta: {meta_path}") + print(f"[guard] source patch block reason: {patch_manager.get_patch_runtime_block_reason()}") + + manager = InitializationManager() + update_events: list[tuple[str, str, str]] = [] + + def _record_item(resource_id: str, status: str, detail: str) -> None: + if resource_id == "updates": + update_events.append((resource_id, status, detail)) + print(f"[updates] {status}: {detail}") + + with patch.object(patch_manager, "get_app_config_dir", return_value=config_dir), patch.object( + patch_manager, "shared_get_patch_dir", return_value=patch_dir + ), patch.object( + manager, + "choose_runtime_install_location", + return_value=RuntimeInstallLocation("default", runtime_dir, None, True), + ), patch.object( + manager, + "detect_runtime_selection", + return_value=RuntimeSelection("cpu", False, "manual-check"), + ), patch.object( + manager, + "_normalize_features", + side_effect=lambda features: list(features or ["core_detection"]), + ), patch.object(manager, "_save_config"), patch.object( + manager, + "_resolve_best_sources", + return_value={ + "pypi_primary": "https://example.invalid/simple", + "pypi_fallback": "", + "torch_primary": "", + "torch_fallback": "", + }, + ), patch.object(manager, "repair_runtime_if_needed"), patch.object( + manager, "repair_resources_if_needed" + ), patch.object(manager, "is_ready_for_main_ui", return_value=True), patch.object( + manager, "_raise_if_cancelled" + ), patch.object(manager, "_emit_stage"), patch.object( + manager, "_emit_item_status", side_effect=_record_item + ): + manager._run( + { + "features": ["core_detection"], + "auto_update_enabled": False, + "runtime_variant": "cpu", + "runtime_install_location": "default", + }, + mode="repair", + ) + + if patch_dir.exists() or meta_path.exists(): + print("[result] FAILED: residual patch artifacts still exist") + return 1 + + if not any(status == "done" and "补丁环境已清除" in detail for _, status, detail in update_events): + print("[result] FAILED: patch cleanup event was not emitted") + return 1 + + if not any(status == "skipped" and "Automatic updates disabled by user" in detail for _, status, detail in update_events): + print("[result] FAILED: auto-update disabled event was not emitted") + return 1 + + print("[result] PASS: residual patch environment was cleared during disabled-update initialization") + return 0 + + +if __name__ == "__main__": + raise SystemExit(run_manual_regression_check()) \ No newline at end of file diff --git a/superpicky_cli.py b/superpicky_cli.py index e8a5dce..baa2d55 100644 --- a/superpicky_cli.py +++ b/superpicky_cli.py @@ -37,6 +37,8 @@ import sys import os from pathlib import Path +from types import SimpleNamespace +from core.recursive_scanner import DEFAULT_SCAN_MAX_DEPTH from tools.i18n import t # 确保模块路径正确 @@ -196,8 +198,8 @@ def cmd_process(args): # V4.1: Crop save_crop=args.save_crop, birdid_use_ebird=True, - birdid_country_code=getattr(args, 'birdid_country', None), - birdid_region_code=getattr(args, 'birdid_region', None), + birdid_country_code=getattr(args, 'birdid_country', None) or "", + birdid_region_code=getattr(args, 'birdid_region', None) or "", birdid_confidence_threshold=getattr(args, 'birdid_threshold', 70.0) ) @@ -787,28 +789,35 @@ def cmd_identify(args): def cmd_batch(args): """递归批量处理子目录""" - from core.recursive_scanner import scan_recursive, count_photos, is_processed + from core.recursive_scanner import is_dangerous_root, is_processed, scan_directories from core.batch_processor import BatchProcessor from core.photo_processor import ProcessingSettings from advanced_config import get_advanced_config print_banner() print(f"\n📂 批量处理: {args.directory}") + + is_dangerous, reason = is_dangerous_root(args.directory) + if is_dangerous: + print(f"\n❌ {t('health.dangerous_dir_title')}") + print(t("health.dangerous_dir_msg", directory=args.directory, reason=reason)) + return 1 # 扫描 - dirs = scan_recursive(args.directory, max_depth=args.max_depth) + scan_results = scan_directories(args.directory, max_depth=args.max_depth) - if not dirs: - print("\n❌ 未找到包含照片的子目录") + if not scan_results: + print(f"\n❌ {t('health.no_photos_title')}") + print(t("health.no_photos_msg", directory=args.directory)) return 1 # 预览 - print(f"\n🔍 找到 {len(dirs)} 个待处理目录:") + print(f"\n🔍 找到 {len(scan_results)} 个待处理目录:") total_photos = 0 - for i, d in enumerate(dirs, 1): - rel = os.path.relpath(d, args.directory) - n = count_photos(d) - processed = is_processed(d) + for i, scanned_dir in enumerate(scan_results, 1): + rel = os.path.relpath(scanned_dir.path, args.directory) + n = scanned_dir.photo_count + processed = is_processed(scanned_dir.path) status = " (已处理)" if processed else "" print(f" {i:3d}. {rel}/ ({n} 张){status}") total_photos += n @@ -821,7 +830,7 @@ def cmd_batch(args): # 确认 if not args.yes: - confirm = input(f"\n确定处理这 {len(dirs)} 个目录? [y/N]: ") + confirm = input(f"\n确定处理这 {len(scan_results)} 个目录? [y/N]: ") if confirm.lower() not in ['y', 'yes']: print("❌ 已取消") return 1 @@ -848,8 +857,8 @@ def cmd_batch(args): auto_identify=auto_identify, save_crop=getattr(args, 'save_crop', False), birdid_use_ebird=True, - birdid_country_code=getattr(args, 'birdid_country', None), - birdid_region_code=getattr(args, 'birdid_region', None), + birdid_country_code=getattr(args, 'birdid_country', None) or "", + birdid_region_code=getattr(args, 'birdid_region', None) or "", birdid_confidence_threshold=getattr(args, 'birdid_threshold', 70.0), ) @@ -862,7 +871,7 @@ def cmd_batch(args): ) result = processor.process( - dirs=dirs, + dirs=scan_results, organize_files=args.organize, cleanup_temp=not adv_config.keep_temp_files, ) @@ -907,11 +916,7 @@ def cmd_batch_reset(args): print(f"🔄 [{i}/{len(processed_dirs)}] 重置: {rel}/") # 创建一个模拟的 args 对象给 cmd_reset - class ResetArgs: - pass - reset_args = ResetArgs() - reset_args.directory = d - reset_args.yes = True # 已经确认过了 + reset_args = SimpleNamespace(directory=d, yes=True) try: ret = cmd_reset(reset_args) @@ -1097,8 +1102,8 @@ def main(): help='跳过已处理的目录') p_batch.add_argument('--dry-run', action='store_true', help='仅列出待处理目录,不执行') - p_batch.add_argument('--max-depth', type=int, default=10, - help='最大递归深度 (默认: 10)') + p_batch.add_argument('--max-depth', type=int, default=DEFAULT_SCAN_MAX_DEPTH, + help=f'最大递归深度 (默认: {DEFAULT_SCAN_MAX_DEPTH})') p_batch.add_argument('-y', '--yes', action='store_true', help='跳过确认提示') p_batch.add_argument('-q', '--quiet', action='store_true') diff --git a/tools/patch_manager.py b/tools/patch_manager.py index 5ac4de6..d67fc33 100644 --- a/tools/patch_manager.py +++ b/tools/patch_manager.py @@ -13,8 +13,10 @@ """ import sys +import os import json import ssl +import stat import shutil import zipfile import tempfile @@ -55,6 +57,77 @@ def get_patch_dir() -> Path: return shared_get_patch_dir() +def get_patch_runtime_channel() -> str: + """返回当前运行环境的发布渠道。""" + try: + from core.build_info import RELEASE_CHANNEL + + if RELEASE_CHANNEL in ("official", "nightly"): + return RELEASE_CHANNEL + except Exception: + pass + return "dev" + + +def get_patch_runtime_block_reason() -> Optional[str]: + """返回当前环境禁止在线补丁的原因;允许时返回 None。""" + if not getattr(sys, "frozen", False): + return "源码运行环境禁用在线补丁" + + channel = get_patch_runtime_channel() + if channel not in ("nightly", "official"): + return f"{channel} 渠道禁用在线补丁" + + return None + + +def _normalize_patch_channels(meta: dict) -> set[str]: + channels: set[str] = set() + + for key in ("target_channels", "channels"): + value = meta.get(key) + if isinstance(value, list): + channels.update( + str(item).strip().lower() + for item in value + if str(item).strip() + ) + + for key in ("target_channel", "channel", "release_channel"): + value = meta.get(key) + if isinstance(value, str) and value.strip(): + channels.add(value.strip().lower()) + + return channels + + +def validate_patch_metadata(meta: dict, current_app_version: str) -> Tuple[bool, str]: + """校验当前运行环境与补丁元数据是否允许应用。""" + blocked_reason = get_patch_runtime_block_reason() + if blocked_reason: + return False, blocked_reason + + if not isinstance(meta, dict): + return False, "补丁元数据格式无效" + + base_version = str(meta.get("base_version", "")).strip() + if not base_version: + return False, "补丁元数据缺少 base_version" + if base_version != current_app_version: + return False, f"补丁 base_version={base_version} 与当前版本 {current_app_version} 不匹配" + + patch_version = str(meta.get("patch_version", "")).strip() + if not patch_version: + return False, "补丁元数据缺少 patch_version" + + target_channels = _normalize_patch_channels(meta) + current_channel = get_patch_runtime_channel() + if target_channels and current_channel not in target_channels: + return False, f"补丁渠道限制为 {sorted(target_channels)},当前渠道为 {current_channel}" + + return True, "ok" + + def _get_local_meta_path() -> Path: return _get_app_data_dir() / "patch_meta.json" @@ -138,6 +211,37 @@ def _download_to_temp(url: str, timeout: int = 60) -> Optional[Path]: return None +def _make_path_writable(path: str) -> None: + try: + os.chmod(path, stat.S_IWRITE | stat.S_IREAD) + except Exception: + pass + + +def _remove_tree_safely(path: Path) -> None: + def _onerror(func, target, _exc_info): + _make_path_writable(target) + func(target) + + shutil.rmtree(path, onerror=_onerror) + + +def safe_clear_patch() -> Tuple[bool, str]: + """安全清理补丁目录与本地元数据。""" + patch_dir = get_patch_dir() + meta_path = _get_local_meta_path() + + try: + if patch_dir.exists(): + _remove_tree_safely(patch_dir) + if meta_path.exists(): + _make_path_writable(str(meta_path)) + meta_path.unlink(missing_ok=True) + return True, "补丁环境已清除" + except Exception as exc: + return False, f"补丁环境清理失败: {exc}" + + def apply_patch_file(zip_path: Path, meta: dict) -> bool: """ 解压 zip 到 code_updates/ 目录并写入 patch_meta.json。 @@ -151,9 +255,16 @@ def apply_patch_file(zip_path: Path, meta: dict) -> bool: """ patch_dir = get_patch_dir() try: + from constants import APP_VERSION + + valid, reason = validate_patch_metadata(meta, APP_VERSION) + if not valid: + print(f"[PatchManager] 已拒绝应用补丁: {reason}") + return False + # 先清空旧补丁 if patch_dir.exists(): - shutil.rmtree(patch_dir) + _remove_tree_safely(patch_dir) patch_dir.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zf: @@ -171,13 +282,8 @@ def apply_patch_file(zip_path: Path, meta: dict) -> bool: def clear_patch() -> None: """清除当前补丁(回滚到内置版本)""" - patch_dir = get_patch_dir() - meta_path = _get_local_meta_path() - if patch_dir.exists(): - shutil.rmtree(patch_dir, ignore_errors=True) - if meta_path.exists(): - meta_path.unlink(missing_ok=True) - print("[PatchManager] 补丁已清除") + _success, message = safe_clear_patch() + print(f"[PatchManager] {message}") def check_and_apply_patch_from_gitcode( @@ -209,9 +315,9 @@ def check_and_apply_patch_from_gitcode( if not remote_meta: return False, "拉取 GitCode patch_meta.json 失败" - base_version = remote_meta.get("base_version", "") - if base_version != current_app_version: - return False, f"补丁 base_version={base_version} 与当前版本 {current_app_version} 不匹配" + valid, reason = validate_patch_metadata(remote_meta, current_app_version) + if not valid: + return False, reason remote_patch_version = remote_meta.get("patch_version", "") local_meta = read_local_meta() @@ -267,9 +373,9 @@ def check_and_apply_patch_from_mirror( if not remote_meta: return False, "镜像服务器不可用" - base_version = remote_meta.get("base_version", "") - if base_version != current_app_version: - return False, f"补丁 base_version={base_version} 与当前版本 {current_app_version} 不匹配" + valid, reason = validate_patch_metadata(remote_meta, current_app_version) + if not valid: + return False, reason remote_patch_version = remote_meta.get("patch_version", "") local_meta = read_local_meta() @@ -323,9 +429,9 @@ def check_and_apply_patch( return False, "拉取 patch_meta.json 失败" # 3. 检查 base_version 是否匹配当前应用版本 - base_version = remote_meta.get("base_version", "") - if base_version != current_app_version: - return False, f"补丁 base_version={base_version} 与当前版本 {current_app_version} 不匹配" + valid, reason = validate_patch_metadata(remote_meta, current_app_version) + if not valid: + return False, reason # 4. 对比本地补丁版本 remote_patch_version = remote_meta.get("patch_version", "") diff --git a/tools/update_checker.py b/tools/update_checker.py index 367770c..650ef44 100644 --- a/tools/update_checker.py +++ b/tools/update_checker.py @@ -71,6 +71,13 @@ def get_version_channel(ver: str) -> str: return 'dev' +def _mark_patch_check_skipped(update_info: Dict, reason: str) -> None: + """在更新结果中标记补丁检查被安全跳过。""" + update_info['patch_applied'] = False + update_info['patch_skipped'] = True + update_info['patch_message'] = reason + + class UpdateChecker: """更新检测器""" @@ -206,17 +213,25 @@ def check_for_updates(self, timeout: int = 10, include_prerelease: bool = False) # 没有整包更新时,检查是否有补丁 if not has_update: try: - from tools.patch_manager import check_and_apply_patch - patched, msg = check_and_apply_patch( - data.get('assets', []), - self.current_version, + from tools.patch_manager import ( + check_and_apply_patch, + get_patch_runtime_block_reason, ) - update_info['patch_applied'] = patched - update_info['patch_message'] = msg - if patched: - from tools.patch_manager import read_local_meta - meta = read_local_meta() - update_info['patch_version'] = meta.get('patch_version') if meta else None + + blocked_reason = get_patch_runtime_block_reason() + if blocked_reason: + _mark_patch_check_skipped(update_info, blocked_reason) + else: + patched, msg = check_and_apply_patch( + data.get('assets', []), + self.current_version, + ) + update_info['patch_applied'] = patched + update_info['patch_message'] = msg + if patched: + from tools.patch_manager import read_local_meta + meta = read_local_meta() + update_info['patch_version'] = meta.get('patch_version') if meta else None except Exception as e: update_info['patch_message'] = f'补丁检查异常: {e}' @@ -280,14 +295,22 @@ def _check_from_gitcode(self) -> Tuple[bool, Optional[Dict]]: if not has_update: try: - from tools.patch_manager import check_and_apply_patch_from_gitcode - patched, msg = check_and_apply_patch_from_gitcode(gitcode_links, self.current_version) - update_info['patch_applied'] = patched - update_info['patch_message'] = msg - if patched: - from tools.patch_manager import read_local_meta - meta = read_local_meta() - update_info['patch_version'] = meta.get('patch_version') if meta else None + from tools.patch_manager import ( + check_and_apply_patch_from_gitcode, + get_patch_runtime_block_reason, + ) + + blocked_reason = get_patch_runtime_block_reason() + if blocked_reason: + _mark_patch_check_skipped(update_info, blocked_reason) + else: + patched, msg = check_and_apply_patch_from_gitcode(gitcode_links, self.current_version) + update_info['patch_applied'] = patched + update_info['patch_message'] = msg + if patched: + from tools.patch_manager import read_local_meta + meta = read_local_meta() + update_info['patch_version'] = meta.get('patch_version') if meta else None except Exception as e: update_info['patch_message'] = f'GitCode 补丁检查异常: {e}' @@ -338,14 +361,22 @@ def _check_from_mirror(self) -> Tuple[bool, Optional[Dict]]: if not has_update: try: - from tools.patch_manager import check_and_apply_patch_from_mirror - patched, msg = check_and_apply_patch_from_mirror(self.current_version) - update_info['patch_applied'] = patched - update_info['patch_message'] = msg - if patched: - from tools.patch_manager import read_local_meta - meta = read_local_meta() - update_info['patch_version'] = meta.get('patch_version') if meta else None + from tools.patch_manager import ( + check_and_apply_patch_from_mirror, + get_patch_runtime_block_reason, + ) + + blocked_reason = get_patch_runtime_block_reason() + if blocked_reason: + _mark_patch_check_skipped(update_info, blocked_reason) + else: + patched, msg = check_and_apply_patch_from_mirror(self.current_version) + update_info['patch_applied'] = patched + update_info['patch_message'] = msg + if patched: + from tools.patch_manager import read_local_meta + meta = read_local_meta() + update_info['patch_version'] = meta.get('patch_version') if meta else None except Exception as e: update_info['patch_message'] = f'镜像补丁检查异常: {e}' diff --git a/ui/birdid_dock.py b/ui/birdid_dock.py index 5c341b7..37eb5cc 100644 --- a/ui/birdid_dock.py +++ b/ui/birdid_dock.py @@ -1,12 +1,16 @@ #!/usr/bin/env python3 """ -鸟类识别停靠面板 -可停靠在主窗口边缘的识鸟功能面板 -风格与 SuperPicky 主窗口统一 +鸟类识别停靠面板。 +Bird-identification dock panel. + +可停靠在主窗口边缘,负责识鸟入口、区域筛选与结果展示。 +Dockable beside the main window and responsible for the BirdID entry, +region filtering, and result presentation. """ import os import sys +from typing import Any, Optional, cast from PySide6.QtWidgets import ( QDockWidget, QWidget, QVBoxLayout, QHBoxLayout, @@ -19,24 +23,47 @@ from PySide6.QtGui import QPixmap, QDragEnterEvent, QDropEvent, QFont from ui.styles import COLORS, FONTS - +from config import ( + get_app_config_dir, + get_install_scoped_resource_path, + get_runtime_meipass, +) from tools.i18n import get_i18n +ALIGN_CENTER = Qt.AlignmentFlag.AlignCenter +ALIGN_RIGHT_VCENTER = ( + Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter +) +LEFT_BUTTON = Qt.MouseButton.LeftButton +POINTING_HAND_CURSOR = Qt.CursorShape.PointingHandCursor +SIZE_POLICY_EXPANDING = QSizePolicy.Policy.Expanding +SIZE_POLICY_PREFERRED = QSizePolicy.Policy.Preferred +ALLOWED_DOCK_AREAS = ( + Qt.DockWidgetArea.LeftDockWidgetArea | Qt.DockWidgetArea.RightDockWidgetArea +) +USER_ROLE = int(Qt.ItemDataRole.UserRole) +KEEP_ASPECT_RATIO = Qt.AspectRatioMode.KeepAspectRatio +SMOOTH_TRANSFORMATION = Qt.TransformationMode.SmoothTransformation + def get_birdid_data_path(relative_path: str) -> str: - """获取 birdid/data 目录下的资源路径""" + """ + 获取 `birdid/data` 目录下的资源路径。 + Return a resource path under `birdid/data`. + """ + if getattr(sys, 'frozen', False) and sys.platform == 'win32': + return str(get_install_scoped_resource_path(os.path.join('birdid', 'data', relative_path))) if getattr(sys, 'frozen', False): - return os.path.join(sys._MEIPASS, 'birdid', 'data', relative_path) + meipass = get_runtime_meipass() + if meipass is not None: + return os.path.join(meipass, 'birdid', 'data', relative_path) base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) return os.path.join(base_dir, 'birdid', 'data', relative_path) def get_settings_path() -> str: - """获取设置文件路径""" - if sys.platform == 'darwin': - settings_dir = os.path.expanduser('~/Documents/SuperPicky_Data') - else: - settings_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') + """获取设置文件路径(统一使用标准配置目录)""" + settings_dir = str(get_app_config_dir()) os.makedirs(settings_dir, exist_ok=True) return os.path.join(settings_dir, 'birdid_dock_settings.json') @@ -48,8 +75,8 @@ class IdentifyWorker(QThread): def __init__(self, image_path: str, top_k: int = 5, use_gps: bool = True, use_ebird: bool = True, - country_code: str = None, region_code: str = None, - name_format: str = None): + country_code: Optional[str] = None, region_code: Optional[str] = None, + name_format: Optional[str] = None): super().__init__() self.image_path = image_path self.top_k = top_k @@ -100,7 +127,7 @@ def __init__(self): """) layout = QVBoxLayout(self) - layout.setAlignment(Qt.AlignCenter) + layout.setAlignment(ALIGN_CENTER) layout.setSpacing(8) # 图标 - + 号 @@ -111,12 +138,12 @@ def __init__(self): color: {COLORS['text_tertiary']}; background: transparent; """) - icon_label.setAlignment(Qt.AlignCenter) + icon_label.setAlignment(ALIGN_CENTER) layout.addWidget(icon_label) # 提示文字 hint_label = QLabel(self.i18n.t("birdid.drag_hint")) - hint_label.setAlignment(Qt.AlignCenter) + hint_label.setAlignment(ALIGN_CENTER) hint_label.setWordWrap(True) hint_label.setStyleSheet(f""" color: {COLORS['text_tertiary']}; @@ -139,7 +166,7 @@ def dropEvent(self, event: QDropEvent): def mousePressEvent(self, event): - if event.button() == Qt.LeftButton: + if event.button() == LEFT_BUTTON: self.selectFile() def selectFile(self): @@ -187,7 +214,7 @@ def __init__(self, rank: int, cn_name: str, en_name: str, confidence: float): self.i18n = get_i18n() self._selected = False - self.setCursor(Qt.PointingHandCursor) + self.setCursor(POINTING_HAND_CURSOR) self._update_style() # 外层水平布局:左侧色条 + 内容 @@ -239,7 +266,7 @@ def __init__(self, rank: int, cn_name: str, en_name: str, confidence: float): color: {COLORS['text_primary']}; background: transparent; """) - self.name_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred) + self.name_label.setSizePolicy(SIZE_POLICY_EXPANDING, SIZE_POLICY_PREFERRED) layout.addWidget(self.name_label, 1) @@ -260,7 +287,7 @@ def __init__(self, rank: int, cn_name: str, en_name: str, confidence: float): background: transparent; """) self.conf_label.setFixedWidth(40) - self.conf_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter) + self.conf_label.setAlignment(ALIGN_RIGHT_VCENTER) layout.addWidget(self.conf_label) outer_layout.addWidget(content_widget, 1) @@ -357,7 +384,7 @@ def __init__(self, parent=None): self.i18n = get_i18n() super().__init__(self.i18n.t("birdid.title").upper(), parent) self.setObjectName("BirdIDDock") - self.setAllowedAreas(Qt.LeftDockWidgetArea | Qt.RightDockWidgetArea) + self.setAllowedAreas(ALLOWED_DOCK_AREAS) self.setMinimumWidth(280) # 使用自定义标题栏以控制按钮位置 @@ -510,7 +537,6 @@ def _on_float_changed(self, floating: bool): self._float_btn.setToolTip(self.i18n.t("birdid.float_panel")) def _load_regions_data(self) -> dict: - """加载 eBird 区域数据""" regions_path = get_birdid_data_path('ebird_regions.json') if os.path.exists(regions_path): try: @@ -521,35 +547,20 @@ def _load_regions_data(self) -> dict: return {'countries': []} def _build_country_list(self) -> dict: - """构建国家列表 {显示名称: 代码} - - V4.4: 简化下拉菜单,只显示约 15 项 - - 自动定位 (Auto GPS) - - 全球模式 (Global) - - 分隔符 - - Top 10 常用国家 (按英文首字母 A-Z) - - 分隔符 - - "更多国家..." 选项 - """ from collections import OrderedDict t = self.i18n.t is_english = self.i18n.current_lang.startswith('en') - # 使用 OrderedDict 保持插入顺序 country_list = OrderedDict() - # === 第一部分:特殊选项 === country_list[t("birdid.country_auto_gps")] = None country_list[t("birdid.country_global")] = "GLOBAL" - # === 分隔符 1 === country_list["─" * 15] = "SEP1" - # === 第二部分:Top 10 常用国家 (按英文首字母 A-Z 排序) === top10_codes = ['AU', 'BR', 'CN', 'GB', 'HK', 'ID', 'JP', 'MY', 'TW', 'US'] - # 国家代码到 i18n 键的映射 (Top 10) top10_i18n = { 'AU': 'birdid.country_au', 'BR': 'birdid.country_br', @@ -563,18 +574,15 @@ def _build_country_list(self) -> dict: 'US': 'birdid.country_us', } - # 构建 code -> region_data 映射 code_to_region = {} for region in self.regions_data.get('countries', []): code_to_region[region.get('code')] = region - # 添加 Top 10 (已按英文首字母排序) for code in top10_codes: i18n_key = top10_i18n.get(code) if i18n_key: display_name = t(i18n_key) else: - # 回退到 regions_data region = code_to_region.get(code, {}) if is_english: display_name = region.get('name', code) @@ -582,10 +590,8 @@ def _build_country_list(self) -> dict: display_name = region.get('name_cn') or region.get('name', code) country_list[display_name] = code - # === 分隔符 2 === - country_list["─" * 15 + " "] = "SEP2" # 添加空格使 key 不同 + country_list["─" * 15 + " "] = "SEP2" - # === "更多国家..." 选项 === country_list[t("birdid.country_more")] = "MORE" return country_list @@ -604,7 +610,7 @@ def _populate_country_combo(self): if code in ("SEP1", "SEP2"): idx = self.country_combo.count() - 1 # 获取模型中的 item 并设置为不可选 - model = self.country_combo.model() + model = cast(Any, self.country_combo.model()) item = model.item(idx) if item: item.setEnabled(False) @@ -612,7 +618,6 @@ def _populate_country_combo(self): item.setSelectable(False) def _load_settings(self) -> dict: - """加载设置""" settings_path = get_settings_path() if os.path.exists(settings_path): try: @@ -622,18 +627,15 @@ def _load_settings(self) -> dict: pass return { 'use_ebird': True, - 'auto_identify': False, # 选片时自动识别,默认关闭 + 'auto_identify': False, 'selected_country': self.i18n.t('birdid.country_auto_gps'), 'selected_region': self.i18n.t('birdid.region_entire_country') } def _save_settings(self): - """保存设置""" - # V4.0.4: 同时保存 country_code,避免读取时需要硬编码映射 country_display = self.country_combo.currentText() country_code = self.country_list.get(country_display) - - # 解析 region_code + region_display = self.region_combo.currentText() region_code = None if region_display and region_display != self.i18n.t('birdid.region_entire_country'): @@ -641,14 +643,14 @@ def _save_settings(self): match = re.search(r'\(([A-Z]{2}-[A-Z0-9]+)\)', region_display) if match: region_code = match.group(1) - + self.settings = { 'use_ebird': self.ebird_checkbox.isChecked(), 'auto_identify': self.auto_identify_checkbox.isChecked(), 'selected_country': country_display, - 'country_code': country_code, # V4.0.4: 直接存储代码 + 'country_code': country_code, 'selected_region': region_display, - 'region_code': region_code # V4.0.4: 直接存储代码 + 'region_code': region_code } try: settings_path = get_settings_path() @@ -658,98 +660,76 @@ def _save_settings(self): print(f"保存设置失败: {e}") def _apply_settings(self): - """应用保存的设置""" - # 设置标志,防止在应用设置时触发保存 self._applying_settings = True - + self.ebird_checkbox.setChecked(self.settings.get('use_ebird', True)) self.auto_identify_checkbox.setChecked(self.settings.get('auto_identify', False)) - - # V4.0.4: 优先使用 country_code 来匹配,而不是 selected_country 文本 + country_code = self.settings.get('country_code') saved_country = self.settings.get('selected_country', self.i18n.t('birdid.country_auto_gps')) - + matched = False if country_code: - # 通过 country_code 找到对应的显示名称 for display_name, code in self.country_list.items(): if code == country_code: idx = self.country_combo.findText(display_name) if idx >= 0: self.country_combo.setCurrentIndex(idx) matched = True - print(f"[DEBUG] Matched via country_code={country_code}: {display_name}") break - + if not matched: - # 回退:使用文本匹配 idx = self.country_combo.findText(saved_country) if idx >= 0: self.country_combo.setCurrentIndex(idx) - print(f"[DEBUG] Matched via text: {saved_country}") else: - # 如果都找不到,可能是从"更多国家"选的,需要动态添加 if country_code and country_code not in [None, "GLOBAL", "MORE"]: - # 从 regions_data 获取国家名称 for country in self.regions_data.get('countries', []): if country.get('code') == country_code: display_name = saved_country or country.get('name_cn') or country.get('name') - # 添加到列表 t = self.i18n.t more_idx = self.country_combo.findText(t("birdid.country_more")) if more_idx >= 0: self.country_combo.insertItem(more_idx, display_name) self.country_list[display_name] = country_code self.country_combo.setCurrentText(display_name) - print(f"[DEBUG] 动态添加国家: {display_name} ({country_code})") break - - # 等待 _on_country_changed 填充区域列表后再设置区域 - # 使用 QTimer 延迟设置 + saved_region = self.settings.get('selected_region', self.i18n.t('birdid.region_entire_country')) QTimer.singleShot(100, lambda: self._apply_saved_region(saved_region)) def _apply_saved_region(self, saved_region: str): - """延迟应用保存的区域设置""" idx = self.region_combo.findText(saved_region) if idx >= 0: self.region_combo.setCurrentIndex(idx) - # 设置完成后解除标志 self._applying_settings = False def _on_country_changed(self, country_display: str): - """国家选择变化时更新区域列表""" country_code = self.country_list.get(country_display) - # 忽略分隔符 if country_code in ("SEP1", "SEP2"): return - # 处理"更多国家"选项 (已移除,保留兼容性) if country_code == "MORE": self._show_more_countries_dialog() return - # 设置标志,防止在填充区域列表时触发 _on_region_changed self._updating_regions = True self.region_combo.clear() self.region_combo.addItem(self.i18n.t("birdid.region_entire_country")) - # 支持省/州的国家列表 _STATE_COUNTRIES = {"AU", "US", "CN"} is_english = self.i18n.current_lang.startswith('en') show_region = False if country_code and country_code not in (None, "GLOBAL"): - # 查找该国家的区域列表 for country in self.regions_data.get('countries', []): if country.get('code') == country_code: if country.get('has_regions') and country.get('regions'): for region in country['regions']: region_code = region.get('code', '') - # 中文界面显示中文名,英文界面显示英文名 if is_english: region_name = region.get('name', region_code) else: @@ -758,49 +738,34 @@ def _on_country_changed(self, country_display: str): show_region = country_code in _STATE_COUNTRIES break - # 显示/隐藏省州行 if hasattr(self, '_region_row'): self._region_row.setVisible(show_region) self._updating_regions = False - # 只有当不是在应用设置时才保存 if not getattr(self, '_applying_settings', False): self._save_settings() - # 如果已有图片,重新识别(应用新的国家/地区过滤) self._reidentify_if_needed() def _on_region_changed(self, region_display: str): - """区域选择变化时保存设置并重新识别""" - # 如果正在更新区域列表或正在应用设置,不触发保存 if getattr(self, '_updating_regions', False) or getattr(self, '_applying_settings', False): return - + self._save_settings() - - # 如果已有图片,重新识别 + self._reidentify_if_needed() def _show_more_countries_dialog(self): - """显示更多国家选择对话框 - 显示大洲和其他国家,支持搜索 - - V4.4: 只显示不在 Top 10 中的区域(大洲 + 其他国家) - - 大洲项目前面加 🌍 前缀 - - 按英文名 A-Z 排序 - """ from PySide6.QtWidgets import QDialog, QListWidget, QDialogButtonBox, QListWidgetItem, QLineEdit t = self.i18n.t is_english = self.i18n.current_lang.startswith('en') - # Top 10 国家代码(已在下拉菜单中) top10_codes = {'AU', 'BR', 'CN', 'GB', 'HK', 'ID', 'JP', 'MY', 'TW', 'US', 'GLOBAL'} - # 大洲代码 continent_codes = {'AF', 'AS', 'EU', 'NA', 'SA', 'OC'} - # 大洲 i18n 映射 continent_i18n = { 'AF': 'birdid.continent_af', 'AS': 'birdid.continent_as', @@ -810,7 +775,6 @@ def _show_more_countries_dialog(self): 'OC': 'birdid.continent_oc', } - # 其他国家 i18n 映射 other_country_i18n = { 'AR': 'birdid.country_ar', 'CA': 'birdid.country_ca', @@ -938,50 +902,51 @@ def _show_more_countries_dialog(self): for _, display, code, name_en in other_regions: item = QListWidgetItem(display) - item.setData(Qt.UserRole, code) - item.setData(Qt.UserRole + 1, name_en) # 用于搜索 + item.setData(USER_ROLE, code) + item.setData(USER_ROLE + 1, name_en) # 用于搜索 list_widget.addItem(item) layout.addWidget(list_widget) - # 搜索过滤功能 def filter_countries(text): text = text.lower() for i in range(list_widget.count()): item = list_widget.item(i) display_name = item.text().lower() - en_name = (item.data(Qt.UserRole + 1) or "").lower() + en_name = (item.data(USER_ROLE + 1) or "").lower() visible = text in display_name or text in en_name item.setHidden(not visible) search_input.textChanged.connect(filter_countries) - button_box = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok + | QDialogButtonBox.StandardButton.Cancel + ) button_box.accepted.connect(dialog.accept) button_box.rejected.connect(dialog.reject) layout.addWidget(button_box) - if dialog.exec() == QDialog.Accepted: + if dialog.exec() == QDialog.DialogCode.Accepted: selected = list_widget.currentItem() if selected: - code = selected.data(Qt.UserRole) + code_data = selected.data(USER_ROLE) + if not isinstance(code_data, str): + return + code = code_data display = selected.text() - # 添加到下拉菜单并选中 existing = [self.country_combo.itemText(i) for i in range(self.country_combo.count())] if display not in existing: - # 在"更多国家"之前插入 idx = self.country_combo.findText(t("birdid.country_more")) if idx >= 0: self.country_combo.insertItem(idx, display) self.country_list[display] = code self.country_combo.setCurrentText(display) else: - # 用户取消,恢复到之前的选择 saved = self.settings.get('selected_country', t('birdid.country_auto_gps')) self.country_combo.setCurrentText(saved) def _setup_ui(self): - """设置界面""" container = QWidget() container.setStyleSheet(f"background-color: {COLORS['bg_void']};") @@ -989,21 +954,16 @@ def _setup_ui(self): layout.setContentsMargins(12, 12, 12, 12) layout.setSpacing(12) - # 创建QStackedWidget管理两个面板 self.stacked_widget = QStackedWidget() self.stacked_widget.setStyleSheet("background: transparent;") - - # ===== 面板1: 鸟类识别 ===== + self.identify_panel = QWidget() identify_layout = QVBoxLayout(self.identify_panel) identify_layout.setContentsMargins(0, 0, 0, 0) identify_layout.setSpacing(12) - # 拖放区域 self.drop_area = DropArea() self.drop_area.fileDropped.connect(self.on_file_dropped) - - # ===== 国家/区域过滤 ===== filter_frame = QFrame() filter_frame.setStyleSheet(f""" QFrame {{ @@ -1063,8 +1023,7 @@ def _setup_ui(self): self.country_combo.currentTextChanged.connect(self._on_country_changed) country_row.addWidget(self.country_combo, 1) filter_layout.addLayout(country_row) - - # 省/州选择行(仅 AU/US/CN 可见) + self._region_row = QWidget() self._region_row.setStyleSheet("background: transparent;") region_row_layout = QHBoxLayout(self._region_row) @@ -1089,7 +1048,7 @@ def _setup_ui(self): font-size: 11px; }} QComboBox:hover {{ border-color: {COLORS['accent']}; }} - QComboBox::drop-down {{ border: none; }} + QComboBox::drop {{ border: none; }} QComboBox QAbstractItemView {{ background-color: {COLORS['bg_elevated']}; border: 1px solid {COLORS['border']}; @@ -1106,28 +1065,23 @@ def _setup_ui(self): """) self.region_combo.currentTextChanged.connect(self._on_region_changed) region_row_layout.addWidget(self.region_combo, 1) - self._region_row.hide() # 默认隐藏,选 AU/US/CN 时显示 + self._region_row.hide() filter_layout.addWidget(self._region_row) - - # V4.2: 移除 eBird 过滤开关(默认启用,选择"全球"可禁用) - # V4.2: 移除自动识别开关(已移到主界面的"识鸟"按钮) - # 保留隐藏的 checkbox 以兼容设置保存/加载 self.ebird_checkbox = QCheckBox() - self.ebird_checkbox.setChecked(True) # 默认启用 + self.ebird_checkbox.setChecked(True) self.ebird_checkbox.hide() - + self.auto_identify_checkbox = QCheckBox() self.auto_identify_checkbox.setChecked(False) self.auto_identify_checkbox.hide() - + identify_layout.addWidget(filter_frame) identify_layout.addWidget(self.drop_area) - # 图片预览(初始隐藏,支持拖放替换) self.preview_label = DropPreviewLabel() - self.preview_label.setAlignment(Qt.AlignCenter) + self.preview_label.setAlignment(ALIGN_CENTER) self.preview_label.setMinimumHeight(100) - self.preview_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Preferred) + self.preview_label.setSizePolicy(SIZE_POLICY_EXPANDING, SIZE_POLICY_PREFERRED) self.preview_label.setStyleSheet(f""" background-color: {COLORS['bg_elevated']}; border-radius: 10px; @@ -1135,23 +1089,21 @@ def _setup_ui(self): """) self.preview_label.fileDropped.connect(self.on_file_dropped) self.preview_label.hide() - self._current_pixmap = None # 保存原始 pixmap 用于自适应缩放 - self._result_crop_pixmap = None # 保存识别完成的裁剪图,用于结果卡片点击恢复 + self._current_pixmap = None + self._result_crop_pixmap = None identify_layout.addWidget(self.preview_label) - # 文件名显示 self.filename_label = QLabel() self.filename_label.setStyleSheet(f""" font-size: 11px; color: {COLORS['text_tertiary']}; font-family: {FONTS['mono']}; """) - self.filename_label.setAlignment(Qt.AlignCenter) + self.filename_label.setAlignment(ALIGN_CENTER) self.filename_label.setWordWrap(True) self.filename_label.hide() identify_layout.addWidget(self.filename_label) - # 进度条 self.progress = QProgressBar() self.progress.setRange(0, 0) self.progress.setMaximumHeight(3) @@ -1169,8 +1121,6 @@ def _setup_ui(self): """) self.progress.hide() identify_layout.addWidget(self.progress) - - # 结果区域 self.results_frame = QFrame() self.results_frame.setStyleSheet(f""" QFrame {{ @@ -1193,7 +1143,7 @@ def _setup_ui(self): self.results_scroll = QScrollArea() self.results_scroll.setWidgetResizable(True) - self.results_scroll.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) + self.results_scroll.setSizePolicy(SIZE_POLICY_EXPANDING, SIZE_POLICY_EXPANDING) self.results_scroll.setStyleSheet(f""" QScrollArea {{ border: none; @@ -1211,7 +1161,6 @@ def _setup_ui(self): results_layout.addWidget(self.results_scroll) self.results_frame.hide() - # 占位区:初始可见,有结果时隐藏 self.placeholder_frame = QFrame() self.placeholder_frame.setStyleSheet(f""" QFrame {{ @@ -1221,24 +1170,22 @@ def _setup_ui(self): }} """) ph_layout = QVBoxLayout(self.placeholder_frame) - ph_layout.setAlignment(Qt.AlignCenter) + ph_layout.setAlignment(ALIGN_CENTER) ph_label = QLabel(self.i18n.t("birdid.drag_photo_hint")) - ph_label.setAlignment(Qt.AlignCenter) + ph_label.setAlignment(ALIGN_CENTER) ph_label.setStyleSheet(f""" color: {COLORS['text_muted']}; font-size: 12px; background: transparent; """) ph_layout.addWidget(ph_label) - identify_layout.addWidget(self.placeholder_frame, 1) # stretch=1,与 results_frame 同级 + identify_layout.addWidget(self.placeholder_frame, 1) - identify_layout.addWidget(self.results_frame, 1) # stretch=1,填满剩余空间 + identify_layout.addWidget(self.results_frame, 1) - # 操作按钮 btn_layout = QHBoxLayout() btn_layout.setSpacing(8) - # 选择图片按钮 - 次级样式 self.btn_new = QPushButton(self.i18n.t("birdid.btn_select")) self.btn_new.setStyleSheet(f""" QPushButton {{ @@ -1257,7 +1204,6 @@ def _setup_ui(self): self.btn_new.clicked.connect(self.drop_area.selectFile) btn_layout.addWidget(self.btn_new) - # 截图识别按钮 self.btn_screenshot = QPushButton(self.i18n.t("birdid.btn_screenshot")) self.btn_screenshot.setStyleSheet(f""" QPushButton {{ @@ -1278,28 +1224,23 @@ def _setup_ui(self): identify_layout.addLayout(btn_layout) - # 状态标签(隐藏,保留变量用于内部状态追踪) self.status_label = QLabel("") self.status_label.hide() - - # ===== 面板2: 查询鸟名 ===== + from ui.birdname_search_widget import BirdNameSearchWidget self.search_panel = BirdNameSearchWidget() - - # 将两个面板添加到stacked_widget + self.stacked_widget.addWidget(self.identify_panel) self.stacked_widget.addWidget(self.search_panel) - - # 将stacked_widget添加到主布局 + layout.addWidget(self.stacked_widget) self.setWidget(container) def _show_qimage_preview(self, qimage): - """显示 QImage 预览""" from PySide6.QtGui import QImage - + pixmap = QPixmap.fromImage(qimage) if not pixmap.isNull(): self._current_pixmap = pixmap @@ -1308,7 +1249,6 @@ def _show_qimage_preview(self, qimage): QTimer.singleShot(50, self._scale_preview) def on_file_dropped(self, file_path: str): - """处理文件拖放""" if not os.path.exists(file_path): self.status_label.setText(self.i18n.t("birdid.file_not_found_short")) self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['error']};") @@ -1318,47 +1258,33 @@ def on_file_dropped(self, file_path: str): self.status_label.setText(self.i18n.t("birdid.analyzing")) self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['accent']};") - # 显示文件名 filename = os.path.basename(file_path) self.filename_label.setText(filename) self.filename_label.show() - # 显示预览 self.show_preview(file_path) - # 清空之前的结果 self.clear_results() - # 显示进度 self.progress.show() self.results_frame.hide() - - # 启动识别 self._start_identify(file_path) def _reidentify_if_needed(self): - """当国家/地区改变时,如果有当前图片,重新识别""" if hasattr(self, 'current_image_path') and self.current_image_path: if os.path.exists(self.current_image_path): - print(f"[调试] 国家/地区已改变,重新识别: {self.current_image_path}") self.status_label.setText(self.i18n.t("birdid.re_identifying")) self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['accent']};") - - # 清空之前的结果 + self.clear_results() - - # 显示进度 + self.progress.show() self.results_frame.hide() - - # 重新启动识别 self._start_identify(self.current_image_path) def _start_identify(self, file_path: str): - """启动识别(供文件拖放和粘贴共用)""" - # 如果有正在运行的识别任务,先等待它完成或断开连接 if hasattr(self, 'worker') and self.worker is not None: try: self.worker.finished.disconnect() @@ -1366,33 +1292,30 @@ def _start_identify(self, file_path: str): except: pass if self.worker.isRunning(): - self.worker.wait(1000) # 最多等待1秒 + self.worker.wait(1000) self.worker = None - - # 获取过滤设置 + use_ebird = self.ebird_checkbox.isChecked() - use_gps = True # GPS 自动检测始终启用 - + use_gps = True + country_code = None region_code = None - + country_display = self.country_combo.currentText() country_code_raw = self.country_list.get(country_display) - + if country_code_raw and country_code_raw not in ("GLOBAL", "MORE"): country_code = country_code_raw - - # 检查是否选择了具体区域 + region_display = self.region_combo.currentText() if region_display != self.i18n.t("birdid.region_entire_country"): - # 从 "South Australia (AU-SA)" 提取 AU-SA import re match = re.search(r'\(([A-Z]{2}-[A-Z0-9]+)\)', region_display) if match: region_code = match.group(1) - # 启动识别 from advanced_config import get_advanced_config + advanced_config = get_advanced_config() self.worker = IdentifyWorker( file_path, top_k=5, @@ -1400,14 +1323,13 @@ def _start_identify(self, file_path: str): use_ebird=use_ebird, country_code=country_code, region_code=region_code, - name_format=get_advanced_config().name_format, + name_format=advanced_config.name_format, ) self.worker.finished.connect(self.on_identify_finished) self.worker.error.connect(self.on_identify_error) self.worker.start() def show_preview(self, file_path: str): - """显示图片预览""" try: ext = os.path.splitext(file_path)[1].lower() raw_extensions = ['.nef', '.cr2', '.cr3', '.arw', '.raf', '.orf', '.rw2', '.dng'] @@ -1427,28 +1349,24 @@ def show_preview(self, file_path: str): self._current_pixmap = pixmap self.drop_area.hide() self.preview_label.show() - # 延迟缩放,确保布局完成 QTimer.singleShot(50, self._scale_preview) except Exception as e: print(f"预览加载失败: {e}") def _scale_preview(self): - """根据面板宽度缩放预览图""" if self._current_pixmap is None: return - # 获取容器宽度(减去边距和 padding) container = self.widget() if container: - available_width = container.width() - 24 - 16 # 边距 + padding + available_width = container.width() - 24 - 16 else: available_width = self.width() - 40 if available_width < 100: available_width = 256 - # 限制最大高度 max_height = 280 scaled = self._current_pixmap.scaled( available_width, max_height, - Qt.KeepAspectRatio, Qt.SmoothTransformation + KEEP_ASPECT_RATIO, SMOOTH_TRANSFORMATION ) self.preview_label.setPixmap(scaled) @@ -1458,42 +1376,37 @@ def resizeEvent(self, event): if self._current_pixmap is not None and self.preview_label.isVisible(): self._scale_preview() - # 对焦状态键映射(photo_processor 内部值 → i18n key) _FOCUS_STATUS_I18N = { 'BEST': 'rating_engine.focus_best', 'GOOD': 'rating_engine.focus_good', 'BAD': 'rating_engine.focus_bad', 'WORST': 'rating_engine.focus_worst', } - # 对焦状态颜色 _FOCUS_STATUS_COLOR = { - 'BEST': COLORS['focus_best'], # 绿 — 精焦 - 'GOOD': COLORS['focus_good'], # 琥珀 — 合焦 - 'BAD': COLORS['focus_bad'], # 近白灰 — 失焦 - 'WORST': COLORS['focus_worst'], # 灰 — 脱焦 + 'BEST': COLORS['focus_best'], + 'GOOD': COLORS['focus_good'], + 'BAD': COLORS['focus_bad'], + 'WORST': COLORS['focus_worst'], } def update_crop_preview(self, debug_img, focus_status=None): - """ - V4.2: 接收选片过程中的裁剪预览图像并显示,同时在结果区更新对焦状态文字 - Args: - debug_img: BGR numpy 数组 (带标注的鸟类裁剪图) - focus_status: 对焦状态键 "BEST"/"GOOD"/"BAD"/"WORST" 或 None - """ try: import cv2 from PySide6.QtGui import QImage - # BGR -> RGB rgb_img = cv2.cvtColor(debug_img, cv2.COLOR_BGR2RGB) h, w, ch = rgb_img.shape bytes_per_line = ch * w - # numpy -> QImage -> QPixmap - q_img = QImage(rgb_img.data, w, h, bytes_per_line, QImage.Format_RGB888) + q_img = QImage( + rgb_img.data, + w, + h, + bytes_per_line, + QImage.Format.Format_RGB888, + ) pixmap = QPixmap.fromImage(q_img) - # 保存并显示预览 self._current_pixmap = pixmap self.preview_label.show() self._scale_preview() @@ -1501,7 +1414,6 @@ def update_crop_preview(self, debug_img, focus_status=None): except Exception as e: print(f"[BirdIDDock] 预览更新失败: {e}") - # 更新结果区:清空旧内容,显示当前对焦状态 self.clear_results() self.placeholder_frame.hide() self.results_frame.show() @@ -1509,12 +1421,11 @@ def update_crop_preview(self, debug_img, focus_status=None): if focus_status and focus_status in self._FOCUS_STATUS_I18N: i18n_key = self._FOCUS_STATUS_I18N[focus_status] raw_text = self.i18n.t(i18n_key) - # i18n 值带前缀标点(",精焦" / ", Critical Focus"),去掉它 display_text = raw_text.lstrip(",, ").strip() color = self._FOCUS_STATUS_COLOR.get(focus_status, COLORS['text_secondary']) focus_label = QLabel(display_text) - focus_label.setAlignment(Qt.AlignCenter) + focus_label.setAlignment(ALIGN_CENTER) focus_label.setStyleSheet(f""" color: {color}; font-size: 15px; @@ -1527,16 +1438,9 @@ def update_crop_preview(self, debug_img, focus_status=None): self.results_layout.addStretch() def show_completion_message(self, stats: dict): - """ - V4.2: 处理完成后显示统计摘要,隐藏预览图 - Args: - stats: photo_processor 返回的统计字典 - """ - # 隐藏预览图 self.preview_label.hide() self._current_pixmap = None - # 清空结果,切换到结果区显示完成信息 self.clear_results() self.placeholder_frame.hide() self.results_frame.show() @@ -1600,18 +1504,16 @@ def pct(n): self.results_layout.addStretch() def clear_results(self): - """清空结果区域""" while self.results_layout.count(): item = self.results_layout.takeAt(0) - if item.widget(): - item.widget().deleteLater() + widget = item.widget() # pyright: ignore[reportOptionalMemberAccess] + if widget is not None: + widget.deleteLater() def on_identify_finished(self, result: dict): - """识别完成""" self.progress.hide() t = self.i18n.t - # === 构建状态信息 === info_lines = [] # 1. YOLO 检测状态 @@ -1685,7 +1587,6 @@ def on_identify_finished(self, result: dict): """) self.results_layout.addWidget(info_label) - # 断崖式领先判断:#1 与 #2 差距 >= 80% 时只显示 #1,否则显示 Top 3 if len(results) >= 2: gap = results[0].get('confidence', 0) - results[1].get('confidence', 0) show_count = 1 if gap >= 80 else min(3, len(results)) @@ -1707,15 +1608,19 @@ def on_identify_finished(self, result: dict): self.results_layout.addStretch() - # 用 YOLO 裁剪图替换预览(正方形) cropped_pil = result.get('cropped_image') if cropped_pil: try: from PySide6.QtGui import QImage rgb = cropped_pil.convert('RGB') data = rgb.tobytes('raw', 'RGB') - q_img = QImage(data, rgb.width, rgb.height, - rgb.width * 3, QImage.Format_RGB888) + q_img = QImage( + data, + rgb.width, + rgb.height, + rgb.width * 3, + QImage.Format.Format_RGB888, + ) pixmap = QPixmap.fromImage(q_img) if not pixmap.isNull(): self._current_pixmap = pixmap @@ -1724,14 +1629,11 @@ def on_identify_finished(self, result: dict): except Exception as _e: print(f"[BirdIDDock] 裁剪图预览更新失败: {_e}") - # 保存结果 self.identify_results = results - # 状态显示选中的候选 self._update_status_label() def _show_info_panel(self, info_lines: list): - """显示纯信息面板(无结果卡片时使用)""" self.results_frame.show() self.placeholder_frame.hide() info_label = QLabel('\n'.join(info_lines)) @@ -1747,34 +1649,28 @@ def _show_info_panel(self, info_lines: list): self.results_layout.addWidget(info_label) def on_identify_error(self, error_msg: str): - """识别出错""" self.progress.hide() self.status_label.setText(self.i18n.t("birdid.error_prefix") + error_msg[:30]) self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['error']};") - + def on_result_card_clicked(self, rank: int): - """点击结果卡片:切换选中状态 + 复制鸟名到剪贴板""" index = rank - 1 if index < 0 or index >= len(self.result_cards): return - # 切换选中状态 if hasattr(self, 'result_cards'): for card in self.result_cards: card.set_selected(False) self.result_cards[index].set_selected(True) self.selected_index = index - # 更新状态标签 self._update_status_label() - # 恢复 YOLO 裁剪预览 if getattr(self, '_result_crop_pixmap', None): self._current_pixmap = self._result_crop_pixmap self._scale_preview() - # ── 复制鸟名到剪贴板 ────────────────────────────────────── - if hasattr(self, 'identify_results') and 0 <= index < len(self.identify_results): + if isinstance(self.identify_results, list) and 0 <= index < len(self.identify_results): result = self.identify_results[index] is_en = self.i18n.current_lang.startswith('en') bird_name = result.get('en_name', '') if is_en else result.get('cn_name', '') @@ -1783,7 +1679,6 @@ def on_result_card_clicked(self, rank: int): QApplication.clipboard().setText(bird_name) - # 视觉反馈:卡片名称标签短暂变色 card = self.result_cards[index] original_style = card.name_label.styleSheet() card.name_label.setStyleSheet(f""" @@ -1795,19 +1690,15 @@ def on_result_card_clicked(self, rank: int): QTimer.singleShot(600, lambda: card.name_label.setStyleSheet(original_style)) def _update_status_label(self): - """更新状态标签,显示当前选中的候选""" - if hasattr(self, 'selected_index') and hasattr(self, 'identify_results'): + if hasattr(self, 'selected_index') and isinstance(self.identify_results, list): if 0 <= self.selected_index < len(self.identify_results): selected = self.identify_results[self.selected_index] self.status_label.setText(f"✓ {selected['cn_name']} ({selected['confidence']:.0f}%)") self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['success']};") - def _switch_tab(self, index: int): - """切换标签页""" if index == 0: - # 切换到鸟类识别 self.tab_identify.setChecked(True) self.tab_identify.setStyleSheet(f""" QPushButton {{ @@ -1878,18 +1769,14 @@ def _switch_tab(self, index: int): self.stacked_widget.setCurrentIndex(1) def _take_screenshot(self): - """调用系统截图工具,截图后加载识别""" if sys.platform == 'darwin': self._take_screenshot_mac() elif sys.platform == 'win32': self._take_screenshot_win() def _take_screenshot_mac(self): - """macOS: 隐藏主窗口后再启动 screencapture,避免覆盖层被遮挡 - 用 Popen 非阻塞启动,Qt 主线程轮询进程退出,避免阻塞事件循环""" import tempfile - # 先检查屏幕录制权限:快速做一次非交互截图测试 import subprocess as _sp _test_file = os.path.join(tempfile.gettempdir(), 'birdid_sc_test.png') try: @@ -1909,7 +1796,7 @@ def _take_screenshot_mac(self): is_en = self.i18n.current_lang.startswith('en') msg = QMessageBox(self) - msg.setIcon(QMessageBox.Warning) + msg.setIcon(QMessageBox.Icon.Warning) msg.setWindowTitle(self.i18n.t("birdid.title")) if is_en: @@ -1919,16 +1806,16 @@ def _take_screenshot_mac(self): "Tap \"Open Settings\" — find this app and flip the switch on.\n" "Then come back and try again!" ) - open_btn = msg.addButton(" Open Settings ", QMessageBox.AcceptRole) - msg.addButton("Later", QMessageBox.RejectRole) + open_btn = msg.addButton(" Open Settings ", QMessageBox.ButtonRole.AcceptRole) + msg.addButton("Later", QMessageBox.ButtonRole.RejectRole) else: msg.setText("需要屏幕录制权限") msg.setInformativeText( "截图识鸟功能需要「屏幕录制」权限才能工作。\n\n" "点击下方按钮一键跳转设置页,为本应用开启权限后即可使用。" ) - open_btn = msg.addButton(" 打开系统设置 ", QMessageBox.AcceptRole) - msg.addButton("稍后再说", QMessageBox.RejectRole) + open_btn = msg.addButton(" 打开系统设置 ", QMessageBox.ButtonRole.AcceptRole) + msg.addButton("稍后再说", QMessageBox.ButtonRole.RejectRole) msg.setStyleSheet(f""" QMessageBox {{ @@ -1957,7 +1844,7 @@ def _take_screenshot_mac(self): if msg.clickedButton() == open_btn: import subprocess as _open_sp - # macOS URL Scheme 直接跳转到「屏幕录制」权限页面 + _open_sp.Popen([ 'open', 'x-apple.systempreferences:com.apple.preference.security?Privacy_ScreenCapture' ]) @@ -1971,22 +1858,18 @@ def _take_screenshot_mac(self): except Exception: pass - # 找到顶层主窗口并隐藏,让 screencapture 覆盖层能正常显示 self._sc_main_win = self.window() if self._sc_main_win: self._sc_main_win.hide() - # 等待 300ms 让窗口动画完成后再启动截图 QTimer.singleShot(300, self._launch_screencapture_mac) def _launch_screencapture_mac(self): - """延迟启动 screencapture(非阻塞)""" import subprocess print(f"[Screenshot] 启动 screencapture, 目标文件: {self._sc_tmp_file}") try: - # 非阻塞启动 — Qt 事件循环继续运行,screencapture UI 才能正常显示 self._sc_proc = subprocess.Popen( ['screencapture', '-i', '-s', self._sc_tmp_file], stdout=subprocess.PIPE, @@ -1994,28 +1877,22 @@ def _launch_screencapture_mac(self): ) print(f"[Screenshot] screencapture 进程已启动, PID: {self._sc_proc.pid}") except FileNotFoundError: - # screencapture 不可用,恢复窗口 if getattr(self, '_sc_main_win', None): self._sc_main_win.show() self._sc_main_win.raise_() self._show_screenshot_error("screencapture 不可用") return - - # 停止上次残留的轮询 if hasattr(self, '_sc_poll_timer') and self._sc_poll_timer is not None: self._sc_poll_timer.stop() - # 轮询进程退出,每 200ms 检查一次,最多等待 120 秒 self._sc_poll_count = 0 self._sc_poll_timer = QTimer(self) self._sc_poll_timer.timeout.connect(self._poll_screencapture_done) self._sc_poll_timer.start(200) def _poll_screencapture_done(self): - """轮询 screencapture 进程是否退出""" self._sc_poll_count += 1 - # 超时保护(120 秒) if self._sc_poll_count > 600: print("[Screenshot] ⚠️ 超时 (120s),停止轮询") self._sc_poll_timer.stop() @@ -2042,27 +1919,23 @@ def _poll_screencapture_done(self): self._sc_poll_timer.stop() self._sc_proc = None - # 先恢复主窗口 main_win = getattr(self, '_sc_main_win', None) if main_win: main_win.show() main_win.raise_() main_win.activateWindow() - # 用户取消时不会生成文件 file_exists = os.path.exists(self._sc_tmp_file) if file_exists: file_size = os.path.getsize(self._sc_tmp_file) print(f"[Screenshot] ✅ 截图文件存在, 大小: {file_size} bytes, 路径: {self._sc_tmp_file}") if file_size > 0: - # 稍等 100ms 让窗口完全显示后再加载 QTimer.singleShot(100, lambda: self.on_file_dropped(self._sc_tmp_file)) else: print("[Screenshot] ⚠️ 截图文件为空 (0 bytes),可能缺少屏幕录制权限") self._show_screenshot_error("截图文件为空,请检查系统偏好设置 > 隐私与安全 > 屏幕录制 权限") else: print(f"[Screenshot] ❌ 截图文件不存在 (用户可能取消了截图)") - # 列出临时目录中的相关文件用于调试 import glob tmp_dir = os.path.dirname(self._sc_tmp_file) related = glob.glob(os.path.join(tmp_dir, 'birdid_*')) @@ -2070,44 +1943,35 @@ def _poll_screencapture_done(self): print(f"[Screenshot] 临时目录中的相关文件: {related}") def _load_screenshot_from_clipboard(self): - """从剪贴板读取截图并保存为临时文件(Windows 模式备用)""" import tempfile clipboard = QApplication.clipboard() image = clipboard.image() if image is None or image.isNull(): return tmp_file = os.path.join(tempfile.gettempdir(), 'birdid_screenshot.png') - if image.save(tmp_file, 'PNG'): + if image.save(tmp_file, b'PNG'): self.on_file_dropped(tmp_file) else: self._show_screenshot_error("截图保存失败") def _show_screenshot_error(self, msg: str): - """显示截图错误提示""" self.status_label.setText(msg) self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['error']};") self.status_label.show() - - def _take_screenshot_win(self): - """Windows: 隐藏主窗口后发送 Win+Shift+S,轮询剪贴板等待图像""" - # 先清空剪贴板 try: QApplication.clipboard().clear() except Exception: pass - # 隐藏主窗口 self._sc_main_win = self.window() if self._sc_main_win: self._sc_main_win.hide() - # 等待 300ms 让窗口动画完成后再发送快捷键 QTimer.singleShot(300, self._launch_snip_win) def _launch_snip_win(self): - """发送 Win+Shift+S 唤起截图工具""" import ctypes KEYEVENTF_KEYUP = 0x0002 @@ -2115,7 +1979,13 @@ def _launch_snip_win(self): VK_SHIFT = 0x10 VK_S = 0x53 - keybd = ctypes.windll.user32.keybd_event + windll = getattr(ctypes, "windll", None) + if windll is None: + self._restore_win_window() + self.status_label.setText("当前环境不支持 Windows 截图快捷键") + self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['error']};") + return + keybd = windll.user32.keybd_event try: keybd(VK_LWIN, 0, 0, 0) keybd(VK_SHIFT, 0, 0, 0) @@ -2124,20 +1994,17 @@ def _launch_snip_win(self): keybd(VK_SHIFT, 0, KEYEVENTF_KEYUP, 0) keybd(VK_LWIN, 0, KEYEVENTF_KEYUP, 0) except Exception as e: - # 发送失败,直接恢复窗口 self._restore_win_window() self.status_label.setText(f"截图快捷键发送失败: {e}") self.status_label.setStyleSheet(f"font-size: 11px; color: {COLORS['error']};") return - # 轮询剪贴板,每 500ms 检查一次 self._screenshot_poll_count = 0 self._screenshot_timer = QTimer(self) self._screenshot_timer.timeout.connect(self._poll_clipboard_for_screenshot) self._screenshot_timer.start(500) def _restore_win_window(self): - """恢复 Windows 主窗口""" main_win = getattr(self, '_sc_main_win', None) if main_win: main_win.show() @@ -2145,12 +2012,10 @@ def _restore_win_window(self): main_win.activateWindow() def _poll_clipboard_for_screenshot(self): - """轮询剪贴板,检测到图像后恢复窗口并加载""" import tempfile self._screenshot_poll_count += 1 - # 超时 60 秒自动放弃 if self._screenshot_poll_count > 120: self._screenshot_timer.stop() self._restore_win_window() @@ -2168,15 +2033,13 @@ def _poll_clipboard_for_screenshot(self): return tmp_file = os.path.join(tempfile.gettempdir(), 'birdid_screenshot.png') - if image.save(tmp_file, 'PNG'): - # 先恢复窗口,再加载图片 + if image.save(tmp_file, b'PNG'): self._restore_win_window() QTimer.singleShot(100, lambda: self.on_file_dropped(tmp_file)) else: self._restore_win_window() def reset_view(self): - """重置视图""" self.drop_area.show() self.preview_label.hide() self.filename_label.hide() @@ -2189,4 +2052,4 @@ def reset_view(self): self.current_image_path = None self.identify_results = None self._current_pixmap = None - self.clear_results() \ No newline at end of file + self.clear_results() diff --git a/ui/birdname_search_widget.py b/ui/birdname_search_widget.py index db3d7b1..df014b8 100644 --- a/ui/birdname_search_widget.py +++ b/ui/birdname_search_widget.py @@ -6,49 +6,38 @@ """ import os -import sys import sqlite3 import configparser from typing import List, Dict, Optional from PySide6.QtWidgets import ( - QWidget, QVBoxLayout, QHBoxLayout, QLabel, - QLineEdit, QPushButton, QComboBox, QScrollArea, - QFrame, QSizePolicy, QApplication + QWidget, + QVBoxLayout, + QHBoxLayout, + QLabel, + QLineEdit, + QPushButton, + QComboBox, + QScrollArea, + QFrame, + QSizePolicy, + QApplication, ) from PySide6.QtCore import Qt, Signal, QTimer -from PySide6.QtGui import QFont from ui.styles import COLORS, FONTS from tools.i18n import get_i18n +from config import get_birdname_settings_path, get_install_scoped_resource_path def get_birdname_db_path() -> str: """获取鸟类名称数据库路径""" - if getattr(sys, 'frozen', False): - # macOS .app bundle struct: executable is in Contents/MacOS/ - # PyInstaller puts datas in Contents/Resources/ - if sys.platform == 'darwin': - app_contents = os.path.dirname(os.path.dirname(sys.executable)) - res_dir = os.path.join(app_contents, 'Resources') - path_in_res = os.path.join(res_dir, 'ioc', 'birdname.db') - if os.path.exists(path_in_res): - return path_in_res - - # Windows 或 one-dir / one-file 模式的回退 - return os.path.join(sys._MEIPASS, 'ioc', 'birdname.db') - else: - base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - return os.path.join(base_dir, 'ioc', 'birdname.db') + return str(get_install_scoped_resource_path(os.path.join("ioc", "birdname.db"))) def get_birdname_ini_path() -> str: """获取 ioc 目录下的 ini 配置文件路径(用户设置,保留在用户可写目录)""" - # 将持久化配置存储在用户的主目录 .superpicky/ 下,避免 macOS App 内沙盒只读权限问题 - user_home = os.path.expanduser('~') - app_data_dir = os.path.join(user_home, '.superpicky', 'ioc') - os.makedirs(app_data_dir, exist_ok=True) - return os.path.join(app_data_dir, 'birdname_settings.ini') + return str(get_birdname_settings_path()) def load_last_version() -> Optional[str]: @@ -58,10 +47,9 @@ def load_last_version() -> Optional[str]: return None try: cfg = configparser.ConfigParser() - cfg.read(ini_path, encoding='utf-8') - return cfg.get('settings', 'last_version_name', fallback=None) - except Exception as e: - print(f"读取版本设置失败: {e}") + cfg.read(ini_path, encoding="utf-8") + return cfg.get("settings", "last_version_name", fallback=None) + except Exception: return None @@ -70,28 +58,31 @@ def save_last_version(version_name: str): ini_path = get_birdname_ini_path() try: cfg = configparser.ConfigParser() - cfg['settings'] = {'last_version_name': version_name} - with open(ini_path, 'w', encoding='utf-8') as f: + cfg["settings"] = {"last_version_name": version_name} + with open(ini_path, "w", encoding="utf-8") as f: cfg.write(f) - except Exception as e: - print(f"保存版本设置失败: {e}") + except Exception: + pass class ClickableLabel(QLabel): """可点击复制的标签""" + clicked = Signal() def __init__(self, text: str, original_color: str, parent=None): super().__init__(text, parent) self.setCursor(Qt.PointingHandCursor) self.original_color = original_color - self.accent_color = COLORS['accent'] + self.accent_color = COLORS["accent"] def mousePressEvent(self, event): if event.button() == Qt.LeftButton: self.clicked.emit() self.setStyleSheet(f"color: {self.accent_color};") - QTimer.singleShot(500, lambda: self.setStyleSheet(f"color: {self.original_color};")) + QTimer.singleShot( + 500, lambda: self.setStyleSheet(f"color: {self.original_color};") + ) super().mousePressEvent(event) @@ -123,17 +114,19 @@ def __init__(self, bird_data: Dict, parent=None): layout.setContentsMargins(10, 6, 10, 6) layout.setSpacing(2) - cn_name = bird_data.get('chinese_name', '') + cn_name = bird_data.get("chinese_name", "") if cn_name: - cn_color = COLORS['text_primary'] + cn_color = COLORS["text_primary"] self.cn_label = ClickableLabel(cn_name, cn_color) - self.cn_label.setStyleSheet(f"color: {cn_color}; font-size: 13px; font-weight: 500;") + self.cn_label.setStyleSheet( + f"color: {cn_color}; font-size: 13px; font-weight: 500;" + ) self.cn_label.clicked.connect(lambda: self._copy_text(cn_name)) layout.addWidget(self.cn_label) - en_name = bird_data.get('english_name', '') + en_name = bird_data.get("english_name", "") if en_name: - en_color = COLORS['text_secondary'] + en_color = COLORS["text_secondary"] self.en_label = ClickableLabel(en_name, en_color) self.en_label.setStyleSheet(f"color: {en_color}; font-size: 11px;") self.en_label.clicked.connect(lambda: self._copy_text(en_name)) @@ -151,7 +144,7 @@ def __init__(self, parent=None): self.i18n = get_i18n() self.db_path = get_birdname_db_path() self.current_version_id = None - self._loading_versions = False # 防止加载时触发保存 + self._loading_versions = False self._setup_ui() self._load_versions() @@ -162,7 +155,6 @@ def _setup_ui(self): main_layout.setContentsMargins(0, 0, 0, 0) main_layout.setSpacing(8) - # ── 行1:标题 + 版本选择(固定高度)──────────────────────── title_row = QHBoxLayout() title_row.setSpacing(6) @@ -178,7 +170,9 @@ def _setup_ui(self): version_label = QLabel("请选择版本:") version_label.setFixedHeight(28) - version_label.setStyleSheet(f"color: {COLORS['text_secondary']}; font-size: 10px;") + version_label.setStyleSheet( + f"color: {COLORS['text_secondary']}; font-size: 10px;" + ) title_row.addWidget(version_label) self.version_combo = QComboBox() @@ -216,7 +210,6 @@ def _setup_ui(self): main_layout.addLayout(title_row) - # ── 行2:搜索框 + 清空按钮(固定高度)────────────────────── search_row = QHBoxLayout() search_row.setSpacing(6) @@ -265,7 +258,6 @@ def _setup_ui(self): main_layout.addLayout(search_row) - # ── 结果区域:始终占满剩余空间 ────────────────────────────── self.results_area = QFrame() self.results_area.setStyleSheet(f""" QFrame {{ @@ -279,7 +271,6 @@ def _setup_ui(self): results_area_layout.setContentsMargins(6, 6, 6, 6) results_area_layout.setSpacing(0) - # 空状态提示 self.empty_label = QLabel("请在上方输入关键词搜索") self.empty_label.setAlignment(Qt.AlignCenter) self.empty_label.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) @@ -291,7 +282,6 @@ def _setup_ui(self): """) results_area_layout.addWidget(self.empty_label) - # 滚动区域 self.results_scroll = QScrollArea() self.results_scroll.setWidgetResizable(True) self.results_scroll.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) @@ -331,7 +321,6 @@ def _setup_ui(self): main_layout.addWidget(self.results_area, 1) - # ── 统计标签 ────────────────────────────────────────────── self.stats_label = QLabel("") self.stats_label.setFixedHeight(16) self.stats_label.setStyleSheet(f""" @@ -348,11 +337,13 @@ def _load_versions(self): self.version_combo.setEnabled(False) return try: - self._loading_versions = True # 加载期间屏蔽保存 + self._loading_versions = True conn = sqlite3.connect(self.db_path) cursor = conn.cursor() - cursor.execute("SELECT version_id, version_name FROM versions ORDER BY created_at DESC") + cursor.execute( + "SELECT version_id, version_name FROM versions ORDER BY created_at DESC" + ) versions = cursor.fetchall() conn.close() @@ -361,11 +352,9 @@ def _load_versions(self): self.version_combo.setEnabled(False) return - # 填充下拉框 for version_id, version_name in versions: self.version_combo.addItem(version_name, version_id) - # 还原上次选择的版本 last_name = load_last_version() restored = False if last_name: @@ -375,13 +364,11 @@ def _load_versions(self): self.current_version_id = self.version_combo.itemData(idx) restored = True - # 找不到上次记录则默认选第一项 if not restored: self.version_combo.setCurrentIndex(0) self.current_version_id = versions[0][0] - except Exception as e: - print(f"加载版本列表失败: {e}") + except Exception: self.version_combo.addItem("加载失败") self.version_combo.setEnabled(False) finally: @@ -393,7 +380,6 @@ def _on_version_changed(self, index: int): return self.current_version_id = self.version_combo.itemData(index) - # 只有用户主动切换时才保存(加载期间跳过) if not self._loading_versions: save_last_version(self.version_combo.currentText()) @@ -408,7 +394,7 @@ def _on_search_text_changed(self, text: str): return if self.current_version_id is None: return - if hasattr(self, '_search_timer'): + if hasattr(self, "_search_timer"): self._search_timer.stop() self._search_timer = QTimer() self._search_timer.setSingleShot(True) @@ -449,18 +435,25 @@ def _perform_search(self, query: str): """ params = ( self.current_version_id, - f'%{query}%', f'%{query}%', f'%{query}%', - f'%{query}%', f'%{query}%', - f'%{query_lower}%', f'%{query_lower}%', - query, query, query, query_lower, - f'{query}%', f'{query}%' + f"%{query}%", + f"%{query}%", + f"%{query}%", + f"%{query}%", + f"%{query}%", + f"%{query_lower}%", + f"%{query_lower}%", + query, + query, + query, + query_lower, + f"{query}%", + f"{query}%", ) cursor.execute(sql, params) results = cursor.fetchall() self._display_results(results) conn.close() - except Exception as e: - print(f"搜索失败: {e}") + except Exception: self._clear_results() def _display_results(self, results: List): @@ -477,12 +470,12 @@ def _display_results(self, results: List): for row in results: bird_data = { - 'bird_id': row['bird_id'], - 'chinese_name': row['chinese_name'], - 'english_name': row['english_name'], - 'latin_name': row['latin_name'], - 'pinyin_name': row['pinyin_name'], - 'abbreviation': row['abbreviation'] + "bird_id": row["bird_id"], + "chinese_name": row["chinese_name"], + "english_name": row["english_name"], + "latin_name": row["latin_name"], + "pinyin_name": row["pinyin_name"], + "abbreviation": row["abbreviation"], } self.results_layout.addWidget(BirdResultCard(bird_data)) @@ -502,4 +495,4 @@ def _clear_results(self): def _clear_search(self): self.search_input.clear() self._clear_results() - self.search_input.setFocus() \ No newline at end of file + self.search_input.setFocus() diff --git a/ui/custom_dialogs.py b/ui/custom_dialogs.py index 7ef6248..46b1294 100644 --- a/ui/custom_dialogs.py +++ b/ui/custom_dialogs.py @@ -116,13 +116,14 @@ def _setup_ui(self, title, message, buttons): btn_layout = QHBoxLayout() btn_layout.setSpacing(12) - # 如果有多个按钮,左侧添加弹性空间 + button_width = 180 if len(buttons) > 1 else 120 + if len(buttons) > 1: btn_layout.addStretch() for btn_text, btn_value, btn_style in buttons: btn = QPushButton(btn_text) - btn.setMinimumWidth(100) + btn.setFixedWidth(button_width) btn.setMinimumHeight(40) btn.setCursor(Qt.PointingHandCursor) @@ -169,6 +170,9 @@ def _setup_ui(self, title, message, buttons): btn.clicked.connect(lambda checked, v=btn_value: self._on_button_clicked(v)) btn_layout.addWidget(btn) + if len(buttons) > 1: + btn_layout.addStretch() + layout.addLayout(btn_layout) # 调整大小以适应内容 diff --git a/ui/main_window.py b/ui/main_window.py index 1ed9d6a..30e2d37 100644 --- a/ui/main_window.py +++ b/ui/main_window.py @@ -8,14 +8,16 @@ import sys import threading import subprocess +from types import SimpleNamespace from pathlib import Path def get_resource_path(relative_path): """获取资源文件路径(兼容 PyInstaller 打包环境)""" # PyInstaller 打包后会设置 _MEIPASS - if hasattr(sys, '_MEIPASS'): - return os.path.join(sys._MEIPASS, relative_path) + meipass = getattr(sys, "_MEIPASS", None) + if isinstance(meipass, str): + return os.path.join(meipass, relative_path) # 开发环境 return os.path.join(os.path.dirname(os.path.dirname(__file__)), relative_path) @@ -24,6 +26,7 @@ def get_resource_path(relative_path): QLabel, QLineEdit, QPushButton, QSlider, QProgressBar, QTextEdit, QGroupBox, QCheckBox, QMenuBar, QMenu, QFileDialog, QMessageBox, QSizePolicy, QFrame, QSpacerItem, + QDialog, QSystemTrayIcon, QApplication # V4.0: 系统托盘图标 ) from PySide6.QtCore import Qt, Signal, QObject, Slot, QTimer, QPropertyAnimation, QEasingCurve, QMimeData, QThread @@ -31,13 +34,15 @@ def get_resource_path(relative_path): from tools.i18n import get_i18n, set_primary_language from advanced_config import get_advanced_config -from config import config as app_config +from config import config as app_config, get_app_config_dir from ui.styles import ( GLOBAL_STYLE, TITLE_STYLE, SUBTITLE_STYLE, VERSION_STYLE, VALUE_STYLE, COLORS, FONTS, LOG_COLORS, PROGRESS_INFO_STYLE, PROGRESS_PERCENT_STYLE ) from ui.custom_dialogs import StyledMessageBox from ui.skill_level_dialog import SkillLevelDialog, SKILL_PRESETS, get_skill_level_thresholds +from ui.welcome_onboarding_dialog import EnvironmentRepairDialog, WelcomeOnboardingDialog +from core.initialization_manager import InitializationManager # V3.9: 支持拖放的目录输入框 @@ -86,13 +91,14 @@ class WorkerSignals(QObject): class WorkerThread(threading.Thread): """处理线程""" - def __init__(self, dir_path, ui_settings, signals, i18n=None, resume=False): + def __init__(self, dir_path, ui_settings, signals, i18n=None, resume=False, scan_results=None): super().__init__(daemon=True) self.dir_path = dir_path self.ui_settings = ui_settings self.signals = signals - self.i18n = i18n + self.i18n = i18n or get_i18n() self.resume = resume + self.scan_results = list(scan_results) if scan_results is not None else None self._stop_event = threading.Event() self._active_processor = None self.caffeinate_process = None @@ -197,13 +203,8 @@ def process_files(self): try: import json import re - import sys as sys_module - import os - if sys_module.platform == 'darwin': - birdid_settings_dir = os.path.expanduser('~/Documents/SuperPicky_Data') - else: - birdid_settings_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') + birdid_settings_dir = str(get_app_config_dir()) birdid_settings_path = os.path.join(birdid_settings_dir, 'birdid_dock_settings.json') if os.path.exists(birdid_settings_path): @@ -266,8 +267,8 @@ def process_files(self): # BirdID 设置 auto_identify=birdid_auto_identify, birdid_use_ebird=birdid_use_ebird, - birdid_country_code=birdid_country_code, - birdid_region_code=birdid_region_code, + birdid_country_code=birdid_country_code or "", + birdid_region_code=birdid_region_code or "", birdid_confidence_threshold=float(birdid_confidence_threshold), # V4.2 ) @@ -369,13 +370,18 @@ def crop_preview_callback(debug_img, focus_status=None): ) # Detect batch mode: check for subdirectories with photos - from core.recursive_scanner import scan_recursive, has_photos - sub_dirs = scan_recursive(self.dir_path, max_depth=5) + from core.recursive_scanner import DEFAULT_SCAN_MAX_DEPTH, scan_directories + + scan_results = self.scan_results + if scan_results is None: + scan_results = scan_directories(self.dir_path, max_depth=DEFAULT_SCAN_MAX_DEPTH) - if len(sub_dirs) <= 1: + sub_dirs = [item.path for item in scan_results] + + if len(scan_results) <= 1: # Single directory mode (original behavior) # 若扫描到的实际目录与根目录不同(根目录无图片、子目录有图片),使用实际目录 - single_dir = sub_dirs[0] if sub_dirs else self.dir_path + single_dir = scan_results[0].path if scan_results else self.dir_path processor = PhotoProcessor( dir_path=single_dir, settings=settings, @@ -410,21 +416,11 @@ def crop_preview_callback(debug_img, focus_status=None): adv_config = get_advanced_config() log_callback(f"\n{'='*56}", "info") - log_callback(f" \U0001f4c2 Batch mode: {len(sub_dirs)} directories detected", "info") + log_callback(f" \U0001f4c2 Batch mode: {len(scan_results)} directories detected", "info") log_callback(f"{'='*56}", "info") # Count total photos across all dirs for progress - from constants import IMAGE_EXTENSIONS - _photo_exts = set(e.lower() for e in IMAGE_EXTENSIONS) - total_all = 0 - dir_photo_counts = {} - for d in sub_dirs: - count = 0 - for f in os.listdir(d): - if os.path.splitext(f)[1].lower() in _photo_exts: - count += 1 - dir_photo_counts[d] = count - total_all += count + total_all = sum(item.photo_count for item in scan_results) processed_so_far = 0 aggregated = { @@ -438,9 +434,10 @@ def crop_preview_callback(debug_img, focus_status=None): import time as _time aggregated['start_time'] = _time.time() - for idx, sub_dir in enumerate(sub_dirs, 1): + for idx, scanned_dir in enumerate(scan_results, 1): + sub_dir = scanned_dir.path rel = os.path.relpath(sub_dir, self.dir_path) - n_photos = dir_photo_counts.get(sub_dir, 0) + n_photos = scanned_dir.photo_count if n_photos == 0: continue @@ -639,13 +636,14 @@ def __init__(self): self._setup_ui() self._setup_birdid_dock() # V4.0: 识鸟停靠面板 self._show_initial_help() + self._init_manager = InitializationManager(self) # 连接重置信号 # 连接重置信号 self.reset_log_signal.connect(self._log) # 修复Crash: 确保日志信号连接到主线程槽 # noinspection PyUnresolvedReferences - self.log_signal.connect(self._log, Qt.QueuedConnection) + self.log_signal.connect(self._log, Qt.ConnectionType.QueuedConnection) self.reset_complete_signal.connect(self._on_reset_complete) self.reset_error_signal.connect(self._on_reset_error) @@ -659,7 +657,10 @@ def __init__(self): # V4.0.1: 启动时检查更新(延迟2秒,避免阻塞UI,没有更新时不弹窗) from advanced_config import get_advanced_config as _get_cfg_startup - if _get_cfg_startup().auto_check_updates: + # Keep the legacy startup auto-update path for full installs. + # Lightweight initialization owns first-run update probing and must + # completely skip automatic update work when the user disables it. + if _get_cfg_startup().auto_check_updates and self._skip_until_initialized("首次初始化尚未完成,暂不检查更新。"): QTimer.singleShot(2000, lambda: self._check_for_updates(silent=True)) # V4.2: 启动时预加载所有模型(延迟3秒,后台加载不阻塞UI) @@ -679,10 +680,10 @@ def __init__(self): # V4.2: 使用默认窗口大小,不最大化 # self.showMaximized() # 注释掉这行,使用默认大小 - # V4.3: 首次运行时显示水平选择对话框(延迟500ms,确保UI已完成渲染) - if self.config.is_first_run: - QTimer.singleShot(500, self._show_first_run_skill_level_dialog) - else: + # 首次启动欢迎向导由 run_startup_prompts 统一调度,避免重复弹窗。 + # NOTE: onboarding 只替代“首次启动设置流程”,不替代后续手动设置入口。 + # 因此这里仅在非首次运行时预先应用已保存的等级阈值,不在 __init__ 里直接弹窗。 + if not self.config.is_first_run: # 非首次运行:根据保存的水平设置滑块 self._apply_skill_level_thresholds(self.config.skill_level) @@ -763,6 +764,14 @@ def _setup_menu(self): skill_level_action = QAction(self.i18n.t("skill_level.section_title") + "...", self) skill_level_action.triggered.connect(self._show_skill_level_dialog) settings_menu.addAction(skill_level_action) + + update_action = QAction(self.i18n.t("menu.check_update"), self) + update_action.triggered.connect(self._show_update_center) + settings_menu.addAction(update_action) + + repair_action = QAction(self.i18n.t("menu.environment_repair"), self) + repair_action.triggered.connect(self._show_environment_repair_dialog) + settings_menu.addAction(repair_action) settings_menu.addSeparator() @@ -788,13 +797,6 @@ def _setup_menu(self): # 帮助菜单 help_menu = menubar.addMenu(self.i18n.t("menu.help")) - # 在线更新 - update_action = QAction(self.i18n.t("menu.check_update"), self) - update_action.triggered.connect(self._show_update_center) - help_menu.addAction(update_action) - - help_menu.addSeparator() - # 关于 about_action = QAction(self.i18n.t("menu.about"), self) about_action.triggered.connect(self._show_about) @@ -873,7 +875,7 @@ def _setup_birdid_dock(self): from .birdid_dock import BirdIDDockWidget self.birdid_dock = BirdIDDockWidget(self) - self.addDockWidget(Qt.RightDockWidgetArea, self.birdid_dock) + self.addDockWidget(Qt.DockWidgetArea.RightDockWidgetArea, self.birdid_dock) # 设置 dock 初始宽度为最小值,让主区域更宽 self.birdid_dock.setFixedWidth(280) @@ -961,19 +963,22 @@ def _show_main_window(self): # macOS: 恢复 Dock 图标 if sys.platform == 'darwin': try: - from AppKit import NSApp, NSApplicationActivationPolicyRegular - NSApp.setActivationPolicy_(NSApplicationActivationPolicyRegular) + import importlib + appkit = importlib.import_module("AppKit") + appkit.NSApp.setActivationPolicy_(appkit.NSApplicationActivationPolicyRegular) print("✅ 已恢复 Dock 图标") - except ImportError: + except Exception: pass - except Exception as e: - print(f"⚠️ 恢复 Dock 图标失败: {e}") self.show() self.raise_() self.activateWindow() # 确保窗口获得焦点 - self.setWindowState(self.windowState() & ~Qt.WindowMinimized | Qt.WindowActive) + self.setWindowState( + self.windowState() + & ~Qt.WindowState.WindowMinimized + | Qt.WindowState.WindowActive + ) def _quit_app(self): """完全退出应用(清理由 aboutToQuit 信号统一处理)""" @@ -1032,7 +1037,7 @@ def _minimize_to_tray(self): self, self.i18n.t("menu.background_mode_title"), self.i18n.t("menu.background_mode_msg"), - QMessageBox.Ok + QMessageBox.StandardButton.Ok ) # 3. 设置后台模式标志,然后退出 GUI @@ -1050,10 +1055,7 @@ def _on_birdid_check_changed(self, state): """识鸟开关状态变化 - 同步到 BirdID Dock 设置""" import json try: - if sys.platform == 'darwin': - settings_dir = os.path.expanduser('~/Documents/SuperPicky_Data') - else: - settings_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') + settings_dir = str(get_app_config_dir()) os.makedirs(settings_dir, exist_ok=True) settings_path = os.path.join(settings_dir, 'birdid_dock_settings.json') @@ -1102,7 +1104,12 @@ def _create_header_section(self, parent_layout): icon_inner_layout.setContentsMargins(2, 2, 2, 2) icon_label = QLabel() - pixmap = QPixmap(icon_path).scaled(44, 44, Qt.KeepAspectRatio, Qt.SmoothTransformation) + pixmap = QPixmap(icon_path).scaled( + 44, + 44, + Qt.AspectRatioMode.KeepAspectRatio, + Qt.TransformationMode.SmoothTransformation, + ) icon_label.setPixmap(pixmap) icon_inner_layout.addWidget(icon_label) brand_layout.addWidget(icon_container) @@ -1145,7 +1152,7 @@ def _create_header_section(self, parent_layout): version_label = QLabel(version_text) version_label.setStyleSheet(VERSION_STYLE) - version_label.setAlignment(Qt.AlignRight | Qt.AlignBottom) + version_label.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignBottom) header_layout.addWidget(version_label) @@ -1242,10 +1249,7 @@ def _create_parameters_section(self, parent_layout): birdid_saved_state = False try: import json - if sys.platform == 'darwin': - settings_dir = os.path.expanduser('~/Documents/SuperPicky_Data') - else: - settings_dir = os.path.join(os.path.expanduser('~'), 'Documents', 'SuperPicky_Data') + settings_dir = str(get_app_config_dir()) settings_path = os.path.join(settings_dir, 'birdid_dock_settings.json') if os.path.exists(settings_path): with open(settings_path, 'r', encoding='utf-8') as f: @@ -1293,7 +1297,7 @@ def _create_parameters_section(self, parent_layout): sharp_label.setStyleSheet(f"color: {COLORS['text_secondary']}; font-size: 13px; min-width: 80px;") sharp_layout.addWidget(sharp_label) - self.sharp_slider = QSlider(Qt.Horizontal) + self.sharp_slider = QSlider(Qt.Orientation.Horizontal) self.sharp_slider.setRange(200, 600) # 新范围 200-600 self.sharp_slider.setValue(400) # 新默认值 self.sharp_slider.setSingleStep(10) # V4.0: 更精细的调节(键盘方向键) @@ -1304,7 +1308,7 @@ def _create_parameters_section(self, parent_layout): self.sharp_value = QLabel("400") # 新默认值 self.sharp_value.setStyleSheet(VALUE_STYLE) self.sharp_value.setFixedWidth(50) - self.sharp_value.setAlignment(Qt.AlignRight | Qt.AlignVCenter) + self.sharp_value.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) sharp_layout.addWidget(self.sharp_value) sliders_layout.addLayout(sharp_layout) @@ -1317,7 +1321,7 @@ def _create_parameters_section(self, parent_layout): nima_label.setStyleSheet(f"color: {COLORS['text_secondary']}; font-size: 13px; min-width: 80px;") nima_layout.addWidget(nima_label) - self.nima_slider = QSlider(Qt.Horizontal) + self.nima_slider = QSlider(Qt.Orientation.Horizontal) self.nima_slider.setRange(40, 70) # 新范围 4.0-7.0 self.nima_slider.setValue(50) # 默认值 5.0 self.nima_slider.valueChanged.connect(self._on_nima_changed) @@ -1326,7 +1330,7 @@ def _create_parameters_section(self, parent_layout): self.nima_value = QLabel("5.0") # 默认值 self.nima_value.setStyleSheet(VALUE_STYLE) self.nima_value.setFixedWidth(50) - self.nima_value.setAlignment(Qt.AlignRight | Qt.AlignVCenter) + self.nima_value.setAlignment(Qt.AlignmentFlag.AlignRight | Qt.AlignmentFlag.AlignVCenter) nima_layout.addWidget(self.nima_value) sliders_layout.addLayout(nima_layout) @@ -1405,7 +1409,7 @@ def _create_status_banner(self, parent_layout): """创建状态条(进度条下方,按钮上方)""" self._status_banner = QLabel(self.i18n.t("labels.support_format_hint")) self._status_banner.setFixedHeight(32) - self._status_banner.setAlignment(Qt.AlignCenter) + self._status_banner.setAlignment(Qt.AlignmentFlag.AlignCenter) self._status_banner.setStyleSheet(f""" QLabel {{ background-color: {COLORS['bg_card']}; @@ -1501,7 +1505,7 @@ def _browse_directory(self): self, self.i18n.t("labels.select_photo_dir"), "", - QFileDialog.ShowDirsOnly + QFileDialog.Option.ShowDirsOnly ) if directory: self._handle_directory_selection(directory) @@ -1537,7 +1541,7 @@ def _handle_directory_selection(self, directory): def _check_directory_health(self, directory: str): """检查目标目录的磁盘空间和写权限,结果输出到 UI 日志。""" - import shutil, os + import shutil try: usage = shutil.disk_usage(directory) free_gb = usage.free / (1024 ** 3) @@ -1838,6 +1842,9 @@ def _update_status(self, text, color=None): @Slot() def _start_processing(self): """开始处理""" + if not self._require_initialization_for_processing(): + return + if not self.directory_path: StyledMessageBox.warning( self, @@ -1945,15 +1952,25 @@ def _start_processing(self): return # 2. 照片数量预扫描(阻断型) + scan_results = None try: - import os as _os - from constants import IMAGE_EXTENSIONS - _ext_set = set(e.lower() for e in IMAGE_EXTENSIONS) - _photo_count = sum( - 1 for _e in _os.scandir(self.directory_path) - if _e.is_file() and _os.path.splitext(_e.name)[1].lower() in _ext_set - ) - if _photo_count == 0: + from core.recursive_scanner import DEFAULT_SCAN_MAX_DEPTH, is_dangerous_root, scan_directories + + is_dangerous, reason = is_dangerous_root(self.directory_path) + if is_dangerous: + StyledMessageBox.warning( + self, + self.i18n.t("health.dangerous_dir_title"), + self.i18n.t( + "health.dangerous_dir_msg", + directory=self.directory_path, + reason=reason, + ), + ) + return + + scan_results = scan_directories(self.directory_path, max_depth=DEFAULT_SCAN_MAX_DEPTH) + if not scan_results: StyledMessageBox.warning( self, self.i18n.t("health.no_photos_title"), @@ -2010,7 +2027,8 @@ def _start_processing(self): ui_settings, self.worker_signals, self.i18n, - resume=resume_processing + resume=resume_processing, + scan_results=scan_results, ) self.worker.start() @@ -2179,11 +2197,7 @@ def emit_log(msg): emit_log(f"\n\U0001f504 [{idx}/{len(sub_dirs_to_reset)}] {rel}/") try: # Reuse CLI reset logic - class _ResetArgs: - pass - _args = _ResetArgs() - _args.directory = sub_dir - _args.yes = True + _args = SimpleNamespace(directory=sub_dir, yes=True) from superpicky_cli import cmd_reset as _cli_reset _cli_reset(_args) emit_log(f" \u2705 {rel}/ reset done") @@ -2479,6 +2493,9 @@ def _toggle_birdid_dock(self, checked): def _auto_start_birdid_server(self): """自动启动识鸟 API 服务器(使用服务器管理器) - 在后台线程中运行""" + if not self._skip_until_initialized("首次初始化尚未完成,暂不启动识鸟 API 服务器。"): + return + import threading def start_server_task(): @@ -2543,7 +2560,7 @@ def _log(self, message, tag=None): print(message) cursor = self.log_text.textCursor() - cursor.movePosition(QTextCursor.End) + cursor.movePosition(QTextCursor.MoveOperation.End) # 根据标签选择颜色 if tag == "error": @@ -2754,6 +2771,9 @@ def closeEvent(self, event): def _preload_all_models(self): """后台预加载所有AI模型(不阻塞UI)""" + if not self._skip_until_initialized("首次初始化尚未完成,跳过模型预加载。"): + return + import threading def _emit_and_log(msg, level="info"): @@ -3023,6 +3043,7 @@ def _check(): has_update, info = checker.check_for_updates( include_prerelease=_cfg.include_prerelease ) + info = info or {} if has_update: text = f"{self.i18n.t('update.update_center_result_has_update')} V{info.get('version','')}" color = COLORS['accent'] @@ -3075,6 +3096,12 @@ def _do_clear(): layout.addLayout(btn_row) dialog.exec() + def _show_environment_repair_dialog(self): + """显示环境修复对话框,复用初始化修复逻辑但不走首启欢迎页。""" + dialog = EnvironmentRepairDialog(self.i18n, self.config, self) + dialog.start_repair() + dialog.exec() + def _check_for_updates(self, silent=False): """检查更新 @@ -3210,7 +3237,12 @@ def _show_update_result_dialog(self, has_update: bool, update_info): QPushButton:hover {{ background-color: {COLORS['accent_hover']}; }} """) from PySide6.QtWidgets import QApplication - restart_btn.clicked.connect(lambda: (dialog.accept(), QApplication.instance().quit())) + def _restart_app(): + dialog.accept() + app = QApplication.instance() + if app is not None: + app.quit() + restart_btn.clicked.connect(_restart_app) btn_row.addWidget(restart_btn) btn_row.addSpacing(8) @@ -3418,25 +3450,82 @@ def _on_skip(): def _show_skill_level_dialog(self): """菜单打开水平选择对话框""" + # 保留此手动入口:onboarding 只负责首启流程,后续用户仍可在设置菜单中单独调整摄影等级。 dialog = SkillLevelDialog(self.i18n, self) dialog.level_selected.connect(self._on_skill_level_selected) dialog.exec() def _show_first_run_skill_level_dialog(self): - """首次运行:显示水平选择对话框""" - dialog = SkillLevelDialog(self.i18n, self) - dialog.level_selected.connect(self._on_skill_level_selected) + """首次运行:显示轻量欢迎向导。""" + # Safety guard: onboarding 只允许作为首启流程出现。 + # 如果未来旧代码路径误调用这里,非首次运行时直接跳过,避免重复打断用户。 + # NOTE: + # We intentionally keep this legacy entrypoint. The dialog now embeds + # lightweight-package initialization, while full packages can still use + # the same onboarding shell as a compatibility path. + if not self.config.is_first_run and self._initialization_ready(): + return + + dialog = WelcomeOnboardingDialog(self.i18n, self) + dialog.onboarding_completed.connect(self._on_welcome_onboarding_completed) dialog.exec() + def _initialization_ready(self) -> bool: + return self._init_manager.is_ready_for_main_ui() + + def _skip_until_initialized(self, log_message: str) -> bool: + if self._initialization_ready(): + return True + self.log_signal.emit(log_message, "info") + return False + + def _require_initialization_for_processing(self) -> bool: + if self._initialization_ready(): + return True + StyledMessageBox.warning( + self, + self.i18n.t("messages.hint"), + self.i18n.t("messages.initialization_required"), + ) + self._show_first_run_skill_level_dialog() + return False + + def _resume_post_initialization_flow(self): + """初始化完成后补触发被首启门禁跳过的后台流程。""" + if not self._initialization_ready(): + return + + self.config = get_advanced_config() + self._apply_skill_level_thresholds(self.config.skill_level) + self._update_skill_level_label(self.config.skill_level) + + # 首次轻量初始化完成后,这些任务之前可能被跳过,这里补一次。 + QTimer.singleShot(200, self._preload_all_models) + QTimer.singleShot(400, self._auto_start_birdid_server) + if self.config.auto_check_updates: + QTimer.singleShot(600, lambda: self._check_for_updates(silent=True)) + def run_startup_prompts(self): """在启动统计同意流程结束后继续启动期弹窗/预设应用。""" if self._startup_prompts_ran: return + # Centralized first-run gating: 所有首启提示都从这里统一进入。 + # 这样 telemetry / consent 完成后只会决策一次,避免 onboarding 被其他启动路径重复触发。 self._startup_prompts_ran = True - if self.config.is_first_run: + needs_init = self._init_manager.needs_initialization() + if ( + needs_init + and not self.config.is_first_run + and self.config.last_init_exit_reason == "interrupted" + and self.config.last_init_mode == "repair" + ): + self._show_environment_repair_dialog() + return + if self.config.is_first_run or needs_init: self._show_first_run_skill_level_dialog() else: + # 非首次运行不再进入 onboarding,只恢复上次保存的摄影等级阈值。 self._apply_skill_level_thresholds(self.config.skill_level) def _on_skill_level_selected(self, level_key: str): @@ -3453,6 +3542,25 @@ def _on_skill_level_selected(self, level_key: str): self._update_skill_level_label(level_key) print(self.i18n.t("logs.skill_level_selected", level=level_key)) + + def _on_welcome_onboarding_completed(self, level_key: str, auto_update_enabled: bool): + """处理首次启动欢迎向导完成。""" + # Keep signal payload order stable: (level_key, auto_update_enabled) + # 这里同时负责首启设置持久化与立即生效,避免状态已保存但主界面仍停留在旧阈值。 + self.config.set_skill_level(level_key) + self.config.set_auto_check_updates(auto_update_enabled) + self.config.set_is_first_run(False) + self.config.set_initialization_completed(self._initialization_ready()) + self.config.save() + + self._apply_skill_level_thresholds(level_key) + self._update_skill_level_label(level_key) + self._resume_post_initialization_flow() + + print( + f"[onboarding] first-run setup saved: " + f"skill_level={level_key}, auto_check_updates={auto_update_enabled}" + ) def _apply_skill_level_thresholds(self, level_key: str): """应用水平预设的阈值到滑块""" diff --git a/ui/welcome_onboarding_dialog.py b/ui/welcome_onboarding_dialog.py new file mode 100644 index 0000000..f2e0117 --- /dev/null +++ b/ui/welcome_onboarding_dialog.py @@ -0,0 +1,1081 @@ +# -*- coding: utf-8 -*- +""" +SuperPicky onboarding and initialization dialogs. + +This module contains the first-run welcome wizard, the environment repair +dialog, and the lightweight Qt widgets that render initialization progress. +The actual long-task animation policy lives in `core.initialization_progress` +so the GUI layer stays thin and testable. + +SuperPicky 首次启动欢迎向导与初始化对话框。 + +此模块包含首次运行欢迎向导、环境修复对话框,以及负责渲染初始化进度的轻量 Qt 组件。 +实际的长任务动画策略位于 `core.initialization_progress`,从而保持 GUI 层足够薄且可测试。 +""" + +import os +import sys +import time +from dataclasses import dataclass +from typing import Callable, Mapping, Protocol, cast + +from PySide6.QtCore import Qt, QObject, QTimer, Signal +from PySide6.QtGui import QColor, QPainter, QPen +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QDialog, + QFrame, + QHBoxLayout, + QLabel, + QPushButton, + QStackedWidget, + QTextEdit, + QVBoxLayout, + QWidget, +) + +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from advanced_config import get_advanced_config +from core.initialization_progress import ( + InitializationProgressEvent, + InitializationProgressModel, +) +from core.initialization_manager import InitializationManager +from ui.custom_dialogs import StyledMessageBox +from ui.skill_level_dialog import SkillLevelCard +from ui.styles import COLORS, FONTS + +UPDATE_OPTION_KEYS = ("enabled", "disabled") +SKILL_LEVEL_KEYS = ("beginner", "intermediate", "master") +FULL_FEATURE_SET = ("core_detection", "quality", "keypoint", "flight", "birdid") + +SELECTABLE_CARD_TITLE_STYLE = f""" + color: {COLORS['text_primary']}; + font-size: 15px; + font-weight: 600; + background: transparent; + border: none; + border-radius: 0; + padding: 0; +""" + +SELECTABLE_CARD_DESC_STYLE = f""" + color: {COLORS['text_secondary']}; + font-size: 12px; + background: transparent; + border: none; + border-radius: 0; + padding: 0; +""" + +SELECTABLE_CARD_SELECTED_STYLE = f""" + QFrame#updateOptionCard {{ + background-color: {COLORS['bg_elevated']}; + border: 2px solid {COLORS['accent']}; + border-radius: 8px; + }} +""" + +SELECTABLE_CARD_UNSELECTED_STYLE = f""" + QFrame#updateOptionCard {{ + background-color: {COLORS['bg_elevated']}; + border: 1px solid transparent; + border-radius: 8px; + }} + QFrame#updateOptionCard:hover {{ + border-color: {COLORS['border']}; + }} +""" + +DIALOG_STYLE = f""" + QDialog {{ + background-color: {COLORS['bg_primary']}; + border-radius: 14px; + }} + QLabel {{ + color: {COLORS['text_primary']}; + background: transparent; + font-family: {FONTS['sans']}; + }} + QPushButton {{ + background-color: {COLORS['accent']}; + color: {COLORS['bg_void']}; + border: none; + border-radius: 8px; + padding: 10px 20px; + font-size: 14px; + font-weight: 600; + font-family: {FONTS['sans']}; + }} + QPushButton:hover {{ + background-color: {COLORS['accent_hover']}; + }} + QPushButton:pressed {{ + background-color: {COLORS['accent_pressed']}; + }} + QPushButton#secondary {{ + background-color: {COLORS['bg_card']}; + color: {COLORS['text_secondary']}; + border: 1px solid {COLORS['border']}; + }} + QPushButton#secondary:hover {{ + background-color: {COLORS['bg_elevated']}; + color: {COLORS['text_primary']}; + border-color: {COLORS['text_tertiary']}; + }} + QPushButton:disabled {{ + background-color: {COLORS['bg_card']}; + color: {COLORS['text_muted']}; + border: 1px solid {COLORS['border_subtle']}; + }} + QCheckBox {{ + color: {COLORS['text_primary']}; + font-size: 13px; + spacing: 8px; + }} + QTextEdit {{ + background-color: {COLORS['bg_card']}; + color: {COLORS['text_secondary']}; + border: 1px solid {COLORS['border']}; + border-radius: 8px; + padding: 8px; + font-family: {FONTS['sans']}; + font-size: 12px; + }} +""" + +PAGE_TITLE_STYLE = f""" + QLabel {{ + color: {COLORS['text_primary']}; + font-size: 24px; + font-weight: 700; + }} +""" + +BODY_SUBTITLE_STYLE = f""" + QLabel {{ + color: {COLORS['text_secondary']}; + font-size: 13px; + }} +""" + +HINT_STYLE = f""" + QLabel {{ + color: {COLORS['text_tertiary']}; + font-size: 12px; + }} +""" + +DOT_ACTIVE_STYLE = f"background-color: {COLORS['accent']}; border-radius: 5px;" +DOT_INACTIVE_STYLE = f"background-color: {COLORS['border']}; border-radius: 5px;" +ALIGN_CENTER = Qt.AlignmentFlag.AlignCenter +POINTING_HAND_CURSOR = Qt.CursorShape.PointingHandCursor + + +class _SelectableCardLike(Protocol): + def set_selected(self, selected: bool) -> None: + ... + + +class _PostInitializationFlowHost(Protocol): + def _resume_post_initialization_flow(self) -> None: + ... + + +@dataclass(frozen=True) +class _NavState: + prev_enabled: bool + next_text: str + next_enabled: bool + background_visible: bool + retry_visible: bool + + +class SelectableCard(QFrame): + clicked = Signal(str) + + def __init__(self, option_key: str, title: str, description: str, parent=None): + super().__init__(parent) + self.option_key = option_key + self._selected = False + self.setObjectName("updateOptionCard") + self.setCursor(POINTING_HAND_CURSOR) + self.setFixedSize(260, 150) + + layout = QVBoxLayout(self) + layout.setContentsMargins(14, 10, 14, 10) + layout.setSpacing(6) + layout.setAlignment(ALIGN_CENTER) + + self.title_label = QLabel(title) + self.title_label.setAlignment(ALIGN_CENTER) + self.title_label.setStyleSheet(SELECTABLE_CARD_TITLE_STYLE) + layout.addWidget(self.title_label) + + self.desc_label = QLabel(description) + self.desc_label.setWordWrap(True) + self.desc_label.setAlignment(ALIGN_CENTER) + self.desc_label.setStyleSheet(SELECTABLE_CARD_DESC_STYLE) + layout.addWidget(self.desc_label) + self._apply_style() + + def set_selected(self, selected: bool): + if self._selected == selected: + return + self._selected = selected + self._apply_style() + + def _apply_style(self): + self.setStyleSheet( + SELECTABLE_CARD_SELECTED_STYLE if self._selected else SELECTABLE_CARD_UNSELECTED_STYLE + ) + + def mousePressEvent(self, event): + self.clicked.emit(self.option_key) + super().mousePressEvent(event) + + +class LockedFeatureCheckBox(QCheckBox): + def __init__(self, text: str, parent=None): + super().__init__(text, parent) + self.setChecked(True) + self.setFocusPolicy(Qt.FocusPolicy.NoFocus) + + def nextCheckState(self) -> None: + return + + def mousePressEvent(self, event) -> None: + event.accept() + + def mouseReleaseEvent(self, event) -> None: + event.accept() + + def keyPressEvent(self, event) -> None: + event.accept() + + +class StatusBulletLabel(QLabel): + def __init__(self, text: str = "", parent=None): + super().__init__(text, parent) + self.setWordWrap(True) + self.setAlignment(Qt.AlignmentFlag.AlignLeft | Qt.AlignmentFlag.AlignVCenter) + self.setStyleSheet( + f""" + color: {COLORS['text_primary']}; + font-size: 13px; + font-weight: 600; + background: transparent; + padding-left: 4px; + """ + ) + + +class RoundedProgressBar(QWidget): + """ + Lightweight rounded progress bar with floating-point fill support. + + 支持浮点填充进度的轻量圆角进度条。 + """ + + def __init__(self, parent=None): + super().__init__(parent) + self._minimum = 0 + self._maximum = 100 + self._value = 0.0 + self.setMinimumHeight(14) + + def setRange(self, minimum: int, maximum: int) -> None: + self._minimum = minimum + self._maximum = max(minimum + 1, maximum) + self.update() + + def setValue(self, value: float) -> None: + """ + Update the rendered progress value with sub-percent precision. + + 使用亚百分比精度更新渲染进度值。 + """ + bounded = float(max(self._minimum, min(self._maximum, value))) + if abs(bounded - self._value) < 0.02: + return + self._value = bounded + self.update() + + def setTextVisible(self, _visible: bool) -> None: + # Kept for compatibility with the previous QProgressBar calls. + return + + def paintEvent(self, _event) -> None: + painter = QPainter(self) + painter.setRenderHint(QPainter.RenderHint.Antialiasing, True) + + rect = self.rect().adjusted(0, 0, -1, -1) + radius = min(rect.height() / 2.0, 8.0) + + painter.setPen(QPen(QColor(COLORS["border"]), 1)) + painter.setBrush(QColor(COLORS["bg_card"])) + painter.drawRoundedRect(rect, radius, radius) + + span = max(1, self._maximum - self._minimum) + progress_ratio = (self._value - self._minimum) / span + if progress_ratio <= 0: + return + + fill_width = max(rect.height(), int(rect.width() * progress_ratio)) + fill_width = min(rect.width(), fill_width) + fill_rect = rect.adjusted(1, 1, -(rect.width() - fill_width), -1) + + painter.setPen(Qt.PenStyle.NoPen) + painter.setBrush(QColor(COLORS["accent"])) + painter.drawRoundedRect(fill_rect, radius - 1, radius - 1) + + +class InitializationProgressBinder(QObject): + """ + Thin Qt adapter that renders structured initialization progress events. + + 将结构化初始化进度事件渲染到 Qt 控件上的薄适配层。 + + The binder owns no hard-coded timing policy. Instead, it forwards stage + changes and progress events into the shared pure-Python model and only + handles Qt timer scheduling plus success/failure callbacks. + + 该适配层不再持有硬编码的时间策略,而是把阶段变化与进度事件转发给共享的纯 Python 模型, + 自身只负责 Qt 定时器调度以及成功/失败回调。 + """ + + def __init__( + self, + manager: InitializationManager, + *, + set_stage_text: Callable[[str], None], + set_progress_value: Callable[[float], None], + append_log: Callable[[str], None], + on_success: Callable[[object], None], + on_failure: Callable[[object], None], + parent=None, + ): + super().__init__(parent) + self._manager = manager + self._set_stage_text = set_stage_text + self._set_progress_value = set_progress_value + self._append_log = append_log + self._on_success = on_success + self._on_failure = on_failure + self._model = InitializationProgressModel() + self._pending_success_summary: object | None = None + self._desired_progress = 0.0 + self._rendered_progress = 0.0 + self._last_animation_tick = time.monotonic() + self._progress_timer = QTimer(self) + self._progress_timer.setInterval(16) + self._progress_timer.timeout.connect(self._advance_progress_animation) + manager.stage_changed.connect(self._handle_stage_changed) + manager.progress_event.connect(self._handle_progress_event) + manager.item_status_changed.connect(self._handle_item_status_changed) + manager.finished.connect(self._handle_finished) + + def reset(self) -> None: + """ + Clear the current animation state before a new run begins. + + 在新一轮初始化开始前清空当前动画状态。 + """ + self._pending_success_summary = None + now = time.monotonic() + self._model.reset(now) + self._desired_progress = 0.0 + self._rendered_progress = 0.0 + self._last_animation_tick = now + self._progress_timer.stop() + self._push_progress(0.0) + + def _push_progress(self, value: float) -> None: + """ + Clamp and forward the displayed progress to the bound widget. + + 夹紧并转发显示进度到绑定控件。 + """ + clamped = max(0.0, min(100.0, value)) + self._rendered_progress = clamped + self._set_progress_value(clamped) + + def _apply_snapshot(self, *, now: float | None = None) -> None: + """ + Advance the pure-Python model and render its latest snapshot. + + 推进纯 Python 模型并渲染其最新快照。 + """ + snapshot = self._model.advance(time.monotonic() if now is None else now) + self._desired_progress = snapshot.display_value + + if self._pending_success_summary is not None and snapshot.is_settled: + self._push_progress(100.0) + summary = self._pending_success_summary + self._pending_success_summary = None + self._progress_timer.stop() + self._on_success(summary) + return + + if snapshot.is_finishing or snapshot.active_phase is not None: + if not self._progress_timer.isActive(): + self._last_animation_tick = time.monotonic() if now is None else now + self._progress_timer.start() + return + + if self._progress_timer.isActive(): + self._progress_timer.stop() + + def _handle_stage_changed(self, stage: str, message: str) -> None: + """ + Update the stage label and synchronize the animation model. + + 更新阶段标签并同步动画模型。 + """ + now = time.monotonic() + self._set_stage_text(message) + self._append_log(f"[{stage}] {message}") + self._model.on_stage_changed(stage, now) + self._apply_snapshot(now=now) + + def _handle_progress_event(self, event: InitializationProgressEvent) -> None: + """ + Feed a structured progress event into the shared animation model. + + 将结构化进度事件送入共享动画模型。 + """ + now = time.monotonic() + self._model.on_progress_event(event, now) + self._apply_snapshot(now=now) + + def _handle_item_status_changed(self, resource_id: str, status: str, detail: str) -> None: + if resource_id in {"updates", "runtime"}: + self._append_log(f"{resource_id}: {detail}") + return + self._append_log(f"{resource_id} [{status}] {detail}") + + def _handle_finished(self, success: bool, summary: object) -> None: + """ + Start the success settle animation or fail immediately. + + 成功时启动收尾动画,失败时立即结束。 + """ + now = time.monotonic() + if success: + self._pending_success_summary = summary + self._model.on_finished(True, now) + self._apply_snapshot(now=now) + return + self._pending_success_summary = None + self._progress_timer.stop() + self._on_failure(summary) + + def _advance_progress_animation(self) -> None: + """ + Advance the animation from the Qt timer tick with smooth interpolation. + + 以平滑插值方式推进每一帧动画。 + """ + now = time.monotonic() + self._apply_snapshot(now=now) + + dt = max(0.001, now - self._last_animation_tick) + self._last_animation_tick = now + delta = self._desired_progress - self._rendered_progress + if abs(delta) < 0.015: + self._push_progress(self._desired_progress) + if self._pending_success_summary is None and self._desired_progress >= 99.999: + self._progress_timer.stop() + return + + # Use critically damped tracking: larger gaps move faster, small gaps ease in. + # This removes the stair-step feel of integer updates while preserving monotonicity. + # 使用接近临界阻尼的追踪方式:差距大时移动更快,差距小时自然缓入, + # 从而消除整数跳格的顿挫感,同时保持单调前进。 + smoothing = 1.0 - pow(0.0025, dt) + min_step = 0.045 + min(0.18, abs(delta) * 0.12) + step = max(min_step, abs(delta) * smoothing) + next_value = self._rendered_progress + min(abs(delta), step) + self._push_progress(min(next_value, self._desired_progress)) + + +class EnvironmentRepairDialog(QDialog): + def __init__(self, i18n, config, parent=None): + super().__init__(parent) + self.i18n = i18n + self.config = config + self.manager = InitializationManager(self) + self._repair_running = False + self._closing_after_interrupt = False + self._setup_ui() + self._progress = InitializationProgressBinder( + self.manager, + set_stage_text=self.stage_label.setText, + set_progress_value=self.progress_bar.setValue, + append_log=self.log_view.append, + on_success=self._on_repair_success, + on_failure=self._on_repair_failure, + parent=self, + ) + self.retry_btn.clicked.connect(self.start_repair) + + def _setup_ui(self) -> None: + self.setWindowTitle(self.i18n.t("repair.window_title")) + self.setMinimumWidth(520) + self.setMinimumHeight(420) + self.setStyleSheet(DIALOG_STYLE) + + layout = QVBoxLayout(self) + layout.setContentsMargins(24, 24, 24, 20) + layout.setSpacing(12) + + title = QLabel(self.i18n.t("repair.window_title")) + title.setStyleSheet(f"font-size: 18px; font-weight: 600; color: {COLORS['text_primary']};") + layout.addWidget(title) + + summary = QLabel(self.i18n.t("repair.summary")) + summary.setWordWrap(True) + summary.setStyleSheet(f"color: {COLORS['text_secondary']}; font-size: 13px;") + layout.addWidget(summary) + + self.stage_label = QLabel(self.i18n.t("repair.start")) + self.stage_label.setStyleSheet(f"color: {COLORS['text_secondary']}; font-size: 12px;") + layout.addWidget(self.stage_label) + + self.progress_bar = RoundedProgressBar() + self.progress_bar.setRange(0, 100) + self.progress_bar.setValue(0) + self.progress_bar.setTextVisible(False) + layout.addWidget(self.progress_bar) + + self.log_view = QTextEdit() + self.log_view.setReadOnly(True) + layout.addWidget(self.log_view, 1) + + btn_row = QHBoxLayout() + btn_row.addStretch() + + self.retry_btn = QPushButton(self.i18n.t("repair.retry")) + self.retry_btn.setObjectName("secondary") + self.retry_btn.hide() + btn_row.addWidget(self.retry_btn) + + self.close_btn = QPushButton(self.i18n.t("update.close")) + self.close_btn.setObjectName("secondary") + self.close_btn.clicked.connect(self.reject) + btn_row.addWidget(self.close_btn) + layout.addLayout(btn_row) + + def _repair_options(self) -> dict: + return { + "runtime_variant": self.config.selected_runtime_variant or "auto", + "runtime_install_location": self.config.runtime_install_location_preference, + "features": list(FULL_FEATURE_SET), + "auto_update_enabled": self.config.auto_check_updates, + } + + def start_repair(self) -> None: + self.retry_btn.hide() + self._repair_running = True + self._closing_after_interrupt = False + self._progress.reset() + self.stage_label.setText(self.i18n.t("repair.running")) + self.log_view.append(self.i18n.t("repair.log_retry")) + self.manager.start_repair(self._repair_options()) + + def _on_repair_success(self, _summary: object) -> None: + self._repair_running = False + self.stage_label.setText(self.i18n.t("repair.success")) + self.log_view.append(f"[done] {self.i18n.t('repair.success')}") + parent = self.parent() + if parent is not None and hasattr(parent, "_resume_post_initialization_flow"): + cast(_PostInitializationFlowHost, parent)._resume_post_initialization_flow() + + def _on_repair_failure(self, summary: object) -> None: + self._repair_running = False + if isinstance(summary, dict) and summary.get("interrupted"): + if not self._closing_after_interrupt: + self.stage_label.setText(self.i18n.t("onboarding.initialization_interrupted")) + self.log_view.append(self.i18n.t("onboarding.initialization_interrupted")) + return + self.retry_btn.show() + error_text = ( + summary.get("error", self.i18n.t("repair.failed")) + if isinstance(summary, dict) + else self.i18n.t("repair.failed") + ) + self.stage_label.setText(error_text) + self.log_view.append(f"[failed] {error_text}") + + def _confirm_interrupt_repair(self) -> bool: + reply = StyledMessageBox.question( + self, + self.i18n.t("onboarding.close_confirm_title"), + self.i18n.t("onboarding.close_confirm_message"), + yes_text=self.i18n.t("onboarding.close_confirm_exit"), + no_text=self.i18n.t("onboarding.close_confirm_continue"), + ) + return reply == StyledMessageBox.Yes + + def reject(self) -> None: + if self._repair_running and not self._closing_after_interrupt: + if not self._confirm_interrupt_repair(): + return + self._closing_after_interrupt = True + self.manager.cancel() + super().reject() + + def closeEvent(self, event) -> None: + if self._repair_running and not self._closing_after_interrupt: + if not self._confirm_interrupt_repair(): + event.ignore() + return + self._closing_after_interrupt = True + self.manager.cancel() + super().closeEvent(event) + + +class WelcomeOnboardingDialog(QDialog): + onboarding_completed = Signal(str, bool) + + def __init__(self, i18n, parent=None): + super().__init__(parent) + self.i18n = i18n + self.config = get_advanced_config() + self.current_page = 0 + self.selected_level = self.config.skill_level or "intermediate" + self.auto_update_enabled = self.config.auto_check_updates + self._dots: list[QLabel] = [] + self._skill_cards: dict[str, SkillLevelCard] = {} + self._update_cards: dict[str, SelectableCard] = {} + self._feature_boxes: dict[str, LockedFeatureCheckBox] = {} + self._runtime_status_labels: list[QLabel] = [] + self._initialization_complete = False + self._initialization_running = False + self._closing_after_interrupt = False + + self.initialization_manager = InitializationManager(self) + self.selected_runtime_install_location = ( + self.initialization_manager.choose_runtime_install_location().key + ) + + self.setModal(True) + self.setWindowTitle(self.i18n.t("onboarding.window_title")) + self.setFixedSize(640, 520) + self.setStyleSheet(DIALOG_STYLE) + + self._setup_ui() + self._progress = InitializationProgressBinder( + self.initialization_manager, + set_stage_text=self.stage_label.setText, + set_progress_value=self.progress_bar.setValue, + append_log=self.log_view.append, + on_success=self._on_initialization_succeeded, + on_failure=self._on_initialization_failed, + parent=self, + ) + self._sync_defaults() + self._set_current_page(0, force=True) + + def get_selected_options(self) -> dict: + return { + "skill_level": self.selected_level, + "auto_update_enabled": self.auto_update_enabled, + "runtime_variant": self.config.selected_runtime_variant or "auto", + "runtime_install_location": self.selected_runtime_install_location, + "features": list(FULL_FEATURE_SET), + } + + def _create_page_widget(self) -> tuple[QWidget, QVBoxLayout]: + page = QWidget() + layout = QVBoxLayout(page) + layout.setContentsMargins(16, 4, 16, 4) + layout.setSpacing(12) + return page, layout + + def _create_text_label(self, text: str, style: str, *, word_wrap: bool = True) -> QLabel: + label = QLabel(text) + label.setAlignment(ALIGN_CENTER) + label.setWordWrap(word_wrap) + label.setStyleSheet(style) + return label + + def _create_nav_button(self, text: str, handler: Callable[[], None], *, secondary: bool = False) -> QPushButton: + button = QPushButton(text) + if secondary: + button.setObjectName("secondary") + button.setFixedSize(120, 38) + button.clicked.connect(handler) + return button + + def _create_card_row(self, cards: list[QWidget], *, spacing: int = 12) -> QHBoxLayout: + row = QHBoxLayout() + row.setSpacing(spacing) + row.setAlignment(ALIGN_CENTER) + for card in cards: + row.addWidget(card) + return row + + def _page_count(self) -> int: + return self.stack.count() + + def _is_initialization_page(self, page_index: int) -> bool: + return page_index == self._page_count() - 1 + + def _is_preparation_page(self, page_index: int) -> bool: + return page_index == self._page_count() - 2 + + def _nav_state_for_page(self, page_index: int) -> _NavState: + is_init_page = self._is_initialization_page(page_index) + if self._initialization_complete and is_init_page: + next_text = self.i18n.t("onboarding.finish") + elif self._is_preparation_page(page_index) and self._preparation_can_finish(): + next_text = self.i18n.t("onboarding.finish") + elif self._is_preparation_page(page_index): + next_text = self.i18n.t("onboarding.start_initialization") + else: + next_text = self.i18n.t("onboarding.next") + return _NavState( + prev_enabled=page_index > 0 and not is_init_page, + next_text=next_text, + next_enabled=not is_init_page or self._initialization_complete, + background_visible=False, + retry_visible=is_init_page and not self._initialization_complete and self.retry_btn.isVisible(), + ) + + def _apply_nav_state(self, state: _NavState) -> None: + self.prev_btn.setEnabled(state.prev_enabled) + self.next_btn.setText(state.next_text) + self.next_btn.setEnabled(state.next_enabled) + self.retry_btn.setVisible(state.retry_visible) + + def _refresh_nav_state(self) -> None: + self._apply_nav_state(self._nav_state_for_page(self.current_page)) + + def _setup_ui(self): + root = QVBoxLayout(self) + root.setContentsMargins(24, 24, 24, 24) + root.setSpacing(18) + + self.stack = QStackedWidget() + for page_builder in ( + self._build_welcome_page, + self._build_update_page, + self._build_skill_level_page, + self._build_feature_page, + self._build_runtime_status_page, + self._build_initialization_page, + ): + self.stack.addWidget(page_builder()) + root.addWidget(self.stack, 1) + + dots_layout = QHBoxLayout() + dots_layout.setSpacing(10) + dots_layout.setAlignment(ALIGN_CENTER) + for _ in range(self._page_count()): + dot = QLabel() + dot.setFixedSize(10, 10) + dots_layout.addWidget(dot) + self._dots.append(dot) + root.addLayout(dots_layout) + + nav_layout = QHBoxLayout() + nav_layout.setAlignment(ALIGN_CENTER) + nav_layout.setSpacing(12) + + self.prev_btn = self._create_nav_button(self.i18n.t("onboarding.previous"), self._go_previous, secondary=True) + nav_layout.addWidget(self.prev_btn) + + self.retry_btn = self._create_nav_button(self.i18n.t("repair.retry"), self._retry_initialization, secondary=True) + self.retry_btn.hide() + nav_layout.addWidget(self.retry_btn) + + self.next_btn = self._create_nav_button(self.i18n.t("onboarding.next"), self._go_next) + nav_layout.addWidget(self.next_btn) + + root.addLayout(nav_layout) + + def _build_welcome_page(self) -> QWidget: + page, layout = self._create_page_widget() + layout.addStretch() + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.lite_welcome_title"), PAGE_TITLE_STYLE)) + layout.addWidget( + self._create_text_label(self.i18n.t("onboarding.lite_welcome_subtitle"), BODY_SUBTITLE_STYLE) + ) + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.lite_welcome_hint"), HINT_STYLE)) + layout.addStretch() + return page + + def _build_update_page(self) -> QWidget: + page, layout = self._create_page_widget() + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.update_title"), PAGE_TITLE_STYLE)) + layout.addWidget( + self._create_text_label(self.i18n.t("onboarding.update_subtitle"), BODY_SUBTITLE_STYLE) + ) + cards = [] + for option_key in UPDATE_OPTION_KEYS: + card = SelectableCard( + option_key, + self.i18n.t(f"onboarding.update_{option_key}_title"), + self.i18n.t(f"onboarding.update_{option_key}_desc"), + ) + card.clicked.connect(self._on_update_option_clicked) + self._update_cards[option_key] = card + cards.append(card) + layout.addLayout(self._create_card_row(cards)) + layout.addStretch() + return page + + def _build_skill_level_page(self) -> QWidget: + page, layout = self._create_page_widget() + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.skill_title"), PAGE_TITLE_STYLE)) + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.skill_subtitle"), BODY_SUBTITLE_STYLE)) + cards = [] + for level_key in SKILL_LEVEL_KEYS: + card = SkillLevelCard(level_key, self.i18n) + card.clicked.connect(self._on_skill_level_clicked) + self._skill_cards[level_key] = card + cards.append(card) + layout.addLayout(self._create_card_row(cards)) + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.skill_hint"), HINT_STYLE)) + layout.addStretch() + return page + + def _build_feature_page(self) -> QWidget: + page, layout = self._create_page_widget() + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.features_title"), PAGE_TITLE_STYLE)) + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.features_subtitle"), BODY_SUBTITLE_STYLE)) + for feature_key in FULL_FEATURE_SET: + checkbox = LockedFeatureCheckBox(self.i18n.t(f"onboarding.feature_{feature_key}_label")) + self._feature_boxes[feature_key] = checkbox + layout.addWidget(checkbox) + layout.addStretch() + return page + + def _build_runtime_status_page(self) -> QWidget: + page, layout = self._create_page_widget() + layout.addWidget( + self._create_text_label(self.i18n.t("onboarding.runtime_status_title"), PAGE_TITLE_STYLE) + ) + self.runtime_status_label = self._create_text_label("", BODY_SUBTITLE_STYLE) + layout.addWidget(self.runtime_status_label) + status_layout = QVBoxLayout() + status_layout.setContentsMargins(0, 10, 0, 0) + status_layout.setSpacing(8) + for _ in range(5): + label = StatusBulletLabel() + self._runtime_status_labels.append(label) + status_layout.addWidget(label) + layout.addLayout(status_layout) + layout.addStretch() + self._refresh_runtime_status_page() + return page + + def _build_initialization_page(self) -> QWidget: + page, layout = self._create_page_widget() + layout.addWidget(self._create_text_label(self.i18n.t("onboarding.initialization_title"), PAGE_TITLE_STYLE)) + self.stage_label = self._create_text_label(self.i18n.t("onboarding.initialization_waiting"), BODY_SUBTITLE_STYLE) + layout.addWidget(self.stage_label) + self.progress_bar = RoundedProgressBar() + self.progress_bar.setRange(0, 100) + self.progress_bar.setValue(0) + self.progress_bar.setTextVisible(False) + layout.addWidget(self.progress_bar) + self.log_view = QTextEdit() + self.log_view.setReadOnly(True) + layout.addWidget(self.log_view, 1) + return page + + def _runtime_hint_text(self) -> str: + if self.initialization_manager.check_runtime_health(): + return self.i18n.t("onboarding.runtime_check_passed") + runtime_selection = self.initialization_manager.detect_runtime_selection( + self.config.selected_runtime_variant or "auto" + ) + if runtime_selection.variant == "cuda": + return self.i18n.t("onboarding.runtime_hint_cuda") + if runtime_selection.variant == "mac": + return self.i18n.t("onboarding.runtime_hint_mac") + return self.i18n.t("onboarding.runtime_hint_cpu") + + def _runtime_status_lines(self) -> list[str]: + runtime_ready = self.initialization_manager.check_runtime_health() + runtime_selection = self.initialization_manager.detect_runtime_selection( + self.config.selected_runtime_variant or "auto" + ) + resolved_runtime_dir = self.initialization_manager.runtime_display_dir( + self.selected_runtime_install_location + ) + install_policy = ( + self.i18n.t("onboarding.runtime_status_policy_windows") + if sys.platform == "win32" + else self.i18n.t("onboarding.runtime_status_policy_mac") + ) + health_line = ( + self.i18n.t("onboarding.runtime_status_item_ready") + if runtime_ready + else self.i18n.t("onboarding.runtime_status_item_pending") + ) + variant_line = self.i18n.t( + "onboarding.runtime_status_item_variant", + variant=runtime_selection.variant.upper(), + ) + source_line = self.i18n.t( + "onboarding.runtime_status_item_source", + detail=( + self.i18n.t("onboarding.runtime_status_result_ready") + if runtime_ready + else self.i18n.t("onboarding.runtime_status_result_pending") + ), + ) + path_line = self.i18n.t("onboarding.runtime_status_path", path=str(resolved_runtime_dir)) + return [health_line, variant_line, install_policy, source_line, path_line] + + def _refresh_runtime_status_page(self) -> None: + lines = self._runtime_status_lines() + self.runtime_status_label.setText(self._runtime_hint_text()) + for label, text in zip(self._runtime_status_labels, lines): + label.setText(text) + + def _apply_single_selection(self, cards: Mapping[str, _SelectableCardLike], selected_key: str): + for key, card in cards.items(): + card.set_selected(key == selected_key) + + def _sync_defaults(self): + self._set_auto_update_enabled(self.auto_update_enabled, force=True) + self._set_skill_level(self.selected_level, force=True) + + def _set_auto_update_enabled(self, enabled: bool, *, force: bool = False): + if not force and self.auto_update_enabled == enabled: + return + self.auto_update_enabled = enabled + self._apply_single_selection(self._update_cards, "enabled" if enabled else "disabled") + + def _set_skill_level(self, level_key: str, *, force: bool = False): + if not force and self.selected_level == level_key: + return + self.selected_level = level_key + self._apply_single_selection(self._skill_cards, level_key) + + def _preparation_can_finish(self) -> bool: + return not self.initialization_manager.needs_initialization(FULL_FEATURE_SET) + + def _set_current_page(self, page_index: int, *, force: bool = False): + if not 0 <= page_index < self._page_count(): + return + if not force and self.current_page == page_index: + return + + self.current_page = page_index + self.stack.setCurrentIndex(page_index) + if self._is_preparation_page(page_index): + self._refresh_runtime_status_page() + self._refresh_nav_state() + for index, dot in enumerate(self._dots): + dot.setStyleSheet(DOT_ACTIVE_STYLE if index == page_index else DOT_INACTIVE_STYLE) + + def _start_initialization(self): + self._initialization_complete = False + self._initialization_running = True + self._closing_after_interrupt = False + self._set_current_page(self._page_count() - 1) + self._progress.reset() + self.log_view.append(self.i18n.t("onboarding.log_start")) + self._refresh_nav_state() + self.initialization_manager.start(self.get_selected_options()) + + def _complete_onboarding(self): + self.onboarding_completed.emit(self.selected_level, self.auto_update_enabled) + self.accept() + + def _on_update_option_clicked(self, option_key: str): + self._set_auto_update_enabled(option_key == "enabled") + + def _on_skill_level_clicked(self, level_key: str): + self._set_skill_level(level_key) + + def _go_previous(self): + self._set_current_page(self.current_page - 1) + + def _go_next(self): + if self._is_initialization_page(self.current_page): + if self._initialization_complete: + self._complete_onboarding() + return + if self._is_preparation_page(self.current_page): + if self._preparation_can_finish(): + self._complete_onboarding() + return + self._start_initialization() + return + self._set_current_page(self.current_page + 1) + + def _retry_initialization(self): + self.retry_btn.hide() + self.log_view.append(self.i18n.t("onboarding.log_retry")) + self._progress.reset() + self._initialization_running = True + self._closing_after_interrupt = False + self._refresh_nav_state() + self.initialization_manager.retry_failed() + + def _on_initialization_succeeded(self, _summary: object) -> None: + self._initialization_complete = True + self._initialization_running = False + self.next_btn.setText(self.i18n.t("onboarding.finish")) + self.next_btn.setEnabled(True) + self.retry_btn.hide() + self._refresh_runtime_status_page() + self._refresh_nav_state() + QApplication.processEvents() + + def _on_initialization_failed(self, summary: object) -> None: + self._initialization_complete = False + self._initialization_running = False + if isinstance(summary, dict) and summary.get("interrupted"): + if not self._closing_after_interrupt: + self.stage_label.setText(self.i18n.t("onboarding.initialization_interrupted")) + self.log_view.append(self.i18n.t("onboarding.initialization_interrupted")) + self._refresh_nav_state() + return + self.retry_btn.show() + self.next_btn.setEnabled(False) + error_text = ( + summary.get("error", self.i18n.t("onboarding.initialization_failed")) + if isinstance(summary, dict) + else self.i18n.t("onboarding.initialization_failed") + ) + self.stage_label.setText(error_text) + self.log_view.append(f"[failed] {error_text}") + self._refresh_nav_state() + + def _confirm_interrupt_initialization(self) -> bool: + reply = StyledMessageBox.question( + self, + self.i18n.t("onboarding.close_confirm_title"), + self.i18n.t("onboarding.close_confirm_message"), + yes_text=self.i18n.t("onboarding.close_confirm_exit"), + no_text=self.i18n.t("onboarding.close_confirm_continue"), + ) + return reply == StyledMessageBox.Yes + + def reject(self) -> None: + if self._initialization_running and not self._closing_after_interrupt: + if not self._confirm_interrupt_initialization(): + return + self._closing_after_interrupt = True + self.initialization_manager.cancel() + super().reject() + + def closeEvent(self, event) -> None: + if self._initialization_running and not self._closing_after_interrupt: + if not self._confirm_interrupt_initialization(): + event.ignore() + return + self._closing_after_interrupt = True + self.initialization_manager.cancel() + super().closeEvent(event) diff --git a/workflows/dev_docs/project_structure.md b/workflows/dev_docs/project_structure.md index 0d2a6b8..943dbfd 100644 --- a/workflows/dev_docs/project_structure.md +++ b/workflows/dev_docs/project_structure.md @@ -117,7 +117,7 @@ ExifTool 二进制文件(macOS 和 Windows)。 | `SuperPicky.spec` | PyInstaller 打包配置 | | `entitlements.plist` | macOS 权限声明 | | `create_pkg_dmg_v4.0.0.sh` | PKG/DMG 构建脚本 | -| `build_release.sh` | 发布构建脚本 | +| `build_release_mac.py` / `build_release_win.py` | 发布构建主脚本 | | `requirements.txt` | Python 依赖 | --- diff --git a/workflows/intel-build.md b/workflows/intel-build.md index c228fb7..2991b11 100644 --- a/workflows/intel-build.md +++ b/workflows/intel-build.md @@ -31,10 +31,10 @@ rm -rf build/ dist/ ### 步骤 2: 运行打包脚本 ```bash # 开发测试 (不公证,快速打包) -./build_release.sh +python3 build_release_mac.py --build-type full --arch x86_64 # 正式发布 (签名+公证) -./build_release.sh --release +python3 build_release_mac.py --build-type full --arch x86_64 --notarize --sign-p12 /path/to/certificate.p12 --sign-p12-password-env MACOS_CERTIFICATE_PWD --apple-id "your@email.com" --team-id "YOUR_TEAM_ID" ``` ## 常见问题处理