diff --git a/cmd/wl/main_test.go b/cmd/wl/main_test.go index 445ae07..b921609 100644 --- a/cmd/wl/main_test.go +++ b/cmd/wl/main_test.go @@ -624,7 +624,7 @@ func TestGetTrackerNotFound(t *testing.T) { defer server.Close() err := runGet(getOpts{ - magnetURI: "magnet:?xt=urn:btih:deadbeef&dn=missing", + magnetURI: "magnet:?xt=urn:btih:deadbeefdeadbeefdeadbeefdeadbeefdeadbeef&dn=missing", trackerURL: server.URL, }) if err == nil { @@ -637,7 +637,7 @@ func TestGetTrackerNotFound(t *testing.T) { func TestGetTrackerDown(t *testing.T) { err := runGet(getOpts{ - magnetURI: "magnet:?xt=urn:btih:deadbeef&dn=unreachable", + magnetURI: "magnet:?xt=urn:btih:deadbeefdeadbeefdeadbeefdeadbeefdeadbeef&dn=unreachable", trackerURL: "http://127.0.0.1:1", }) if err == nil { diff --git a/internal/client/downloader.go b/internal/client/downloader.go index 829aaed..8e3b9ec 100644 --- a/internal/client/downloader.go +++ b/internal/client/downloader.go @@ -2,12 +2,12 @@ package client import ( "context" + "crypto/rand" "crypto/sha1" + "encoding/binary" "fmt" "log" - "math/rand" "strings" - "time" ) const ( @@ -133,6 +133,14 @@ func downloadPiece(ctx context.Context, p *PeerConn, index int, size int, expect if len(msg.Payload) < 8 { continue } + // Validate the block belongs where we expect before writing it, so a + // stray/duplicate/reordered block can't land at the wrong offset. + // Requests are issued one block at a time, in order. + blkIndex := binary.BigEndian.Uint32(msg.Payload[0:4]) + blkBegin := binary.BigEndian.Uint32(msg.Payload[4:8]) + if blkIndex != uint32(index) || blkBegin != uint32(downloaded) { + continue + } block := msg.Payload[8:] if downloaded+len(block) > size { return nil, fmt.Errorf("received block too large") @@ -181,11 +189,16 @@ func formatPieceSize(n int) string { } func generatePeerID() string { - r := rand.New(rand.NewSource(time.Now().UnixNano())) const charset = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, 12) + // crypto/rand so two clients started in the same instant don't collide. + if _, err := rand.Read(b); err != nil { + // rand.Read never returns an error on supported platforms, but fall + // back to a fixed-but-valid suffix rather than panicking. + return "-WL0020-aaaaaaaaaaaa" + } for i := range b { - b[i] = charset[r.Intn(len(charset))] + b[i] = charset[int(b[i])%len(charset)] } return "-WL0020-" + string(b) } diff --git a/internal/client/storage.go b/internal/client/storage.go index 0c6d8f2..d8032a3 100644 --- a/internal/client/storage.go +++ b/internal/client/storage.go @@ -95,5 +95,11 @@ func (s *Storage) WritePiece(pieceIndex int, pieceLength int, data []byte) error currentPos = fileEnd } + // Every byte of the piece must have mapped to a file; otherwise the + // piece-to-file offset math is wrong and we'd silently drop data. + if dataOffset != bytesToWrite { + return fmt.Errorf("piece %d: wrote %d of %d bytes (offset mapping mismatch)", pieceIndex, dataOffset, bytesToWrite) + } + return nil } diff --git a/internal/torrent/magnet.go b/internal/torrent/magnet.go index f1cbccc..3ac1569 100644 --- a/internal/torrent/magnet.go +++ b/internal/torrent/magnet.go @@ -40,10 +40,21 @@ func ParseMagnet(uri string) (Magnet, error) { for _, xt := range params["xt"] { switch { case strings.HasPrefix(xt, "urn:btih:"): - m.InfoHashV1 = strings.ToLower(xt[len("urn:btih:"):]) + v1 := strings.ToLower(xt[len("urn:btih:"):]) + // This implementation is hex-only (no base32): a v1 info hash is + // the 40-char hex of a 20-byte SHA-1. + if !isHex(v1, 40) { + return Magnet{}, fmt.Errorf("invalid v1 info hash in magnet: %q", v1) + } + m.InfoHashV1 = v1 case strings.HasPrefix(xt, "urn:btmh:1220"): - // Multihash: 0x12 = SHA-256, 0x20 = 32 bytes - m.InfoHashV2 = strings.ToLower(xt[len("urn:btmh:1220"):]) + // Multihash prefix 1220: 0x12 = SHA-256, 0x20 = 32 bytes, so the + // remainder must be the 64-char hex of a 32-byte digest. + v2 := strings.ToLower(xt[len("urn:btmh:1220"):]) + if !isHex(v2, 64) { + return Magnet{}, fmt.Errorf("invalid v2 info hash in magnet: %q", v2) + } + m.InfoHashV2 = v2 } } @@ -56,3 +67,16 @@ func ParseMagnet(uri string) (Magnet, error) { return m, nil } + +// isHex reports whether s is exactly n lowercase hex characters. +func isHex(s string, n int) bool { + if len(s) != n { + return false + } + for _, c := range s { + if (c < '0' || c > '9') && (c < 'a' || c > 'f') { + return false + } + } + return true +} diff --git a/internal/torrent/magnet_test.go b/internal/torrent/magnet_test.go index c4b7727..0998dff 100644 --- a/internal/torrent/magnet_test.go +++ b/internal/torrent/magnet_test.go @@ -4,18 +4,25 @@ import ( "testing" ) +// Valid-length hex fixtures: v1 is 40 hex chars (20-byte SHA-1), v2 is 64 (32-byte SHA-256). +const ( + v1Hex = "0123456789abcdef0123456789abcdef01234567" + v1HexUpper = "0123456789ABCDEF0123456789ABCDEF01234567" + v2Hex = "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210" +) + func TestParseMagnetHybrid(t *testing.T) { t.Parallel() - uri := "magnet:?xt=urn:btih:abc123def456&xt=urn:btmh:1220fedcba9876543210&dn=TestFile&tr=http://tracker:8080/announce" + uri := "magnet:?xt=urn:btih:" + v1Hex + "&xt=urn:btmh:1220" + v2Hex + "&dn=TestFile&tr=http://tracker:8080/announce" m, err := ParseMagnet(uri) if err != nil { t.Fatal(err) } - if m.InfoHashV1 != "abc123def456" { - t.Errorf("v1 hash = %q, want abc123def456", m.InfoHashV1) + if m.InfoHashV1 != v1Hex { + t.Errorf("v1 hash = %q, want %s", m.InfoHashV1, v1Hex) } - if m.InfoHashV2 != "fedcba9876543210" { - t.Errorf("v2 hash = %q, want fedcba9876543210", m.InfoHashV2) + if m.InfoHashV2 != v2Hex { + t.Errorf("v2 hash = %q, want %s", m.InfoHashV2, v2Hex) } if m.DisplayName != "TestFile" { t.Errorf("display name = %q, want TestFile", m.DisplayName) @@ -27,25 +34,25 @@ func TestParseMagnetHybrid(t *testing.T) { func TestParseMagnetV1Only(t *testing.T) { t.Parallel() - uri := "magnet:?xt=urn:btih:AABBCCDD&dn=V1Only" + uri := "magnet:?xt=urn:btih:" + v1Hex + "&dn=V1Only" m, err := ParseMagnet(uri) if err != nil { t.Fatal(err) } - if m.InfoHashV1 != "aabbccdd" { - t.Errorf("v1 hash = %q, want aabbccdd", m.InfoHashV1) + if m.InfoHashV1 != v1Hex { + t.Errorf("v1 hash = %q, want %s", m.InfoHashV1, v1Hex) } if m.InfoHashV2 != "" { t.Errorf("v2 hash should be empty, got %q", m.InfoHashV2) } - if m.BestHash() != "aabbccdd" { - t.Errorf("BestHash = %q, want aabbccdd", m.BestHash()) + if m.BestHash() != v1Hex { + t.Errorf("BestHash = %q, want %s", m.BestHash(), v1Hex) } } func TestParseMagnetV2Only(t *testing.T) { t.Parallel() - uri := "magnet:?xt=urn:btmh:1220abcdef1234567890" + uri := "magnet:?xt=urn:btmh:1220" + v2Hex m, err := ParseMagnet(uri) if err != nil { t.Fatal(err) @@ -53,29 +60,29 @@ func TestParseMagnetV2Only(t *testing.T) { if m.InfoHashV1 != "" { t.Errorf("v1 hash should be empty, got %q", m.InfoHashV1) } - if m.InfoHashV2 != "abcdef1234567890" { - t.Errorf("v2 hash = %q, want abcdef1234567890", m.InfoHashV2) + if m.InfoHashV2 != v2Hex { + t.Errorf("v2 hash = %q, want %s", m.InfoHashV2, v2Hex) } - if m.BestHash() != "abcdef1234567890" { + if m.BestHash() != v2Hex { t.Errorf("BestHash should prefer v2") } } func TestParseMagnetBestHashPrefersV2(t *testing.T) { t.Parallel() - uri := "magnet:?xt=urn:btih:v1hash&xt=urn:btmh:1220v2hash" + uri := "magnet:?xt=urn:btih:" + v1Hex + "&xt=urn:btmh:1220" + v2Hex m, err := ParseMagnet(uri) if err != nil { t.Fatal(err) } - if m.BestHash() != "v2hash" { + if m.BestHash() != v2Hex { t.Errorf("BestHash should prefer v2, got %q", m.BestHash()) } } func TestParseMagnetMultipleTrackers(t *testing.T) { t.Parallel() - uri := "magnet:?xt=urn:btih:abc&tr=http://one/announce&tr=http://two/announce" + uri := "magnet:?xt=urn:btih:" + v1Hex + "&tr=http://one/announce&tr=http://two/announce" m, err := ParseMagnet(uri) if err != nil { t.Fatal(err) @@ -103,7 +110,7 @@ func TestParseMagnetNotMagnet(t *testing.T) { func TestParseMagnetNoDisplayName(t *testing.T) { t.Parallel() - uri := "magnet:?xt=urn:btih:abc123" + uri := "magnet:?xt=urn:btih:" + v1Hex m, err := ParseMagnet(uri) if err != nil { t.Fatal(err) @@ -116,12 +123,33 @@ func TestParseMagnetNoDisplayName(t *testing.T) { func TestParseMagnetCaseInsensitive(t *testing.T) { t.Parallel() // Hashes should be lowercased - uri := "magnet:?xt=urn:btih:AABBCCDD" + uri := "magnet:?xt=urn:btih:" + v1HexUpper m, err := ParseMagnet(uri) if err != nil { t.Fatal(err) } - if m.InfoHashV1 != "aabbccdd" { + if m.InfoHashV1 != v1Hex { t.Errorf("expected lowercase hash, got %q", m.InfoHashV1) } } + +func TestParseMagnetRejectsMalformedHash(t *testing.T) { + t.Parallel() + tests := []struct { + name string + uri string + }{ + {"v1 too short", "magnet:?xt=urn:btih:abc123"}, + {"v1 non-hex", "magnet:?xt=urn:btih:zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz"}, + {"v2 too short", "magnet:?xt=urn:btmh:1220abcdef"}, + {"v2 non-hex", "magnet:?xt=urn:btmh:1220" + "g123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + if _, err := ParseMagnet(tt.uri); err == nil { + t.Errorf("ParseMagnet(%q) = nil error, want rejection", tt.uri) + } + }) + } +} diff --git a/internal/tracker/registry.go b/internal/tracker/registry.go index 6ae7d8d..4eb67b5 100644 --- a/internal/tracker/registry.go +++ b/internal/tracker/registry.go @@ -11,6 +11,7 @@ import ( "os" "sort" "strconv" + "strings" "time" "weightless/internal/torrent" @@ -212,10 +213,23 @@ func HandleTorrentDownload(w http.ResponseWriter, r *http.Request) { } w.Header().Set("Content-Type", "application/x-bittorrent") - w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.torrent"`, name)) + // Sanitize the DB-sourced name before interpolating into the header so a + // quote or control char can't break out of the quoted-string / inject. + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s.torrent"`, sanitizeFilename(name))) w.Write(data) } +// sanitizeFilename drops characters that would break out of a quoted-string +// HTTP header value (quotes, backslashes, control chars including CR/LF). +func sanitizeFilename(name string) string { + return strings.Map(func(r rune) rune { + if r < 0x20 || r == '"' || r == '\\' || r == 0x7f { + return -1 + } + return r + }, name) +} + func HandleSearch(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") @@ -264,14 +278,18 @@ func HandleSearch(w http.ResponseWriter, r *http.Request) { } } - // Sorting — seeders is memory-only, so DB sorts by created_at and we re-sort after + // Sorting — seeder counts are memory-only (not in SQLite), so they can't be + // sorted or paginated in SQL. For that path we scan the matching set (capped) + // and sort + paginate in memory below. sortCol := "created_at" + sortBySeeders := false switch q.Get("sort") { case "completions": sortCol = "completions" case "seeders": - sortCol = "created_at" // DB fallback; post-query sort below + sortBySeeders = true } + const seedersSortScanCap = 1000 // Get total count var total int @@ -284,7 +302,11 @@ func HandleSearch(w http.ResponseWriter, r *http.Request) { // Main query query := "SELECT " + registryCols + " FROM registry" + where - query += fmt.Sprintf(" ORDER BY %s DESC LIMIT %d OFFSET %d", sortCol, limit, offset) + if sortBySeeders { + query += fmt.Sprintf(" ORDER BY %s DESC LIMIT %d", sortCol, seedersSortScanCap) + } else { + query += fmt.Sprintf(" ORDER BY %s DESC LIMIT %d OFFSET %d", sortCol, limit, offset) + } rows, err := DB.Query(query, args...) if err != nil { @@ -309,11 +331,23 @@ func HandleSearch(w http.ResponseWriter, r *http.Request) { fillSwarmStats(results) - // Seeders are memory-only (not in SQLite), so sort post-query - if q.Get("sort") == "seeders" { + // Seeders are memory-only (not in SQLite), so sort then paginate in memory. + if sortBySeeders { + if total > seedersSortScanCap { + log.Printf("seeders sort scanned %d of %d matching rows; ranking may be incomplete", seedersSortScanCap, total) + } sort.Slice(results, func(i, j int) bool { return results[i].Seeders > results[j].Seeders }) + start := offset + if start > len(results) { + start = len(results) + } + end := start + limit + if end > len(results) { + end = len(results) + } + results = results[start:end] } if err := json.NewEncoder(w).Encode(results); err != nil { diff --git a/internal/tracker/state.go b/internal/tracker/state.go index 4ce1356..3d398dc 100644 --- a/internal/tracker/state.go +++ b/internal/tracker/state.go @@ -177,10 +177,11 @@ func (s *SwarmState) GetPeers(hash, excludeID string, limit int) []string { if id == excludeID { continue } - addrs = append(addrs, p.Addr) + // Check before appending so a limit of 0 (numwant=0) returns no peers. if len(addrs) >= limit { break } + addrs = append(addrs, p.Addr) } return addrs } diff --git a/internal/tracker/state_test.go b/internal/tracker/state_test.go index 91454ff..450b658 100644 --- a/internal/tracker/state_test.go +++ b/internal/tracker/state_test.go @@ -4,6 +4,7 @@ import ( "net/http" "net/http/httptest" "os" + "strconv" "strings" "testing" "time" @@ -139,6 +140,32 @@ func TestStatePruneMemory(t *testing.T) { } } +func TestStateGetPeersLimit(t *testing.T) { + State.mu.Lock() + State.Peers = make(map[string]map[string]*Peer) + State.mu.Unlock() + + hash := "limithash" + now := time.Now().Unix() + for i := 0; i < 5; i++ { + id := "p" + strconv.Itoa(i) + State.UpdatePeer(hash, id, &Peer{Addr: id + ":1", UpdatedAt: now}) + } + + // limit 0 (numwant=0) must return no peers, not one. + if got := State.GetPeers(hash, "none", 0); len(got) != 0 { + t.Errorf("limit 0: expected 0 peers, got %d (%v)", len(got), got) + } + // A positive limit caps the result. + if got := State.GetPeers(hash, "none", 3); len(got) != 3 { + t.Errorf("limit 3: expected 3 peers, got %d", len(got)) + } + // A limit above the swarm size returns all peers. + if got := State.GetPeers(hash, "none", 10); len(got) != 5 { + t.Errorf("limit 10: expected 5 peers, got %d", len(got)) + } +} + func TestMetricsHandler(t *testing.T) { req := httptest.NewRequest("GET", "/metrics", nil) w := httptest.NewRecorder()