diff --git a/expense_tracker/app.py b/expense_tracker/app.py index 4f1c1ce..7c994d4 100644 --- a/expense_tracker/app.py +++ b/expense_tracker/app.py @@ -6,7 +6,10 @@ from expense_tracker.utils.migration import migrate_legacy_databases from expense_tracker.core.merchant_repository import MerchantCategoryRepository from expense_tracker.core.transaction_repository import TransactionRepository +from expense_tracker.services.merchant import MerchantCategoryService +from expense_tracker.services.transaction import TransactionService from expense_tracker.services.statistics import StatisticsService +from expense_tracker.utils.merchant_normalizer import normalize_merchant def main(): """Start the Expense Tracker application.""" @@ -24,6 +27,10 @@ def main(): merchant_repo = MerchantCategoryRepository( str(get_database_path("merchant_categories.db")) ) + merchant_service = MerchantCategoryService( + merchant_repo, transaction_repo, normalize_merchant + ) + transaction_service = TransactionService(transaction_repo, merchant_service) statistics_service = StatisticsService(transaction_repo) root = Tk() @@ -35,6 +42,6 @@ def main(): tb.Style("darkly") except Exception: ttk.Style() - MainWindow(root, transaction_repo, merchant_repo, statistics_service) + MainWindow(root, transaction_repo, transaction_service, statistics_service) root.focus_force() root.mainloop() diff --git a/expense_tracker/core/transaction_repository.py b/expense_tracker/core/transaction_repository.py index a04a22d..e2da19c 100644 --- a/expense_tracker/core/transaction_repository.py +++ b/expense_tracker/core/transaction_repository.py @@ -234,28 +234,6 @@ def get_monthly_net_income(self, start_date: date, end_date: date) -> float: result = row.fetchone() return result["net_income"] if result["net_income"] is not None else 0.0 - def get_top_spending_category(self, start_date: date, end_date: date) -> tuple[str, float] | None: - """ - Returns the category with the highest spending (sum of negative amounts) for a specific month. - Returns tuple of (category_name, total_spending) or None if no expenses exist. - """ - rows = self.conn.execute( - """ - SELECT category, SUM(ABS(amount)) as total - FROM transactions - WHERE date >= ? AND date < ? - AND amount < 0 - GROUP BY category - ORDER BY total DESC - LIMIT 1 - """, - (start_date.isoformat(), end_date.isoformat()), - ) - result = rows.fetchone() - if result is None: - return None - return (result["category"], result["total"]) - def get_transactions_for_date(self, target_date: date) -> list[Transaction]: """ Query transactions matching exact date. @@ -298,7 +276,7 @@ def get_latest_month_with_data(self) -> tuple[int, int]: return (result["year"], result["month"]) - def get_all_months_with_data(self) -> list[tuple[int, int]]: + def get_all_months_with_data(self) -> set[tuple[int, int]]: """ Returns a list of (year, month) tuples for all months that have transaction data. Ordered by year and month descending (most recent first). diff --git a/expense_tracker/gui/dialogs/add_expense.py b/expense_tracker/gui/dialogs/add_expense.py index f8e6e51..fc0409b 100644 --- a/expense_tracker/gui/dialogs/add_expense.py +++ b/expense_tracker/gui/dialogs/add_expense.py @@ -1,15 +1,16 @@ from datetime import date import tkinter as tk -from tkinter import ttk, messagebox +from tkinter import messagebox from expense_tracker.core.models import Transaction -from expense_tracker.core.transaction_repository import TransactionRepository +from expense_tracker.services.transaction import TransactionService +from expense_tracker.gui.dialogs.expense_form import build_expense_form, validate_amount class AddExpenseDialog(tk.Toplevel): - def __init__(self, master, repo: TransactionRepository): + def __init__(self, master, transaction_service: TransactionService): super().__init__(master) - self.repo = repo + self.transaction_service = transaction_service self.title("Add Expense") self.resizable(False, False) @@ -17,50 +18,20 @@ def __init__(self, master, repo: TransactionRepository): self.category_var = tk.StringVar() self.description_var = tk.StringVar() - self._build_form() - - def _build_form(self): - frame = ttk.Frame(self) - frame.pack(fill="both", padx=10, pady=10) - - # Amount - ttk.Label(frame, text="Amount (e.g. 12.50):").grid(row=0, column=0, sticky="w") - amount = ttk.Entry(frame, textvariable=self.amount_var, width=20) - amount.grid(row=1, column=0, sticky="w") - - # Category - ttk.Label(frame, text="Category:").grid(row=2, column=0, sticky="w") - category = ttk.Entry(frame, textvariable=self.category_var, width=20) - category.grid(row=3, column=0, sticky="w") - - # Description - ttk.Label(frame, text="Description:").grid(row=4, column=0, sticky="w") - description = ttk.Entry(frame, textvariable=self.description_var, width=20) - description.grid(row=5, column=0, sticky="w") - - # Buttons - button_frame = ttk.Frame(frame) - button_frame.grid(row=6, column=0, pady=10, sticky="e") - ttk.Button(button_frame, text="Add", command=self._on_add).pack( - side="right", padx=5 + build_expense_form( + self, + self.amount_var, + self.category_var, + self.description_var, + submit_text="Add", + on_submit=self._on_add, + on_cancel=self._on_cancel, ) - ttk.Button(button_frame, text="Cancel", command=self._on_cancel).pack( - side="right" - ) - - # Keyboard bindings self.bind("", lambda e: self._on_cancel()) def _on_add(self): - raw = self.amount_var.get() - if not raw: - messagebox.showerror("Error", "Amount is required.") - return - - try: - amount = float(raw) - except ValueError: - messagebox.showerror("Error", "Amount must be a valid number.") + amount = validate_amount(self.amount_var) + if amount is None: return try: @@ -71,7 +42,7 @@ def _on_add(self): category=self.category_var.get() or "Uncategorized", description=self.description_var.get() or "", ) - saved_transaction = self.repo.add_transaction(transaction) + saved_transaction = self.transaction_service.add_transaction(transaction) self.result = saved_transaction.id self.destroy() messagebox.showinfo( diff --git a/expense_tracker/gui/dialogs/edit_expense.py b/expense_tracker/gui/dialogs/edit_expense.py index 6a8c193..25bbf82 100644 --- a/expense_tracker/gui/dialogs/edit_expense.py +++ b/expense_tracker/gui/dialogs/edit_expense.py @@ -2,10 +2,8 @@ import tkinter as tk from tkinter import ttk, messagebox -from expense_tracker.core.transaction_repository import TransactionRepository -from expense_tracker.core.merchant_repository import MerchantCategoryRepository -from expense_tracker.services.merchant import MerchantCategoryService -from expense_tracker.utils.merchant_normalizer import normalize_merchant +from expense_tracker.services.transaction import TransactionService +from expense_tracker.gui.dialogs.expense_form import build_expense_form, validate_amount logger = logging.getLogger(__name__) @@ -14,17 +12,12 @@ class EditExpenseDialog(tk.Toplevel): def __init__( self, master, - repo: TransactionRepository, - merchant_repo: MerchantCategoryRepository, + transaction_service: TransactionService, transaction_id: int, ): super().__init__(master) - self.repo = repo - self.merchant_repo = merchant_repo + self.transaction_service = transaction_service self.transaction_id = transaction_id - self.merchant_service = MerchantCategoryService( - merchant_repo, repo, normalize_merchant - ) self.title("Edit Expense") self.resizable(False, False) @@ -40,66 +33,37 @@ def __init__( self.prev_data = None - self._build_form() - self._load_transaction_data() - - def _build_form(self): - frame = ttk.Frame(self) - frame.pack(fill="both", padx=10, pady=10) - - # Amount - ttk.Label(frame, text="Amount (e.g. 12.50):").grid(row=0, column=0, sticky="w") - amount = ttk.Entry(frame, textvariable=self.amount_var, width=20) - amount.grid(row=1, column=0, sticky="w") - - # Category - ttk.Label(frame, text="Category:").grid(row=2, column=0, sticky="w") - category = ttk.Entry(frame, textvariable=self.category_var, width=20) - category.grid(row=3, column=0, sticky="w") - - # Description - ttk.Label(frame, text="Description:").grid(row=4, column=0, sticky="w") - description = ttk.Entry(frame, textvariable=self.description_var, width=20) - description.grid(row=5, column=0, sticky="w") - - # Buttons - button_frame = ttk.Frame(frame) - button_frame.grid(row=6, column=0, pady=10, sticky="e") - ttk.Button(button_frame, text="Save", command=self._on_save).pack( - side="right", padx=5 - ) - ttk.Button(button_frame, text="Cancel", command=self._on_cancel).pack( - side="right" + build_expense_form( + self, + self.amount_var, + self.category_var, + self.description_var, + submit_text="Save", + on_submit=self._on_save, + on_cancel=self._on_cancel, ) - - # Keyboard bindings self.bind("", lambda e: self._on_cancel()) + self._load_transaction_data() + def _load_transaction_data(self): - self.prev_data = self.repo.get_transaction(self.transaction_id) + self.prev_data = self.transaction_service.get_transaction(self.transaction_id) if self.prev_data is not None: self.amount_var.set(str(self.prev_data.amount)) self.category_var.set(self.prev_data.category) self.description_var.set(self.prev_data.description) - # Only suggest category if the current category is "Uncategorized" + # Suggest a better category if currently uncategorized if self.prev_data.category == "Uncategorized": - suggested_category = self.merchant_repo.get_category( - normalize_merchant(self.prev_data.description) + suggested = self.transaction_service.suggest_category( + self.prev_data.description, self.prev_data.amount ) - if suggested_category: - self.category_var.set(suggested_category.category) + if suggested != "Uncategorized": + self.category_var.set(suggested) def _on_save(self): - raw = self.amount_var.get() - if not raw: - messagebox.showerror("Error", "Amount is required.") - return - - try: - amount = float(raw) - except ValueError: - messagebox.showerror("Error", "Amount must be a valid number.") + amount = validate_amount(self.amount_var) + if amount is None: return try: @@ -108,34 +72,22 @@ def _on_save(self): "category": self.category_var.get() or "Uncategorized", "description": self.description_var.get() or "", } - self.repo.update_transaction(self.transaction_id, data) - - # Check if we need to update merchant categories - if ( - self.prev_data is not None - and self.prev_data.category != data["category"] - ): - try: - self.merchant_service.update_category( - self.prev_data.description, data["category"] - ) - self.merchant_service.update_uncategorized_transactions() - messagebox.showinfo( - "Success", - f"Transaction {self.transaction_id} updated and related transactions recategorized.", - ) - except Exception as e: - messagebox.showerror( - "Error", f"Failed to update related transactions: {e}" - ) - self.destroy() + + categories_updated = self.transaction_service.update_transaction( + self.transaction_id, data + ) + + if categories_updated: + messagebox.showinfo( + "Success", + f"Transaction {self.transaction_id} updated and related transactions recategorized.", + ) else: - # No category change, just close the dialog self.result = self.transaction_id messagebox.showinfo( "Success", f"Transaction {self.transaction_id} updated." ) - self.destroy() + self.destroy() except Exception as e: messagebox.showerror("Error", f"Failed to update transaction: {e}") diff --git a/expense_tracker/gui/dialogs/expense_form.py b/expense_tracker/gui/dialogs/expense_form.py new file mode 100644 index 0000000..1fc2001 --- /dev/null +++ b/expense_tracker/gui/dialogs/expense_form.py @@ -0,0 +1,57 @@ +import tkinter as tk +from tkinter import ttk, messagebox + + +def build_expense_form( + parent: tk.Widget, + amount_var: tk.StringVar, + category_var: tk.StringVar, + description_var: tk.StringVar, + submit_text: str, + on_submit, + on_cancel, +) -> ttk.Frame: + """Build the shared Amount/Category/Description form used by Add and Edit dialogs.""" + frame = ttk.Frame(parent) + frame.pack(fill="both", padx=10, pady=10) + + # Amount + ttk.Label(frame, text="Amount (e.g. 12.50):").grid(row=0, column=0, sticky="w") + ttk.Entry(frame, textvariable=amount_var, width=20).grid( + row=1, column=0, sticky="w" + ) + + # Category + ttk.Label(frame, text="Category:").grid(row=2, column=0, sticky="w") + ttk.Entry(frame, textvariable=category_var, width=20).grid( + row=3, column=0, sticky="w" + ) + + # Description + ttk.Label(frame, text="Description:").grid(row=4, column=0, sticky="w") + ttk.Entry(frame, textvariable=description_var, width=20).grid( + row=5, column=0, sticky="w" + ) + + # Buttons + button_frame = ttk.Frame(frame) + button_frame.grid(row=6, column=0, pady=10, sticky="e") + ttk.Button(button_frame, text=submit_text, command=on_submit).pack( + side="right", padx=5 + ) + ttk.Button(button_frame, text="Cancel", command=on_cancel).pack(side="right") + + return frame + + +def validate_amount(amount_var: tk.StringVar) -> float | None: + """Validate and parse the amount field. Shows error messagebox on failure.""" + raw = amount_var.get() + if not raw: + messagebox.showerror("Error", "Amount is required.") + return None + try: + return float(raw) + except ValueError: + messagebox.showerror("Error", "Amount must be a valid number.") + return None diff --git a/expense_tracker/gui/dialogs/upload.py b/expense_tracker/gui/dialogs/upload.py index a6d8118..6c51521 100644 --- a/expense_tracker/gui/dialogs/upload.py +++ b/expense_tracker/gui/dialogs/upload.py @@ -1,30 +1,17 @@ import tkinter as tk from tkinter import ttk, messagebox -from datetime import datetime, date from expense_tracker.core.models import Transaction -from expense_tracker.core.transaction_repository import TransactionRepository -from expense_tracker.core.merchant_repository import MerchantCategoryRepository +from expense_tracker.services.transaction import TransactionService from expense_tracker.utils.extract import parse_bofa_statement_pdf -from expense_tracker.services.merchant import MerchantCategoryService -from expense_tracker.utils.merchant_normalizer import normalize_merchant class UploadDialog(tk.Toplevel): - def __init__( - self, - master, - repo: TransactionRepository, - merchant_repo: MerchantCategoryRepository, - ): + def __init__(self, master, transaction_service: TransactionService): super().__init__(master) - self.repo = repo - self.merchant_repo = merchant_repo + self.transaction_service = transaction_service self.title("Upload Bank Statement") self.resizable(False, False) - self.merchant_service = MerchantCategoryService( - merchant_repo, repo, normalize_merchant - ) self.file_var = tk.StringVar() @@ -66,22 +53,22 @@ def _on_upload(self): return try: - transactions = parse_bofa_statement_pdf(file_path) - for t in transactions: - transaction = Transaction( + raw_transactions = parse_bofa_statement_pdf(file_path) + transactions = [ + Transaction( id=None, - date=self._parse_date(t["date"]), + date=t["date"], amount=t["amount"], category="Uncategorized", description=t["description"], ) - transaction.category = self.merchant_service.categorize_merchant( - transaction.description, transaction.amount - ) - if self.repo.transaction_exists(transaction): - continue # Skip duplicates - self.repo.add_transaction(transaction) - messagebox.showinfo("Success", "Bank statement uploaded successfully.") + for t in raw_transactions + ] + imported = self.transaction_service.import_transactions(transactions) + messagebox.showinfo( + "Success", + f"Imported {imported} transaction(s) from bank statement.", + ) self.destroy() except Exception as e: messagebox.showerror("Error", f"Failed to upload bank statement: {e}") @@ -89,13 +76,3 @@ def _on_upload(self): def _on_cancel(self): self.file_var.set("") self.destroy() - - def _parse_date(self, raw_date) -> date: - if isinstance(raw_date, date): - return raw_date - for fmt in ("%Y-%m-%d", "%m/%d/%Y", "%m/%d/%y"): - try: - return datetime.strptime(raw_date, fmt).date() - except (ValueError, TypeError): - continue - raise ValueError(f"Unsupported date format: {raw_date}") diff --git a/expense_tracker/gui/main_window.py b/expense_tracker/gui/main_window.py index 8e64d81..12a55fa 100644 --- a/expense_tracker/gui/main_window.py +++ b/expense_tracker/gui/main_window.py @@ -6,10 +6,12 @@ class MainWindow(tk.Frame): - def __init__(self, master, transaction_repo, merchant_repo, statistics_service): + def __init__( + self, master, transaction_repo, transaction_service, statistics_service + ): super().__init__(master) self.transaction_repo = transaction_repo - self.merchant_repo = merchant_repo + self.transaction_service = transaction_service self.statistics_service = statistics_service self.master = master self._active_dialog: tk.Toplevel | None = None @@ -22,7 +24,7 @@ def __init__(self, master, transaction_repo, merchant_repo, statistics_service): # Create Transactions tab self.transactions_tab = TransactionsTab( - self.notebook, transaction_repo, merchant_repo, self + self.notebook, transaction_repo, transaction_service, self ) # Create Statistics tab diff --git a/expense_tracker/gui/tabs/transactions_tab.py b/expense_tracker/gui/tabs/transactions_tab.py index 2f36633..6c34d9a 100644 --- a/expense_tracker/gui/tabs/transactions_tab.py +++ b/expense_tracker/gui/tabs/transactions_tab.py @@ -4,16 +4,23 @@ from tkinter import ttk, messagebox from expense_tracker.core.transaction_repository import TransactionRepository +from expense_tracker.services.transaction import TransactionService from expense_tracker.gui.dialogs.add_expense import AddExpenseDialog from expense_tracker.gui.dialogs.edit_expense import EditExpenseDialog from expense_tracker.gui.dialogs.upload import UploadDialog class TransactionsTab(tk.Frame): - def __init__(self, master, transaction_repo, merchant_repo, main_window): + def __init__( + self, + master, + transaction_repo, + transaction_service: TransactionService, + main_window, + ): super().__init__(master) self.transaction_repo: TransactionRepository = transaction_repo - self.merchant_repo = merchant_repo + self.transaction_service = transaction_service self.main_window = main_window self._current_page = 0 self._page_size = 100 @@ -173,12 +180,10 @@ def _get_selected_ids(self) -> list[int]: return ids def _upload_statement(self): - self.main_window._open_dialog( - UploadDialog, self.transaction_repo, self.merchant_repo - ) + self.main_window._open_dialog(UploadDialog, self.transaction_service) def _add_transaction(self): - self.main_window._open_dialog(AddExpenseDialog, self.transaction_repo) + self.main_window._open_dialog(AddExpenseDialog, self.transaction_service) def _edit_transaction(self): transaction_ids = self._get_selected_ids() @@ -196,8 +201,7 @@ def _edit_transaction(self): self.main_window._open_dialog( EditExpenseDialog, - self.transaction_repo, - self.merchant_repo, + self.transaction_service, transaction_ids[0], ) @@ -216,7 +220,7 @@ def _delete_transaction(self): ) if confirm: - deleted = self.transaction_repo.delete_multiple_transactions( + deleted = self.transaction_service.delete_multiple_transactions( transaction_ids ) messagebox.showinfo( diff --git a/expense_tracker/services/__init__.py b/expense_tracker/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/expense_tracker/services/statistics.py b/expense_tracker/services/statistics.py index afb9b71..ceb985a 100644 --- a/expense_tracker/services/statistics.py +++ b/expense_tracker/services/statistics.py @@ -49,7 +49,8 @@ def get_monthly_metrics(self, year: int, month: int) -> MonthlyMetrics: """ start_date, end_date = self._get_month_date_range(year, month) net_income = self.transaction_repo.get_monthly_net_income(start_date, end_date) - top_category_data = self.transaction_repo.get_top_spending_category(start_date, end_date) + top_category_data = self.transaction_repo.get_spending_by_category(start_date, end_date) + top_category_data = top_category_data[0] if top_category_data else None if top_category_data: top_category, top_spending = top_category_data diff --git a/expense_tracker/services/transaction.py b/expense_tracker/services/transaction.py new file mode 100644 index 0000000..a96b73d --- /dev/null +++ b/expense_tracker/services/transaction.py @@ -0,0 +1,83 @@ +import logging +from dataclasses import replace + +from expense_tracker.core.models import Transaction +from expense_tracker.core.transaction_repository import TransactionRepository +from expense_tracker.services.merchant import MerchantCategoryService + +logger = logging.getLogger(__name__) + + +class TransactionService: + """ + Service layer for transaction write operations. + Ensures consistent auto-categorization across all entry points + (manual add, PDF import, edit). + """ + + def __init__( + self, + transaction_repo: TransactionRepository, + merchant_service: MerchantCategoryService, + ): + self.transaction_repo = transaction_repo + self.merchant_service = merchant_service + + def add_transaction(self, transaction: Transaction) -> Transaction: + """Add a transaction with auto-categorization. + + If the transaction's category is "Uncategorized", attempts to + categorize it via the merchant category service. + """ + if transaction.category == "Uncategorized": + category = self.merchant_service.categorize_merchant( + transaction.description, transaction.amount + ) + transaction = replace(transaction, category=category) + return self.transaction_repo.add_transaction(transaction) + + def get_transaction(self, transaction_id: int) -> Transaction | None: + return self.transaction_repo.get_transaction(transaction_id) + + def update_transaction(self, transaction_id: int, data: dict) -> bool: + """Update a transaction. If category changed, updates merchant mapping + and re-categorizes uncategorized transactions. + + Returns True if merchant categories were updated, False otherwise. + """ + prev = self.transaction_repo.get_transaction(transaction_id) + self.transaction_repo.update_transaction(transaction_id, data) + + new_category = data.get("category") + if prev and new_category and prev.category != new_category: + self.merchant_service.update_category(prev.description, new_category) + self.merchant_service.update_uncategorized_transactions() + return True + return False + + def suggest_category(self, description: str, amount: float) -> str: + """Suggest a category for a transaction based on merchant mappings.""" + return self.merchant_service.categorize_merchant(description, amount) + + def import_transactions(self, transactions: list[Transaction]) -> int: + """Import transactions with auto-categorization and duplicate detection. + + Returns the number of transactions imported (skipping duplicates). + """ + imported = 0 + for transaction in transactions: + category = self.merchant_service.categorize_merchant( + transaction.description, transaction.amount + ) + transaction = replace(transaction, category=category) + if self.transaction_repo.transaction_exists(transaction): + continue + self.transaction_repo.add_transaction(transaction) + imported += 1 + return imported + + def delete_transaction(self, transaction_id: int) -> None: + self.transaction_repo.delete_transaction(transaction_id) + + def delete_multiple_transactions(self, transaction_ids: list[int]) -> int: + return self.transaction_repo.delete_multiple_transactions(transaction_ids) diff --git a/expense_tracker/utils/date.py b/expense_tracker/utils/date.py new file mode 100644 index 0000000..122f474 --- /dev/null +++ b/expense_tracker/utils/date.py @@ -0,0 +1,11 @@ +from datetime import date, datetime + +def parse_date_from_str(raw_date) -> date: + if isinstance(raw_date, date): + return raw_date + for fmt in ("%Y-%m-%d", "%m/%d/%Y", "%m/%d/%y"): + try: + return datetime.strptime(raw_date, fmt).date() + except (ValueError, TypeError): + continue + raise ValueError(f"Unsupported date format: {raw_date}") \ No newline at end of file diff --git a/expense_tracker/utils/extract.py b/expense_tracker/utils/extract.py index a709e28..6d2e99f 100644 --- a/expense_tracker/utils/extract.py +++ b/expense_tracker/utils/extract.py @@ -1,6 +1,7 @@ import pdfplumber import re -from datetime import datetime + +from expense_tracker.utils.date import parse_date_from_str DATE_RX = re.compile(r"^\d{2}/\d{2}/\d{2}$") AMOUNT_RX = re.compile( @@ -45,7 +46,7 @@ def parse_bofa_page(page): continue rows.append( { - "date": _parse_date(tokens[0]), + "date": parse_date_from_str(tokens[0]), "description": desc, "amount": _parse_amount(tokens[amt_idx]), } @@ -53,15 +54,6 @@ def parse_bofa_page(page): return rows -def _parse_date(s): - for fmt in ("%m/%d/%y", "%m/%d/%Y"): - try: - return datetime.strptime(s, fmt).date().isoformat() - except Exception: - pass - return s - - def _parse_amount(s: str) -> float: s = s.replace("$", "").replace(",", "").strip() neg = s.startswith("(") and s.endswith(")") diff --git a/tests/core/test_repository.py b/tests/core/test_repository.py index 09b22a1..a9a1c98 100644 --- a/tests/core/test_repository.py +++ b/tests/core/test_repository.py @@ -889,123 +889,6 @@ def test_get_monthly_net_income_only_income(in_memory_repo): assert net_income == 2000.0 -def test_get_top_spending_category_with_multiple_categories(in_memory_repo): - repo: TransactionRepository = in_memory_repo - # Add expenses in different categories - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-01-05"), - amount=-100.0, - category="Groceries", - description="Whole Foods", - ) - ) - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-01-10"), - amount=-50.0, - category="Groceries", - description="Trader Joes", - ) - ) - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-01-15"), - amount=-75.0, - category="Restaurants", - description="Dinner", - ) - ) - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-01-20"), - amount=-25.0, - category="Transportation", - description="Uber", - ) - ) - - top_category = repo.get_top_spending_category(date(2023, 1, 1), date(2023, 2, 1)) - - # Groceries has highest spending: 100 + 50 = 150 - assert top_category is not None - assert top_category[0] == "Groceries" - assert top_category[1] == 150.0 - - -def test_get_top_spending_category_exclude_income(in_memory_repo): - repo: TransactionRepository = in_memory_repo - # Add income and expenses - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-01-05"), - amount=2000.0, # Income (should be excluded) - category="Income", - description="Salary", - ) - ) - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-01-10"), - amount=-50.0, - category="Food", - description="Groceries", - ) - ) - - top_category = repo.get_top_spending_category(date(2023, 1, 1), date(2023, 2, 1)) - - # Should return Food, not Income - assert top_category is not None - assert top_category[0] == "Food" - assert top_category[1] == 50.0 - - -def test_get_top_spending_category_no_expenses(in_memory_repo): - repo: TransactionRepository = in_memory_repo - # Add only income - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-01-05"), - amount=1000.0, - category="Income", - description="Salary", - ) - ) - - top_category = repo.get_top_spending_category(date(2023, 1, 1), date(2023, 2, 1)) - - # Should return None when no expenses exist - assert top_category is None - - -def test_get_top_spending_category_empty_month(in_memory_repo): - repo: TransactionRepository = in_memory_repo - # Add transaction for different month - repo.add_transaction( - Transaction( - id=None, - date=date.fromisoformat("2023-02-10"), - amount=-50.0, - category="Food", - description="Groceries", - ) - ) - - # Query for month with no transactions - top_category = repo.get_top_spending_category(date(2023, 1, 1), date(2023, 2, 1)) - - # Should return None - assert top_category is None - - def test_get_latest_month_with_data(in_memory_repo): repo: TransactionRepository = in_memory_repo # Add transactions across different months diff --git a/tests/services/test_transaction.py b/tests/services/test_transaction.py new file mode 100644 index 0000000..1dec38d --- /dev/null +++ b/tests/services/test_transaction.py @@ -0,0 +1,256 @@ +from datetime import date + +import pytest + +from expense_tracker.core.models import Transaction, MerchantCategory +from expense_tracker.core.transaction_repository import TransactionRepository +from expense_tracker.core.merchant_repository import MerchantCategoryRepository +from expense_tracker.services.merchant import MerchantCategoryService +from expense_tracker.services.transaction import TransactionService +from expense_tracker.utils.merchant_normalizer import normalize_merchant + + +@pytest.fixture +def in_memory_repo(): + repo = TransactionRepository(":memory:") + yield repo + repo.conn.close() + + +@pytest.fixture +def merchant_repo(): + repo = MerchantCategoryRepository(":memory:") + yield repo + repo.conn.close() + + +@pytest.fixture +def transaction_service(in_memory_repo, merchant_repo): + merchant_service = MerchantCategoryService( + merchant_repo, in_memory_repo, normalize_merchant + ) + return TransactionService(in_memory_repo, merchant_service) + + +def test_add_transaction_auto_categorizes(transaction_service, merchant_repo): + """Uncategorized transactions get auto-categorized via merchant mappings.""" + merchant_repo.set_category(MerchantCategory("WHOLE FOODS", "Groceries")) + + txn = Transaction( + id=None, + date=date(2023, 1, 5), + amount=-50.0, + category="Uncategorized", + description="WHOLE FOODS MARKET #123", + ) + saved = transaction_service.add_transaction(txn) + + assert saved.category == "Groceries" + assert saved.id is not None + + +def test_add_transaction_preserves_explicit_category(transaction_service, merchant_repo): + """Transactions with an explicit category are not re-categorized.""" + merchant_repo.set_category(MerchantCategory("WHOLE FOODS", "Groceries")) + + txn = Transaction( + id=None, + date=date(2023, 1, 5), + amount=-50.0, + category="Food", + description="WHOLE FOODS MARKET #123", + ) + saved = transaction_service.add_transaction(txn) + + assert saved.category == "Food" + + +def test_add_transaction_income_auto_categorized(transaction_service): + """Positive amounts get categorized as Income.""" + txn = Transaction( + id=None, + date=date(2023, 1, 5), + amount=2000.0, + category="Uncategorized", + description="EMPLOYER DIRECT DEP", + ) + saved = transaction_service.add_transaction(txn) + + assert saved.category == "Income" + + +def test_add_transaction_no_mapping_stays_uncategorized(transaction_service): + """Transactions with no merchant mapping stay Uncategorized.""" + txn = Transaction( + id=None, + date=date(2023, 1, 5), + amount=-25.0, + category="Uncategorized", + description="RANDOM SHOP XYZ", + ) + saved = transaction_service.add_transaction(txn) + + assert saved.category == "Uncategorized" + + +def test_update_transaction_updates_merchant_categories( + transaction_service, in_memory_repo, merchant_repo +): + """Changing category on update should update merchant mapping and recategorize.""" + # Add an uncategorized transaction + txn1 = in_memory_repo.add_transaction( + Transaction( + id=None, + date=date(2023, 1, 5), + amount=-50.0, + category="Uncategorized", + description="WHOLE FOODS MARKET", + ) + ) + # Add another uncategorized transaction with similar description + txn2 = in_memory_repo.add_transaction( + Transaction( + id=None, + date=date(2023, 1, 10), + amount=-30.0, + category="Uncategorized", + description="WHOLE FOODS", + ) + ) + + # Update txn1's category — should trigger merchant mapping update + categories_updated = transaction_service.update_transaction( + txn1.id, {"category": "Groceries"} + ) + + assert categories_updated is True + + # txn2 should now be auto-categorized + updated_txn2 = in_memory_repo.get_transaction(txn2.id) + assert updated_txn2.category == "Groceries" + + +def test_update_transaction_no_category_change(transaction_service, in_memory_repo): + """Updating without changing category should not trigger merchant update.""" + txn = in_memory_repo.add_transaction( + Transaction( + id=None, + date=date(2023, 1, 5), + amount=-50.0, + category="Food", + description="Lunch", + ) + ) + + categories_updated = transaction_service.update_transaction( + txn.id, {"amount": -60.0, "category": "Food"} + ) + + assert categories_updated is False + updated = in_memory_repo.get_transaction(txn.id) + assert updated.amount == -60.0 + + +def test_import_transactions_categorizes_and_deduplicates( + transaction_service, in_memory_repo, merchant_repo +): + """Import should auto-categorize and skip duplicates.""" + merchant_repo.set_category(MerchantCategory("AMAZON", "Shopping")) + + # Pre-existing transaction (will be a duplicate) + in_memory_repo.add_transaction( + Transaction( + id=None, + date=date(2023, 1, 5), + amount=-50.0, + category="Shopping", + description="AMAZON.COM", + ) + ) + + transactions = [ + Transaction( + id=None, + date=date(2023, 1, 5), + amount=-50.0, + category="Uncategorized", + description="AMAZON.COM", + ), + Transaction( + id=None, + date=date(2023, 1, 10), + amount=-30.0, + category="Uncategorized", + description="AMAZON PRIME", + ), + Transaction( + id=None, + date=date(2023, 1, 15), + amount=2000.0, + category="Uncategorized", + description="EMPLOYER PAYROLL", + ), + ] + + imported = transaction_service.import_transactions(transactions) + + assert imported == 2 # first one skipped as duplicate + + # Check that imported transactions were categorized + all_txns = in_memory_repo.get_all_transactions() + assert len(all_txns) == 3 # 1 existing + 2 imported + + # Find the Amazon Prime transaction + prime_txn = next(t for t in all_txns if "PRIME" in t.description) + assert prime_txn.category == "Shopping" + + # Find the income transaction + income_txn = next(t for t in all_txns if "PAYROLL" in t.description) + assert income_txn.category == "Income" + + +def test_import_transactions_empty_list(transaction_service): + """Importing empty list returns 0.""" + imported = transaction_service.import_transactions([]) + assert imported == 0 + + +def test_suggest_category(transaction_service, merchant_repo): + """suggest_category delegates to merchant service.""" + merchant_repo.set_category(MerchantCategory("STARBUCKS", "Coffee")) + + result = transaction_service.suggest_category("STARBUCKS COFFEE #123", -5.0) + assert result == "Coffee" + + +def test_suggest_category_no_match(transaction_service): + """suggest_category returns Uncategorized when no match.""" + result = transaction_service.suggest_category("UNKNOWN MERCHANT", -10.0) + assert result == "Uncategorized" + + +def test_delete_multiple_transactions(transaction_service, in_memory_repo): + """delete_multiple_transactions delegates to repository.""" + txn1 = in_memory_repo.add_transaction( + Transaction( + id=None, + date=date(2023, 1, 5), + amount=-50.0, + category="Food", + description="Lunch", + ) + ) + txn2 = in_memory_repo.add_transaction( + Transaction( + id=None, + date=date(2023, 1, 10), + amount=-30.0, + category="Food", + description="Dinner", + ) + ) + + deleted = transaction_service.delete_multiple_transactions([txn1.id, txn2.id]) + assert deleted == 2 + assert in_memory_repo.get_transaction(txn1.id) is None + assert in_memory_repo.get_transaction(txn2.id) is None diff --git a/tests/utils/test_extract.py b/tests/utils/test_extract.py index 86e7f58..0871499 100644 --- a/tests/utils/test_extract.py +++ b/tests/utils/test_extract.py @@ -1,20 +1,12 @@ +from datetime import date + from expense_tracker.utils.extract import ( - _parse_date, _parse_amount, parse_bofa_page, parse_bofa_statement_pdf, ) from unittest.mock import patch, Mock - -def test_parse_date(): - assert _parse_date("11/08/23") == "2023-11-08" - assert _parse_date("01/01/2024") == "2024-01-01" - assert ( - _parse_date("invalid-date") == "invalid-date" - ) # Should return original string if parsing fails - - def test_parse_amount(): assert _parse_amount("100.00") == 100.0 assert _parse_amount("$50.50") == 50.5 @@ -60,12 +52,12 @@ def test_parse_bofa_page(): assert len(result) == 2 assert result[0] == { - "date": "2024-01-15", + "date": date(2024, 1, 15), "description": "Some Transaction", "amount": 123.45, } assert result[1] == { - "date": "2024-01-16", + "date": date(2024, 1, 16), "description": "Another One", "amount": -50.00, }