From ca89d8c52c90e8c5b2b6478cf2902dbb44b9cc06 Mon Sep 17 00:00:00 2001 From: nikhilb2 Date: Wed, 6 May 2026 23:17:27 +0200 Subject: [PATCH] refactor(invoices): extract payload processing to InvoiceProcessor service (closes #351) --- backend/src/api/routes/invoices.py | 328 +---------- backend/src/services/invoice_processor.py | 532 +++++++++++++++++ .../api/test_invoice_dues_allocations.py | 6 +- .../tests/api/test_invoice_reference_notes.py | 2 +- backend/tests/api/test_invoice_tax_split.py | 15 +- .../api/test_invoice_update_inventory.py | 10 +- .../tests/services/test_invoice_processor.py | 545 ++++++++++++++++++ 7 files changed, 1106 insertions(+), 332 deletions(-) create mode 100644 backend/src/services/invoice_processor.py create mode 100644 backend/tests/services/test_invoice_processor.py diff --git a/backend/src/api/routes/invoices.py b/backend/src/api/routes/invoices.py index 00925d1..b65dab4 100644 --- a/backend/src/api/routes/invoices.py +++ b/backend/src/api/routes/invoices.py @@ -1,29 +1,22 @@ from io import BytesIO from datetime import date, datetime -from collections import defaultdict -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.responses import StreamingResponse -from sqlalchemy import case, func, or_ -from sqlalchemy.orm import Session -from sqlalchemy.orm import joinedload -from decimal import Decimal, ROUND_HALF_UP +from sqlalchemy import case, func +from sqlalchemy.orm import Session, joinedload +from decimal import Decimal import weasyprint -from fastapi import Query - from src.db.session import get_db -from src.models.buyer import Buyer as Ledger from src.models.company_account import CompanyAccount from src.models.company import CompanyProfile from src.models.invoice import Invoice, InvoiceItem -from src.models.inventory import Inventory from src.models.product import Product from src.models.user import User from src.schemas.invoice import InvoiceCreate, InvoiceOut, PaginatedInvoiceOut from src.api.deps import get_active_company, get_current_user -from src.services.series import generate_next_number from src.services.financial_year import get_active_fy, get_fy_for_date from src.services.invoice_payments import build_invoice_payment_summaries from src.services.pdf_templates import ( @@ -31,288 +24,11 @@ _build_multi_copy_invoice_html, _copy_label, ) -from src.services.gst_tax_service import money as _money, is_interstate_supply, assign_item_tax_split, assign_invoice_tax_totals +from src.services.invoice_processor import InvoiceProcessor router = APIRouter() - - -def _generate_next_number( - db: Session, - voucher_type: str, - financial_year_id: int | None = None, - invoice_date: date | None = None, - active_financial_year_id: int | None = None, - company_id: int | None = None, -) -> str: - return generate_next_number( - db, - voucher_type, - financial_year_id, - invoice_date, - active_financial_year_id, - company_id=company_id, - ) - - -def _require_ledger(db: Session, ledger_id: int, company_id: int | None) -> Ledger: - query = db.query(Ledger).filter(Ledger.id == ledger_id) - if company_id is not None: - query = query.filter(or_(Ledger.company_id == company_id, Ledger.company_id.is_(None))) - ledger = query.first() - if not ledger: - raise HTTPException(status_code=404, detail=f"Ledger {ledger_id} not found") - return ledger - - -def _change_inventory_quantity( - db: Session, - product_id: int, - quantity_delta: Decimal, - *, - company_id: int | None, - context: str, -) -> None: - query = db.query(Inventory).filter(Inventory.product_id == product_id) - if company_id is not None: - query = query.filter(or_(Inventory.company_id == company_id, Inventory.company_id.is_(None))) - inventory = query.first() - if not inventory: - inventory = Inventory(company_id=company_id, product_id=product_id, quantity=0) - db.add(inventory) - db.flush() - - inventory.quantity = Decimal(str(inventory.quantity or 0)) + quantity_delta - if Decimal(str(inventory.quantity or 0)) < 0: - raise HTTPException(status_code=400, detail=f"Insufficient inventory while {context}") - - - - -def _reverse_existing_invoice_inventory(db: Session, invoice: Invoice) -> None: - for item in invoice.items: - product_query = db.query(Product).filter(Product.id == item.product_id) - if invoice.company_id is not None: - product_query = product_query.filter(or_(Product.company_id == invoice.company_id, Product.company_id.is_(None))) - product = product_query.first() - if not product: - raise HTTPException(status_code=404, detail=f"Product {item.product_id} not found") - if not product.maintain_inventory: - continue - - reverse_delta = item.quantity if invoice.voucher_type == "sales" else -item.quantity - _change_inventory_quantity( - db, - item.product_id, - reverse_delta, - company_id=invoice.company_id, - context=f"reversing invoice {invoice.id}", - ) - - -def _inventory_effect_for_voucher_type(quantity: float, voucher_type: str) -> Decimal: - quantity_value = Decimal(str(quantity)) - return -quantity_value if voucher_type == "sales" else quantity_value - - -def _apply_inventory_delta_for_invoice_update( - db: Session, - invoice: Invoice, - payload: InvoiceCreate, - *, - company_id: int | None, -) -> None: - existing_effect_by_product: dict[int, Decimal] = defaultdict(lambda: Decimal("0")) - for item in invoice.items: - existing_effect_by_product[item.product_id] += _inventory_effect_for_voucher_type( - item.quantity, - invoice.voucher_type, - ) - - next_effect_by_product: dict[int, Decimal] = defaultdict(lambda: Decimal("0")) - for item in payload.items: - next_effect_by_product[item.product_id] += _inventory_effect_for_voucher_type( - item.quantity, - payload.voucher_type, - ) - - for product_id in set(existing_effect_by_product) | set(next_effect_by_product): - quantity_delta = next_effect_by_product[product_id] - existing_effect_by_product[product_id] - if quantity_delta == 0: - continue - - product_query = db.query(Product).filter(Product.id == product_id) - if company_id is not None: - product_query = product_query.filter(or_(Product.company_id == company_id, Product.company_id.is_(None))) - product = product_query.first() - if not product: - raise HTTPException(status_code=404, detail=f"Product {product_id} not found") - if not product.maintain_inventory: - continue - - _change_inventory_quantity( - db, - product_id, - quantity_delta, - company_id=company_id, - context=f"editing invoice {invoice.id}", - ) - - -def _apply_payload_to_invoice( - db: Session, - invoice: Invoice, - payload: InvoiceCreate, - active_company: CompanyProfile | None = None, - created_by: int | None = None, - financial_year_id: int | None = None, - active_financial_year_id: int | None = None, - regenerate_number: bool = True, - apply_inventory_changes: bool = True, -) -> None: - company = active_company or db.query(CompanyProfile).order_by(CompanyProfile.id.asc()).first() - company_id = company.id if company else None - ledger = _require_ledger(db, payload.ledger_id, company_id) - - invoice.company_id = company_id - invoice.ledger_id = ledger.id - invoice.ledger_name = ledger.name - invoice.ledger_address = ledger.address - invoice.ledger_gst = ledger.gst - invoice.ledger_phone = ledger.phone_number - invoice.company_name = company.name if company else None - invoice.company_address = company.address if company else None - invoice.company_gst = company.gst if company else None - invoice.company_phone = company.phone_number if company else None - invoice.company_email = company.email if company else None - invoice.company_website = company.website if company else None - invoice.company_currency_code = company.currency_code if company else None - invoice.company_bank_name = company.bank_name if company else None - invoice.company_branch_name = company.branch_name if company else None - invoice.company_account_name = company.account_name if company else None - invoice.company_account_number = company.account_number if company else None - invoice.company_ifsc_code = company.ifsc_code if company else None - invoice.voucher_type = payload.voucher_type - invoice.supplier_invoice_number = payload.supplier_invoice_number - invoice.reference_notes = payload.reference_notes - if created_by is not None: - invoice.created_by = created_by - if financial_year_id is not None: - invoice.financial_year_id = financial_year_id - - if payload.invoice_date is not None: - invoice.invoice_date = datetime.combine(payload.invoice_date, datetime.min.time()) - - invoice.due_date = datetime.combine(payload.due_date, datetime.min.time()) if payload.due_date is not None else None - - invoice.tax_inclusive = payload.tax_inclusive - invoice.apply_round_off = payload.apply_round_off - if regenerate_number: - invoice.invoice_number = _generate_next_number( - db, invoice.voucher_type, financial_year_id, payload.invoice_date, - active_financial_year_id, - company_id=company_id, - ) - - if not payload.items: - raise HTTPException(status_code=400, detail="Invoice must have at least one line item") - - interstate_supply = is_interstate_supply(invoice.company_gst, invoice.ledger_gst) - - taxable_total = Decimal("0") - created_items: list[InvoiceItem] = [] - for item in payload.items: - quantity_value = Decimal(str(item.quantity)) - if quantity_value <= 0: - raise HTTPException(status_code=400, detail="Item quantity must be greater than zero") - - product_query = db.query(Product).filter(Product.id == item.product_id) - if company_id is not None: - product_query = product_query.filter(or_(Product.company_id == company_id, Product.company_id.is_(None))) - product = product_query.first() - if not product: - raise HTTPException(status_code=404, detail=f"Product {item.product_id} not found") - - if not product.allow_decimal and quantity_value != quantity_value.to_integral_value(): - raise HTTPException( - status_code=400, - detail=f"Quantity for {product.name} must be a whole number", - ) - - if apply_inventory_changes and product.maintain_inventory: - inventory_query = db.query(Inventory).filter(Inventory.product_id == item.product_id) - if company_id is not None: - inventory_query = inventory_query.filter(or_(Inventory.company_id == company_id, Inventory.company_id.is_(None))) - inventory = inventory_query.first() - if payload.voucher_type == "sales" and (not inventory or Decimal(str(inventory.quantity or 0)) < quantity_value): - raise HTTPException(status_code=400, detail=f"Insufficient inventory for {product.name}") - - quantity_delta = _inventory_effect_for_voucher_type(item.quantity, payload.voucher_type) - _change_inventory_quantity( - db, - item.product_id, - quantity_delta, - company_id=company_id, - context=f"applying invoice {invoice.id or 'new'}", - ) - - # Use custom unit_price if provided, otherwise use product price. - # GST rate is snapshotted from the product at invoice time. - unit_price = Decimal(str(item.unit_price)) if item.unit_price is not None else Decimal(str(product.price)) - gst_rate = Decimal(str(product.gst_rate or 0)) - - if payload.tax_inclusive: - # Entered price already includes tax; back-calculate taxable amount - line_total = _money(unit_price * quantity_value) - taxable_amount = _money(line_total / (1 + gst_rate / Decimal("100"))) - tax_amount = _money(line_total - taxable_amount) - else: - taxable_amount = _money(unit_price * quantity_value) - tax_amount = _money(taxable_amount * gst_rate / Decimal("100")) - line_total = _money(taxable_amount + tax_amount) - - taxable_total += taxable_amount - - invoice_item = InvoiceItem( - invoice_id=invoice.id, - product_id=product.id, - quantity=float(quantity_value), - hsn_sac=product.hsn_sac, - unit_price=float(unit_price), - gst_rate=float(gst_rate), - taxable_amount=float(taxable_amount), - tax_amount=float(tax_amount), - line_total=float(line_total), - description=item.description, - ) - created_items.append(invoice_item) - db.add(invoice_item) - - taxable_total = _money(taxable_total) - invoice.taxable_amount = float(taxable_total) - - assign_item_tax_split( - created_items, - interstate_supply=interstate_supply, - ) - - tax_total = assign_invoice_tax_totals( - invoice, - created_items, - interstate_supply=interstate_supply, - ) - raw_total = _money(taxable_total + tax_total) - if invoice.apply_round_off: - rounded_total = raw_total.quantize(Decimal("1"), rounding=ROUND_HALF_UP) - round_off_amount = _money(rounded_total - raw_total) - invoice.round_off_amount = float(round_off_amount) - invoice.total_amount = float(_money(rounded_total)) - else: - invoice.round_off_amount = 0 - invoice.total_amount = float(raw_total) - - def _to_invoice_out( invoice: Invoice, *, @@ -355,8 +71,8 @@ def create_invoice( ) db.add(invoice) db.flush() - _apply_payload_to_invoice( - db, + processor = InvoiceProcessor(db) + processor.apply_payload( invoice, payload, active_company, @@ -583,8 +299,8 @@ def update_invoice( try: active_fy = get_active_fy(db, company_id=active_company.id) - _apply_inventory_delta_for_invoice_update( - db, + processor = InvoiceProcessor(db) + processor.apply_inventory_delta_for_update( invoice, payload, company_id=active_company.id, @@ -594,8 +310,7 @@ def update_invoice( db.delete(item) db.flush() - _apply_payload_to_invoice( - db, + processor.apply_payload( invoice, payload, active_company, @@ -710,7 +425,8 @@ def cancel_invoice( raise HTTPException(status_code=400, detail="Invoice is already cancelled") try: - _reverse_existing_invoice_inventory(db, invoice) + processor = InvoiceProcessor(db) + processor.reverse_inventory(invoice) invoice.status = "cancelled" db.commit() db.refresh(invoice) @@ -743,24 +459,8 @@ def restore_invoice( raise HTTPException(status_code=400, detail="Invoice is not cancelled") try: - # Re-apply inventory changes (reverse the reversal) - for item in invoice.items: - product_query = db.query(Product).filter(Product.id == item.product_id) - product_query = product_query.filter(or_(Product.company_id == active_company.id, Product.company_id.is_(None))) - product = product_query.first() - if not product: - raise HTTPException(status_code=404, detail=f"Product {item.product_id} not found") - if not product.maintain_inventory: - continue - - restore_delta = -item.quantity if invoice.voucher_type == "sales" else item.quantity - _change_inventory_quantity( - db, - item.product_id, - restore_delta, - company_id=active_company.id, - context=f"restoring invoice {invoice.id}", - ) + processor = InvoiceProcessor(db) + processor.restore_inventory(invoice, company_id=active_company.id) invoice.status = "active" db.commit() db.refresh(invoice) diff --git a/backend/src/services/invoice_processor.py b/backend/src/services/invoice_processor.py new file mode 100644 index 0000000..96708f5 --- /dev/null +++ b/backend/src/services/invoice_processor.py @@ -0,0 +1,532 @@ +""" +InvoiceProcessor service — encapsulates invoice payload application, inventory +delta calculation, and ledger/item validation that used to live in the route +handler. Separating this logic makes it independently testable and reusable +outside of the HTTP layer. +""" + +from collections import defaultdict +from datetime import datetime +from decimal import Decimal, ROUND_HALF_UP + +from fastapi import HTTPException +from sqlalchemy import or_ +from sqlalchemy.orm import Session + +from src.models.buyer import Buyer as Ledger +from src.models.company import CompanyProfile +from src.models.inventory import Inventory +from src.models.invoice import Invoice, InvoiceItem +from src.models.product import Product +from src.schemas.invoice import InvoiceCreate +from src.services.gst_tax_service import ( + assign_invoice_tax_totals, + assign_item_tax_split, + is_interstate_supply, + money as _money, +) +from src.services.series import generate_next_number + + +# --------------------------------------------------------------------------- +# Module-level helpers (pure / stateless) +# --------------------------------------------------------------------------- + +def inventory_effect_for_voucher_type(quantity: float, voucher_type: str) -> Decimal: + """Return the signed inventory delta for *quantity* depending on voucher type. + + Sales reduce stock (negative), purchases increase it (positive). + """ + quantity_value = Decimal(str(quantity)) + return -quantity_value if voucher_type == "sales" else quantity_value + + +def change_inventory_quantity( + db: Session, + product_id: int, + quantity_delta: Decimal, + *, + company_id: int | None, + context: str, +) -> None: + """Apply *quantity_delta* to the inventory row for *product_id*. + + Creates the inventory row if it does not exist yet. Raises 400 if the + resulting quantity would drop below zero. + """ + query = db.query(Inventory).filter(Inventory.product_id == product_id) + if company_id is not None: + query = query.filter( + or_(Inventory.company_id == company_id, Inventory.company_id.is_(None)) + ) + inventory = query.first() + if not inventory: + inventory = Inventory(company_id=company_id, product_id=product_id, quantity=0) + db.add(inventory) + db.flush() + + inventory.quantity = Decimal(str(inventory.quantity or 0)) + quantity_delta + if Decimal(str(inventory.quantity or 0)) < 0: + raise HTTPException( + status_code=400, + detail=f"Insufficient inventory while {context}", + ) + + +# --------------------------------------------------------------------------- +# InvoiceProcessor class +# --------------------------------------------------------------------------- + +class InvoiceProcessor: + """Encapsulates invoice payload application and inventory management.""" + + def __init__(self, db: Session) -> None: + self.db = db + + # ------------------------------------------------------------------ + # Ledger helpers + # ------------------------------------------------------------------ + + def require_ledger(self, ledger_id: int, company_id: int | None) -> Ledger: + """Fetch the ledger by *ledger_id*, scoped to *company_id*. + + Raises 404 if not found. + """ + query = self.db.query(Ledger).filter(Ledger.id == ledger_id) + if company_id is not None: + query = query.filter( + or_(Ledger.company_id == company_id, Ledger.company_id.is_(None)) + ) + ledger = query.first() + if not ledger: + raise HTTPException( + status_code=404, detail=f"Ledger {ledger_id} not found" + ) + return ledger + + # ------------------------------------------------------------------ + # Inventory helpers + # ------------------------------------------------------------------ + + def reverse_inventory(self, invoice: Invoice) -> None: + """Undo the inventory effect of an existing invoice's line items. + + Used when cancelling an invoice. + """ + for item in invoice.items: + product_query = self.db.query(Product).filter( + Product.id == item.product_id + ) + if invoice.company_id is not None: + product_query = product_query.filter( + or_( + Product.company_id == invoice.company_id, + Product.company_id.is_(None), + ) + ) + product = product_query.first() + if not product: + raise HTTPException( + status_code=404, + detail=f"Product {item.product_id} not found", + ) + if not product.maintain_inventory: + continue + + reverse_delta = ( + Decimal(str(item.quantity)) + if invoice.voucher_type == "sales" + else -Decimal(str(item.quantity)) + ) + change_inventory_quantity( + self.db, + item.product_id, + reverse_delta, + company_id=invoice.company_id, + context=f"reversing invoice {invoice.id}", + ) + + def restore_inventory(self, invoice: Invoice, *, company_id: int) -> None: + """Re-apply the inventory effect of a previously-cancelled invoice. + + Used when restoring an invoice. + """ + for item in invoice.items: + product_query = self.db.query(Product).filter( + Product.id == item.product_id + ) + product_query = product_query.filter( + or_( + Product.company_id == company_id, + Product.company_id.is_(None), + ) + ) + product = product_query.first() + if not product: + raise HTTPException( + status_code=404, + detail=f"Product {item.product_id} not found", + ) + if not product.maintain_inventory: + continue + + restore_delta = ( + -Decimal(str(item.quantity)) + if invoice.voucher_type == "sales" + else Decimal(str(item.quantity)) + ) + change_inventory_quantity( + self.db, + item.product_id, + restore_delta, + company_id=company_id, + context=f"restoring invoice {invoice.id}", + ) + + def apply_inventory_delta_for_update( + self, + invoice: Invoice, + payload: InvoiceCreate, + *, + company_id: int | None, + ) -> None: + """Compute and apply the *net* inventory delta when editing an invoice. + + Compares the existing items against the incoming payload items and only + adjusts the difference, avoiding full reverse-then-reapply churn. + """ + existing_effect_by_product: dict[int, Decimal] = defaultdict( + lambda: Decimal("0") + ) + for item in invoice.items: + existing_effect_by_product[item.product_id] += ( + inventory_effect_for_voucher_type( + item.quantity, + invoice.voucher_type, + ) + ) + + next_effect_by_product: dict[int, Decimal] = defaultdict( + lambda: Decimal("0") + ) + for item in payload.items: + next_effect_by_product[item.product_id] += ( + inventory_effect_for_voucher_type( + item.quantity, + payload.voucher_type, + ) + ) + + for product_id in set(existing_effect_by_product) | set( + next_effect_by_product + ): + quantity_delta = ( + next_effect_by_product[product_id] + - existing_effect_by_product[product_id] + ) + if quantity_delta == 0: + continue + + product_query = self.db.query(Product).filter( + Product.id == product_id + ) + if company_id is not None: + product_query = product_query.filter( + or_( + Product.company_id == company_id, + Product.company_id.is_(None), + ) + ) + product = product_query.first() + if not product: + raise HTTPException( + status_code=404, + detail=f"Product {product_id} not found", + ) + if not product.maintain_inventory: + continue + + change_inventory_quantity( + self.db, + product_id, + quantity_delta, + company_id=company_id, + context=f"editing invoice {invoice.id}", + ) + + # ------------------------------------------------------------------ + # Item validation + # ------------------------------------------------------------------ + + def validate_items( + self, + items: list, + company_id: int | None, + voucher_type: str, + apply_inventory_changes: bool = True, + invoice_id: int | None = None, + ) -> list[tuple]: + """Validate each line item and return a list of (item_schema, product, quantity_decimal) + tuples, raising 400/404 errors for invalid data. + """ + if not items: + raise HTTPException( + status_code=400, + detail="Invoice must have at least one line item", + ) + + validated: list[tuple] = [] + for item in items: + quantity_value = Decimal(str(item.quantity)) + if quantity_value <= 0: + raise HTTPException( + status_code=400, + detail="Item quantity must be greater than zero", + ) + + product_query = self.db.query(Product).filter( + Product.id == item.product_id + ) + if company_id is not None: + product_query = product_query.filter( + or_( + Product.company_id == company_id, + Product.company_id.is_(None), + ) + ) + product = product_query.first() + if not product: + raise HTTPException( + status_code=404, + detail=f"Product {item.product_id} not found", + ) + + if not product.allow_decimal and quantity_value != quantity_value.to_integral_value(): + raise HTTPException( + status_code=400, + detail=f"Quantity for {product.name} must be a whole number", + ) + + if apply_inventory_changes and product.maintain_inventory: + inventory_query = self.db.query(Inventory).filter( + Inventory.product_id == item.product_id + ) + if company_id is not None: + inventory_query = inventory_query.filter( + or_( + Inventory.company_id == company_id, + Inventory.company_id.is_(None), + ) + ) + inventory = inventory_query.first() + if voucher_type == "sales" and ( + not inventory + or Decimal(str(inventory.quantity or 0)) < quantity_value + ): + raise HTTPException( + status_code=400, + detail=f"Insufficient inventory for {product.name}", + ) + + validated.append((item, product, quantity_value)) + return validated + + # ------------------------------------------------------------------ + # Totals calculation + # ------------------------------------------------------------------ + + def calculate_totals( + self, + validated_items: list[tuple], + tax_inclusive: bool, + ) -> list[dict]: + """Compute per-line tax and total amounts. + + Accepts the output of :meth:`validate_items` and returns a list of + dicts with the calculated fields for each line item. + """ + results = [] + for item_schema, product, quantity_value in validated_items: + unit_price = ( + Decimal(str(item_schema.unit_price)) + if item_schema.unit_price is not None + else Decimal(str(product.price)) + ) + gst_rate = Decimal(str(product.gst_rate or 0)) + + if tax_inclusive: + line_total = _money(unit_price * quantity_value) + taxable_amount = _money( + line_total / (1 + gst_rate / Decimal("100")) + ) + tax_amount = _money(line_total - taxable_amount) + else: + taxable_amount = _money(unit_price * quantity_value) + tax_amount = _money( + taxable_amount * gst_rate / Decimal("100") + ) + line_total = _money(taxable_amount + tax_amount) + + results.append( + { + "item_schema": item_schema, + "product": product, + "quantity_value": quantity_value, + "unit_price": unit_price, + "gst_rate": gst_rate, + "taxable_amount": taxable_amount, + "tax_amount": tax_amount, + "line_total": line_total, + } + ) + return results + + # ------------------------------------------------------------------ + # Main payload application + # ------------------------------------------------------------------ + + def apply_payload( + self, + invoice: Invoice, + payload: InvoiceCreate, + active_company: CompanyProfile | None = None, + created_by: int | None = None, + financial_year_id: int | None = None, + active_financial_year_id: int | None = None, + regenerate_number: bool = True, + apply_inventory_changes: bool = True, + ) -> None: + """Apply *payload* data onto *invoice*, updating all scalar fields, + creating new InvoiceItem rows, and recalculating totals. + + This is the primary entry-point used by both create and update routes. + """ + company = active_company or ( + self.db.query(CompanyProfile) + .order_by(CompanyProfile.id.asc()) + .first() + ) + company_id = company.id if company else None + ledger = self.require_ledger(payload.ledger_id, company_id) + + # Snapshot company / ledger fields onto the invoice record + invoice.company_id = company_id + invoice.ledger_id = ledger.id + invoice.ledger_name = ledger.name + invoice.ledger_address = ledger.address + invoice.ledger_gst = ledger.gst + invoice.ledger_phone = ledger.phone_number + invoice.company_name = company.name if company else None + invoice.company_address = company.address if company else None + invoice.company_gst = company.gst if company else None + invoice.company_phone = company.phone_number if company else None + invoice.company_email = company.email if company else None + invoice.company_website = company.website if company else None + invoice.company_currency_code = company.currency_code if company else None + invoice.company_bank_name = company.bank_name if company else None + invoice.company_branch_name = company.branch_name if company else None + invoice.company_account_name = company.account_name if company else None + invoice.company_account_number = company.account_number if company else None + invoice.company_ifsc_code = company.ifsc_code if company else None + invoice.voucher_type = payload.voucher_type + invoice.supplier_invoice_number = payload.supplier_invoice_number + invoice.reference_notes = payload.reference_notes + if created_by is not None: + invoice.created_by = created_by + if financial_year_id is not None: + invoice.financial_year_id = financial_year_id + + if payload.invoice_date is not None: + invoice.invoice_date = datetime.combine( + payload.invoice_date, datetime.min.time() + ) + + invoice.due_date = ( + datetime.combine(payload.due_date, datetime.min.time()) + if payload.due_date is not None + else None + ) + + invoice.tax_inclusive = payload.tax_inclusive + invoice.apply_round_off = payload.apply_round_off + + if regenerate_number: + invoice.invoice_number = generate_next_number( + self.db, + invoice.voucher_type, + financial_year_id, + payload.invoice_date, + active_financial_year_id, + company_id=company_id, + ) + + # Validate items and check inventory availability + validated = self.validate_items( + payload.items, + company_id, + payload.voucher_type, + apply_inventory_changes=apply_inventory_changes, + invoice_id=invoice.id, + ) + + interstate_supply = is_interstate_supply( + invoice.company_gst, invoice.ledger_gst + ) + + # Apply inventory changes for new items + if apply_inventory_changes: + for item_schema, product, quantity_value in validated: + if product.maintain_inventory: + quantity_delta = inventory_effect_for_voucher_type( + item_schema.quantity, payload.voucher_type + ) + change_inventory_quantity( + self.db, + item_schema.product_id, + quantity_delta, + company_id=company_id, + context=f"applying invoice {invoice.id or 'new'}", + ) + + # Calculate per-line totals + line_results = self.calculate_totals(validated, payload.tax_inclusive) + + # Create InvoiceItem ORM objects + taxable_total = Decimal("0") + created_items: list[InvoiceItem] = [] + for result in line_results: + taxable_total += result["taxable_amount"] + invoice_item = InvoiceItem( + invoice_id=invoice.id, + product_id=result["product"].id, + quantity=float(result["quantity_value"]), + hsn_sac=result["product"].hsn_sac, + unit_price=float(result["unit_price"]), + gst_rate=float(result["gst_rate"]), + taxable_amount=float(result["taxable_amount"]), + tax_amount=float(result["tax_amount"]), + line_total=float(result["line_total"]), + description=result["item_schema"].description, + ) + created_items.append(invoice_item) + self.db.add(invoice_item) + + taxable_total = _money(taxable_total) + invoice.taxable_amount = float(taxable_total) + + assign_item_tax_split(created_items, interstate_supply=interstate_supply) + + tax_total = assign_invoice_tax_totals( + invoice, created_items, interstate_supply=interstate_supply + ) + raw_total = _money(taxable_total + tax_total) + if invoice.apply_round_off: + rounded_total = raw_total.quantize( + Decimal("1"), rounding=ROUND_HALF_UP + ) + round_off_amount = _money(rounded_total - raw_total) + invoice.round_off_amount = float(round_off_amount) + invoice.total_amount = float(_money(rounded_total)) + else: + invoice.round_off_amount = 0 + invoice.total_amount = float(raw_total) diff --git a/backend/tests/api/test_invoice_dues_allocations.py b/backend/tests/api/test_invoice_dues_allocations.py index 91ea72d..b709627 100644 --- a/backend/tests/api/test_invoice_dues_allocations.py +++ b/backend/tests/api/test_invoice_dues_allocations.py @@ -86,7 +86,7 @@ def test_dues_endpoint_excludes_fully_paid_invoices(client): invoice_numbers = iter(["INV-000001", "INV-000002"]) with patch( - "src.api.routes.invoices._generate_next_number", + "src.services.invoice_processor.generate_next_number", side_effect=lambda *args, **kwargs: next(invoice_numbers), ), patch( "src.api.routes.payments.generate_next_number", @@ -141,7 +141,7 @@ def test_unpaid_invoices_returns_oldest_first_suggestions(client): invoice_numbers = iter(["INV-000101", "INV-000102"]) with patch( - "src.api.routes.invoices._generate_next_number", + "src.services.invoice_processor.generate_next_number", side_effect=lambda *args, **kwargs: next(invoice_numbers), ): ledger_id = _create_ledger(client) @@ -182,7 +182,7 @@ def test_ledger_statement_payment_entry_includes_invoice_allocations(client): today = datetime.utcnow().date() with patch( - "src.api.routes.invoices._generate_next_number", + "src.services.invoice_processor.generate_next_number", return_value="INV-000401", ), patch( "src.api.routes.payments.generate_next_number", diff --git a/backend/tests/api/test_invoice_reference_notes.py b/backend/tests/api/test_invoice_reference_notes.py index fb823b2..5a464f4 100644 --- a/backend/tests/api/test_invoice_reference_notes.py +++ b/backend/tests/api/test_invoice_reference_notes.py @@ -50,7 +50,7 @@ def _add_inventory(client, product_id: int, quantity: int): def test_sales_invoice_reference_notes_round_trip(client): - with patch("src.api.routes.invoices._generate_next_number", return_value="SAL-000001"): + with patch("src.services.invoice_processor.generate_next_number", return_value="SAL-000001"): ledger_id = _create_ledger(client, name="Sales Ledger", gst="27ABCDE9999F1Z5") product_id = _create_product(client) _add_inventory(client, product_id=product_id, quantity=20) diff --git a/backend/tests/api/test_invoice_tax_split.py b/backend/tests/api/test_invoice_tax_split.py index ce347a8..69d0c22 100644 --- a/backend/tests/api/test_invoice_tax_split.py +++ b/backend/tests/api/test_invoice_tax_split.py @@ -14,7 +14,8 @@ os.environ.setdefault("DATABASE_URL", "sqlite:///./test.db") -from src.api.routes.invoices import _apply_payload_to_invoice, _build_invoice_html +from src.services.invoice_processor import InvoiceProcessor +from src.services.pdf_templates import _build_invoice_html from src.db.base import Base from src.models.buyer import Buyer from src.models.company import CompanyProfile @@ -107,8 +108,7 @@ def test_intrastate_odd_paise_total_tax_is_adjusted_and_split_equally(db_session items=[InvoiceItemCreate(product_id=product.id, quantity=1, unit_price=100.03)], ) - _apply_payload_to_invoice( - db_session, + InvoiceProcessor(db_session).apply_payload( invoice, payload, created_by=user.id, @@ -152,8 +152,7 @@ def test_pdf_unit_price_always_displays_tax_inclusive_value(db_session): items=[InvoiceItemCreate(product_id=product.id, quantity=2, unit_price=100.00)], ) - _apply_payload_to_invoice( - db_session, + InvoiceProcessor(db_session).apply_payload( invoice, payload, created_by=user.id, @@ -187,8 +186,7 @@ def test_intrastate_three_line_case_keeps_itemwise_cgst_sgst_equal(db_session): ], ) - _apply_payload_to_invoice( - db_session, + InvoiceProcessor(db_session).apply_payload( invoice, payload, created_by=user.id, @@ -233,8 +231,7 @@ def test_interstate_item_tax_is_stored_as_igst_and_rendered_in_pdf(db_session): items=[InvoiceItemCreate(product_id=product.id, quantity=1, unit_price=100.00)], ) - _apply_payload_to_invoice( - db_session, + InvoiceProcessor(db_session).apply_payload( invoice, payload, created_by=user.id, diff --git a/backend/tests/api/test_invoice_update_inventory.py b/backend/tests/api/test_invoice_update_inventory.py index 5f36148..7ec9da4 100644 --- a/backend/tests/api/test_invoice_update_inventory.py +++ b/backend/tests/api/test_invoice_update_inventory.py @@ -68,7 +68,7 @@ def _create_invoice(client, ledger_id: int, voucher_type: str, product_id: int, def test_update_purchase_invoice_succeeds_when_current_inventory_is_zero(client, db_session): invoice_numbers = iter(["PUR-000001", "SAL-000001"]) with patch( - "src.api.routes.invoices._generate_next_number", + "src.services.invoice_processor.generate_next_number", side_effect=lambda *args, **kwargs: next(invoice_numbers), ): purchase_ledger_id = _create_ledger(client, name="Purchase Ledger", gst="27ABCDE1234F1Z5") @@ -115,7 +115,7 @@ def test_update_purchase_invoice_succeeds_when_current_inventory_is_zero(client, def test_sales_invoice_allows_untracked_product_without_inventory(client, db_session): - with patch("src.api.routes.invoices._generate_next_number", return_value="SAL-UNTRACKED-001"): + with patch("src.services.invoice_processor.generate_next_number", return_value="SAL-UNTRACKED-001"): sales_ledger_id = _create_ledger(client, name="Sales Ledger Untracked", gst="27ABCDE9999F1Z5") product_id = _create_product(client, sku="UPD-INV-UNTRACK-1", maintain_inventory=False) @@ -136,7 +136,7 @@ def test_sales_invoice_allows_untracked_product_without_inventory(client, db_ses def test_update_sales_invoice_with_untracked_product_does_not_create_inventory(client, db_session): - with patch("src.api.routes.invoices._generate_next_number", return_value="SAL-UNTRACKED-002"): + with patch("src.services.invoice_processor.generate_next_number", return_value="SAL-UNTRACKED-002"): sales_ledger_id = _create_ledger(client, name="Sales Ledger Update", gst="27ABCDE8888F1Z5") product_id = _create_product(client, sku="UPD-INV-UNTRACK-2", maintain_inventory=False) @@ -165,7 +165,7 @@ def test_update_sales_invoice_with_untracked_product_does_not_create_inventory(c def test_invoice_rejects_decimal_quantity_for_whole_number_product(client): - with patch("src.api.routes.invoices._generate_next_number", return_value="SAL-WHOLE-001"): + with patch("src.services.invoice_processor.generate_next_number", return_value="SAL-WHOLE-001"): ledger_id = _create_ledger(client, name="Sales Whole Qty", gst="27ABCDE1234F1Z5") product_id = _create_product(client, sku="UPD-INV-WHOLE-QTY", allow_decimal=False, unit="Pieces") @@ -185,7 +185,7 @@ def test_invoice_rejects_decimal_quantity_for_whole_number_product(client): def test_invoice_accepts_decimal_quantity_for_decimal_enabled_product(client): - with patch("src.api.routes.invoices._generate_next_number", return_value="PUR-DEC-001"): + with patch("src.services.invoice_processor.generate_next_number", return_value="PUR-DEC-001"): ledger_id = _create_ledger(client, name="Purchase Decimal Qty", gst="27ABCDE1234F1Z5") product_id = _create_product(client, sku="UPD-INV-DEC-QTY", allow_decimal=True, unit="Kg") diff --git a/backend/tests/services/test_invoice_processor.py b/backend/tests/services/test_invoice_processor.py new file mode 100644 index 0000000..0008cc9 --- /dev/null +++ b/backend/tests/services/test_invoice_processor.py @@ -0,0 +1,545 @@ +""" +Unit tests for src.services.invoice_processor. + +These tests use an in-memory SQLite database (the same one configured in +conftest.py via autouse fixture) so they exercise real ORM queries without +needing a running Postgres instance. +""" +from decimal import Decimal + +import pytest +from fastapi import HTTPException + +from src.models.buyer import Buyer +from src.models.company import CompanyProfile +from src.models.inventory import Inventory +from src.models.invoice import Invoice, InvoiceItem +from src.models.product import Product +from src.models.user import User, UserRole +from src.schemas.invoice import InvoiceCreate, InvoiceItemCreate +from src.services.invoice_processor import ( + InvoiceProcessor, + change_inventory_quantity, + inventory_effect_for_voucher_type, +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def make_company(db, *, gst="27AABCU9603R1ZX") -> CompanyProfile: + company = CompanyProfile( + name="Test Co", + address="123 Test St", + gst=gst, + phone_number="9999999999", + currency_code="INR", + ) + db.add(company) + db.flush() + return company + + +def make_ledger(db, company_id, *, gst="29AABCU9603R1ZX") -> Buyer: + ledger = Buyer( + company_id=company_id, + name="Test Ledger", + address="456 Ledger Ave", + gst=gst, + phone_number="8888888888", + ) + db.add(ledger) + db.flush() + return ledger + + +def make_product( + db, + company_id, + *, + price=100, + gst_rate=18, + maintain_inventory=True, + allow_decimal=False, + sku="P001", +) -> Product: + product = Product( + company_id=company_id, + sku=sku, + name="Widget", + price=price, + gst_rate=gst_rate, + maintain_inventory=maintain_inventory, + allow_decimal=allow_decimal, + ) + db.add(product) + db.flush() + return product + + +def make_inventory(db, product_id, company_id, quantity=50) -> Inventory: + inv = Inventory( + product_id=product_id, + company_id=company_id, + quantity=quantity, + ) + db.add(inv) + db.flush() + return inv + + +def make_user(db) -> User: + user = User( + email="test@example.com", + full_name="Test User", + hashed_password="hashed", + role=UserRole.admin, + ) + db.add(user) + db.flush() + return user + + +def make_invoice(db, company_id, user_id=None) -> Invoice: + if user_id is None: + user = make_user(db) + user_id = user.id + invoice = Invoice(total_amount=0, company_id=company_id, created_by=user_id) + db.add(invoice) + db.flush() + return invoice + + +# --------------------------------------------------------------------------- +# inventory_effect_for_voucher_type (pure function) +# --------------------------------------------------------------------------- + +class TestInventoryEffectForVoucherType: + def test_sales_returns_negative(self): + assert inventory_effect_for_voucher_type(5, "sales") == Decimal("-5") + + def test_purchase_returns_positive(self): + assert inventory_effect_for_voucher_type(3, "purchase") == Decimal("3") + + def test_fractional_quantity(self): + assert inventory_effect_for_voucher_type(1.5, "sales") == Decimal("-1.5") + + +# --------------------------------------------------------------------------- +# change_inventory_quantity +# --------------------------------------------------------------------------- + +class TestChangeInventoryQuantity: + def test_reduces_stock(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id) + make_inventory(db_session, product.id, company.id, quantity=10) + + change_inventory_quantity( + db_session, product.id, Decimal("-3"), company_id=company.id, context="test" + ) + + inv = db_session.query(Inventory).filter_by(product_id=product.id).first() + assert Decimal(str(inv.quantity)) == Decimal("7") + + def test_creates_inventory_row_if_missing(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id) + + change_inventory_quantity( + db_session, product.id, Decimal("5"), company_id=company.id, context="test" + ) + + inv = db_session.query(Inventory).filter_by(product_id=product.id).first() + assert inv is not None + assert Decimal(str(inv.quantity)) == Decimal("5") + + def test_raises_on_negative_stock(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id) + make_inventory(db_session, product.id, company.id, quantity=2) + + with pytest.raises(HTTPException) as exc_info: + change_inventory_quantity( + db_session, + product.id, + Decimal("-5"), + company_id=company.id, + context="selling", + ) + assert exc_info.value.status_code == 400 + assert "Insufficient inventory" in exc_info.value.detail + + +# --------------------------------------------------------------------------- +# InvoiceProcessor.require_ledger +# --------------------------------------------------------------------------- + +class TestRequireLedger: + def test_returns_ledger_when_found(self, db_session): + company = make_company(db_session) + ledger = make_ledger(db_session, company.id) + processor = InvoiceProcessor(db_session) + result = processor.require_ledger(ledger.id, company.id) + assert result.id == ledger.id + + def test_raises_404_when_not_found(self, db_session): + company = make_company(db_session) + processor = InvoiceProcessor(db_session) + with pytest.raises(HTTPException) as exc_info: + processor.require_ledger(9999, company.id) + assert exc_info.value.status_code == 404 + + +# --------------------------------------------------------------------------- +# InvoiceProcessor.validate_items +# --------------------------------------------------------------------------- + +class TestValidateItems: + def test_raises_on_empty_items(self, db_session): + company = make_company(db_session) + processor = InvoiceProcessor(db_session) + with pytest.raises(HTTPException) as exc_info: + processor.validate_items([], company.id, "sales") + assert exc_info.value.status_code == 400 + assert "at least one line item" in exc_info.value.detail + + def test_raises_on_zero_quantity(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id) + items = [InvoiceItemCreate(product_id=product.id, quantity=0)] + processor = InvoiceProcessor(db_session) + with pytest.raises(HTTPException) as exc_info: + processor.validate_items(items, company.id, "sales") + assert exc_info.value.status_code == 400 + assert "greater than zero" in exc_info.value.detail + + def test_raises_on_missing_product(self, db_session): + company = make_company(db_session) + items = [InvoiceItemCreate(product_id=99999, quantity=1)] + processor = InvoiceProcessor(db_session) + with pytest.raises(HTTPException) as exc_info: + processor.validate_items(items, company.id, "sales") + assert exc_info.value.status_code == 404 + + def test_raises_on_decimal_quantity_for_whole_number_product(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, allow_decimal=False) + items = [InvoiceItemCreate(product_id=product.id, quantity=1.5)] + processor = InvoiceProcessor(db_session) + with pytest.raises(HTTPException) as exc_info: + processor.validate_items(items, company.id, "sales", apply_inventory_changes=False) + assert exc_info.value.status_code == 400 + assert "whole number" in exc_info.value.detail + + def test_raises_insufficient_inventory_for_sales(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=True) + make_inventory(db_session, product.id, company.id, quantity=2) + items = [InvoiceItemCreate(product_id=product.id, quantity=5)] + processor = InvoiceProcessor(db_session) + with pytest.raises(HTTPException) as exc_info: + processor.validate_items(items, company.id, "sales", apply_inventory_changes=True) + assert exc_info.value.status_code == 400 + assert "Insufficient inventory" in exc_info.value.detail + + def test_returns_validated_tuples(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=False) + items = [InvoiceItemCreate(product_id=product.id, quantity=3)] + processor = InvoiceProcessor(db_session) + result = processor.validate_items(items, company.id, "sales", apply_inventory_changes=False) + assert len(result) == 1 + item_schema, prod, qty = result[0] + assert prod.id == product.id + assert qty == Decimal("3") + + +# --------------------------------------------------------------------------- +# InvoiceProcessor.calculate_totals +# --------------------------------------------------------------------------- + +class TestCalculateTotals: + def _make_validated(self, product, quantity, unit_price=None): + item_schema = InvoiceItemCreate( + product_id=product.id, + quantity=float(quantity), + unit_price=float(unit_price) if unit_price else None, + ) + return [(item_schema, product, Decimal(str(quantity)))] + + def test_tax_exclusive_calculation(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, price=100, gst_rate=18) + validated = self._make_validated(product, 2) + processor = InvoiceProcessor(db_session) + results = processor.calculate_totals(validated, tax_inclusive=False) + r = results[0] + assert r["taxable_amount"] == Decimal("200.00") + assert r["tax_amount"] == Decimal("36.00") + assert r["line_total"] == Decimal("236.00") + + def test_tax_inclusive_back_calculation(self, db_session): + company = make_company(db_session) + # Price of 118 includes 18% GST → taxable = 100, tax = 18 + product = make_product(db_session, company.id, price=118, gst_rate=18) + validated = self._make_validated(product, 1) + processor = InvoiceProcessor(db_session) + results = processor.calculate_totals(validated, tax_inclusive=True) + r = results[0] + assert r["line_total"] == Decimal("118.00") + assert r["taxable_amount"] == Decimal("100.00") + assert r["tax_amount"] == Decimal("18.00") + + def test_custom_unit_price_overrides_product_price(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, price=100, gst_rate=0) + validated = self._make_validated(product, 1, unit_price=200) + processor = InvoiceProcessor(db_session) + results = processor.calculate_totals(validated, tax_inclusive=False) + assert results[0]["taxable_amount"] == Decimal("200.00") + + def test_zero_gst_rate(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, price=100, gst_rate=0) + validated = self._make_validated(product, 1) + processor = InvoiceProcessor(db_session) + results = processor.calculate_totals(validated, tax_inclusive=False) + r = results[0] + assert r["tax_amount"] == Decimal("0.00") + assert r["taxable_amount"] == r["line_total"] + + +# --------------------------------------------------------------------------- +# InvoiceProcessor.apply_payload (integration-style, uses DB) +# --------------------------------------------------------------------------- + +class TestApplyPayload: + def _make_payload(self, ledger_id, product_id, quantity=1, voucher_type="sales"): + return InvoiceCreate( + ledger_id=ledger_id, + voucher_type=voucher_type, + items=[InvoiceItemCreate(product_id=product_id, quantity=quantity)], + ) + + def test_creates_invoice_items_and_total(self, db_session): + company = make_company(db_session) + ledger = make_ledger(db_session, company.id) + product = make_product(db_session, company.id, price=100, gst_rate=18, maintain_inventory=False) + invoice = make_invoice(db_session, company.id) + + payload = self._make_payload(ledger.id, product.id, quantity=2) + processor = InvoiceProcessor(db_session) + processor.apply_payload( + invoice, payload, company, apply_inventory_changes=False + ) + + assert invoice.ledger_id == ledger.id + assert invoice.total_amount == pytest.approx(236.0) + assert invoice.taxable_amount == pytest.approx(200.0) + items = db_session.query(InvoiceItem).filter_by(invoice_id=invoice.id).all() + assert len(items) == 1 + assert items[0].quantity == 2.0 + + def test_snapshots_company_fields(self, db_session): + company = make_company(db_session) + ledger = make_ledger(db_session, company.id) + product = make_product(db_session, company.id, maintain_inventory=False) + invoice = make_invoice(db_session, company.id) + + payload = self._make_payload(ledger.id, product.id) + processor = InvoiceProcessor(db_session) + processor.apply_payload(invoice, payload, company, apply_inventory_changes=False) + + assert invoice.company_name == company.name + assert invoice.ledger_name == ledger.name + assert invoice.company_gst == company.gst + + def test_raises_on_missing_ledger(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=False) + invoice = make_invoice(db_session, company.id) + + payload = self._make_payload(99999, product.id) + processor = InvoiceProcessor(db_session) + with pytest.raises(HTTPException) as exc_info: + processor.apply_payload(invoice, payload, company, apply_inventory_changes=False) + assert exc_info.value.status_code == 404 + + def test_apply_round_off(self, db_session): + company = make_company(db_session) + ledger = make_ledger(db_session, company.id) + # Price that yields a fractional total: 100 * 5% GST = 105.00, no fractions + # Use 101 * 5% = 106.05 → round off to 106 + product = make_product(db_session, company.id, price=101, gst_rate=5, maintain_inventory=False) + invoice = make_invoice(db_session, company.id) + + payload = InvoiceCreate( + ledger_id=ledger.id, + voucher_type="sales", + apply_round_off=True, + items=[InvoiceItemCreate(product_id=product.id, quantity=1)], + ) + processor = InvoiceProcessor(db_session) + processor.apply_payload(invoice, payload, company, apply_inventory_changes=False) + + assert invoice.total_amount == pytest.approx(106.0) + # round_off_amount should be small (−0.05 in this case) + assert abs(invoice.round_off_amount) < 1.0 + + +# --------------------------------------------------------------------------- +# InvoiceProcessor.apply_inventory_delta_for_update +# --------------------------------------------------------------------------- + +class TestApplyInventoryDeltaForUpdate: + def test_net_delta_for_quantity_increase(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=True) + make_inventory(db_session, product.id, company.id, quantity=10) + + # Simulate an existing invoice with 2 units sold + invoice = make_invoice(db_session, company.id) + invoice.voucher_type = "sales" + item = InvoiceItem( + invoice_id=invoice.id, + product_id=product.id, + quantity=2.0, + unit_price=100, + gst_rate=0, + taxable_amount=200, + tax_amount=0, + line_total=200, + ) + db_session.add(item) + db_session.flush() + invoice.items # load relationship + + # New payload wants 5 units → delta should reduce stock by 3 more + payload = InvoiceCreate( + ledger_id=1, + voucher_type="sales", + items=[InvoiceItemCreate(product_id=product.id, quantity=5)], + ) + processor = InvoiceProcessor(db_session) + processor.apply_inventory_delta_for_update(invoice, payload, company_id=company.id) + + inv = db_session.query(Inventory).filter_by(product_id=product.id).first() + # Started at 10, sold 2 originally (outside), now delta = -5 - (-2) = -3 more + assert Decimal(str(inv.quantity)) == Decimal("7") + + def test_no_change_when_quantity_same(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=True) + make_inventory(db_session, product.id, company.id, quantity=10) + + invoice = make_invoice(db_session, company.id) + invoice.voucher_type = "sales" + item = InvoiceItem( + invoice_id=invoice.id, + product_id=product.id, + quantity=3.0, + unit_price=100, + gst_rate=0, + taxable_amount=300, + tax_amount=0, + line_total=300, + ) + db_session.add(item) + db_session.flush() + + payload = InvoiceCreate( + ledger_id=1, + voucher_type="sales", + items=[InvoiceItemCreate(product_id=product.id, quantity=3)], + ) + processor = InvoiceProcessor(db_session) + processor.apply_inventory_delta_for_update(invoice, payload, company_id=company.id) + + inv = db_session.query(Inventory).filter_by(product_id=product.id).first() + assert Decimal(str(inv.quantity)) == Decimal("10") # unchanged + + +# --------------------------------------------------------------------------- +# InvoiceProcessor.reverse_inventory / restore_inventory +# --------------------------------------------------------------------------- + +class TestReverseRestoreInventory: + def test_reverse_increases_stock_for_sales_invoice(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=True) + make_inventory(db_session, product.id, company.id, quantity=5) + + invoice = make_invoice(db_session, company.id) + invoice.voucher_type = "sales" + item = InvoiceItem( + invoice_id=invoice.id, + product_id=product.id, + quantity=3.0, + unit_price=100, + gst_rate=0, + taxable_amount=300, + tax_amount=0, + line_total=300, + ) + db_session.add(item) + db_session.flush() + + processor = InvoiceProcessor(db_session) + processor.reverse_inventory(invoice) + + inv = db_session.query(Inventory).filter_by(product_id=product.id).first() + assert Decimal(str(inv.quantity)) == Decimal("8") # 5 + 3 (reversed) + + def test_restore_reduces_stock_for_sales_invoice(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=True) + make_inventory(db_session, product.id, company.id, quantity=8) + + invoice = make_invoice(db_session, company.id) + invoice.voucher_type = "sales" + item = InvoiceItem( + invoice_id=invoice.id, + product_id=product.id, + quantity=3.0, + unit_price=100, + gst_rate=0, + taxable_amount=300, + tax_amount=0, + line_total=300, + ) + db_session.add(item) + db_session.flush() + + processor = InvoiceProcessor(db_session) + processor.restore_inventory(invoice, company_id=company.id) + + inv = db_session.query(Inventory).filter_by(product_id=product.id).first() + assert Decimal(str(inv.quantity)) == Decimal("5") # 8 - 3 (re-applied) + + def test_skips_products_without_inventory_tracking(self, db_session): + company = make_company(db_session) + product = make_product(db_session, company.id, maintain_inventory=False) + + invoice = make_invoice(db_session, company.id) + invoice.voucher_type = "sales" + item = InvoiceItem( + invoice_id=invoice.id, + product_id=product.id, + quantity=10.0, + unit_price=100, + gst_rate=0, + taxable_amount=1000, + tax_amount=0, + line_total=1000, + ) + db_session.add(item) + db_session.flush() + + processor = InvoiceProcessor(db_session) + # Neither should raise nor create inventory rows + processor.reverse_inventory(invoice) + + inv = db_session.query(Inventory).filter_by(product_id=product.id).first() + assert inv is None