diff --git a/README.md b/README.md index 4e945bb..febb946 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ [![License](https://img.shields.io/badge/License-UNLICENSE-blue.svg)](https://raw.githubusercontent.com/7i/shorter/master/UNLICENSE) [![License](https://img.shields.io/badge/License-0BSD-blue.svg)](https://raw.githubusercontent.com/7i/shorter/master/LICENSE) # shorter -URL shortener with pastebin and file upload functions +URL shortener with pastebin and QR code support ## WIP @@ -27,55 +27,156 @@ go get github.com/7i/shorter shorter /path/to/config ``` -## Examples -A deployed version of shorter is accessable at [7i.se](http://7i.se) +## Features + +### URL shortening + +Create a short link via the web UI or the quick-add GET syntax: -create a temporary link to "https://www.example.com" via a GET request that is as short as possible: ```bash +# Quick-add: shortest available key 7i.se?https://www.example.com -or -7i.se/?https://www.example.com + +# Quick-add: custom key +7i.se/mykey?https://www.example.com ``` -create a temporary link to "https://www.example.com" via a GET request using the key "KeyToExample": -```bash -7i.se/KeyToExample?https://www.example.com -or -7i.se/KeyToExample/?https://www.example.com + +Four key buckets, each with a configurable timeout: + +| Length | Example | Default timeout | +|--------|---------|-----------------| +| 1 char | `7i.se/a` | 24 h | +| 2 chars | `7i.se/ab` | 7 d | +| 3 chars | `7i.se/abc` | 60 d | +| Custom (4–64 chars) | `7i.se/mykey` | 30 d | + +Append `~` to any key to preview where it points without consuming an access (`7i.se/a~`). + +### Pastebin + +Submit a text blob via the web UI (`requestType=text`). Large blobs are transparently gzip-compressed before storage. + +### QR codes + +``` +GET /qr/{key} +``` + +Returns a 256×256 PNG QR code encoding the full short URL for the given key. + +### JSON API + +#### Shorten a URL + +``` +POST /api/v1/shorten +Content-Type: application/json + +{ + "url": "https://www.example.com", + "len": "1", // "1", "2", "3", or "custom" + "key": "mykey", // optional, required when len=custom + "x_times": 5 // optional: delete after N accesses (0 = unlimited) +} +``` + +Response `201 Created`: +```json +{ + "key": "a", + "short_url": "https://7i.se/a", + "expires": "Mon 2025-01-01 12:00 UTC" +} +``` + +#### Look up a key (no access consumed) + +``` +GET /api/v1/lookup/{key} ``` +Response `200 OK`: +```json +{ + "key": "a", + "link_type": "url", + "url": "https://www.example.com", + "expires": "Mon 2025-01-01 12:00 UTC", + "access_count": 42, + "times_remaining": -1 +} +``` + +`times_remaining: -1` means unlimited accesses. + +### Persistence + +Links are stored in a [bbolt](https://github.com/etcd-io/bbolt) embedded database (`shorterdata/shorter.db`). They survive server restarts and are pruned automatically on startup when expired. + +### Admin endpoint + +``` +GET /listactive~ +``` + +HTTP Basic Auth required. Password is verified as `sha256(password + Salt) == HashSHA256` using constant-time comparison. + +### Rate limiting + +POST requests are rate-limited to **30 per minute per IP** (sliding window). Exceeding the limit returns `429 Too Many Requests`. + +### Blocklist + +Domains can be blocked via: +- `BlockedDomains` list in the config file +- An optional newline-delimited `BlocklistFile` (lines starting with `#` are comments) + +Blocked URLs return `403 Forbidden`. + +## Security + +- TLS 1.2+ with AEAD-only cipher suites (AES-GCM, ChaCha20-Poly1305), X25519 curve preferred +- `X-Content-Type-Options: nosniff` and `Referrer-Policy: no-referrer` on all responses +- Optional `Strict-Transport-Security`, `Content-Security-Policy`, and `Report-To` headers +- Decompression bomb protection: 20 MiB hard cap on gzip decompression +- All URL inputs validated; only `http://` and `https://` schemes accepted +- Concurrent-safe link storage with no data races (verified with `-race`) + ## TODO - [x] Implement shortening of URLs - - [x] 1 char long - configurabe timeout - - [x] 2 chars long - configurabe timeout - - [x] 3 chars long - configurabe timeout + - [x] 1 char long - configurable timeout + - [x] 2 chars long - configurable timeout + - [x] 3 chars long - configurable timeout - [x] make timeouts configurable - [x] temporary word bindings (7i.se/coolthing) - - [x] quick add link via get request with syntax 7i.se?https://example.com - - [x] quick add word bindings link via get request with syntax 7i.se/coolthing?https://example.com where coolthing is the key - - [ ] optional removal of link after N accesses -- [x] Add functionality to print where a link is pointing by adding ~ at the end of the link e.g. 7i.se/a~ will display where 7i.se/a is pointing to + - [x] quick add link via GET request with syntax 7i.se?https://example.com + - [x] quick add word bindings link via GET request with syntax 7i.se/coolthing?https://example.com + - [x] optional removal of link after N accesses (x_times) +- [x] Add functionality to print where a link is pointing by adding ~ at the end of the link - [x] Add config file that specifies relevant options - [x] Pastebin functionality with same timeouts as above -- [x] Move to ssl with Let's Encrypt -- [ ] Save all active links in a database file instead of gob files -- [ ] Add support for subdomains with diffrent configs e.g. d1.7i.se - - [ ] Add password/client cert protected subdomain management e.g. d1.7i.se/admin - - [ ] Let the user managing a subdomain specify generic links and set timeouts, including "no timeout" for the shortened links, text-blobs and files. +- [x] Move to SSL with Let's Encrypt +- [x] Save all active links in a database file (bbolt) +- [x] JSON REST API for programmatic shortening and lookup +- [x] QR code generation per short link +- [x] Click analytics (access count per link) +- [x] URL deduplication (repeated submissions return existing key) +- [x] Per-IP rate limiting +- [x] Blocklist support (config + file) - [x] Enable CSP - - [x] Move all js and css to seperate files and modify html/template files to use these + - [x] Move all js and css to separate files and modify html/template files to use these - [ ] Setup a CSP report collector -- [ ] Use blocklists for known malware sites, integrate with: +- [ ] Add support for subdomains with different configs e.g. d1.7i.se + - [ ] Add password/client cert protected subdomain management + - [ ] Let the user managing a subdomain specify generic links and set timeouts +- [ ] Integrate with external malware/blocklist feeds: - [ ] https://www.stopbadware.org/firefox - [ ] https://www.malwaredomainlist.com - [ ] https://isc.sans.edu/suspicious_domains.html - - [ ] https://zeltser.com/malicious-ip-blocklists/ - - [ ] if linking to a page that redirects, follow redirects only for 5 levels and display error if redirected more times -- [ ] Include report form to take down links that breaks terms of usage - - [ ] implement capcha for submitting reports to take down links -- [x] Create Terms of usage +- [ ] Include report form to take down links that break terms of use +- [x] Create Terms of use ## License The `shorter` project is dual-licensed to the [public domain](UNLICENSE) and under a [zero-clause BSD license](LICENSE). You may choose either license to govern your use of `shorter`. - diff --git a/api.go b/api.go new file mode 100644 index 0000000..6511f39 --- /dev/null +++ b/api.go @@ -0,0 +1,178 @@ +package main + +import ( + "encoding/json" + "net/http" + "strings" + "time" +) + +// API request / response types + +type apiShortenRequest struct { + URL string `json:"url"` + Key string `json:"key"` // optional custom key (4-64 chars) + Len string `json:"len"` // "1", "2", "3", or "custom" + XTimes int `json:"x_times"` // max accesses; omit or 0 for unlimited +} + +type apiShortenResponse struct { + Key string `json:"key"` + ShortURL string `json:"short_url"` + Expires string `json:"expires"` +} + +type apiLookupResponse struct { + Key string `json:"key"` + LinkType string `json:"link_type"` + URL string `json:"url,omitempty"` + Expires string `json:"expires"` + AccessCount int64 `json:"access_count"` + TimesRemaining int `json:"times_remaining"` // -1 = unlimited +} + +func handleAPI(mux *http.ServeMux) { + mux.HandleFunc("/api/v1/shorten", apiShorten) + mux.HandleFunc("/api/v1/lookup/", apiLookup) +} + +func writeJSON(w http.ResponseWriter, status int, v interface{}) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + json.NewEncoder(w).Encode(v) +} + +// apiShorten handles POST /api/v1/shorten +// Body: {"url":"https://...","len":"1","x_times":0} +// Returns: {"key":"ab","short_url":"https://host/ab","expires":"..."} +func apiShorten(w http.ResponseWriter, r *http.Request) { + if !validRequest(r) { + writeJSON(w, http.StatusForbidden, map[string]string{"error": "forbidden"}) + return + } + if r.Method != http.MethodPost { + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + return + } + if !rateLimitAllow(r.RemoteAddr) { + writeJSON(w, http.StatusTooManyRequests, map[string]string{"error": "rate limit exceeded"}) + return + } + + var req apiShortenRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid JSON body"}) + return + } + + if !validURL(req.URL) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": "invalid URL; only http and https are allowed"}) + return + } + if isBlocklisted(req.URL) { + writeJSON(w, http.StatusForbidden, map[string]string{"error": "URL is not allowed"}) + return + } + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + // Dedup: return existing unlimited short link for the same URL if no custom key requested. + if req.Key == "" && req.XTimes <= 0 { + if existing := findExistingURL(r.Host, req.URL); existing != nil { + writeJSON(w, http.StatusOK, apiShortenResponse{ + Key: existing.Key, + ShortURL: scheme + "://" + r.Host + "/" + existing.Key, + Expires: existing.Timeout.Format(dateFormat), + }) + return + } + } + + // Choose the target LinkLen bucket. + var ll *LinkLen + switch req.Len { + case "2": + ll = &domainLinkLens[r.Host].LinkLen2 + case "3": + ll = &domainLinkLens[r.Host].LinkLen3 + case "custom": + if !validate(req.Key) || len(req.Key) < 4 || len(req.Key) > maxKeyLen { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": errInvalidCustomKey}) + return + } + ll = &domainLinkLens[r.Host].LinkCustom + default: + ll = &domainLinkLens[r.Host].LinkLen1 + } + + xTimes := req.XTimes + if xTimes < 1 { + xTimes = -1 + } else if xTimes > config.LinkAccessMaxNr { + xTimes = config.LinkAccessMaxNr + } + + ll.Mutex.RLock() + timeout := ll.Timeout + ll.Mutex.RUnlock() + + lnk := &Link{ + Key: req.Key, + LinkType: "url", + Data: req.URL, + Times: xTimes, + Timeout: time.Now().Add(timeout), + } + key, err := ll.Add(lnk) + if err != nil { + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + writeJSON(w, http.StatusCreated, apiShortenResponse{ + Key: key, + ShortURL: scheme + "://" + r.Host + "/" + key, + Expires: lnk.Timeout.Format(dateFormat), + }) +} + +// apiLookup handles GET /api/v1/lookup/{key} +// Returns metadata for the key without consuming an access. +func apiLookup(w http.ResponseWriter, r *http.Request) { + if !validRequest(r) { + writeJSON(w, http.StatusForbidden, map[string]string{"error": "forbidden"}) + return + } + if r.Method != http.MethodGet { + writeJSON(w, http.StatusMethodNotAllowed, map[string]string{"error": "method not allowed"}) + return + } + + key := strings.TrimPrefix(r.URL.Path, "/api/v1/lookup/") + key = strings.TrimSuffix(key, "/") + if !validate(key) || len(key) == 0 { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": errInvalidKey}) + return + } + + lnk, _ := lookupLink(r.Host, key) + if lnk == nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "key not found"}) + return + } + + resp := apiLookupResponse{ + Key: key, + LinkType: lnk.LinkType, + Expires: lnk.Timeout.Format(dateFormat), + AccessCount: lnk.AccessCount, + TimesRemaining: lnk.Times, + } + if lnk.LinkType == "url" { + resp.URL = lnk.Data + } + writeJSON(w, http.StatusOK, resp) +} diff --git a/db.go b/db.go index b8989f2..af341ba 100644 --- a/db.go +++ b/db.go @@ -3,286 +3,208 @@ package main import ( "bytes" "encoding/gob" - "io/ioutil" + "encoding/json" "os" "path/filepath" + "sort" "time" + + bbolt "go.etcd.io/bbolt" ) -// Fugly solution, TODO switch to real DB like bolt -func setupDB() { +var linkLenTypes = []string{"len1", "len2", "len3", "custom"} +// setupDB opens (or creates) the bbolt database, prunes expired entries, and +// restores surviving links into memory. Falls back to legacy gob files if the +// database cannot be opened. +func setupDB() { if logger != nil { - logger.Println("Reading in links and data from db") - } - for _, domain := range config.DomainNames { - restoreLinkLen(&domainLinkLens[domain].LinkLen1, "len1", domain) - restoreLinkLen(&domainLinkLens[domain].LinkLen2, "len2", domain) - restoreLinkLen(&domainLinkLens[domain].LinkLen3, "len3", domain) - restoreLinkLen(&domainLinkLens[domain].LinkCustom, "custom", domain) + logger.Println("Opening bbolt database") } -} -func restoreLinkLen(l *LinkLen, typ, domain string) { - var backupLinkLen []Link - fileName := "backupdb-" + domain + "-" + typ + ".gob" + dbPath := filepath.Join(config.BaseDir, "shorter.db") + var err error + boltDB, err = bbolt.Open(dbPath, 0600, &bbolt.Options{Timeout: 2 * time.Second}) + if err != nil { + if logger != nil { + logger.Println("Failed to open bbolt DB, falling back to gob backup:", err) + } + for _, domain := range config.DomainNames { + restoreGob(&domainLinkLens[domain].LinkLen1, "len1", domain) + restoreGob(&domainLinkLens[domain].LinkLen2, "len2", domain) + restoreGob(&domainLinkLens[domain].LinkLen3, "len3", domain) + restoreGob(&domainLinkLens[domain].LinkCustom, "custom", domain) + } + return + } - d, err := ioutil.ReadFile(filepath.Join(config.BaseDir, domain, fileName)) - if err != nil && logger != nil { - logger.Println(err, "ReadFile - Skipping "+fileName) - } else { - buf := bytes.NewBuffer(d) - dec := gob.NewDecoder(buf) - err := dec.Decode(&backupLinkLen) - if err != nil && logger != nil { - logger.Println(err, "Unmarshal - Skipping"+fileName) - } else { - if len(backupLinkLen) > 0 && backupLinkLen[0].Key != "" { - l.NextClear = &backupLinkLen[0] - l.EndClear = &backupLinkLen[len(backupLinkLen)-1] - l.Links = len(backupLinkLen) - l.LinkMap[backupLinkLen[0].Key] = &backupLinkLen[0] - delete(l.FreeMap, backupLinkLen[0].Key) + // Ensure all domain/type buckets exist. + if err = boltDB.Update(func(tx *bbolt.Tx) error { + for _, domain := range config.DomainNames { + b, err := tx.CreateBucketIfNotExists([]byte(domain)) + if err != nil { + return err } - for i := 1; i < len(backupLinkLen); i++ { - if backupLinkLen[i].Key != "" { - l.LinkMap[backupLinkLen[i].Key] = &backupLinkLen[i] - backupLinkLen[i-1].NextClear = &backupLinkLen[i] - delete(l.FreeMap, backupLinkLen[i].Key) + for _, t := range linkLenTypes { + if _, err = b.CreateBucketIfNotExists([]byte(t)); err != nil { + return err } } } + return nil + }); err != nil && logger != nil { + logger.Println("Failed to create bbolt buckets:", err) } -} - -// part 2 of the fugly solution -func BackupRoutine() { - - for { - time.Sleep(time.Minute * 30) + // Restore links, pruning any that have already expired. + now := time.Now() + boltDB.Update(func(tx *bbolt.Tx) error { for _, domain := range config.DomainNames { - saveBackup(&domainLinkLens[domain].LinkLen1, "len1", domain) - saveBackup(&domainLinkLens[domain].LinkLen2, "len2", domain) - saveBackup(&domainLinkLens[domain].LinkLen3, "len3", domain) - saveBackup(&domainLinkLens[domain].LinkCustom, "custom", domain) + b := tx.Bucket([]byte(domain)) + if b == nil { + return nil + } + for _, typ := range linkLenTypes { + tb := b.Bucket([]byte(typ)) + if tb == nil { + continue + } + ll := llForType(domain, typ) + + var active []Link + var toDelete [][]byte + + tb.ForEach(func(k, v []byte) error { + var lnk Link + if json.Unmarshal(v, &lnk) != nil { + toDelete = append(toDelete, append([]byte(nil), k...)) + return nil + } + if !lnk.Timeout.After(now) { + toDelete = append(toDelete, append([]byte(nil), k...)) + return nil + } + active = append(active, lnk) + return nil + }) + for _, k := range toDelete { + tb.Delete(k) + } + + // Sort ascending by Timeout so NextClear linked list stays ordered. + sort.Slice(active, func(i, j int) bool { + return active[i].Timeout.Before(active[j].Timeout) + }) + for i := range active { + restoreLinkToMemory(ll, &active[i]) + } + } } + return nil + }) - logger.Println("Finished saving new backup") + if logger != nil { + logger.Println("bbolt restore complete") } } -func saveBackup(l *LinkLen, typ, domain string) { - var err error - var backupLinkLen []Link - filename := "backupdb-" + domain + "-" + typ + ".gob" +func llForType(domain, typ string) *LinkLen { + switch typ { + case "len1": + return &domainLinkLens[domain].LinkLen1 + case "len2": + return &domainLinkLens[domain].LinkLen2 + case "len3": + return &domainLinkLens[domain].LinkLen3 + default: + return &domainLinkLens[domain].LinkCustom + } +} - if l == nil { - logger.Println("*LinkLen is nil, skipping ", filename) - return +// restoreLinkToMemory inserts lnk directly into ll's in-memory structures +// without triggering a DB write. Links must be inserted in ascending Timeout order. +func restoreLinkToMemory(l *LinkLen, lnk *Link) { + lnk.NextClear = nil + l.LinkMap[lnk.Key] = lnk + if l.FreeMap != nil { + delete(l.FreeMap, lnk.Key) + } else { + l.Links++ } if l.NextClear == nil { - logger.Println("l.NextClear is nil, skipping ", filename) - return + l.NextClear = lnk + l.EndClear = lnk + } else { + l.EndClear.NextClear = lnk + l.EndClear = lnk } +} - l.Mutex.Lock() - - next := *l.NextClear - - stop := false - for !stop { - backupLinkLen = append(backupLinkLen, next) - if next.NextClear != nil { - next = *next.NextClear - } else { - stop = true - } +// saveLinkToDB persists lnk to bbolt. Called after a successful Add(), outside +// the LinkLen mutex so it doesn't block readers. +func saveLinkToDB(domain, typ string, lnk *Link) { + if boltDB == nil { + return } - l.Mutex.Unlock() - - var backupBuffer bytes.Buffer - enc := gob.NewEncoder(&backupBuffer) - err = enc.Encode(backupLinkLen) + data, err := json.Marshal(lnk) if err != nil { - logger.Println(err, "Error while saving backup in enc.Encode()") - } - - backupLinkLen = nil - - if err = os.WriteFile(filepath.Join(config.BaseDir, domain, filename), backupBuffer.Bytes(), 0644); err != nil && logger != nil { - logger.Println(err, "failed to save DB") + if logger != nil { + logger.Println("saveLinkToDB marshal error:", err) + } + return } - - logger.Println("Backed up:", filename) -} - -// New BoltDB restore -//startRestoreDB(&domainLinkLens[domain].LinkLen1, domain, "linkLen1") -//startRestoreDB(&domainLinkLens[domain].LinkLen2, domain, "linkLen2") -//startRestoreDB(&domainLinkLens[domain].LinkLen3, domain, "linkLen3") -//startRestoreDB(&domainLinkLens[domain].LinkCustom, domain, "linkCustom") -// -// -// -// -// -// -// -// -// -// -// -// -// -// -// -// -// new bolt implementation of backup ============ -// Config2 type - -/* -type Config2 struct { - Height float64 `json:"height"` - Birthday time.Time `json:"birthday"` -} - -// Entry type -type Entry struct { - Calories int `json:"calories"` - Food string `json:"food"` -} - -func startRestoreDB(l *LinkLen, domain, linkLen string) { - db := setupDB2(domain) - defer db.Close() - restoreDBLinkLen(db, domain) -} - -// restoreDBLinkLen will read out all links for all linkLen from the bolt.DB for the specified domain and populate the domainLinkLens map -func restoreDBLinkLen(db *bolt.DB, domain string) { - err := db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte(domain)).Bucket([]byte("linkLen1")) - b.ForEach(func(k, v []byte) error { - // TODO Restore entries from DB HERE - - var lnk Link - err := json.Unmarshal(v, &lnk) - if err != nil { - logger.Fatalln("Unable to restore link,", err) - } - l1 := &domainLinkLens[domain].LinkLen1 - - l1.Add(&lnk) - - - - if len(backupLinkLen) > 0 && backupLinkLen[0].Key != "" { - l.NextClear = &backupLinkLen[0] - l.EndClear = &backupLinkLen[len(backupLinkLen)-1] - l.Links = len(backupLinkLen) - l.LinkMap[backupLinkLen[0].Key] = &backupLinkLen[0] - delete(l.FreeMap, backupLinkLen[0].Key) - } - for i := 1; i < len(backupLinkLen); i++ { - if backupLinkLen[i].Key != "" { - l.LinkMap[backupLinkLen[i].Key] = &backupLinkLen[i] - backupLinkLen[i-1].NextClear = &backupLinkLen[i] - delete(l.FreeMap, backupLinkLen[i].Key) - } - } - - fmt.Println(string(k), string(v)) + if err = boltDB.Update(func(tx *bbolt.Tx) error { + b := tx.Bucket([]byte(domain)) + if b == nil { return nil - }) - return nil - }) - if err != nil { - log.Fatal(err) + } + return b.Bucket([]byte(typ)).Put([]byte(lnk.Key), data) + }); err != nil && logger != nil { + logger.Println("saveLinkToDB write error:", err) } } -func setupDB2(domain string) *bolt.DB { - - db, err := bolt.Open(filepath.Join(config.BaseDir, domain, domain+".db"), 0600, nil) - - if err != nil && logger != nil { - logger.Fatalln("could not open db,", err) +// deleteLinkFromDB removes a key from bbolt. Called by TimeoutManager when a +// link expires and by recordAccess when Times reaches zero. +func deleteLinkFromDB(domain, typ, key string) { + if boltDB == nil { + return } - err = db.Update(func(tx *bolt.Tx) error { - root, err := tx.CreateBucketIfNotExists([]byte(domain)) - if err != nil { - logger.Fatalln("could not create root bucket:", err) - } - _, err = root.CreateBucketIfNotExists([]byte("linkLen1")) - if err != nil { - logger.Fatalln("could not create linkLen1 bucket:", err) - } - _, err = root.CreateBucketIfNotExists([]byte("linkLen2")) - if err != nil { - logger.Fatalln("could not create linkLen2 bucket:", err) - } - _, err = root.CreateBucketIfNotExists([]byte("linkLen3")) - if err != nil { - logger.Fatalln("could not create linkLen3 bucket:", err) - } - _, err = root.CreateBucketIfNotExists([]byte("linkCustom")) - if err != nil { - logger.Fatalln("could not create linkCustom bucket:", err) + if err := boltDB.Update(func(tx *bbolt.Tx) error { + b := tx.Bucket([]byte(domain)) + if b == nil { + return nil } - return nil - }) - if err != nil { - logger.Fatalln("could not set up buckets,", err) + return b.Bucket([]byte(typ)).Delete([]byte(key)) + }); err != nil && logger != nil { + logger.Println("deleteLinkFromDB error:", err) } - logger.Println("") - fmt.Println("DB Setup for", domain, "Done") - return db } -func setConfig2(db *bolt.DB, Config2 Config2) error { - confBytes, err := json.Marshal(Config2) +// restoreGob is the legacy fallback: restores links from a gob backup file. +func restoreGob(l *LinkLen, typ, domain string) { + fileName := "backupdb-" + domain + "-" + typ + ".gob" + d, err := os.ReadFile(filepath.Join(config.BaseDir, domain, fileName)) if err != nil { - return fmt.Errorf("could not marshal Config2 json: %v", err) + if logger != nil { + logger.Println(err, "restoreGob - skipping "+fileName) + } + return } - err = db.Update(func(tx *bolt.Tx) error { - err = tx.Bucket([]byte("DB")).Put([]byte("CONFIG"), confBytes) - if err != nil { - return fmt.Errorf("could not set Config2: %v", err) + var links []Link + if err = gob.NewDecoder(bytes.NewBuffer(d)).Decode(&links); err != nil { + if logger != nil { + logger.Println(err, "restoreGob - decode error "+fileName) } - return nil + return + } + now := time.Now() + sort.Slice(links, func(i, j int) bool { + return links[i].Timeout.Before(links[j].Timeout) }) - fmt.Println("Set Config2") - return err -} - -func addWeight(db *bolt.DB, weight string, date time.Time) error { - err := db.Update(func(tx *bolt.Tx) error { - err := tx.Bucket([]byte("DB")).Bucket([]byte("WEIGHT")).Put([]byte(date.Format(time.RFC3339)), []byte(weight)) - if err != nil { - return fmt.Errorf("could not insert weight: %v", err) + for i := range links { + if links[i].Key != "" && links[i].Timeout.After(now) { + restoreLinkToMemory(l, &links[i]) } - return nil - }) - fmt.Println("Added Weight") - return err -} - -func addEntry(db *bolt.DB, calories int, food string, date time.Time) error { - entry := Entry{Calories: calories, Food: food} - entryBytes, err := json.Marshal(entry) - if err != nil { - return fmt.Errorf("could not marshal entry json: %v", err) } - err = db.Update(func(tx *bolt.Tx) error { - err := tx.Bucket([]byte("DB")).Bucket([]byte("ENTRIES")).Put([]byte(date.Format(time.RFC3339)), entryBytes) - if err != nil { - return fmt.Errorf("could not insert entry: %v", err) - } - - return nil - }) - fmt.Println("Added Entry") - return err } -*/ diff --git a/defines.go b/defines.go index 2c82c52..0389333 100644 --- a/defines.go +++ b/defines.go @@ -3,47 +3,47 @@ package main import ( "html/template" "log" + + bbolt "go.etcd.io/bbolt" ) const ( // charset consists of alphanumeric characters with some characters removed due to them being to similar in some fonts. charset = "abcdefghijkmnopqrstuvwxyz23456789ABCDEFGHJKLMNPQRSTUVWXYZ" - // charset consists of characters that are valid for custom keys. + // customKeyCharset consists of characters that are valid for custom keys. customKeyCharset = "abcdefghijklmnopqrstuvwxyzåäö0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZÅÄÖ-_" // dateFormat specifies the format in which date and time is represented. dateFormat = "Mon 2006-01-02 15:04 MST" - // errServerError contains the generic error message users will se when somthing goes wrong + // errServerError contains the generic error message users will see when something goes wrong errServerError = "Internal Server Error" errInvalidKey = "Invalid key" errInvalidKeyUsed = "Invalid key, key is already in use" errInvalidCustomKey = "Invalid Custom Key was provided, valid characters are:\n" + customKeyCharset errNotImplemented = "Not Implemented" errLowRAM = "No Space available, new space will be available as old links become invalid" - // Do not try to gzip data that is less than minSizeToGzip + // Do not try to gzip data that is less than minSizeToGzip bytes minSizeToGzip = 128 - // Max key length for custom links + // maxKeyLen is the maximum length of a custom key maxKeyLen = 64 + // maxDecompressedSize caps how many bytes returnDecompressed / decompress will inflate. + maxDecompressedSize = 20 << 20 // 20 MiB ) var ( - // logSep is a 128bit random number together with a configured log separator string to make it harder to forge log entry's + // logSep is a 128-bit random value combined with config.LogSep to make log entries hard to forge. logSep string - // Server config variable + // config holds the parsed server configuration. config Config - // linkLen1, linkLen2 and linkLen3 will contain all data related to their respective key length and linkCustom will contain all data related to custom keys. + // domainLinkLens contains per-domain link buckets for all key lengths. domainLinkLens map[string]*LinkLens - // If we want to log errors logger will write these to a file specified in the config + // logger writes structured entries to the configured log file; nil means logging is disabled. logger *log.Logger - // ImageMap is used in handlers.go to map requests to imagedata + // ImageMap maps "domain-logo" / "domain-favicon" keys to their PNG bytes. ImageMap map[string][]byte - // TextBlobs is a temporary map until saving to DB is implemented - TextBlobs map[string][]byte - // BackupLinkLen is used to repopulate the database after loading backuped data from a file - BackupLinkLen1 []Link - BackupLinkLen2 []Link - BackupLinkLen3 []Link - BackupLinkLenC []Link templateMap map[string]*template.Template + + // boltDB is the persistent store; nil means DB unavailable (gob fallback used). + boltDB *bbolt.DB ) diff --git a/go.mod b/go.mod index c4d1067..57d987e 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,11 @@ module github.com/7i/shorter -go 1.17 +go 1.23 require ( github.com/kr/pretty v0.3.1 + github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e + go.etcd.io/bbolt v1.4.3 golang.org/x/crypto v0.4.0 gopkg.in/yaml.v2 v2.4.0 ) @@ -12,5 +14,6 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/rogpeppe/go-internal v1.9.0 // indirect golang.org/x/net v0.4.0 // indirect + golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.5.0 // indirect ) diff --git a/go.sum b/go.sum index c290ec4..d481bb9 100644 --- a/go.sum +++ b/go.sum @@ -1,44 +1,34 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0= +github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e/go.mod h1:XV66xRDqSt+GTGFMVlhk3ULuV0y9ZmzeVGR4mloJI3M= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.etcd.io/bbolt v1.4.3 h1:dEadXpI6G79deX5prL3QRNP6JB8UxVkqo4UPnHaNXJo= +go.etcd.io/bbolt v1.4.3/go.mod h1:tKQlpPaYCVFctUIgFKFnAlvbmB3tpy1vkTnDWohtc0E= golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= golang.org/x/crypto v0.4.0/go.mod h1:3quD/ATkf6oY+rnes5c3ExXTbLc8mueNue5/DoinL80= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.4.0 h1:Q5QPcMlvfxFTAPV0+07Xz/MpK9NTXu2VDUuy0FeMfaU= golang.org/x/net v0.4.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/handlers.go b/handlers.go index 84892dd..b228567 100644 --- a/handlers.go +++ b/handlers.go @@ -3,12 +3,13 @@ package main import ( "bytes" "compress/gzip" + "encoding/base64" "fmt" "html" - "io/ioutil" "log" "net/http" "net/url" + "os" "path/filepath" "strconv" "strings" @@ -52,6 +53,10 @@ func handleRequests(w http.ResponseWriter, r *http.Request) { // If the user tries to submit data via POST if r.Method == http.MethodPost { + if !rateLimitAllow(r.RemoteAddr) { + http.Error(w, "Too Many Requests", http.StatusTooManyRequests) + return + } err := r.ParseMultipartForm(config.MaxFileSize) if err != nil { logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: "+url.QueryEscape(err.Error())) @@ -91,11 +96,7 @@ func handleRequests(w http.ResponseWriter, r *http.Request) { logErrors(w, r, errInvalidCustomKey, http.StatusInternalServerError, "") return } - - if _, used := domainLinkLens[r.Host].LinkCustom.LinkMap[customKey]; used { - http.Error(w, errInvalidKeyUsed, http.StatusInternalServerError) - return - } + // Duplicate-key check is performed atomically inside Add() under the write lock. } // Handle different request types @@ -103,11 +104,36 @@ func handleRequests(w http.ResponseWriter, r *http.Request) { switch requestType { case "url": formURL := r.Form.Get("url") - valid := validURL(formURL) - if !valid { + if !validURL(formURL) { logErrors(w, r, "Invalid url, only \"http://\" and \"https://\" url schemes are allowed.", http.StatusInternalServerError, "") return } + if isBlocklisted(formURL) { + logErrors(w, r, "URL is not allowed", http.StatusForbidden, "Blocked URL: "+url.QueryEscape(formURL)) + return + } + // Dedup: if the same unlimited URL is already active and no custom key + // was requested, return the existing short link instead of creating a new one. + if customKey == "" { + if existing := findExistingURL(r.Host, formURL); existing != nil { + w.Header().Add("Content-Type", "text/html; charset=utf-8") + t, ok := templateMap[r.Host+"#showLink"] + if !ok { + http.Error(w, errServerError, http.StatusInternalServerError) + return + } + tmplArgs := showLinkVars{Domain: scheme + "://" + r.Host, Data: scheme + "://" + r.Host + "/" + existing.Key, Timeout: existing.Timeout.Format("Mon 2006-01-02 15:04 MST")} + if err := t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs); err != nil { + if logger != nil { + logger.Println("ERROR executing showLink.tmpl for host:", r.Host, "err:", err) + } + http.Error(w, errServerError, http.StatusInternalServerError) + return + } + logOK(r, http.StatusOK) + return + } + } currentLinkLen.Mutex.RLock() currentLinkLenTimeout := currentLinkLen.Timeout currentLinkLen.Mutex.RUnlock() @@ -116,26 +142,28 @@ func handleRequests(w http.ResponseWriter, r *http.Request) { showLnk := &Link{Key: customKey, LinkType: "url", Data: formURL, IsCompressed: isCompressed, Times: xTimes, Timeout: time.Now().Add(currentLinkLenTimeout)} key, err := currentLinkLen.Add(showLnk) - if err == nil { - w.Header().Add("Content-Type", "text/html; charset=utf-8") - logger.Println("requesting template :", r.Host+"showLink") - t, ok := templateMap[r.Host+"#showLink"] - if !ok { - logger.Println("ERROR getting template template :", r.Host+"showLink") - http.Error(w, errServerError, http.StatusInternalServerError) - return + if err != nil { + logErrors(w, r, err.Error(), http.StatusInternalServerError, url.QueryEscape(err.Error())) + return + } + w.Header().Add("Content-Type", "text/html; charset=utf-8") + t, ok := templateMap[r.Host+"#showLink"] + if !ok { + if logger != nil { + logger.Println("ERROR getting template showLink for host:", r.Host) } - - tmplArgs := showLinkVars{Domain: scheme + "://" + r.Host, Data: scheme + "://" + r.Host + "/" + key, Timeout: showLnk.Timeout.Format("Mon 2006-01-02 15:04 MST")} - - err = t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs) - if err != nil { - logger.Println("ERROR executing template template showLink.tmpl for host :", r.Host, "with args: ", tmplArgs, "with the error: ", err) - http.Error(w, errServerError, http.StatusInternalServerError) + http.Error(w, errServerError, http.StatusInternalServerError) + return + } + tmplArgs := showLinkVars{Domain: scheme + "://" + r.Host, Data: scheme + "://" + r.Host + "/" + key, Timeout: showLnk.Timeout.Format("Mon 2006-01-02 15:04 MST")} + if err = t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs); err != nil { + if logger != nil { + logger.Println("ERROR executing showLink.tmpl for host:", r.Host, "err:", err) } - logOK(r, http.StatusOK) + http.Error(w, errServerError, http.StatusInternalServerError) return } + logOK(r, http.StatusOK) return case "text": if lowRAM() { @@ -159,23 +187,25 @@ func handleRequests(w http.ResponseWriter, r *http.Request) { showLnk := &Link{Key: customKey, LinkType: "text", Data: textBlob, IsCompressed: isCompressed, Times: xTimes, Timeout: time.Now().Add(currentLinkLenTimeout)} key, err := currentLinkLen.Add(showLnk) - if err == nil { - w.Header().Add("Content-Type", "text/html; charset=utf-8") - t, ok := templateMap[r.Host+"#showLink"] - if !ok { - http.Error(w, errServerError, http.StatusInternalServerError) - return - } - tmplArgs := showLinkVars{Domain: scheme + "://" + r.Host, Data: scheme + "://" + r.Host + "/" + key, Timeout: showLnk.Timeout.Format("Mon 2006-01-02 15:04 MST")} - - err = t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs) - if err != nil { - logger.Println("ERROR executing template template showLink.tmpl for host :", r.Host, "with args: ", tmplArgs) - http.Error(w, errServerError, http.StatusInternalServerError) + if err != nil { + logErrors(w, r, err.Error(), http.StatusInternalServerError, url.QueryEscape(err.Error())) + return + } + w.Header().Add("Content-Type", "text/html; charset=utf-8") + t, ok := templateMap[r.Host+"#showLink"] + if !ok { + http.Error(w, errServerError, http.StatusInternalServerError) + return + } + tmplArgs := showLinkVars{Domain: scheme + "://" + r.Host, Data: scheme + "://" + r.Host + "/" + key, Timeout: showLnk.Timeout.Format("Mon 2006-01-02 15:04 MST")} + if err = t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs); err != nil { + if logger != nil { + logger.Println("ERROR executing showLink.tmpl for host:", r.Host, "err:", err) } - logOK(r, http.StatusOK) + http.Error(w, errServerError, http.StatusInternalServerError) return } + logOK(r, http.StatusOK) return default: logErrors(w, r, errNotImplemented, http.StatusNotImplemented, "Error: Invalid requestType argument.") @@ -221,23 +251,24 @@ func handleGET(w http.ResponseWriter, r *http.Request) { // verify that key only consists of valid characters if !validate(key) { - logErrors(w, r, errInvalidKey, http.StatusInternalServerError, "") + logErrors(w, r, errInvalidKey, http.StatusBadRequest, "") + return + } + + // Admin endpoint — authenticated via HTTP Basic Auth, no query string required. + if key == "listactive~" { + listActiveLinks(w, r) return } // quick check if request is quickAddURL request if len(r.URL.RawQuery) > 0 { - if key == "listactive~" { - listActiveLinks(w, r) - return - } if validURL(r.URL.RawQuery) { quickAddURL(w, r, r.URL.RawQuery, key) return - } else { - logErrors(w, r, "Invalid Quick Add URL request", http.StatusInternalServerError, "Invalid Quick Add URL request, please use the following syntax: \""+r.Host+"?http://example.com/\". where http://example.com/ is your link.\nAlso note that only \"http://\" and \"https://\" url schemes are allowed.") - return } + logErrors(w, r, "Invalid Quick Add URL request", http.StatusInternalServerError, "Invalid Quick Add URL request, please use the following syntax: \""+r.Host+"?http://example.com/\". where http://example.com/ is your link.\nAlso note that only \"http://\" and \"https://\" url schemes are allowed.") + return } var showLink bool @@ -248,6 +279,10 @@ func handleGET(w http.ResponseWriter, r *http.Request) { // start by checking static key map if lnk, ok := config.StaticLinks[key]; ok { + if !validURL(lnk) { + http.Error(w, errInvalidKey, http.StatusInternalServerError) + return + } logOK(r, http.StatusPermanentRedirect) http.Redirect(w, r, lnk, http.StatusPermanentRedirect) return @@ -258,50 +293,20 @@ func handleGET(w http.ResponseWriter, r *http.Request) { scheme = "https" } - var lnk *Link - var ok bool - switch keylen := len(key); { - case keylen == 1: - domainLinkLens[r.Host].LinkLen1.Mutex.RLock() - if lnk, ok = domainLinkLens[r.Host].LinkLen1.LinkMap[key]; !ok { - domainLinkLens[r.Host].LinkLen1.Mutex.RUnlock() - http.Error(w, errInvalidKey, http.StatusInternalServerError) - return - } - domainLinkLens[r.Host].LinkLen1.Mutex.RUnlock() - case keylen == 2: - domainLinkLens[r.Host].LinkLen2.Mutex.RLock() - if lnk, ok = domainLinkLens[r.Host].LinkLen2.LinkMap[key]; !ok { - domainLinkLens[r.Host].LinkLen2.Mutex.RUnlock() - http.Error(w, errInvalidKey, http.StatusInternalServerError) - return - } - domainLinkLens[r.Host].LinkLen2.Mutex.RUnlock() - case keylen == 3: - domainLinkLens[r.Host].LinkLen3.Mutex.RLock() - if lnk, ok = domainLinkLens[r.Host].LinkLen3.LinkMap[key]; !ok { - domainLinkLens[r.Host].LinkLen3.Mutex.RUnlock() - http.Error(w, errInvalidKey, http.StatusInternalServerError) - return - } - domainLinkLens[r.Host].LinkLen3.Mutex.RUnlock() - case keylen > 3 && keylen < maxKeyLen: - // key is validated previously - domainLinkLens[r.Host].LinkCustom.Mutex.RLock() - if lnk, ok = domainLinkLens[r.Host].LinkCustom.LinkMap[key]; !ok { - domainLinkLens[r.Host].LinkCustom.Mutex.RUnlock() - http.Error(w, errInvalidKey, http.StatusInternalServerError) - return - } - domainLinkLens[r.Host].LinkCustom.Mutex.RUnlock() - default: - http.Error(w, errInvalidKey, http.StatusInternalServerError) + lnk, ll := lookupLink(r.Host, key) + if lnk == nil { + http.Error(w, errInvalidKey, http.StatusNotFound) return } - if lnk == nil { - http.Error(w, errInvalidKey, http.StatusInternalServerError) - return + // Record the access: increment AccessCount and, for limited links, consume one use. + // showLink (info-only, ~ suffix) requests are not counted as accesses. + if !showLink { + lnk = recordAccess(ll, key) + if lnk == nil { + http.Error(w, errInvalidKey, http.StatusNotFound) + return + } } switch lnk.LinkType { @@ -309,18 +314,28 @@ func handleGET(w http.ResponseWriter, r *http.Request) { if showLink { logOK(r, http.StatusOK) w.Header().Add("Content-Type", "text/plain; charset=utf-8") - fmt.Fprint(w, r.Host+"/"+key+"\n\nis pointing to \n\n"+html.EscapeString(lnk.Data)) + timesInfo := "unlimited" + if lnk.Times >= 0 { + timesInfo = strconv.Itoa(lnk.Times) + " remaining" + } + fmt.Fprintf(w, "%s/%s\n\npoints to\n\n%s\n\nAccesses: %d Uses: %s Expires: %s", + r.Host, key, html.EscapeString(lnk.Data), + lnk.AccessCount, timesInfo, lnk.Timeout.Format(dateFormat)) return } w.Header().Add("Content-Type", "text/html; charset=utf-8") t, ok := templateMap[r.Host+"#showLink"] if !ok { http.Error(w, errServerError, http.StatusInternalServerError) + return } tmplArgs := showLinkVars{Domain: scheme + "://" + r.Host, Data: lnk.Data, Timeout: lnk.Timeout.Format("Mon 2006-01-02 15:04 MST")} - err := t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs) - if err != nil { + if err := t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs); err != nil { + if logger != nil { + logger.Println("ERROR executing showLink.tmpl:", err) + } http.Error(w, errServerError, http.StatusInternalServerError) + return } logOK(r, http.StatusTemporaryRedirect) return @@ -328,19 +343,25 @@ func handleGET(w http.ResponseWriter, r *http.Request) { w.Header().Add("Content-Type", "text/plain; charset=utf-8") if showLink { logOK(r, http.StatusOK) - fmt.Fprint(w, r.Host+"/"+key+"\n\nis pointing to a "+r.Host+" Text dump") + fmt.Fprintf(w, "%s/%s\n\npoints to a text blob\n\nAccesses: %d Expires: %s", + r.Host, key, lnk.AccessCount, lnk.Timeout.Format(dateFormat)) return } if lnk.IsCompressed { if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") { + // Data is stored as base64(gzip); decode to raw gzip bytes for the client. + raw, err := base64.StdEncoding.DecodeString(lnk.Data) + if err != nil { + http.Error(w, errServerError, http.StatusInternalServerError) + return + } w.Header().Add("content-encoding", "gzip") logOK(r, http.StatusOK) - fmt.Fprint(w, lnk.Data) - return - } else { - returnDecompressed(lnk, w, r) // defined in misc.go + w.Write(raw) return } + returnDecompressed(lnk, w, r) + return } logOK(r, http.StatusOK) fmt.Fprint(w, lnk.Data) @@ -351,7 +372,7 @@ func handleGET(w http.ResponseWriter, r *http.Request) { } func handleCSS(mux *http.ServeMux) { - f, err := ioutil.ReadFile(filepath.Join(config.BaseDir, "css", "shorter.css")) + f, err := os.ReadFile(filepath.Join(config.BaseDir, "css", "shorter.css")) if err != nil { log.Fatalln("Missing shorter.css in Template dir/css/") } @@ -375,7 +396,7 @@ func getSingleFileHandler(f []byte, mimeType string) (handleFile func(w http.Res if validRequest(r) { w.Header().Add("Content-Type", mimeType) w.Header().Add("Cache-Control", "max-age=2592000, public") - if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") && tryGzip && false { + if strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") && tryGzip { w.Header().Add("content-encoding", "gzip") fmt.Fprintf(w, "%s", cf) return @@ -410,7 +431,7 @@ func handleImages(mux *http.ServeMux) { defaultFavicon := []byte{0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x10, 0x00, 0x00, 0x00, 0x10, 0x08, 0x03, 0x00, 0x00, 0x00, 0x28, 0x2d, 0x0f, 0x53, 0x00, 0x00, 0x00, 0x9c, 0x50, 0x4c, 0x54, 0x45, 0x1f, 0x9b, 0xed, 0x1f, 0x9b, 0xef, 0x1e, 0x9a, 0xed, 0x1e, 0x9c, 0xed, 0x1f, 0x9b, 0xee, 0x1f, 0x9c, 0xef, 0x20, 0x9c, 0xee, 0x20, 0x9c, 0xee, 0x21, 0x9c, 0xee, 0x23, 0x9d, 0xee, 0x26, 0x9e, 0xee, 0x28, 0x9f, 0xee, 0x2a, 0xa0, 0xee, 0x2e, 0xa2, 0xef, 0x31, 0xa3, 0xef, 0x37, 0xa6, 0xef, 0x39, 0xa7, 0xef, 0x45, 0xac, 0xf0, 0x55, 0xb3, 0xf1, 0x5e, 0xb7, 0xf2, 0x62, 0xb9, 0xf3, 0x63, 0xb9, 0xf2, 0x65, 0xba, 0xf2, 0x6e, 0xbe, 0xf3, 0x77, 0xc2, 0xf4, 0x78, 0xc2, 0xf4, 0x81, 0xc7, 0xf5, 0x87, 0xc9, 0xf5, 0x8a, 0xca, 0xf5, 0x8b, 0xcb, 0xf5, 0x91, 0xce, 0xf6, 0x96, 0xd0, 0xf6, 0x99, 0xd1, 0xf6, 0x9b, 0xd2, 0xf6, 0x9b, 0xd2, 0xf7, 0x9d, 0xd3, 0xf7, 0x9f, 0xd4, 0xf7, 0xb9, 0xe0, 0xf9, 0xcb, 0xe7, 0xfa, 0xd7, 0xed, 0xfb, 0xda, 0xee, 0xfb, 0xdf, 0xf0, 0xfc, 0xe5, 0xf3, 0xfc, 0xe7, 0xf4, 0xfc, 0xeb, 0xf6, 0xfd, 0xed, 0xf7, 0xfd, 0xf0, 0xf8, 0xfd, 0xf1, 0xf8, 0xfd, 0xf2, 0xf9, 0xfd, 0xf5, 0xfa, 0xfe, 0xf9, 0xfc, 0xfe, 0xff, 0xff, 0xff, 0x7a, 0x52, 0xe8, 0x58, 0x00, 0x00, 0x00, 0x07, 0x74, 0x52, 0x4e, 0x53, 0x7d, 0x7d, 0x7e, 0x7e, 0xf8, 0xf8, 0xf9, 0x01, 0xb6, 0xcf, 0xc8, 0x00, 0x00, 0x00, 0x7e, 0x49, 0x44, 0x41, 0x54, 0x18, 0x57, 0x55, 0xcf, 0xc7, 0x12, 0x82, 0x40, 0x10, 0x84, 0xe1, 0x51, 0x59, 0x7f, 0xd7, 0x84, 0x62, 0x00, 0x23, 0x06, 0xcc, 0x71, 0x9d, 0xf7, 0x7f, 0x37, 0x2f, 0x50, 0x35, 0xf4, 0xad, 0xbf, 0xaa, 0x3e, 0xb4, 0xb4, 0x1c, 0x26, 0xae, 0x21, 0x6d, 0xdb, 0x21, 0x12, 0xdb, 0x26, 0xdb, 0x18, 0x21, 0x7f, 0x95, 0x59, 0xf1, 0xd1, 0x3d, 0xc2, 0x21, 0x84, 0x10, 0xc2, 0x4f, 0xdf, 0x03, 0x66, 0xf9, 0x88, 0x6a, 0x72, 0xd6, 0x0d, 0x24, 0x69, 0x5c, 0xc1, 0x5c, 0x9f, 0x7d, 0x38, 0xe9, 0xb4, 0x04, 0x7f, 0xd5, 0x25, 0x16, 0x32, 0xbd, 0x77, 0x2d, 0xf4, 0x1e, 0xba, 0xc0, 0xc2, 0x5a, 0x6f, 0xde, 0xc2, 0xf0, 0xab, 0x29, 0x16, 0x8e, 0x7a, 0xe9, 0x58, 0xf0, 0xbb, 0x22, 0x01, 0x80, 0xac, 0x18, 0x23, 0xb5, 0xb3, 0xe0, 0xa4, 0x59, 0x93, 0x48, 0xfe, 0x29, 0x72, 0x10, 0x99, 0xc7, 0x5c, 0x2b, 0x48, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, 0xae, 0x42, 0x60, 0x82} for _, domain := range config.DomainNames { - logo, err := ioutil.ReadFile(filepath.Join(config.BaseDir, domain, "logo.png")) + logo, err := os.ReadFile(filepath.Join(config.BaseDir, domain, "logo.png")) if err != nil { if logger != nil { logger.Println("Missing /" + domain + "/logo.png in Template dir, fallback to default logo.png") @@ -420,7 +441,7 @@ func handleImages(mux *http.ServeMux) { ImageMap[domain+"-logo"] = logo } - favicon, err := ioutil.ReadFile(filepath.Join(config.BaseDir, domain, "favicon.png")) + favicon, err := os.ReadFile(filepath.Join(config.BaseDir, domain, "favicon.png")) if err != nil { if logger != nil { logger.Println("Missing /" + domain + "/favicon.png in Template dir, fallback to default favicon.png") @@ -438,7 +459,7 @@ func handleImages(mux *http.ServeMux) { // handleRobots will return the robots.txt located in the Template dir specified in the config file, if no robots.txt file is found we return a 404 error func handleRobots(mux *http.ServeMux) { - f, err := ioutil.ReadFile(filepath.Join(config.BaseDir, "robots.txt")) + f, err := os.ReadFile(filepath.Join(config.BaseDir, "robots.txt")) if err != nil { if logger != nil { logger.Println("Missing robots.txt in Template dir, fallback to returning 404 on requests for robots.txt") @@ -463,6 +484,11 @@ func handleRobots(mux *http.ServeMux) { } func quickAddURL(w http.ResponseWriter, r *http.Request, url, key string) { + if isBlocklisted(url) { + http.Error(w, "URL is not allowed", http.StatusForbidden) + return + } + var urlLink *LinkLen // Remove keys of invalid size, note that key has been validated to only contain valid characters previously @@ -483,10 +509,7 @@ func quickAddURL(w http.ResponseWriter, r *http.Request, url, key string) { continue } urlLink = &domainLinkLens[r.Host].LinkCustom - if _, used := urlLink.LinkMap[key]; used { - http.Error(w, errInvalidKeyUsed, http.StatusInternalServerError) - return - } + // Duplicate-key check is handled atomically inside Add() under the write lock. case 1: urlLink = &domainLinkLens[r.Host].LinkLen1 case 2: @@ -500,9 +523,7 @@ func quickAddURL(w http.ResponseWriter, r *http.Request, url, key string) { linkTimeout := urlLink.Timeout urlLink.Mutex.RUnlock() - isCompressed := false - - showLink := &Link{Key: key, LinkType: "url", Data: url, IsCompressed: isCompressed, Times: -1, Timeout: time.Now().Add(linkTimeout)} + showLink := &Link{Key: key, LinkType: "url", Data: url, Times: -1, Timeout: time.Now().Add(linkTimeout)} _, err := urlLink.Add(showLink) if err == nil { w.Header().Add("Content-Type", "text/html; charset=utf-8") @@ -511,14 +532,20 @@ func quickAddURL(w http.ResponseWriter, r *http.Request, url, key string) { http.Error(w, errServerError, http.StatusInternalServerError) return } - tmplArgs := showLinkVars{Domain: scheme + "://" + r.Host, Data: showLink.Data, Timeout: showLink.Timeout.Format("Mon 2006-01-02 15:04 MST")} - err = t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs) - if err != nil { + if err = t.ExecuteTemplate(w, "showLink.tmpl", tmplArgs); err != nil { http.Error(w, errServerError, http.StatusInternalServerError) + return } logOK(r, http.StatusOK) return } + // If the custom-key slot was already taken, stop immediately rather than + // falling through to auto-assigned length buckets. + if i == 0 { + logErrors(w, r, err.Error(), http.StatusInternalServerError, err.Error()) + return + } } + logErrors(w, r, errServerError, http.StatusInternalServerError, "quickAddURL: all key lengths exhausted") } diff --git a/letsencrypt.go b/letsencrypt.go index 3d3d032..74d49a9 100644 --- a/letsencrypt.go +++ b/letsencrypt.go @@ -27,24 +27,30 @@ func getServer(mux *http.ServeMux) (server *http.Server) { Email: config.Email, } tlsConf := &tls.Config{ - Rand: rand.Reader, - Time: time.Now, - NextProtos: []string{acme.ALPNProto, "http/1.1"}, // add http2.NextProtoTLS? - MinVersion: tls.VersionTLS12, - CurvePreferences: []tls.CurveID{tls.CurveP521, tls.CurveP384, tls.CurveP256}, - GetCertificate: m.GetCertificate, - PreferServerCipherSuites: true, + Rand: rand.Reader, + Time: time.Now, + NextProtos: []string{acme.ALPNProto, "http/1.1"}, + MinVersion: tls.VersionTLS12, + // X25519 first: fastest handshake and forward-secret; P-256 as fallback. + CurvePreferences: []tls.CurveID{tls.X25519, tls.CurveP256}, + GetCertificate: m.GetCertificate, + // Only AEAD cipher suites for TLS 1.2; TLS 1.3 cipher selection is automatic. CipherSuites: []uint16{ - tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, }, } server = &http.Server{ - Addr: config.TLSAddressPort, - Handler: mux, - TLSConfig: tlsConf, + Addr: config.TLSAddressPort, + Handler: mux, + TLSConfig: tlsConf, + ReadTimeout: 15 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, // https://blog.bracebin.com/achieving-perfect-ssl-labs-score-with-go TLSNextProto: make(map[string]func(*http.Server, *tls.Conn, http.Handler), 0), } diff --git a/misc.go b/misc.go index 35d2a45..2e3e9a5 100644 --- a/misc.go +++ b/misc.go @@ -4,9 +4,11 @@ import ( "bytes" "compress/gzip" "crypto/sha256" + "crypto/subtle" + "encoding/base64" "encoding/hex" + "errors" "fmt" - "go/build" "html/template" "io" "net/http" @@ -18,18 +20,26 @@ import ( "sync" ) -// validate validates if string s contains only characters in charset. validate is not a crypto related function so no need for constant time +// customKeyCharSet is a pre-built lookup for O(k) key validation instead of O(k·n). +var customKeyCharSet = func() map[rune]struct{} { + m := make(map[rune]struct{}, len([]rune(customKeyCharset))) + for _, r := range customKeyCharset { + m[r] = struct{}{} + } + return m +}() + +// validate reports whether s contains only characters from customKeyCharset. +// A trailing '~' (info-view suffix) is stripped before checking. func validate(s string) bool { if len(s) == 0 { return true } - if s[len(s)-1] == '~' { s = s[:len(s)-1] } - for _, char := range s { - if !strings.Contains(customKeyCharset, string(char)) { + if _, ok := customKeyCharSet[char]; !ok { return false } } @@ -58,6 +68,7 @@ func initLinkLensDomain(domain string) { FreeMap: make(map[string]bool), Timeout: config.Clear1Duration, Domain: domain, + Type: "len1", } domainLinkLens[domain].LinkLen2 = LinkLen{ @@ -66,6 +77,7 @@ func initLinkLensDomain(domain string) { FreeMap: make(map[string]bool), Timeout: config.Clear2Duration, Domain: domain, + Type: "len2", } domainLinkLens[domain].LinkLen3 = LinkLen{ @@ -74,6 +86,7 @@ func initLinkLensDomain(domain string) { FreeMap: make(map[string]bool), Timeout: config.Clear3Duration, Domain: domain, + Type: "len3", } domainLinkLens[domain].LinkCustom = LinkLen{ @@ -81,6 +94,7 @@ func initLinkLensDomain(domain string) { LinkMap: make(map[string]*Link), Timeout: config.ClearCustomLinksDuration, Domain: domain, + Type: "custom", } domainLinkLens[domain].LinkLen1.Mutex.Lock() @@ -104,7 +118,159 @@ func initLinkLensDomain(domain string) { } } +// lookupLink returns a snapshot copy of the Link and its owning LinkLen for the +// given host+key. The copy is taken under the read lock, so callers may read any +// field of the returned *Link without holding a lock. Returns (nil, ll) if the key +// is not found (ll is still set so callers can pass it to recordAccess). +func lookupLink(host, key string) (*Link, *LinkLen) { + dl, ok := domainLinkLens[host] + if !ok { + return nil, nil + } + var ll *LinkLen + switch keylen := len(key); { + case keylen == 1: + ll = &dl.LinkLen1 + case keylen == 2: + ll = &dl.LinkLen2 + case keylen == 3: + ll = &dl.LinkLen3 + case keylen > 3 && keylen < maxKeyLen: + ll = &dl.LinkCustom + default: + return nil, nil + } + ll.Mutex.RLock() + lnkPtr := ll.LinkMap[key] + var snap Link + if lnkPtr != nil { + snap = *lnkPtr // copy while holding the read lock + } + ll.Mutex.RUnlock() + if lnkPtr == nil { + return nil, ll + } + return &snap, ll +} + +// recordAccess increments AccessCount and, for limited-use links (Times > 0), +// decrements Times under the write lock. If Times reaches zero the link is deleted +// from the map but a snapshot is returned so the caller can serve the final response. +// Returns nil if the link no longer exists (expired between lookupLink and now). +func recordAccess(ll *LinkLen, key string) *Link { + ll.Mutex.Lock() + defer ll.Mutex.Unlock() + lnk, ok := ll.LinkMap[key] + if !ok { + return nil + } + lnk.AccessCount++ + if lnk.Times > 0 { + lnk.Times-- + if lnk.Times <= 0 { + snapshot := *lnk + delete(ll.LinkMap, key) + if ll.FreeMap != nil { + ll.FreeMap[key] = true + } else { + ll.Links-- + } + domain, typ := ll.Domain, ll.Type + go deleteLinkFromDB(domain, typ, key) + return &snapshot + } + } + snapshot := *lnk + return &snapshot +} + +// findExistingURL returns the first active, unlimited-use URL link that points +// to targetURL for the given host, or nil if none exists. +// Only unlimited links (Times == -1) are matched to avoid silently consuming uses. +func findExistingURL(host, targetURL string) *Link { + dl, ok := domainLinkLens[host] + if !ok { + return nil + } + search := func(ll *LinkLen) *Link { + ll.Mutex.RLock() + defer ll.Mutex.RUnlock() + for _, lnk := range ll.LinkMap { + if lnk.LinkType == "url" && lnk.Data == targetURL && lnk.Times == -1 { + return lnk + } + } + return nil + } + if lnk := search(&dl.LinkLen1); lnk != nil { + return lnk + } + if lnk := search(&dl.LinkLen2); lnk != nil { + return lnk + } + if lnk := search(&dl.LinkLen3); lnk != nil { + return lnk + } + return search(&dl.LinkCustom) +} + +var blocklist = make(map[string]struct{}) + +// loadBlocklist populates the blocklist from the config's BlockedDomains list +// and, if BlocklistFile is set, from a newline-delimited file (# comments OK). +func loadBlocklist() { + for _, domain := range config.BlockedDomains { + blocklist[strings.ToLower(strings.TrimSpace(domain))] = struct{}{} + } + if config.BlocklistFile == "" { + return + } + data, err := os.ReadFile(config.BlocklistFile) + if err != nil { + if logger != nil { + logger.Println("Failed to load blocklist file:", err) + } + return + } + for _, line := range strings.Split(string(data), "\n") { + line = strings.ToLower(strings.TrimSpace(line)) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + blocklist[line] = struct{}{} + } + if logger != nil { + logger.Println("Loaded", len(blocklist), "entries into blocklist") + } +} + +// isBlocklisted returns true if the hostname of rawURL (or any parent domain) +// appears in the blocklist. +func isBlocklisted(rawURL string) bool { + if len(blocklist) == 0 { + return false + } + parsed, err := url.Parse(rawURL) + if err != nil { + return false + } + host := strings.ToLower(parsed.Hostname()) + for host != "" { + if _, blocked := blocklist[host]; blocked { + return true + } + idx := strings.IndexByte(host, '.') + if idx < 0 { + break + } + host = host[idx+1:] + } + return false +} + func addHeaders(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Content-Type-Options", "nosniff") + w.Header().Add("Referrer-Policy", "no-referrer") if config.ReportTo != "" { w.Header().Add("Report-To", strings.ReplaceAll(config.ReportTo, "###DomainNames###", r.Host)) } @@ -151,27 +317,28 @@ func findFolderDefaultLocations(folder string) (path string) { if _, err := os.Stat(filepath.Join(".", folder)); !os.IsNotExist(err) { return filepath.Join(".", folder) } - possibleDirs := os.Getenv("GOPATH") - if possibleDirs == "" { - possibleDirs = build.Default.GOPATH + gopath := os.Getenv("GOPATH") + if gopath == "" { + // Fallback: use the default GOPATH convention ($HOME/go). + if home, err := os.UserHomeDir(); err == nil { + gopath = filepath.Join(home, "go") + } } - var dirs []string + sep := ":" if runtime.GOOS == "windows" { - dirs = strings.Split(possibleDirs, ";") - } else { - dirs = strings.Split(possibleDirs, ":") + sep = ";" } - for _, dir := range dirs { - if _, err := os.Stat(filepath.Join(dir, "src", "github.com", "7i", "shorter", folder)); !os.IsNotExist(err) { - // Found - return filepath.Join(dir, "src", "github.com", "7i", "shorter", folder) + for _, dir := range strings.Split(gopath, sep) { + candidate := filepath.Join(dir, "src", "github.com", "7i", "shorter", folder) + if _, err := os.Stat(candidate); !os.IsNotExist(err) { + return candidate } } - return "" } -func compress(data string) (compressedData string, err error) { +// compress gzips data and returns it base64-encoded so it is safe for JSON storage. +func compress(data string) (string, error) { var buf bytes.Buffer zw := gzip.NewWriter(&buf) if _, err := io.Copy(zw, strings.NewReader(data)); err != nil { @@ -180,41 +347,54 @@ func compress(data string) (compressedData string, err error) { if err := zw.Close(); err != nil { return "", err } - return buf.String(), nil + return base64.StdEncoding.EncodeToString(buf.Bytes()), nil } -func decompress(data string) (decompressedData string, err error) { - var buf bytes.Buffer - zw, err := gzip.NewReader(strings.NewReader(data)) +// decompress decodes a base64-encoded gzip payload produced by compress(). +// Returns an error if the decompressed size exceeds maxDecompressedSize. +func decompress(data string) (string, error) { + raw, err := base64.StdEncoding.DecodeString(data) + if err != nil { + return "", err + } + zr, err := gzip.NewReader(bytes.NewReader(raw)) if err != nil { return "", err } - if _, err := io.Copy(&buf, zw); err != nil { + var buf bytes.Buffer + if _, err = io.Copy(&buf, io.LimitReader(zr, maxDecompressedSize+1)); err != nil { return "", err } - if err := zw.Close(); err != nil { + if err = zr.Close(); err != nil { return "", err } + if int64(buf.Len()) > maxDecompressedSize { + return "", errors.New("decompressed size exceeds limit") + } return buf.String(), nil } func returnDecompressed(lnk *Link, w http.ResponseWriter, r *http.Request) { if lnk == nil { - logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: invalid lnk in request to returnDecompressed().") + logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: nil lnk in returnDecompressed") + return + } + raw, err := base64.StdEncoding.DecodeString(lnk.Data) + if err != nil { + logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: base64 decode failed in returnDecompressed") return } - dataReader, err := gzip.NewReader(strings.NewReader(lnk.Data)) - if err == nil { - fmt.Println("ERROR in lnk.Data, misc.go line 203", lnk.Data) // DEBUG - logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: invalid lnk.Data in request to returnDecompressed().") + dataReader, err := gzip.NewReader(bytes.NewReader(raw)) + if err != nil { + logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: gzip open failed in returnDecompressed") return } - if _, err = io.Copy(w, dataReader); err != nil { - logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: while decompresing in request to returnDecompressed().") + if _, err = io.Copy(w, io.LimitReader(dataReader, maxDecompressedSize+1)); err != nil { + logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: decompression failed in returnDecompressed") return } if err = dataReader.Close(); err != nil { - logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: closing dataReader in request to returnDecompressed().") + logErrors(w, r, errServerError, http.StatusInternalServerError, "Error: closing dataReader in returnDecompressed") return } logOK(r, http.StatusOK) @@ -234,50 +414,63 @@ func logOK(r *http.Request, statusCode int) { } } -// fugly temp function +// listActiveLinks serves the admin link-list endpoint. +// Authentication uses HTTP Basic Auth; the password is sha256(password+Salt) == HashSHA256. func listActiveLinks(w http.ResponseWriter, r *http.Request) { - ba := sha256.Sum256([]byte(r.URL.RawQuery + config.Salt)) - pwd := hex.EncodeToString(ba[:]) - if pwd == config.HashSHA256 { + _, password, ok := r.BasicAuth() + if !ok { + w.Header().Set("WWW-Authenticate", `Basic realm="shorter admin"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + ba := sha256.Sum256([]byte(password + config.Salt)) + computed := hex.EncodeToString(ba[:]) + if subtle.ConstantTimeCompare([]byte(computed), []byte(config.HashSHA256)) == 1 { w.Header().Add("Content-Type", "text/plain") - resp := "" for _, domain := range config.DomainNames { - resp += "Domain: " + domain + "\n" - resp += "Linklen 1:\n" - resp += getActiveList(&domainLinkLens[domain].LinkLen1) - resp += "Linklen 2:\n" - resp += getActiveList(&domainLinkLens[domain].LinkLen2) - resp += "Linklen 3:\n" - resp += getActiveList(&domainLinkLens[domain].LinkLen3) - resp += "Custome Links:\n" - resp += getActiveList(&domainLinkLens[domain].LinkCustom) + fmt.Fprintf(w, "Domain: %s\nLinklen 1:\n%sLinklen 2:\n%sLinklen 3:\n%sCustom Links:\n%s", + domain, + getActiveList(&domainLinkLens[domain].LinkLen1), + getActiveList(&domainLinkLens[domain].LinkLen2), + getActiveList(&domainLinkLens[domain].LinkLen3), + getActiveList(&domainLinkLens[domain].LinkCustom), + ) } - - fmt.Fprint(w, resp) } else { - http.Error(w, errServerError, http.StatusInternalServerError) + w.Header().Set("WWW-Authenticate", `Basic realm="shorter admin"`) + http.Error(w, "Unauthorized", http.StatusUnauthorized) } } -func getActiveList(l *LinkLen) (resp string) { - l.Mutex.Lock() - next := *l.NextClear - stop := false - for !stop { - resp += "Domain: " + l.Domain + " Key: " + next.Key + " LinkType: " + next.LinkType + " IsCompressed: " + fmt.Sprintf("%v", next.IsCompressed) + "Timeout:" + next.Timeout.String() + "Data: " +func getActiveList(l *LinkLen) string { + l.Mutex.RLock() + defer l.Mutex.RUnlock() + if l.NextClear == nil { + return "" + } + var sb strings.Builder + next := l.NextClear + for next != nil { + sb.WriteString("Domain: ") + sb.WriteString(l.Domain) + sb.WriteString(" Key: ") + sb.WriteString(next.Key) + sb.WriteString(" LinkType: ") + sb.WriteString(next.LinkType) + sb.WriteString(" IsCompressed: ") + sb.WriteString(fmt.Sprintf("%v", next.IsCompressed)) + sb.WriteString(" Timeout: ") + sb.WriteString(next.Timeout.String()) + sb.WriteString(" Data: ") if next.IsCompressed { - resp += url.QueryEscape(next.Data) + "\n" + sb.WriteString("\n") } else { - resp += next.Data + "\n" - } - if next.NextClear != nil { - next = *next.NextClear - } else { - stop = true + sb.WriteString(url.QueryEscape(next.Data)) + sb.WriteByte('\n') } + next = next.NextClear } - l.Mutex.Unlock() - return + return sb.String() } func initTemplates() { @@ -305,7 +498,9 @@ func loadTemplate(templateName, defaultTmplStr string) { } templateMap[domain+"#"+templateName] = defaultTmpl } else { - logger.Println("Template key value: ", domain+"#"+templateName) + if logger != nil { + logger.Println("Template key value: ", domain+"#"+templateName) + } templateMap[domain+"#"+templateName] = tmpl } } diff --git a/qr.go b/qr.go new file mode 100644 index 0000000..32ed860 --- /dev/null +++ b/qr.go @@ -0,0 +1,59 @@ +package main + +import ( + "net/http" + "strings" + + qrcode "github.com/skip2/go-qrcode" +) + +// handleQR registers the /qr/ route, which returns a QR code PNG for a short link. +func handleQR(mux *http.ServeMux) { + mux.HandleFunc("/qr/", qrHandler) +} + +// qrHandler serves GET /qr/{key} — returns a 256×256 PNG QR code encoding the +// full short URL for the given key. The image is cached for 24 hours. +func qrHandler(w http.ResponseWriter, r *http.Request) { + if !validRequest(r) { + http.Error(w, errServerError, http.StatusInternalServerError) + return + } + if r.Method != http.MethodGet { + http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) + return + } + + key := strings.TrimPrefix(r.URL.Path, "/qr/") + key = strings.TrimSuffix(key, "/") + if len(key) == 0 || !validate(key) { + http.Error(w, errInvalidKey, http.StatusBadRequest) + return + } + + lnk, _ := lookupLink(r.Host, key) + if lnk == nil { + // Also check static links. + if _, ok := config.StaticLinks[key]; !ok { + http.Error(w, errInvalidKey, http.StatusNotFound) + return + } + } + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + shortURL := scheme + "://" + r.Host + "/" + key + + png, err := qrcode.Encode(shortURL, qrcode.Medium, 256) + if err != nil { + http.Error(w, errServerError, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "image/png") + w.Header().Set("Cache-Control", "max-age=86400, public") + w.Write(png) + logOK(r, http.StatusOK) +} diff --git a/ratelimit.go b/ratelimit.go new file mode 100644 index 0000000..73e4772 --- /dev/null +++ b/ratelimit.go @@ -0,0 +1,56 @@ +package main + +import ( + "net" + "sync" + "time" +) + +const ( + rateLimitMax = 30 // max POST requests per IP per window + rateLimitWindow = time.Minute // sliding window duration +) + +type ipRate struct { + count int + windowEnd time.Time +} + +var ( + rateMu sync.Mutex + rateClients = make(map[string]*ipRate) +) + +func init() { + go func() { + ticker := time.NewTicker(5 * time.Minute) + for range ticker.C { + now := time.Now() + rateMu.Lock() + for ip, r := range rateClients { + if now.After(r.windowEnd) { + delete(rateClients, ip) + } + } + rateMu.Unlock() + } + }() +} + +// rateLimitAllow returns true if the request is within the rate limit for the given remoteAddr. +func rateLimitAllow(remoteAddr string) bool { + ip, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + ip = remoteAddr + } + now := time.Now() + rateMu.Lock() + defer rateMu.Unlock() + cl, ok := rateClients[ip] + if !ok || now.After(cl.windowEnd) { + rateClients[ip] = &ipRate{count: 1, windowEnd: now.Add(rateLimitWindow)} + return true + } + cl.count++ + return cl.count <= rateLimitMax +} diff --git a/shorter.go b/shorter.go index 709d51e..132e5ac 100644 --- a/shorter.go +++ b/shorter.go @@ -5,7 +5,6 @@ import ( "encoding/hex" "flag" "fmt" - "io/ioutil" "log" "net/http" "os" @@ -21,7 +20,7 @@ func main() { var err error // accept if we specify the path to the config directly without a flag, e.g. shorter /path/to/config if len(os.Args) == 2 { - conf, err = ioutil.ReadFile(os.Args[1]) + conf, err = os.ReadFile(os.Args[1]) if err != nil { log.Fatalln("Invalid config file:\n", err) } @@ -30,11 +29,11 @@ func main() { var confFile string // confDir specifies the path to config file. flag.StringVar(&confFile, "config", filepath.Join(".", "config"), "path to the config file") flag.Parse() - conf, err = ioutil.ReadFile(confFile) + conf, err = os.ReadFile(confFile) if err != nil { configPath := findFolderDefaultLocations("shorterdata") if configPath != "" { - conf, err = ioutil.ReadFile(filepath.Join(configPath, "config")) + conf, err = os.ReadFile(filepath.Join(configPath, "config")) if err != nil { log.Fatalln("Invalid config file:\n", err) } @@ -81,8 +80,12 @@ func main() { logger = nil } else { defer f.Close() - // Write out server config on startup if logging is enabled - f.WriteString("Loaded config:\n" + fmt.Sprintf("%# v", pretty.Formatter(config)) + "\nLog Separator: " + logSep + "\n") + // Write out server config on startup if logging is enabled. + // Sensitive fields are redacted so the log file cannot be used to derive credentials. + logConfig := config + logConfig.Salt = "[REDACTED]" + logConfig.HashSHA256 = "[REDACTED]" + f.WriteString("Loaded config:\n" + fmt.Sprintf("%# v", pretty.Formatter(logConfig)) + "\nLog Separator: " + logSep + "\n") logger = log.New(f, logSep+"\n", log.LstdFlags) } } @@ -90,14 +93,9 @@ func main() { // init linkLen1, linkLen2, linkLen3 and fill each freeMap with all valid keys for each len. Defined in misc.go initLinkLens() - // TODO: find better solution, maybe waitgroup so all TimeoutManager have started before starting the server - time.Sleep(time.Millisecond * 500) + loadBlocklist() setupDB() - go BackupRoutine() - - // TODO: find better solution, maybe waitgroup - time.Sleep(time.Millisecond * 500) initTemplates() @@ -106,6 +104,8 @@ func main() { handleCSS(mux) // defined in handlers.go handleImages(mux) // defined in handlers.go handleRobots(mux) // defined in handlers.go + handleQR(mux) // defined in qr.go + handleAPI(mux) // defined in api.go handleRoot(mux) // defined in handlers.go // Start server @@ -114,7 +114,14 @@ func main() { } // if NoTLS is set only start a http server if config.NoTLS { - log.Fatalln(http.ListenAndServe(config.AddressPort, mux)) + srv := &http.Server{ + Addr: config.AddressPort, + Handler: mux, + ReadTimeout: 15 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + log.Fatalln(srv.ListenAndServe()) } server := getServer(mux) // defined in letsencrypt.go // Using LetsEncrypt, no premade cert and key files needed diff --git a/types.go b/types.go index 99618d2..846cb3e 100644 --- a/types.go +++ b/types.go @@ -55,6 +55,10 @@ type Config struct { Email string `yaml:"Email"` // StaticLinks contains a list of static keys that will no time out StaticLinks map[string]string `yaml:"StaticLinks"` + // BlockedDomains lists hostnames (and their subdomains) that may not be shortened + BlockedDomains []string `yaml:"BlockedDomains"` + // BlocklistFile is an optional path to a newline-delimited file of blocked domains + BlocklistFile string `yaml:"BlocklistFile"` // Salt is used as the Salt for the password for special requests Salt string `yaml:"Salt"` // HashSHA256 is the sha256 hash of the password and Salt used for special requests @@ -74,6 +78,7 @@ type Link struct { Data string `json:"Data"` IsCompressed bool `json:"IsCompressed"` Times int `json:"Times"` + AccessCount int64 `json:"AccessCount"` Timeout time.Time `json:"Timeout"` NextClear *Link `json:"NextClear"` } @@ -87,6 +92,7 @@ type LinkLen struct { EndClear *Link `json:"EndClear"` // last element in linked list Timeout time.Duration `json:"Timeout"` Domain string `json:"Domain"` + Type string `json:"Type"` // bbolt bucket name: "len1","len2","len3","custom" } type LinkLens struct { @@ -118,11 +124,18 @@ func (l *LinkLen) Add(lnk *Link) (key string, err error) { isCustomLink := false if l.FreeMap == nil { if len(lnk.Key) < 4 || len(lnk.Key) >= maxKeyLen || !validate(lnk.Key) { - logger.Println("AddKey: invalid parameter key, key can only be > 4 or < " + strconv.Itoa(maxKeyLen)) + if logger != nil { + logger.Println("AddKey: invalid key length or charset, key:", url.QueryEscape(lnk.Key)) + } return "", errors.New("Error: key can only be of length > 4 and < " + strconv.Itoa(maxKeyLen) + " and only use the following characters:\n" + customKeyCharset) } isCustomLink = true key = lnk.Key + // Authoritative duplicate check under the write lock, preventing the data race + // that would occur if callers checked LinkMap before acquiring the lock. + if _, exists := l.LinkMap[key]; exists { + return "", errors.New(errInvalidKeyUsed) + } } // Formatted output for the log @@ -208,6 +221,12 @@ func (l *LinkLen) Add(lnk *Link) (key string, err error) { logstr = append(logstr, "\n Added key:"+url.QueryEscape(key)+"\n l.NextClear.Key: "+url.QueryEscape(l.NextClear.Key)+"\n l.EndClear.Key: "+url.QueryEscape(l.EndClear.Key)) logger.Println(strings.Join(logstr, "")) } + // Snapshot lnk while still holding the write lock so the goroutine + // doesn't race with a concurrent recordAccess or Add() on lnk's fields. + lnkSnap := *lnk + lnkSnap.NextClear = nil + domain, typ := l.Domain, l.Type + go saveLinkToDB(domain, typ, &lnkSnap) return key, nil } @@ -250,20 +269,19 @@ func (l *LinkLen) TimeoutManager() { } delete(l.LinkMap, keyToClear) if l.FreeMap != nil { - // Links of specific length l.FreeMap[keyToClear] = true if logger != nil { logger.Println("Finished clearing nextClear of length:", len(keyToClear), "\ncurrently using:", len(l.LinkMap), "keys\ncurrent free keys:", len(l.FreeMap)) } } else { - // Custom links l.Links-- if logger != nil { logger.Println("Finished clearing nextClear for custom link\ncurrently using:", l.Links, "keys\ncurrent free keys:", config.MaxCustomLinks-l.Links) } } - + domain, typ := l.Domain, l.Type l.Mutex.Unlock() + go deleteLinkFromDB(domain, typ, keyToClear) l.Mutex.RLock() } l.Mutex.RUnlock()