diff --git a/selfupdate/MIGRATION.md b/selfupdate/MIGRATION.md index 0d6d3a0..2e4c6bf 100644 --- a/selfupdate/MIGRATION.md +++ b/selfupdate/MIGRATION.md @@ -14,6 +14,7 @@ client := selfupdate.Client{ BinaryName: "agentsview", // or "msgvault" CurrentVersion: version, CacheDir: appCacheDir, + GitHubToken: selfupdate.EnvironmentGitHubToken(), // optional API fallback auth AllowUnsignedChecksums: true, // current CLI releases publish SHA256SUMS only } ``` @@ -27,6 +28,50 @@ Use `client.Install(ctx, info, selfupdate.InstallOptions{Progress: progress})` where the current command calls `PerformUpdate`. CLI output, config loading, confirmation prompts, and command wiring should stay in the application. +## Release Discovery + +By default, `Check` avoids unauthenticated `api.github.com` release discovery. +It follows `https://github.com///releases/latest` to the release +tag, constructs the conventional archive URL, and reads `SHA256SUMS` from the +release downloads. If that web path fails, it falls back to the GitHub REST API. +Set `GitHubToken` to authenticate only the API fallback request; kit never sends +that token to release asset or checksum download URLs. + +Set `ReleaseManifestURL` when a project publishes a static latest-release JSON +document, such as from a docs site or CDN. The smallest useful manifest only +needs the current release tag: + +```json +{ + "tag_name": "v1.2.3" +} +``` + +With only a tag, kit uses the same conventional release asset and `SHA256SUMS` +URLs as web redirect discovery. Projects with custom asset URLs can instead +publish the same compact shape as the GitHub release fields kit consumes: + +```json +{ + "tag_name": "v1.2.3", + "assets": [ + { + "name": "agentsview_1.2.3_darwin_arm64.tar.gz", + "size": 123456, + "browser_download_url": "https://github.com/kenn-io/agentsview/releases/download/v1.2.3/agentsview_1.2.3_darwin_arm64.tar.gz" + }, + { + "name": "SHA256SUMS", + "browser_download_url": "https://github.com/kenn-io/agentsview/releases/download/v1.2.3/SHA256SUMS" + } + ] +} +``` + +When `ReleaseManifestURL` is set, kit uses it directly instead of probing +GitHub's web or API endpoints. `GitHubWebBaseURL` and `GitHubAPIBaseURL` remain +available for tests and GitHub Enterprise installs. + Install verification fails closed by default and requires signed update metadata. The current agentsview and msgvault CLI release workflows publish archives plus `SHA256SUMS`, but not CLI update signatures or embedded public diff --git a/selfupdate/selfupdate.go b/selfupdate/selfupdate.go index f93bcfe..1a0edfb 100644 --- a/selfupdate/selfupdate.go +++ b/selfupdate/selfupdate.go @@ -11,9 +11,11 @@ import ( "crypto/sha256" "encoding/hex" "encoding/json" + "errors" "fmt" "io" "net/http" + "net/url" "os" "path/filepath" "regexp" @@ -26,6 +28,7 @@ import ( const ( defaultGitHubAPIBaseURL = "https://api.github.com" + defaultGitHubWebBaseURL = "https://github.com" defaultCacheFileName = "update_check.json" defaultCacheDuration = time.Hour defaultDevCacheDuration = 15 * time.Minute @@ -35,6 +38,11 @@ const ( legacyTarRegularType = byte(0) ) +var ( + errNonHTTPSRedirect = errors.New("redirect to non-HTTPS URL") + errChecksumAssetNotFound = errors.New("checksum asset not found") +) + // Release represents the subset of a GitHub release response used by Client. type Release struct { TagName string `json:"tag_name"` @@ -73,8 +81,11 @@ type Client struct { HTTPClient *http.Client Clock func() time.Time - GitHubAPIBaseURL string - UserAgent string + GitHubAPIBaseURL string + GitHubWebBaseURL string + ReleaseManifestURL string + GitHubToken string + UserAgent string CacheFileName string CacheDuration time.Duration @@ -133,6 +144,16 @@ func (c Client) Check(ctx context.Context, opts CheckOptions) (*Info, error) { if err := c.validateCheckConfig(); err != nil { return nil, err } + if c.ReleaseManifestURL != "" { + if err := requireHTTPSURL(c.ReleaseManifestURL, "release manifest URL"); err != nil { + return nil, fmt.Errorf("check for updates: %w", err) + } + } + if c.unsignedChecksumsAllowed() { + if err := c.validateUnsignedBaseURLs(); err != nil { + return nil, fmt.Errorf("check for updates: %w", err) + } + } currentVersion := c.CurrentVersion cleanVersion := strings.TrimPrefix(currentVersion, "v") @@ -148,28 +169,84 @@ func (c Client) Check(ctx context.Context, opts CheckOptions) (*Info, error) { if err != nil { return nil, fmt.Errorf("check for updates: %w", err) } - - _ = c.saveCache(release.TagName) + if err := c.validateUnsignedReleaseSource(release); err != nil { + return nil, fmt.Errorf("check for updates: %w", err) + } latestVersion := strings.TrimPrefix(release.TagName, "v") if !shouldOfferUpdate(latestVersion, cleanVersion, isDevBuild) { + _ = c.saveCache(release.TagName) return nil, nil } + webConventionalRelease := c.ReleaseManifestURL == "" && len(release.Assets) == 0 + if len(release.Assets) == 0 { + if err := c.addConventionalAssets(ctx, release, opts); err != nil { + if c.ReleaseManifestURL != "" { + return nil, fmt.Errorf("check for updates: %w", err) + } + apiRelease, apiErr := c.fetchLatestReleaseFromAPI(ctx) + if apiErr != nil { + return nil, fmt.Errorf("check for updates: %w (GitHub API fallback also failed: %w)", err, apiErr) + } + release = apiRelease + latestVersion = strings.TrimPrefix(release.TagName, "v") + if !shouldOfferUpdate(latestVersion, cleanVersion, isDevBuild) { + _ = c.saveCache(release.TagName) + return nil, nil + } + } + if err := c.validateUnsignedAssetURLs(release.Assets); err != nil { + return nil, fmt.Errorf("check for updates: %w", err) + } + } + _ = c.saveCache(release.TagName) + goos, goarch := platform(opts) assetName := c.platformAssetName(release, latestVersion, opts) - asset, checksumsAsset, signatureAsset := c.findAssets(release.Assets, assetName) + asset, checksumsAssets, signatureAsset := c.findAssets(release.Assets, assetName) if asset == nil { return nil, fmt.Errorf("no release asset for %s/%s", goos, goarch) } - var checksum string - if checksumsAsset != nil { - checksum, _ = c.fetchChecksumFromFile(ctx, checksumsAsset.BrowserDownloadURL, assetName) + checksum, err := c.fetchChecksumFromAssets(ctx, checksumsAssets, assetName) + if err != nil { + return nil, fmt.Errorf("check for updates: %w", err) } if checksum == "" { checksum = ExtractChecksum(release.Body, assetName) } + if checksum == "" && webConventionalRelease { + apiRelease, apiErr := c.fetchLatestReleaseFromAPI(ctx) + if apiErr != nil { + return nil, fmt.Errorf("check for updates: conventional release missing checksum for %s (GitHub API fallback also failed: %w)", assetName, apiErr) + } + release = apiRelease + latestVersion = strings.TrimPrefix(release.TagName, "v") + if !shouldOfferUpdate(latestVersion, cleanVersion, isDevBuild) { + _ = c.saveCache(release.TagName) + return nil, nil + } + if err := c.validateUnsignedReleaseSource(release); err != nil { + return nil, fmt.Errorf("check for updates: %w", err) + } + _ = c.saveCache(release.TagName) + assetName = c.platformAssetName(release, latestVersion, opts) + asset, checksumsAssets, signatureAsset = c.findAssets(release.Assets, assetName) + if asset == nil { + return nil, fmt.Errorf("no release asset for %s/%s", goos, goarch) + } + checksum, err = c.fetchChecksumFromAssets(ctx, checksumsAssets, assetName) + if err != nil { + return nil, fmt.Errorf("check for updates: %w", err) + } + if checksum == "" { + checksum = ExtractChecksum(release.Body, assetName) + } + if checksum == "" { + return nil, fmt.Errorf("check for updates: no checksum for %s", assetName) + } + } return &Info{ CurrentVersion: currentVersion, @@ -227,7 +304,7 @@ func (c Client) Install(ctx context.Context, info *Info, opts InstallOptions) er defer os.RemoveAll(tempDir) archivePath := filepath.Join(tempDir, assetName) - downloadChecksum, err := c.downloadFile(ctx, info.DownloadURL, archivePath, info.Size, opts.Progress) + downloadChecksum, err := c.downloadFile(ctx, info.DownloadURL, archivePath, info.Size, c.unsignedChecksumsAllowed(), opts.Progress) if err != nil { return fmt.Errorf("download: %w", err) } @@ -628,6 +705,17 @@ func FormatSize(bytes int64) string { return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) } +// EnvironmentGitHubToken returns a GitHub API token from GH_TOKEN or +// GITHUB_TOKEN, using the same precedence as the gh CLI. +func EnvironmentGitHubToken() string { + for _, key := range []string{"GH_TOKEN", "GITHUB_TOKEN"} { + if token := strings.TrimSpace(os.Getenv(key)); token != "" { + return token + } + } + return "" +} + type cachedCheck struct { CheckedAt time.Time `json:"checked_at"` Version string `json:"version"` @@ -653,6 +741,48 @@ func (c Client) httpClient() *http.Client { return &http.Client{Timeout: defaultHTTPTimeout} } +func (c Client) doHTTPRequest(req *http.Request, requireHTTPSRedirects bool) (*http.Response, error) { + client := c.httpClient() + if requireHTTPSRedirects { + if req.URL == nil || req.URL.Scheme != "https" { + rawURL := "" + if req.URL != nil { + rawURL = req.URL.Redacted() + } + return nil, fmt.Errorf("%w: %s", errNonHTTPSRedirect, rawURL) + } + client = c.httpClientRejectingHTTPSDowngrades() + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + if requireHTTPSRedirects && resp.Request != nil && resp.Request.URL != nil && resp.Request.URL.Scheme != "https" { + resp.Body.Close() + return nil, fmt.Errorf("%w: %s", errNonHTTPSRedirect, resp.Request.URL.Redacted()) + } + return resp, nil +} + +func (c Client) httpClientRejectingHTTPSDowngrades() *http.Client { + base := c.httpClient() + client := *base + originalCheckRedirect := client.CheckRedirect + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if req.URL.Scheme != "https" { + return fmt.Errorf("%w: %s", errNonHTTPSRedirect, req.URL.Redacted()) + } + if originalCheckRedirect != nil { + return originalCheckRedirect(req, via) + } + if len(via) >= 10 { + return fmt.Errorf("stopped after 10 redirects") + } + return nil + } + return &client +} + func (c Client) now() time.Time { if c.Clock != nil { return c.Clock() @@ -667,6 +797,25 @@ func (c Client) apiBaseURL() string { return defaultGitHubAPIBaseURL } +func (c Client) webBaseURL() string { + if c.GitHubWebBaseURL != "" { + return strings.TrimRight(c.GitHubWebBaseURL, "/") + } + apiBase := c.apiBaseURL() + parsed, err := url.Parse(apiBase) + if err != nil || parsed.Host == "" { + return defaultGitHubWebBaseURL + } + if strings.EqualFold(parsed.Hostname(), "api.github.com") { + parsed.Host = "github.com" + } + parsed.Path = "" + parsed.RawPath = "" + parsed.RawQuery = "" + parsed.Fragment = "" + return strings.TrimRight(parsed.String(), "/") +} + func (c Client) userAgent() string { if c.UserAgent != "" { return c.UserAgent @@ -780,15 +929,181 @@ func (c Client) checkCache(currentVersion, cleanVersion string, isDevBuild bool) } func (c Client) fetchLatestRelease(ctx context.Context) (*Release, error) { + if c.ReleaseManifestURL != "" { + return c.fetchReleaseManifest(ctx) + } + release, err := c.fetchLatestReleaseFromWeb(ctx) + if err == nil { + return release, nil + } + apiRelease, apiErr := c.fetchLatestReleaseFromAPI(ctx) + if apiErr != nil { + return nil, fmt.Errorf("%w (GitHub API fallback also failed: %w)", err, apiErr) + } + return apiRelease, nil +} + +func (c Client) validateUnsignedBaseURLs() error { + if err := requireHTTPSURL(c.webBaseURL(), "GitHub web base URL"); err != nil { + return err + } + if err := requireHTTPSURL(c.apiBaseURL(), "GitHub API base URL"); err != nil { + return err + } + if c.ReleaseManifestURL != "" { + if err := requireHTTPSURL(c.ReleaseManifestURL, "release manifest URL"); err != nil { + return err + } + } + return nil +} + +func (c Client) validateUnsignedReleaseSource(release *Release) error { + if !c.unsignedChecksumsAllowed() { + return nil + } + if len(release.Assets) == 0 { + if err := requireHTTPSURL(c.webBaseURL(), "GitHub web base URL"); err != nil { + return err + } + } + return c.validateUnsignedAssetURLs(release.Assets) +} + +func (c Client) validateUnsignedAssetURLs(assets []Asset) error { + if !c.unsignedChecksumsAllowed() { + return nil + } + for _, asset := range assets { + if asset.BrowserDownloadURL == "" { + continue + } + if err := requireHTTPSURL(asset.BrowserDownloadURL, "release asset URL for "+asset.Name); err != nil { + return err + } + } + return nil +} + +func (c Client) unsignedChecksumsAllowed() bool { + return c.AllowUnsignedChecksums && !c.RequireSignature && len(c.TrustedPublicKeys) == 0 +} + +func (c Client) requireHTTPSForUnsignedChecksums() bool { + return c.unsignedChecksumsAllowed() +} + +func requireHTTPSURL(rawURL, label string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("%s is invalid: %w", label, err) + } + if u.Scheme != "https" { + return fmt.Errorf("%s must use https", label) + } + return nil +} + +func (c Client) fetchReleaseManifest(ctx context.Context) (*Release, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.ReleaseManifestURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", c.userAgent()) + + resp, err := c.doHTTPRequest(req, true) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("release manifest returned %s", resp.Status) + } + + var release Release + if err := json.NewDecoder(resp.Body).Decode(&release); err != nil { + return nil, err + } + if release.TagName == "" { + return nil, fmt.Errorf("release manifest missing tag_name") + } + return &release, nil +} + +func (c Client) fetchLatestReleaseFromWeb(ctx context.Context) (*Release, error) { + pageURL := fmt.Sprintf("%s/%s/%s/releases/latest", c.webBaseURL(), c.Owner, c.Repo) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, pageURL, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", c.userAgent()) + + resp, err := c.doHTTPRequest(req, c.requireHTTPSForUnsignedChecksums()) + if err != nil { + return nil, fmt.Errorf("fetch latest release page: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("latest release page returned %s", resp.Status) + } + + finalURL := req.URL + if resp.Request != nil && resp.Request.URL != nil { + finalURL = resp.Request.URL + } + tag, err := releaseTagFromURL(finalURL) + if err != nil { + return nil, err + } + return &Release{TagName: tag}, nil +} + +func (c Client) addConventionalAssets(ctx context.Context, release *Release, opts CheckOptions) error { + latestVersion := strings.TrimPrefix(release.TagName, "v") + assetName := c.platformAssetName(release, latestVersion, opts) + downloadBase := fmt.Sprintf("%s/%s/%s/releases/download/%s", c.webBaseURL(), c.Owner, c.Repo, release.TagName) + assetDownloadURL := downloadBase + "/" + assetName + size, err := c.fetchContentLength(ctx, assetDownloadURL) + if err != nil { + return err + } + + assets := []Asset{ + {Name: assetName, Size: size, BrowserDownloadURL: assetDownloadURL}, + } + for _, checksumAssetName := range c.checksumAssetNames() { + assets = append(assets, Asset{Name: checksumAssetName, BrowserDownloadURL: downloadBase + "/" + checksumAssetName}) + } + for _, signatureAssetName := range signatureAssetNames(assetName) { + signatureURL := downloadBase + "/" + signatureAssetName + if c.releaseAssetExists(ctx, signatureURL) { + assets = append(assets, Asset{Name: signatureAssetName, BrowserDownloadURL: signatureURL}) + } + } + release.Assets = assets + return nil +} + +func (c Client) fetchLatestReleaseFromAPI(ctx context.Context) (*Release, error) { url := fmt.Sprintf("%s/repos/%s/%s/releases/latest", c.apiBaseURL(), c.Owner, c.Repo) + token := strings.TrimSpace(c.GitHubToken) + if token != "" { + if err := requireHTTPSURL(c.apiBaseURL(), "GitHub API base URL"); err != nil { + return nil, err + } + } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } req.Header.Set("Accept", "application/vnd.github.v3+json") req.Header.Set("User-Agent", c.userAgent()) + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } - resp, err := c.httpClient().Do(req) + resp, err := c.doHTTPRequest(req, token != "" || c.requireHTTPSForUnsignedChecksums()) if err != nil { return nil, err } @@ -805,6 +1120,57 @@ func (c Client) fetchLatestRelease(ctx context.Context) (*Release, error) { return &release, nil } +func (c Client) fetchContentLength(ctx context.Context, rawURL string) (int64, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, rawURL, nil) + if err != nil { + return 0, err + } + req.Header.Set("User-Agent", c.userAgent()) + + resp, err := c.doHTTPRequest(req, c.requireHTTPSForUnsignedChecksums()) + if err != nil { + return 0, err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return 0, fmt.Errorf("release asset returned %s", resp.Status) + } + if resp.ContentLength < 0 { + return 0, nil + } + return resp.ContentLength, nil +} + +func (c Client) releaseAssetExists(ctx context.Context, rawURL string) bool { + req, err := http.NewRequestWithContext(ctx, http.MethodHead, rawURL, nil) + if err != nil { + return false + } + req.Header.Set("User-Agent", c.userAgent()) + + resp, err := c.doHTTPRequest(req, c.requireHTTPSForUnsignedChecksums()) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +func releaseTagFromURL(u *url.URL) (string, error) { + const marker = "/releases/tag/" + idx := strings.Index(u.Path, marker) + if idx < 0 { + return "", fmt.Errorf("latest release did not redirect to a tag (got %s)", u) + } + tag := u.Path[idx+len(marker):] + if tag == "" { + return "", fmt.Errorf("empty release tag in %s", u) + } + return tag, nil +} + func (c Client) fetchChecksumFromFile(ctx context.Context, url, assetName string) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { @@ -812,13 +1178,16 @@ func (c Client) fetchChecksumFromFile(ctx context.Context, url, assetName string } req.Header.Set("User-Agent", c.userAgent()) - resp, err := c.httpClient().Do(req) + resp, err := c.doHTTPRequest(req, c.requireHTTPSForUnsignedChecksums()) if err != nil { return "", err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + if resp.StatusCode == http.StatusNotFound { + return "", fmt.Errorf("%w: %s", errChecksumAssetNotFound, resp.Status) + } return "", fmt.Errorf("failed to fetch checksums: %s", resp.Status) } @@ -829,14 +1198,33 @@ func (c Client) fetchChecksumFromFile(ctx context.Context, url, assetName string return ExtractChecksum(string(body), assetName), nil } -func (c Client) downloadFile(ctx context.Context, url, dest string, totalSize int64, progress func(downloaded, total int64)) (string, error) { +func (c Client) fetchChecksumFromAssets(ctx context.Context, checksumsAssets []*Asset, assetName string) (string, error) { + for _, checksumsAsset := range checksumsAssets { + checksum, err := c.fetchChecksumFromFile(ctx, checksumsAsset.BrowserDownloadURL, assetName) + if errors.Is(err, errNonHTTPSRedirect) { + return "", err + } + if errors.Is(err, errChecksumAssetNotFound) { + continue + } + if err != nil { + return "", err + } + if checksum != "" { + return checksum, nil + } + } + return "", nil +} + +func (c Client) downloadFile(ctx context.Context, url, dest string, totalSize int64, requireHTTPSRedirects bool, progress func(downloaded, total int64)) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return "", err } req.Header.Set("User-Agent", c.userAgent()) - resp, err := c.httpClient().Do(req) + resp, err := c.doHTTPRequest(req, requireHTTPSRedirects) if err != nil { return "", err } @@ -935,28 +1323,39 @@ func (c Client) platformAssetName(release *Release, version string, opts CheckOp return DefaultAssetName(req) } -func (c Client) findAssets(assets []Asset, assetName string) (asset *Asset, checksumsAsset *Asset, signatureAsset *Asset) { - checksumNames := map[string]struct{}{} - for _, name := range c.checksumAssetNames() { - checksumNames[name] = struct{}{} - } - signatureNames := map[string]struct{}{ - assetName + ".sha256.sig": {}, - assetName + ".sig": {}, - } +func (c Client) findAssets(assets []Asset, assetName string) (asset *Asset, checksumsAssets []*Asset, signatureAsset *Asset) { + checksumAssetsByName := map[string]*Asset{} + signatureAssetsByName := map[string]*Asset{} for i := range assets { a := &assets[i] if a.Name == assetName { asset = a } - if _, ok := checksumNames[a.Name]; ok { - checksumsAsset = a + if _, ok := checksumAssetsByName[a.Name]; !ok { + checksumAssetsByName[a.Name] = a + } + if _, ok := signatureAssetsByName[a.Name]; !ok { + signatureAssetsByName[a.Name] = a + } + } + for _, checksumAssetName := range c.checksumAssetNames() { + if checksumsAsset := checksumAssetsByName[checksumAssetName]; checksumsAsset != nil { + checksumsAssets = append(checksumsAssets, checksumsAsset) } - if _, ok := signatureNames[a.Name]; ok { - signatureAsset = a + } + for _, signatureAssetName := range signatureAssetNames(assetName) { + if signatureAsset = signatureAssetsByName[signatureAssetName]; signatureAsset != nil { + break } } - return asset, checksumsAsset, signatureAsset + return asset, checksumsAssets, signatureAsset +} + +func signatureAssetNames(assetName string) []string { + return []string{ + assetName + ".sha256.sig", + assetName + ".sig", + } } func (c Client) signaturePayload(info *Info) []byte { diff --git a/selfupdate/selfupdate_test.go b/selfupdate/selfupdate_test.go index 1e259fd..23a3c78 100644 --- a/selfupdate/selfupdate_test.go +++ b/selfupdate/selfupdate_test.go @@ -20,6 +20,9 @@ import ( "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const ( @@ -87,6 +90,763 @@ func TestCheckFindsUpdateAndChecksumAsset(t *testing.T) { } } +func TestCheckDiscoversReleaseThroughWebRedirectByDefault(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + require := require.New(t) + var apiRequests atomic.Int64 + var latestPageRequests atomic.Int64 + var checksumRequests atomic.Int64 + assetName := "tool_1.2.0_linux_amd64.tar.gz" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /repos/kenn/tool/releases/latest": + apiRequests.Add(1) + http.Error(w, "api should not be used before web discovery", http.StatusInternalServerError) + case "GET /kenn/tool/releases/latest": + latestPageRequests.Add(1) + http.Redirect(w, r, "/kenn/tool/releases/tag/v1.2.0", http.StatusFound) + case "GET /kenn/tool/releases/tag/v1.2.0": + _, _ = w.Write([]byte("release page")) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName: + w.Header().Set("Content-Length", "123") + w.WriteHeader(http.StatusOK) + case "GET /kenn/tool/releases/download/v1.2.0/SHA256SUMS": + checksumRequests.Add(1) + _, _ = fmt.Fprintf(w, "%s %s\n", testHash64, assetName) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + CacheDir: t.TempDir(), + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + Clock: func() time.Time { return time.Unix(100, 0) }, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.NoError(err) + require.NotNil(info) + assert.Equal("v1.2.0", info.LatestVersion) + assert.Equal(assetName, info.AssetName) + assert.Equal(server.URL+"/kenn/tool/releases/download/v1.2.0/"+assetName, info.DownloadURL) + assert.Equal(testHash64, info.Checksum) + assert.Equal(int64(123), info.Size) + assert.Zero(apiRequests.Load()) + assert.Equal(int64(1), latestPageRequests.Load()) + assert.Equal(int64(1), checksumRequests.Load()) +} + +func TestCheckSkipsConventionalAssetProbeWhenWebTagIsCurrent(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + require := require.New(t) + var assetProbeRequests atomic.Int64 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /kenn/tool/releases/latest": + http.Redirect(w, r, "/kenn/tool/releases/tag/v1.2.0", http.StatusFound) + case "GET /kenn/tool/releases/tag/v1.2.0": + _, _ = w.Write([]byte("release page")) + case "HEAD /kenn/tool/releases/download/v1.2.0/tool_1.2.0_linux_amd64.tar.gz": + assetProbeRequests.Add(1) + http.Error(w, "already-current checks should not probe assets", http.StatusInternalServerError) + case "GET /repos/kenn/tool/releases/latest": + http.Error(w, "api fallback should not be needed", http.StatusInternalServerError) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.2.0", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.NoError(err) + assert.Nil(info) + assert.Zero(assetProbeRequests.Load()) +} + +func TestCheckUsesReleaseManifestBeforeNetworkDiscovery(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + require := require.New(t) + var apiRequests atomic.Int64 + var latestPageRequests atomic.Int64 + assetName := "tool_1.2.0_linux_amd64.tar.gz" + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest.json": + _ = json.NewEncoder(w).Encode(Release{ + TagName: "v1.2.0", + Assets: []Asset{ + {Name: assetName, Size: 123, BrowserDownloadURL: "https://example.invalid/tool"}, + {Name: "SHA256SUMS", BrowserDownloadURL: "https://" + r.Host + "/SHA256SUMS"}, + }, + }) + case "/SHA256SUMS": + _, _ = fmt.Fprintf(w, "%s %s\n", testHash64, assetName) + case "/repos/kenn/tool/releases/latest": + apiRequests.Add(1) + http.Error(w, "api should not be used when manifest is configured", http.StatusInternalServerError) + case "/kenn/tool/releases/latest": + latestPageRequests.Add(1) + http.Error(w, "web discovery should not be used when manifest is configured", http.StatusInternalServerError) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.NoError(err) + require.NotNil(info) + assert.Equal("v1.2.0", info.LatestVersion) + assert.Equal(assetName, info.AssetName) + assert.Equal("https://example.invalid/tool", info.DownloadURL) + assert.Equal(testHash64, info.Checksum) + assert.Zero(apiRequests.Load()) + assert.Zero(latestPageRequests.Load()) +} + +func TestCheckUsesManifestTagWithConventionalAssets(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + require := require.New(t) + assetName := "tool_1.2.0_linux_amd64.tar.gz" + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /latest.json": + _ = json.NewEncoder(w).Encode(Release{TagName: "v1.2.0"}) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName: + w.Header().Set("Content-Length", "123") + w.WriteHeader(http.StatusOK) + case "GET /kenn/tool/releases/download/v1.2.0/SHA256SUMS": + _, _ = fmt.Fprintf(w, "%s %s\n", testHash64, assetName) + case "GET /repos/kenn/tool/releases/latest", "GET /kenn/tool/releases/latest": + http.Error(w, "manifest should be enough", http.StatusInternalServerError) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.NoError(err) + require.NotNil(info) + assert.Equal(assetName, info.AssetName) + assert.Equal(server.URL+"/kenn/tool/releases/download/v1.2.0/"+assetName, info.DownloadURL) + assert.Equal(testHash64, info.Checksum) + assert.Equal(int64(123), info.Size) +} + +func TestCheckRejectsHTTPManifestWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("insecure manifest URL should be rejected before fetch") + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "release manifest URL must use https") +} + +func TestCheckRejectsHTTPManifest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("insecure manifest URL should be rejected before fetch") + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "release manifest URL must use https") +} + +func TestCheckRejectsHTTPManifestAssetWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + assetName := "tool_1.2.0_linux_amd64.tar.gz" + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest.json": + _ = json.NewEncoder(w).Encode(Release{ + TagName: "v1.2.0", + Assets: []Asset{ + {Name: assetName, Size: 123, BrowserDownloadURL: "http://example.invalid/tool"}, + {Name: "SHA256SUMS", BrowserDownloadURL: "https://example.invalid/SHA256SUMS"}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + HTTPClient: server.Client(), + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "release asset URL for "+assetName+" must use https") +} + +func TestCheckRejectsHTTPSManifestRedirectToHTTPWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://example.invalid/latest.json", http.StatusFound) + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + HTTPClient: server.Client(), + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestCheckRejectsHTTPSManifestRedirectToHTTP(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://example.invalid/latest.json", http.StatusFound) + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + HTTPClient: server.Client(), + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestCheckRejectsHTTPSChecksumRedirectToHTTPWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + assetName := "tool_1.2.0_linux_amd64.tar.gz" + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/latest.json": + _ = json.NewEncoder(w).Encode(Release{ + TagName: "v1.2.0", + Assets: []Asset{ + {Name: assetName, Size: 123, BrowserDownloadURL: "https://example.invalid/tool"}, + {Name: "SHA256SUMS", BrowserDownloadURL: "https://" + r.Host + "/SHA256SUMS"}, + }, + }) + case "/SHA256SUMS": + http.Redirect(w, r, "http://example.invalid/SHA256SUMS", http.StatusFound) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + HTTPClient: server.Client(), + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestCheckUsesConventionalChecksumAndSignatureFallbacks(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + require := require.New(t) + assetName := "tool_1.2.0_linux_amd64.tar.gz" + var primaryChecksumRequests atomic.Int64 + var fallbackChecksumRequests atomic.Int64 + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /latest.json": + _ = json.NewEncoder(w).Encode(Release{TagName: "v1.2.0"}) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName: + w.Header().Set("Content-Length", "123") + w.WriteHeader(http.StatusOK) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName + ".sha256.sig": + http.NotFound(w, r) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName + ".sig": + w.WriteHeader(http.StatusOK) + case "GET /kenn/tool/releases/download/v1.2.0/SHA256SUMS": + primaryChecksumRequests.Add(1) + http.NotFound(w, r) + case "GET /kenn/tool/releases/download/v1.2.0/checksums.txt": + fallbackChecksumRequests.Add(1) + _, _ = fmt.Fprintf(w, "%s %s\n", testHash64, assetName) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + ReleaseManifestURL: server.URL + "/latest.json", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.NoError(err) + require.NotNil(info) + assert.Equal(testHash64, info.Checksum) + assert.Equal(server.URL+"/kenn/tool/releases/download/v1.2.0/"+assetName+".sig", info.SignatureURL) + assert.Equal(int64(1), primaryChecksumRequests.Load()) + assert.Equal(int64(1), fallbackChecksumRequests.Load()) +} + +func TestCheckFallsBackToAPIWhenWebConventionalReleaseHasNoChecksum(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + require := require.New(t) + assetName := "tool_1.2.0_linux_amd64.tar.gz" + var apiRequests atomic.Int64 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /kenn/tool/releases/latest": + http.Redirect(w, r, "/kenn/tool/releases/tag/v1.2.0", http.StatusFound) + case "GET /kenn/tool/releases/tag/v1.2.0": + _, _ = w.Write([]byte("release page")) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName: + w.Header().Set("Content-Length", "123") + w.WriteHeader(http.StatusOK) + case "GET /kenn/tool/releases/download/v1.2.0/SHA256SUMS", "GET /kenn/tool/releases/download/v1.2.0/checksums.txt": + http.NotFound(w, r) + case "GET /repos/kenn/tool/releases/latest": + apiRequests.Add(1) + _ = json.NewEncoder(w).Encode(Release{ + TagName: "v1.2.0", + Body: fmt.Sprintf("%s %s\n", testHash64, assetName), + Assets: []Asset{ + {Name: assetName, Size: 456, BrowserDownloadURL: "https://example.invalid/tool"}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.NoError(err) + require.NotNil(info) + assert.Equal(testHash64, info.Checksum) + assert.Equal("https://example.invalid/tool", info.DownloadURL) + assert.Equal(int64(456), info.Size) + assert.Equal(int64(1), apiRequests.Load()) +} + +func TestCheckRejectsHTTPAPIAssetAfterWebChecksumFallbackWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + assetName := "tool_1.2.0_linux_amd64.tar.gz" + var apiRequests atomic.Int64 + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /kenn/tool/releases/latest": + http.Redirect(w, r, "/kenn/tool/releases/tag/v1.2.0", http.StatusFound) + case "GET /kenn/tool/releases/tag/v1.2.0": + _, _ = w.Write([]byte("release page")) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName: + w.Header().Set("Content-Length", "123") + w.WriteHeader(http.StatusOK) + case "GET /kenn/tool/releases/download/v1.2.0/SHA256SUMS", "GET /kenn/tool/releases/download/v1.2.0/checksums.txt": + http.NotFound(w, r) + case "GET /repos/kenn/tool/releases/latest": + apiRequests.Add(1) + _ = json.NewEncoder(w).Encode(Release{ + TagName: "v1.2.0", + Body: fmt.Sprintf("%s %s\n", testHash64, assetName), + Assets: []Asset{ + {Name: assetName, Size: 456, BrowserDownloadURL: "http://example.invalid/tool"}, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "release asset URL for "+assetName+" must use https") + assert.Equal(t, int64(1), apiRequests.Load()) +} + +func TestCheckRejectsHTTPWebBaseWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubWebBaseURL: "http://example.invalid", + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "GitHub web base URL must use https") +} + +func TestCheckRejectsWebChecksumHTTPRedirectWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + assetName := "tool_1.2.0_linux_amd64.tar.gz" + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /kenn/tool/releases/latest": + http.Redirect(w, r, "/kenn/tool/releases/tag/v1.2.0", http.StatusFound) + case "GET /kenn/tool/releases/tag/v1.2.0": + _, _ = w.Write([]byte("release page")) + case "HEAD /kenn/tool/releases/download/v1.2.0/" + assetName: + w.Header().Set("Content-Length", "123") + w.WriteHeader(http.StatusOK) + case "GET /kenn/tool/releases/download/v1.2.0/SHA256SUMS": + http.Redirect(w, r, "http://example.invalid/SHA256SUMS", http.StatusFound) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestCheckRejectsWebLatestHTTPRedirectWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /kenn/tool/releases/latest": + http.Redirect(w, r, "http://example.invalid/kenn/tool/releases/tag/v1.2.0", http.StatusFound) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestCheckSendsTokenOnlyToAPIFallback(t *testing.T) { + t.Parallel() + + assert := assert.New(t) + require := require.New(t) + assetName := "tool_1.2.0_linux_amd64.tar.gz" + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /kenn/tool/releases/latest": + assert.Empty(r.Header.Get("Authorization")) + http.Error(w, "web discovery unavailable", http.StatusInternalServerError) + case "GET /repos/kenn/tool/releases/latest": + assert.Equal("Bearer test-token", r.Header.Get("Authorization")) + _ = json.NewEncoder(w).Encode(Release{ + TagName: "v1.2.0", + Assets: []Asset{ + {Name: assetName, Size: 123, BrowserDownloadURL: "https://example.invalid/tool"}, + {Name: "SHA256SUMS", BrowserDownloadURL: "https://" + r.Host + "/SHA256SUMS"}, + }, + }) + case "GET /SHA256SUMS": + assert.Empty(r.Header.Get("Authorization")) + _, _ = fmt.Fprintf(w, "%s %s\n", testHash64, assetName) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + GitHubToken: "test-token", + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.NoError(err) + require.NotNil(info) + assert.Equal(testHash64, info.Checksum) +} + +func TestCheckRejectsTokenWithHTTPAPIBaseURL(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/kenn/tool/releases/latest": + http.Error(w, "web discovery unavailable", http.StatusInternalServerError) + case "/repos/kenn/tool/releases/latest": + t.Fatalf("token-bearing API request should be rejected before fetch") + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubWebBaseURL: server.URL, + GitHubAPIBaseURL: server.URL, + GitHubToken: "test-token", + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "GitHub API base URL must use https") +} + +func TestCheckRejectsTokenAPIHTTPRedirect(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/kenn/tool/releases/latest": + http.Error(w, "web discovery unavailable", http.StatusInternalServerError) + case "/repos/kenn/tool/releases/latest": + assert.Equal(t, "Bearer test-token", r.Header.Get("Authorization")) + http.Redirect(w, r, "http://example.invalid/repos/kenn/tool/releases/latest", http.StatusFound) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubWebBaseURL: server.URL, + GitHubAPIBaseURL: server.URL, + HTTPClient: server.Client(), + GitHubToken: "test-token", + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestCheckRejectsUnsignedAPIHTTPRedirectWithoutToken(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.Method + " " + r.URL.Path { + case "GET /kenn/tool/releases/latest": + http.NotFound(w, r) + case "GET /repos/kenn/tool/releases/latest": + http.Redirect(w, r, "http://example.invalid/repos/kenn/tool/releases/latest", http.StatusFound) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client := Client{ + Owner: "kenn", + Repo: "tool", + BinaryName: "tool", + CurrentVersion: "v1.1.0", + GitHubAPIBaseURL: server.URL, + GitHubWebBaseURL: server.URL, + HTTPClient: server.Client(), + AllowUnsignedChecksums: true, + } + + info, err := client.Check(context.Background(), CheckOptions{GOOS: "linux", GOARCH: "amd64"}) + require.Error(t, err) + assert.Nil(t, info) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestEnvironmentGitHubToken(t *testing.T) { + assert := assert.New(t) + + t.Setenv("GH_TOKEN", "primary") + t.Setenv("GITHUB_TOKEN", "fallback") + assert.Equal("primary", EnvironmentGitHubToken()) + + t.Setenv("GH_TOKEN", "") + assert.Equal("fallback", EnvironmentGitHubToken()) + + t.Setenv("GITHUB_TOKEN", "") + assert.Empty(EnvironmentGitHubToken()) +} + func TestCheckUsesReleaseBodyChecksumFallback(t *testing.T) { t.Parallel() @@ -462,12 +1222,12 @@ func TestInstallRejectsMismatchedInfoRepository(t *testing.T) { func TestInstallRejectsDownloadLargerThanExpected(t *testing.T) { t.Parallel() - server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte("too large")) })) defer server.Close() - c := Client{BinaryName: "tool", AllowUnsignedChecksums: true} + c := Client{BinaryName: "tool", HTTPClient: server.Client(), AllowUnsignedChecksums: true} err := c.Install(context.Background(), &Info{ DownloadURL: server.URL, AssetName: "tool.tar.gz", @@ -479,6 +1239,45 @@ func TestInstallRejectsDownloadLargerThanExpected(t *testing.T) { } } +func TestInstallRejectsArchiveHTTPRedirectWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + server := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "http://example.invalid/archive.tar.gz", http.StatusFound) + })) + defer server.Close() + + c := Client{BinaryName: "tool", HTTPClient: server.Client(), AllowUnsignedChecksums: true} + err := c.Install(context.Background(), &Info{ + DownloadURL: server.URL + "/archive.tar.gz", + AssetName: "tool.tar.gz", + Checksum: strings.Repeat("0", 64), + }, InstallOptions{DestinationPath: filepath.Join(t.TempDir(), "tool")}) + require.Error(t, err) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") +} + +func TestInstallRejectsHTTPArchiveBeforeRequestWhenUnsignedChecksumsAllowed(t *testing.T) { + t.Parallel() + + var requests atomic.Int64 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests.Add(1) + _, _ = w.Write([]byte("archive")) + })) + defer server.Close() + + c := Client{BinaryName: "tool", AllowUnsignedChecksums: true} + err := c.Install(context.Background(), &Info{ + DownloadURL: server.URL + "/archive.tar.gz", + AssetName: "tool.tar.gz", + Checksum: strings.Repeat("0", 64), + }, InstallOptions{DestinationPath: filepath.Join(t.TempDir(), "tool")}) + require.Error(t, err) + assert.Contains(t, err.Error(), "redirect to non-HTTPS URL") + assert.Equal(t, int64(0), requests.Load()) +} + func TestInstallArchive(t *testing.T) { t.Parallel() @@ -773,6 +1572,74 @@ func TestFetchChecksumFromFileLimitsResponseSize(t *testing.T) { } } +func TestFetchChecksumFromAssetsPropagatesCanceledContext(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatalf("canceled checksum request should not reach server") + })) + defer server.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + c := Client{BinaryName: "tool"} + checksum, err := c.fetchChecksumFromAssets(ctx, []*Asset{ + {Name: "SHA256SUMS", BrowserDownloadURL: server.URL + "/SHA256SUMS"}, + }, "tool.tar.gz") + require.Error(t, err) + assert.Empty(t, checksum) + assert.ErrorIs(t, err, context.Canceled) +} + +func TestFetchChecksumFromAssetsPropagatesOversizedChecksum(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/SHA256SUMS": + _, _ = io.Copy(w, io.LimitReader(strings.NewReader(strings.Repeat("x", maxChecksumBytes+1)), maxChecksumBytes+1)) + case "/checksums.txt": + _, _ = fmt.Fprintf(w, "%s tool.tar.gz\n", testHash64) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + c := Client{BinaryName: "tool"} + checksum, err := c.fetchChecksumFromAssets(context.Background(), []*Asset{ + {Name: "SHA256SUMS", BrowserDownloadURL: server.URL + "/SHA256SUMS"}, + {Name: "checksums.txt", BrowserDownloadURL: server.URL + "/checksums.txt"}, + }, "tool.tar.gz") + require.Error(t, err) + assert.Empty(t, checksum) +} + +func TestFetchChecksumFromAssetsFallsBackAfterMissingAsset(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/SHA256SUMS": + http.NotFound(w, r) + case "/checksums.txt": + _, _ = fmt.Fprintf(w, "%s tool.tar.gz\n", testHash64) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + c := Client{BinaryName: "tool"} + checksum, err := c.fetchChecksumFromAssets(context.Background(), []*Asset{ + {Name: "SHA256SUMS", BrowserDownloadURL: server.URL + "/SHA256SUMS"}, + {Name: "checksums.txt", BrowserDownloadURL: server.URL + "/checksums.txt"}, + }, "tool.tar.gz") + require.NoError(t, err) + assert.Equal(t, testHash64, checksum) +} + func TestSanitizeArchivePath(t *testing.T) { t.Parallel()