diff --git a/README.md b/README.md index 12fcf42..ef646ea 100644 --- a/README.md +++ b/README.md @@ -211,6 +211,7 @@ Build the binary, place it and the config, then install the service: ```bash go build -o /usr/local/bin/helm ./cmd/helm +install -d /etc/helm cp config.yml /etc/helm/config.yml ``` @@ -225,6 +226,7 @@ ExecStart=/usr/local/bin/helm /etc/helm/config.yml WorkingDirectory=/var/lib/helm User=helm Group=helm +Environment=TZ=UTC Restart=on-failure RestartSec=5s @@ -232,6 +234,8 @@ RestartSec=5s WantedBy=multi-user.target ``` +> `TZ` controls the timezone Helm uses for reminder/recurrence scheduling. Change `UTC` to your local tz (e.g. `America/New_York`) so scheduled events fire at the expected wall-clock time. `tzdata` must be installed on the host. + ```bash useradd -r -s /sbin/nologin helm mkdir -p /var/lib/helm/data diff --git a/cmd/helm/main.go b/cmd/helm/main.go index bdb0181..04c8f34 100644 --- a/cmd/helm/main.go +++ b/cmd/helm/main.go @@ -10,6 +10,7 @@ import ( "net/http" "os" "os/signal" + "sync" "syscall" "time" @@ -34,6 +35,8 @@ func main() { log.Fatalf("config: %v", err) } + log.Printf("timezone: %s (set TZ env var to override)", time.Local.String()) + database, err := db.Open(cfg.Storage.DBPath) if err != nil { log.Fatalf("database: %v", err) @@ -48,14 +51,17 @@ func main() { log.Fatalf("attachments dir: %v", err) } + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + b := broker.New() - stopReminders := reminder.StartScheduler(database, b) + stopReminders := reminder.StartScheduler(ctx, database, b) defer stopReminders() - stopRecurrence := recurrence.StartScheduler(database) + stopRecurrence := recurrence.StartScheduler(ctx, database) defer stopRecurrence() - stopCalDAV := startCalDAVScheduler(database, cfg.Auth.Secret) + stopCalDAV := startCalDAVScheduler(ctx, database, cfg.Auth.Secret) defer stopCalDAV() var uiFS fs.FS @@ -70,9 +76,6 @@ func main() { addr := fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port) srv := &http.Server{Addr: addr, Handler: router} - ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) - defer stop() - go func() { log.Printf("helm listening on http://%s", addr) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { @@ -91,13 +94,14 @@ func main() { } // startCalDAVScheduler syncs all non-local calendar sources every 15 minutes. -// Returns a cancel function. -func startCalDAVScheduler(database *sql.DB, secret string) func() { - ticker := time.NewTicker(15 * time.Minute) - done := make(chan struct{}) +// The scheduler stops when parent ctx is cancelled or when the returned stop function is +// invoked. Stop blocks until the ticker goroutine and any in-flight sync goroutines return. +func startCalDAVScheduler(parent context.Context, database *sql.DB, secret string) func() { + ctx, cancel := context.WithCancel(parent) + var wg sync.WaitGroup syncAll := func() { - rows, err := database.Query( + rows, err := database.QueryContext(ctx, `SELECT id, name, url, username, password_enc FROM calendar_sources WHERE is_local = 0`, ) if err != nil { @@ -116,25 +120,36 @@ func startCalDAVScheduler(database *sql.DB, secret string) func() { src.Username = username.String src.PasswordEnc = passwordEnc.String + wg.Add(1) go func(s caldav.CalendarSource) { + defer wg.Done() if err := caldav.SyncSource(database, s, secret); err != nil { log.Printf("caldav scheduler: source %d: %v", s.ID, err) } }(src) } + if err := rows.Err(); err != nil { + log.Printf("caldav scheduler: iterate sources: %v", err) + } } + wg.Add(1) go func() { + defer wg.Done() + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() for { select { case <-ticker.C: syncAll() - case <-done: - ticker.Stop() + case <-ctx.Done(): return } } }() - return func() { close(done) } + return func() { + cancel() + wg.Wait() + } } diff --git a/internal/api/handlers/attachments.go b/internal/api/handlers/attachments.go index 30c15a3..7f2d181 100644 --- a/internal/api/handlers/attachments.go +++ b/internal/api/handlers/attachments.go @@ -164,6 +164,10 @@ func ListAttachments(db *sql.DB) http.HandlerFunc { } attachments = append(attachments, a) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, attachments) } } diff --git a/internal/api/handlers/bookmarks.go b/internal/api/handlers/bookmarks.go index ec0353b..c1b50cf 100644 --- a/internal/api/handlers/bookmarks.go +++ b/internal/api/handlers/bookmarks.go @@ -54,6 +54,10 @@ func ListBookmarkCollections(db *sql.DB) http.HandlerFunc { } collections = append(collections, c) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, collections) } } @@ -142,6 +146,10 @@ func ListBookmarks(db *sql.DB) http.HandlerFunc { bookmarks = append(bookmarks, bm) ids = append(ids, bm.ID) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } tagMap := batchGetEntityTags(db, "bookmark", ids) for i := range bookmarks { if tags, ok := tagMap[bookmarks[i].ID]; ok { diff --git a/internal/api/handlers/calendar.go b/internal/api/handlers/calendar.go index bbfdfdf..ecc8999 100644 --- a/internal/api/handlers/calendar.go +++ b/internal/api/handlers/calendar.go @@ -65,6 +65,10 @@ func ListCalendarSources(db *sql.DB) http.HandlerFunc { } sources = append(sources, s) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, sources) } } @@ -221,6 +225,10 @@ func ListCalendarEvents(db *sql.DB) http.HandlerFunc { } events = append(events, e) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, events) } } diff --git a/internal/api/handlers/clipboard.go b/internal/api/handlers/clipboard.go index c71eca2..a80ee81 100644 --- a/internal/api/handlers/clipboard.go +++ b/internal/api/handlers/clipboard.go @@ -60,6 +60,10 @@ func ListClipboardItems(db *sql.DB) http.HandlerFunc { items = append(items, item) ids = append(ids, item.ID) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } tagMap := batchGetEntityTags(db, "clipboard", ids) for i := range items { if tags, ok := tagMap[items[i].ID]; ok { diff --git a/internal/api/handlers/memos.go b/internal/api/handlers/memos.go index 8e79971..efc106c 100644 --- a/internal/api/handlers/memos.go +++ b/internal/api/handlers/memos.go @@ -63,6 +63,10 @@ func ListMemos(db *sql.DB) http.HandlerFunc { memos = append(memos, m) ids = append(ids, m.ID) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } tagMap := batchGetEntityTags(db, "memo", ids) for i := range memos { if tags, ok := tagMap[memos[i].ID]; ok { diff --git a/internal/api/handlers/notes.go b/internal/api/handlers/notes.go index ea0f317..4458f70 100644 --- a/internal/api/handlers/notes.go +++ b/internal/api/handlers/notes.go @@ -52,6 +52,10 @@ func ListNoteFolders(db *sql.DB) http.HandlerFunc { } folders = append(folders, f) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, folders) } } @@ -140,6 +144,10 @@ func ListNotes(db *sql.DB) http.HandlerFunc { notes = append(notes, n) ids = append(ids, n.ID) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } tagMap := batchGetEntityTags(db, "note", ids) for i := range notes { if tags, ok := tagMap[notes[i].ID]; ok { diff --git a/internal/api/handlers/reminders.go b/internal/api/handlers/reminders.go index 7b59424..1119f4e 100644 --- a/internal/api/handlers/reminders.go +++ b/internal/api/handlers/reminders.go @@ -40,6 +40,10 @@ func ListReminders(db *sql.DB) http.HandlerFunc { } reminders = append(reminders, rem) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, reminders) } } diff --git a/internal/api/handlers/tags.go b/internal/api/handlers/tags.go index 3669afb..8dfb629 100644 --- a/internal/api/handlers/tags.go +++ b/internal/api/handlers/tags.go @@ -65,6 +65,10 @@ func ListTags(db *sql.DB) http.HandlerFunc { } tags = append(tags, t) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, tags) } } diff --git a/internal/api/handlers/todos.go b/internal/api/handlers/todos.go index a019d10..fde489b 100644 --- a/internal/api/handlers/todos.go +++ b/internal/api/handlers/todos.go @@ -77,6 +77,10 @@ func ListTodoLists(db *sql.DB) http.HandlerFunc { } lists = append(lists, l) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } respond(w, http.StatusOK, lists) } } @@ -180,6 +184,10 @@ func ListTodos(db *sql.DB) http.HandlerFunc { todos = append(todos, t) ids = append(ids, t.ID) } + if err := rows.Err(); err != nil { + respondError(w, http.StatusInternalServerError, "row iteration failed") + return + } tagMap := batchGetEntityTags(db, "todo", ids) subtaskMap := batchGetSubtasks(db, ids) for i := range todos { diff --git a/internal/caldav/sync.go b/internal/caldav/sync.go index d35af2d..8e65614 100644 --- a/internal/caldav/sync.go +++ b/internal/caldav/sync.go @@ -2,14 +2,26 @@ package caldav import ( "database/sql" + "errors" "fmt" "log" + "sync" "time" "github.com/lerko/helm/internal/crypto" "github.com/lerko/helm/internal/httpclient" ) +// ErrSyncInProgress is returned when a concurrent sync is already running for the same source. +var ErrSyncInProgress = errors.New("sync already in progress for source") + +var syncLocks sync.Map // map[int64]*sync.Mutex + +func lockFor(id int64) *sync.Mutex { + v, _ := syncLocks.LoadOrStore(id, &sync.Mutex{}) + return v.(*sync.Mutex) +} + // validateCalDAVURL defers to the shared SSRF policy: https-only + no // private/loopback destinations. func validateCalDAVURL(rawURL string) error { @@ -27,7 +39,14 @@ type CalendarSource struct { // SyncSource fetches events from a remote CalDAV source and upserts them into the DB. // It skips events whose etag is unchanged, and deletes DB events not present in the remote response. +// Returns ErrSyncInProgress if a concurrent sync for the same source is already running. func SyncSource(db *sql.DB, source CalendarSource, secret string) error { + mu := lockFor(source.ID) + if !mu.TryLock() { + return ErrSyncInProgress + } + defer mu.Unlock() + if err := validateCalDAVURL(source.URL); err != nil { return fmt.Errorf("source %d URL rejected: %w", source.ID, err) } @@ -154,6 +173,9 @@ func deleteStaleEvents(db *sql.DB, sourceID int64, keepUIDs map[string]struct{}) toDelete = append(toDelete, uid) } } + if err := rows.Err(); err != nil { + return err + } rows.Close() for _, uid := range toDelete { diff --git a/internal/caldav/sync_test.go b/internal/caldav/sync_test.go new file mode 100644 index 0000000..22ee2ce --- /dev/null +++ b/internal/caldav/sync_test.go @@ -0,0 +1,54 @@ +package caldav + +import ( + "errors" + "strings" + "testing" +) + +func TestSyncSource_ConcurrentCallsReturnErrSyncInProgress(t *testing.T) { + // Acquire the per-source lock manually to simulate an in-flight sync. + mu := lockFor(999) + mu.Lock() + defer mu.Unlock() + + src := CalendarSource{ + ID: 999, + Name: "concurrent-test", + URL: "https://example.invalid/", + } + + err := SyncSource(nil, src, "secret") + if !errors.Is(err, ErrSyncInProgress) { + t.Fatalf("expected ErrSyncInProgress, got %v", err) + } +} + +func TestSyncSource_PrivateIPRejected(t *testing.T) { + // Fresh source ID so the lock is free and the URL check runs. + src := CalendarSource{ + ID: 1001, + Name: "ssrf-test", + URL: "http://127.0.0.1/", + } + + err := SyncSource(nil, src, "secret") + if err == nil { + t.Fatal("expected SSRF rejection, got nil") + } + if !strings.Contains(err.Error(), "rejected") { + t.Errorf("expected rejection error, got: %v", err) + } +} + +func TestSyncSource_MetadataEndpointRejected(t *testing.T) { + src := CalendarSource{ + ID: 1002, + Name: "aws-metadata", + URL: "https://169.254.169.254/latest/meta-data/", + } + err := SyncSource(nil, src, "secret") + if err == nil { + t.Fatal("expected SSRF rejection, got nil") + } +} diff --git a/internal/recurrence/scheduler.go b/internal/recurrence/scheduler.go index e704a09..3593e05 100644 --- a/internal/recurrence/scheduler.go +++ b/internal/recurrence/scheduler.go @@ -8,10 +8,13 @@ import ( ) // StartScheduler polls for due recurrences every hour and spawns new todo copies. -// Returns a cancel function that stops the scheduler. -func StartScheduler(db *sql.DB) func() { - ctx, cancel := context.WithCancel(context.Background()) +// The scheduler stops when parent is cancelled or when the returned stop function is invoked. +// The stop function blocks until the goroutine returns. +func StartScheduler(parent context.Context, db *sql.DB) func() { + ctx, cancel := context.WithCancel(parent) + done := make(chan struct{}) go func() { + defer close(done) ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() for { @@ -19,15 +22,18 @@ func StartScheduler(db *sql.DB) func() { case <-ctx.Done(): return case <-ticker.C: - spawnDue(db) + spawnDue(ctx, db) } } }() - return cancel + return func() { + cancel() + <-done + } } -func spawnDue(db *sql.DB) { - rows, err := db.QueryContext(context.Background(), ` +func spawnDue(ctx context.Context, db *sql.DB) { + rows, err := db.QueryContext(ctx, ` SELECT tr.id, tr.todo_id, tr.rrule, tr.next_occurrence FROM todo_recurrences tr WHERE tr.next_occurrence <= CURRENT_TIMESTAMP @@ -55,20 +61,18 @@ func spawnDue(db *sql.DB) { rows.Close() for _, r := range due { - if err := spawnTodo(db, r.id, r.todoID, r.rrule, r.nextOccurrence); err != nil { + if err := spawnTodo(ctx, db, r.id, r.todoID, r.rrule, r.nextOccurrence); err != nil { log.Printf("recurrence scheduler: spawn todo %d: %v", r.todoID, err) } } } -func spawnTodo(db *sql.DB, recurrenceID, parentID int64, rrule string, nextOcc time.Time) error { +func spawnTodo(ctx context.Context, db *sql.DB, recurrenceID, parentID int64, rrule string, nextOcc time.Time) error { freq, interval, err := ParseRRule(rrule) if err != nil { return err } - ctx := context.Background() - var ( listID sql.NullInt64 title string diff --git a/internal/reminder/scheduler.go b/internal/reminder/scheduler.go index 4dc4d2e..d42bb2c 100644 --- a/internal/reminder/scheduler.go +++ b/internal/reminder/scheduler.go @@ -18,11 +18,14 @@ type payload struct { } // StartScheduler polls for due reminders every 30s and publishes them to the broker. -// Returns a cancel function that stops the scheduler. -func StartScheduler(db *sql.DB, b *broker.Broker) func() { - ctx, cancel := context.WithCancel(context.Background()) +// The scheduler stops when parent is cancelled or when the returned stop function is invoked. +// The stop function blocks until the goroutine returns. +func StartScheduler(parent context.Context, db *sql.DB, b *broker.Broker) func() { + ctx, cancel := context.WithCancel(parent) + done := make(chan struct{}) go func() { + defer close(done) ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() @@ -31,17 +34,18 @@ func StartScheduler(db *sql.DB, b *broker.Broker) func() { case <-ctx.Done(): return case <-ticker.C: - fire(db, b) + fire(ctx, db, b) } } }() - return cancel + return func() { + cancel() + <-done + } } -func fire(db *sql.DB, b *broker.Broker) { - ctx := context.Background() - +func fire(ctx context.Context, db *sql.DB, b *broker.Broker) { tx, err := db.BeginTx(ctx, nil) if err != nil { log.Printf("reminder scheduler: begin tx: %v", err) diff --git a/internal/reminder/scheduler_test.go b/internal/reminder/scheduler_test.go new file mode 100644 index 0000000..35d5893 --- /dev/null +++ b/internal/reminder/scheduler_test.go @@ -0,0 +1,101 @@ +package reminder + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/lerko/helm/internal/broker" + "github.com/lerko/helm/internal/db" +) + +func openTestDB(t *testing.T) *sql.DB { + t.Helper() + tmp := t.TempDir() + "/test.db" + d, err := db.Open(tmp) + if err != nil { + t.Fatalf("open: %v", err) + } + if err := db.Migrate(d); err != nil { + t.Fatalf("migrate: %v", err) + } + t.Cleanup(func() { d.Close() }) + return d +} + +func TestStartScheduler_CancelStopsGoroutine(t *testing.T) { + d := openTestDB(t) + b := broker.New() + + ctx, cancel := context.WithCancel(context.Background()) + stop := StartScheduler(ctx, d, b) + + done := make(chan struct{}) + go func() { + cancel() + stop() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("scheduler did not stop within 2s after ctx cancel") + } +} + +func TestStartScheduler_StopFuncAlone(t *testing.T) { + d := openTestDB(t) + b := broker.New() + + stop := StartScheduler(context.Background(), d, b) + done := make(chan struct{}) + go func() { + stop() + close(done) + }() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("stop func did not return within 2s") + } +} + +func TestFire_MarksReminderSentAndPublishes(t *testing.T) { + d := openTestDB(t) + b := broker.New() + + ch := b.Subscribe("test-client") + defer b.Unsubscribe("test-client") + + // Insert a due reminder (5 minutes in the past). + past := time.Now().UTC().Add(-5 * time.Minute).Format("2006-01-02 15:04:05") + _, err := d.Exec( + `INSERT INTO reminders (user_id, entity_type, entity_id, remind_at, is_sent) VALUES (1, 'todo', 42, ?, 0)`, + past, + ) + if err != nil { + t.Fatalf("insert: %v", err) + } + + fire(context.Background(), d, b) + + select { + case msg := <-ch: + if msg == "" { + t.Error("received empty message") + } + case <-time.After(1 * time.Second): + t.Fatal("expected publish after fire, got nothing") + } + + var sent int + if err := d.QueryRow(`SELECT is_sent FROM reminders WHERE entity_id = 42`).Scan(&sent); err != nil { + t.Fatalf("query: %v", err) + } + if sent != 1 { + t.Errorf("is_sent = %d, want 1", sent) + } +} diff --git a/web/src/App.tsx b/web/src/App.tsx index 472b016..8b31096 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -3,7 +3,7 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import Shell, { type Page } from './components/layout/Shell' import LoginPage from './components/LoginPage' import { isAuthenticated, clearToken } from './lib/auth' -import { apiFetch } from './lib/api' +import { apiFetch, authEvents } from './lib/api' import { startSSE, type ReminderEvent, type MutationEvent } from './lib/sse' import MemosWidget from './components/widgets/MemosWidget' import TodosWidget from './components/widgets/TodosWidget' @@ -99,6 +99,16 @@ export default function App() { } }, []) + useEffect(() => { + const handler = () => { + setAuthed(false) + setPages(null) + showBanner('SESSION EXPIRED — PLEASE LOG IN') + } + authEvents.addEventListener('unauth', handler) + return () => authEvents.removeEventListener('unauth', handler) + }, [showBanner]) + useEffect(() => { if (!authed) return apiFetch('/api/config/pages') diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index fb98de6..e25d31a 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -1,5 +1,7 @@ import { clearToken, getToken } from './auth' +export const authEvents = new EventTarget() + export async function apiFetch(path: string, options: RequestInit = {}): Promise { const token = getToken() const res = await fetch(path, { @@ -13,7 +15,7 @@ export async function apiFetch(path: string, options: RequestInit = {}): Prom if (res.status === 401) { clearToken() - window.location.reload() + authEvents.dispatchEvent(new Event('unauth')) throw new Error('unauthenticated') }