From 4b2beeed2707ea90317c3da122e8fbb96f66d304 Mon Sep 17 00:00:00 2001 From: Weronika Kombat Date: Thu, 5 Mar 2026 09:36:05 +0000 Subject: [PATCH] feat: add category support for RSS articles - Add Categories field to Article model - Extract categories from RSS/Atom feeds using gofeed - Store categories as comma-separated in database - Add migration for existing databases - Add --category flag to articles command - Display categories in article list output --- internal/cli/commands.go | 6 +- internal/controller/controller.go | 16 +-- internal/controller/controller_test.go | 4 +- internal/model/model.go | 1 + internal/rss/rss.go | 2 + internal/scanner/scanner.go | 1 + internal/scanner/scanner_test.go | 2 +- internal/storage/database.go | 176 ++++++++++++++----------- internal/storage/database_test.go | 10 +- 9 files changed, 121 insertions(+), 97 deletions(-) diff --git a/internal/cli/commands.go b/internal/cli/commands.go index 3b8323b..2919159 100644 --- a/internal/cli/commands.go +++ b/internal/cli/commands.go @@ -195,6 +195,7 @@ func newScanCommand() *cobra.Command { func newArticlesCommand() *cobra.Command { var showAll bool var blogName string + var category string cmd := &cobra.Command{ Use: "articles", @@ -205,7 +206,7 @@ func newArticlesCommand() *cobra.Command { return err } defer db.Close() - articles, blogNames, err := controller.GetArticles(db, showAll, blogName) + articles, blogNames, err := controller.GetArticles(db, showAll, blogName, category) if err != nil { printError(err) return markError(err) @@ -233,6 +234,7 @@ func newArticlesCommand() *cobra.Command { cmd.Flags().BoolVarP(&showAll, "all", "a", false, "Show all articles (including read)") cmd.Flags().StringVarP(&blogName, "blog", "b", "", "Filter by blog name") + cmd.Flags().StringVarP(&category, "category", "c", "", "Filter by category") return cmd } @@ -281,7 +283,7 @@ func newReadAllCommand() *cobra.Command { } defer db.Close() - articles, blogNames, err := controller.GetArticles(db, false, blogName) + articles, blogNames, err := controller.GetArticles(db, false, blogName, "") if err != nil { printError(err) return markError(err) diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 6a22739..f04bb9d 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -43,7 +43,6 @@ func AddBlog(db *storage.Database, name string, url string, feedURL string, scra } else if existing != nil { return model.Blog{}, BlogAlreadyExistsError{Field: "URL", Value: url} } - blog := model.Blog{ Name: name, URL: url, @@ -65,7 +64,7 @@ func RemoveBlog(db *storage.Database, name string) error { return err } -func GetArticles(db *storage.Database, showAll bool, blogName string) ([]model.Article, map[int64]string, error) { +func GetArticles(db *storage.Database, showAll bool, blogName string, category string) ([]model.Article, map[int64]string, error) { var blogID *int64 if blogName != "" { blog, err := db.GetBlogByName(blogName) @@ -77,8 +76,11 @@ func GetArticles(db *storage.Database, showAll bool, blogName string) ([]model.A } blogID = &blog.ID } - - articles, err := db.ListArticles(!showAll, blogID) + var categoryPtr *string + if category != "" { + categoryPtr = &category + } + articles, err := db.ListArticles(!showAll, blogID, categoryPtr) if err != nil { return nil, nil, err } @@ -90,7 +92,6 @@ func GetArticles(db *storage.Database, showAll bool, blogName string) ([]model.A for _, blog := range blogs { blogNames[blog.ID] = blog.Name } - return articles, blogNames, nil } @@ -123,19 +124,16 @@ func MarkAllArticlesRead(db *storage.Database, blogName string) ([]model.Article } blogID = &blog.ID } - - articles, err := db.ListArticles(true, blogID) + articles, err := db.ListArticles(true, blogID, nil) if err != nil { return nil, err } - for _, article := range articles { _, err := db.MarkArticleRead(article.ID) if err != nil { return nil, err } } - return articles, nil } diff --git a/internal/controller/controller_test.go b/internal/controller/controller_test.go index d3d368b..e78e8ea 100644 --- a/internal/controller/controller_test.go +++ b/internal/controller/controller_test.go @@ -73,7 +73,7 @@ func TestGetArticlesFilters(t *testing.T) { t.Fatalf("add article: %v", err) } - articles, blogNames, err := GetArticles(db, false, "") + articles, blogNames, err := GetArticles(db, false, "", "") if err != nil { t.Fatalf("get articles: %v", err) } @@ -84,7 +84,7 @@ func TestGetArticlesFilters(t *testing.T) { t.Fatalf("expected blog name") } - if _, _, err := GetArticles(db, false, "Missing"); err == nil { + if _, _, err := GetArticles(db, false, "Missing", ""); err == nil { t.Fatalf("expected blog not found error") } } diff --git a/internal/model/model.go b/internal/model/model.go index dd0a2d8..07d0fee 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -19,4 +19,5 @@ type Article struct { PublishedDate *time.Time DiscoveredDate *time.Time IsRead bool + Categories []string } diff --git a/internal/rss/rss.go b/internal/rss/rss.go index a162811..3a4caf0 100644 --- a/internal/rss/rss.go +++ b/internal/rss/rss.go @@ -16,6 +16,7 @@ type FeedArticle struct { Title string URL string PublishedDate *time.Time + Categories []string } type FeedParseError struct { @@ -54,6 +55,7 @@ func ParseFeed(feedURL string, timeout time.Duration) ([]FeedArticle, error) { Title: title, URL: link, PublishedDate: pickPublishedDate(item), + Categories: item.Categories, }) } diff --git a/internal/scanner/scanner.go b/internal/scanner/scanner.go index e5aaeb1..68de08b 100644 --- a/internal/scanner/scanner.go +++ b/internal/scanner/scanner.go @@ -181,6 +181,7 @@ func convertFeedArticles(blogID int64, articles []rss.FeedArticle) []model.Artic URL: article.URL, PublishedDate: article.PublishedDate, IsRead: false, + Categories: article.Categories, }) } return result diff --git a/internal/scanner/scanner_test.go b/internal/scanner/scanner_test.go index 376a9a2..f23badf 100644 --- a/internal/scanner/scanner_test.go +++ b/internal/scanner/scanner_test.go @@ -49,7 +49,7 @@ func TestScanBlogRSS(t *testing.T) { t.Fatalf("expected rss source, got %s", result.Source) } - articles, err := db.ListArticles(false, nil) + articles, err := db.ListArticles(false, nil, nil) if err != nil { t.Fatalf("list articles: %v", err) } diff --git a/internal/storage/database.go b/internal/storage/database.go index 07573e6..ff84233 100644 --- a/internal/storage/database.go +++ b/internal/storage/database.go @@ -10,7 +10,6 @@ import ( "time" _ "modernc.org/sqlite" - "github.com/Hyaxia/blogwatcher/internal/model" ) @@ -37,7 +36,6 @@ func OpenDatabase(path string) (*Database, error) { return nil, err } } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { return nil, err } @@ -47,12 +45,15 @@ func OpenDatabase(path string) (*Database, error) { if err != nil { return nil, err } - db := &Database{path: path, conn: conn} if err := db.init(); err != nil { _ = conn.Close() return nil, err } + if err := db.migrate(); err != nil { + _ = conn.Close() + return nil, err + } return db, nil } @@ -69,37 +70,68 @@ func (db *Database) Close() error { func (db *Database) init() error { schema := ` - CREATE TABLE IF NOT EXISTS blogs ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL, - url TEXT NOT NULL UNIQUE, - feed_url TEXT, - scrape_selector TEXT, - last_scanned TIMESTAMP - ); - CREATE TABLE IF NOT EXISTS articles ( - id INTEGER PRIMARY KEY, - blog_id INTEGER NOT NULL, - title TEXT NOT NULL, - url TEXT NOT NULL UNIQUE, - published_date TIMESTAMP, - discovered_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - is_read BOOLEAN DEFAULT FALSE, - FOREIGN KEY (blog_id) REFERENCES blogs(id) - ); + CREATE TABLE IF NOT EXISTS blogs ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + url TEXT NOT NULL UNIQUE, + feed_url TEXT, + scrape_selector TEXT, + last_scanned TIMESTAMP + ); + CREATE TABLE IF NOT EXISTS articles ( + id INTEGER PRIMARY KEY, + blog_id INTEGER NOT NULL, + title TEXT NOT NULL, + url TEXT NOT NULL UNIQUE, + published_date TIMESTAMP, + discovered_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + is_read BOOLEAN DEFAULT FALSE, + categories TEXT, + FOREIGN KEY (blog_id) REFERENCES blogs(id) + ); ` _, err := db.conn.Exec(schema) return err } +// migrate adds new columns for existing databases +func (db *Database) migrate() error { + // Check if categories column exists + var count int + err := db.conn.QueryRow( + "SELECT COUNT(*) FROM pragma_table_info('articles') WHERE name = 'categories'", + ).Scan(&count) + if err != nil { + return err + } + if count == 0 { + _, err = db.conn.Exec("ALTER TABLE articles ADD COLUMN categories TEXT") + if err != nil { + return err + } + } + return nil +} + +func categoriesToString(categories []string) *string { + if len(categories) == 0 { + return nil + } + s := strings.Join(categories, ",") + return &s +} + +func categoriesFromString(s *string) []string { + if s == nil || *s == "" { + return nil + } + return strings.Split(*s, ",") +} + func (db *Database) AddBlog(blog model.Blog) (model.Blog, error) { result, err := db.conn.Exec( - `INSERT INTO blogs (name, url, feed_url, scrape_selector, last_scanned) - VALUES (?, ?, ?, ?, ?)`, - blog.Name, - blog.URL, - nullIfEmpty(blog.FeedURL), - nullIfEmpty(blog.ScrapeSelector), + `INSERT INTO blogs (name, url, feed_url, scrape_selector, last_scanned) VALUES (?, ?, ?, ?, ?)`, + blog.Name, blog.URL, nullIfEmpty(blog.FeedURL), nullIfEmpty(blog.ScrapeSelector), formatTimePtr(blog.LastScanned), ) if err != nil { @@ -134,7 +166,6 @@ func (db *Database) ListBlogs() ([]model.Blog, error) { return nil, err } defer rows.Close() - var blogs []model.Blog for rows.Next() { blog, err := scanBlog(rows) @@ -151,12 +182,8 @@ func (db *Database) ListBlogs() ([]model.Blog, error) { func (db *Database) UpdateBlog(blog model.Blog) error { _, err := db.conn.Exec( `UPDATE blogs SET name = ?, url = ?, feed_url = ?, scrape_selector = ?, last_scanned = ? WHERE id = ?`, - blog.Name, - blog.URL, - nullIfEmpty(blog.FeedURL), - nullIfEmpty(blog.ScrapeSelector), - formatTimePtr(blog.LastScanned), - blog.ID, + blog.Name, blog.URL, nullIfEmpty(blog.FeedURL), nullIfEmpty(blog.ScrapeSelector), + formatTimePtr(blog.LastScanned), blog.ID, ) return err } @@ -184,14 +211,10 @@ func (db *Database) RemoveBlog(id int64) (bool, error) { func (db *Database) AddArticle(article model.Article) (model.Article, error) { result, err := db.conn.Exec( - `INSERT INTO articles (blog_id, title, url, published_date, discovered_date, is_read) - VALUES (?, ?, ?, ?, ?, ?)`, - article.BlogID, - article.Title, - article.URL, - formatTimePtr(article.PublishedDate), - formatTimePtr(article.DiscoveredDate), - article.IsRead, + `INSERT INTO articles (blog_id, title, url, published_date, discovered_date, is_read, categories) VALUES (?, ?, ?, ?, ?, ?, ?)`, + article.BlogID, article.Title, article.URL, + formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), + article.IsRead, categoriesToString(article.Categories), ) if err != nil { return article, err @@ -212,7 +235,7 @@ func (db *Database) AddArticlesBulk(articles []model.Article) (int, error) { if err != nil { return 0, err } - stmt, err := _tx.Prepare(`INSERT INTO articles (blog_id, title, url, published_date, discovered_date, is_read) VALUES (?, ?, ?, ?, ?, ?)`) + stmt, err := _tx.Prepare(`INSERT INTO articles (blog_id, title, url, published_date, discovered_date, is_read, categories) VALUES (?, ?, ?, ?, ?, ?, ?)`) if err != nil { _ = _tx.Rollback() return 0, err @@ -221,12 +244,9 @@ func (db *Database) AddArticlesBulk(articles []model.Article) (int, error) { for _, article := range articles { _, err := stmt.Exec( - article.BlogID, - article.Title, - article.URL, - formatTimePtr(article.PublishedDate), - formatTimePtr(article.DiscoveredDate), - article.IsRead, + article.BlogID, article.Title, article.URL, + formatTimePtr(article.PublishedDate), formatTimePtr(article.DiscoveredDate), + article.IsRead, categoriesToString(article.Categories), ) if err != nil { _ = _tx.Rollback() @@ -240,12 +260,12 @@ func (db *Database) AddArticlesBulk(articles []model.Article) (int, error) { } func (db *Database) GetArticle(id int64) (*model.Article, error) { - row := db.conn.QueryRow(`SELECT id, blog_id, title, url, published_date, discovered_date, is_read FROM articles WHERE id = ?`, id) + row := db.conn.QueryRow(`SELECT id, blog_id, title, url, published_date, discovered_date, is_read, categories FROM articles WHERE id = ?`, id) return scanArticle(row) } func (db *Database) GetArticleByURL(url string) (*model.Article, error) { - row := db.conn.QueryRow(`SELECT id, blog_id, title, url, published_date, discovered_date, is_read FROM articles WHERE url = ?`, url) + row := db.conn.QueryRow(`SELECT id, blog_id, title, url, published_date, discovered_date, is_read, categories FROM articles WHERE url = ?`, url) return scanArticle(row) } @@ -267,7 +287,6 @@ func (db *Database) GetExistingArticleURLs(urls []string) (map[string]struct{}, if len(urls) == 0 { return result, nil } - chunkSize := 900 for start := 0; start < len(urls); start += chunkSize { end := start + chunkSize @@ -298,8 +317,8 @@ func (db *Database) GetExistingArticleURLs(urls []string) (map[string]struct{}, return result, nil } -func (db *Database) ListArticles(unreadOnly bool, blogID *int64) ([]model.Article, error) { - query := `SELECT id, blog_id, title, url, published_date, discovered_date, is_read FROM articles WHERE 1=1` +func (db *Database) ListArticles(unreadOnly bool, blogID *int64, category *string) ([]model.Article, error) { + query := `SELECT id, blog_id, title, url, published_date, discovered_date, is_read, categories FROM articles WHERE 1=1` var args []interface{} if unreadOnly { query += " AND is_read = 0" @@ -308,14 +327,16 @@ func (db *Database) ListArticles(unreadOnly bool, blogID *int64) ([]model.Articl query += " AND blog_id = ?" args = append(args, *blogID) } + if category != nil && *category != "" { + query += " AND categories LIKE ?" + args = append(args, "%"+*category+"%") + } query += " ORDER BY discovered_date DESC" - rows, err := db.conn.Query(query, args...) if err != nil { return nil, err } defer rows.Close() - var articles []model.Article for rows.Next() { article, err := scanArticle(rows) @@ -355,12 +376,12 @@ func (db *Database) MarkArticleUnread(id int64) (bool, error) { func scanBlog(scanner interface{ Scan(dest ...any) error }) (*model.Blog, error) { var ( - id int64 - name string - url string - feedURL sql.NullString + id int64 + name string + url string + feedURL sql.NullString scrapeSelector sql.NullString - lastScanned sql.NullString + lastScanned sql.NullString ) if err := scanner.Scan(&id, &name, &url, &feedURL, &scrapeSelector, &lastScanned); err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -368,7 +389,6 @@ func scanBlog(scanner interface{ Scan(dest ...any) error }) (*model.Blog, error) } return nil, err } - blog := &model.Blog{ ID: id, Name: name, @@ -386,27 +406,28 @@ func scanBlog(scanner interface{ Scan(dest ...any) error }) (*model.Blog, error) func scanArticle(scanner interface{ Scan(dest ...any) error }) (*model.Article, error) { var ( - id int64 - blogID int64 - title string - url string - publishedDate sql.NullString - discovered sql.NullString - isRead bool + id int64 + blogID int64 + title string + url string + publishedDate sql.NullString + discovered sql.NullString + isRead bool + categories sql.NullString ) - if err := scanner.Scan(&id, &blogID, &title, &url, &publishedDate, &discovered, &isRead); err != nil { + if err := scanner.Scan(&id, &blogID, &title, &url, &publishedDate, &discovered, &isRead, &categories); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, nil } return nil, err } - article := &model.Article{ - ID: id, - BlogID: blogID, - Title: title, - URL: url, - IsRead: isRead, + ID: id, + BlogID: blogID, + Title: title, + URL: url, + IsRead: isRead, + Categories: categoriesFromString(&categories.String), } if publishedDate.Valid { if parsed, err := parseTime(publishedDate.String); err == nil { @@ -418,7 +439,6 @@ func scanArticle(scanner interface{ Scan(dest ...any) error }) (*model.Article, article.DiscoveredDate = &parsed } } - return article, nil } @@ -454,4 +474,4 @@ func interfaceSlice(values []string) []interface{} { result[i] = value } return result -} +} \ No newline at end of file diff --git a/internal/storage/database_test.go b/internal/storage/database_test.go index acf871e..a298878 100644 --- a/internal/storage/database_test.go +++ b/internal/storage/database_test.go @@ -50,7 +50,7 @@ func TestDatabaseCreatesFileAndCRUD(t *testing.T) { t.Fatalf("expected 2 articles, got %d", count) } - list, err := db.ListArticles(false, nil) + list, err := db.ListArticles(false, nil, nil) if err != nil { t.Fatalf("list articles: %v", err) } @@ -265,7 +265,7 @@ func TestListArticlesFiltersAndOrdering(t *testing.T) { t.Fatalf("mark read: %v", err) } - all, err := db.ListArticles(false, nil) + all, err := db.ListArticles(false, nil, nil) if err != nil { t.Fatalf("list articles: %v", err) } @@ -276,7 +276,7 @@ func TestListArticlesFiltersAndOrdering(t *testing.T) { t.Fatalf("expected newest article first") } - unread, err := db.ListArticles(true, nil) + unread, err := db.ListArticles(true, nil, nil) if err != nil { t.Fatalf("list unread: %v", err) } @@ -285,7 +285,7 @@ func TestListArticlesFiltersAndOrdering(t *testing.T) { } blogID := blogB.ID - filtered, err := db.ListArticles(false, &blogID) + filtered, err := db.ListArticles(false, &blogID, nil) if err != nil { t.Fatalf("list by blog: %v", err) } @@ -325,7 +325,7 @@ func TestBulkInsertDuplicateRollbackAndEmpty(t *testing.T) { t.Fatalf("expected bulk insert to fail on duplicate url") } - articles, err := db.ListArticles(false, nil) + articles, err := db.ListArticles(false, nil, nil) if err != nil { t.Fatalf("list articles: %v", err) }