From 3b1ea12ec08036b23f6c72bfe566cf6c63a27437 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:32:53 +0000 Subject: [PATCH 1/2] Initial plan From 2640a0aa1678b9f0e4bdcaded19c739c425187ab Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 17 Mar 2026 13:46:29 +0000 Subject: [PATCH 2/2] Resolve conflicts: prefer copilot-vibe-code (latest) with lint fixes applied Co-authored-by: art-test-stack <110672812+art-test-stack@users.noreply.github.com> --- pkg/mybib/arxiv.py | 42 +++---- pkg/mybib/bibtex.py | 77 ++++++------ pkg/mybib/categories.py | 52 ++++---- pkg/mybib/cli.py | 260 ++++++++++++++++++---------------------- pkg/mybib/db_storage.py | 127 +++++++++++--------- pkg/mybib/graph.py | 143 +++++++++++----------- pkg/mybib/markdown.py | 47 ++++---- pkg/mybib/metadata.py | 71 ++++++----- pkg/mybib/models.py | 26 ++-- pkg/mybib/scholar.py | 117 +++++++++--------- pkg/mybib/storage.py | 40 ++++--- pkg/mybib/ui.py | 37 +++--- pkg/mybib/utils.py | 30 +++-- tests/test_arxiv.py | 54 +++++---- tests/test_markdown.py | 55 +++++---- tests/test_metadata.py | 74 +++++++----- tests/test_storage.py | 76 +++++++----- 17 files changed, 703 insertions(+), 625 deletions(-) diff --git a/pkg/mybib/arxiv.py b/pkg/mybib/arxiv.py index bf3d3ad..ab67841 100644 --- a/pkg/mybib/arxiv.py +++ b/pkg/mybib/arxiv.py @@ -1,54 +1,54 @@ """Fetch metadata from arXiv API.""" import sys -import requests import xml.etree.ElementTree as ET +import requests + def fetch_arxiv_metadata(arxiv_id: str) -> dict: """Fetch metadata from arXiv API. - + Args: arxiv_id: arXiv identifier (e.g., '2301.00001') - + Returns: Dictionary with keys: title, authors, journal, year, doi, link, arxiv_id - + Raises: SystemExit: If API call fails or no entry found """ url = f"http://export.arxiv.org/api/query?id_list={arxiv_id}" response = requests.get(url) - + if response.status_code != 200: print(f"Error fetching arxiv metadata: {response.status_code}") sys.exit(1) - + root = ET.fromstring(response.content) ns = { - 'atom': 'http://www.w3.org/2005/Atom', - 'arxiv': 'http://arxiv.org/schemas/atom' + "atom": "http://www.w3.org/2005/Atom", + "arxiv": "http://arxiv.org/schemas/atom", } - - entry = root.find('atom:entry', ns) + + entry = root.find("atom:entry", ns) if entry is None: print("No entry found for this arxiv ID.") sys.exit(1) - - title = entry.find('atom:title', ns).text.strip().replace('\n', ' ') - authors = ', '.join( - author.find('atom:name', ns).text - for author in entry.findall('atom:author', ns) + + title = entry.find("atom:title", ns).text.strip().replace("\n", " ") + authors = ", ".join( + author.find("atom:name", ns).text for author in entry.findall("atom:author", ns) ) - published = entry.find('atom:published', ns).text + published = entry.find("atom:published", ns).text year = int(published[:4]) - - doi_elem = entry.find('arxiv:doi', ns) + + doi_elem = entry.find("arxiv:doi", ns) doi = doi_elem.text if doi_elem is not None else arxiv_id - - journal_elem = entry.find('arxiv:journal_ref', ns) + + journal_elem = entry.find("arxiv:journal_ref", ns) journal = journal_elem.text if journal_elem is not None else "arXiv" - + return { "title": title, "authors": authors, diff --git a/pkg/mybib/bibtex.py b/pkg/mybib/bibtex.py index 6fc0462..2c2eb16 100644 --- a/pkg/mybib/bibtex.py +++ b/pkg/mybib/bibtex.py @@ -5,75 +5,74 @@ def generate_bibtex(df: pd.DataFrame) -> str: """Generate BibTeX entries from a DataFrame of references. - + Args: df: DataFrame with columns: Title, Authors, Journal, Year, DOI, Link - + Returns: String containing BibTeX formatted entries """ if df.empty: return "% No references found\n" - + # Handle cases where DataFrame column names might not exactly match # Create a mapping of expected columns to actual columns col_mapping = {} - + for col in df.columns: col_lower = col.lower().strip() - if col_lower == 'title': - col_mapping['title'] = col - elif col_lower == 'authors': - col_mapping['authors'] = col - elif col_lower == 'journal': - col_mapping['journal'] = col - elif col_lower == 'year': - col_mapping['year'] = col - elif col_lower == 'doi': - col_mapping['doi'] = col - elif col_lower == 'link': - col_mapping['link'] = col - elif col_lower == 'url': - col_mapping['link'] = col - + if col_lower == "title": + col_mapping["title"] = col + elif col_lower == "authors": + col_mapping["authors"] = col + elif col_lower == "journal": + col_mapping["journal"] = col + elif col_lower == "year": + col_mapping["year"] = col + elif col_lower == "doi": + col_mapping["doi"] = col + elif col_lower == "link": + col_mapping["link"] = col + elif col_lower == "url": + col_mapping["link"] = col + entries = [] - + for _, row in df.iterrows(): # Extract values using mapped column names - title = str(row.get(col_mapping.get('title', 'Title'), "")).strip() - authors = str(row.get(col_mapping.get('authors', 'Authors'), "")).strip() - journal = str(row.get(col_mapping.get('journal', 'Journal'), "")).strip() - year = str(row.get(col_mapping.get('year', 'Year'), "")).strip() - doi = str(row.get(col_mapping.get('doi', 'DOI'), "")).strip() - link = str(row.get(col_mapping.get('link', 'Link'), "")).strip() - + title = str(row.get(col_mapping.get("title", "Title"), "")).strip() + authors = str(row.get(col_mapping.get("authors", "Authors"), "")).strip() + journal = str(row.get(col_mapping.get("journal", "Journal"), "")).strip() + year = str(row.get(col_mapping.get("year", "Year"), "")).strip() + doi = str(row.get(col_mapping.get("doi", "DOI"), "")).strip() + link = str(row.get(col_mapping.get("link", "Link"), "")).strip() + # Skip entries with missing critical fields if not title: continue - + # Use DOI as the key, or create one from title if DOI is missing if doi and doi != "nan": key = doi else: # Fallback: create key from title key = title.lower().replace(" ", "_").replace(":", "")[:30] - + # Build BibTeX entry entry = f"@article{{{key},\n" - entry += f' title={{{title}}},\n' - + entry += f" title={{{title}}},\n" + if authors: - entry += f' author={{{authors}}},\n' + entry += f" author={{{authors}}},\n" if journal: - entry += f' journal={{{journal}}},\n' + entry += f" journal={{{journal}}},\n" if year and year != "nan": - entry += f' year={{{year}}},\n' + entry += f" year={{{year}}},\n" if link and link != "nan": - entry += f' url={{{link}}}\n' - + entry += f" url={{{link}}}\n" + entry += "}\n" - + entries.append(entry) - - return "\n".join(entries) + return "\n".join(entries) diff --git a/pkg/mybib/categories.py b/pkg/mybib/categories.py index b7bbb14..4f30982 100644 --- a/pkg/mybib/categories.py +++ b/pkg/mybib/categories.py @@ -1,16 +1,15 @@ """Category management for bibliography.""" import json -from pathlib import Path from typing import Dict, List, Tuple def load_categories(file_path: str = "categories.json") -> Dict[str, str]: """Load category mappings from file. - + Args: file_path: Path to categories JSON file - + Returns: Dictionary mapping category ID to category name """ @@ -22,9 +21,11 @@ def load_categories(file_path: str = "categories.json") -> Dict[str, str]: return {} -def save_categories(categories: Dict[str, str], file_path: str = "categories.json") -> None: +def save_categories( + categories: Dict[str, str], file_path: str = "categories.json" +) -> None: """Save category mappings to file. - + Args: categories: Dictionary mapping category ID to category name file_path: Path to categories JSON file @@ -33,65 +34,70 @@ def save_categories(categories: Dict[str, str], file_path: str = "categories.jso json.dump(categories, f, indent=2, sort_keys=True) -def get_or_create_category(name: str, categories: Dict[str, str] = None) -> Tuple[str, Dict[str, str]]: +def get_or_create_category( + name: str, categories: Dict[str, str] = None +) -> Tuple[str, Dict[str, str]]: """Get category ID for given name, creating if needed. - + Uses lowercase normalization to group similar categories. - + Args: - name: Category name + name: Category name categories: Existing categories dict (loads from file if not provided) - + Returns: Tuple of (category_id, updated_categories_dict) """ if categories is None: categories = load_categories() - + # Normalize category name normalized = name.lower().strip() - + # Check if category already exists (case-insensitive) for cat_id, cat_name in categories.items(): if cat_name.lower() == normalized: return cat_id, categories - + # Create new category - new_id = str(max(int(cat_id) for cat_id in categories.keys() if cat_id.isdigit()) + 1 if categories else 1) + new_id = str( + max(int(cat_id) for cat_id in categories.keys() if cat_id.isdigit()) + 1 + if categories + else 1 + ) categories[new_id] = name - + return new_id, categories def list_categories(categories: Dict[str, str] = None) -> List[Tuple[str, str]]: """List all categories sorted by ID. - + Args: categories: Category mapping dict (loads from file if not provided) - + Returns: List of (id, name) tuples sorted by ID """ if categories is None: categories = load_categories() - + return sorted( - categories.items(), - key=lambda x: int(x[0]) if x[0].isdigit() else float('inf') + categories.items(), key=lambda x: int(x[0]) if x[0].isdigit() else float("inf") ) def get_category_name(cat_id: str, categories: Dict[str, str] = None) -> str: """Get category name by ID. - + Args: cat_id: Category ID categories: Category mapping dict (loads from file if not provided) - + Returns: Category name, or empty string if not found """ if categories is None: categories = load_categories() - + return categories.get(str(cat_id), "") diff --git a/pkg/mybib/cli.py b/pkg/mybib/cli.py index f80657c..299c3f9 100644 --- a/pkg/mybib/cli.py +++ b/pkg/mybib/cli.py @@ -4,32 +4,36 @@ import sys from .arxiv import fetch_arxiv_metadata -from .storage import add_reference, load_references -from .markdown import make_markdown_table, make_markdown_tables_by_category from .bibtex import generate_bibtex +from .categories import ( + get_or_create_category, + list_categories, + load_categories, + save_categories, +) from .graph import build_citation_graph, export_graph_html +from .markdown import make_markdown_table, make_markdown_tables_by_category from .scholar import search_and_confirm_article -from .categories import load_categories, get_or_create_category, list_categories, save_categories +from .storage import add_reference, load_references from .ui import ( + api_progress, + confirm_action, console, - print_success, + display_reference_preview, print_error, print_info, + print_success, print_warning, - api_progress, - confirm_action, - display_reference_preview, - create_references_table, ) def prompt_for_category(title: str, category_arg: str = None) -> str: """Prompt user to select or create a category. - + Args: title: Article title for context category_arg: Pre-specified category (used if provided) - + Returns: Category name """ @@ -39,21 +43,22 @@ def prompt_for_category(title: str, category_arg: str = None) -> str: cat_id, categories = get_or_create_category(category_arg, categories) save_categories(categories) return categories[cat_id] - + # Show existing categories and allow selection or creation categories = load_categories() cat_list = list_categories(categories) - + console.print("\n[bold]Available categories:[/]") for cat_id, cat_name in cat_list: console.print(f" {cat_id}: {cat_name}") - + # Prompt for selection while True: choice = console.input( - f"\n[bold]Select category ID for '{title}'[/] (or enter new category name): " + f"\n[bold]Select category ID for '{title}'[/] " + "(or enter new category name): " ).strip() - + if choice.isdigit() and choice in categories: return categories[choice] elif choice: @@ -68,21 +73,21 @@ def prompt_for_category(title: str, category_arg: str = None) -> str: def handle_add_arxiv(args) -> None: """Handle the add-arxiv command. - + Args: args: Parsed command-line arguments """ # Extract arxiv ID from URL arxiv_id = args.arxiv_url.split("/")[-1] - + # Fetch metadata from arXiv with progress indicator print_info(f"Fetching metadata for arXiv ID: {arxiv_id}") with api_progress(): metadata = fetch_arxiv_metadata(arxiv_id) - + # Get category using new category system - category = prompt_for_category(metadata['title'], args.category) - + category = prompt_for_category(metadata["title"], args.category) + # Show reference preview preview_data = { "title": metadata["title"], @@ -94,14 +99,14 @@ def handle_add_arxiv(args) -> None: console.print() display_reference_preview(preview_data) console.print() - + # Confirm before adding if not confirm_action( f"Add '[bold cyan]{metadata['title']}[/]' to category '[yellow]{category}[/]'?" ): print_warning("Aborted.") sys.exit(0) - + # Add reference to storage add_reference( title=metadata["title"], @@ -119,33 +124,33 @@ def handle_add_arxiv(args) -> None: def handle_add_scholar(args) -> None: """Handle the add-scholar command to search Google Scholar. - + Args: args: Parsed command-line arguments """ title = args.title url = args.url - + print_info("Searching Google Scholar for your article...") - + # If no title provided, try to extract from URL or abort if not title and not url: print_error("Either --title or --url must be provided") sys.exit(1) - + # Search query: use title if provided, else use URL search_query = title if title else url - + # Search and get confirmation from user metadata = search_and_confirm_article(search_query) - + if not metadata: print_error("Could not find or confirm article on Google Scholar") sys.exit(1) - + # Get category using new category system - category = prompt_for_category(metadata['title'], args.category) - + category = prompt_for_category(metadata["title"], args.category) + # Show reference preview preview_data = { "title": metadata["title"], @@ -157,14 +162,14 @@ def handle_add_scholar(args) -> None: console.print() display_reference_preview(preview_data) console.print() - + # Confirm before adding if not confirm_action( f"Add '[bold cyan]{metadata['title']}[/]' to category '[yellow]{category}[/]'?" ): print_warning("Aborted.") sys.exit(0) - + # Add reference to storage add_reference( title=metadata["title"], @@ -182,28 +187,30 @@ def handle_add_scholar(args) -> None: def handle_add_manual(args) -> None: """Handle the add command for manual reference entry. - + If only title is provided, searches Google Scholar automatically. All fields except title are optional. - + Args: args: Parsed command-line arguments """ # Check if we need to search Google Scholar # If only title and category are provided (other fields are None), search Scholar - has_manual_metadata = any([ - args.authors, - args.journal, - args.year, - args.doi, - args.link, - ]) - + has_manual_metadata = any( + [ + args.authors, + args.journal, + args.year, + args.doi, + args.link, + ] + ) + if not has_manual_metadata: # Only title provided, search Google Scholar print_info("Searching Google Scholar for your article...") metadata = search_and_confirm_article(args.title) - + if not metadata: print_error("Could not find or confirm article on Google Scholar") sys.exit(1) @@ -217,10 +224,10 @@ def handle_add_manual(args) -> None: "doi": args.doi, "link": args.link or "", } - + # Get category using new category system - category = prompt_for_category(metadata['title'], args.category) - + category = prompt_for_category(metadata["title"], args.category) + # Show reference preview preview_data = { "title": metadata["title"], @@ -232,12 +239,14 @@ def handle_add_manual(args) -> None: console.print() display_reference_preview(preview_data) console.print() - + # Confirm before adding - if not confirm_action(f"Add '[bold cyan]{metadata['title']}[/]' to [yellow]{category}[/]?"): + if not confirm_action( + f"Add '[bold cyan]{metadata['title']}[/]' to [yellow]{category}[/]?" + ): print_warning("Aborted.") sys.exit(0) - + add_reference( title=metadata["title"], authors=metadata.get("authors") or None, @@ -254,17 +263,17 @@ def handle_add_manual(args) -> None: def handle_markdown(args) -> None: """Handle the markdown command to generate markdown tables. - + Args: args: Parsed command-line arguments """ print_info("Generating markdown tables...") - + if args.by_category: table = make_markdown_tables_by_category(args.file) else: table = make_markdown_table(args.file) - + if args.output: with open(args.output, "w") as f: f.write(table) @@ -275,14 +284,14 @@ def handle_markdown(args) -> None: def handle_bibtex(args) -> None: """Handle the bibtex command to generate BibTeX file. - + Args: args: Parsed command-line arguments """ print_info("Generating BibTeX...") df = load_references(args.file) bibtex_content = generate_bibtex(df) - + if args.output: with open(args.output, "w") as f: f.write(bibtex_content) @@ -293,19 +302,19 @@ def handle_bibtex(args) -> None: def handle_graph(args) -> None: """Handle the graph command to build and visualize citations. - + Args: args: Parsed command-line arguments """ df = load_references(args.file) - + if df.empty: print_error(f"No references found in {args.file}") sys.exit(1) - + print_info(f"Building citation graph from {len(df)} references...") graph = build_citation_graph(df, output_references=args.verbose) - + output_file = args.output or "citation_graph.html" export_graph_html(graph, output_file) print_success(f"Citation graph exported to {output_file}") @@ -313,16 +322,16 @@ def handle_graph(args) -> None: def handle_db_init(args) -> None: """Handle database initialization. - + Args: args: Parsed command-line arguments """ from .db_storage import DatabaseStorage - + print_info(f"Initializing database: {args.db_url}") - + try: - storage = DatabaseStorage(args.db_url) + DatabaseStorage(args.db_url) print_success("Database initialized successfully!") except Exception as e: print_error(f"Failed to initialize database: {e}") @@ -331,24 +340,24 @@ def handle_db_init(args) -> None: def handle_db_migrate(args) -> None: """Handle migration from CSV to database. - + Args: args: Parsed command-line arguments """ from .db_storage import DatabaseStorage - + print_info(f"Migrating from {args.file} to {args.db_url}") - + try: storage = DatabaseStorage(args.db_url) stats = storage.migrate_from_csv(args.file) - + console.print("\n[bold]Migration Statistics:[/]") console.print(f" Total: {stats['total']}") console.print(f" Added: {stats['added']}") console.print(f" Duplicates: {stats['duplicates']}") console.print(f" Errors: {stats['errors']}") - + print_success("Migration completed!") except Exception as e: print_error(f"Failed to migrate database: {e}") @@ -357,14 +366,14 @@ def handle_db_migrate(args) -> None: def handle_db_export(args) -> None: """Handle export from database to CSV. - + Args: args: Parsed command-line arguments """ from .db_storage import DatabaseStorage - + print_info(f"Exporting from {args.db_url} to {args.output}") - + try: storage = DatabaseStorage(args.db_url) count = storage.export_to_csv(args.output) @@ -378,68 +387,56 @@ def main() -> None: """Main CLI entry point.""" parser = argparse.ArgumentParser( prog="mybib", - description="📚 Manage research paper references with ease. Similar to gh, poetry, and uv!", + description="📚 Manage research paper references with ease. " + "Similar to gh, poetry, and uv!", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: mybib add-arxiv https://arxiv.org/abs/2301.00001 --category ML - mybib add --title "My Paper" --authors "Author Name" --journal "Nature" --year 2024 --doi "10.xxxx/xxxxx" --category Science + mybib add --title "My Paper" --authors "Author Name" --journal "Nature" \ + --year 2024 --doi "10.xxxx/xxxxx" --category Science mybib markdown --file references.csv --output README.md mybib bibtex --file references.csv --output references.bib mybib graph --file references.csv --output graph.html - """ + """, ) subparsers = parser.add_subparsers(dest="command", help="Available commands") # add-arxiv command add_arxiv_parser = subparsers.add_parser( - "add-arxiv", - help="Add a reference from an arXiv URL" - ) - add_arxiv_parser.add_argument( - "arxiv_url", - help="The arXiv URL (e.g., https://arxiv.org/abs/2301.00001)" + "add-arxiv", help="Add a reference from an arXiv URL" ) add_arxiv_parser.add_argument( - "--category", - help="Category for the reference" + "arxiv_url", help="The arXiv URL (e.g., https://arxiv.org/abs/2301.00001)" ) + add_arxiv_parser.add_argument("--category", help="Category for the reference") add_arxiv_parser.add_argument( "--file", default="references.csv", - help="CSV file path (default: references.csv)" + help="CSV file path (default: references.csv)", ) add_arxiv_parser.set_defaults(func=handle_add_arxiv) # add-scholar command add_scholar_parser = subparsers.add_parser( - "add-scholar", - help="Add a reference from Google Scholar" - ) - add_scholar_parser.add_argument( - "--title", - help="Article title to search for" - ) - add_scholar_parser.add_argument( - "--url", - help="Article URL to search for" - ) - add_scholar_parser.add_argument( - "--category", - help="Category for the reference" + "add-scholar", help="Add a reference from Google Scholar" ) + add_scholar_parser.add_argument("--title", help="Article title to search for") + add_scholar_parser.add_argument("--url", help="Article URL to search for") + add_scholar_parser.add_argument("--category", help="Category for the reference") add_scholar_parser.add_argument( "--file", default="references.csv", - help="CSV file path (default: references.csv)" + help="CSV file path (default: references.csv)", ) add_scholar_parser.set_defaults(func=handle_add_scholar) # add command add_parser = subparsers.add_parser( "add", - help="Add a reference manually (or search Google Scholar if only title provided)" + help="Add a reference manually (or search Google Scholar" + " if only title provided)", ) add_parser.add_argument("--title", required=True, help="Article title (required)") add_parser.add_argument("--authors", help="Comma-separated author names") @@ -451,118 +448,97 @@ def main() -> None: add_parser.add_argument( "--file", default="references.csv", - help="CSV file path (default: references.csv)" + help="CSV file path (default: references.csv)", ) add_parser.set_defaults(func=handle_add_manual) # markdown command md_parser = subparsers.add_parser( - "markdown", - help="Generate markdown tables from references" + "markdown", help="Generate markdown tables from references" ) md_parser.add_argument( "--file", default="references.csv", - help="CSV file path (default: references.csv)" - ) - md_parser.add_argument( - "--by-category", - action="store_true", - help="Split tables by category" + help="CSV file path (default: references.csv)", ) md_parser.add_argument( - "--output", - help="Output file (e.g., README.md)" + "--by-category", action="store_true", help="Split tables by category" ) + md_parser.add_argument("--output", help="Output file (e.g., README.md)") md_parser.set_defaults(func=handle_markdown) # bibtex command bibtex_parser = subparsers.add_parser( - "bibtex", - help="Generate BibTeX file from references" + "bibtex", help="Generate BibTeX file from references" ) bibtex_parser.add_argument( "--file", default="references.csv", - help="CSV file path (default: references.csv)" - ) - bibtex_parser.add_argument( - "--output", - help="Output file (e.g., references.bib)" + help="CSV file path (default: references.csv)", ) + bibtex_parser.add_argument("--output", help="Output file (e.g., references.bib)") bibtex_parser.set_defaults(func=handle_bibtex) # graph command graph_parser = subparsers.add_parser( - "graph", - help="Build and visualize citation graph" + "graph", help="Build and visualize citation graph" ) graph_parser.add_argument( "--file", default="references.csv", - help="CSV file path (default: references.csv)" + help="CSV file path (default: references.csv)", ) graph_parser.add_argument( - "--output", - help="Output HTML file (default: citation_graph.html)" + "--output", help="Output HTML file (default: citation_graph.html)" ) graph_parser.add_argument( - "--verbose", - action="store_true", - help="Show verbose output references" + "--verbose", action="store_true", help="Show verbose output references" ) graph_parser.set_defaults(func=handle_graph) # db-init command db_init_parser = subparsers.add_parser( - "db-init", - help="Initialize database for bibliography management" + "db-init", help="Initialize database for bibliography management" ) db_init_parser.add_argument( "--db-url", default="sqlite:///bibliography.db", - help="Database URL (default: sqlite:///bibliography.db)" + help="Database URL (default: sqlite:///bibliography.db)", ) db_init_parser.set_defaults(func=handle_db_init) # db-migrate command db_migrate_parser = subparsers.add_parser( - "db-migrate", - help="Migrate references from CSV to database" + "db-migrate", help="Migrate references from CSV to database" ) db_migrate_parser.add_argument( "--file", default="references.csv", - help="CSV file path (default: references.csv)" + help="CSV file path (default: references.csv)", ) db_migrate_parser.add_argument( "--db-url", default="sqlite:///bibliography.db", - help="Database URL (default: sqlite:///bibliography.db)" + help="Database URL (default: sqlite:///bibliography.db)", ) db_migrate_parser.set_defaults(func=handle_db_migrate) # db-export command db_export_parser = subparsers.add_parser( - "db-export", - help="Export database references to CSV" - ) - db_export_parser.add_argument( - "--output", - required=True, - help="Output CSV file" + "db-export", help="Export database references to CSV" ) + db_export_parser.add_argument("--output", required=True, help="Output CSV file") db_export_parser.add_argument( "--db-url", default="sqlite:///bibliography.db", - help="Database URL (default: sqlite:///bibliography.db)" + help="Database URL (default: sqlite:///bibliography.db)", ) db_export_parser.set_defaults(func=handle_db_export) args = parser.parse_args() # Execute the appropriate handler or show help - if hasattr(args, 'func'): + if hasattr(args, "func"): try: args.func(args) except KeyboardInterrupt: @@ -573,4 +549,4 @@ def main() -> None: print_error(f"An error occurred: {e}") sys.exit(1) else: - parser.print_help() \ No newline at end of file + parser.print_help() diff --git a/pkg/mybib/db_storage.py b/pkg/mybib/db_storage.py index cc73075..7ee8f09 100644 --- a/pkg/mybib/db_storage.py +++ b/pkg/mybib/db_storage.py @@ -1,8 +1,10 @@ """Database storage adapter for bibliography management.""" -from typing import Optional, List, Dict +from typing import Dict, List, Optional + from sqlalchemy.exc import IntegrityError -from .models import Reference, Category, create_db_engine, init_db, get_session + +from .models import Category, Reference, create_db_engine, get_session, init_db class DatabaseStorage: @@ -10,7 +12,7 @@ class DatabaseStorage: def __init__(self, db_url: str = "sqlite:///bibliography.db"): """Initialize database storage. - + Args: db_url: Database connection URL """ @@ -30,7 +32,7 @@ def add_reference( scholar_id: str = None, ) -> Optional[Reference]: """Add a reference to the database. - + Args: title: Article title authors: Comma-separated author names @@ -41,12 +43,12 @@ def add_reference( category_name: Category name arxiv_id: arXiv identifier scholar_id: Google Scholar ID - + Returns: Created Reference object or None if duplicate """ session = get_session(self.engine) - + try: # Check for duplicate DOI if doi: @@ -54,7 +56,7 @@ def add_reference( if existing: session.close() return None - + # Get or create category category = None if category_name: @@ -63,7 +65,7 @@ def add_reference( category = Category(name=category_name) session.add(category) session.flush() - + # Create reference reference = Reference( title=title, @@ -76,13 +78,13 @@ def add_reference( scholar_id=scholar_id, category_id=category.id if category else None, ) - + session.add(reference) session.commit() session.close() - + return reference - + except IntegrityError: session.rollback() session.close() @@ -96,72 +98,68 @@ def get_references( self, category_id: int = None, year: int = None, order_by: str = None ) -> List[Reference]: """Get references from database with optional filtering. - + Args: category_id: Filter by category ID year: Filter by year order_by: Field to order by (e.g., "year", "-year", "title") - + Returns: List of Reference objects """ session = get_session(self.engine) query = session.query(Reference) - + if category_id: query = query.filter_by(category_id=category_id) - + if year: query = query.filter_by(year=year) - + # Ordering if order_by: reverse = order_by.startswith("-") field = order_by[1:] if reverse else order_by - + if hasattr(Reference, field): col = getattr(Reference, field) query = query.order_by(col.desc() if reverse else col) else: # Default ordering: category, then year descending query = query.order_by(Category.name, Reference.year.desc()) - + results = query.all() session.close() - + return results def add_category(self, name: str, description: str = None) -> Optional[Category]: """Add a category to the database. - + Args: name: Category name description: Optional description - + Returns: Created Category object or None if duplicate """ session = get_session(self.engine) - + try: # Check for existing category (case-insensitive) - existing = ( - session.query(Category) - .filter(Category.name.ilike(name)) - .first() - ) - + existing = session.query(Category).filter(Category.name.ilike(name)).first() + if existing: session.close() return existing - + category = Category(name=name, description=description) session.add(category) session.commit() session.close() - + return category - + except IntegrityError: session.rollback() session.close() @@ -173,37 +171,37 @@ def add_category(self, name: str, description: str = None) -> Optional[Category] def get_categories(self) -> List[Category]: """Get all categories. - + Returns: List of Category objects ordered by name """ session = get_session(self.engine) categories = session.query(Category).order_by(Category.name).all() session.close() - + return categories def migrate_from_csv(self, csv_file: str) -> Dict[str, int]: """Migrate references from CSV file to database. - + Args: csv_file: Path to CSV file - + Returns: Dictionary with migration statistics """ import pandas as pd - + df = pd.read_csv(csv_file, dtype={"ArxivID": str}) df = df.fillna("") - + stats = { "total": len(df), "added": 0, "duplicates": 0, "errors": 0, } - + for _, row in df.iterrows(): try: result = self.add_reference( @@ -216,48 +214,61 @@ def migrate_from_csv(self, csv_file: str) -> Dict[str, int]: category_name=row.get("Category", ""), arxiv_id=row.get("ArxivID", ""), ) - + if result: stats["added"] += 1 else: stats["duplicates"] += 1 - + except Exception as e: stats["errors"] += 1 print(f"Error migrating {row.get('Title', 'Unknown')}: {e}") - + return stats def export_to_csv(self, csv_file: str) -> int: """Export database references to CSV file. - + Args: csv_file: Path to output CSV file - + Returns: Number of references exported """ import pandas as pd - + session = get_session(self.engine) references = session.query(Reference).all() session.close() - + data = [] for ref in references: - data.append({ - "Title": ref.title, - "Authors": ref.authors or "", - "Journal": ref.journal or "", - "Year": ref.year or "", - "DOI": ref.doi or "", - "Link": ref.link or "", - "Category": ref.category.name if ref.category else "", - "ArxivID": ref.arxiv_id or "", - }) - + data.append( + { + "Title": ref.title, + "Authors": ref.authors or "", + "Journal": ref.journal or "", + "Year": ref.year or "", + "DOI": ref.doi or "", + "Link": ref.link or "", + "Category": ref.category.name if ref.category else "", + "ArxivID": ref.arxiv_id or "", + } + ) + df = pd.DataFrame(data) - df = df[["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"]] + df = df[ + [ + "Title", + "Authors", + "Journal", + "Year", + "DOI", + "Link", + "Category", + "ArxivID", + ] + ] df.to_csv(csv_file, index=False) - + return len(data) diff --git a/pkg/mybib/graph.py b/pkg/mybib/graph.py index ba2fabc..1ca3473 100644 --- a/pkg/mybib/graph.py +++ b/pkg/mybib/graph.py @@ -1,105 +1,111 @@ """Citation graph building and visualization using Crossref API.""" +import time + import networkx as nx -from typing import Optional import pandas as pd import requests -import time def query_crossref_references(doi: str, max_retries: int = 3) -> list[str]: """Query Crossref API for references cited by a given DOI. - + Args: doi: The DOI of the paper to query (e.g., "10.1234/example") max_retries: Maximum number of retry attempts - + Returns: List of DOIs cited by the paper """ # Ensure DOI doesn't have 'https://doi.org/' prefix - doi_clean = doi.strip().lstrip('https://doi.org/').lstrip('http://doi.org/') - + doi_clean = doi.strip().lstrip("https://doi.org/").lstrip("http://doi.org/") + url = f"https://api.crossref.org/works/{doi_clean}" headers = {"User-Agent": "MyBible (mailto:test@example.com)"} - + for attempt in range(max_retries): try: response = requests.get(url, headers=headers, timeout=10) response.raise_for_status() - + data = response.json() - if data.get('status') == 'ok': - work = data.get('message', {}) - references = work.get('reference', []) - + if data.get("status") == "ok": + work = data.get("message", {}) + references = work.get("reference", []) + # Extract DOIs from references reference_dois = [] for ref in references: - if isinstance(ref, dict) and 'DOI' in ref: - reference_dois.append(ref['DOI'].lower()) - + if isinstance(ref, dict) and "DOI" in ref: + reference_dois.append(ref["DOI"].lower()) + return reference_dois - + except requests.exceptions.RequestException as e: if attempt < max_retries - 1: # Exponential backoff - wait_time = 2 ** attempt + wait_time = 2**attempt time.sleep(wait_time) else: print(f"Warning: Failed to query DOI {doi_clean}: {e}") - + return [] -def build_citation_graph(df: pd.DataFrame, - output_references: bool = False) -> nx.DiGraph: +def build_citation_graph( + df: pd.DataFrame, output_references: bool = False +) -> nx.DiGraph: """Build a citation graph from stored research papers. - + Creates a directed graph where: - Nodes represent papers (identified by DOI) - Edges represent citations (paper A -> paper B means A cites B) - + Args: df: DataFrame with columns including 'DOI' and 'Title' output_references: If True, print progress information - + Returns: A NetworkX directed graph """ graph = nx.DiGraph() - + # Add all papers as nodes paper_dois = {} # Maps clean DOI to original DOI for idx, row in df.iterrows(): - doi = str(row['DOI']).strip() - if not doi or doi.lower() == 'nan': + doi = str(row["DOI"]).strip() + if not doi or doi.lower() == "nan": continue - + # Normalize DOI for use as node identifier - doi_clean = doi.lstrip('https://doi.org/').lstrip('http://doi.org/').lower() + doi_clean = doi.lstrip("https://doi.org/").lstrip("http://doi.org/").lower() paper_dois[doi_clean] = doi - - title = row.get('Title', 'Unknown') - graph.add_node(doi_clean, - title=title, - original_doi=doi, - authors=row.get('Authors', ''), - year=row.get('Year', '')) - + + title = row.get("Title", "Unknown") + graph.add_node( + doi_clean, + title=title, + original_doi=doi, + authors=row.get("Authors", ""), + year=row.get("Year", ""), + ) + if output_references: print(f"Added {len(graph.nodes())} papers to graph") - + # Query Crossref API for each paper's references stored_dois = set(paper_dois.keys()) - + for idx, (doi_clean, doi_original) in enumerate(paper_dois.items()): if output_references: - print(f"Querying references for paper {idx + 1}/{len(paper_dois)}: {doi_clean}") - + print( + f"Querying references for paper " + f"{idx + 1}/{len(paper_dois)}: {doi_clean}" + ) + # Query Crossref for references reference_dois = query_crossref_references(doi_original) - + # Add edges only for references in our stored papers for ref_doi in reference_dois: if ref_doi in stored_dois: @@ -107,19 +113,22 @@ def build_citation_graph(df: pd.DataFrame, graph.add_edge(doi_clean, ref_doi) if output_references: print(f" -> Edge: {doi_clean} cites {ref_doi}") - + # Rate limiting: Crossref API recommends polite usage time.sleep(0.5) - + if output_references: - print(f"Citation graph built with {len(graph.nodes())} nodes and {len(graph.edges())} edges") - + print( + f"Citation graph built with {len(graph.nodes())} nodes " + f"and {len(graph.edges())} edges" + ) + return graph def export_graph_html(graph: nx.DiGraph, output_file: str) -> None: """Export citation graph as interactive HTML using pyvis. - + Args: graph: NetworkX directed graph output_file: Path to output HTML file @@ -127,36 +136,34 @@ def export_graph_html(graph: nx.DiGraph, output_file: str) -> None: try: from pyvis.network import Network except ImportError: - raise ImportError("pyvis is required for graph visualization. Install it with: pip install pyvis") - + raise ImportError( + "pyvis is required for graph visualization. " + "Install it with: pip install pyvis" + ) + import json - + # Create pyvis network - net = Network(directed=True, height='750px', width='100%') - + net = Network(directed=True, height="750px", width="100%") + # Add nodes with labels for node in graph.nodes(data=True): node_id = node[0] node_data = node[1] - - title = node_data.get('title', 'Unknown') - label = title[:50] + '...' if len(title) > 50 else title - + + title = node_data.get("title", "Unknown") + label = title[:50] + "..." if len(title) > 50 else title + # Create hover title with full information hover_title = f"{title}\nDOI: {node_data.get('original_doi', '')}" - - net.add_node( - node_id, - label=label, - title=hover_title, - color='#FF6E63' - ) - + + net.add_node(node_id, label=label, title=hover_title, color="#FF6E63") + # Add edges for edge in graph.edges(): source, target = edge net.add_edge(source, target) - + # Configure physics settings as JSON string options = { "physics": { @@ -165,12 +172,12 @@ def export_graph_html(graph: nx.DiGraph, output_file: str) -> None: "gravitationalConstant": -26000, "centralGravity": 0.3, "springLength": 200, - "springConstant": 0.04 - } + "springConstant": 0.04, + }, } } net.set_options(json.dumps(options)) - + # Write HTML file net.write_html(output_file) - print(f"Citation graph exported to {output_file}") \ No newline at end of file + print(f"Citation graph exported to {output_file}") diff --git a/pkg/mybib/markdown.py b/pkg/mybib/markdown.py index 3f6b2ce..a3aa413 100644 --- a/pkg/mybib/markdown.py +++ b/pkg/mybib/markdown.py @@ -6,81 +6,84 @@ def _prepare_references_for_markdown(file_path: str = "references.csv"): """Prepare references DataFrame for markdown output. - + Args: file_path: Path to the CSV file - + Returns: Processed DataFrame ready for markdown conversion """ df = load_references(file_path) - + if df.empty: return df - + df["Authors"] = df["Authors"].apply(reform_names) df["DOI"] = df.apply(lambda row: f"[{row.get('DOI', 'unknown')}]", axis=1) # Ensure ArxivID stays as string for display (prevents float conversion) if "ArxivID" in df.columns: df["ArxivID"] = df["ArxivID"].astype(str) df = df.sort_values(by=["Category", "Year"], ascending=[True, False]) - + return df def make_markdown_table(file_path: str = "references.csv") -> str: """Generate a markdown table from all references. - + Args: file_path: Path to the CSV file - + Returns: Markdown formatted table as string """ df = _prepare_references_for_markdown(file_path) - + if df.empty: return "No references found." - + markdown_table = df.to_markdown(index=False) return markdown_table def make_markdown_tables_by_category(file_path: str = "references.csv") -> str: """Generate markdown tables separated by category. - + Args: file_path: Path to the CSV file - + Returns: Markdown formatted tables with category headers and footer links """ df = _prepare_references_for_markdown(file_path) - + if df.empty: return "No references found." - + output = [] footer = [] - + for category, group in df.groupby("Category"): output.append(f"## {category}\n") # Format ArxivID as string to avoid float conversion by to_markdown display_group = group.copy() if "ArxivID" in display_group.columns: - display_group["ArxivID"] = display_group["ArxivID"].astype(str).str.replace("nan", "") - - table = display_group.drop(columns=["Category", "Link"]).to_markdown(index=False) + display_group["ArxivID"] = ( + display_group["ArxivID"].astype(str).str.replace("nan", "") + ) + + table = display_group.drop(columns=["Category", "Link"]).to_markdown( + index=False + ) output.append(table) output.append("\n") - + # Collect footer links footer.extend( - f"{doi}: {link}" - for doi, link in zip(group["DOI"], group["Link"]) + f"{doi}: {link}" for doi, link in zip(group["DOI"], group["Link"]) ) - + # Sort and deduplicate footer footer = sorted(set(footer), key=lambda x: x.split(":")[0]) - + return "\n".join(output) + "\n".join(footer) diff --git a/pkg/mybib/metadata.py b/pkg/mybib/metadata.py index 75aa77d..973d148 100644 --- a/pkg/mybib/metadata.py +++ b/pkg/mybib/metadata.py @@ -2,23 +2,23 @@ import re import sys + import requests -from urllib.parse import urlparse from . import arxiv def fetch_metadata(url: str) -> dict: """Fetch metadata from a URL, automatically detecting the source. - + Supports: - arXiv URLs (arxiv.org) - DOI URLs (doi.org) and DOI patterns - Generic URLs with HTML metadata extraction - + Args: url: URL or DOI string to fetch metadata from - + Returns: Dictionary with standardized fields: { @@ -29,12 +29,12 @@ def fetch_metadata(url: str) -> dict: 'doi': str or None, 'link': str } - + Raises: SystemExit: If URL is invalid or metadata cannot be extracted """ url = url.strip() - + # Detect source and route to appropriate handler if _is_arxiv_url(url): return _fetch_arxiv_metadata(url) @@ -62,7 +62,7 @@ def _is_doi_pattern(url: str) -> bool: def _extract_arxiv_id(url: str) -> str: """Extract arXiv ID from URL. - + Examples: - https://arxiv.org/abs/2301.00001 -> 2301.00001 - https://arxiv.org/pdf/2301.00001.pdf -> 2301.00001 @@ -71,7 +71,7 @@ def _extract_arxiv_id(url: str) -> str: match = re.search(r"/(?:abs|pdf)/(\d{4}\.\d{4,5})", url) if match: return match.group(1) - + print(f"Error: Could not extract arXiv ID from {url}") sys.exit(1) @@ -95,27 +95,31 @@ def _normalize_doi(url_or_doi: str) -> str: def _fetch_crossref_metadata(url_or_doi: str) -> dict: """Fetch metadata from Crossref API using DOI.""" doi = _normalize_doi(url_or_doi) - + # Crossref API endpoint crossref_url = f"https://api.crossref.org/works/{doi}" - + try: response = requests.get(crossref_url, timeout=10) response.raise_for_status() except requests.RequestException as e: print(f"Error fetching DOI metadata: {e}") sys.exit(1) - + data = response.json() if data.get("status") != "ok" or not data.get("message"): print(f"Error: Could not fetch metadata for DOI {doi}") sys.exit(1) - + work = data["message"] - + # Extract fields with fallbacks - title = work.get("title", [""])[0] if isinstance(work.get("title"), list) else work.get("title", "Unknown") - + title = ( + work.get("title", [""])[0] + if isinstance(work.get("title"), list) + else work.get("title", "Unknown") + ) + # Authors authors = "" if "author" in work: @@ -124,29 +128,29 @@ def _fetch_crossref_metadata(url_or_doi: str) -> dict: for a in work["author"] ] authors = ", ".join(author_list) - + # Journal journal = work.get("container-title", "Unknown") if isinstance(journal, list): journal = journal[0] if journal else "Unknown" - + # Year year = None if "issued" in work: date_parts = work["issued"].get("date-parts", [[None]])[0] if date_parts: year = date_parts[0] - + # Link (prefer DOI link, fallback to URL) link = work.get("URL", f"https://doi.org/{doi}") - + return { "title": title, "authors": authors, "journal": journal, "year": year, "doi": doi, - "link": link + "link": link, } @@ -158,41 +162,41 @@ def _fetch_generic_metadata(url: str) -> dict: except requests.RequestException as e: print(f"Error fetching URL: {e}") sys.exit(1) - + html = response.text - + # Extract title from tag or og:title meta tag title = _extract_html_meta(html, ["og:title", "twitter:title"]) if not title: title_match = re.search(r"<title[^>]*>([^<]+)", html, re.IGNORECASE) title = title_match.group(1).strip() if title_match else "Unknown" - + # Extract author authors = _extract_html_meta(html, ["author", "article:author"]) - + # Try to extract DOI doi = None doi_match = re.search(r"10\.\S+/\S+", html) if doi_match: doi = doi_match.group(0) - + return { "title": title or "Unknown", "authors": authors or "Unknown", "journal": "Unknown", "year": None, "doi": doi, - "link": url + "link": url, } def _extract_html_meta(html: str, meta_names: list) -> str: """Extract value from HTML meta tags. - + Args: html: HTML content as string meta_names: List of meta tag names to search for - + Returns: Content of the first matching meta tag, or None if not found """ @@ -202,17 +206,20 @@ def _extract_html_meta(html: str, meta_names: list) -> str: match = re.search(pattern, html, re.IGNORECASE) if match: return match.group(1).strip() - + # Try name attribute pattern = rf'" + return ( + f"" + ) def create_db_engine(db_url: str = "sqlite:///bibliography.db"): """Create database engine. - + Args: db_url: Database URL (default: SQLite) - + Returns: SQLAlchemy engine """ @@ -80,7 +84,7 @@ def create_db_engine(db_url: str = "sqlite:///bibliography.db"): def init_db(engine): """Initialize database tables. - + Args: engine: SQLAlchemy engine """ @@ -89,10 +93,10 @@ def init_db(engine): def get_session(engine): """Get database session. - + Args: engine: SQLAlchemy engine - + Returns: SQLAlchemy session """ diff --git a/pkg/mybib/scholar.py b/pkg/mybib/scholar.py index 81ad328..28471e2 100644 --- a/pkg/mybib/scholar.py +++ b/pkg/mybib/scholar.py @@ -1,15 +1,15 @@ """Fetch metadata from Google Scholar using SerpAPI.""" -import sys import os +import sys +from typing import Dict, List, Optional + import requests -import json -from typing import Optional, Dict, List -from urllib.parse import urlencode # Suppress SSL warnings if needed try: import urllib3 + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) except ImportError: pass @@ -17,15 +17,15 @@ def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]: """Search Google Scholar for articles matching a query. - + Args: query: Search query (title, authors, or keywords) max_results: Maximum number of results to return (1-20) - + Returns: - List of results with keys: position, title, result_id, link, snippet, + List of results with keys: position, title, result_id, link, snippet, publication_info, inline_links, authors, year, etc. - + Raises: SystemExit: If API call fails or key is missing """ @@ -34,7 +34,7 @@ def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]: print("Error: SERPAPI_KEY environment variable not set.") print("Get a free API key at https://serpapi.com") sys.exit(1) - + params = { "engine": "google_scholar", "q": query, @@ -42,21 +42,21 @@ def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]: "num": min(max_results, 20), "hl": "en", } - + url = "https://serpapi.com/search" - + try: response = requests.get(url, params=params, timeout=10) response.raise_for_status() data = response.json() - + if "error" in data: print(f"Error from SerpAPI: {data['error']}") sys.exit(1) - + organic_results = data.get("organic_results", []) return organic_results - + except requests.exceptions.RequestException as e: print(f"Error querying Google Scholar: {e}") sys.exit(1) @@ -64,27 +64,27 @@ def search_google_scholar(query: str, max_results: int = 5) -> List[Dict]: def get_scholar_cite_link(result_id: str) -> Optional[str]: """Get the Google Scholar cite link for a result. - + Args: result_id: The result ID from search results - + Returns: The SerpAPI cite API link, or None if not available """ api_key = os.environ.get("SERPAPI_KEY") if not api_key: return None - + # Format: https://serpapi.com/search.json?engine=google_scholar_cite&q={result_id}&api_key={api_key} return f"https://serpapi.com/search.json?engine=google_scholar_cite&q={result_id}&api_key={api_key}" def fetch_bibtex_from_scholar(result_id: str) -> Optional[str]: """Fetch BibTeX citation from Google Scholar using SerpAPI cite API. - + Args: result_id: The result ID from a Google Scholar search result - + Returns: BibTeX string, or None if unable to fetch """ @@ -92,36 +92,36 @@ def fetch_bibtex_from_scholar(result_id: str) -> Optional[str]: if not api_key: print("Error: SERPAPI_KEY environment variable not set.") return None - + params = { "engine": "google_scholar_cite", "q": result_id, "api_key": api_key, } - + url = "https://serpapi.com/search" - + try: response = requests.get(url, params=params, timeout=10) response.raise_for_status() data = response.json() - + if "error" in data: print(f"Error from SerpAPI: {data['error']}") return None - + # SerpAPI returns citations in different formats # The BibTeX is typically in the response as a string or in a structured format citations = data.get("citations", {}) - + # Try to find BibTeX format bibtex = citations.get("bibtex") if bibtex: return bibtex - + # If we can't find BibTeX, we might need to construct it from available data return None - + except requests.exceptions.RequestException as e: print(f"Error fetching BibTeX from Google Scholar: {e}") return None @@ -129,15 +129,15 @@ def fetch_bibtex_from_scholar(result_id: str) -> Optional[str]: def extract_metadata_from_result(result: Dict) -> Dict: """Extract standardized metadata from a Google Scholar search result. - + Args: result: A single result from search_google_scholar() - + Returns: Dictionary with keys: title, authors, journal, year, doi, link, scholar_id """ import re - + metadata = { "title": result.get("title", ""), "authors": "", @@ -148,29 +148,29 @@ def extract_metadata_from_result(result: Dict) -> Dict: "result_id": result.get("result_id", ""), "scholar_id": result.get("result_id", ""), } - + # Extract publication info pub_info = result.get("publication_info", {}) if isinstance(pub_info, dict): summary = pub_info.get("summary", "") metadata["journal"] = summary - + # Try to extract year from summary - look for 4-digit years (1900-2099) # Use a more specific pattern to avoid partial matches - year_matches = re.findall(r'\b(19\d{2}|20\d{2})\b', summary) + year_matches = re.findall(r"\b(19\d{2}|20\d{2})\b", summary) if year_matches: # Take the most likely year (prefer 20xx over 19xx if available) for year_str in year_matches: metadata["year"] = int(year_str) # Prefer years starting with 20 - if year_str.startswith('20'): + if year_str.startswith("20"): break - + # Try to extract DOI from summary (pattern: 10.xxxx/xxxx) - doi_match = re.search(r'\b(10\.\S+/\S+)\b', summary) + doi_match = re.search(r"\b(10\.\S+/\S+)\b", summary) if doi_match: metadata["doi"] = doi_match.group(1) - + # Extract authors authors_list = pub_info.get("authors", []) if authors_list: @@ -184,13 +184,13 @@ def extract_metadata_from_result(result: Dict) -> Dict: metadata["authors"] = ", ".join(author_names) else: metadata["authors"] = str(authors_list) - + # Try to extract year from other fields if not found if metadata["year"] is None: - title_year_match = re.search(r'\((\d{4})\)', metadata["title"]) + title_year_match = re.search(r"\((\d{4})\)", metadata["title"]) if title_year_match: metadata["year"] = int(title_year_match.group(1)) - + # Try to extract DOI from result directly if not found in summary if metadata["doi"] is None: if "doi" in result: @@ -201,50 +201,51 @@ def extract_metadata_from_result(result: Dict) -> Dict: # Check for DOI link in inline_links for key, value in inline_links.items(): if "doi.org" in str(value): - doi_match = re.search(r'doi\.org/(.+?)(?:\s|$)', str(value)) + doi_match = re.search(r"doi\.org/(.+?)(?:\s|$)", str(value)) if doi_match: metadata["doi"] = doi_match.group(1) break - - # Note: We do NOT use scholar_id as a fallback for DOI because they are different identifiers. + + # Note: We do NOT use scholar_id as a fallback for DOI because they are + # different identifiers. # Scholar ID is Google Scholar's internal identifier, not a DOI. # If no real DOI is found, leave it as None. - + return metadata def search_and_confirm_article(title: str, max_attempts: int = 3) -> Optional[Dict]: """Search for an article on Google Scholar and get user confirmation. - + Args: title: Article title to search for max_attempts: Maximum attempts to find a match - + Returns: Metadata dictionary if found and confirmed, None otherwise """ - from .ui import console, confirm_action, display_reference_preview - + from .ui import confirm_action, console, display_reference_preview + console.print(f"[cyan]Searching Google Scholar for: {title}[/]") - + results = search_google_scholar(title, max_results=5) - + if not results: console.print("[red]No results found on Google Scholar[/]") return None - + # Show first result and ask for confirmation first_result = results[0] metadata = extract_metadata_from_result(first_result) - + console.print() console.print("[yellow]Found:[/]") display_reference_preview(metadata) console.print() - + if confirm_action("Is this the correct article?"): return metadata - + # If not confirmed, show other results if len(results) > 1: for i, result in enumerate(results[1:], 1): @@ -253,12 +254,12 @@ def search_and_confirm_article(title: str, max_attempts: int = 3) -> Optional[Di result_metadata = extract_metadata_from_result(result) display_reference_preview(result_metadata) console.print() - - if confirm_action(f"Is this the correct article?"): + + if confirm_action("Is this the correct article?"): return result_metadata - + if i >= max_attempts - 1: break - + console.print("[red]No matching article confirmed[/]") return None diff --git a/pkg/mybib/storage.py b/pkg/mybib/storage.py index c14b4bd..6ffc787 100644 --- a/pkg/mybib/storage.py +++ b/pkg/mybib/storage.py @@ -1,9 +1,10 @@ """CSV storage for bibliography references.""" -import pandas as pd import sys from pathlib import Path +import pandas as pd + def add_reference( title: str, @@ -18,7 +19,7 @@ def add_reference( file_path: str = "references.csv", ) -> None: """Add a reference to the CSV file. - + Args: title: Article title authors: Comma-separated author names @@ -30,13 +31,13 @@ def add_reference( arxiv_id: arXiv identifier (optional) scholar_id: Google Scholar result ID (optional, used as DOI fallback) file_path: Path to the CSV file - + Raises: SystemExit: If reference already exists """ # Use scholar_id as DOI fallback if DOI not provided final_doi = doi if doi else scholar_id - + new_reference = { "Title": title, "Authors": authors, @@ -48,9 +49,9 @@ def add_reference( "ArxivID": arxiv_id or "", } row = pd.DataFrame([new_reference]) - + file_exists = Path(file_path).exists() - + if file_exists: df_existing = pd.read_csv(file_path) # Convert ArxivID to string for comparison @@ -59,16 +60,14 @@ def add_reference( existing_df = df_existing # Normalize DOI for comparison: convert to string, strip whitespace, lowercase existing_dois = set( - str(d).strip().lower() - for d in existing_df["DOI"].to_list() - if pd.notna(d) + str(d).strip().lower() for d in existing_df["DOI"].to_list() if pd.notna(d) ) normalized_doi = str(final_doi).strip().lower() if final_doi else "" - + if normalized_doi and normalized_doi in existing_dois: print("Reference already exists in the CSV file.") sys.exit(0) - + # Ensure ArxivID is treated as string row["ArxivID"] = row["ArxivID"].astype(str) row.to_csv(file_path, mode="a", index=False, header=not file_exists) @@ -76,10 +75,10 @@ def add_reference( def load_references(file_path: str = "references.csv") -> pd.DataFrame: """Load references from CSV file. - + Args: file_path: Path to the CSV file - + Returns: DataFrame with reference data """ @@ -93,9 +92,18 @@ def load_references(file_path: str = "references.csv") -> pd.DataFrame: df["ArxivID"] = "" except FileNotFoundError: df = pd.DataFrame( - columns=["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"] + columns=[ + "Title", + "Authors", + "Journal", + "Year", + "DOI", + "Link", + "Category", + "ArxivID", + ] ) df = df.astype({"ArxivID": str}) df.to_csv(file_path, index=False) - - return df \ No newline at end of file + + return df diff --git a/pkg/mybib/ui.py b/pkg/mybib/ui.py index 1645c09..6e61e9b 100644 --- a/pkg/mybib/ui.py +++ b/pkg/mybib/ui.py @@ -1,14 +1,17 @@ """UI utilities using rich for enhanced terminal output.""" from contextlib import contextmanager -from pathlib import Path from typing import Generator +import pandas as pd from rich.console import Console -from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeRemainingColumn -from rich.table import Table +from rich.progress import ( + Progress, + SpinnerColumn, + TextColumn, +) from rich.prompt import Confirm -import pandas as pd +from rich.table import Table # Create a global console instance for consistent styling console = Console() @@ -37,7 +40,7 @@ def print_info(message: str) -> None: @contextmanager def progress_context(description: str = "Processing") -> Generator: """Context manager for showing progress during operations. - + Usage: with progress_context("Fetching metadata"): # Long-running operation @@ -48,7 +51,7 @@ def progress_context(description: str = "Processing") -> Generator: TextColumn("[progress.description]{task.description}"), console=console, ) as progress: - task_id = progress.add_task(description, total=None) + progress.add_task(description, total=None) try: yield finally: @@ -63,7 +66,7 @@ def api_progress() -> Generator: TextColumn("[progress.description]{task.description}"), console=console, ) as progress: - task_id = progress.add_task("[cyan]Fetching from API...", total=None) + progress.add_task("[cyan]Fetching from API...", total=None) try: yield finally: @@ -72,11 +75,11 @@ def api_progress() -> Generator: def confirm_action(prompt: str, default: bool = True) -> bool: """Show a confirmation prompt with rich styling. - + Args: prompt: The question to ask default: Default response if user just presses enter (default: True) - + Returns: True if confirmed, False otherwise """ @@ -85,10 +88,10 @@ def confirm_action(prompt: str, default: bool = True) -> bool: def create_references_table(df: pd.DataFrame) -> Table: """Create a rich Table from a DataFrame of references. - + Args: df: DataFrame with columns: Title, Authors, Journal, Year, Category - + Returns: A rich Table object """ @@ -100,14 +103,14 @@ def create_references_table(df: pd.DataFrame) -> Table: padding=(0, 1), show_lines=False, ) - + # Add columns table.add_column("Title", style="cyan", width=40, overflow="fold") table.add_column("Authors", style="green", width=30, overflow="fold") table.add_column("Year", justify="center", width=6) table.add_column("Category", style="yellow", width=15) table.add_column("Journal", style="blue", width=20, overflow="fold") - + # Add rows for _, row in df.iterrows(): table.add_row( @@ -117,22 +120,22 @@ def create_references_table(df: pd.DataFrame) -> Table: str(row.get("Category", "")), str(row.get("Journal", ""))[:20], ) - + return table def display_reference_preview(metadata: dict) -> None: """Display a nice preview of a reference before adding. - + Args: metadata: Dictionary with title, authors, journal, year, category """ table = Table(show_header=False, box=None, padding=(0, 2)) - + table.add_row("[bold cyan]Title:[/]", metadata.get("title", "N/A")) table.add_row("[bold cyan]Authors:[/]", metadata.get("authors", "N/A")) table.add_row("[bold cyan]Journal:[/]", metadata.get("journal", "N/A")) table.add_row("[bold cyan]Year:[/]", str(metadata.get("year", "N/A"))) table.add_row("[bold cyan]DOI:[/]", metadata.get("doi", "N/A")) - + console.print(table) diff --git a/pkg/mybib/utils.py b/pkg/mybib/utils.py index 42d255e..28d863a 100644 --- a/pkg/mybib/utils.py +++ b/pkg/mybib/utils.py @@ -3,42 +3,54 @@ def reform_names(authors_str: str) -> str: """Format author names for display. - + Converts full author lists to abbreviated form: - Single author or team: Last name only - Two authors: "LastName1 and LastName2" - Three+ authors: "FirstAuthor et al." - Team names (contain "Team"): Entity name only without "et al." - + Args: authors_str: Comma-separated string of author names or "Authors et al." format - + Returns: Formatted author string """ if not authors_str or not isinstance(authors_str, str): return "" - + authors_str = authors_str.strip() - + # Handle "X et al." format - extract just the first author if " et al." in authors_str: first_author = authors_str.split(" et al.")[0].strip() # Extract last name from first author first_author_last = first_author.split()[-1] return f"{first_author_last} et al." - + # Check if it's a team name (contains "Team", "AI", etc.) - if any(team_keyword in authors_str for team_keyword in ["Team", "team", "-Ai", "Mistral", "Meta", "OpenAI", "DeepMind", "Google"]): + if any( + team_keyword in authors_str + for team_keyword in [ + "Team", + "team", + "-Ai", + "Mistral", + "Meta", + "OpenAI", + "DeepMind", + "Google", + ] + ): # Just return the entity name as is, or extract last part parts = authors_str.split(",") if len(parts) > 0: return parts[0].strip() return authors_str - + # Split by comma authors = [a.strip() for a in authors_str.split(",")] - + if len(authors) > 2: first_author_last_name = authors[0].split()[-1] return f"{first_author_last_name} et al." diff --git a/tests/test_arxiv.py b/tests/test_arxiv.py index 1ba2d66..530cbb1 100644 --- a/tests/test_arxiv.py +++ b/tests/test_arxiv.py @@ -1,10 +1,10 @@ """Tests for arXiv metadata fetching module.""" -import pytest import tempfile from pathlib import Path -from unittest.mock import patch, Mock -import sys +from unittest.mock import Mock, patch + +import pytest from pkg.mybib import arxiv @@ -86,12 +86,12 @@ def sample_arxiv_response_no_journal(): class TestFetchArxivMetadata: """Test fetching metadata from arXiv API.""" - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_success(self, mock_get, sample_arxiv_response): """Test successfully fetching arXiv metadata.""" mock_response = Mock() mock_response.status_code = 200 - mock_response.content = sample_arxiv_response.encode('utf-8') + mock_response.content = sample_arxiv_response.encode("utf-8") mock_get.return_value = mock_response result = arxiv.fetch_arxiv_metadata("1706.03762") @@ -106,12 +106,14 @@ def test_fetch_arxiv_metadata_success(self, mock_get, sample_arxiv_response): assert result["link"] == "https://arxiv.org/abs/1706.03762" assert result["arxiv_id"] == "1706.03762" - @patch('pkg.mybib.arxiv.requests.get') - def test_fetch_arxiv_metadata_multiple_authors(self, mock_get, sample_arxiv_response): + @patch("pkg.mybib.arxiv.requests.get") + def test_fetch_arxiv_metadata_multiple_authors( + self, mock_get, sample_arxiv_response + ): """Test that multiple authors are properly parsed.""" mock_response = Mock() mock_response.status_code = 200 - mock_response.content = sample_arxiv_response.encode('utf-8') + mock_response.content = sample_arxiv_response.encode("utf-8") mock_get.return_value = mock_response result = arxiv.fetch_arxiv_metadata("1706.03762") @@ -122,12 +124,12 @@ def test_fetch_arxiv_metadata_multiple_authors(self, mock_get, sample_arxiv_resp assert "Noam Shazeer" in authors assert "Parmar Aidan" in authors - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_no_doi(self, mock_get, sample_arxiv_response_no_doi): """Test handling of metadata without DOI (uses arxiv_id as fallback).""" mock_response = Mock() mock_response.status_code = 200 - mock_response.content = sample_arxiv_response_no_doi.encode('utf-8') + mock_response.content = sample_arxiv_response_no_doi.encode("utf-8") mock_get.return_value = mock_response result = arxiv.fetch_arxiv_metadata("2003.06123") @@ -137,12 +139,14 @@ def test_fetch_arxiv_metadata_no_doi(self, mock_get, sample_arxiv_response_no_do assert result["doi"] == "2003.06123" # Falls back to arxiv_id assert result["journal"] == "arXiv" - @patch('pkg.mybib.arxiv.requests.get') - def test_fetch_arxiv_metadata_no_journal(self, mock_get, sample_arxiv_response_no_journal): + @patch("pkg.mybib.arxiv.requests.get") + def test_fetch_arxiv_metadata_no_journal( + self, mock_get, sample_arxiv_response_no_journal + ): """Test handling of metadata without journal reference.""" mock_response = Mock() mock_response.status_code = 200 - mock_response.content = sample_arxiv_response_no_journal.encode('utf-8') + mock_response.content = sample_arxiv_response_no_journal.encode("utf-8") mock_get.return_value = mock_response result = arxiv.fetch_arxiv_metadata("2201.04000") @@ -151,8 +155,8 @@ def test_fetch_arxiv_metadata_no_journal(self, mock_get, sample_arxiv_response_n assert result["journal"] == "arXiv" # Default journal assert result["year"] == 2022 - @patch('pkg.mybib.arxiv.sys.exit', side_effect=SystemExit) - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.sys.exit", side_effect=SystemExit) + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_http_error(self, mock_get, mock_exit): """Test handling of HTTP errors.""" mock_response = Mock() @@ -162,8 +166,8 @@ def test_fetch_arxiv_metadata_http_error(self, mock_get, mock_exit): with pytest.raises(SystemExit): arxiv.fetch_arxiv_metadata("1706.03762") - @patch('pkg.mybib.arxiv.sys.exit', side_effect=SystemExit) - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.sys.exit", side_effect=SystemExit) + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_not_found(self, mock_get, mock_exit): """Test handling when no entry is found.""" empty_response = """ @@ -171,18 +175,18 @@ def test_fetch_arxiv_metadata_not_found(self, mock_get, mock_exit): """ mock_response = Mock() mock_response.status_code = 200 - mock_response.content = empty_response.encode('utf-8') + mock_response.content = empty_response.encode("utf-8") mock_get.return_value = mock_response with pytest.raises(SystemExit): arxiv.fetch_arxiv_metadata("9999.99999") - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_url_formation(self, mock_get, sample_arxiv_response): """Test that the correct URL is being formed.""" mock_response = Mock() mock_response.status_code = 200 - mock_response.content = sample_arxiv_response.encode('utf-8') + mock_response.content = sample_arxiv_response.encode("utf-8") mock_get.return_value = mock_response arxiv.fetch_arxiv_metadata("1706.03762") @@ -193,7 +197,7 @@ def test_fetch_arxiv_metadata_url_formation(self, mock_get, sample_arxiv_respons assert "export.arxiv.org/api/query" in called_url assert "1706.03762" in called_url - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_whitespace_in_title(self, mock_get): """Test that whitespace in title is normalized.""" response_with_newline = """ @@ -213,7 +217,7 @@ def test_fetch_arxiv_metadata_whitespace_in_title(self, mock_get): """ mock_response = Mock() mock_response.status_code = 200 - mock_response.content = response_with_newline.encode('utf-8') + mock_response.content = response_with_newline.encode("utf-8") mock_get.return_value = mock_response result = arxiv.fetch_arxiv_metadata("1706.03762") @@ -223,7 +227,7 @@ def test_fetch_arxiv_metadata_whitespace_in_title(self, mock_get): assert "Attention Is All" in result["title"] assert "You Need" in result["title"] - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_single_author(self, mock_get): """Test handling of single author.""" single_author_response = """ @@ -241,14 +245,14 @@ def test_fetch_arxiv_metadata_single_author(self, mock_get): """ mock_response = Mock() mock_response.status_code = 200 - mock_response.content = single_author_response.encode('utf-8') + mock_response.content = single_author_response.encode("utf-8") mock_get.return_value = mock_response result = arxiv.fetch_arxiv_metadata("2105.00001") assert result["authors"] == "Alice Author" - @patch('pkg.mybib.arxiv.requests.get') + @patch("pkg.mybib.arxiv.requests.get") def test_fetch_arxiv_metadata_connection_error(self, mock_get): """Test handling of connection errors.""" mock_get.side_effect = Exception("Connection error") diff --git a/tests/test_markdown.py b/tests/test_markdown.py index ea125bc..ff87da7 100644 --- a/tests/test_markdown.py +++ b/tests/test_markdown.py @@ -1,11 +1,12 @@ """Tests for markdown generation module.""" -import pytest -import pandas as pd import tempfile from pathlib import Path -from pkg.mybib import storage, markdown +import pandas as pd +import pytest + +from pkg.mybib import markdown, storage @pytest.fixture @@ -27,7 +28,7 @@ def csv_with_references(temp_csv): "Year": 2023, "DOI": "10.1234/ml.2023", "Link": "https://example.com/paper1", - "Category": "Machine Learning" + "Category": "Machine Learning", }, { "Title": "Deep Neural Networks", @@ -36,7 +37,7 @@ def csv_with_references(temp_csv): "Year": 2022, "DOI": "10.5678/nn.2022", "Link": "https://example.com/paper2", - "Category": "Machine Learning" + "Category": "Machine Learning", }, { "Title": "Computer Vision Survey", @@ -45,8 +46,8 @@ def csv_with_references(temp_csv): "Year": 2024, "DOI": "10.9999/cv.2024", "Link": "https://example.com/paper3", - "Category": "Computer Vision" - } + "Category": "Computer Vision", + }, ] df = pd.DataFrame(references) @@ -72,7 +73,7 @@ def test_make_markdown_table_single_reference(self, temp_csv): doi="10.0000/test.2023", link="https://test.com", category="Testing", - file_path=temp_csv + file_path=temp_csv, ) result = markdown.make_markdown_table(temp_csv) @@ -140,7 +141,7 @@ def test_make_markdown_table_is_valid_markdown(self, temp_csv): doi="10.1111/paper1.2023", link="https://link1.com", category="Cat1", - file_path=temp_csv + file_path=temp_csv, ) storage.add_reference( @@ -151,13 +152,13 @@ def test_make_markdown_table_is_valid_markdown(self, temp_csv): doi="10.2222/paper2.2022", link="https://link2.com", category="Cat1", - file_path=temp_csv + file_path=temp_csv, ) result = markdown.make_markdown_table(temp_csv) # Check for markdown table structure - lines = result.split('\n') + lines = result.split("\n") assert len(lines) >= 4 # Header, separator, at least 2 data rows assert "|" in lines[0] # Header has pipes assert "-" in lines[1] # Separator has dashes @@ -181,7 +182,7 @@ def test_make_markdown_tables_by_category_single_category(self, temp_csv): doi="10.1111/ml1.2023", link="https://link1.com", category="Machine Learning", - file_path=temp_csv + file_path=temp_csv, ) storage.add_reference( @@ -192,7 +193,7 @@ def test_make_markdown_tables_by_category_single_category(self, temp_csv): doi="10.1111/ml2.2022", link="https://link2.com", category="Machine Learning", - file_path=temp_csv + file_path=temp_csv, ) result = markdown.make_markdown_tables_by_category(temp_csv) @@ -201,7 +202,9 @@ def test_make_markdown_tables_by_category_single_category(self, temp_csv): assert "ML Paper 1" in result assert "ML Paper 2" in result - def test_make_markdown_tables_by_category_multiple_categories(self, csv_with_references): + def test_make_markdown_tables_by_category_multiple_categories( + self, csv_with_references + ): """Test generating markdown with multiple categories.""" result = markdown.make_markdown_tables_by_category(csv_with_references) @@ -220,7 +223,7 @@ def test_make_markdown_tables_by_category_excludes_category_column(self, temp_cs doi="10.0000/test.2023", link="https://test.com", category="Testing", - file_path=temp_csv + file_path=temp_csv, ) result = markdown.make_markdown_tables_by_category(temp_csv) @@ -230,7 +233,9 @@ def test_make_markdown_tables_by_category_excludes_category_column(self, temp_cs # But the actual category values shouldn't appear in the table # (only in the section header) - def test_make_markdown_tables_by_category_has_footer_links(self, csv_with_references): + def test_make_markdown_tables_by_category_has_footer_links( + self, csv_with_references + ): """Test that footer contains DOI links.""" result = markdown.make_markdown_tables_by_category(csv_with_references) @@ -251,7 +256,7 @@ def test_make_markdown_tables_by_category_footer_uniqueness(self, temp_csv): doi="10.1111/unique.2023", link="https://link1.com", category="Category1", - file_path=temp_csv + file_path=temp_csv, ) storage.add_reference( @@ -262,7 +267,7 @@ def test_make_markdown_tables_by_category_footer_uniqueness(self, temp_csv): doi="10.2222/unique2.2022", link="https://link2.com", category="Category2", - file_path=temp_csv + file_path=temp_csv, ) result = markdown.make_markdown_tables_by_category(temp_csv) @@ -271,7 +276,9 @@ def test_make_markdown_tables_by_category_footer_uniqueness(self, temp_csv): assert result.count("https://link1.com") >= 1 assert result.count("https://link2.com") >= 1 - def test_make_markdown_tables_by_category_sorted_categories(self, csv_with_references): + def test_make_markdown_tables_by_category_sorted_categories( + self, csv_with_references + ): """Test that categories appear in the output.""" result = markdown.make_markdown_tables_by_category(csv_with_references) @@ -282,7 +289,9 @@ def test_make_markdown_tables_by_category_sorted_categories(self, csv_with_refer assert ml_pos >= 0 or cv_pos >= 0 assert "## Computer Vision" in result or "## Machine Learning" in result - def test_make_markdown_tables_by_category_within_category_sorting(self, csv_with_references): + def test_make_markdown_tables_by_category_within_category_sorting( + self, csv_with_references + ): """Test that within each category, papers are sorted by year descending.""" result = markdown.make_markdown_tables_by_category(csv_with_references) @@ -294,8 +303,8 @@ def test_make_markdown_tables_by_category_within_category_sorting(self, csv_with if basics_pos > 0 and nn_pos > 0: # Get the category that comes before them ml_section_start = result.rfind("## Machine Learning") - cv_section_start = result.rfind("## Computer Vision") - + result.rfind("## Computer Vision") + # Check they're in the ML section and ordered by year if ml_section_start > 0 and ml_section_start < min(basics_pos, nn_pos): # If next category marker exists, they should be before it @@ -313,7 +322,7 @@ def test_make_markdown_tables_by_category_link_column_excluded(self, temp_csv): doi="10.0000/test.2023", link="https://special-link.com/unique", category="Testing", - file_path=temp_csv + file_path=temp_csv, ) result = markdown.make_markdown_tables_by_category(temp_csv) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 676e3fd..95a84a3 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,13 +1,14 @@ """Tests for metadata extraction module.""" import sys -import pytest import tempfile from pathlib import Path -from unittest.mock import patch, MagicMock +from unittest.mock import patch + +import pytest # Add parent directory to path -sys.path.insert(0, '../pkg') +sys.path.insert(0, "../pkg") from mybib import metadata @@ -28,19 +29,19 @@ def temp_file(temp_dir): class TestSourceDetection: """Test source detection functions.""" - + def test_arxiv_url_detection(self): """Test arXiv URL detection.""" assert metadata._is_arxiv_url("https://arxiv.org/abs/2301.00001") assert metadata._is_arxiv_url("http://arxiv.org/pdf/2301.00001.pdf") assert not metadata._is_arxiv_url("https://doi.org/10.1234/example") - + def test_doi_url_detection(self): """Test DOI URL detection.""" assert metadata._is_doi_url("https://doi.org/10.1145/1234567") assert metadata._is_doi_url("http://doi.org/10.1234/example") assert not metadata._is_doi_url("https://arxiv.org/abs/2301.00001") - + def test_doi_pattern_detection(self): """Test DOI pattern detection.""" assert metadata._is_doi_pattern("10.1145/1234567") @@ -50,12 +51,18 @@ def test_doi_pattern_detection(self): class TestIdExtraction: """Test ID extraction functions.""" - + def test_arxiv_id_extraction(self): """Test arXiv ID extraction from URLs.""" - assert metadata._extract_arxiv_id("https://arxiv.org/abs/2301.00001") == "2301.00001" - assert metadata._extract_arxiv_id("https://arxiv.org/pdf/2301.00001v2.pdf") == "2301.00001" - + assert ( + metadata._extract_arxiv_id("https://arxiv.org/abs/2301.00001") + == "2301.00001" + ) + assert ( + metadata._extract_arxiv_id("https://arxiv.org/pdf/2301.00001v2.pdf") + == "2301.00001" + ) + def test_arxiv_id_extraction_invalid(self): """Test arXiv ID extraction with invalid URL.""" with pytest.raises(SystemExit): @@ -64,11 +71,14 @@ def test_arxiv_id_extraction_invalid(self): class TestDoiNormalization: """Test DOI normalization.""" - + def test_normalize_doi_from_url(self): """Test DOI extraction from URL.""" - assert metadata._normalize_doi("https://doi.org/10.1145/1234567") == "10.1145/1234567" - + assert ( + metadata._normalize_doi("https://doi.org/10.1145/1234567") + == "10.1145/1234567" + ) + def test_normalize_doi_already_normalized(self): """Test DOI that's already normalized.""" assert metadata._normalize_doi("10.1145/1234567") == "10.1145/1234567" @@ -76,51 +86,51 @@ def test_normalize_doi_already_normalized(self): class TestHtmlMetaExtraction: """Test HTML meta tag extraction.""" - + def test_extract_html_meta_with_property(self): """Test extraction of meta tag with property attribute.""" html = '' result = metadata._extract_html_meta(html, ["og:title"]) assert result == "Test Title" - + def test_extract_html_meta_with_name(self): """Test extraction of meta tag with name attribute.""" html = '' result = metadata._extract_html_meta(html, ["author"]) assert result == "John Doe" - + def test_extract_html_meta_not_found(self): """Test extraction when meta tag is not found.""" - html = '' + html = "" result = metadata._extract_html_meta(html, ["og:title", "author"]) assert result is None class TestFunctionRouting: """Test that fetch_metadata routes to correct handler.""" - - @patch('mybib.metadata._fetch_arxiv_metadata') + + @patch("mybib.metadata._fetch_arxiv_metadata") def test_route_to_arxiv(self, mock_arxiv): """Test routing to arXiv handler.""" mock_arxiv.return_value = {"title": "arxiv paper"} metadata.fetch_metadata("https://arxiv.org/abs/2301.00001") mock_arxiv.assert_called_once() - - @patch('mybib.metadata._fetch_crossref_metadata') + + @patch("mybib.metadata._fetch_crossref_metadata") def test_route_to_doi_url(self, mock_doi): """Test routing to DOI URL handler.""" mock_doi.return_value = {"title": "doi paper"} metadata.fetch_metadata("https://doi.org/10.1145/1234567") mock_doi.assert_called_once() - - @patch('mybib.metadata._fetch_crossref_metadata') + + @patch("mybib.metadata._fetch_crossref_metadata") def test_route_to_doi_pattern(self, mock_doi): """Test routing to DOI pattern handler.""" mock_doi.return_value = {"title": "doi paper"} metadata.fetch_metadata("10.1145/1234567") mock_doi.assert_called_once() - - @patch('mybib.metadata._fetch_generic_metadata') + + @patch("mybib.metadata._fetch_generic_metadata") def test_route_to_generic(self, mock_generic): """Test routing to generic URL handler.""" mock_generic.return_value = {"title": "generic page"} @@ -130,8 +140,8 @@ def test_route_to_generic(self, mock_generic): class TestReturnFormat: """Test that all handlers return standardized format.""" - - @patch('mybib.metadata._fetch_generic_metadata') + + @patch("mybib.metadata._fetch_generic_metadata") def test_return_format_has_required_fields(self, mock_generic): """Test that returned dict has all required fields.""" mock_generic.return_value = { @@ -140,10 +150,10 @@ def test_return_format_has_required_fields(self, mock_generic): "journal": "Journal", "year": 2023, "doi": "10.1234/test", - "link": "https://example.com" + "link": "https://example.com", } result = metadata.fetch_metadata("https://example.com") - + required_fields = {"title", "authors", "journal", "year", "doi", "link"} assert required_fields.issubset(result.keys()) @@ -154,7 +164,7 @@ def demo_arxiv_metadata(): print("\n=== Demo: arXiv Metadata ===") url = "https://arxiv.org/abs/2301.00001" print(f"URL: {url}") - print(f"Expected: Metadata for arXiv paper with automatic source detection") + print("Expected: Metadata for arXiv paper with automatic source detection") # Note: Actual API call commented out to avoid network dependency in tests # result = metadata.fetch_metadata(url) # print(f"Result: {result}") @@ -165,7 +175,7 @@ def demo_doi_metadata(): print("\n=== Demo: DOI Metadata ===") url = "https://doi.org/10.1145/1234567" print(f"URL: {url}") - print(f"Expected: Metadata from Crossref API for DOI") + print("Expected: Metadata from Crossref API for DOI") # Note: Actual API call commented out to avoid network dependency in tests # result = metadata.fetch_metadata(url) # print(f"Result: {result}") @@ -176,7 +186,7 @@ def demo_generic_metadata(): print("\n=== Demo: Generic URL Metadata ===") url = "https://example.com/article" print(f"URL: {url}") - print(f"Expected: Metadata extracted from HTML meta tags") + print("Expected: Metadata extracted from HTML meta tags") # Note: Actual API call commented out to avoid network dependency in tests # result = metadata.fetch_metadata(url) # print(f"Result: {result}") diff --git a/tests/test_storage.py b/tests/test_storage.py index ccfed33..9e89376 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,11 +1,11 @@ """Tests for storage module.""" -import pytest -import pandas as pd import tempfile from pathlib import Path from unittest.mock import patch -import sys + +import pandas as pd +import pytest from pkg.mybib import storage @@ -47,7 +47,7 @@ def test_add_reference_to_empty_file(self, temp_csv, sample_references): link=sample_references["link"], category=sample_references["category"], arxiv_id=sample_references["arxiv_id"], - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv) @@ -67,7 +67,7 @@ def test_add_multiple_references(self, temp_csv, sample_references): link=sample_references["link"], category=sample_references["category"], arxiv_id=sample_references["arxiv_id"], - file_path=temp_csv + file_path=temp_csv, ) # Add second reference @@ -79,7 +79,7 @@ def test_add_multiple_references(self, temp_csv, sample_references): doi="10.5678/example.2024", link="https://example.com/paper2", category="Deep Learning", - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv) @@ -98,7 +98,7 @@ def test_add_reference_preserves_headers(self, temp_csv, sample_references): link=sample_references["link"], category=sample_references["category"], arxiv_id=sample_references["arxiv_id"], - file_path=temp_csv + file_path=temp_csv, ) storage.add_reference( @@ -109,12 +109,21 @@ def test_add_reference_preserves_headers(self, temp_csv, sample_references): doi="10.9999/another.2022", link="https://example.com/paper3", category="Research", - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv) # Check headers are correct - expected_headers = ["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"] + expected_headers = [ + "Title", + "Authors", + "Journal", + "Year", + "DOI", + "Link", + "Category", + "ArxivID", + ] assert list(df.columns) == expected_headers @@ -131,11 +140,11 @@ def test_duplicate_doi_detection(self, temp_csv, sample_references): doi=sample_references["doi"], link=sample_references["link"], category=sample_references["category"], - file_path=temp_csv + file_path=temp_csv, ) # Try to add duplicate with different title - with patch('sys.exit') as mock_exit: + with patch("sys.exit") as mock_exit: storage.add_reference( title="Different Title", authors="Different Authors", @@ -144,7 +153,7 @@ def test_duplicate_doi_detection(self, temp_csv, sample_references): doi=sample_references["doi"], # Same DOI link="https://different.com", category="Different Category", - file_path=temp_csv + file_path=temp_csv, ) mock_exit.assert_called_once_with(0) @@ -158,11 +167,11 @@ def test_duplicate_detection_case_insensitive(self, temp_csv, sample_references) doi=sample_references["doi"], link=sample_references["link"], category=sample_references["category"], - file_path=temp_csv + file_path=temp_csv, ) # Try to add duplicate with different case DOI - with patch('sys.exit') as mock_exit: + with patch("sys.exit") as mock_exit: storage.add_reference( title="Another Title", authors="Another Author", @@ -171,7 +180,7 @@ def test_duplicate_detection_case_insensitive(self, temp_csv, sample_references) doi=sample_references["doi"].upper(), # Different case link="https://another.com", category="Another Category", - file_path=temp_csv + file_path=temp_csv, ) mock_exit.assert_called_once_with(0) @@ -185,11 +194,11 @@ def test_duplicate_detection_with_whitespace(self, temp_csv, sample_references): doi=sample_references["doi"], link=sample_references["link"], category=sample_references["category"], - file_path=temp_csv + file_path=temp_csv, ) # Try to add duplicate with extra whitespace - with patch('sys.exit') as mock_exit: + with patch("sys.exit") as mock_exit: storage.add_reference( title="Yet Another Title", authors="Yet Another Author", @@ -198,7 +207,7 @@ def test_duplicate_detection_with_whitespace(self, temp_csv, sample_references): doi=f" {sample_references['doi']} ", # Whitespace added link="https://yet-another.com", category="Yet Another Category", - file_path=temp_csv + file_path=temp_csv, ) mock_exit.assert_called_once_with(0) @@ -212,7 +221,7 @@ def test_no_duplicate_with_different_doi(self, temp_csv, sample_references): doi=sample_references["doi"], link=sample_references["link"], category=sample_references["category"], - file_path=temp_csv + file_path=temp_csv, ) # Add reference with different DOI @@ -224,7 +233,7 @@ def test_no_duplicate_with_different_doi(self, temp_csv, sample_references): doi="10.1111/different.2024", # Different DOI link="https://another.com", category="Different Category", - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv) @@ -244,7 +253,7 @@ def test_load_references_from_existing_file(self, temp_csv, sample_references): doi=sample_references["doi"], link=sample_references["link"], category=sample_references["category"], - file_path=temp_csv + file_path=temp_csv, ) df = storage.load_references(temp_csv) @@ -257,9 +266,18 @@ def test_load_references_creates_empty_file(self, temp_csv): Path(temp_csv).unlink(missing_ok=True) df = storage.load_references(temp_csv) - + assert df.empty - assert list(df.columns) == ["Title", "Authors", "Journal", "Year", "DOI", "Link", "Category", "ArxivID"] + assert list(df.columns) == [ + "Title", + "Authors", + "Journal", + "Year", + "DOI", + "Link", + "Category", + "ArxivID", + ] assert Path(temp_csv).exists() def test_load_references_preserves_data(self, temp_csv, sample_references): @@ -271,10 +289,10 @@ def test_load_references_preserves_data(self, temp_csv, sample_references): authors=f"Author {i}", journal=f"Journal {i}", year=2020 + i, - doi=f"10.{i}/example.{2020+i}", + doi=f"10.{i}/example.{2020 + i}", link=f"https://example.com/paper{i}", category=f"Category {i}", - file_path=temp_csv + file_path=temp_csv, ) df = storage.load_references(temp_csv) @@ -296,7 +314,7 @@ def test_add_reference_with_scholar_id_fallback(self, temp_csv): link="https://scholar.com/paper", category="Testing", scholar_id="scholar_id_12345", - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv) @@ -314,7 +332,7 @@ def test_add_reference_prefers_doi_over_scholar_id(self, temp_csv): link="https://example.com/paper", category="Testing", scholar_id="scholar_id_67890", - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv) @@ -336,7 +354,7 @@ def test_add_reference_with_arxiv_id(self, temp_csv): link="https://arxiv.org/abs/2301.00001", category="Machine Learning", arxiv_id="2301.00001", - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv, dtype={"ArxivID": str}) @@ -353,7 +371,7 @@ def test_add_reference_without_arxiv_id(self, temp_csv): doi="10.1234/example.2023", link="https://example.com/paper", category="Research", - file_path=temp_csv + file_path=temp_csv, ) df = pd.read_csv(temp_csv, dtype={"ArxivID": str})