Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

All notable changes to this project will be documented in this file.

## Unreleased

- Fixed wheel packaging so joke sound assets are included from `trushell/sounds/`.
- Updated README custom sound instructions to use the current `trushell/sounds/` path.
- Replaced a one-off Rich console print in `cli.py` with `typer.echo()`.
- Added double-checked database initialization to avoid taking the initialization lock after startup.

## 0.1.0 - Initial release

- Reorganized the repository into a proper Python package.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Tests: pytest tests/
Version: kept in sync between trushell/__init__.py and pyproject.toml

To add a custom sound for jokes, put an .mp3 or .wav file into
trushell/chronoterm/sounds/ – it will appear in the ‘settings’ menu.
trushell/sounds/ – it will appear in the ‘settings’ menu.


License
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ build-backend = "hatchling.build"

[tool.hatch.build]
packages = ["trushell"]
include = ["README.md", "LICENSE", "trushell/chronoterm/sounds/*"]
include = ["README.md", "LICENSE", "trushell/sounds/*"]

[tool.ruff]
line-length = 88
Expand Down
42 changes: 30 additions & 12 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from trushell.core.database import _create_table, get_all_todos, get_db_connection, insert_todo
from trushell.core.database import _ensure_initialized, get_all_todos, get_db_connection, insert_todo
from trushell.core.models import Todo


def test_get_db_connection_returns_fresh_connection(monkeypatch, tmp_path) -> None:
def _use_temp_database(monkeypatch, tmp_path):
db_path = tmp_path / "todos.db"
monkeypatch.setattr("trushell.core.database.DB_PATH", db_path)
monkeypatch.setattr("trushell.core.database._DB_PATH", db_path)
monkeypatch.setattr("trushell.core.database._INITIALIZED", False)
return db_path


def test_get_db_connection_returns_fresh_connection(monkeypatch, tmp_path) -> None:
_use_temp_database(monkeypatch, tmp_path)
conn_one = get_db_connection()
conn_two = get_db_connection()

Expand All @@ -16,10 +21,9 @@ def test_get_db_connection_returns_fresh_connection(monkeypatch, tmp_path) -> No


def test_insert_todo_assigns_sequential_positions(monkeypatch, tmp_path) -> None:
db_path = tmp_path / "todos.db"
monkeypatch.setattr("trushell.core.database.DB_PATH", db_path)
_use_temp_database(monkeypatch, tmp_path)

_create_table()
_ensure_initialized()
insert_todo(Todo(task="first", category="work"))
insert_todo(Todo(task="second", category="work"))

Expand All @@ -30,20 +34,18 @@ def test_insert_todo_assigns_sequential_positions(monkeypatch, tmp_path) -> None


def test_get_all_todos_works_with_local_connections(monkeypatch, tmp_path) -> None:
db_path = tmp_path / "todos.db"
monkeypatch.setattr("trushell.core.database.DB_PATH", db_path)
_use_temp_database(monkeypatch, tmp_path)

_create_table()
_ensure_initialized()
insert_todo(Todo(task="alpha", category="study"))

assert len(get_all_todos()) == 1


def test_get_all_todos_returns_rows_ordered_by_position(monkeypatch, tmp_path) -> None:
db_path = tmp_path / "todos.db"
monkeypatch.setattr("trushell.core.database.DB_PATH", db_path)
_use_temp_database(monkeypatch, tmp_path)

_create_table()
_ensure_initialized()
with get_db_connection() as conn:
conn.execute(
"INSERT INTO todos VALUES (?, ?, ?, ?, ?, ?)",
Expand All @@ -58,3 +60,19 @@ def test_get_all_todos_returns_rows_ordered_by_position(monkeypatch, tmp_path) -

assert [task.task for task in tasks] == ["first", "second"]
assert [task.position for task in tasks] == [0, 1]


def test_ensure_initialized_skips_lock_when_already_initialized(monkeypatch, tmp_path) -> None:
_use_temp_database(monkeypatch, tmp_path)
_ensure_initialized()

class FailingLock:
def __enter__(self):
raise AssertionError("lock should not be acquired after initialization")

def __exit__(self, exc_type, exc, tb):
return False

monkeypatch.setattr("trushell.core.database._INITIALIZE_LOCK", FailingLock())

_ensure_initialized()
4 changes: 1 addition & 3 deletions trushell/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import time
import typer
from pathlib import Path
from rich.console import Console
from textual.app import App, ComposeResult
from textual.binding import Binding
from textual.widgets import Footer, Header, TextArea
Expand All @@ -23,7 +22,6 @@
from .core.trukernel import EXIT_SENTINEL, get_kernel

app = typer.Typer(name="trushell", help="TruShell manifest-driven launcher.")
console = Console()


def app_with_lower() -> None:
Expand Down Expand Up @@ -290,7 +288,7 @@ def _handle_cd_command(raw_command: str) -> bool:

try:
os.chdir(target)
console.print(f"[green]→ {os.getcwd()}[/green]")
typer.echo(f"→ {os.getcwd()}")
except (FileNotFoundError, NotADirectoryError, PermissionError) as error:
typer.secho(f"❌ Cannot navigate: {error}", fg=typer.colors.RED)
except OSError as error:
Expand Down
48 changes: 27 additions & 21 deletions trushell/core/database.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sqlite3
import threading
from pathlib import Path
from typing import List, Optional

Expand All @@ -10,6 +11,7 @@

# Global state to track initialization
_INITIALIZED = False
_INITIALIZE_LOCK = threading.Lock()
_DB_PATH: Optional[Path] = None


Expand All @@ -34,27 +36,31 @@ def _ensure_initialized() -> None:
if _INITIALIZED:
return

db_path = _get_db_path()

# Open a direct connection to initialize.
# We do NOT use get_db_connection() here to avoid recursion.
conn = sqlite3.connect(str(db_path), check_same_thread=False)
try:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS todos (
task TEXT,
category TEXT,
date_added TEXT,
date_completed TEXT,
status INTEGER,
position INTEGER
)"""
)
conn.commit()
_INITIALIZED = True
finally:
conn.close()
with _INITIALIZE_LOCK:
if _INITIALIZED:
return

db_path = _get_db_path()

# Open a direct connection to initialize.
# We do NOT use get_db_connection() here to avoid recursion.
conn = sqlite3.connect(str(db_path), check_same_thread=False)
try:
cursor = conn.cursor()
cursor.execute(
"""CREATE TABLE IF NOT EXISTS todos (
task TEXT,
category TEXT,
date_added TEXT,
date_completed TEXT,
status INTEGER,
position INTEGER
)"""
)
conn.commit()
_INITIALIZED = True
finally:
conn.close()


def get_db_connection() -> sqlite3.Connection:
Expand Down
Loading