From 3b1ea12ec08036b23f6c72bfe566cf6c63a27437 Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Tue, 17 Mar 2026 13:32:53 +0000
Subject: [PATCH 1/2] Initial plan
From 2640a0aa1678b9f0e4bdcaded19c739c425187ab Mon Sep 17 00:00:00 2001
From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com>
Date: Tue, 17 Mar 2026 13:46:29 +0000
Subject: [PATCH 2/2] Resolve conflicts: prefer copilot-vibe-code (latest) with
lint fixes applied
Co-authored-by: art-test-stack <110672812+art-test-stack@users.noreply.github.com>
---
pkg/mybib/arxiv.py | 42 +++----
pkg/mybib/bibtex.py | 77 ++++++------
pkg/mybib/categories.py | 52 ++++----
pkg/mybib/cli.py | 260 ++++++++++++++++++----------------------
pkg/mybib/db_storage.py | 127 +++++++++++---------
pkg/mybib/graph.py | 143 +++++++++++-----------
pkg/mybib/markdown.py | 47 ++++----
pkg/mybib/metadata.py | 71 ++++++-----
pkg/mybib/models.py | 26 ++--
pkg/mybib/scholar.py | 117 +++++++++---------
pkg/mybib/storage.py | 40 ++++---
pkg/mybib/ui.py | 37 +++---
pkg/mybib/utils.py | 30 +++--
tests/test_arxiv.py | 54 +++++----
tests/test_markdown.py | 55 +++++----
tests/test_metadata.py | 74 +++++++-----
tests/test_storage.py | 76 +++++++-----
17 files changed, 703 insertions(+), 625 deletions(-)
diff --git a/pkg/mybib/arxiv.py b/pkg/mybib/arxiv.py
index bf3d3ad..ab67841 100644
--- a/pkg/mybib/arxiv.py
+++ b/pkg/mybib/arxiv.py
@@ -1,54 +1,54 @@
"""Fetch metadata from arXiv API."""
import sys
-import requests
import xml.etree.ElementTree as ET
+import requests
+
def fetch_arxiv_metadata(arxiv_id: str) -> dict:
"""Fetch metadata from arXiv API.
-
+
Args:
arxiv_id: arXiv identifier (e.g., '2301.00001')
-
+
Returns:
Dictionary with keys: title, authors, journal, year, doi, link, arxiv_id
-
+
Raises:
SystemExit: If API call fails or no entry found
"""
url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}"
response = requests.get(url)
-
+
if response.status_code != 200:
print(f"Error fetching arxiv metadata: {response.status_code}")
sys.exit(1)
-
+
root = ET.fromstring(response.content)
ns = {
- 'atom': 'http://www.w3.org/2005/Atom',
- 'arxiv': 'http://arxiv.org/schemas/atom'
+ "atom": "http://www.w3.org/2005/Atom",
+ "arxiv": "http://arxiv.org/schemas/atom",
}
-
- entry = root.find('atom:entry', ns)
+
+ entry = root.find("atom:entry", ns)
if entry is None:
print("No entry found for this arxiv ID.")
sys.exit(1)
-
- title = entry.find('atom:title', ns).text.strip().replace('\n', ' ')
- authors = ', '.join(
- author.find('atom:name', ns).text
- for author in entry.findall('atom:author', ns)
+
+ title = entry.find("atom:title", ns).text.strip().replace("\n", " ")
+ authors = ", ".join(
+ author.find("atom:name", ns).text for author in entry.findall("atom:author", ns)
)
- published = entry.find('atom:published', ns).text
+ published = entry.find("atom:published", ns).text
year = int(published[:4])
-
- doi_elem = entry.find('arxiv:doi', ns)
+
+ doi_elem = entry.find("arxiv:doi", ns)
doi = doi_elem.text if doi_elem is not None else arxiv_id
-
- journal_elem = entry.find('arxiv:journal_ref', ns)
+
+ journal_elem = entry.find("arxiv:journal_ref", ns)
journal = journal_elem.text if journal_elem is not None else "arXiv"
-
+
return {
"title": title,
"authors": authors,
diff --git a/pkg/mybib/bibtex.py b/pkg/mybib/bibtex.py
index 6fc0462..2c2eb16 100644
--- a/pkg/mybib/bibtex.py
+++ b/pkg/mybib/bibtex.py
@@ -5,75 +5,74 @@
def generate_bibtex(df: pd.DataFrame) -> str:
"""Generate BibTeX entries from a DataFrame of references.
-
+
Args:
df: DataFrame with columns: Title, Authors, Journal, Year, DOI, Link
-
+
Returns:
String containing BibTeX formatted entries
"""
if df.empty:
return "% No references found\n"
-
+
# Handle cases where DataFrame column names might not exactly match
# Create a mapping of expected columns to actual columns
col_mapping = {}
-
+
for col in df.columns:
col_lower = col.lower().strip()
- if col_lower == 'title':
- col_mapping['title'] = col
- elif col_lower == 'authors':
- col_mapping['authors'] = col
- elif col_lower == 'journal':
- col_mapping['journal'] = col
- elif col_lower == 'year':
- col_mapping['year'] = col
- elif col_lower == 'doi':
- col_mapping['doi'] = col
- elif col_lower == 'link':
- col_mapping['link'] = col
- elif col_lower == 'url':
- col_mapping['link'] = col
-
+ if col_lower == "title":
+ col_mapping["title"] = col
+ elif col_lower == "authors":
+ col_mapping["authors"] = col
+ elif col_lower == "journal":
+ col_mapping["journal"] = col
+ elif col_lower == "year":
+ col_mapping["year"] = col
+ elif col_lower == "doi":
+ col_mapping["doi"] = col
+ elif col_lower == "link":
+ col_mapping["link"] = col
+ elif col_lower == "url":
+ col_mapping["link"] = col
+
entries = []
-
+
for _, row in df.iterrows():
# Extract values using mapped column names
- title = str(row.get(col_mapping.get('title', 'Title'), "")).strip()
- authors = str(row.get(col_mapping.get('authors', 'Authors'), "")).strip()
- journal = str(row.get(col_mapping.get('journal', 'Journal'), "")).strip()
- year = str(row.get(col_mapping.get('year', 'Year'), "")).strip()
- doi = str(row.get(col_mapping.get('doi', 'DOI'), "")).strip()
- link = str(row.get(col_mapping.get('link', 'Link'), "")).strip()
-
+ title = str(row.get(col_mapping.get("title", "Title"), "")).strip()
+ authors = str(row.get(col_mapping.get("authors", "Authors"), "")).strip()
+ journal = str(row.get(col_mapping.get("journal", "Journal"), "")).strip()
+ year = str(row.get(col_mapping.get("year", "Year"), "")).strip()
+ doi = str(row.get(col_mapping.get("doi", "DOI"), "")).strip()
+ link = str(row.get(col_mapping.get("link", "Link"), "")).strip()
+
# Skip entries with missing critical fields
if not title:
continue
-
+
# Use DOI as the key, or create one from title if DOI is missing
if doi and doi != "nan":
key = doi
else:
# Fallback: create key from title
key = title.lower().replace(" ", "_").replace(":", "")[:30]
-
+
# Build BibTeX entry
entry = f"@article{{{key},\n"
- entry += f' title={{{title}}},\n'
-
+ entry += f" title={{{title}}},\n"
+
if authors:
- entry += f' author={{{authors}}},\n'
+ entry += f" author={{{authors}}},\n"
if journal:
- entry += f' journal={{{journal}}},\n'
+ entry += f" journal={{{journal}}},\n"
if year and year != "nan":
- entry += f' year={{{year}}},\n'
+ entry += f" year={{{year}}},\n"
if link and link != "nan":
- entry += f' url={{{link}}}\n'
-
+ entry += f" url={{{link}}}\n"
+
entry += "}\n"
-
+
entries.append(entry)
-
- return "\n".join(entries)
+ return "\n".join(entries)
diff --git a/pkg/mybib/categories.py b/pkg/mybib/categories.py
index b7bbb14..4f30982 100644
--- a/pkg/mybib/categories.py
+++ b/pkg/mybib/categories.py
@@ -1,16 +1,15 @@
"""Category management for bibliography."""
import json
-from pathlib import Path
from typing import Dict, List, Tuple
def load_categories(file_path: str = "categories.json") -> Dict[str, str]:
"""Load category mappings from file.
-
+
Args:
file_path: Path to categories JSON file
-
+
Returns:
Dictionary mapping category ID to category name
"""
@@ -22,9 +21,11 @@ def load_categories(file_path: str = "categories.json") -> Dict[str, str]:
return {}
-def save_categories(categories: Dict[str, str], file_path: str = "categories.json") -> None:
+def save_categories(
+ categories: Dict[str, str], file_path: str = "categories.json"
+) -> None:
"""Save category mappings to file.
-
+
Args:
categories: Dictionary mapping category ID to category name
file_path: Path to categories JSON file
@@ -33,65 +34,70 @@ def save_categories(categories: Dict[str, str], file_path: str = "categories.jso
json.dump(categories, f, indent=2, sort_keys=True)
-def get_or_create_category(name: str, categories: Dict[str, str] = None) -> Tuple[str, Dict[str, str]]:
+def get_or_create_category(
+ name: str, categories: Dict[str, str] = None
+) -> Tuple[str, Dict[str, str]]:
"""Get category ID for given name, creating if needed.
-
+
Uses lowercase normalization to group similar categories.
-
+
Args:
- name: Category name
+ name: Category name
categories: Existing categories dict (loads from file if not provided)
-
+
Returns:
Tuple of (category_id, updated_categories_dict)
"""
if categories is None:
categories = load_categories()
-
+
# Normalize category name
normalized = name.lower().strip()
-
+
# Check if category already exists (case-insensitive)
for cat_id, cat_name in categories.items():
if cat_name.lower() == normalized:
return cat_id, categories
-
+
# Create new category
- new_id = str(max(int(cat_id) for cat_id in categories.keys() if cat_id.isdigit()) + 1 if categories else 1)
+ new_id = str(
+ max(int(cat_id) for cat_id in categories.keys() if cat_id.isdigit()) + 1
+ if categories
+ else 1
+ )
categories[new_id] = name
-
+
return new_id, categories
def list_categories(categories: Dict[str, str] = None) -> List[Tuple[str, str]]:
"""List all categories sorted by ID.
-
+
Args:
categories: Category mapping dict (loads from file if not provided)
-
+
Returns:
List of (id, name) tuples sorted by ID
"""
if categories is None:
categories = load_categories()
-
+
return sorted(
- categories.items(),
- key=lambda x: int(x[0]) if x[0].isdigit() else float('inf')
+ categories.items(), key=lambda x: int(x[0]) if x[0].isdigit() else float("inf")
)
def get_category_name(cat_id: str, categories: Dict[str, str] = None) -> str:
"""Get category name by ID.
-
+
Args:
cat_id: Category ID
categories: Category mapping dict (loads from file if not provided)
-
+
Returns:
Category name, or empty string if not found
"""
if categories is None:
categories = load_categories()
-
+
return categories.get(str(cat_id), "")
diff --git a/pkg/mybib/cli.py b/pkg/mybib/cli.py
index f80657c..299c3f9 100644
--- a/pkg/mybib/cli.py
+++ b/pkg/mybib/cli.py
@@ -4,32 +4,36 @@
import sys
from .arxiv import fetch_arxiv_metadata
-from .storage import add_reference, load_references
-from .markdown import make_markdown_table, make_markdown_tables_by_category
from .bibtex import generate_bibtex
+from .categories import (
+ get_or_create_category,
+ list_categories,
+ load_categories,
+ save_categories,
+)
from .graph import build_citation_graph, export_graph_html
+from .markdown import make_markdown_table, make_markdown_tables_by_category
from .scholar import search_and_confirm_article
-from .categories import load_categories, get_or_create_category, list_categories, save_categories
+from .storage import add_reference, load_references
from .ui import (
+ api_progress,
+ confirm_action,
console,
- print_success,
+ display_reference_preview,
print_error,
print_info,
+ print_success,
print_warning,
- api_progress,
- confirm_action,
- display_reference_preview,
- create_references_table,
)
def prompt_for_category(title: str, category_arg: str = None) -> str:
"""Prompt user to select or create a category.
-
+
Args:
title: Article title for context
category_arg: Pre-specified category (used if provided)
-
+
Returns:
Category name
"""
@@ -39,21 +43,22 @@ def prompt_for_category(title: str, category_arg: str = None) -> str:
cat_id, categories = get_or_create_category(category_arg, categories)
save_categories(categories)
return categories[cat_id]
-
+
# Show existing categories and allow selection or creation
categories = load_categories()
cat_list = list_categories(categories)
-
+
console.print("\n[bold]Available categories:[/]")
for cat_id, cat_name in cat_list:
console.print(f" {cat_id}: {cat_name}")
-
+
# Prompt for selection
while True:
choice = console.input(
- f"\n[bold]Select category ID for '{title}'[/] (or enter new category name): "
+ f"\n[bold]Select category ID for '{title}'[/] "
+ "(or enter new category name): "
).strip()
-
+
if choice.isdigit() and choice in categories:
return categories[choice]
elif choice:
@@ -68,21 +73,21 @@ def prompt_for_category(title: str, category_arg: str = None) -> str:
def handle_add_arxiv(args) -> None:
"""Handle the add-arxiv command.
-
+
Args:
args: Parsed command-line arguments
"""
# Extract arxiv ID from URL
arxiv_id = args.arxiv_url.split("/")[-1]
-
+
# Fetch metadata from arXiv with progress indicator
print_info(f"Fetching metadata for arXiv ID: {arxiv_id}")
with api_progress():
metadata = fetch_arxiv_metadata(arxiv_id)
-
+
# Get category using new category system
- category = prompt_for_category(metadata['title'], args.category)
-
+ category = prompt_for_category(metadata["title"], args.category)
+
# Show reference preview
preview_data = {
"title": metadata["title"],
@@ -94,14 +99,14 @@ def handle_add_arxiv(args) -> None:
console.print()
display_reference_preview(preview_data)
console.print()
-
+
# Confirm before adding
if not confirm_action(
f"Add '[bold cyan]{metadata['title']}[/]' to category '[yellow]{category}[/]'?"
):
print_warning("Aborted.")
sys.exit(0)
-
+
# Add reference to storage
add_reference(
title=metadata["title"],
@@ -119,33 +124,33 @@ def handle_add_arxiv(args) -> None:
def handle_add_scholar(args) -> None:
"""Handle the add-scholar command to search Google Scholar.
-
+
Args:
args: Parsed command-line arguments
"""
title = args.title
url = args.url
-
+
print_info("Searching Google Scholar for your article...")
-
+
# If no title provided, try to extract from URL or abort
if not title and not url:
print_error("Either --title or --url must be provided")
sys.exit(1)
-
+
# Search query: use title if provided, else use URL
search_query = title if title else url
-
+
# Search and get confirmation from user
metadata = search_and_confirm_article(search_query)
-
+
if not metadata:
print_error("Could not find or confirm article on Google Scholar")
sys.exit(1)
-
+
# Get category using new category system
- category = prompt_for_category(metadata['title'], args.category)
-
+ category = prompt_for_category(metadata["title"], args.category)
+
# Show reference preview
preview_data = {
"title": metadata["title"],
@@ -157,14 +162,14 @@ def handle_add_scholar(args) -> None:
console.print()
display_reference_preview(preview_data)
console.print()
-
+
# Confirm before adding
if not confirm_action(
f"Add '[bold cyan]{metadata['title']}[/]' to category '[yellow]{category}[/]'?"
):
print_warning("Aborted.")
sys.exit(0)
-
+
# Add reference to storage
add_reference(
title=metadata["title"],
@@ -182,28 +187,30 @@ def handle_add_scholar(args) -> None:
def handle_add_manual(args) -> None:
"""Handle the add command for manual reference entry.
-
+
If only title is provided, searches Google Scholar automatically.
All fields except title are optional.
-
+
Args:
args: Parsed command-line arguments
"""
# Check if we need to search Google Scholar
# If only title and category are provided (other fields are None), search Scholar
- has_manual_metadata = any([
- args.authors,
- args.journal,
- args.year,
- args.doi,
- args.link,
- ])
-
+ has_manual_metadata = any(
+ [
+ args.authors,
+ args.journal,
+ args.year,
+ args.doi,
+ args.link,
+ ]
+ )
+
if not has_manual_metadata:
# Only title provided, search Google Scholar
print_info("Searching Google Scholar for your article...")
metadata = search_and_confirm_article(args.title)
-
+
if not metadata:
print_error("Could not find or confirm article on Google Scholar")
sys.exit(1)
@@ -217,10 +224,10 @@ def handle_add_manual(args) -> None:
"doi": args.doi,
"link": args.link or "",
}
-
+
# Get category using new category system
- category = prompt_for_category(metadata['title'], args.category)
-
+ category = prompt_for_category(metadata["title"], args.category)
+
# Show reference preview
preview_data = {
"title": metadata["title"],
@@ -232,12 +239,14 @@ def handle_add_manual(args) -> None:
console.print()
display_reference_preview(preview_data)
console.print()
-
+
# Confirm before adding
- if not confirm_action(f"Add '[bold cyan]{metadata['title']}[/]' to [yellow]{category}[/]?"):
+ if not confirm_action(
+ f"Add '[bold cyan]{metadata['title']}[/]' to [yellow]{category}[/]?"
+ ):
print_warning("Aborted.")
sys.exit(0)
-
+
add_reference(
title=metadata["title"],
authors=metadata.get("authors") or None,
@@ -254,17 +263,17 @@ def handle_add_manual(args) -> None:
def handle_markdown(args) -> None:
"""Handle the markdown command to generate markdown tables.
-
+
Args:
args: Parsed command-line arguments
"""
print_info("Generating markdown tables...")
-
+
if args.by_category:
table = make_markdown_tables_by_category(args.file)
else:
table = make_markdown_table(args.file)
-
+
if args.output:
with open(args.output, "w") as f:
f.write(table)
@@ -275,14 +284,14 @@ def handle_markdown(args) -> None:
def handle_bibtex(args) -> None:
"""Handle the bibtex command to generate BibTeX file.
-
+
Args:
args: Parsed command-line arguments
"""
print_info("Generating BibTeX...")
df = load_references(args.file)
bibtex_content = generate_bibtex(df)
-
+
if args.output:
with open(args.output, "w") as f:
f.write(bibtex_content)
@@ -293,19 +302,19 @@ def handle_bibtex(args) -> None:
def handle_graph(args) -> None:
"""Handle the graph command to build and visualize citations.
-
+
Args:
args: Parsed command-line arguments
"""
df = load_references(args.file)
-
+
if df.empty:
print_error(f"No references found in {args.file}")
sys.exit(1)
-
+
print_info(f"Building citation graph from {len(df)} references...")
graph = build_citation_graph(df, output_references=args.verbose)
-
+
output_file = args.output or "citation_graph.html"
export_graph_html(graph, output_file)
print_success(f"Citation graph exported to {output_file}")
@@ -313,16 +322,16 @@ def handle_graph(args) -> None:
def handle_db_init(args) -> None:
"""Handle database initialization.
-
+
Args:
args: Parsed command-line arguments
"""
from .db_storage import DatabaseStorage
-
+
print_info(f"Initializing database: {args.db_url}")
-
+
try:
- storage = DatabaseStorage(args.db_url)
+ DatabaseStorage(args.db_url)
print_success("Database initialized successfully!")
except Exception as e:
print_error(f"Failed to initialize database: {e}")
@@ -331,24 +340,24 @@ def handle_db_init(args) -> None:
def handle_db_migrate(args) -> None:
"""Handle migration from CSV to database.
-
+
Args:
args: Parsed command-line arguments
"""
from .db_storage import DatabaseStorage
-
+
print_info(f"Migrating from {args.file} to {args.db_url}")
-
+
try:
storage = DatabaseStorage(args.db_url)
stats = storage.migrate_from_csv(args.file)
-
+
console.print("\n[bold]Migration Statistics:[/]")
console.print(f" Total: {stats['total']}")
console.print(f" Added: {stats['added']}")
console.print(f" Duplicates: {stats['duplicates']}")
console.print(f" Errors: {stats['errors']}")
-
+
print_success("Migration completed!")
except Exception as e:
print_error(f"Failed to migrate database: {e}")
@@ -357,14 +366,14 @@ def handle_db_migrate(args) -> None:
def handle_db_export(args) -> None:
"""Handle export from database to CSV.
-
+
Args:
args: Parsed command-line arguments
"""
from .db_storage import DatabaseStorage
-
+
print_info(f"Exporting from {args.db_url} to {args.output}")
-
+
try:
storage = DatabaseStorage(args.db_url)
count = storage.export_to_csv(args.output)
@@ -378,68 +387,56 @@ def main() -> None:
"""Main CLI entry point."""
parser = argparse.ArgumentParser(
prog="mybib",
- description="📚 Manage research paper references with ease. Similar to gh, poetry, and uv!",
+ description="📚 Manage research paper references with ease. "
+ "Similar to gh, poetry, and uv!",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
mybib add-arxiv https://arxiv.org/abs/2301.00001 --category ML
- mybib add --title "My Paper" --authors "Author Name" --journal "Nature" --year 2024 --doi "10.xxxx/xxxxx" --category Science
+ mybib add --title "My Paper" --authors "Author Name" --journal "Nature" \
+ --year 2024 --doi "10.xxxx/xxxxx" --category Science
mybib markdown --file references.csv --output README.md
mybib bibtex --file references.csv --output references.bib
mybib graph --file references.csv --output graph.html
- """
+ """,
)
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# add-arxiv command
add_arxiv_parser = subparsers.add_parser(
- "add-arxiv",
- help="Add a reference from an arXiv URL"
- )
- add_arxiv_parser.add_argument(
- "arxiv_url",
- help="The arXiv URL (e.g., https://arxiv.org/abs/2301.00001)"
+ "add-arxiv", help="Add a reference from an arXiv URL"
)
add_arxiv_parser.add_argument(
- "--category",
- help="Category for the reference"
+ "arxiv_url", help="The arXiv URL (e.g., https://arxiv.org/abs/2301.00001)"
)
+ add_arxiv_parser.add_argument("--category", help="Category for the reference")
add_arxiv_parser.add_argument(
"--file",
default="references.csv",
- help="CSV file path (default: references.csv)"
+ help="CSV file path (default: references.csv)",
)
add_arxiv_parser.set_defaults(func=handle_add_arxiv)
# add-scholar command
add_scholar_parser = subparsers.add_parser(
- "add-scholar",
- help="Add a reference from Google Scholar"
- )
- add_scholar_parser.add_argument(
- "--title",
- help="Article title to search for"
- )
- add_scholar_parser.add_argument(
- "--url",
- help="Article URL to search for"
- )
- add_scholar_parser.add_argument(
- "--category",
- help="Category for the reference"
+ "add-scholar", help="Add a reference from Google Scholar"
)
+ add_scholar_parser.add_argument("--title", help="Article title to search for")
+ add_scholar_parser.add_argument("--url", help="Article URL to search for")
+ add_scholar_parser.add_argument("--category", help="Category for the reference")
add_scholar_parser.add_argument(
"--file",
default="references.csv",
- help="CSV file path (default: references.csv)"
+ help="CSV file path (default: references.csv)",
)
add_scholar_parser.set_defaults(func=handle_add_scholar)
# add command
add_parser = subparsers.add_parser(
"add",
- help="Add a reference manually (or search Google Scholar if only title provided)"
+ help="Add a reference manually (or search Google Scholar"
+ " if only title provided)",
)
add_parser.add_argument("--title", required=True, help="Article title (required)")
add_parser.add_argument("--authors", help="Comma-separated author names")
@@ -451,118 +448,97 @@ def main() -> None:
add_parser.add_argument(
"--file",
default="references.csv",
- help="CSV file path (default: references.csv)"
+ help="CSV file path (default: references.csv)",
)
add_parser.set_defaults(func=handle_add_manual)
# markdown command
md_parser = subparsers.add_parser(
- "markdown",
- help="Generate markdown tables from references"
+ "markdown", help="Generate markdown tables from references"
)
md_parser.add_argument(
"--file",
default="references.csv",
- help="CSV file path (default: references.csv)"
- )
- md_parser.add_argument(
- "--by-category",
- action="store_true",
- help="Split tables by category"
+ help="CSV file path (default: references.csv)",
)
md_parser.add_argument(
- "--output",
- help="Output file (e.g., README.md)"
+ "--by-category", action="store_true", help="Split tables by category"
)
+ md_parser.add_argument("--output", help="Output file (e.g., README.md)")
md_parser.set_defaults(func=handle_markdown)
# bibtex command
bibtex_parser = subparsers.add_parser(
- "bibtex",
- help="Generate BibTeX file from references"
+ "bibtex", help="Generate BibTeX file from references"
)
bibtex_parser.add_argument(
"--file",
default="references.csv",
- help="CSV file path (default: references.csv)"
- )
- bibtex_parser.add_argument(
- "--output",
- help="Output file (e.g., references.bib)"
+ help="CSV file path (default: references.csv)",
)
+ bibtex_parser.add_argument("--output", help="Output file (e.g., references.bib)")
bibtex_parser.set_defaults(func=handle_bibtex)
# graph command
graph_parser = subparsers.add_parser(
- "graph",
- help="Build and visualize citation graph"
+ "graph", help="Build and visualize citation graph"
)
graph_parser.add_argument(
"--file",
default="references.csv",
- help="CSV file path (default: references.csv)"
+ help="CSV file path (default: references.csv)",
)
graph_parser.add_argument(
- "--output",
- help="Output HTML file (default: citation_graph.html)"
+ "--output", help="Output HTML file (default: citation_graph.html)"
)
graph_parser.add_argument(
- "--verbose",
- action="store_true",
- help="Show verbose output references"
+ "--verbose", action="store_true", help="Show verbose output references"
)
graph_parser.set_defaults(func=handle_graph)
# db-init command
db_init_parser = subparsers.add_parser(
- "db-init",
- help="Initialize database for bibliography management"
+ "db-init", help="Initialize database for bibliography management"
)
db_init_parser.add_argument(
"--db-url",
default="sqlite:///bibliography.db",
- help="Database URL (default: sqlite:///bibliography.db)"
+ help="Database URL (default: sqlite:///bibliography.db)",
)
db_init_parser.set_defaults(func=handle_db_init)
# db-migrate command
db_migrate_parser = subparsers.add_parser(
- "db-migrate",
- help="Migrate references from CSV to database"
+ "db-migrate", help="Migrate references from CSV to database"
)
db_migrate_parser.add_argument(
"--file",
default="references.csv",
- help="CSV file path (default: references.csv)"
+ help="CSV file path (default: references.csv)",
)
db_migrate_parser.add_argument(
"--db-url",
default="sqlite:///bibliography.db",
- help="Database URL (default: sqlite:///bibliography.db)"
+ help="Database URL (default: sqlite:///bibliography.db)",
)
db_migrate_parser.set_defaults(func=handle_db_migrate)
# db-export command
db_export_parser = subparsers.add_parser(
- "db-export",
- help="Export database references to CSV"
- )
- db_export_parser.add_argument(
- "--output",
- required=True,
- help="Output CSV file"
+ "db-export", help="Export database references to CSV"
)
+ db_export_parser.add_argument("--output", required=True, help="Output CSV file")
db_export_parser.add_argument(
"--db-url",
default="sqlite:///bibliography.db",
- help="Database URL (default: sqlite:///bibliography.db)"
+ help="Database URL (default: sqlite:///bibliography.db)",
)
db_export_parser.set_defaults(func=handle_db_export)
args = parser.parse_args()
# Execute the appropriate handler or show help
- if hasattr(args, 'func'):
+ if hasattr(args, "func"):
try:
args.func(args)
except KeyboardInterrupt:
@@ -573,4 +549,4 @@ def main() -> None:
print_error(f"An error occurred: {e}")
sys.exit(1)
else:
- parser.print_help()
\ No newline at end of file
+ parser.print_help()
diff --git a/pkg/mybib/db_storage.py b/pkg/mybib/db_storage.py
index cc73075..7ee8f09 100644
--- a/pkg/mybib/db_storage.py
+++ b/pkg/mybib/db_storage.py
@@ -1,8 +1,10 @@
"""Database storage adapter for bibliography management."""
-from typing import Optional, List, Dict
+from typing import Dict, List, Optional
+
from sqlalchemy.exc import IntegrityError
-from .models import Reference, Category, create_db_engine, init_db, get_session
+
+from .models import Category, Reference, create_db_engine, get_session, init_db
class DatabaseStorage:
@@ -10,7 +12,7 @@ class DatabaseStorage:
def __init__(self, db_url: str = "sqlite:///bibliography.db"):
"""Initialize database storage.
-
+
Args:
db_url: Database connection URL
"""
@@ -30,7 +32,7 @@ def add_reference(
scholar_id: str = None,
) -> Optional[Reference]:
"""Add a reference to the database.
-
+
Args:
title: Article title
authors: Comma-separated author names
@@ -41,12 +43,12 @@ def add_reference(
category_name: Category name
arxiv_id: arXiv identifier
scholar_id: Google Scholar ID
-
+
Returns:
Created Reference object or None if duplicate
"""
session = get_session(self.engine)
-
+
try:
# Check for duplicate DOI
if doi:
@@ -54,7 +56,7 @@ def add_reference(
if existing:
session.close()
return None
-
+
# Get or create category
category = None
if category_name:
@@ -63,7 +65,7 @@ def add_reference(
category = Category(name=category_name)
session.add(category)
session.flush()
-
+
# Create reference
reference = Reference(
title=title,
@@ -76,13 +78,13 @@ def add_reference(
scholar_id=scholar_id,
category_id=category.id if category else None,
)
-
+
session.add(reference)
session.commit()
session.close()
-
+
return reference
-
+
except IntegrityError:
session.rollback()
session.close()
@@ -96,72 +98,68 @@ def get_references(
self, category_id: int = None, year: int = None, order_by: str = None
) -> List[Reference]:
"""Get references from database with optional filtering.
-
+
Args:
category_id: Filter by category ID
year: Filter by year
order_by: Field to order by (e.g., "year", "-year", "title")
-
+
Returns:
List of Reference objects
"""
session = get_session(self.engine)
query = session.query(Reference)
-
+
if category_id:
query = query.filter_by(category_id=category_id)
-
+
if year:
query = query.filter_by(year=year)
-
+
# Ordering
if order_by:
reverse = order_by.startswith("-")
field = order_by[1:] if reverse else order_by
-
+
if hasattr(Reference, field):
col = getattr(Reference, field)
query = query.order_by(col.desc() if reverse else col)
else:
# Default ordering: category, then year descending
query = query.order_by(Category.name, Reference.year.desc())
-
+
results = query.all()
session.close()
-
+
return results
def add_category(self, name: str, description: str = None) -> Optional[Category]:
"""Add a category to the database.
-
+
Args:
name: Category name
description: Optional description
-
+
Returns:
Created Category object or None if duplicate
"""
session = get_session(self.engine)
-
+
try:
# Check for existing category (case-insensitive)
- existing = (
- session.query(Category)
- .filter(Category.name.ilike(name))
- .first()
- )
-
+ existing = session.query(Category).filter(Category.name.ilike(name)).first()
+
if existing:
session.close()
return existing
-
+
category = Category(name=name, description=description)
session.add(category)
session.commit()
session.close()
-
+
return category
-
+
except IntegrityError:
session.rollback()
session.close()
@@ -173,37 +171,37 @@ def add_category(self, name: str, description: str = None) -> Optional[Category]
def get_categories(self) -> List[Category]:
"""Get all categories.
-
+
Returns:
List of Category objects ordered by name
"""
session = get_session(self.engine)
categories = session.query(Category).order_by(Category.name).all()
session.close()
-
+
return categories
def migrate_from_csv(self, csv_file: str) -> Dict[str, int]:
"""Migrate references from CSV file to database.
-
+
Args:
csv_file: Path to CSV file
-
+
Returns:
Dictionary with migration statistics
"""
import pandas as pd
-
+
df = pd.read_csv(csv_file, dtype={"ArxivID": str})
df = df.fillna("")
-
+
stats = {
"total": len(df),
"added": 0,
"duplicates": 0,
"errors": 0,
}
-
+
for _, row in df.iterrows():
try:
result = self.add_reference(
@@ -216,48 +214,61 @@ def migrate_from_csv(self, csv_file: str) -> Dict[str, int]:
category_name=row.get("Category", ""),
arxiv_id=row.get("ArxivID", ""),
)
-
+
if result:
stats["added"] += 1
else:
stats["duplicates"] += 1
-
+
except Exception as e:
stats["errors"] += 1
print(f"Error migrating {row.get('Title', 'Unknown')}: {e}")
-
+
return stats
def export_to_csv(self, csv_file: str) -> int:
"""Export database references to CSV file.
-
+
Args:
csv_file: Path to output CSV file
-
+
Returns:
Number of references exported
"""
import pandas as pd
-
+
session = get_session(self.engine)
references = session.query(Reference).all()
session.close()
-
+
data = []
for ref in references:
- data.append({
- "Title": ref.title,
- "Authors": ref.authors or "",
- "Journal": ref.journal or "",
- "Year": ref.year or "",
- "DOI": ref.doi or "",
- "Link": ref.link or "",
- "Category": ref.category.name if ref.category else "",
- "ArxivID": ref.arxiv_id or "",
- })
-
+ data.append(
+ {
+ "Title": ref.title,
+ "Authors": ref.authors or "",
+ "Journal": ref.journal or "",
+ "Year": ref.year or "",
+ "DOI": ref.doi or "",
+ "Link": ref.link or "",
+ "Category": ref.category.name if ref.category else "",
+ "ArxivID": ref.arxiv_id or "",
+ }
+ )
+
df = pd.DataFrame(data)
- df = df[["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"]]
+ df = df[
+ [
+ "Title",
+ "Authors",
+ "Journal",
+ "Year",
+ "DOI",
+ "Link",
+ "Category",
+ "ArxivID",
+ ]
+ ]
df.to_csv(csv_file, index=False)
-
+
return len(data)
diff --git a/pkg/mybib/graph.py b/pkg/mybib/graph.py
index ba2fabc..1ca3473 100644
--- a/pkg/mybib/graph.py
+++ b/pkg/mybib/graph.py
@@ -1,105 +1,111 @@
"""Citation graph building and visualization using Crossref API."""
+import time
+
import networkx as nx
-from typing import Optional
import pandas as pd
import requests
-import time
def query_crossref_references(doi: str, max_retries: int = 3) -> list[str]:
"""Query Crossref API for references cited by a given DOI.
-
+
Args:
doi: The DOI of the paper to query (e.g., "10.1234/example")
max_retries: Maximum number of retry attempts
-
+
Returns:
List of DOIs cited by the paper
"""
# Ensure DOI doesn't have 'https://doi.org/' prefix
- doi_clean = doi.strip().lstrip('https://doi.org/').lstrip('http://doi.org/')
-
+ doi_clean = doi.strip().lstrip("https://doi.org/").lstrip("http://doi.org/")
+
url = f"https://api.crossref.org/works/{doi_clean}"
headers = {"User-Agent": "MyBible (mailto:test@example.com)"}
-
+
for attempt in range(max_retries):
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
-
+
data = response.json()
- if data.get('status') == 'ok':
- work = data.get('message', {})
- references = work.get('reference', [])
-
+ if data.get("status") == "ok":
+ work = data.get("message", {})
+ references = work.get("reference", [])
+
# Extract DOIs from references
reference_dois = []
for ref in references:
- if isinstance(ref, dict) and 'DOI' in ref:
- reference_dois.append(ref['DOI'].lower())
-
+ if isinstance(ref, dict) and "DOI" in ref:
+ reference_dois.append(ref["DOI"].lower())
+
return reference_dois
-
+
except requests.exceptions.RequestException as e:
if attempt < max_retries - 1:
# Exponential backoff
- wait_time = 2 ** attempt
+ wait_time = 2**attempt
time.sleep(wait_time)
else:
print(f"Warning: Failed to query DOI {doi_clean}: {e}")
-
+
return []
-def build_citation_graph(df: pd.DataFrame,
- output_references: bool = False) -> nx.DiGraph:
+def build_citation_graph(
+ df: pd.DataFrame, output_references: bool = False
+) -> nx.DiGraph:
"""Build a citation graph from stored research papers.
-
+
Creates a directed graph where:
- Nodes represent papers (identified by DOI)
- Edges represent citations (paper A -> paper B means A cites B)
-
+
Args:
df: DataFrame with columns including 'DOI' and 'Title'
output_references: If True, print progress information
-
+
Returns:
A NetworkX directed graph
"""
graph = nx.DiGraph()
-
+
# Add all papers as nodes
paper_dois = {} # Maps clean DOI to original DOI
for idx, row in df.iterrows():
- doi = str(row['DOI']).strip()
- if not doi or doi.lower() == 'nan':
+ doi = str(row["DOI"]).strip()
+ if not doi or doi.lower() == "nan":
continue
-
+
# Normalize DOI for use as node identifier
- doi_clean = doi.lstrip('https://doi.org/').lstrip('http://doi.org/').lower()
+ doi_clean = doi.lstrip("https://doi.org/").lstrip("http://doi.org/").lower()
paper_dois[doi_clean] = doi
-
- title = row.get('Title', 'Unknown')
- graph.add_node(doi_clean,
- title=title,
- original_doi=doi,
- authors=row.get('Authors', ''),
- year=row.get('Year', ''))
-
+
+ title = row.get("Title", "Unknown")
+ graph.add_node(
+ doi_clean,
+ title=title,
+ original_doi=doi,
+ authors=row.get("Authors", ""),
+ year=row.get("Year", ""),
+ )
+
if output_references:
print(f"Added {len(graph.nodes())} papers to graph")
-
+
# Query Crossref API for each paper's references
stored_dois = set(paper_dois.keys())
-
+
for idx, (doi_clean, doi_original) in enumerate(paper_dois.items()):
if output_references:
- print(f"Querying references for paper {idx + 1}/{len(paper_dois)}: {doi_clean}")
-
+ print(
+ f"Querying references for paper "
+ f"{idx + 1}/{len(paper_dois)}: {doi_clean}"
+ )
+
# Query Crossref for references
reference_dois = query_crossref_references(doi_original)
-
+
# Add edges only for references in our stored papers
for ref_doi in reference_dois:
if ref_doi in stored_dois:
@@ -107,19 +113,22 @@ def build_citation_graph(df: pd.DataFrame,
graph.add_edge(doi_clean, ref_doi)
if output_references:
print(f" -> Edge: {doi_clean} cites {ref_doi}")
-
+
# Rate limiting: Crossref API recommends polite usage
time.sleep(0.5)
-
+
if output_references:
- print(f"Citation graph built with {len(graph.nodes())} nodes and {len(graph.edges())} edges")
-
+ print(
+ f"Citation graph built with {len(graph.nodes())} nodes "
+ f"and {len(graph.edges())} edges"
+ )
+
return graph
def export_graph_html(graph: nx.DiGraph, output_file: str) -> None:
"""Export citation graph as interactive HTML using pyvis.
-
+
Args:
graph: NetworkX directed graph
output_file: Path to output HTML file
@@ -127,36 +136,34 @@ def export_graph_html(graph: nx.DiGraph, output_file: str) -> None:
try:
from pyvis.network import Network
except ImportError:
- raise ImportError("pyvis is required for graph visualization. Install it with: pip install pyvis")
-
+ raise ImportError(
+ "pyvis is required for graph visualization. "
+ "Install it with: pip install pyvis"
+ )
+
import json
-
+
# Create pyvis network
- net = Network(directed=True, height='750px', width='100%')
-
+ net = Network(directed=True, height="750px", width="100%")
+
# Add nodes with labels
for node in graph.nodes(data=True):
node_id = node[0]
node_data = node[1]
-
- title = node_data.get('title', 'Unknown')
- label = title[:50] + '...' if len(title) > 50 else title
-
+
+ title = node_data.get("title", "Unknown")
+ label = title[:50] + "..." if len(title) > 50 else title
+
# Create hover title with full information
hover_title = f"{title}\nDOI: {node_data.get('original_doi', '')}"
-
- net.add_node(
- node_id,
- label=label,
- title=hover_title,
- color='#FF6E63'
- )
-
+
+ net.add_node(node_id, label=label, title=hover_title, color="#FF6E63")
+
# Add edges
for edge in graph.edges():
source, target = edge
net.add_edge(source, target)
-
+
# Configure physics settings as JSON string
options = {
"physics": {
@@ -165,12 +172,12 @@ def export_graph_html(graph: nx.DiGraph, output_file: str) -> None:
"gravitationalConstant": -26000,
"centralGravity": 0.3,
"springLength": 200,
- "springConstant": 0.04
- }
+ "springConstant": 0.04,
+ },
}
}
net.set_options(json.dumps(options))
-
+
# Write HTML file
net.write_html(output_file)
- print(f"Citation graph exported to {output_file}")
\ No newline at end of file
+ print(f"Citation graph exported to {output_file}")
diff --git a/pkg/mybib/markdown.py b/pkg/mybib/markdown.py
index 3f6b2ce..a3aa413 100644
--- a/pkg/mybib/markdown.py
+++ b/pkg/mybib/markdown.py
@@ -6,81 +6,84 @@
def _prepare_references_for_markdown(file_path: str = "references.csv"):
"""Prepare references DataFrame for markdown output.
-
+
Args:
file_path: Path to the CSV file
-
+
Returns:
Processed DataFrame ready for markdown conversion
"""
df = load_references(file_path)
-
+
if df.empty:
return df
-
+
df["Authors"] = df["Authors"].apply(reform_names)
df["DOI"] = df.apply(lambda row: f"[{row.get('DOI', 'unknown')}]", axis=1)
# Ensure ArxivID stays as string for display (prevents float conversion)
if "ArxivID" in df.columns:
df["ArxivID"] = df["ArxivID"].astype(str)
df = df.sort_values(by=["Category", "Year"], ascending=[True, False])
-
+
return df
def make_markdown_table(file_path: str = "references.csv") -> str:
"""Generate a markdown table from all references.
-
+
Args:
file_path: Path to the CSV file
-
+
Returns:
Markdown formatted table as string
"""
df = _prepare_references_for_markdown(file_path)
-
+
if df.empty:
return "No references found."
-
+
markdown_table = df.to_markdown(index=False)
return markdown_table
def make_markdown_tables_by_category(file_path: str = "references.csv") -> str:
"""Generate markdown tables separated by category.
-
+
Args:
file_path: Path to the CSV file
-
+
Returns:
Markdown formatted tables with category headers and footer links
"""
df = _prepare_references_for_markdown(file_path)
-
+
if df.empty:
return "No references found."
-
+
output = []
footer = []
-
+
for category, group in df.groupby("Category"):
output.append(f"## {category}\n")
# Format ArxivID as string to avoid float conversion by to_markdown
display_group = group.copy()
if "ArxivID" in display_group.columns:
- display_group["ArxivID"] = display_group["ArxivID"].astype(str).str.replace("nan", "")
-
- table = display_group.drop(columns=["Category", "Link"]).to_markdown(index=False)
+ display_group["ArxivID"] = (
+ display_group["ArxivID"].astype(str).str.replace("nan", "")
+ )
+
+ table = display_group.drop(columns=["Category", "Link"]).to_markdown(
+ index=False
+ )
output.append(table)
output.append("\n")
-
+
# Collect footer links
footer.extend(
- f"{doi}: {link}"
- for doi, link in zip(group["DOI"], group["Link"])
+ f"{doi}: {link}" for doi, link in zip(group["DOI"], group["Link"])
)
-
+
# Sort and deduplicate footer
footer = sorted(set(footer), key=lambda x: x.split(":")[0])
-
+
return "\n".join(output) + "\n".join(footer)
diff --git a/pkg/mybib/metadata.py b/pkg/mybib/metadata.py
index 75aa77d..973d148 100644
--- a/pkg/mybib/metadata.py
+++ b/pkg/mybib/metadata.py
@@ -2,23 +2,23 @@
import re
import sys
+
import requests
-from urllib.parse import urlparse
from . import arxiv
def fetch_metadata(url: str) -> dict:
"""Fetch metadata from a URL, automatically detecting the source.
-
+
Supports:
- arXiv URLs (arxiv.org)
- DOI URLs (doi.org) and DOI patterns
- Generic URLs with HTML metadata extraction
-
+
Args:
url: URL or DOI string to fetch metadata from
-
+
Returns:
Dictionary with standardized fields:
{
@@ -29,12 +29,12 @@ def fetch_metadata(url: str) -> dict:
'doi': str or None,
'link': str
}
-
+
Raises:
SystemExit: If URL is invalid or metadata cannot be extracted
"""
url = url.strip()
-
+
# Detect source and route to appropriate handler
if _is_arxiv_url(url):
return _fetch_arxiv_metadata(url)
@@ -62,7 +62,7 @@ def _is_doi_pattern(url: str) -> bool:
def _extract_arxiv_id(url: str) -> str:
"""Extract arXiv ID from URL.
-
+
Examples:
- https://arxiv.org/abs/2301.00001 -> 2301.00001
- https://arxiv.org/pdf/2301.00001.pdf -> 2301.00001
@@ -71,7 +71,7 @@ def _extract_arxiv_id(url: str) -> str:
match = re.search(r"/(?:abs|pdf)/(\d{4}\.\d{4,5})", url)
if match:
return match.group(1)
-
+
print(f"Error: Could not extract arXiv ID from {url}")
sys.exit(1)
@@ -95,27 +95,31 @@ def _normalize_doi(url_or_doi: str) -> str:
def _fetch_crossref_metadata(url_or_doi: str) -> dict:
"""Fetch metadata from Crossref API using DOI."""
doi = _normalize_doi(url_or_doi)
-
+
# Crossref API endpoint
crossref_url = f"https://api.crossref.org/works/{doi}"
-
+
try:
response = requests.get(crossref_url, timeout=10)
response.raise_for_status()
except requests.RequestException as e:
print(f"Error fetching DOI metadata: {e}")
sys.exit(1)
-
+
data = response.json()
if data.get("status") != "ok" or not data.get("message"):
print(f"Error: Could not fetch metadata for DOI {doi}")
sys.exit(1)
-
+
work = data["message"]
-
+
# Extract fields with fallbacks
- title = work.get("title", [""])[0] if isinstance(work.get("title"), list) else work.get("title", "Unknown")
-
+ title = (
+ work.get("title", [""])[0]
+ if isinstance(work.get("title"), list)
+ else work.get("title", "Unknown")
+ )
+
# Authors
authors = ""
if "author" in work:
@@ -124,29 +128,29 @@ def _fetch_crossref_metadata(url_or_doi: str) -> dict:
for a in work["author"]
]
authors = ", ".join(author_list)
-
+
# Journal
journal = work.get("container-title", "Unknown")
if isinstance(journal, list):
journal = journal[0] if journal else "Unknown"
-
+
# Year
year = None
if "issued" in work:
date_parts = work["issued"].get("date-parts", [[None]])[0]
if date_parts:
year = date_parts[0]
-
+
# Link (prefer DOI link, fallback to URL)
link = work.get("URL", f"https://doi.org/{doi}")
-
+
return {
"title": title,
"authors": authors,
"journal": journal,
"year": year,
"doi": doi,
- "link": link
+ "link": link,
}
@@ -158,41 +162,41 @@ def _fetch_generic_metadata(url: str) -> dict:
except requests.RequestException as e:
print(f"Error fetching URL: {e}")
sys.exit(1)
-
+
html = response.text
-
+
# Extract title from
tag or og:title meta tag
title = _extract_html_meta(html, ["og:title", "twitter:title"])
if not title:
title_match = re.search(r"]*>([^<]+)", html, re.IGNORECASE)
title = title_match.group(1).strip() if title_match else "Unknown"
-
+
# Extract author
authors = _extract_html_meta(html, ["author", "article:author"])
-
+
# Try to extract DOI
doi = None
doi_match = re.search(r"10\.\S+/\S+", html)
if doi_match:
doi = doi_match.group(0)
-
+
return {
"title": title or "Unknown",
"authors": authors or "Unknown",
"journal": "Unknown",
"year": None,
"doi": doi,
- "link": url
+ "link": url,
}
def _extract_html_meta(html: str, meta_names: list) -> str:
"""Extract value from HTML meta tags.
-
+
Args:
html: HTML content as string
meta_names: List of meta tag names to search for
-
+
Returns:
Content of the first matching meta tag, or None if not found
"""
@@ -202,17 +206,20 @@ def _extract_html_meta(html: str, meta_names: list) -> str:
match = re.search(pattern, html, re.IGNORECASE)
if match:
return match.group(1).strip()
-
+
# Try name attribute
pattern = rf'"
+ return (
+ f""
+ )
def create_db_engine(db_url: str = "sqlite:///bibliography.db"):
"""Create database engine.
-
+
Args:
db_url: Database URL (default: SQLite)
-
+
Returns:
SQLAlchemy engine
"""
@@ -80,7 +84,7 @@ def create_db_engine(db_url: str = "sqlite:///bibliography.db"):
def init_db(engine):
"""Initialize database tables.
-
+
Args:
engine: SQLAlchemy engine
"""
@@ -89,10 +93,10 @@ def init_db(engine):
def get_session(engine):
"""Get database session.
-
+
Args:
engine: SQLAlchemy engine
-
+
Returns:
SQLAlchemy session
"""
diff --git a/pkg/mybib/scholar.py b/pkg/mybib/scholar.py
index 81ad328..28471e2 100644
--- a/pkg/mybib/scholar.py
+++ b/pkg/mybib/scholar.py
@@ -1,15 +1,15 @@
"""Fetch metadata from Google Scholar using SerpAPI."""
-import sys
import os
+import sys
+from typing import Dict, List, Optional
+
import requests
-import json
-from typing import Optional, Dict, List
-from urllib.parse import urlencode
# Suppress SSL warnings if needed
try:
import urllib3
+
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
except ImportError:
pass
@@ -17,15 +17,15 @@
def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]:
"""Search Google Scholar for articles matching a query.
-
+
Args:
query: Search query (title, authors, or keywords)
max_results: Maximum number of results to return (1-20)
-
+
Returns:
- List of results with keys: position, title, result_id, link, snippet,
+ List of results with keys: position, title, result_id, link, snippet,
publication_info, inline_links, authors, year, etc.
-
+
Raises:
SystemExit: If API call fails or key is missing
"""
@@ -34,7 +34,7 @@ def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]:
print("Error: SERPAPI_KEY environment variable not set.")
print("Get a free API key at https://serpapi.com")
sys.exit(1)
-
+
params = {
"engine": "google_scholar",
"q": query,
@@ -42,21 +42,21 @@ def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]:
"num": min(max_results, 20),
"hl": "en",
}
-
+
url = "https://serpapi.com/search"
-
+
try:
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
-
+
if "error" in data:
print(f"Error from SerpAPI: {data['error']}")
sys.exit(1)
-
+
organic_results = data.get("organic_results", [])
return organic_results
-
+
except requests.exceptions.RequestException as e:
print(f"Error querying Google Scholar: {e}")
sys.exit(1)
@@ -64,27 +64,27 @@ def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]:
def get_scholar_cite_link(result_id: str) -> Optional[str]:
"""Get the Google Scholar cite link for a result.
-
+
Args:
result_id: The result ID from search results
-
+
Returns:
The SerpAPI cite API link, or None if not available
"""
api_key = os.environ.get("SERPAPI_KEY")
if not api_key:
return None
-
+
# Format: https://serpapi.com/search.json?engine=google_scholar_cite&q={result_id}&api_key={api_key}
return f"https://serpapi.com/search.json?engine=google_scholar_cite&q={result_id}&api_key={api_key}"
def fetch_bibtex_from_scholar(result_id: str) -> Optional[str]:
"""Fetch BibTeX citation from Google Scholar using SerpAPI cite API.
-
+
Args:
result_id: The result ID from a Google Scholar search result
-
+
Returns:
BibTeX string, or None if unable to fetch
"""
@@ -92,36 +92,36 @@ def fetch_bibtex_from_scholar(result_id: str) -> Optional[str]:
if not api_key:
print("Error: SERPAPI_KEY environment variable not set.")
return None
-
+
params = {
"engine": "google_scholar_cite",
"q": result_id,
"api_key": api_key,
}
-
+
url = "https://serpapi.com/search"
-
+
try:
response = requests.get(url, params=params, timeout=10)
response.raise_for_status()
data = response.json()
-
+
if "error" in data:
print(f"Error from SerpAPI: {data['error']}")
return None
-
+
# SerpAPI returns citations in different formats
# The BibTeX is typically in the response as a string or in a structured format
citations = data.get("citations", {})
-
+
# Try to find BibTeX format
bibtex = citations.get("bibtex")
if bibtex:
return bibtex
-
+
# If we can't find BibTeX, we might need to construct it from available data
return None
-
+
except requests.exceptions.RequestException as e:
print(f"Error fetching BibTeX from Google Scholar: {e}")
return None
@@ -129,15 +129,15 @@ def fetch_bibtex_from_scholar(result_id: str) -> Optional[str]:
def extract_metadata_from_result(result: Dict) -> Dict:
"""Extract standardized metadata from a Google Scholar search result.
-
+
Args:
result: A single result from search_google_scholar()
-
+
Returns:
Dictionary with keys: title, authors, journal, year, doi, link, scholar_id
"""
import re
-
+
metadata = {
"title": result.get("title", ""),
"authors": "",
@@ -148,29 +148,29 @@ def extract_metadata_from_result(result: Dict) -> Dict:
"result_id": result.get("result_id", ""),
"scholar_id": result.get("result_id", ""),
}
-
+
# Extract publication info
pub_info = result.get("publication_info", {})
if isinstance(pub_info, dict):
summary = pub_info.get("summary", "")
metadata["journal"] = summary
-
+
# Try to extract year from summary - look for 4-digit years (1900-2099)
# Use a more specific pattern to avoid partial matches
- year_matches = re.findall(r'\b(19\d{2}|20\d{2})\b', summary)
+ year_matches = re.findall(r"\b(19\d{2}|20\d{2})\b", summary)
if year_matches:
# Take the most likely year (prefer 20xx over 19xx if available)
for year_str in year_matches:
metadata["year"] = int(year_str)
# Prefer years starting with 20
- if year_str.startswith('20'):
+ if year_str.startswith("20"):
break
-
+
# Try to extract DOI from summary (pattern: 10.xxxx/xxxx)
- doi_match = re.search(r'\b(10\.\S+/\S+)\b', summary)
+ doi_match = re.search(r"\b(10\.\S+/\S+)\b", summary)
if doi_match:
metadata["doi"] = doi_match.group(1)
-
+
# Extract authors
authors_list = pub_info.get("authors", [])
if authors_list:
@@ -184,13 +184,13 @@ def extract_metadata_from_result(result: Dict) -> Dict:
metadata["authors"] = ", ".join(author_names)
else:
metadata["authors"] = str(authors_list)
-
+
# Try to extract year from other fields if not found
if metadata["year"] is None:
- title_year_match = re.search(r'\((\d{4})\)', metadata["title"])
+ title_year_match = re.search(r"\((\d{4})\)", metadata["title"])
if title_year_match:
metadata["year"] = int(title_year_match.group(1))
-
+
# Try to extract DOI from result directly if not found in summary
if metadata["doi"] is None:
if "doi" in result:
@@ -201,50 +201,51 @@ def extract_metadata_from_result(result: Dict) -> Dict:
# Check for DOI link in inline_links
for key, value in inline_links.items():
if "doi.org" in str(value):
- doi_match = re.search(r'doi\.org/(.+?)(?:\s|$)', str(value))
+ doi_match = re.search(r"doi\.org/(.+?)(?:\s|$)", str(value))
if doi_match:
metadata["doi"] = doi_match.group(1)
break
-
- # Note: We do NOT use scholar_id as a fallback for DOI because they are different identifiers.
+
+ # Note: We do NOT use scholar_id as a fallback for DOI because they are
+ # different identifiers.
# Scholar ID is Google Scholar's internal identifier, not a DOI.
# If no real DOI is found, leave it as None.
-
+
return metadata
def search_and_confirm_article(title: str, max_attempts: int = 3) -> Optional[Dict]:
"""Search for an article on Google Scholar and get user confirmation.
-
+
Args:
title: Article title to search for
max_attempts: Maximum attempts to find a match
-
+
Returns:
Metadata dictionary if found and confirmed, None otherwise
"""
- from .ui import console, confirm_action, display_reference_preview
-
+ from .ui import confirm_action, console, display_reference_preview
+
console.print(f"[cyan]Searching Google Scholar for: {title}[/]")
-
+
results = search_google_scholar(title, max_results=5)
-
+
if not results:
console.print("[red]No results found on Google Scholar[/]")
return None
-
+
# Show first result and ask for confirmation
first_result = results[0]
metadata = extract_metadata_from_result(first_result)
-
+
console.print()
console.print("[yellow]Found:[/]")
display_reference_preview(metadata)
console.print()
-
+
if confirm_action("Is this the correct article?"):
return metadata
-
+
# If not confirmed, show other results
if len(results) > 1:
for i, result in enumerate(results[1:], 1):
@@ -253,12 +254,12 @@ def search_and_confirm_article(title: str, max_attempts: int = 3) -> Optional[Di
result_metadata = extract_metadata_from_result(result)
display_reference_preview(result_metadata)
console.print()
-
- if confirm_action(f"Is this the correct article?"):
+
+ if confirm_action("Is this the correct article?"):
return result_metadata
-
+
if i >= max_attempts - 1:
break
-
+
console.print("[red]No matching article confirmed[/]")
return None
diff --git a/pkg/mybib/storage.py b/pkg/mybib/storage.py
index c14b4bd..6ffc787 100644
--- a/pkg/mybib/storage.py
+++ b/pkg/mybib/storage.py
@@ -1,9 +1,10 @@
"""CSV storage for bibliography references."""
-import pandas as pd
import sys
from pathlib import Path
+import pandas as pd
+
def add_reference(
title: str,
@@ -18,7 +19,7 @@ def add_reference(
file_path: str = "references.csv",
) -> None:
"""Add a reference to the CSV file.
-
+
Args:
title: Article title
authors: Comma-separated author names
@@ -30,13 +31,13 @@ def add_reference(
arxiv_id: arXiv identifier (optional)
scholar_id: Google Scholar result ID (optional, used as DOI fallback)
file_path: Path to the CSV file
-
+
Raises:
SystemExit: If reference already exists
"""
# Use scholar_id as DOI fallback if DOI not provided
final_doi = doi if doi else scholar_id
-
+
new_reference = {
"Title": title,
"Authors": authors,
@@ -48,9 +49,9 @@ def add_reference(
"ArxivID": arxiv_id or "",
}
row = pd.DataFrame([new_reference])
-
+
file_exists = Path(file_path).exists()
-
+
if file_exists:
df_existing = pd.read_csv(file_path)
# Convert ArxivID to string for comparison
@@ -59,16 +60,14 @@ def add_reference(
existing_df = df_existing
# Normalize DOI for comparison: convert to string, strip whitespace, lowercase
existing_dois = set(
- str(d).strip().lower()
- for d in existing_df["DOI"].to_list()
- if pd.notna(d)
+ str(d).strip().lower() for d in existing_df["DOI"].to_list() if pd.notna(d)
)
normalized_doi = str(final_doi).strip().lower() if final_doi else ""
-
+
if normalized_doi and normalized_doi in existing_dois:
print("Reference already exists in the CSV file.")
sys.exit(0)
-
+
# Ensure ArxivID is treated as string
row["ArxivID"] = row["ArxivID"].astype(str)
row.to_csv(file_path, mode="a", index=False, header=not file_exists)
@@ -76,10 +75,10 @@ def add_reference(
def load_references(file_path: str = "references.csv") -> pd.DataFrame:
"""Load references from CSV file.
-
+
Args:
file_path: Path to the CSV file
-
+
Returns:
DataFrame with reference data
"""
@@ -93,9 +92,18 @@ def load_references(file_path: str = "references.csv") -> pd.DataFrame:
df["ArxivID"] = ""
except FileNotFoundError:
df = pd.DataFrame(
- columns=["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"]
+ columns=[
+ "Title",
+ "Authors",
+ "Journal",
+ "Year",
+ "DOI",
+ "Link",
+ "Category",
+ "ArxivID",
+ ]
)
df = df.astype({"ArxivID": str})
df.to_csv(file_path, index=False)
-
- return df
\ No newline at end of file
+
+ return df
diff --git a/pkg/mybib/ui.py b/pkg/mybib/ui.py
index 1645c09..6e61e9b 100644
--- a/pkg/mybib/ui.py
+++ b/pkg/mybib/ui.py
@@ -1,14 +1,17 @@
"""UI utilities using rich for enhanced terminal output."""
from contextlib import contextmanager
-from pathlib import Path
from typing import Generator
+import pandas as pd
from rich.console import Console
-from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn
-from rich.table import Table
+from rich.progress import (
+ Progress,
+ SpinnerColumn,
+ TextColumn,
+)
from rich.prompt import Confirm
-import pandas as pd
+from rich.table import Table
# Create a global console instance for consistent styling
console = Console()
@@ -37,7 +40,7 @@ def print_info(message: str) -> None:
@contextmanager
def progress_context(description: str = "Processing") -> Generator:
"""Context manager for showing progress during operations.
-
+
Usage:
with progress_context("Fetching metadata"):
# Long-running operation
@@ -48,7 +51,7 @@ def progress_context(description: str = "Processing") -> Generator:
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
- task_id = progress.add_task(description, total=None)
+ progress.add_task(description, total=None)
try:
yield
finally:
@@ -63,7 +66,7 @@ def api_progress() -> Generator:
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
- task_id = progress.add_task("[cyan]Fetching from API...", total=None)
+ progress.add_task("[cyan]Fetching from API...", total=None)
try:
yield
finally:
@@ -72,11 +75,11 @@ def api_progress() -> Generator:
def confirm_action(prompt: str, default: bool = True) -> bool:
"""Show a confirmation prompt with rich styling.
-
+
Args:
prompt: The question to ask
default: Default response if user just presses enter (default: True)
-
+
Returns:
True if confirmed, False otherwise
"""
@@ -85,10 +88,10 @@ def confirm_action(prompt: str, default: bool = True) -> bool:
def create_references_table(df: pd.DataFrame) -> Table:
"""Create a rich Table from a DataFrame of references.
-
+
Args:
df: DataFrame with columns: Title, Authors, Journal, Year, Category
-
+
Returns:
A rich Table object
"""
@@ -100,14 +103,14 @@ def create_references_table(df: pd.DataFrame) -> Table:
padding=(0, 1),
show_lines=False,
)
-
+
# Add columns
table.add_column("Title", style="cyan", width=40, overflow="fold")
table.add_column("Authors", style="green", width=30, overflow="fold")
table.add_column("Year", justify="center", width=6)
table.add_column("Category", style="yellow", width=15)
table.add_column("Journal", style="blue", width=20, overflow="fold")
-
+
# Add rows
for _, row in df.iterrows():
table.add_row(
@@ -117,22 +120,22 @@ def create_references_table(df: pd.DataFrame) -> Table:
str(row.get("Category", "")),
str(row.get("Journal", ""))[:20],
)
-
+
return table
def display_reference_preview(metadata: dict) -> None:
"""Display a nice preview of a reference before adding.
-
+
Args:
metadata: Dictionary with title, authors, journal, year, category
"""
table = Table(show_header=False, box=None, padding=(0, 2))
-
+
table.add_row("[bold cyan]Title:[/]", metadata.get("title", "N/A"))
table.add_row("[bold cyan]Authors:[/]", metadata.get("authors", "N/A"))
table.add_row("[bold cyan]Journal:[/]", metadata.get("journal", "N/A"))
table.add_row("[bold cyan]Year:[/]", str(metadata.get("year", "N/A")))
table.add_row("[bold cyan]DOI:[/]", metadata.get("doi", "N/A"))
-
+
console.print(table)
diff --git a/pkg/mybib/utils.py b/pkg/mybib/utils.py
index 42d255e..28d863a 100644
--- a/pkg/mybib/utils.py
+++ b/pkg/mybib/utils.py
@@ -3,42 +3,54 @@
def reform_names(authors_str: str) -> str:
"""Format author names for display.
-
+
Converts full author lists to abbreviated form:
- Single author or team: Last name only
- Two authors: "LastName1 and LastName2"
- Three+ authors: "FirstAuthor et al."
- Team names (contain "Team"): Entity name only without "et al."
-
+
Args:
authors_str: Comma-separated string of author names or "Authors et al." format
-
+
Returns:
Formatted author string
"""
if not authors_str or not isinstance(authors_str, str):
return ""
-
+
authors_str = authors_str.strip()
-
+
# Handle "X et al." format - extract just the first author
if " et al." in authors_str:
first_author = authors_str.split(" et al.")[0].strip()
# Extract last name from first author
first_author_last = first_author.split()[-1]
return f"{first_author_last} et al."
-
+
# Check if it's a team name (contains "Team", "AI", etc.)
- if any(team_keyword in authors_str for team_keyword in ["Team", "team", "-Ai", "Mistral", "Meta", "OpenAI", "DeepMind", "Google"]):
+ if any(
+ team_keyword in authors_str
+ for team_keyword in [
+ "Team",
+ "team",
+ "-Ai",
+ "Mistral",
+ "Meta",
+ "OpenAI",
+ "DeepMind",
+ "Google",
+ ]
+ ):
# Just return the entity name as is, or extract last part
parts = authors_str.split(",")
if len(parts) > 0:
return parts[0].strip()
return authors_str
-
+
# Split by comma
authors = [a.strip() for a in authors_str.split(",")]
-
+
if len(authors) > 2:
first_author_last_name = authors[0].split()[-1]
return f"{first_author_last_name} et al."
diff --git a/tests/test_arxiv.py b/tests/test_arxiv.py
index 1ba2d66..530cbb1 100644
--- a/tests/test_arxiv.py
+++ b/tests/test_arxiv.py
@@ -1,10 +1,10 @@
"""Tests for arXiv metadata fetching module."""
-import pytest
import tempfile
from pathlib import Path
-from unittest.mock import patch, Mock
-import sys
+from unittest.mock import Mock, patch
+
+import pytest
from pkg.mybib import arxiv
@@ -86,12 +86,12 @@ def sample_arxiv_response_no_journal():
class TestFetchArxivMetadata:
"""Test fetching metadata from arXiv API."""
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_success(self, mock_get, sample_arxiv_response):
"""Test successfully fetching arXiv metadata."""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = sample_arxiv_response.encode('utf-8')
+ mock_response.content = sample_arxiv_response.encode("utf-8")
mock_get.return_value = mock_response
result = arxiv.fetch_arxiv_metadata("1706.03762")
@@ -106,12 +106,14 @@ def test_fetch_arxiv_metadata_success(self, mock_get, sample_arxiv_response):
assert result["link"] == "https://arxiv.org/abs/1706.03762"
assert result["arxiv_id"] == "1706.03762"
- @patch('pkg.mybib.arxiv.requests.get')
- def test_fetch_arxiv_metadata_multiple_authors(self, mock_get, sample_arxiv_response):
+ @patch("pkg.mybib.arxiv.requests.get")
+ def test_fetch_arxiv_metadata_multiple_authors(
+ self, mock_get, sample_arxiv_response
+ ):
"""Test that multiple authors are properly parsed."""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = sample_arxiv_response.encode('utf-8')
+ mock_response.content = sample_arxiv_response.encode("utf-8")
mock_get.return_value = mock_response
result = arxiv.fetch_arxiv_metadata("1706.03762")
@@ -122,12 +124,12 @@ def test_fetch_arxiv_metadata_multiple_authors(self, mock_get, sample_arxiv_resp
assert "Noam Shazeer" in authors
assert "Parmar Aidan" in authors
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_no_doi(self, mock_get, sample_arxiv_response_no_doi):
"""Test handling of metadata without DOI (uses arxiv_id as fallback)."""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = sample_arxiv_response_no_doi.encode('utf-8')
+ mock_response.content = sample_arxiv_response_no_doi.encode("utf-8")
mock_get.return_value = mock_response
result = arxiv.fetch_arxiv_metadata("2003.06123")
@@ -137,12 +139,14 @@ def test_fetch_arxiv_metadata_no_doi(self, mock_get, sample_arxiv_response_no_do
assert result["doi"] == "2003.06123" # Falls back to arxiv_id
assert result["journal"] == "arXiv"
- @patch('pkg.mybib.arxiv.requests.get')
- def test_fetch_arxiv_metadata_no_journal(self, mock_get, sample_arxiv_response_no_journal):
+ @patch("pkg.mybib.arxiv.requests.get")
+ def test_fetch_arxiv_metadata_no_journal(
+ self, mock_get, sample_arxiv_response_no_journal
+ ):
"""Test handling of metadata without journal reference."""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = sample_arxiv_response_no_journal.encode('utf-8')
+ mock_response.content = sample_arxiv_response_no_journal.encode("utf-8")
mock_get.return_value = mock_response
result = arxiv.fetch_arxiv_metadata("2201.04000")
@@ -151,8 +155,8 @@ def test_fetch_arxiv_metadata_no_journal(self, mock_get, sample_arxiv_response_n
assert result["journal"] == "arXiv" # Default journal
assert result["year"] == 2022
- @patch('pkg.mybib.arxiv.sys.exit', side_effect=SystemExit)
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.sys.exit", side_effect=SystemExit)
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_http_error(self, mock_get, mock_exit):
"""Test handling of HTTP errors."""
mock_response = Mock()
@@ -162,8 +166,8 @@ def test_fetch_arxiv_metadata_http_error(self, mock_get, mock_exit):
with pytest.raises(SystemExit):
arxiv.fetch_arxiv_metadata("1706.03762")
- @patch('pkg.mybib.arxiv.sys.exit', side_effect=SystemExit)
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.sys.exit", side_effect=SystemExit)
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_not_found(self, mock_get, mock_exit):
"""Test handling when no entry is found."""
empty_response = """
@@ -171,18 +175,18 @@ def test_fetch_arxiv_metadata_not_found(self, mock_get, mock_exit):
"""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = empty_response.encode('utf-8')
+ mock_response.content = empty_response.encode("utf-8")
mock_get.return_value = mock_response
with pytest.raises(SystemExit):
arxiv.fetch_arxiv_metadata("9999.99999")
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_url_formation(self, mock_get, sample_arxiv_response):
"""Test that the correct URL is being formed."""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = sample_arxiv_response.encode('utf-8')
+ mock_response.content = sample_arxiv_response.encode("utf-8")
mock_get.return_value = mock_response
arxiv.fetch_arxiv_metadata("1706.03762")
@@ -193,7 +197,7 @@ def test_fetch_arxiv_metadata_url_formation(self, mock_get, sample_arxiv_respons
assert "export.arxiv.org/api/query" in called_url
assert "1706.03762" in called_url
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_whitespace_in_title(self, mock_get):
"""Test that whitespace in title is normalized."""
response_with_newline = """
@@ -213,7 +217,7 @@ def test_fetch_arxiv_metadata_whitespace_in_title(self, mock_get):
"""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = response_with_newline.encode('utf-8')
+ mock_response.content = response_with_newline.encode("utf-8")
mock_get.return_value = mock_response
result = arxiv.fetch_arxiv_metadata("1706.03762")
@@ -223,7 +227,7 @@ def test_fetch_arxiv_metadata_whitespace_in_title(self, mock_get):
assert "Attention Is All" in result["title"]
assert "You Need" in result["title"]
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_single_author(self, mock_get):
"""Test handling of single author."""
single_author_response = """
@@ -241,14 +245,14 @@ def test_fetch_arxiv_metadata_single_author(self, mock_get):
"""
mock_response = Mock()
mock_response.status_code = 200
- mock_response.content = single_author_response.encode('utf-8')
+ mock_response.content = single_author_response.encode("utf-8")
mock_get.return_value = mock_response
result = arxiv.fetch_arxiv_metadata("2105.00001")
assert result["authors"] == "Alice Author"
- @patch('pkg.mybib.arxiv.requests.get')
+ @patch("pkg.mybib.arxiv.requests.get")
def test_fetch_arxiv_metadata_connection_error(self, mock_get):
"""Test handling of connection errors."""
mock_get.side_effect = Exception("Connection error")
diff --git a/tests/test_markdown.py b/tests/test_markdown.py
index ea125bc..ff87da7 100644
--- a/tests/test_markdown.py
+++ b/tests/test_markdown.py
@@ -1,11 +1,12 @@
"""Tests for markdown generation module."""
-import pytest
-import pandas as pd
import tempfile
from pathlib import Path
-from pkg.mybib import storage, markdown
+import pandas as pd
+import pytest
+
+from pkg.mybib import markdown, storage
@pytest.fixture
@@ -27,7 +28,7 @@ def csv_with_references(temp_csv):
"Year": 2023,
"DOI": "10.1234/ml.2023",
"Link": "https://example.com/paper1",
- "Category": "Machine Learning"
+ "Category": "Machine Learning",
},
{
"Title": "Deep Neural Networks",
@@ -36,7 +37,7 @@ def csv_with_references(temp_csv):
"Year": 2022,
"DOI": "10.5678/nn.2022",
"Link": "https://example.com/paper2",
- "Category": "Machine Learning"
+ "Category": "Machine Learning",
},
{
"Title": "Computer Vision Survey",
@@ -45,8 +46,8 @@ def csv_with_references(temp_csv):
"Year": 2024,
"DOI": "10.9999/cv.2024",
"Link": "https://example.com/paper3",
- "Category": "Computer Vision"
- }
+ "Category": "Computer Vision",
+ },
]
df = pd.DataFrame(references)
@@ -72,7 +73,7 @@ def test_make_markdown_table_single_reference(self, temp_csv):
doi="10.0000/test.2023",
link="https://test.com",
category="Testing",
- file_path=temp_csv
+ file_path=temp_csv,
)
result = markdown.make_markdown_table(temp_csv)
@@ -140,7 +141,7 @@ def test_make_markdown_table_is_valid_markdown(self, temp_csv):
doi="10.1111/paper1.2023",
link="https://link1.com",
category="Cat1",
- file_path=temp_csv
+ file_path=temp_csv,
)
storage.add_reference(
@@ -151,13 +152,13 @@ def test_make_markdown_table_is_valid_markdown(self, temp_csv):
doi="10.2222/paper2.2022",
link="https://link2.com",
category="Cat1",
- file_path=temp_csv
+ file_path=temp_csv,
)
result = markdown.make_markdown_table(temp_csv)
# Check for markdown table structure
- lines = result.split('\n')
+ lines = result.split("\n")
assert len(lines) >= 4 # Header, separator, at least 2 data rows
assert "|" in lines[0] # Header has pipes
assert "-" in lines[1] # Separator has dashes
@@ -181,7 +182,7 @@ def test_make_markdown_tables_by_category_single_category(self, temp_csv):
doi="10.1111/ml1.2023",
link="https://link1.com",
category="Machine Learning",
- file_path=temp_csv
+ file_path=temp_csv,
)
storage.add_reference(
@@ -192,7 +193,7 @@ def test_make_markdown_tables_by_category_single_category(self, temp_csv):
doi="10.1111/ml2.2022",
link="https://link2.com",
category="Machine Learning",
- file_path=temp_csv
+ file_path=temp_csv,
)
result = markdown.make_markdown_tables_by_category(temp_csv)
@@ -201,7 +202,9 @@ def test_make_markdown_tables_by_category_single_category(self, temp_csv):
assert "ML Paper 1" in result
assert "ML Paper 2" in result
- def test_make_markdown_tables_by_category_multiple_categories(self, csv_with_references):
+ def test_make_markdown_tables_by_category_multiple_categories(
+ self, csv_with_references
+ ):
"""Test generating markdown with multiple categories."""
result = markdown.make_markdown_tables_by_category(csv_with_references)
@@ -220,7 +223,7 @@ def test_make_markdown_tables_by_category_excludes_category_column(self, temp_cs
doi="10.0000/test.2023",
link="https://test.com",
category="Testing",
- file_path=temp_csv
+ file_path=temp_csv,
)
result = markdown.make_markdown_tables_by_category(temp_csv)
@@ -230,7 +233,9 @@ def test_make_markdown_tables_by_category_excludes_category_column(self, temp_cs
# But the actual category values shouldn't appear in the table
# (only in the section header)
- def test_make_markdown_tables_by_category_has_footer_links(self, csv_with_references):
+ def test_make_markdown_tables_by_category_has_footer_links(
+ self, csv_with_references
+ ):
"""Test that footer contains DOI links."""
result = markdown.make_markdown_tables_by_category(csv_with_references)
@@ -251,7 +256,7 @@ def test_make_markdown_tables_by_category_footer_uniqueness(self, temp_csv):
doi="10.1111/unique.2023",
link="https://link1.com",
category="Category1",
- file_path=temp_csv
+ file_path=temp_csv,
)
storage.add_reference(
@@ -262,7 +267,7 @@ def test_make_markdown_tables_by_category_footer_uniqueness(self, temp_csv):
doi="10.2222/unique2.2022",
link="https://link2.com",
category="Category2",
- file_path=temp_csv
+ file_path=temp_csv,
)
result = markdown.make_markdown_tables_by_category(temp_csv)
@@ -271,7 +276,9 @@ def test_make_markdown_tables_by_category_footer_uniqueness(self, temp_csv):
assert result.count("https://link1.com") >= 1
assert result.count("https://link2.com") >= 1
- def test_make_markdown_tables_by_category_sorted_categories(self, csv_with_references):
+ def test_make_markdown_tables_by_category_sorted_categories(
+ self, csv_with_references
+ ):
"""Test that categories appear in the output."""
result = markdown.make_markdown_tables_by_category(csv_with_references)
@@ -282,7 +289,9 @@ def test_make_markdown_tables_by_category_sorted_categories(self, csv_with_refer
assert ml_pos >= 0 or cv_pos >= 0
assert "## Computer Vision" in result or "## Machine Learning" in result
- def test_make_markdown_tables_by_category_within_category_sorting(self, csv_with_references):
+ def test_make_markdown_tables_by_category_within_category_sorting(
+ self, csv_with_references
+ ):
"""Test that within each category, papers are sorted by year descending."""
result = markdown.make_markdown_tables_by_category(csv_with_references)
@@ -294,8 +303,8 @@ def test_make_markdown_tables_by_category_within_category_sorting(self, csv_with
if basics_pos > 0 and nn_pos > 0:
# Get the category that comes before them
ml_section_start = result.rfind("## Machine Learning")
- cv_section_start = result.rfind("## Computer Vision")
-
+ result.rfind("## Computer Vision")
+
# Check they're in the ML section and ordered by year
if ml_section_start > 0 and ml_section_start < min(basics_pos, nn_pos):
# If next category marker exists, they should be before it
@@ -313,7 +322,7 @@ def test_make_markdown_tables_by_category_link_column_excluded(self, temp_csv):
doi="10.0000/test.2023",
link="https://special-link.com/unique",
category="Testing",
- file_path=temp_csv
+ file_path=temp_csv,
)
result = markdown.make_markdown_tables_by_category(temp_csv)
diff --git a/tests/test_metadata.py b/tests/test_metadata.py
index 676e3fd..95a84a3 100644
--- a/tests/test_metadata.py
+++ b/tests/test_metadata.py
@@ -1,13 +1,14 @@
"""Tests for metadata extraction module."""
import sys
-import pytest
import tempfile
from pathlib import Path
-from unittest.mock import patch, MagicMock
+from unittest.mock import patch
+
+import pytest
# Add parent directory to path
-sys.path.insert(0, '../pkg')
+sys.path.insert(0, "../pkg")
from mybib import metadata
@@ -28,19 +29,19 @@ def temp_file(temp_dir):
class TestSourceDetection:
"""Test source detection functions."""
-
+
def test_arxiv_url_detection(self):
"""Test arXiv URL detection."""
assert metadata._is_arxiv_url("https://arxiv.org/abs/2301.00001")
assert metadata._is_arxiv_url("http://arxiv.org/pdf/2301.00001.pdf")
assert not metadata._is_arxiv_url("https://doi.org/10.1234/example")
-
+
def test_doi_url_detection(self):
"""Test DOI URL detection."""
assert metadata._is_doi_url("https://doi.org/10.1145/1234567")
assert metadata._is_doi_url("http://doi.org/10.1234/example")
assert not metadata._is_doi_url("https://arxiv.org/abs/2301.00001")
-
+
def test_doi_pattern_detection(self):
"""Test DOI pattern detection."""
assert metadata._is_doi_pattern("10.1145/1234567")
@@ -50,12 +51,18 @@ def test_doi_pattern_detection(self):
class TestIdExtraction:
"""Test ID extraction functions."""
-
+
def test_arxiv_id_extraction(self):
"""Test arXiv ID extraction from URLs."""
- assert metadata._extract_arxiv_id("https://arxiv.org/abs/2301.00001") == "2301.00001"
- assert metadata._extract_arxiv_id("https://arxiv.org/pdf/2301.00001v2.pdf") == "2301.00001"
-
+ assert (
+ metadata._extract_arxiv_id("https://arxiv.org/abs/2301.00001")
+ == "2301.00001"
+ )
+ assert (
+ metadata._extract_arxiv_id("https://arxiv.org/pdf/2301.00001v2.pdf")
+ == "2301.00001"
+ )
+
def test_arxiv_id_extraction_invalid(self):
"""Test arXiv ID extraction with invalid URL."""
with pytest.raises(SystemExit):
@@ -64,11 +71,14 @@ def test_arxiv_id_extraction_invalid(self):
class TestDoiNormalization:
"""Test DOI normalization."""
-
+
def test_normalize_doi_from_url(self):
"""Test DOI extraction from URL."""
- assert metadata._normalize_doi("https://doi.org/10.1145/1234567") == "10.1145/1234567"
-
+ assert (
+ metadata._normalize_doi("https://doi.org/10.1145/1234567")
+ == "10.1145/1234567"
+ )
+
def test_normalize_doi_already_normalized(self):
"""Test DOI that's already normalized."""
assert metadata._normalize_doi("10.1145/1234567") == "10.1145/1234567"
@@ -76,51 +86,51 @@ def test_normalize_doi_already_normalized(self):
class TestHtmlMetaExtraction:
"""Test HTML meta tag extraction."""
-
+
def test_extract_html_meta_with_property(self):
"""Test extraction of meta tag with property attribute."""
html = ''
result = metadata._extract_html_meta(html, ["og:title"])
assert result == "Test Title"
-
+
def test_extract_html_meta_with_name(self):
"""Test extraction of meta tag with name attribute."""
html = ''
result = metadata._extract_html_meta(html, ["author"])
assert result == "John Doe"
-
+
def test_extract_html_meta_not_found(self):
"""Test extraction when meta tag is not found."""
- html = ''
+ html = ""
result = metadata._extract_html_meta(html, ["og:title", "author"])
assert result is None
class TestFunctionRouting:
"""Test that fetch_metadata routes to correct handler."""
-
- @patch('mybib.metadata._fetch_arxiv_metadata')
+
+ @patch("mybib.metadata._fetch_arxiv_metadata")
def test_route_to_arxiv(self, mock_arxiv):
"""Test routing to arXiv handler."""
mock_arxiv.return_value = {"title": "arxiv paper"}
metadata.fetch_metadata("https://arxiv.org/abs/2301.00001")
mock_arxiv.assert_called_once()
-
- @patch('mybib.metadata._fetch_crossref_metadata')
+
+ @patch("mybib.metadata._fetch_crossref_metadata")
def test_route_to_doi_url(self, mock_doi):
"""Test routing to DOI URL handler."""
mock_doi.return_value = {"title": "doi paper"}
metadata.fetch_metadata("https://doi.org/10.1145/1234567")
mock_doi.assert_called_once()
-
- @patch('mybib.metadata._fetch_crossref_metadata')
+
+ @patch("mybib.metadata._fetch_crossref_metadata")
def test_route_to_doi_pattern(self, mock_doi):
"""Test routing to DOI pattern handler."""
mock_doi.return_value = {"title": "doi paper"}
metadata.fetch_metadata("10.1145/1234567")
mock_doi.assert_called_once()
-
- @patch('mybib.metadata._fetch_generic_metadata')
+
+ @patch("mybib.metadata._fetch_generic_metadata")
def test_route_to_generic(self, mock_generic):
"""Test routing to generic URL handler."""
mock_generic.return_value = {"title": "generic page"}
@@ -130,8 +140,8 @@ def test_route_to_generic(self, mock_generic):
class TestReturnFormat:
"""Test that all handlers return standardized format."""
-
- @patch('mybib.metadata._fetch_generic_metadata')
+
+ @patch("mybib.metadata._fetch_generic_metadata")
def test_return_format_has_required_fields(self, mock_generic):
"""Test that returned dict has all required fields."""
mock_generic.return_value = {
@@ -140,10 +150,10 @@ def test_return_format_has_required_fields(self, mock_generic):
"journal": "Journal",
"year": 2023,
"doi": "10.1234/test",
- "link": "https://example.com"
+ "link": "https://example.com",
}
result = metadata.fetch_metadata("https://example.com")
-
+
required_fields = {"title", "authors", "journal", "year", "doi", "link"}
assert required_fields.issubset(result.keys())
@@ -154,7 +164,7 @@ def demo_arxiv_metadata():
print("\n=== Demo: arXiv Metadata ===")
url = "https://arxiv.org/abs/2301.00001"
print(f"URL: {url}")
- print(f"Expected: Metadata for arXiv paper with automatic source detection")
+ print("Expected: Metadata for arXiv paper with automatic source detection")
# Note: Actual API call commented out to avoid network dependency in tests
# result = metadata.fetch_metadata(url)
# print(f"Result: {result}")
@@ -165,7 +175,7 @@ def demo_doi_metadata():
print("\n=== Demo: DOI Metadata ===")
url = "https://doi.org/10.1145/1234567"
print(f"URL: {url}")
- print(f"Expected: Metadata from Crossref API for DOI")
+ print("Expected: Metadata from Crossref API for DOI")
# Note: Actual API call commented out to avoid network dependency in tests
# result = metadata.fetch_metadata(url)
# print(f"Result: {result}")
@@ -176,7 +186,7 @@ def demo_generic_metadata():
print("\n=== Demo: Generic URL Metadata ===")
url = "https://example.com/article"
print(f"URL: {url}")
- print(f"Expected: Metadata extracted from HTML meta tags")
+ print("Expected: Metadata extracted from HTML meta tags")
# Note: Actual API call commented out to avoid network dependency in tests
# result = metadata.fetch_metadata(url)
# print(f"Result: {result}")
diff --git a/tests/test_storage.py b/tests/test_storage.py
index ccfed33..9e89376 100644
--- a/tests/test_storage.py
+++ b/tests/test_storage.py
@@ -1,11 +1,11 @@
"""Tests for storage module."""
-import pytest
-import pandas as pd
import tempfile
from pathlib import Path
from unittest.mock import patch
-import sys
+
+import pandas as pd
+import pytest
from pkg.mybib import storage
@@ -47,7 +47,7 @@ def test_add_reference_to_empty_file(self, temp_csv, sample_references):
link=sample_references["link"],
category=sample_references["category"],
arxiv_id=sample_references["arxiv_id"],
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv)
@@ -67,7 +67,7 @@ def test_add_multiple_references(self, temp_csv, sample_references):
link=sample_references["link"],
category=sample_references["category"],
arxiv_id=sample_references["arxiv_id"],
- file_path=temp_csv
+ file_path=temp_csv,
)
# Add second reference
@@ -79,7 +79,7 @@ def test_add_multiple_references(self, temp_csv, sample_references):
doi="10.5678/example.2024",
link="https://example.com/paper2",
category="Deep Learning",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv)
@@ -98,7 +98,7 @@ def test_add_reference_preserves_headers(self, temp_csv, sample_references):
link=sample_references["link"],
category=sample_references["category"],
arxiv_id=sample_references["arxiv_id"],
- file_path=temp_csv
+ file_path=temp_csv,
)
storage.add_reference(
@@ -109,12 +109,21 @@ def test_add_reference_preserves_headers(self, temp_csv, sample_references):
doi="10.9999/another.2022",
link="https://example.com/paper3",
category="Research",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv)
# Check headers are correct
- expected_headers = ["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"]
+ expected_headers = [
+ "Title",
+ "Authors",
+ "Journal",
+ "Year",
+ "DOI",
+ "Link",
+ "Category",
+ "ArxivID",
+ ]
assert list(df.columns) == expected_headers
@@ -131,11 +140,11 @@ def test_duplicate_doi_detection(self, temp_csv, sample_references):
doi=sample_references["doi"],
link=sample_references["link"],
category=sample_references["category"],
- file_path=temp_csv
+ file_path=temp_csv,
)
# Try to add duplicate with different title
- with patch('sys.exit') as mock_exit:
+ with patch("sys.exit") as mock_exit:
storage.add_reference(
title="Different Title",
authors="Different Authors",
@@ -144,7 +153,7 @@ def test_duplicate_doi_detection(self, temp_csv, sample_references):
doi=sample_references["doi"], # Same DOI
link="https://different.com",
category="Different Category",
- file_path=temp_csv
+ file_path=temp_csv,
)
mock_exit.assert_called_once_with(0)
@@ -158,11 +167,11 @@ def test_duplicate_detection_case_insensitive(self, temp_csv, sample_references)
doi=sample_references["doi"],
link=sample_references["link"],
category=sample_references["category"],
- file_path=temp_csv
+ file_path=temp_csv,
)
# Try to add duplicate with different case DOI
- with patch('sys.exit') as mock_exit:
+ with patch("sys.exit") as mock_exit:
storage.add_reference(
title="Another Title",
authors="Another Author",
@@ -171,7 +180,7 @@ def test_duplicate_detection_case_insensitive(self, temp_csv, sample_references)
doi=sample_references["doi"].upper(), # Different case
link="https://another.com",
category="Another Category",
- file_path=temp_csv
+ file_path=temp_csv,
)
mock_exit.assert_called_once_with(0)
@@ -185,11 +194,11 @@ def test_duplicate_detection_with_whitespace(self, temp_csv, sample_references):
doi=sample_references["doi"],
link=sample_references["link"],
category=sample_references["category"],
- file_path=temp_csv
+ file_path=temp_csv,
)
# Try to add duplicate with extra whitespace
- with patch('sys.exit') as mock_exit:
+ with patch("sys.exit") as mock_exit:
storage.add_reference(
title="Yet Another Title",
authors="Yet Another Author",
@@ -198,7 +207,7 @@ def test_duplicate_detection_with_whitespace(self, temp_csv, sample_references):
doi=f" {sample_references['doi']} ", # Whitespace added
link="https://yet-another.com",
category="Yet Another Category",
- file_path=temp_csv
+ file_path=temp_csv,
)
mock_exit.assert_called_once_with(0)
@@ -212,7 +221,7 @@ def test_no_duplicate_with_different_doi(self, temp_csv, sample_references):
doi=sample_references["doi"],
link=sample_references["link"],
category=sample_references["category"],
- file_path=temp_csv
+ file_path=temp_csv,
)
# Add reference with different DOI
@@ -224,7 +233,7 @@ def test_no_duplicate_with_different_doi(self, temp_csv, sample_references):
doi="10.1111/different.2024", # Different DOI
link="https://another.com",
category="Different Category",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv)
@@ -244,7 +253,7 @@ def test_load_references_from_existing_file(self, temp_csv, sample_references):
doi=sample_references["doi"],
link=sample_references["link"],
category=sample_references["category"],
- file_path=temp_csv
+ file_path=temp_csv,
)
df = storage.load_references(temp_csv)
@@ -257,9 +266,18 @@ def test_load_references_creates_empty_file(self, temp_csv):
Path(temp_csv).unlink(missing_ok=True)
df = storage.load_references(temp_csv)
-
+
assert df.empty
- assert list(df.columns) == ["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"]
+ assert list(df.columns) == [
+ "Title",
+ "Authors",
+ "Journal",
+ "Year",
+ "DOI",
+ "Link",
+ "Category",
+ "ArxivID",
+ ]
assert Path(temp_csv).exists()
def test_load_references_preserves_data(self, temp_csv, sample_references):
@@ -271,10 +289,10 @@ def test_load_references_preserves_data(self, temp_csv, sample_references):
authors=f"Author {i}",
journal=f"Journal {i}",
year=2020 + i,
- doi=f"10.{i}/example.{2020+i}",
+ doi=f"10.{i}/example.{2020 + i}",
link=f"https://example.com/paper{i}",
category=f"Category {i}",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = storage.load_references(temp_csv)
@@ -296,7 +314,7 @@ def test_add_reference_with_scholar_id_fallback(self, temp_csv):
link="https://scholar.com/paper",
category="Testing",
scholar_id="scholar_id_12345",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv)
@@ -314,7 +332,7 @@ def test_add_reference_prefers_doi_over_scholar_id(self, temp_csv):
link="https://example.com/paper",
category="Testing",
scholar_id="scholar_id_67890",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv)
@@ -336,7 +354,7 @@ def test_add_reference_with_arxiv_id(self, temp_csv):
link="https://arxiv.org/abs/2301.00001",
category="Machine Learning",
arxiv_id="2301.00001",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv, dtype={"ArxivID": str})
@@ -353,7 +371,7 @@ def test_add_reference_without_arxiv_id(self, temp_csv):
doi="10.1234/example.2023",
link="https://example.com/paper",
category="Research",
- file_path=temp_csv
+ file_path=temp_csv,
)
df = pd.read_csv(temp_csv, dtype={"ArxivID": str})