diff --git a/.gitignore b/.gitignore index 3a786ce..3e56785 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,4 @@ tests/ *.tgz frontend/node_modules/ frontend/dist/ +.claude/ diff --git a/Cargo.lock b/Cargo.lock index c45af5e..25d8cd8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -859,7 +859,7 @@ dependencies = [ "libc", "percent-encoding", "pin-project-lite", - "socket2 0.5.10", + "socket2 0.6.2", "tokio", "tower-service", "tracing", @@ -1466,6 +1466,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "psl" +version = "2.1.190" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fed3dc7578357ff12137c75eac73413b6aba9a7204916c19f2a0e9e1e920e0" +dependencies = [ + "psl-types", +] + +[[package]] +name = "psl-types" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33cb294fe86a74cbcf50d4445b37da762029549ebeea341421c7c70370f86cac" + [[package]] name = "quote" version = "1.0.44" @@ -1654,7 +1669,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1869,6 +1884,7 @@ dependencies = [ "futures", "hickory-resolver", "neo4rs", + "psl", "regex", "reqwest", "thiserror 2.0.18", @@ -2006,7 +2022,7 @@ dependencies = [ "getrandom 0.3.4", "once_cell", "rustix", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/docs/api-reference.md b/docs/api-reference.md index 41dbfc5..db82200 100644 --- a/docs/api-reference.md +++ b/docs/api-reference.md @@ -20,13 +20,14 @@ Start a new crawl from a given URL. |-------|------|----------|-------------| | `url` | string | Yes | The URL to crawl (must be http or https) | | `depth` | integer | Yes | Maximum link depth to follow (1–5, where 1 = root only) | +| `targeted` | boolean | No | When `true`, only follow links within the same registered domain (eTLD+1) as the root URL. Defaults to `false`. | **Example:** ```bash curl -X POST http://localhost:8080/api/v1/crawls \ -H 'Content-Type: application/json' \ - -d '{"url": "https://example.com", "depth": 2}' + -d '{"url": "https://example.com", "depth": 2, "targeted": true}' ``` **Response:** `201 Created` @@ -84,7 +85,8 @@ curl "http://localhost:8080/api/v1/crawls?status=running&limit=10" "total": 42, "completed": 40, "failed": 2, - "cancelled": 0 + "cancelled": 0, + "targeted": true } ], "total": 1, @@ -128,7 +130,8 @@ curl http://localhost:8080/api/v1/crawls/d262a3e7-19de-437f-b0a4-cf1d689b1caf "failed": 60, "cancelled": 0, "root_url": "https://example.com", - "requested_depth": 3 + "requested_depth": 3, + "targeted": false } ``` diff --git a/docs/project-vision.md b/docs/project-vision.md new file mode 100644 index 0000000..f00127b --- /dev/null +++ b/docs/project-vision.md @@ -0,0 +1,47 @@ +# Web crawler vision +Create a free, open-source, deployable platform for Red & Blue teams that want to discover the web attack surface of their applications. + +## About +This file should be used as general guidelines for development. When design decisions are made, this doc should define the "spirit" of those decisions. + +## My philosophy +1. Don't reinvent the wheel - There is code written by smarter people than you. Be humble and use well-established code and tools. +2. Open Source - This platform should be open and transparent for everyone to contribute, share, and use. +3. Respect others - Use this platform for the betterment of software and products. Make the world better than you found it. +4. Have fun - The process of creating things should be fun. There will be chores, but enjoy the process. + + +## Design Principles (Derived from above) +These principles are a collection of coding and design rules I personally came across and found to work. A lot of this is based on other people's design principles. + +--- + +### Don't reinvent the wheel + +#### Adopt mainstream tools +Use well-established tools from other open-source projects. Only create custom tools when it's absolutely necessary. + +#### Keep it simple stupid +Keep the project as simple as possible. The more moving parts, the less scalable it becomes, and the more things break. + +### Open Source + +#### All source code is public +The project vision is to be an open source platform for blue & red teams, anyone can contribute. + +#### All source code should be free for individuals +This platform should always be free for individuals, and for the foreseeable future, for anyone. The code license should reflect that. + +### Respect others + +#### Respectful crawling +Rate limiting, robots.txt awareness, and polite user-agent strings by default. The tool should be hard to misuse for DoS or abuse. + +### Have fun + +#### Visualization graph should be fun to use and explore +The visuals and tools for exploring the graph should be fun for the user, possibly gamified. + +#### Project theme should be fun +The theme of this project should be cartoony, playful, and fun. The main theme is cobweb (as it's a crawler). + diff --git a/feeder/src/job.rs b/feeder/src/job.rs index c13e16b..5d619af 100644 --- a/feeder/src/job.rs +++ b/feeder/src/job.rs @@ -1,4 +1,4 @@ -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use neo4rs::{query, Graph}; @@ -16,6 +16,8 @@ pub struct UrlJob { pub current_depth: i64, pub attempts: Option, pub crawl_id: String, + pub targeted: bool, + pub target_domain: String, } /// Represents a child node to be created in Neo4j. @@ -28,6 +30,8 @@ struct ChildNode { current_depth: i64, request_time: String, crawl_id: String, + targeted: bool, + target_domain: String, } /// Atomically fetches and claims a single URL job from Neo4j. @@ -64,6 +68,8 @@ pub async fn fetch_job(graph: &Graph, stale_timeout: i64) -> Result("attempts").ok(), crawl_id: node.get("crawl_id").unwrap_or_default(), + targeted: node.get::("targeted").unwrap_or(false), + target_domain: node.get::("target_domain").unwrap_or_default(), })) } None => Ok(None), @@ -110,23 +116,23 @@ async fn validate_job( tracing::warn!("Request failed: {} -- Attempts: {} -- Error: {}", full_url, attempts, e); - if attempts >= config.max_attempts { - tracing::error!( - "Failure limit reached! Giving up on {} after {} attempts.", - full_url, - attempts - ); + // 4xx errors are permanent — fail immediately without retry + let is_permanent = matches!(e, CrawlerError::HttpStatus { status, .. } if (400..500).contains(&status)); + + if is_permanent || attempts >= config.max_attempts { + if !is_permanent { + tracing::error!( + "Failure limit reached! Giving up on {} after {} attempts.", + full_url, + attempts + ); + } update_job_status(graph, job, "FAILED", Some(attempts)).await?; } else { - // Fix: reset to PENDING so other feeders can retry + // Reset to PENDING so other feeders can retry update_job_status(graph, job, "PENDING", Some(attempts)).await?; } - // Return permanent failures (4xx) as immediate failure - if matches!(e, CrawlerError::HttpStatus { status, .. } if (400..500).contains(&status)) { - update_job_status(graph, job, "FAILED", Some(attempts)).await?; - } - Ok(None) } } @@ -181,7 +187,8 @@ async fn batch_create_children( ON CREATE SET c.ip = $ip, c.domain = $domain, \ c.job_status = CASE WHEN $cur_depth = $req_depth THEN 'COMPLETED' ELSE 'PENDING' END, \ c.requested_depth = $req_depth, \ - c.current_depth = $cur_depth, c.request_time = $req_time \ + c.current_depth = $cur_depth, c.request_time = $req_time, \ + c.targeted = $targeted, c.target_domain = $target_domain \ MERGE (p)-[:Lead]->(c)", ) .param("pname", parent.name.as_str()) @@ -194,7 +201,9 @@ async fn batch_create_children( .param("http_type", child.http_type.as_str()) .param("req_depth", child.requested_depth) .param("cur_depth", child.current_depth) - .param("req_time", child.request_time.as_str()), + .param("req_time", child.request_time.as_str()) + .param("targeted", child.targeted) + .param("target_domain", child.target_domain.as_str()), ) .await?; } @@ -279,11 +288,24 @@ pub async fn feeding( None => return Ok(false), }; - // Step 2: Extract URLs from HTML + // Step 2: Extract URLs from HTML and normalize once let extracted_urls = crawler::extract_urls(&page_data.html); + let mut normalized_map: HashMap = HashMap::new(); + for url in &extracted_urls { + let (norm_name, http_type) = url_normalize::normalize_url(url); + let upper_key = format!("{}{}", http_type, norm_name).to_uppercase(); + normalized_map.entry(upper_key).or_insert((norm_name, http_type)); + } + + // Step 2b: Filter by target domain when targeted + if job.targeted && !job.target_domain.is_empty() { + normalized_map.retain(|_, (norm_name, _)| { + url_normalize::is_same_registered_domain(norm_name, &job.target_domain) + }); + } // Step 3: Deduplicate against existing DB nodes (server-side) - let upper_urls: HashSet = extracted_urls.iter().map(|u| u.to_uppercase()).collect(); + let upper_urls: HashSet = normalized_map.keys().cloned().collect(); let new_urls = filter_new_urls(graph, &upper_urls, &job.crawl_id).await?; if new_urls.is_empty() { @@ -292,10 +314,10 @@ pub async fn feeding( return Ok(true); } - // Step 4: Normalize, DNS resolve in parallel, build child list + // Step 4: DNS resolve in parallel, build child list let normalized: HashSet<(String, String)> = new_urls .iter() - .map(|u| url_normalize::normalize_url(u)) + .filter_map(|key| normalized_map.get(key).cloned()) .collect(); let request_time = format!("{:?}", page_data.elapsed); @@ -303,6 +325,9 @@ pub async fn feeding( let current_depth = job.current_depth; let crawl_id = job.crawl_id.clone(); + let targeted = job.targeted; + let target_domain = job.target_domain.clone(); + let dns_futures: Vec<_> = normalized .iter() .map(|(name, http_type)| { @@ -310,6 +335,7 @@ pub async fn feeding( let http_type = http_type.clone(); let req_time = request_time.clone(); let cid = crawl_id.clone(); + let td = target_domain.clone(); async move { match dns::get_network_stats(resolver, &name, config.max_dns_depth).await { Ok(stats) => Some(ChildNode { @@ -321,6 +347,8 @@ pub async fn feeding( current_depth: current_depth + 1, request_time: req_time, crawl_id: cid, + targeted, + target_domain: td, }), Err(e) => { tracing::error!("URL: {} -- FAILED: {}", name, e); diff --git a/feeder/src/main.rs b/feeder/src/main.rs index 7217e96..6ce4988 100644 --- a/feeder/src/main.rs +++ b/feeder/src/main.rs @@ -123,6 +123,8 @@ async fn main() -> anyhow::Result<()> { current_depth: url_job.current_depth, attempts: url_job.attempts, crawl_id: url_job.crawl_id.clone(), + targeted: url_job.targeted, + target_domain: url_job.target_domain.clone(), }); // Check for shutdown after claiming but before processing. diff --git a/frontend/package-lock.json b/frontend/package-lock.json index 213b8a8..6b8917d 100644 --- a/frontend/package-lock.json +++ b/frontend/package-lock.json @@ -18,6 +18,7 @@ "@tanstack/react-query": "^5.62.0", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", + "d3-force": "^3.0.0", "lucide-react": "^0.460.0", "react": "^18.3.1", "react-dom": "^18.3.1", @@ -30,6 +31,7 @@ }, "devDependencies": { "@eslint/js": "^9.15.0", + "@types/d3-force": "^3.0.10", "@types/react": "^18.3.12", "@types/react-dom": "^18.3.1", "@vitejs/plugin-react": "^4.3.4", @@ -2355,6 +2357,13 @@ "integrity": "sha512-NcV1JjO5oDzoK26oMzbILE6HW7uVXOHLQvHshBUW4UMdZGfiY6v5BeQwh9a9tCzv+CeefZQHJt5SRgK154RtiA==", "license": "MIT" }, + "node_modules/@types/d3-force": { + "version": "3.0.10", + "resolved": "https://registry.npmjs.org/@types/d3-force/-/d3-force-3.0.10.tgz", + "integrity": "sha512-ZYeSaCF3p73RdOKcjj+swRlZfnYpK1EbaDiYICEEp5Q6sUiqFaFQ9qgoshp5CzIyyb/yD09kD9o2zEltCexlgw==", + "dev": true, + "license": "MIT" + }, "node_modules/@types/d3-interpolate": { "version": "3.0.4", "resolved": "https://registry.npmjs.org/@types/d3-interpolate/-/d3-interpolate-3.0.4.tgz", @@ -3257,6 +3266,20 @@ "node": ">=12" } }, + "node_modules/d3-force": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/d3-force/-/d3-force-3.0.0.tgz", + "integrity": "sha512-zxV/SsA+U4yte8051P4ECydjD/S+qeYtnaIyAs9tgHCqfguma/aAQDjo85A9Z6EKhBirHRJHXIgJUlffT4wdLg==", + "license": "ISC", + "dependencies": { + "d3-dispatch": "1 - 3", + "d3-quadtree": "1 - 3", + "d3-timer": "1 - 3" + }, + "engines": { + "node": ">=12" + } + }, "node_modules/d3-force-3d": { "version": "3.0.6", "resolved": "https://registry.npmjs.org/d3-force-3d/-/d3-force-3d-3.0.6.tgz", diff --git a/frontend/package.json b/frontend/package.json index a78fb94..b5d1251 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -21,6 +21,7 @@ "@tanstack/react-query": "^5.62.0", "class-variance-authority": "^0.7.1", "clsx": "^2.1.1", + "d3-force": "^3.0.0", "lucide-react": "^0.460.0", "react": "^18.3.1", "react-dom": "^18.3.1", @@ -33,6 +34,7 @@ }, "devDependencies": { "@eslint/js": "^9.15.0", + "@types/d3-force": "^3.0.10", "@types/react": "^18.3.12", "@types/react-dom": "^18.3.1", "@vitejs/plugin-react": "^4.3.4", diff --git a/frontend/src/components/GraphView.tsx b/frontend/src/components/GraphView.tsx index eb7ac59..0d6463c 100644 --- a/frontend/src/components/GraphView.tsx +++ b/frontend/src/components/GraphView.tsx @@ -1,8 +1,10 @@ -import { useRef, useCallback, useMemo } from "react"; +import { useState, useRef, useCallback, useMemo, useEffect } from "react"; import ForceGraph2D, { type ForceGraphMethods, + type LinkObject, type NodeObject, } from "react-force-graph-2d"; +import { forceRadial } from "d3-force"; import type { GraphData } from "../types/api"; interface GraphViewProps { @@ -11,6 +13,7 @@ interface GraphViewProps { interface CrawlNode { label: string; + domain: string; depth: number; status: string; nodeType: string; @@ -30,15 +33,33 @@ export function GraphView({ data }: GraphViewProps) { const fgRef = useRef> | undefined>( undefined ); + const [selectedNode, setSelectedNode] = useState(null); + const containerRef = useRef(null); + + const needsRecenter = useRef(true); + const [containerWidth, setContainerWidth] = useState(0); + + useEffect(() => { + const el = containerRef.current; + if (!el) return; + const observer = new ResizeObserver((entries) => { + setContainerWidth(entries[0].contentRect.width); + }); + observer.observe(el); + return () => observer.disconnect(); + }, []); const graphData = useMemo(() => { const nodes = data.nodes.map((n) => ({ id: n.id, label: n.label, + domain: n.domain, depth: n.depth, status: n.status, nodeType: n.node_type, - val: n.node_type === "ROOT" ? 3 : 1, + val: { ROOT: 4, 1: 2.5, 2: 1.5 }[n.node_type === "ROOT" ? "ROOT" : n.depth] ?? 1, + // Pin root node at origin for stable centering + ...(n.node_type === "ROOT" ? { fx: 0, fy: 0 } : {}), })); const links = data.edges.map((e) => ({ @@ -49,6 +70,22 @@ export function GraphView({ data }: GraphViewProps) { return { nodes, links }; }, [data]); + const { neighborIds, connectedLinks } = useMemo(() => { + if (!selectedNode) return { neighborIds: new Set(), connectedLinks: new Set() }; + const nIds = new Set(); + const cLinks = new Set(); + graphData.links.forEach((link) => { + const src = typeof link.source === "object" ? (link.source as NodeObject).id : link.source; + const tgt = typeof link.target === "object" ? (link.target as NodeObject).id : link.target; + if (src === selectedNode || tgt === selectedNode) { + nIds.add(src as string); + nIds.add(tgt as string); + cLinks.add(`${src}->${tgt}`); + } + }); + return { neighborIds: nIds, connectedLinks: cLinks }; + }, [selectedNode, graphData]); + const activeStatuses = useMemo(() => { const statuses = new Set(); data.nodes.forEach((n) => { @@ -58,17 +95,60 @@ export function GraphView({ data }: GraphViewProps) { return Object.entries(STATUS_COLORS).filter(([s]) => statuses.has(s)); }, [data]); + useEffect(() => { + const fg = fgRef.current; + if (!fg) return; + + const ringSpacing = 120; + + // Radial force: push nodes into concentric rings by depth + fg.d3Force( + "radial", + forceRadial( + (node: NodeObject) => ((node as CrawlNode).depth ?? 0) * ringSpacing, + 0, + 0 + ).strength(0.8) + ); + + // Link distance based on depth + fg.d3Force("link")?.distance( + (link: LinkObject) => { + const src = link.source as NodeObject; + const tgt = link.target as NodeObject; + return 30 + Math.abs((tgt.depth ?? 0) - (src.depth ?? 0)) * 60; + } + ); + + // Stronger charge to spread nodes within rings + fg.d3Force("charge")?.strength(-80); + + needsRecenter.current = true; + fg.d3ReheatSimulation(); + }, [graphData]); + const handleEngineStop = useCallback(() => { - if (fgRef.current) { - fgRef.current.zoomToFit(400); - } + const fg = fgRef.current; + if (!fg || !needsRecenter.current) return; + needsRecenter.current = false; + + // Root is pinned at (0,0). Center on it and zoom to fit all nodes. + fg.centerAt(0, 0); + fg.zoomToFit(400, 40); }, []); const nodeColor = useCallback( (node: NodeObject) => { - return STATUS_COLORS[node.status || ""] || "#9ca3af"; + const base = STATUS_COLORS[node.status || ""] || "#9ca3af"; + if (!selectedNode) return base; + if (node.id === selectedNode || neighborIds.has(node.id as string)) return base; + // Dim unrelated nodes: parse hex to rgba with low opacity + const r = parseInt(base.slice(1, 3), 16); + const g = parseInt(base.slice(3, 5), 16); + const b = parseInt(base.slice(5, 7), 16); + return `rgba(${r},${g},${b},0.2)`; }, - [] + [selectedNode, neighborIds] ); const nodeLabel = useCallback( @@ -88,6 +168,7 @@ export function GraphView({ data }: GraphViewProps) { return (
@@ -97,13 +178,46 @@ export function GraphView({ data }: GraphViewProps) { nodeColor={nodeColor} nodeLabel={nodeLabel} nodeRelSize={6} - linkColor={() => "rgba(255,255,255,0.15)"} + onNodeClick={(node: NodeObject) => { + setSelectedNode(node.id === selectedNode ? null : (node.id as string)); + }} + onBackgroundClick={() => setSelectedNode(null)} + nodeCanvasObjectMode={() => selectedNode ? ("after" as const) : undefined} + nodeCanvasObject={(node: NodeObject, ctx, globalScale) => { + if (node.id !== selectedNode) return; + const r = Math.sqrt(node.val ?? 1) * 6 + 2; + ctx.beginPath(); + ctx.arc(node.x!, node.y!, r, 0, 2 * Math.PI); + ctx.strokeStyle = "#ffffff"; + ctx.lineWidth = 2 / globalScale; + ctx.stroke(); + }} + linkColor={(link: LinkObject) => { + if (selectedNode) { + const src = typeof link.source === "object" ? (link.source as NodeObject).id : link.source; + const tgt = typeof link.target === "object" ? (link.target as NodeObject).id : link.target; + const key = `${src}->${tgt}`; + return connectedLinks.has(key) ? "rgba(255,255,255,0.6)" : "rgba(255,255,255,0.03)"; + } + const depth = Math.max( + (link.source as NodeObject)?.depth ?? 0, + (link.target as NodeObject)?.depth ?? 0 + ); + const opacity = Math.max(0.05, 0.25 - depth * 0.05); + return `rgba(255,255,255,${opacity})`; + }} + linkWidth={(link: LinkObject) => { + if (!selectedNode) return 0.5; + const src = typeof link.source === "object" ? (link.source as NodeObject).id : link.source; + const tgt = typeof link.target === "object" ? (link.target as NodeObject).id : link.target; + return (src === selectedNode || tgt === selectedNode) ? 2 : 0.5; + }} linkDirectionalArrowLength={3} linkDirectionalArrowRelPos={1} backgroundColor="#111827" onEngineStop={handleEngineStop} cooldownTicks={100} - width={undefined} + width={containerWidth || undefined} height={600} />
@@ -117,6 +231,28 @@ export function GraphView({ data }: GraphViewProps) {
))}
+ {selectedNode && (() => { + const node = graphData.nodes.find((n) => n.id === selectedNode); + if (!node) return null; + return ( +
+
+ {node.label} + +
+
Domain: {node.domain}
+
Depth: {node.depth}
+
Status: {node.status}
+
Type: {node.nodeType}
+
Connections: {neighborIds.size > 0 ? neighborIds.size - 1 : 0}
+
+ ); + })()} ); } diff --git a/frontend/src/lib/api.ts b/frontend/src/lib/api.ts index fc26efa..ef20276 100644 --- a/frontend/src/lib/api.ts +++ b/frontend/src/lib/api.ts @@ -19,12 +19,13 @@ async function fetchJSON(url: string, init?: RequestInit): Promise { export async function createCrawl( url: string, - depth: number + depth: number, + targeted?: boolean ): Promise { return fetchJSON(`${BASE}/crawls`, { method: "POST", headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ url, depth }), + body: JSON.stringify({ url, depth, ...(targeted ? { targeted } : {}) }), }); } diff --git a/frontend/src/pages/CrawlDetail.tsx b/frontend/src/pages/CrawlDetail.tsx index 3a150b3..f02f7da 100644 --- a/frontend/src/pages/CrawlDetail.tsx +++ b/frontend/src/pages/CrawlDetail.tsx @@ -103,7 +103,13 @@ export default function CrawlDetail() { {crawl.root_url.toLowerCase()}

- Depth: {crawl.requested_depth} | ID: {id} + Depth: {crawl.requested_depth} + {crawl.targeted && ( + + Targeted + + )} + {" "}| ID: {id}

@@ -227,6 +233,12 @@ export default function CrawlDetail() {
Requested Depth
{crawl.requested_depth}
+
+
Scope
+
+ {crawl.targeted ? "Targeted" : "Unrestricted"} +
+
Status
diff --git a/frontend/src/pages/CrawlList.tsx b/frontend/src/pages/CrawlList.tsx index f5938cd..6a91b71 100644 --- a/frontend/src/pages/CrawlList.tsx +++ b/frontend/src/pages/CrawlList.tsx @@ -97,6 +97,11 @@ export default function CrawlList() { depth {crawl.requested_depth} + {crawl.targeted && ( + + Targeted + + )}
diff --git a/frontend/src/pages/NewCrawl.tsx b/frontend/src/pages/NewCrawl.tsx index 5098ae5..e49a57f 100644 --- a/frontend/src/pages/NewCrawl.tsx +++ b/frontend/src/pages/NewCrawl.tsx @@ -11,6 +11,7 @@ import { Input } from "../components/ui/input"; const schema = z.object({ url: z.string().url("Please enter a valid URL"), depth: z.number().min(1).max(5), + targeted: z.boolean(), }); type FormData = z.infer; @@ -28,7 +29,7 @@ export default function NewCrawl() { formState: { errors }, } = useForm({ resolver: zodResolver(schema), - defaultValues: { url: "", depth: 2 }, + defaultValues: { url: "", depth: 2, targeted: false }, }); const depth = watch("depth"); @@ -37,7 +38,7 @@ export default function NewCrawl() { setSubmitting(true); setError(""); try { - const result = await createCrawl(data.url, data.depth); + const result = await createCrawl(data.url, data.depth, data.targeted || undefined); navigate(`/crawls/${result.crawl_id}`); } catch (err) { setError(err instanceof Error ? err.message : "Failed to start crawl"); @@ -101,6 +102,32 @@ export default function NewCrawl() { )} +
+ + +
+

What to expect diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index 95f04f0..ea1a25b 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -13,6 +13,7 @@ export interface CrawlProgress { failed: number; root_url: string; requested_depth: number; + targeted: boolean; } export interface CrawlListItem { @@ -23,6 +24,7 @@ export interface CrawlListItem { total: number; completed: number; failed: number; + targeted: boolean; } export interface CrawlListResponse { diff --git a/manager/src/models/crawl.rs b/manager/src/models/crawl.rs index 8ed1e2f..1dc3f84 100644 --- a/manager/src/models/crawl.rs +++ b/manager/src/models/crawl.rs @@ -4,6 +4,8 @@ use serde::{Deserialize, Serialize}; pub struct CrawlRequest { pub url: String, pub depth: i64, + #[serde(default)] + pub targeted: Option, } #[derive(Serialize)] @@ -24,6 +26,7 @@ pub struct CrawlProgress { pub cancelled: i64, pub root_url: String, pub requested_depth: i64, + pub targeted: bool, } #[derive(Serialize)] @@ -36,6 +39,7 @@ pub struct CrawlListItem { pub completed: i64, pub failed: i64, pub cancelled: i64, + pub targeted: bool, } #[derive(Serialize)] diff --git a/manager/src/routes/crawl.rs b/manager/src/routes/crawl.rs index b766967..067f044 100644 --- a/manager/src/routes/crawl.rs +++ b/manager/src/routes/crawl.rs @@ -43,6 +43,23 @@ pub async fn create_crawl( // 1. Normalize root URL let (root_name, http_type) = url_normalize::normalize_url(&req.url); + let targeted = req.targeted.unwrap_or(false); + + // 1b. Compute target domain for targeted crawls + let target_domain = if targeted { + match url_normalize::registered_domain(&root_name) { + Some(rd) => rd, + None => { + return ( + StatusCode::BAD_REQUEST, + Json(json!({"error": "Cannot determine registered domain for targeted crawl (bare public suffix or invalid host)"})), + ) + .into_response(); + } + } + } else { + String::new() + }; // 2. Fetch page HTML let page_data = match crawler::get_page_data(&state.client, &req.url).await { @@ -85,10 +102,20 @@ pub async fn create_crawl( // 6. Resolve DNS for each extracted URL in parallel let request_time = format!("{:?}", page_data.elapsed); - let dns_futures: Vec<_> = extracted_urls + // 6a. Normalize extracted URLs and filter by target domain if targeted + let normalized_urls: Vec<(String, String)> = extracted_urls + .iter() + .map(|url| url_normalize::normalize_url(url)) + .filter(|(norm_name, _)| { + !targeted || url_normalize::is_same_registered_domain(norm_name, &target_domain) + }) + .collect(); + + let dns_futures: Vec<_> = normalized_urls .iter() - .map(|url| { - let (norm_name, child_http_type) = url_normalize::normalize_url(url); + .map(|(norm_name, child_http_type)| { + let norm_name = norm_name.clone(); + let child_http_type = child_http_type.clone(); let resolver = &state.resolver; let max_depth = state.config.max_dns_depth; async move { @@ -117,6 +144,8 @@ pub async fn create_crawl( depth: req.depth, request_time: &request_time, children: &children, + targeted, + target_domain: &target_domain, }; if let Err(e) = crawl_service::create_crawl_graph(&state.graph, ¶ms).await { diff --git a/manager/src/services/crawl_service.rs b/manager/src/services/crawl_service.rs index 62fbff2..193902a 100644 --- a/manager/src/services/crawl_service.rs +++ b/manager/src/services/crawl_service.rs @@ -11,6 +11,8 @@ pub struct CreateCrawlParams<'a> { pub depth: i64, pub request_time: &'a str, pub children: &'a [(String, String, String, String)], + pub targeted: bool, + pub target_domain: &'a str, } /// Create ROOT node and child URL nodes in a single transaction with crawl_id. @@ -25,7 +27,8 @@ pub async fn create_crawl_graph( query( "CREATE (:ROOT {name: $name, ip: $ip, domain: $domain, http_type: $http_type, \ requested_depth: $req_depth, current_depth: 0, request_time: $req_time, \ - crawl_id: $crawl_id, created_at: datetime()})", + crawl_id: $crawl_id, created_at: datetime(), \ + targeted: $targeted, target_domain: $target_domain})", ) .param("name", params.root_name) .param("ip", params.root_ip) @@ -33,7 +36,9 @@ pub async fn create_crawl_graph( .param("http_type", params.http_type) .param("req_depth", params.depth) .param("req_time", params.request_time) - .param("crawl_id", params.crawl_id), + .param("crawl_id", params.crawl_id) + .param("targeted", params.targeted) + .param("target_domain", params.target_domain), ) .await?; @@ -46,7 +51,8 @@ pub async fn create_crawl_graph( ON CREATE SET c.ip = $ip, c.domain = $domain, \ c.job_status = CASE WHEN 1 = $req_depth THEN 'COMPLETED' ELSE 'PENDING' END, \ c.requested_depth = $req_depth, \ - c.current_depth = 1, c.request_time = $req_time \ + c.current_depth = 1, c.request_time = $req_time, \ + c.targeted = $targeted, c.target_domain = $target_domain \ MERGE (root)-[:Lead]->(c)", ) .param("crawl_id", params.crawl_id) @@ -55,7 +61,9 @@ pub async fn create_crawl_graph( .param("ip", child_ip.as_str()) .param("domain", child_domain.as_str()) .param("http_type", child_http_type.as_str()) - .param("req_time", params.request_time), + .param("req_time", params.request_time) + .param("targeted", params.targeted) + .param("target_domain", params.target_domain), ) .await?; } @@ -83,6 +91,7 @@ pub async fn get_crawl_progress( sum(CASE WHEN u.job_status = 'FAILED' THEN 1 ELSE 0 END) AS failed, \ sum(CASE WHEN u.job_status = 'CANCELLED' THEN 1 ELSE 0 END) AS cancelled \ RETURN r.name AS root_url, r.requested_depth AS depth, r.http_type AS http_type, \ + r.targeted AS targeted, \ total, completed, pending, in_progress, failed, cancelled", ) .param("crawl_id", crawl_id), @@ -113,6 +122,8 @@ pub async fn get_crawl_progress( "running".to_string() }; + let targeted: bool = row.get::("targeted").unwrap_or(false); + Ok(Some(CrawlProgress { crawl_id: crawl_id.to_string(), status, @@ -124,6 +135,7 @@ pub async fn get_crawl_progress( cancelled, root_url: format!("{}{}", http_type, url), requested_depth: depth, + targeted, })) } None => Ok(None), @@ -159,6 +171,7 @@ pub async fn list_crawls( UNWIND items[$offset..($offset + $limit)] AS item \ RETURN item.r.crawl_id AS crawl_id, item.r.name AS root_url, \ item.r.http_type AS http_type, item.r.requested_depth AS depth, \ + item.r.targeted AS targeted, \ item.total AS total, item.completed AS completed, item.failed AS failed, item.cancelled AS cancelled, item.status AS status, \ total_count" } else { @@ -178,6 +191,7 @@ pub async fn list_crawls( UNWIND items[$offset..($offset + $limit)] AS item \ RETURN item.r.crawl_id AS crawl_id, item.r.name AS root_url, \ item.r.http_type AS http_type, item.r.requested_depth AS depth, \ + item.r.targeted AS targeted, \ item.total AS total, item.completed AS completed, item.failed AS failed, item.cancelled AS cancelled, item.status AS status, \ total_count" }; @@ -208,6 +222,7 @@ pub async fn list_crawls( completed: row.get("completed")?, failed: row.get("failed")?, cancelled: row.get("cancelled")?, + targeted: row.get::("targeted").unwrap_or(false), }); } diff --git a/shared/Cargo.toml b/shared/Cargo.toml index 393e47b..9a0bc99 100644 --- a/shared/Cargo.toml +++ b/shared/Cargo.toml @@ -13,6 +13,7 @@ regex = { workspace = true } thiserror = { workspace = true } tracing = { workspace = true } futures = { workspace = true } +psl = "2" [dev-dependencies] tokio = { workspace = true } diff --git a/shared/src/url_normalize.rs b/shared/src/url_normalize.rs index fd54467..dfff237 100644 --- a/shared/src/url_normalize.rs +++ b/shared/src/url_normalize.rs @@ -1,3 +1,5 @@ +use psl::Psl; + /// Normalizes a URL by uppercasing, removing protocol and www prefix. /// /// Returns (normalized_name, protocol). @@ -21,6 +23,37 @@ pub fn normalize_url(url: &str) -> (String, String) { (name, proto.to_string()) } +/// Extracts the registered domain (eTLD+1) from a normalized name. +/// +/// The input should be an uppercase normalized name (no protocol, no `www.`). +/// Ports are stripped before lookup. Returns uppercase eTLD+1. +/// +/// # Examples +/// - `"EXAMPLE.COM"` -> `Some("EXAMPLE.COM")` +/// - `"BLOG.EXAMPLE.CO.UK"` -> `Some("EXAMPLE.CO.UK")` +/// - `"EXAMPLE.COM:8080"` -> `Some("EXAMPLE.COM")` +/// - `"COM"` (bare TLD) -> `None` +pub fn registered_domain(normalized_name: &str) -> Option { + // Strip port if present + let host = normalized_name.split(':').next().unwrap_or(normalized_name); + // psl requires lowercase input + let lower = host.to_lowercase(); + let domain = psl::List.domain(lower.as_bytes())?; + let domain_str = std::str::from_utf8(domain.as_bytes()).ok()?; + Some(domain_str.to_uppercase()) +} + +/// Checks if a normalized name belongs to the same registered domain as the target. +/// +/// Both inputs should be uppercase. The target should already be a registered domain +/// (output of `registered_domain()`). +pub fn is_same_registered_domain(normalized_name: &str, target_domain: &str) -> bool { + match registered_domain(normalized_name) { + Some(rd) => rd == target_domain, + None => false, + } +} + #[cfg(test)] mod tests { use super::*; @@ -66,4 +99,69 @@ mod tests { assert_eq!(name, "SUBDOMAIN.WWW.EXAMPLE.COM"); assert_eq!(proto, "HTTPS://"); } + + #[test] + fn test_registered_domain_simple() { + assert_eq!(registered_domain("EXAMPLE.COM"), Some("EXAMPLE.COM".to_string())); + } + + #[test] + fn test_registered_domain_subdomain() { + assert_eq!(registered_domain("BLOG.EXAMPLE.COM"), Some("EXAMPLE.COM".to_string())); + } + + #[test] + fn test_registered_domain_deep_subdomain() { + assert_eq!(registered_domain("A.B.C.EXAMPLE.COM"), Some("EXAMPLE.COM".to_string())); + } + + #[test] + fn test_registered_domain_co_uk() { + assert_eq!(registered_domain("BLOG.EXAMPLE.CO.UK"), Some("EXAMPLE.CO.UK".to_string())); + } + + #[test] + fn test_registered_domain_with_port() { + assert_eq!(registered_domain("EXAMPLE.COM:8080"), Some("EXAMPLE.COM".to_string())); + } + + #[test] + fn test_registered_domain_bare_tld() { + assert_eq!(registered_domain("COM"), None); + } + + #[test] + fn test_registered_domain_bare_public_suffix() { + assert_eq!(registered_domain("GITHUB.IO"), None); + } + + #[test] + fn test_registered_domain_localhost() { + assert_eq!(registered_domain("LOCALHOST"), None); + } + + #[test] + fn test_is_same_registered_domain_match() { + assert!(is_same_registered_domain("BLOG.EXAMPLE.COM", "EXAMPLE.COM")); + } + + #[test] + fn test_is_same_registered_domain_exact() { + assert!(is_same_registered_domain("EXAMPLE.COM", "EXAMPLE.COM")); + } + + #[test] + fn test_is_same_registered_domain_no_match() { + assert!(!is_same_registered_domain("GOOGLE.COM", "EXAMPLE.COM")); + } + + #[test] + fn test_is_same_registered_domain_with_port() { + assert!(is_same_registered_domain("API.EXAMPLE.COM:3000", "EXAMPLE.COM")); + } + + #[test] + fn test_is_same_registered_domain_co_uk() { + assert!(is_same_registered_domain("SHOP.EXAMPLE.CO.UK", "EXAMPLE.CO.UK")); + } }