diff --git a/litresearch.toml.example b/litresearch.toml.example index a4a1501..98340dc 100644 --- a/litresearch.toml.example +++ b/litresearch.toml.example @@ -95,6 +95,19 @@ abstract_fallback = true # PDFs should be named: {paper_id}.pdf or {doi}.pdf (slashes replaced with underscores) # inject_pdf_dir = "/path/to/pdfs" +# ============================================================================ +# Query Expansion (Optional) +# ============================================================================ + +# Enable iterative query expansion after initial enrichment +enable_query_expansion = true + +# Maximum number of expansion queries to generate +max_expansion_queries = 2 + +# Number of top candidates (by citation count) to sample for expansion analysis +expansion_candidate_sample = 30 + # ============================================================================ # Citation Expansion (Optional) # ============================================================================ @@ -105,6 +118,12 @@ expand_citations = false # Minimum number of cross-references required to include a paper min_cross_refs = 3 +# Enable foundational paper detection (papers cited by many candidates) +enable_foundational_detection = true + +# Number of foundational papers to identify +foundational_papers_count = 5 + # ============================================================================ # Zotero Export (Optional) # ============================================================================ diff --git a/src/litresearch/cli.py b/src/litresearch/cli.py index 56e886c..896503a 100644 --- a/src/litresearch/cli.py +++ b/src/litresearch/cli.py @@ -20,7 +20,6 @@ def _build_settings( top_n: int | None = None, output_dir: str | None = None, threshold: int | None = None, - inject_pdf_dir: str | None = None, ) -> Settings: """Load settings and apply CLI overrides.""" overrides = { @@ -30,7 +29,6 @@ def _build_settings( "top_n": top_n, "output_dir": output_dir, "screening_threshold": threshold, - "inject_pdf_dir": inject_pdf_dir, }.items() if value is not None } @@ -51,9 +49,6 @@ def config() -> None: console.print(f"screening_threshold={settings.screening_threshold}") console.print(f"top_n={settings.top_n}") console.print(f"max_results_per_query={settings.max_results_per_query}") - console.print(f"pdf_first_pages={settings.pdf_first_pages}") - console.print(f"pdf_last_pages={settings.pdf_last_pages}") - console.print(f"inject_pdf_dir={settings.inject_pdf_dir}") console.print(f"output_dir={settings.output_dir}") console.print(f"s2_api_key_configured={bool(settings.s2_api_key)}") console.print(f"llm_api_key_configured={settings.has_llm_api_key}") @@ -73,19 +68,6 @@ def run( bool, typer.Option("--overwrite", help="Overwrite existing output directory."), ] = False, - inject_pdfs: Annotated[ - Path | None, - typer.Option( - "--inject-pdfs", help="Directory containing PDFs to inject by paper_id or DOI" - ), - ] = None, - stop_after_screening: Annotated[ - bool, - typer.Option( - "--stop-after-screening", - help="Stop after screening to review papers needing PDFs before analysis", - ), - ] = False, ) -> None: """Run the literature research pipeline.""" settings = _build_settings( @@ -93,15 +75,12 @@ def run( top_n=top_n, output_dir=output_dir, threshold=threshold, - inject_pdf_dir=str(inject_pdfs) if inject_pdfs is not None else None, ) state = run_pipeline( questions, settings, overwrite=overwrite, - inject_pdfs_dir=inject_pdfs, - stop_after_screening=stop_after_screening, ) if state.screened_papers_completed and not state.analyses: console.print( @@ -121,12 +100,6 @@ def resume( int | None, typer.Option("--threshold", help="Override the screening threshold."), ] = None, - inject_pdfs: Annotated[ - Path | None, - typer.Option( - "--inject-pdfs", help="Directory containing PDFs to inject by paper_id or DOI" - ), - ] = None, ) -> None: """Resume the literature research pipeline from saved state.""" settings = _build_settings( @@ -134,10 +107,9 @@ def resume( top_n=top_n, output_dir=output_dir, threshold=threshold, - inject_pdf_dir=str(inject_pdfs) if inject_pdfs is not None else None, ) - state = run_pipeline([], settings, resume_path=Path(state_file), inject_pdfs_dir=inject_pdfs) + state = run_pipeline([], settings, resume_path=Path(state_file)) console.print(f"[green]Resume complete.[/green] Output: {state.output_dir}") diff --git a/src/litresearch/config.py b/src/litresearch/config.py index 7a47bbc..67a9945 100644 --- a/src/litresearch/config.py +++ b/src/litresearch/config.py @@ -47,7 +47,7 @@ def settings_customise_sources( max_retries: int = 3 retry_base_delay: float = 1.0 llm_timeout: int = 120 - default_model: str = "openai/gpt-4o-mini" + default_model: str = "openai/gpt-5.4-mini" screening_selection_mode: Literal["top_percent", "threshold", "top_k"] = "top_percent" screening_top_percent: float = 0.3 # 0-1; used when screening_selection_mode=top_percent screening_top_k: int | None = None # used when screening_selection_mode=top_k @@ -59,9 +59,16 @@ def settings_customise_sources( discovery_sources: list[str] = ["s2"] openalex_email: str | None = None + # Query expansion + enable_query_expansion: bool = True + max_expansion_queries: int = 2 + expansion_candidate_sample: int = 30 + # Citation expansion expand_citations: bool = False min_cross_refs: int = 3 + enable_foundational_detection: bool = True + foundational_papers_count: int = 5 # Zotero export zotero_library_id: str | None = None @@ -71,12 +78,6 @@ def settings_customise_sources( zotero_tag: str | None = None zotero_export: bool = False - pdf_first_pages: int = 4 - pdf_last_pages: int = 2 - pdf_extraction_mode: Literal["budget", "pages"] = "budget" - pdf_token_budget: int = 4000 - abstract_fallback: bool = True - inject_pdf_dir: str | None = None output_dir: str = "output" @computed_field diff --git a/src/litresearch/exporters/zotero.py b/src/litresearch/exporters/zotero.py index a382547..22e2884 100644 --- a/src/litresearch/exporters/zotero.py +++ b/src/litresearch/exporters/zotero.py @@ -1,6 +1,5 @@ """Zotero export integration.""" -from pathlib import Path from typing import Any from rich.console import Console @@ -73,9 +72,6 @@ def export_to_zotero( if paper.doi: item["DOI"] = paper.doi - if paper.open_access_pdf_url: - item["url"] = paper.open_access_pdf_url - if collection_key: item["collections"] = [collection_key] @@ -90,17 +86,6 @@ def create_item(payload: dict[str, Any] = item) -> dict[str, Any]: if result.get("successful"): successful += 1 - - if paper.pdf_path: - try: - pdf_full_path = Path(paper.pdf_path) - if pdf_full_path.exists(): - item_key = list(result["successful"].values())[0]["key"] - zot.attachment_simple([str(pdf_full_path)], item_key) - except Exception as exc: # noqa: BLE001 - console.print( - f"[yellow]Failed to attach PDF for {paper.title}:[/yellow] {exc}" - ) else: failed.append(f"{paper.title}: {result.get('failed', 'Unknown error')}") diff --git a/src/litresearch/models.py b/src/litresearch/models.py index f4192a1..964a2a6 100644 --- a/src/litresearch/models.py +++ b/src/litresearch/models.py @@ -27,7 +27,6 @@ class S2PaperLike(Protocol): citationCount: int | None venue: str | None externalIds: dict[str, str] | None - openAccessPdf: dict[str, str] | None citationStyles: dict[str, str] | None @@ -57,26 +56,14 @@ class Paper(BaseModel): citation_count: int = 0 venue: str | None = None doi: str | None = None - open_access_pdf_url: str | None = None bibtex: str | None = None source: Literal["s2", "openalex", "both", "citation_expansion"] = "s2" - pdf_path: str | None = None - pdf_status: Literal["not_attempted", "downloaded", "unavailable", "user_provided"] = ( - "not_attempted" - ) - data_completeness: Literal["full", "abstract_only", "metadata_only"] = "full" - - @property - def pdf_downloaded(self) -> bool: - """Backwards-compatible indicator for downloaded or provided PDFs.""" - return self.pdf_status in {"downloaded", "user_provided"} or self.pdf_path is not None @classmethod def from_s2(cls, s2_paper: S2PaperLike) -> "Paper": """Create a normalized paper model from a Semantic Scholar paper object.""" external_ids = s2_paper.externalIds or {} - open_access_pdf = s2_paper.openAccessPdf or {} citation_styles = s2_paper.citationStyles or {} authors = s2_paper.authors or [] @@ -90,7 +77,6 @@ def from_s2(cls, s2_paper: S2PaperLike) -> "Paper": citation_count=s2_paper.citationCount or 0, venue=html.unescape(s2_paper.venue) if s2_paper.venue else None, doi=external_ids.get("DOI"), - open_access_pdf_url=open_access_pdf.get("url"), bibtex=citation_styles.get("bibtex"), source="s2", ) @@ -141,13 +127,11 @@ class RunMetrics(BaseModel): total_analyzed: int = 0 total_exported: int = 0 citation_expanded: int = 0 + expansion_queries_generated: int = 0 + foundational_papers: int = 0 sources: dict[str, int] = Field(default_factory=dict) - pdfs_downloaded: int = 0 - pdfs_user_provided: int = 0 - pdfs_unavailable: int = 0 - class PipelineState(BaseModel): """Serializable pipeline state for fresh runs and resume.""" @@ -159,7 +143,9 @@ class PipelineState(BaseModel): screening_results: list[ScreeningResult] = Field(default_factory=list) analyses: list[AnalysisResult] = Field(default_factory=list) ranked_paper_ids: list[str] = Field(default_factory=list) + foundational_paper_ids: list[str] = Field(default_factory=list) screened_papers_completed: bool = False + query_expansion_run: bool = False current_stage: str output_dir: str created_at: str diff --git a/src/litresearch/pdf.py b/src/litresearch/pdf.py deleted file mode 100644 index 6ab1641..0000000 --- a/src/litresearch/pdf.py +++ /dev/null @@ -1,90 +0,0 @@ -"""PDF download and extraction helpers.""" - -from io import BytesIO - -import httpx -from pypdf import PdfReader -from rich.console import Console - -console = Console() - - -def extract_text( - pdf_bytes: bytes, - token_budget: int = 4000, - keywords: list[str] | None = None, -) -> str | None: - """Extract text from PDF with token budget and keyword scoring. - - Args: - pdf_bytes: Raw PDF bytes - token_budget: Maximum tokens to extract (approx 4 chars per token) - keywords: List of keywords to prioritize when selecting pages - - Returns: - Extracted text or None if extraction fails - """ - try: - reader = PdfReader(BytesIO(pdf_bytes)) - except Exception: # noqa: BLE001 - return None - - page_count = len(reader.pages) - if page_count == 0: - return None - - pages: list[tuple[int, str]] = [] - for i in range(page_count): - try: - text = reader.pages[i].extract_text() or "" - if text.strip(): - pages.append((i, text.strip())) - except Exception: # noqa: BLE001 - continue - - if not pages: - return None - - if keywords and len(pages) > 1: - keyword_set = {keyword.lower() for keyword in keywords} - scored_pages: list[tuple[int, int, str]] = [] - for idx, text in pages: - text_lower = text.lower() - score = sum(1 for keyword in keyword_set if keyword in text_lower) - if idx == 0: - score += 1 - scored_pages.append((score, idx, text)) - - scored_pages.sort(key=lambda item: (-item[0], item[1])) - pages = [(idx, text) for _, idx, text in scored_pages] - - max_chars = token_budget * 4 - parts: list[str] = [] - total_chars = 0 - - for idx, text in pages: - page_header = f"\n--- Page {idx + 1} ---\n" - chunk = page_header + text - - if total_chars + len(chunk) > max_chars and parts: - break - - parts.append(chunk) - total_chars += len(chunk) - - if total_chars >= max_chars: - break - - return "\n".join(parts).strip() if parts else None - - -def download_pdf(url: str) -> bytes | None: - """Download a PDF and return its bytes on success.""" - try: - response = httpx.get(url, follow_redirects=True, timeout=30.0) - response.raise_for_status() - except Exception as exc: # noqa: BLE001 - console.print(f"[yellow]Failed to download PDF:[/yellow] {url} ({exc})") - return None - - return response.content diff --git a/src/litresearch/pipeline.py b/src/litresearch/pipeline.py index 622c1aa..37a3aba 100644 --- a/src/litresearch/pipeline.py +++ b/src/litresearch/pipeline.py @@ -15,10 +15,10 @@ discovery, enrichment, export, + query_expansion, query_gen, ranking, ) -from litresearch.stages.analysis import PauseForPDFsError console = Console() @@ -72,16 +72,9 @@ def _populate_aggregate_metrics(metrics: RunMetrics, state: PipelineState) -> Ru "total_analyzed": len(state.analyses), "total_exported": len(state.ranked_paper_ids), "citation_expanded": source_counts.get("citation_expansion", 0), + "expansion_queries_generated": metrics.expansion_queries_generated, + "foundational_papers": len(state.foundational_paper_ids), "sources": source_counts, - "pdfs_downloaded": sum( - 1 for paper in state.candidates if paper.pdf_status == "downloaded" - ), - "pdfs_user_provided": sum( - 1 for paper in state.candidates if paper.pdf_status == "user_provided" - ), - "pdfs_unavailable": sum( - 1 for paper in state.candidates if paper.pdf_status == "unavailable" - ), } ) @@ -91,8 +84,6 @@ def run_pipeline( settings: Settings, resume_path: Path | None = None, overwrite: bool = False, - inject_pdfs_dir: Path | None = None, - stop_after_screening: bool = False, ) -> PipelineState: """Run the configured pipeline from scratch or from a saved state.""" start_time = time.perf_counter() @@ -137,10 +128,6 @@ def run_pipeline( else: metrics = RunMetrics(run_id=f"run-{uuid.uuid4().hex[:12]}", started_at=started_at) - effective_inject_pdfs_dir = inject_pdfs_dir - if effective_inject_pdfs_dir is None and settings.inject_pdf_dir: - effective_inject_pdfs_dir = Path(settings.inject_pdf_dir) - for stage_name in STAGE_ORDER[start_index:]: console.print(f"[bold blue]Running stage:[/bold blue] {stage_name}") started = time.perf_counter() @@ -151,14 +138,7 @@ def run_pipeline( ) stage_runner = STAGES[stage_name] try: - if stage_name == "analysis": - state = stage_runner( - state, - settings, - inject_pdfs_dir=effective_inject_pdfs_dir, - stop_after_screening=stop_after_screening, - ) - elif stage_name == "export": + if stage_name == "export": state = stage_runner(state, settings, run_metrics=metrics) else: state = stage_runner(state, settings) @@ -174,21 +154,6 @@ def run_pipeline( state.save(state_path) metrics = _populate_aggregate_metrics(metrics, state) _write_metrics(output_dir, metrics) - except PauseForPDFsError as pause_exc: - # Not a failure - user chose to pause for manual PDF injection - # Preserve screening results and mark screening as completed - checkpoint_state = state.model_copy( - update={ - "screening_results": pause_exc.screening_results, - "screened_papers_completed": True, - "updated_at": _timestamp(), - } - ) - checkpoint_state.save(state_path) - console.print("\n[bold yellow]Pipeline paused at screening checkpoint.[/bold yellow]") - console.print(f"State saved to: {state_path}") - state = checkpoint_state - return state except Exception as exc: # noqa: BLE001 stage_metrics = stage_metrics.model_copy( update={ @@ -208,6 +173,89 @@ def run_pipeline( elapsed = time.perf_counter() - started console.print(f"[green]Completed[/green] {stage_name} in {elapsed:.2f}s") + # --- Post-enrichment: Iterative Query Expansion --- + if ( + stage_name == "enrichment" + and settings.enable_query_expansion + and not state.query_expansion_run + ): + queries_before = len(state.search_queries) + + console.print("[bold blue]Running stage:[/bold blue] query_expansion") + exp_started = time.perf_counter() + exp_stage_metrics = StageMetrics( + name="query_expansion", + started_at=_timestamp(), + input_count=len(state.candidates), + ) + try: + state = query_expansion.run(state, settings) + queries_generated = len(state.search_queries) - queries_before + exp_stage_metrics = exp_stage_metrics.model_copy( + update={ + "completed_at": _timestamp(), + "duration_seconds": time.perf_counter() - exp_started, + "output_count": queries_generated, + } + ) + metrics = metrics.model_copy( + update={ + "stages": [*metrics.stages, exp_stage_metrics], + "expansion_queries_generated": queries_generated, + } + ) + state = state.model_copy(update={"updated_at": _timestamp()}) + state.save(state_path) + metrics = _populate_aggregate_metrics(metrics, state) + _write_metrics(output_dir, metrics) + + if queries_generated > 0: + console.print( + f"[green]Generated {queries_generated} expansion queries," + f" re-running discovery and enrichment...[/green]" + ) + for sub_stage in ["discovery", "enrichment"]: + console.print( + f"[bold blue]Running stage (expansion):[/bold blue] {sub_stage}" + ) + sub_started = time.perf_counter() + sub_metrics = StageMetrics( + name=f"{sub_stage} (expansion)", + started_at=_timestamp(), + input_count=_stage_count(sub_stage, state), + ) + try: + state = STAGES[sub_stage](state, settings) + sub_metrics = sub_metrics.model_copy( + update={ + "completed_at": _timestamp(), + "duration_seconds": time.perf_counter() - sub_started, + "output_count": _stage_count(sub_stage, state), + } + ) + metrics = metrics.model_copy( + update={"stages": [*metrics.stages, sub_metrics]} + ) + state = state.model_copy(update={"updated_at": _timestamp()}) + state.save(state_path) + metrics = _populate_aggregate_metrics(metrics, state) + _write_metrics(output_dir, metrics) + sub_elapsed = time.perf_counter() - sub_started + console.print( + f"[green]Completed[/green] {sub_stage} (expansion)" + f" in {sub_elapsed:.2f}s" + ) + except Exception as sub_exc: # noqa: BLE001 + console.print( + f"[yellow]Expansion {sub_stage} failed" + f" ({sub_exc}), continuing...[/yellow]" + ) + break + else: + console.print("[dim]Query expansion generated no new queries[/dim]") + except Exception as exc: # noqa: BLE001 + console.print(f"[yellow]Query expansion failed ({exc}), continuing...[/yellow]") + metrics = _populate_aggregate_metrics(metrics, state) metrics = metrics.model_copy( update={ diff --git a/src/litresearch/prompts/analysis.md b/src/litresearch/prompts/analysis.md index 5a0a836..aa8a5c7 100644 --- a/src/litresearch/prompts/analysis.md +++ b/src/litresearch/prompts/analysis.md @@ -26,4 +26,4 @@ Your task is to produce a structured analysis of a paper using the provided meta The user will provide: - research questions - paper metadata -- extracted PDF text or a note that only abstract-level information is available +- extracted paper text or a note that only abstract-level information is available diff --git a/src/litresearch/prompts/query_expansion.md b/src/litresearch/prompts/query_expansion.md new file mode 100644 index 0000000..9789b74 --- /dev/null +++ b/src/litresearch/prompts/query_expansion.md @@ -0,0 +1,20 @@ +You are assisting with academic literature research. The user has already run an initial search and collected candidate papers. Your task is to identify underexplored directions, gaps, or promising angles that are not well covered by the current results. + +## Instructions +- Review the research questions and the abstracts of the initial candidate papers. +- Identify 1-2 specific sub-topics, methodological angles, or related concepts that are missing or underrepresented. +- Generate concise, targeted search queries that would help fill these gaps. +- Each query should be a standalone search string suitable for academic databases like Semantic Scholar. +- Return JSON only. + +## Output Format +{ + "queries": [ + {"query": "search query text", "facet": "short label for the angle"} + ] +} + +## Input +The user will provide: +- research questions +- a sample of initial candidate paper abstracts diff --git a/src/litresearch/prompts/screening_fallback.md b/src/litresearch/prompts/screening_fallback.md index 51278b9..9184ff0 100644 --- a/src/litresearch/prompts/screening_fallback.md +++ b/src/litresearch/prompts/screening_fallback.md @@ -6,8 +6,6 @@ This paper does NOT have an abstract available. You must screen based on availab - Title (always available) - Venue name (if available) - Citation count and publication year (metadata signals) -- Any available PDF text excerpt - ## Scoring Guidance (BE CONSERVATIVE - bias toward inclusion) - 80-100: title/venue strongly suggests direct relevance to the research questions - 60-79: title/venue suggests likely relevance @@ -30,4 +28,4 @@ This paper does NOT have an abstract available. You must screen based on availab ## Input The user will provide: - research questions -- all available signals (title, venue, authors, year, citation count, any PDF excerpt) +- all available signals (title, venue, authors, year, citation count) diff --git a/src/litresearch/sources/openalex.py b/src/litresearch/sources/openalex.py index 92cf24f..a7f3f55 100644 --- a/src/litresearch/sources/openalex.py +++ b/src/litresearch/sources/openalex.py @@ -45,7 +45,6 @@ def search_papers(self, query: str, limit: int = 20) -> list[dict[str, Any]]: params={ "search": query, "per_page": min(limit, 200), - "filter": "has_pdf:true", }, headers=self.headers, timeout=self.timeout, @@ -72,9 +71,6 @@ def work_to_paper(work: dict[str, Any]) -> Paper | None: if isinstance(doi, str) and doi.startswith("https://doi.org/"): doi = doi[16:] - oa_info = work.get("open_access", {}) or {} - oa_url = oa_info.get("oa_url") if oa_info.get("is_oa") else None - primary_location = work.get("primary_location", {}) or {} source = primary_location.get("source", {}) if primary_location else {} venue = source.get("display_name") if source else None @@ -94,7 +90,6 @@ def work_to_paper(work: dict[str, Any]) -> Paper | None: citation_count=work.get("cited_by_count", 0), venue=venue, doi=doi if isinstance(doi, str) else None, - open_access_pdf_url=oa_url, bibtex=None, source="openalex", ) diff --git a/src/litresearch/stages/analysis.py b/src/litresearch/stages/analysis.py index 02ca7f8..013e2b1 100644 --- a/src/litresearch/stages/analysis.py +++ b/src/litresearch/stages/analysis.py @@ -2,8 +2,6 @@ import math import re -from pathlib import Path -from typing import Literal from pydantic import BaseModel, Field from rich.console import Console @@ -12,9 +10,8 @@ from litresearch.config import Settings from litresearch.llm import LLMError, call_llm from litresearch.models import AnalysisResult, Paper, PipelineState, ScreeningResult -from litresearch.pdf import download_pdf, extract_text from litresearch.prompts import load_prompt -from litresearch.utils import parse_llm_json, safe_filename +from litresearch.utils import parse_llm_json console = Console() @@ -84,108 +81,30 @@ def _build_keywords(questions: list[str], title: str) -> list[str]: return unique -def _injected_pdf_path(paper: Paper, inject_pdfs_dir: Path | None) -> Path | None: - if inject_pdfs_dir is None: - return None - - inject_dir_resolved = inject_pdfs_dir.resolve() - - for candidate in [safe_filename(paper.paper_id)]: - candidate_path = (inject_dir_resolved / f"{candidate}.pdf").resolve() - if ( - inject_dir_resolved not in candidate_path.parents - and candidate_path != inject_dir_resolved - ): - continue - if candidate_path.exists(): - return candidate_path - - if paper.doi: - for candidate in [safe_filename(paper.doi)]: - candidate_path = (inject_dir_resolved / f"{candidate}.pdf").resolve() - if ( - inject_dir_resolved not in candidate_path.parents - and candidate_path != inject_dir_resolved - ): - continue - if candidate_path.exists(): - return candidate_path - - return None - - -def _screening_pdf_excerpt( - paper: Paper, - questions: list[str], - settings: Settings, - inject_pdfs_dir: Path | None, -) -> str | None: - keywords = _build_keywords(questions, paper.title) - - injected_path = _injected_pdf_path(paper, inject_pdfs_dir) - if injected_path is not None: - try: - pdf_bytes = injected_path.read_bytes() - except Exception: # noqa: BLE001 - pdf_bytes = None - if pdf_bytes is not None: - return extract_text( - pdf_bytes, token_budget=settings.pdf_token_budget, keywords=keywords - ) - - if paper.open_access_pdf_url: - pdf_bytes = download_pdf(paper.open_access_pdf_url) - if pdf_bytes is not None: - return extract_text( - pdf_bytes, token_budget=settings.pdf_token_budget, keywords=keywords - ) - - return None - - def _screen_paper( paper: Paper, questions: list[str], settings: Settings, prompt: str, - fallback_prompt: str, - pdf_excerpt: str | None = None, ) -> ScreeningResult | None: - if paper.abstract: - selected_prompt = prompt - user_content = "\n".join( - [ - "Research questions:", - *[f"- {question}" for question in questions], - "", - f"Title: {paper.title}", - f"Authors: {', '.join(paper.authors) or 'Unknown'}", - f"Year: {paper.year or 'Unknown'}", - f"Venue: {paper.venue or 'Unknown'}", - f"Abstract: {paper.abstract}", - ] - ) - else: - if not settings.abstract_fallback: - return None - selected_prompt = fallback_prompt - user_content = "\n".join( - [ - "Research questions:", - *[f"- {question}" for question in questions], - "", - "Available signals:", - f"- Title: {paper.title}", - f"- Authors: {', '.join(paper.authors) or 'Unknown'}", - f"- Year: {paper.year or 'Unknown'}", - f"- Venue: {paper.venue or 'Unknown'}", - f"- Citation count: {paper.citation_count}", - f"- PDF excerpt: {pdf_excerpt or 'Unavailable'}", - ] - ) + if not paper.abstract: + return None + + user_content = "\n".join( + [ + "Research questions:", + *[f"- {question}" for question in questions], + "", + f"Title: {paper.title}", + f"Authors: {', '.join(paper.authors) or 'Unknown'}", + f"Year: {paper.year or 'Unknown'}", + f"Venue: {paper.venue or 'Unknown'}", + f"Abstract: {paper.abstract}", + ] + ) try: - response = call_llm(settings, selected_prompt, user_content) + response = call_llm(settings, prompt, user_content) except LLMError as exc: console.print(f"[yellow]Screening failed:[/yellow] {paper.title} ({exc})") return None @@ -203,59 +122,8 @@ def _analyze_paper( questions: list[str], settings: Settings, prompt: str, - output_dir: str, - inject_pdfs_dir: Path | None, ) -> tuple[AnalysisResult | None, Paper]: - papers_dir = Path(output_dir) / "papers" - keywords = _build_keywords(questions, paper.title) - - pdf_text: str | None = None - pdf_path: str | None = None - pdf_status: Literal["downloaded", "unavailable", "user_provided"] = "unavailable" - - injected_path = _injected_pdf_path(paper, inject_pdfs_dir) - if injected_path is not None: - try: - pdf_bytes = injected_path.read_bytes() - except Exception as exc: # noqa: BLE001 - console.print(f"[yellow]Failed to read injected PDF:[/yellow] {injected_path} ({exc})") - pdf_bytes = None - if pdf_bytes is not None: - papers_dir.mkdir(parents=True, exist_ok=True) - target_path = papers_dir / f"{safe_filename(paper.paper_id)}.pdf" - target_path.write_bytes(pdf_bytes) - pdf_path = str(target_path) - pdf_status = "user_provided" - pdf_text = extract_text( - pdf_bytes, token_budget=settings.pdf_token_budget, keywords=keywords - ) - elif paper.open_access_pdf_url: - pdf_bytes = download_pdf(paper.open_access_pdf_url) - if pdf_bytes is not None: - papers_dir.mkdir(parents=True, exist_ok=True) - target_path = papers_dir / f"{safe_filename(paper.paper_id)}.pdf" - target_path.write_bytes(pdf_bytes) - pdf_path = str(target_path) - pdf_status = "downloaded" - pdf_text = extract_text( - pdf_bytes, token_budget=settings.pdf_token_budget, keywords=keywords - ) - - data_completeness: Literal["full", "abstract_only", "metadata_only"] = "metadata_only" - if paper.abstract and pdf_text: - data_completeness = "full" - elif paper.abstract: - data_completeness = "abstract_only" - - updated_paper = paper.model_copy( - update={ - "pdf_status": pdf_status, - "pdf_path": pdf_path, - "data_completeness": data_completeness, - } - ) - - extracted_text = pdf_text or "Only abstract-level information is available." + extracted_text = paper.abstract or "Only abstract-level information is available." user_content = "\n".join( [ "Research questions:", @@ -276,47 +144,23 @@ def _analyze_paper( response = call_llm(settings, prompt, user_content) except LLMError as exc: console.print(f"[yellow]Analysis failed:[/yellow] {paper.title} ({exc})") - return (None, updated_paper) + return (None, paper) payload = parse_llm_json(response, _AnalysisPayload, console=console) if payload is None: console.print(f"[yellow]JSON parse failed:[/yellow] {paper.title}") - return (None, updated_paper) + return (None, paper) - return (AnalysisResult(paper_id=paper.paper_id, **payload), updated_paper) - - -class PauseForPDFsError(Exception): - """Raised when pipeline should pause after screening for manual PDF injection.""" - - def __init__( - self, - papers_needing_pdfs: list[Paper], - state_path: str, - screening_results: list[ScreeningResult] | None = None, - ) -> None: - self.papers_needing_pdfs = papers_needing_pdfs - self.state_path = state_path - self.screening_results = screening_results or [] - super().__init__(f"{len(papers_needing_pdfs)} papers need manual PDFs") + return (AnalysisResult(paper_id=paper.paper_id, **payload), paper) def run( state: PipelineState, settings: Settings, - inject_pdfs_dir: Path | None = None, - stop_after_screening: bool = False, ) -> PipelineState: """Screen candidate papers and analyze the relevant ones.""" screening_prompt = load_prompt("screening") - screening_fallback_prompt = load_prompt("screening_fallback") analysis_prompt = load_prompt("analysis") - if inject_pdfs_dir is not None and not inject_pdfs_dir.exists(): - console.print( - "[yellow]Inject PDFs directory not found:[/yellow] " - f"{inject_pdfs_dir}. Continuing without injection." - ) - inject_pdfs_dir = None papers_by_id = {paper.paper_id: paper for paper in state.candidates} @@ -338,19 +182,11 @@ def run( screening_results: list[ScreeningResult] = [] screened_papers: list[tuple[Paper, ScreeningResult, int]] = [] for index, paper in enumerate(track(state.candidates, description="Screening papers")): - pdf_excerpt = None - if not paper.abstract: - pdf_excerpt = _screening_pdf_excerpt( - paper, state.questions, settings, inject_pdfs_dir - ) - screening_result = _screen_paper( paper, state.questions, settings, screening_prompt, - screening_fallback_prompt, - pdf_excerpt=pdf_excerpt, ) if screening_result is None: continue @@ -360,41 +196,6 @@ def run( passed_papers = _select_papers_for_analysis(screened_papers, settings) - # Check if we should stop after screening for manual PDF injection - if stop_after_screening: - papers_needing_pdfs = [ - paper - for paper in passed_papers - if paper.pdf_status in ("unavailable", "not_attempted") - and not paper.open_access_pdf_url - and not _injected_pdf_path(paper, inject_pdfs_dir) - ] - - if papers_needing_pdfs: - console.print( - "\n[bold yellow]" - f"{len(papers_needing_pdfs)} papers passed screening but need PDFs:[/bold yellow]" - ) - for i, paper in enumerate(papers_needing_pdfs[:10], 1): - console.print(f" {i}. {paper.title}") - console.print(f" ID: {paper.paper_id}") - if paper.doi: - console.print(f" DOI: {paper.doi}") - console.print() - - if len(papers_needing_pdfs) > 10: - console.print(f" ... and {len(papers_needing_pdfs) - 10} more\n") - - console.print("[bold]Options:[/bold]") - console.print(" 1. Source these PDFs manually, then resume with:") - console.print( - f" litresearch resume {state.output_dir}/state.json --inject-pdfs " - ) - console.print(" 2. Continue without PDFs (analysis will use abstracts only):") - console.print(f" litresearch resume {state.output_dir}/state.json\n") - - raise PauseForPDFsError(papers_needing_pdfs, state.output_dir, screening_results) - analyses: list[AnalysisResult] = [] for paper in track(passed_papers, description="Analyzing papers"): analysis_result, updated_paper = _analyze_paper( @@ -402,8 +203,6 @@ def run( state.questions, settings, analysis_prompt, - state.output_dir, - inject_pdfs_dir, ) papers_by_id[paper.paper_id] = updated_paper if analysis_result is not None: diff --git a/src/litresearch/stages/citation_expansion.py b/src/litresearch/stages/citation_expansion.py index b7a7659..185f7a1 100644 --- a/src/litresearch/stages/citation_expansion.py +++ b/src/litresearch/stages/citation_expansion.py @@ -38,7 +38,6 @@ def _paper_from_cited_data(cited: dict[str, Any]) -> Paper | None: authors.append(name) external_ids = _as_dict(cited.get("externalIds") or cited.get("external_ids") or {}) - open_access_pdf = _as_dict(cited.get("openAccessPdf") or cited.get("open_access_pdf") or {}) return Paper( paper_id=paper_id, @@ -49,7 +48,6 @@ def _paper_from_cited_data(cited: dict[str, Any]) -> Paper | None: citation_count=cited.get("citationCount") or cited.get("citation_count") or 0, venue=cited.get("venue"), doi=external_ids.get("DOI"), - open_access_pdf_url=open_access_pdf.get("url"), source="citation_expansion", ) @@ -75,6 +73,7 @@ def run(state: PipelineState, settings: Settings) -> PipelineState: top_paper_ids = state.ranked_paper_ids[: settings.top_n] existing_ids = {paper.paper_id for paper in state.candidates} + in_set_reference_counts: dict[str, int] = {} reference_counts: dict[str, int] = {} reference_papers: dict[str, Paper] = {} @@ -114,6 +113,7 @@ def fetch_references(*, current_paper_id: str = paper_id) -> Any: continue if ref_id in existing_ids: + in_set_reference_counts[ref_id] = in_set_reference_counts.get(ref_id, 0) + 1 continue reference_counts[ref_id] = reference_counts.get(ref_id, 0) + 1 @@ -138,9 +138,19 @@ def fetch_references(*, current_paper_id: str = paper_id) -> Any: console.print(f"[green]Found {len(expanded_papers)} frequently referenced works[/green]") + if settings.enable_foundational_detection: + foundational = sorted(in_set_reference_counts.items(), key=lambda x: -x[1])[ + : settings.foundational_papers_count + ] + foundational_paper_ids = [pid for pid, _ in foundational] + console.print(f"[green]Found {len(foundational_paper_ids)} foundational papers[/green]") + else: + foundational_paper_ids = [] + return state.model_copy( update={ "candidates": [*state.candidates, *expanded_papers], + "foundational_paper_ids": foundational_paper_ids, "current_stage": "citation_expansion", } ) diff --git a/src/litresearch/stages/discovery.py b/src/litresearch/stages/discovery.py index a775aff..a59820e 100644 --- a/src/litresearch/stages/discovery.py +++ b/src/litresearch/stages/discovery.py @@ -27,7 +27,6 @@ "year", "citationCount", "venue", - "openAccessPdf", "externalIds", "citationStyles", ] @@ -68,8 +67,6 @@ def _metadata_score(paper: Paper) -> int: score = 0 if paper.abstract: score += 4 - if paper.open_access_pdf_url: - score += 4 if paper.doi: score += 2 if paper.authors: @@ -116,7 +113,6 @@ def _merge_papers(existing: Paper, incoming: Paper) -> Paper: "citation_count": max(existing.citation_count, incoming.citation_count), "venue": primary.venue or secondary.venue, "doi": _normalize_doi(primary.doi) or _normalize_doi(secondary.doi), - "open_access_pdf_url": primary.open_access_pdf_url or secondary.open_access_pdf_url, "bibtex": primary.bibtex or secondary.bibtex, "corpus_id": primary.corpus_id if primary.corpus_id is not None diff --git a/src/litresearch/stages/enrichment.py b/src/litresearch/stages/enrichment.py index 2bb27ab..84de8a4 100644 --- a/src/litresearch/stages/enrichment.py +++ b/src/litresearch/stages/enrichment.py @@ -19,7 +19,6 @@ "year", "citationCount", "venue", - "openAccessPdf", "externalIds", "citationStyles", ] diff --git a/src/litresearch/stages/export.py b/src/litresearch/stages/export.py index 91310e6..5e95d6a 100644 --- a/src/litresearch/stages/export.py +++ b/src/litresearch/stages/export.py @@ -3,15 +3,12 @@ from pathlib import Path from rich.console import Console -from rich.progress import track from litresearch.config import Settings from litresearch.exporters.zotero import export_to_zotero from litresearch.llm import LLMError, call_llm from litresearch.models import AnalysisResult, Paper, PipelineState, RunMetrics -from litresearch.pdf import download_pdf from litresearch.prompts import load_prompt -from litresearch.utils import safe_filename console = Console() @@ -27,8 +24,6 @@ def _format_ris_entry(paper: Paper) -> str: lines.append(f"JO - {paper.venue}") if paper.doi: lines.append(f"DO - {paper.doi}") - if paper.open_access_pdf_url: - lines.append(f"UR - {paper.open_access_pdf_url}") lines.append("ER -") return "\n".join(lines) @@ -72,11 +67,9 @@ def run( settings: Settings, run_metrics: RunMetrics | None = None, ) -> PipelineState: - """Write reports, reference files, JSON data, and PDFs.""" + """Write reports, reference files, and JSON data.""" output_dir = Path(state.output_dir) output_dir.mkdir(parents=True, exist_ok=True) - papers_dir = output_dir / "papers" - papers_dir.mkdir(parents=True, exist_ok=True) papers_by_id = {paper.paper_id: paper for paper in state.candidates} analyses_by_id = {analysis.paper_id: analysis for analysis in state.analyses} @@ -91,13 +84,6 @@ def run( synthesis = _build_synthesis(state.questions, top_analyses, settings) - # Identify papers needing manual PDF sourcing - papers_needing_pdfs = [ - paper - for paper in top_papers - if paper.pdf_status in ("unavailable", "not_attempted") and paper.paper_id in analyses_by_id - ] - report_lines = [ "# Literature Research Report", "", @@ -128,46 +114,24 @@ def run( "", ] ) - report_lines.extend(["## Synthesis", "", synthesis]) - - # Add section for papers needing manual PDFs - if papers_needing_pdfs: - report_lines.extend( - [ - "", - "## Papers Requiring Manual PDF Sourcing", - "", - "The following high-relevance papers were analyzed using abstracts only. ", - "To improve analysis quality, you can manually source these PDFs and re-run:", - "", - ] - ) - for paper in papers_needing_pdfs: - analysis = analyses_by_id.get(paper.paper_id) + report_lines.extend(["", "## Foundational Papers", ""]) + if state.foundational_paper_ids: + for paper_id in state.foundational_paper_ids: + paper = papers_by_id.get(paper_id) + if paper is None: + continue report_lines.extend( [ f"### {paper.title}", - f"- **Paper ID**: `{paper.paper_id}`", - f"- **Authors**: {', '.join(paper.authors) or 'Unknown'}", - f"- **Year**: {paper.year or 'Unknown'}", - f"- **Venue**: {paper.venue or 'Unknown'}", - f"- **DOI**: {paper.doi or 'N/A'}", - f"- **Relevance Score**: {analysis.relevance_score if analysis else 'N/A'}", + f"- Authors: {', '.join(paper.authors) or 'Unknown'}", + f"- Year: {paper.year or 'Unknown'}", + f"- Venue: {paper.venue or 'Unknown'}", "", ] ) - report_lines.extend( - [ - "### How to Add These PDFs", - "", - "1. Source the PDFs via your institutional access, contacting authors,", - " or other means", - "2. Save them to a directory with filenames matching the Paper ID", - " (e.g., `abc123.pdf`)", - "3. Re-run with: `litresearch run 'your question' --inject-pdfs `", - "", - ] - ) + else: + report_lines.append("No foundational papers identified.") + report_lines.extend(["## Synthesis", "", synthesis]) (output_dir / "report.md").write_text("\n".join(report_lines).strip() + "\n", encoding="utf-8") @@ -206,29 +170,9 @@ def run( ris_content + ("\n" if ris_content else ""), encoding="utf-8" ) - updated_candidates = list(state.candidates) - for paper in track(top_papers, description="Downloading PDFs"): - if not paper.open_access_pdf_url: - continue - if paper.pdf_status in ("downloaded", "user_provided"): - continue - pdf_bytes = download_pdf(paper.open_access_pdf_url) - if pdf_bytes is None: - continue - target_path = papers_dir / f"{safe_filename(paper.paper_id)}.pdf" - target_path.write_bytes(pdf_bytes) - updated_paper = paper.model_copy( - update={ - "pdf_status": "downloaded", - "pdf_path": str(target_path), - } - ) - papers_by_id[paper.paper_id] = updated_paper - - updated_candidates = [papers_by_id[paper.paper_id] for paper in state.candidates] updated_state = state.model_copy( update={ - "candidates": updated_candidates, + "candidates": list(state.candidates), "current_stage": "export", } ) @@ -237,29 +181,6 @@ def run( encoding="utf-8", ) - # Write papers needing PDFs as JSON for programmatic access - if papers_needing_pdfs: - import json - - needing_pdfs_data = [] - for paper in papers_needing_pdfs: - analysis = analyses_by_id.get(paper.paper_id) - needing_pdfs_data.append( - { - "paper_id": paper.paper_id, - "title": paper.title, - "authors": paper.authors, - "year": paper.year, - "venue": paper.venue, - "doi": paper.doi, - "relevance_score": analysis.relevance_score if analysis else None, - } - ) - (output_dir / "papers_needing_pdfs.json").write_text( - json.dumps(needing_pdfs_data, indent=2) + "\n", - encoding="utf-8", - ) - if run_metrics is not None: (output_dir / "metrics.json").write_text( run_metrics.model_dump_json(indent=2) + "\n", diff --git a/src/litresearch/stages/query_expansion.py b/src/litresearch/stages/query_expansion.py new file mode 100644 index 0000000..2c57f63 --- /dev/null +++ b/src/litresearch/stages/query_expansion.py @@ -0,0 +1,87 @@ +"""Stage: iterative query expansion.""" + +from __future__ import annotations + +from rich.console import Console + +from litresearch.config import Settings +from litresearch.llm import LLMError, call_llm +from litresearch.models import PipelineState, SearchQuery +from litresearch.prompts import load_prompt +from litresearch.utils import parse_llm_json + +console = Console() + + +def _build_expansion_input(state: PipelineState, sample_size: int) -> str: + """Build the user prompt input from research questions and candidate abstracts.""" + parts: list[str] = [] + parts.append("Research questions:") + for question in state.questions: + parts.append(f"- {question}") + + # Sample top papers by citation count for the overview + sorted_candidates = sorted(state.candidates, key=lambda p: p.citation_count, reverse=True) + sample = sorted_candidates[:sample_size] + + parts.append(f"\nInitial candidate papers (sample of {len(sample)}, top by citations):") + for i, paper in enumerate(sample, start=1): + abstract = paper.abstract or "(no abstract)" + if len(abstract) > 500: + abstract = abstract[:500] + "..." + parts.append( + f"\n{i}. {paper.title} ({paper.year or 'n/a'}, {paper.venue or 'n/a'})\n {abstract}" + ) + + return "\n".join(parts) + + +def run(state: PipelineState, settings: Settings) -> PipelineState: + """Generate expansion search queries from initial candidate overview.""" + if state.query_expansion_run: + return state + + if not state.candidates: + console.print("[dim]No candidates available for query expansion[/dim]") + return state.model_copy(update={"query_expansion_run": True}) + + prompt = load_prompt("query_expansion") + user_content = _build_expansion_input(state, sample_size=settings.expansion_candidate_sample) + + try: + response = call_llm(settings, prompt, user_content) + except LLMError as exc: + console.print(f"[yellow]Query expansion LLM call failed, skipping:[/yellow] {exc}") + return state.model_copy(update={"query_expansion_run": True}) + + payload = parse_llm_json(response, console=console) + if payload is None: + console.print("[yellow]Query expansion returned invalid JSON, skipping[/yellow]") + return state.model_copy(update={"query_expansion_run": True}) + + queries_raw = payload.get("queries", []) + if not isinstance(queries_raw, list) or len(queries_raw) == 0: + console.print("[dim]Query expansion generated no queries[/dim]") + return state.model_copy(update={"query_expansion_run": True}) + + new_queries: list[SearchQuery] = [] + for item in queries_raw: + if not isinstance(item, dict): + continue + query_text = str(item.get("query", "")).strip() + facet_label = str(item.get("facet", "expansion")).strip() + if query_text: + new_queries.append(SearchQuery(query=query_text, facet=facet_label)) + + max_queries = settings.max_expansion_queries + if len(new_queries) > max_queries: + new_queries = new_queries[:max_queries] + + console.print(f"[green]Generated {len(new_queries)} expansion queries[/green]") + + return state.model_copy( + update={ + "search_queries": [*state.search_queries, *new_queries], + "query_expansion_run": True, + } + ) diff --git a/tests/unit/test_analysis.py b/tests/unit/test_analysis.py index 2cd0fd6..39c0f03 100644 --- a/tests/unit/test_analysis.py +++ b/tests/unit/test_analysis.py @@ -1,8 +1,6 @@ -import json - from litresearch.config import Settings from litresearch.models import AnalysisResult, Paper, PipelineState, ScreeningResult -from litresearch.stages.analysis import PauseForPDFsError, _injected_pdf_path, run +from litresearch.stages.analysis import run def _make_state(tmp_path, *, papers=None, screening_results=None, screened_papers_completed=False): @@ -18,113 +16,6 @@ def _make_state(tmp_path, *, papers=None, screening_results=None, screened_paper ) -def test_injected_pdf_path_rejects_path_traversal(tmp_path) -> None: - """Test that path traversal attempts are rejected in PDF injection.""" - inject_dir = tmp_path / "pdfs" - inject_dir.mkdir() - - # Create a safe PDF file - safe_paper = Paper( - paper_id="safe_paper", - title="Safe Paper", - abstract="Abstract", - authors=[], - year=2024, - citation_count=10, - source="s2", - ) - (inject_dir / "safe_paper.pdf").write_bytes(b"%PDF-1.0") - - # Test that safe path works - result = _injected_pdf_path(safe_paper, inject_dir) - assert result is not None - assert result.name == "safe_paper.pdf" - - # Test path traversal attempt with malicious paper_id - malicious_paper = Paper( - paper_id="../../../etc/passwd", - title="Malicious Paper", - abstract="Abstract", - authors=[], - year=2024, - citation_count=0, - source="s2", - ) - result = _injected_pdf_path(malicious_paper, inject_dir) - assert result is None - - # Test path traversal with null bytes - null_byte_paper = Paper( - paper_id="safe\x00../../../etc/passwd", - title="Null Byte Paper", - abstract="Abstract", - authors=[], - year=2024, - citation_count=0, - source="s2", - ) - result = _injected_pdf_path(null_byte_paper, inject_dir) - assert result is None - - -def test_analysis_saves_pdf_and_marks_candidate_downloaded(tmp_path, monkeypatch) -> None: - state = PipelineState( - questions=["q"], - candidates=[ - Paper( - paper_id="p1", - title="One", - abstract="abstract", - open_access_pdf_url="https://example.com/p1.pdf", - ) - ], - ranked_paper_ids=[], - current_stage="enrichment", - output_dir=str(tmp_path), - created_at="2026-03-09T16:00:00Z", - updated_at="2026-03-09T16:00:00Z", - ) - - import litresearch.stages.analysis as analysis_stage - - monkeypatch.setattr(analysis_stage, "load_prompt", lambda _name: "prompt") - monkeypatch.setattr( - analysis_stage, - "_screen_paper", - lambda paper, questions, settings, prompt, screening_fallback_prompt, pdf_excerpt=None: ( - ScreeningResult( - paper_id=paper.paper_id, - relevance_score=100, - rationale="fit", - ) - ), - ) - monkeypatch.setattr(analysis_stage, "download_pdf", lambda _url: b"%PDF-1.0") - monkeypatch.setattr(analysis_stage, "extract_text", lambda *_args, **_kwargs: "body") - monkeypatch.setattr( - analysis_stage, - "call_llm", - lambda settings, system_prompt, user_content: json.dumps( - { - "summary": "summary", - "key_findings": ["finding"], - "methodology": "experiment", - "relevance_score": 80, - "relevance_rationale": "fit", - } - ), - ) - - updated_state = run(state, Settings()) - - assert updated_state.candidates[0].pdf_status == "downloaded" - # pdf_path may be absolute or relative depending on implementation - assert updated_state.candidates[0].pdf_path - assert "p1.pdf" in updated_state.candidates[0].pdf_path - assert (tmp_path / "papers" / "p1.pdf").read_bytes() == b"%PDF-1.0" - assert len(updated_state.analyses) == 1 - - def test_analysis_skips_screening_when_already_completed(tmp_path, monkeypatch) -> None: """When screened_papers_completed is True and screening_results exist, screening is skipped.""" import litresearch.stages.analysis as analysis_stage @@ -156,7 +47,7 @@ def _fail_if_called(*args, **kwargs): monkeypatch.setattr( analysis_stage, "_analyze_paper", - lambda paper, questions, settings, prompt, output_dir, inject_pdfs_dir=None: ( + lambda paper, questions, settings, prompt: ( AnalysisResult( paper_id=paper.paper_id, summary="summary", @@ -168,9 +59,6 @@ def _fail_if_called(*args, **kwargs): paper, ), ) - import litresearch.models as models - - monkeypatch.setattr(models, "AnalysisResult", models.AnalysisResult, raising=False) settings = Settings(screening_selection_mode="top_k", screening_top_k=2) updated = run(state, settings) @@ -180,37 +68,40 @@ def _fail_if_called(*args, **kwargs): assert updated.screened_papers_completed is True -def test_pause_for_pdfs_carries_screening_results(tmp_path, monkeypatch) -> None: - """PauseForPDFsError includes screening_results for checkpoint state.""" - import litresearch.stages.analysis as analysis_stage - +def test_paper_without_abstract_is_skipped(tmp_path, monkeypatch) -> None: + """Papers without abstract are skipped (return None from screening).""" papers = [ - Paper( - paper_id="p1", - title="P1", - abstract="a", - citation_count=10, - pdf_status="not_attempted", - ), + Paper(paper_id="p1", title="P1", abstract=None, citation_count=10), + Paper(paper_id="p2", title="P2", abstract="a", citation_count=5), ] state = _make_state(tmp_path, papers=papers) - settings = Settings(screening_selection_mode="top_k", screening_top_k=1) + + import litresearch.stages.analysis as analysis_stage monkeypatch.setattr(analysis_stage, "load_prompt", lambda _name: "prompt") monkeypatch.setattr( analysis_stage, - "_screen_paper", - lambda paper, questions, settings, prompt, fb_prompt, pdf_excerpt=None: ScreeningResult( - paper_id=paper.paper_id, - relevance_score=90, - rationale="fit", + "call_llm", + lambda *a, **kw: '{"relevance_score": 90, "rationale": "fit"}', + ) + monkeypatch.setattr( + analysis_stage, + "_analyze_paper", + lambda paper, questions, settings, prompt: ( + AnalysisResult( + paper_id=paper.paper_id, + summary="summary", + key_findings=["finding"], + methodology="experiment", + relevance_score=90, + relevance_rationale="fit", + ), + paper, ), ) - try: - run(state, settings, stop_after_screening=True) - except PauseForPDFsError as exc: - assert len(exc.screening_results) == 1 - assert exc.screening_results[0].paper_id == "p1" - else: - raise AssertionError("Expected PauseForPDFsError") + settings = Settings(screening_selection_mode="top_k", screening_top_k=1) + updated = run(state, settings) + + assert len(updated.screening_results) == 1 + assert updated.screening_results[0].paper_id == "p2" diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index a92d8d0..05c3d3a 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -44,5 +44,4 @@ def test_resume_help_shows_expected_options() -> None: assert "final top-N cutoff" in output assert "output directory" in output assert "screening threshold" in output - assert "--inject-pdfs" in output - assert "Directory containing PDFs" in output + assert "Resume the literature research pipeline" in output diff --git a/tests/unit/test_export.py b/tests/unit/test_export.py index e31768c..70e9922 100644 --- a/tests/unit/test_export.py +++ b/tests/unit/test_export.py @@ -13,7 +13,6 @@ def minimal_state(tmp_path) -> PipelineState: Paper( paper_id="p1", title="One", - open_access_pdf_url="https://example.com/p1.pdf", ) ], analyses=[ @@ -37,7 +36,6 @@ def minimal_state(tmp_path) -> PipelineState: def test_export_writes_report(minimal_state, monkeypatch, tmp_path) -> None: monkeypatch.setattr("litresearch.stages.export.load_prompt", lambda _: "") monkeypatch.setattr("litresearch.stages.export.call_llm", lambda *a, **kw: "synthesis") - monkeypatch.setattr("litresearch.stages.export.download_pdf", lambda _: None) run(minimal_state, Settings()) diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index 0cd1e89..c834e5c 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -37,7 +37,6 @@ def test_paper_from_s2_normalizes_fields() -> None: citationCount=12, venue="ICSE", externalIds={"DOI": "10.1234/example"}, - openAccessPdf={"url": "https://example.com/paper.pdf"}, citationStyles={"bibtex": "@article{example}"}, ) @@ -46,7 +45,6 @@ def test_paper_from_s2_normalizes_fields() -> None: assert paper.paper_id == "paper-123" assert paper.authors == ["Ada Lovelace", "Alan Turing"] assert paper.doi == "10.1234/example" - assert paper.open_access_pdf_url == "https://example.com/paper.pdf" assert paper.bibtex == "@article{example}" diff --git a/tests/unit/test_pdf.py b/tests/unit/test_pdf.py deleted file mode 100644 index f584ab1..0000000 --- a/tests/unit/test_pdf.py +++ /dev/null @@ -1,43 +0,0 @@ -import litresearch.pdf as pdf - - -class FakePage: - def __init__(self, text: str): - self._text = text - - def extract_text(self) -> str: - return self._text - - -class FakeReader: - def __init__(self, _stream) -> None: - self.pages = [ - FakePage("Page 1 text"), - FakePage("Page 2 text"), - FakePage("Page 3 text"), - FakePage("Page 4 text"), - FakePage("Page 5 text"), - ] - - -class BrokenReader: - def __init__(self, _stream) -> None: - raise ValueError("bad pdf") - - -def test_extract_text_returns_none_on_invalid_pdf(monkeypatch) -> None: - monkeypatch.setattr(pdf, "PdfReader", BrokenReader) - - assert pdf.extract_text(b"not a pdf") is None - - -def test_extract_text_returns_all_pages_within_budget(monkeypatch) -> None: - monkeypatch.setattr(pdf, "PdfReader", FakeReader) - - text = pdf.extract_text(b"%PDF", token_budget=10000) - assert text is not None - assert "Page 1 text" in text - assert "Page 2 text" in text - assert "Page 3 text" in text - assert "Page 4 text" in text - assert "Page 5 text" in text diff --git a/tests/unit/test_sources_openalex.py b/tests/unit/test_sources_openalex.py index c48e500..31caa9f 100644 --- a/tests/unit/test_sources_openalex.py +++ b/tests/unit/test_sources_openalex.py @@ -94,7 +94,6 @@ def test_work_to_paper_converts_correctly(self) -> None: assert paper.year == 2023 assert paper.citation_count == 100 assert paper.doi == "10.1234/test" - assert paper.open_access_pdf_url == "https://example.com/test.pdf" assert paper.venue == "Nature" assert paper.source == "openalex" diff --git a/tests/unit/test_stages_citation_expansion.py b/tests/unit/test_stages_citation_expansion.py index 4375120..783bd4b 100644 --- a/tests/unit/test_stages_citation_expansion.py +++ b/tests/unit/test_stages_citation_expansion.py @@ -215,3 +215,438 @@ def test_respects_rate_limit(self, tmp_path: pytest.TempPathFactory) -> None: mock_sleep.assert_called_once() assert mock_sleep.call_args.args[0] == pytest.approx(0.8) + + def test_foundational_detection_enabled(self, tmp_path: pytest.TempPathFactory) -> None: + """Test that foundational papers are detected when references overlap candidates.""" + settings = Settings( + expand_citations=True, + s2_api_key=None, + s2_timeout=10, + top_n=2, + min_cross_refs=1, + enable_foundational_detection=True, + foundational_papers_count=3, + ) + existing_paper = Paper( + paper_id="existing1", + title="Existing Paper", + abstract="Abstract", + authors=["Author One"], + year=2023, + citation_count=80, + source="s2", + ) + state = PipelineState( + questions=["Test?"], + candidates=[existing_paper], + ranked_paper_ids=["paper1", "paper2"], + current_stage="ranking", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + # Both top papers reference "existing1" — that makes it foundational + mock_references = SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing1", + title="Existing Paper", + year=2023, + citationCount=80, + authors=[SimpleNamespace(name="Author One")], + ) + ), + ] + ) + + mock_scholar = MagicMock() + mock_scholar.get_paper_references.return_value = mock_references + + with patch( + "litresearch.stages.citation_expansion.SemanticScholar", + return_value=mock_scholar, + ): + result = run(state, settings) + + # "existing1" is cited twice (once by each top paper) — should be foundational + assert "existing1" in result.foundational_paper_ids + assert len(result.foundational_paper_ids) == 1 + + def test_foundational_detection_disabled(self, tmp_path: pytest.TempPathFactory) -> None: + """Test that foundational detection is skipped when disabled.""" + settings = Settings( + expand_citations=True, + s2_api_key=None, + s2_timeout=10, + top_n=2, + min_cross_refs=1, + enable_foundational_detection=False, + ) + existing_paper = Paper( + paper_id="existing1", + title="Existing Paper", + abstract="Abstract", + authors=[], + year=2023, + citation_count=80, + source="s2", + ) + state = PipelineState( + questions=["Test?"], + candidates=[existing_paper], + ranked_paper_ids=["paper1", "paper2"], + current_stage="ranking", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + mock_references = SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing1", + title="Existing Paper", + year=2023, + citationCount=80, + authors=[], + ) + ), + ] + ) + + mock_scholar = MagicMock() + mock_scholar.get_paper_references.return_value = mock_references + + with patch( + "litresearch.stages.citation_expansion.SemanticScholar", + return_value=mock_scholar, + ): + result = run(state, settings) + + assert result.foundational_paper_ids == [] + + def test_foundational_paper_ids_sorted_by_count( + self, + tmp_path: pytest.TempPathFactory, + ) -> None: + """Test that foundational papers are sorted by citation count (descending).""" + settings = Settings( + expand_citations=True, + s2_api_key=None, + s2_timeout=10, + top_n=3, + min_cross_refs=1, + enable_foundational_detection=True, + foundational_papers_count=3, + ) + existing_a = Paper( + paper_id="existing_a", + title="Paper A", + abstract="A", + authors=[], + year=2023, + citation_count=80, + source="s2", + ) + existing_b = Paper( + paper_id="existing_b", + title="Paper B", + abstract="B", + authors=[], + year=2023, + citation_count=70, + source="s2", + ) + existing_c = Paper( + paper_id="existing_c", + title="Paper C", + abstract="C", + authors=[], + year=2023, + citation_count=60, + source="s2", + ) + state = PipelineState( + questions=["Test?"], + candidates=[existing_a, existing_b, existing_c], + ranked_paper_ids=["paper1", "paper2", "paper3"], + current_stage="ranking", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + # paper1 references all three, paper2 references b and c, paper3 references only c + calls = [ + SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_a", + title="Paper A", + year=2023, + citationCount=80, + authors=[], + ) + ), + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_b", + title="Paper B", + year=2023, + citationCount=70, + authors=[], + ) + ), + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_c", + title="Paper C", + year=2023, + citationCount=60, + authors=[], + ) + ), + ] + ), + SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_b", + title="Paper B", + year=2023, + citationCount=70, + authors=[], + ) + ), + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_c", + title="Paper C", + year=2023, + citationCount=60, + authors=[], + ) + ), + ] + ), + SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_c", + title="Paper C", + year=2023, + citationCount=60, + authors=[], + ) + ), + ] + ), + ] + + mock_scholar = MagicMock() + mock_scholar.get_paper_references.side_effect = calls + + with patch( + "litresearch.stages.citation_expansion.SemanticScholar", + return_value=mock_scholar, + ): + result = run(state, settings) + + # A cited 1 time, B cited 2 times, C cited 3 times → order: C, B, A + assert result.foundational_paper_ids == ["existing_c", "existing_b", "existing_a"] + + def test_foundational_respects_count_setting( + self, + tmp_path: pytest.TempPathFactory, + ) -> None: + """Test that foundational_papers_count limits results.""" + settings = Settings( + expand_citations=True, + s2_api_key=None, + s2_timeout=10, + top_n=3, + min_cross_refs=1, + enable_foundational_detection=True, + foundational_papers_count=2, + ) + existing_a = Paper( + paper_id="existing_a", + title="Paper A", + abstract="A", + authors=[], + year=2023, + citation_count=80, + source="s2", + ) + existing_b = Paper( + paper_id="existing_b", + title="Paper B", + abstract="B", + authors=[], + year=2023, + citation_count=70, + source="s2", + ) + existing_c = Paper( + paper_id="existing_c", + title="Paper C", + abstract="C", + authors=[], + year=2023, + citation_count=60, + source="s2", + ) + state = PipelineState( + questions=["Test?"], + candidates=[existing_a, existing_b, existing_c], + ranked_paper_ids=["paper1", "paper2", "paper3"], + current_stage="ranking", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + # All three papers referenced the same way as the sort test + calls = [ + SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_a", + title="A", + year=2023, + citationCount=80, + authors=[], + ), + ), + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_b", + title="B", + year=2023, + citationCount=70, + authors=[], + ), + ), + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_c", + title="C", + year=2023, + citationCount=60, + authors=[], + ), + ), + ] + ), + SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_b", + title="B", + year=2023, + citationCount=70, + authors=[], + ), + ), + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_c", + title="C", + year=2023, + citationCount=60, + authors=[], + ), + ), + ] + ), + SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing_c", + title="C", + year=2023, + citationCount=60, + authors=[], + ), + ), + ] + ), + ] + + mock_scholar = MagicMock() + mock_scholar.get_paper_references.side_effect = calls + + with patch( + "litresearch.stages.citation_expansion.SemanticScholar", + return_value=mock_scholar, + ): + result = run(state, settings) + + # foundational_papers_count=2 → only top 2 by count + assert result.foundational_paper_ids == ["existing_c", "existing_b"] + + def test_foundational_paper_ids_in_state(self, tmp_path: pytest.TempPathFactory) -> None: + """Test that foundational_paper_ids is stored in returned state.""" + settings = Settings( + expand_citations=True, + s2_api_key=None, + s2_timeout=10, + top_n=1, + min_cross_refs=1, + enable_foundational_detection=True, + ) + existing_paper = Paper( + paper_id="existing1", + title="Existing Paper", + abstract="Abstract", + authors=[], + year=2023, + citation_count=80, + source="s2", + ) + state = PipelineState( + questions=["Test?"], + candidates=[existing_paper], + ranked_paper_ids=["paper1"], + current_stage="ranking", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + mock_references = SimpleNamespace( + items=[ + SimpleNamespace( + citedPaper=SimpleNamespace( + paperId="existing1", + title="Existing Paper", + year=2023, + citationCount=80, + authors=[], + ) + ), + ] + ) + + mock_scholar = MagicMock() + mock_scholar.get_paper_references.return_value = mock_references + + with patch( + "litresearch.stages.citation_expansion.SemanticScholar", + return_value=mock_scholar, + ): + result = run(state, settings) + + assert result.foundational_paper_ids == ["existing1"] + # Verify it's serializable + json_data = result.model_dump_json() + assert "foundational_paper_ids" in json_data diff --git a/tests/unit/test_stages_query_expansion.py b/tests/unit/test_stages_query_expansion.py new file mode 100644 index 0000000..b98d692 --- /dev/null +++ b/tests/unit/test_stages_query_expansion.py @@ -0,0 +1,179 @@ +"""Tests for query expansion stage.""" + +import json +from unittest.mock import patch + +from litresearch.config import Settings +from litresearch.llm import LLMError +from litresearch.models import Paper, PipelineState +from litresearch.stages.query_expansion import run + + +def _dummy_state(tmp_path) -> PipelineState: + """Build a state with candidates for testing.""" + return PipelineState( + questions=["How do LLMs affect developer productivity?"], + search_queries=[], + candidates=[ + Paper( + paper_id="p1", + title="Paper One", + abstract="An abstract about LLMs and coding.", + citation_count=50, + year=2024, + venue="ICSE", + ), + Paper( + paper_id="p2", + title="Paper Two", + abstract="Deep learning for software engineering.", + citation_count=30, + year=2023, + venue="FSE", + ), + ], + current_stage="enrichment", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + +class TestQueryExpansion: + def test_skips_when_already_run(self, tmp_path) -> None: + """Stage returns state unchanged when query_expansion_run is True.""" + settings = Settings(default_model="test-model") + state = _dummy_state(tmp_path).model_copy(update={"query_expansion_run": True}) + + # Should not call LLM at all + result = run(state, settings) + + assert result is state # same object, unchanged + assert result.query_expansion_run is True + + def test_generates_queries_from_candidates(self, tmp_path) -> None: + """Stage generates expansion queries from candidate abstracts.""" + settings = Settings( + default_model="test-model", + max_expansion_queries=2, + expansion_candidate_sample=30, + ) + state = _dummy_state(tmp_path) + + mock_response = json.dumps( + { + "queries": [ + { + "query": "developer experience LLM code completion", + "facet": "developer experience", + }, + { + "query": "LLM debugging assistance studies", + "facet": "debugging", + }, + ] + } + ) + + with patch( + "litresearch.stages.query_expansion.call_llm", + return_value=mock_response, + ): + result = run(state, settings) + + assert result.query_expansion_run is True + assert len(result.search_queries) == 2 + assert result.search_queries[0].query == "developer experience LLM code completion" + assert result.search_queries[0].facet == "developer experience" + assert result.search_queries[1].query == "LLM debugging assistance studies" + assert result.search_queries[1].facet == "debugging" + # Original state unchanged + assert len(state.search_queries) == 0 + + def test_handles_llm_failure_gracefully(self, tmp_path) -> None: + """Stage returns state with query_expansion_run=True on LLM failure.""" + settings = Settings(default_model="test-model") + state = _dummy_state(tmp_path) + + with patch( + "litresearch.stages.query_expansion.call_llm", + side_effect=LLMError("API error"), + ): + result = run(state, settings) + + # Must return a PipelineState, not raise + assert isinstance(result, PipelineState) + assert result.query_expansion_run is True + # State is otherwise unchanged (no new queries) + assert len(result.search_queries) == len(state.search_queries) + + def test_handles_invalid_json(self, tmp_path) -> None: + """Stage handles malformed JSON response gracefully.""" + settings = Settings(default_model="test-model") + state = _dummy_state(tmp_path) + + with patch( + "litresearch.stages.query_expansion.call_llm", + return_value="not json", + ): + result = run(state, settings) + + assert result.query_expansion_run is True + assert len(result.search_queries) == len(state.search_queries) + + def test_handles_empty_queries(self, tmp_path) -> None: + """Stage handles JSON with empty queries list.""" + settings = Settings(default_model="test-model") + state = _dummy_state(tmp_path) + + with patch( + "litresearch.stages.query_expansion.call_llm", + return_value=json.dumps({"queries": []}), + ): + result = run(state, settings) + + assert result.query_expansion_run is True + assert len(result.search_queries) == len(state.search_queries) + + def test_skips_when_no_candidates(self, tmp_path) -> None: + """Stage skips expansion when there are no candidates.""" + settings = Settings(default_model="test-model") + state = PipelineState( + questions=["Test question"], + current_stage="enrichment", + output_dir=str(tmp_path), + created_at="2024-01-01", + updated_at="2024-01-01", + ) + + result = run(state, settings) + + assert result.query_expansion_run is True + assert len(result.search_queries) == 0 + + def test_respects_max_expansion_queries(self, tmp_path) -> None: + """Stage caps the number of queries to max_expansion_queries.""" + settings = Settings( + default_model="test-model", + max_expansion_queries=1, + ) + state = _dummy_state(tmp_path) + + mock_response = json.dumps( + { + "queries": [ + {"query": "query one", "facet": "facet1"}, + {"query": "query two", "facet": "facet2"}, + {"query": "query three", "facet": "facet3"}, + ] + } + ) + + with patch( + "litresearch.stages.query_expansion.call_llm", + return_value=mock_response, + ): + result = run(state, settings) + + assert len(result.search_queries) == 1 + assert result.search_queries[0].query == "query one" diff --git a/tests/unit/test_stages_screening.py b/tests/unit/test_stages_screening.py index 4aa227a..84875c2 100644 --- a/tests/unit/test_stages_screening.py +++ b/tests/unit/test_stages_screening.py @@ -1,7 +1,6 @@ """Tests for screening and analysis stage.""" from collections.abc import Callable -from pathlib import Path from unittest.mock import patch import pytest @@ -34,8 +33,6 @@ def _stub( questions: list[str], settings: Settings, prompt: str, - output_dir: str, - inject_pdfs_dir: Path | None = None, ) -> tuple[AnalysisResult | None, Paper]: analyzed_ids.append(paper.paper_id) return ( @@ -52,36 +49,6 @@ def _stub( return _stub - def test_paper_without_abstract_gets_zero_score(self, tmp_path, monkeypatch) -> None: - """Test that papers without abstract get screening result with score 0.""" - settings = Settings(default_model="test-model", screening_selection_mode="top_percent") - - paper_no_abstract = Paper( - paper_id="123", - title="Test Paper", - authors=["Author"], - year=2024, - abstract=None, - ) - - state = self._state_with_papers(tmp_path, [paper_no_abstract]) - - monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") - monkeypatch.setattr( - "litresearch.stages.analysis._screen_paper", - lambda paper, questions, settings, prompt, fb_prompt, pdf_excerpt=None: ScreeningResult( - paper_id=paper.paper_id, - relevance_score=0, - rationale="no abstract available", - ), - ) - - result = run(state, settings) - - assert len(result.screening_results) == 1 - assert result.screening_results[0].relevance_score == 0 - assert "no abstract available" in result.screening_results[0].rationale - def test_top_percent_selection_analyzes_global_top_share(self, tmp_path, monkeypatch) -> None: """Test global top-percent selection after screening.""" settings = Settings(screening_selection_mode="top_percent", screening_top_percent=0.4) @@ -98,7 +65,7 @@ def test_top_percent_selection_analyzes_global_top_share(self, tmp_path, monkeyp monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") monkeypatch.setattr( "litresearch.stages.analysis._screen_paper", - lambda paper, questions, settings, prompt, fb_prompt, pdf_excerpt=None: ScreeningResult( + lambda paper, questions, settings, prompt: ScreeningResult( paper_id=paper.paper_id, relevance_score=scores[paper.paper_id], rationale="fit", @@ -127,7 +94,7 @@ def test_top_k_selection_uses_tiebreakers(self, tmp_path, monkeypatch) -> None: monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") monkeypatch.setattr( "litresearch.stages.analysis._screen_paper", - lambda paper, questions, settings, prompt, fb_prompt, pdf_excerpt=None: ScreeningResult( + lambda paper, questions, settings, prompt: ScreeningResult( paper_id=paper.paper_id, relevance_score=scores[paper.paper_id], rationale="fit", @@ -156,7 +123,7 @@ def test_threshold_selection_mode_still_supported(self, tmp_path, monkeypatch) - monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") monkeypatch.setattr( "litresearch.stages.analysis._screen_paper", - lambda paper, questions, settings, prompt, fb_prompt, pdf_excerpt=None: ScreeningResult( + lambda paper, questions, settings, prompt: ScreeningResult( paper_id=paper.paper_id, relevance_score=scores[paper.paper_id], rationale="fit", @@ -179,7 +146,7 @@ def test_invalid_top_percent_raises_value_error(self, tmp_path, monkeypatch) -> monkeypatch.setattr("litresearch.stages.analysis.load_prompt", lambda _name: "prompt") monkeypatch.setattr( "litresearch.stages.analysis._screen_paper", - lambda paper, questions, settings, prompt, fb_prompt, pdf_excerpt=None: ScreeningResult( + lambda paper, questions, settings, prompt: ScreeningResult( paper_id=paper.paper_id, relevance_score=90, rationale="fit", @@ -204,6 +171,6 @@ def test_json_parse_failure_skips_paper(self) -> None: settings = Settings(default_model="test-model") with patch("litresearch.stages.analysis.call_llm", return_value="invalid json"): - result = _screen_paper(paper, ["question"], settings, "prompt", "fallback_prompt") + result = _screen_paper(paper, ["question"], settings, "prompt") assert result is None