diff --git a/src/dns_benchmark/analysis.py b/src/dns_benchmark/analysis.py index 3c66cb8..6d7c14d 100644 --- a/src/dns_benchmark/analysis.py +++ b/src/dns_benchmark/analysis.py @@ -50,13 +50,21 @@ def _create_dataframe(self) -> pd.DataFrame: "record_type": result.record_type, "latency_ms": result.latency_ms, "status": result.status.value, + # True for SUCCESS only — used for success rate reporting "success": result.status == QueryStatus.SUCCESS, + # True for SUCCESS or DNSSEC_FAILED — query completed at network + # level so latency is valid and should be included in stats. + "completed": result.status + in ( + QueryStatus.SUCCESS, + QueryStatus.DNSSEC_FAILED, + ), "answers_count": len(result.answers), "ttl": result.ttl or 0, "error_message": result.error_message or "", "attempt_number": result.attempt_number, "cache_hit": result.cache_hit, - "interation": result.iteration, + "iteration": result.iteration, "query_id": result.query_id, "protocol": result.protocol.value, "dnssec_validated": result.dnssec_validated, @@ -75,7 +83,7 @@ def get_resolver_statistics(self) -> List[ResolverStats]: # Basic counts total_queries = len(resolver_data) - successful_queries = len(resolver_data[resolver_data["success"] == True]) + successful_queries = len(resolver_data[resolver_data["completed"] == True]) success_rate = ( (successful_queries / total_queries) * 100 if total_queries > 0 else 0 ) @@ -86,7 +94,7 @@ def get_resolver_statistics(self) -> List[ResolverStats]: else 0.0 ) # Latency statistics (only for successful queries) - successful_latencies = resolver_data[resolver_data["success"] == True][ + successful_latencies = resolver_data[resolver_data["completed"] == True][ "latency_ms" ] @@ -143,12 +151,12 @@ def get_resolver_statistics(self) -> List[ResolverStats]: def get_overall_statistics(self) -> Dict[str, Any]: """Get overall benchmark statistics.""" total_queries = len(self.df) - successful_queries = len(self.df[self.df["success"] == True]) + successful_queries = len(self.df[self.df["completed"] == True]) overall_success_rate = ( (successful_queries / total_queries) * 100 if total_queries > 0 else 0 ) - successful_latencies = self.df[self.df["success"] == True]["latency_ms"] + successful_latencies = self.df[self.df["completed"] == True]["latency_ms"] if len(successful_latencies) > 0: overall_avg_latency = float(successful_latencies.mean()) @@ -195,15 +203,15 @@ def get_domain_statistics(self) -> List[Dict[str, Any]]: for domain in self.df["domain"].unique(): dd = self.df[self.df["domain"] == domain] total = len(dd) - success = len(dd[dd["success"] == True]) + success = len(dd[dd["completed"] == True]) rate = (success / total) * 100 if total > 0 else 0.0 - latencies = dd[dd["success"] == True]["latency_ms"] + latencies = dd[dd["completed"] == True]["latency_ms"] # Find fastest and slowest resolvers for this domain if len(latencies) > 0: - fastest_idx = dd[dd["success"] == True]["latency_ms"].idxmin() - slowest_idx = dd[dd["success"] == True]["latency_ms"].idxmax() + fastest_idx = dd[dd["completed"] == True]["latency_ms"].idxmin() + slowest_idx = dd[dd["completed"] == True]["latency_ms"].idxmax() fastest_resolver = dd.loc[fastest_idx, "resolver_name"] slowest_resolver = dd.loc[slowest_idx, "resolver_name"] else: @@ -233,9 +241,9 @@ def get_record_type_statistics(self) -> List[Dict[str, Any]]: for rt in self.df["record_type"].unique(): rt_df = self.df[self.df["record_type"] == rt] total = len(rt_df) - success = len(rt_df[rt_df["success"] == True]) + success = len(rt_df[rt_df["completed"] == True]) rate = (success / total) * 100 if total > 0 else 0.0 - latencies = rt_df[rt_df["success"] == True]["latency_ms"] + latencies = rt_df[rt_df["completed"] == True]["latency_ms"] rt_stats.append( { "record_type": rt, @@ -261,9 +269,9 @@ def get_protocol_statistics(self) -> List[Dict[str, Any]]: for proto in self.df["protocol"].unique(): proto_df = self.df[self.df["protocol"] == proto] total = len(proto_df) - success = int(proto_df["success"].sum()) + success = int(proto_df["completed"].sum()) rate = (success / total) * 100 if total > 0 else 0.0 - latencies = proto_df[proto_df["success"] == True]["latency_ms"] + latencies = proto_df[proto_df["completed"] == True]["latency_ms"] dnssec_validated = int(proto_df["dnssec_validated"].sum()) proto_stats.append( { diff --git a/src/dns_benchmark/cli.py b/src/dns_benchmark/cli.py index f70a9a2..adfbcc5 100644 --- a/src/dns_benchmark/cli.py +++ b/src/dns_benchmark/cli.py @@ -1,5 +1,6 @@ import asyncio import json +import math import os import time from datetime import datetime @@ -15,6 +16,7 @@ from dns_benchmark.analysis import BenchmarkAnalyzer from dns_benchmark.core import ( DNSQueryEngine, + DNSQueryResult, DomainManager, QueryProtocol, QueryStatus, @@ -489,8 +491,23 @@ def benchmark( except Exception as e: click.echo(error(f"Error loading domains: {e}")) return + # New - if dnssec_validate: + try: + protocol, doh_urls = _resolve_protocol_and_doh_urls( + doh=doh, + dot=dot, + doh_url=doh_url, + resolvers=resolver_list, + ) + except click.UsageError: + raise + + # New + # Only warn about DNSSEC-signed domains when using defaults — for custom + # domain files we have no way to know which are signed without querying, + # so stay silent to avoid noisy false-positive warnings. + if dnssec_validate and use_defaults: signed = { d["domain"] for d in DomainManager.DOMAINS_DATABASE @@ -525,15 +542,20 @@ def benchmark( click.echo(info(f"- Total queries: {total_queries}")) if use_cache: click.echo(info("- Cache enabled: queries may be reused across iterations")) + # New if protocol != QueryProtocol.PLAIN: click.echo(info(f"- Protocol: {protocol.value.upper()}")) + if dnssec_validate: click.echo( - info("- DNSSEC validation: enforced (queries fail if AD flag absent)") + info( + "- DNSSEC: enforced — DO bit set, AD flag required " + "(note: success rate reflects network success, not DNSSEC outcome)" + ) ) else: - click.echo(info("- DNSSEC: passive (AD flag collected, not enforced)")) + click.echo(info("- DNSSEC: off (DO bit not set, AD flag not collected)")) # Show warmup message if (warmup or warmup_fast) and not quiet: @@ -549,6 +571,7 @@ def benchmark( feedback_manager.increment_run() start_time = time.time() + # New try: engine = DNSQueryEngine( @@ -556,7 +579,11 @@ def benchmark( timeout=timeout, max_retries=retries, enable_cache=use_cache, - enable_dnssec=True, # always collect AD flag - always True + # DO bit is only set when --dnssec-validate is passed. + # enable_dnssec=True sets the DO bit (requests RRSIG records). + # enforce_dnssec=True fails queries where the AD flag is absent. + # Both are off by default to avoid latency overhead on normal benchmarks. + enable_dnssec=dnssec_validate, enforce_dnssec=dnssec_validate, ) @@ -579,9 +606,10 @@ def _progress_cb(completed: int, total: int) -> None: pass engine.set_progress_callback(_progress_cb) - # New - results = asyncio.run( - engine.run_benchmark( + + # Single coroutine to avoid closed event loop from two asyncio.run calls + async def _run() -> List[DNSQueryResult]: + results = await engine.run_benchmark( resolvers=resolver_list, domains=domain_list, record_types=record_type_list, @@ -592,7 +620,10 @@ def _progress_cb(completed: int, total: int) -> None: protocol=protocol, doh_urls=doh_urls, ) - ) + await engine.close() + return results + + results = asyncio.run(_run()) if progress_bar: progress_bar.close() @@ -616,7 +647,7 @@ def _progress_cb(completed: int, total: int) -> None: f"Fastest resolver: {overall_stats['fastest_resolver']}", f"Slowest resolver: {overall_stats['slowest_resolver']}", f"Protocol: {protocol.value.upper()}", - f"DNSSEC validated: {sum(1 for r in results if r.dnssec_validated)} / {len(results)} queries", + f"DNSSEC AD validated: {sum(1 for r in results if r.dnssec_validated)} / {len(results)} queries", ] # Add iteration info if multiple iterations if iterations > 1: @@ -697,8 +728,11 @@ def _progress_cb(completed: int, total: int) -> None: ) if export_progress: export_progress.update(1) - except RuntimeError as e: - click.echo(error(f"Error during benchmark: {e}")) + except Exception as e: + # PDF export is non-fatal — warn and keep progress consistent + click.echo(error(f"PDF export failed: {e}")) + if export_progress: + export_progress.update(1) # JSON export now tracked in progress if json_output: @@ -721,7 +755,8 @@ def _progress_cb(completed: int, total: int) -> None: finally: if export_progress: export_progress.close() - + except click.UsageError: + raise except KeyboardInterrupt: click.echo(warning("\nBenchmark interrupted by user")) # Still show feedback prompt since benchmark was started @@ -743,7 +778,19 @@ def _progress_cb(completed: int, total: int) -> None: # ====================== Top Resolvers Command @cli.command() -# -------- +@click.option("--doh", is_flag=True, default=False, help="Use DNS-over-HTTPS.") +@click.option("--dot", is_flag=True, default=False, help="Use DNS-over-TLS.") +@click.option( + "--doh-url", + default=None, + help="Comma-separated DoH URLs, one per resolver (required if resolver not in db).", +) +@click.option( + "--dnssec-validate", + is_flag=True, + default=False, + help="Fail queries where DNSSEC AD flag is not set.", +) @click.option("--limit", "-n", default=10, help="Number of top resolvers to display") @click.option( "--metric", @@ -771,6 +818,10 @@ def _progress_cb(completed: int, total: int) -> None: ) @click.option("--quiet", is_flag=True, help="Suppress progress output") def top( + doh: bool, + dot: bool, + doh_url: Optional[str], + dnssec_validate: bool, limit: int, metric: str, domains: Optional[str], @@ -808,16 +859,15 @@ def top( ) ) else: - # Use all available resolvers for comprehensive ranking all_resolvers = ResolverManager.get_all_resolvers() resolver_list = [{"name": r["name"], "ip": r["ip"]} for r in all_resolvers] if not quiet: click.echo(success(f"Testing {len(resolver_list)} resolvers")) - # Get domains + # Get domains — supports both file path and inline comma-separated list if domains: try: - domain_list = DomainManager.load_domains_from_file(domains) + domain_list = DomainManager.parse_domains_input(domains) except Exception as e: click.echo(error(f"Error loading domains: {e}")) return @@ -827,10 +877,31 @@ def top( # Parse record types record_type_list = [rt.strip().upper() for rt in record_types.split(",")] - # Run benchmark + # Resolve protocol and DoH URLs early — fail fast before any queries + try: + protocol, doh_urls = _resolve_protocol_and_doh_urls( + doh=doh, + dot=dot, + doh_url=doh_url, + resolvers=resolver_list, + ) + except click.UsageError: + raise + total_queries = len(resolver_list) * len(domain_list) * len(record_type_list) if not quiet: click.echo(info(f"Running {total_queries} queries...")) + if protocol != QueryProtocol.PLAIN: + click.echo(info(f"Protocol: {protocol.value.upper()}")) + if dnssec_validate: + click.echo( + info( + "DNSSEC: enforced — DO bit set, AD flag required " + "(note: success rate reflects network success, not DNSSEC outcome)" + ) + ) + else: + click.echo(info("DNSSEC: off (DO bit not set, AD flag not collected)")) progress_bar = None if not quiet: @@ -841,6 +912,8 @@ def top( max_concurrent_queries=max_concurrent, timeout=timeout, enable_cache=False, + enable_dnssec=dnssec_validate, + enforce_dnssec=dnssec_validate, ) if progress_bar: @@ -855,28 +928,31 @@ def _progress_cb(completed: int, total: int) -> None: engine.set_progress_callback(_progress_cb) - results = asyncio.run( - engine.run_benchmark( + # Single coroutine to avoid closed event loop from two asyncio.run calls + async def _run() -> List[DNSQueryResult]: + results = await engine.run_benchmark( resolvers=resolver_list, domains=domain_list, record_types=record_type_list, warmup_fast=True, + protocol=protocol, + doh_urls=doh_urls, ) - ) + await engine.close() + return results + + results = asyncio.run(_run()) if progress_bar: progress_bar.close() duration = time.time() - start_time - if not quiet: click.echo(success(f"Benchmark completed in {duration:.2f} seconds")) # Analyze and rank analyzer = BenchmarkAnalyzer(results) resolver_stats_list = analyzer.get_resolver_statistics() - - # Convert list of ResolverStats objects to dict for easier lookup resolver_stats = {stats.resolver_name: stats for stats in resolver_stats_list} # Calculate ranking score based on metric @@ -886,13 +962,10 @@ def _progress_cb(completed: int, total: int) -> None: if stats.successful_queries > 0 and stats.avg_latency is not None: score = -stats.avg_latency else: - score = float("-inf") # push failed resolvers to bottom - + score = float("-inf") elif metric == "success": - # Higher is better score = stats.success_rate else: # reliability (combined) - # Weighted score: 60% success rate, 40% speed (normalized) if stats.successful_queries > 0 and stats.avg_latency not in (None, 0): latency_score = max(0, 100 - (stats.avg_latency / 5)) score = (stats.success_rate * 0.6) + (latency_score * 0.4) @@ -901,10 +974,7 @@ def _progress_cb(completed: int, total: int) -> None: scored_resolvers.append((name, stats, score)) - # Sort by score (descending) scored_resolvers.sort(key=lambda x: x[2], reverse=True) - - # Display top N top_resolvers = scored_resolvers[:limit] if not quiet: @@ -918,7 +988,6 @@ def _progress_cb(completed: int, total: int) -> None: if rank == 1 else "🥈" if rank == 2 else "🥉" if rank == 3 else f"{rank}." ) - click.echo(Fore.CYAN + f"{medal} {name}" + Style.RESET_ALL) latency_str = ( f"{stats.avg_latency:.2f} ms" @@ -934,7 +1003,6 @@ def _progress_cb(completed: int, total: int) -> None: click.echo( f" Queries: {stats.successful_queries}/{stats.total_queries}" ) - if metric == "reliability": click.echo( f" Reliability Score: {Fore.YELLOW}{score:.2f}/100{Style.RESET_ALL}" @@ -951,6 +1019,7 @@ def _progress_cb(completed: int, total: int) -> None: "timestamp": datetime.now().isoformat(), "metric": metric, "category": category, + "protocol": protocol.value, "top_resolvers": [ { "rank": i + 1, @@ -959,7 +1028,6 @@ def _progress_cb(completed: int, total: int) -> None: "success_rate": stats.success_rate, "successful_queries": stats.successful_queries, "total_queries": stats.total_queries, - # "score": score, } for i, (name, stats, score) in enumerate(top_resolvers) ], @@ -980,7 +1048,6 @@ def _progress_cb(completed: int, total: int) -> None: "Success Rate (%)", "Successful", "Total", - # "Score" ] ) for i, (name, stats, score) in enumerate(top_resolvers, 1): @@ -996,7 +1063,6 @@ def _progress_cb(completed: int, total: int) -> None: f"{stats.success_rate:.1f}", stats.successful_queries, stats.total_queries, - # f"{score:.2f}" ] ) @@ -1009,7 +1075,6 @@ def _progress_cb(completed: int, total: int) -> None: if category: f.write(f"Category: {category}\n") f.write("\n" + "=" * 60 + "\n\n") - for rank, (name, stats, score) in enumerate(top_resolvers, 1): f.write(f"{rank}. {name}\n") f.write( @@ -1021,10 +1086,8 @@ def _progress_cb(completed: int, total: int) -> None: f.write( f" Queries: {stats.successful_queries}/{stats.total_queries}\n" ) - # if metric == "reliability": - # f.write(f" Score: {score:.2f}/100\n") f.write("\n") - # Add summary note if any resolvers had no successful queries + failed_resolvers = [ s for s in resolver_stats.values() if s.successful_queries == 0 ] @@ -1034,10 +1097,11 @@ def _progress_cb(completed: int, total: int) -> None: "⚠️ Some resolvers returned no successful queries and were excluded from ranking" ) ) - if not quiet: click.echo(success(f"Results saved to: {output_path}")) + except click.UsageError: + raise except KeyboardInterrupt: if progress_bar: progress_bar.close() @@ -1051,7 +1115,19 @@ def _progress_cb(completed: int, total: int) -> None: # ======================= Compare @cli.command() -# ---------- +@click.option("--doh", is_flag=True, default=False, help="Use DNS-over-HTTPS.") +@click.option("--dot", is_flag=True, default=False, help="Use DNS-over-TLS.") +@click.option( + "--doh-url", + default=None, + help="Comma-separated DoH URLs, one per resolver (required if resolver not in db).", +) +@click.option( + "--dnssec-validate", + is_flag=True, + default=False, + help="Fail queries where DNSSEC AD flag is not set.", +) @click.argument("resolvers", nargs=-1, required=True) @click.option("--domains", "-d", help="Text file with domain list") @click.option( @@ -1067,6 +1143,10 @@ def _progress_cb(completed: int, total: int) -> None: @click.option("--quiet", is_flag=True, help="Suppress progress output") @click.option("--show-details", is_flag=True, help="Show detailed per-domain breakdown") def compare( + doh: bool, + dot: bool, + doh_url: Optional[str], + dnssec_validate: bool, resolvers: Tuple[str], domains: Optional[str], record_types: str, @@ -1094,17 +1174,13 @@ def compare( resolver_list = [] for resolver_input in resolvers: - # Try to match by name first matched = False for r in all_resolvers: if r["name"].lower() == resolver_input.lower(): resolver_list.append({"name": r["name"], "ip": r["ip"]}) matched = True break - - # If no name match, assume it's an IP if not matched: - # Validate IP format (basic check) if "." in resolver_input or ":" in resolver_input: resolver_list.append({"name": resolver_input, "ip": resolver_input}) else: @@ -1119,10 +1195,10 @@ def compare( success(f"Comparing: {', '.join([r['name'] for r in resolver_list])}") ) - # Get domains + # Get domains — supports both file path and inline comma-separated list if domains: try: - domain_list = DomainManager.load_domains_from_file(domains) + domain_list = DomainManager.parse_domains_input(domains) except Exception as e: click.echo(error(f"Error loading domains: {e}")) return @@ -1132,7 +1208,17 @@ def compare( # Parse record types record_type_list = [rt.strip().upper() for rt in record_types.split(",")] - # Run benchmark + # Resolve protocol and DoH URLs early — fail fast before any queries + try: + protocol, doh_urls = _resolve_protocol_and_doh_urls( + doh=doh, + dot=dot, + doh_url=doh_url, + resolvers=resolver_list, + ) + except click.UsageError: + raise + total_queries = ( len(resolver_list) * len(domain_list) * len(record_type_list) * iterations ) @@ -1140,6 +1226,17 @@ def compare( click.echo( info(f"Running {total_queries} queries across {iterations} iterations...") ) + if protocol != QueryProtocol.PLAIN: + click.echo(info(f"Protocol: {protocol.value.upper()}")) + if dnssec_validate: + click.echo( + info( + "DNSSEC: enforced — DO bit set, AD flag required " + "(note: success rate reflects network success, not DNSSEC outcome)" + ) + ) + else: + click.echo(info("DNSSEC: off (DO bit not set, AD flag not collected)")) progress_bar = None if not quiet: @@ -1150,6 +1247,8 @@ def compare( max_concurrent_queries=max_concurrent, timeout=timeout, enable_cache=False, + enable_dnssec=dnssec_validate, + enforce_dnssec=dnssec_validate, ) if progress_bar: @@ -1164,15 +1263,21 @@ def _progress_cb(completed: int, total: int) -> None: engine.set_progress_callback(_progress_cb) - results = asyncio.run( - engine.run_benchmark( + # Single coroutine to avoid closed event loop from two asyncio.run calls + async def _run() -> List[DNSQueryResult]: + results = await engine.run_benchmark( resolvers=resolver_list, domains=domain_list, record_types=record_type_list, iterations=iterations, warmup_fast=True, + protocol=protocol, + doh_urls=doh_urls, ) - ) + await engine.close() + return results + + results = asyncio.run(_run()) if progress_bar: progress_bar.close() @@ -1180,15 +1285,11 @@ def _progress_cb(completed: int, total: int) -> None: # Analyze analyzer = BenchmarkAnalyzer(results) resolver_stats_list = analyzer.get_resolver_statistics() - - # Convert list of ResolverStats objects to dict for easier lookup resolver_stats = {stats.resolver_name: stats for stats in resolver_stats_list} - # Display comparison if not quiet: click.echo(success("📊 Comparison Results:\n")) - # Header click.echo( Fore.CYAN + f"{'Resolver':<20} {'Avg Latency':<15} {'Success Rate':<15} {'Queries':<10}" @@ -1196,9 +1297,14 @@ def _progress_cb(completed: int, total: int) -> None: ) click.echo("-" * 65) - # Sort by latency for display + # Guard against nan avg_latency from resolvers with zero successes sorted_stats = sorted( - resolver_stats.items(), key=lambda x: x[1].avg_latency + resolver_stats.items(), + key=lambda x: ( + x[1].avg_latency + if x[1].avg_latency is not None and not math.isnan(x[1].avg_latency) + else float("inf") + ), ) for name, stats in sorted_stats: @@ -1212,7 +1318,6 @@ def _progress_cb(completed: int, total: int) -> None: if stats.success_rate >= 95 else Fore.YELLOW if stats.success_rate >= 80 else Fore.RED ) - click.echo( f"{name:<20} " f"{latency_color}{stats.avg_latency:>6.2f} ms{Style.RESET_ALL}{'':>4} " @@ -1220,7 +1325,6 @@ def _progress_cb(completed: int, total: int) -> None: f"{stats.successful_queries}/{stats.total_queries}" ) - # Winner click.echo() fastest = min(sorted_stats, key=lambda x: x[1].avg_latency) most_reliable = max(sorted_stats, key=lambda x: x[1].success_rate) @@ -1238,17 +1342,15 @@ def _progress_cb(completed: int, total: int) -> None: + f"{most_reliable[0]} ({most_reliable[1].success_rate:.1f}%)" ) - # Per-domain details if requested if show_details: click.echo(success("📋 Per-Domain Breakdown:\n")) domain_stats = analyzer.get_domain_statistics() - for dom_stat in domain_stats[:10]: # Limit to first 10 + for dom_stat in domain_stats[:10]: domain = dom_stat["domain"] click.echo(Fore.CYAN + f"\n{domain}:" + Style.RESET_ALL) for name in [r["name"] for r in resolver_list]: - # Find results for this resolver+domain domain_results = [ r for r in results @@ -1273,11 +1375,10 @@ def _progress_cb(completed: int, total: int) -> None: ext = output_path.suffix.lower() if ext == ".json": - import json - export_data = { "timestamp": datetime.now().isoformat(), "iterations": iterations, + "protocol": protocol.value, "comparison": [ { "resolver": name, @@ -1292,13 +1393,14 @@ def _progress_cb(completed: int, total: int) -> None: } with open(output_path, "w") as f: json.dump(export_data, f, indent=2) - - else: # csv + else: CSVExporter.export_summary_statistics(analyzer, str(output_path)) if not quiet: click.echo(success(f"Comparison saved to: {output_path}")) + except click.UsageError: + raise except KeyboardInterrupt: if progress_bar: progress_bar.close() @@ -1312,7 +1414,19 @@ def _progress_cb(completed: int, total: int) -> None: # ==================== Monitoring Command @cli.command() -# --------- +@click.option("--doh", is_flag=True, default=False, help="Use DNS-over-HTTPS.") +@click.option("--dot", is_flag=True, default=False, help="Use DNS-over-TLS.") +@click.option( + "--doh-url", + default=None, + help="Comma-separated DoH URLs, one per resolver (required if resolver not in db).", +) +@click.option( + "--dnssec-validate", + is_flag=True, + default=False, + help="Fail queries where DNSSEC AD flag is not set.", +) @click.option("--resolvers", "-r", help="JSON file with resolver list") @click.option("--domains", "-d", help="Text file with domain list") @click.option( @@ -1341,6 +1455,10 @@ def _progress_cb(completed: int, total: int) -> None: "--use-defaults", is_flag=True, help="Use default resolvers and sample domains" ) def monitoring( + doh: bool, + dot: bool, + doh_url: Optional[str], + dnssec_validate: bool, resolvers: Optional[str], domains: Optional[str], interval: int, @@ -1368,7 +1486,7 @@ def monitoring( click.echo(success(f"Monitoring {len(resolver_list)} default resolvers")) elif resolvers: try: - resolver_list = ResolverManager.load_resolvers_from_file(resolvers) + resolver_list = ResolverManager.parse_resolvers_input(resolvers) click.echo(success(f"Monitoring {len(resolver_list)} resolvers")) except Exception as e: click.echo(error(f"Error loading resolvers: {e}")) @@ -1379,25 +1497,40 @@ def monitoring( # Load domains if use_defaults: - # Use a smaller set for monitoring domain_list = DomainManager.get_sample_domains()[:5] elif domains: try: - domain_list = DomainManager.load_domains_from_file(domains) + domain_list = DomainManager.parse_domains_input(domains) except Exception as e: click.echo(error(f"Error loading domains: {e}")) return else: domain_list = DomainManager.get_sample_domains()[:5] + # Resolve protocol and DoH URLs early — fail fast before monitoring starts + try: + protocol, doh_urls = _resolve_protocol_and_doh_urls( + doh=doh, + dot=dot, + doh_url=doh_url, + resolvers=resolver_list, + ) + except click.UsageError: + raise + click.echo(info(f"Testing against {len(domain_list)} domains")) click.echo(info(f"Check interval: {interval}s")) if duration > 0: click.echo(info(f"Duration: {duration}s")) + if protocol != QueryProtocol.PLAIN: + click.echo(info(f"Protocol: {protocol.value.upper()}")) + if dnssec_validate: + click.echo(info("DNSSEC: enforced (AD flag required)")) + else: + click.echo(info("DNSSEC: off")) click.echo(info(f"Latency alert threshold: {alert_latency} ms")) click.echo(info(f"Failure rate alert threshold: {alert_failure_rate}%\n")) - # Setup output log log_file = None if output: log_file = open(output, "a") @@ -1410,71 +1543,77 @@ def monitoring( start_time = time.time() iteration = 0 + # Engine is created once outside the loop so DoT/DoH connections are + # reused across check intervals — avoids repeated TLS handshakes every check + engine = DNSQueryEngine( + max_concurrent_queries=50, + timeout=5.0, + enable_cache=False, + enable_dnssec=dnssec_validate, + enforce_dnssec=dnssec_validate, + ) + try: while True: iteration += 1 check_time = datetime.now().strftime("%H:%M:%S") - click.echo( Fore.CYAN + f"[{check_time}] Check #{iteration}" + Style.RESET_ALL ) - # Run quick benchmark - engine = DNSQueryEngine( - max_concurrent_queries=50, - timeout=5.0, - enable_cache=False, - ) - - results = asyncio.run( - engine.run_benchmark( + async def _run() -> List[DNSQueryResult]: + results = await engine.run_benchmark( resolvers=resolver_list, domains=domain_list, record_types=["A"], warmup=False, + protocol=protocol, + doh_urls=doh_urls, ) - ) + # Do NOT close engine here — it is reused on next interval + return results + + results = asyncio.run(_run()) - # Analyze analyzer = BenchmarkAnalyzer(results) resolver_stats_list = analyzer.get_resolver_statistics() - # Check for alerts alerts = [] for stats in resolver_stats_list: - if stats.avg_latency > alert_latency: + if stats.avg_latency and stats.avg_latency > alert_latency: alerts.append( f"⚠️ {stats.resolver_name}: High latency ({stats.avg_latency:.2f} ms)" ) - failure_rate = 100 - stats.success_rate if failure_rate > alert_failure_rate: alerts.append( f"⚠️ {stats.resolver_name}: High failure rate ({failure_rate:.1f}%)" ) - # Display results for stats in resolver_stats_list: latency_indicator = ( "🟢" - if stats.avg_latency < 50 - else "🟡" if stats.avg_latency < 100 else "🔴" + if stats.avg_latency and stats.avg_latency < 50 + else "🟡" if stats.avg_latency and stats.avg_latency < 100 else "🔴" ) success_indicator = ( "🟢" if stats.success_rate >= 95 else "🟡" if stats.success_rate >= 80 else "🔴" ) - - status_line = f" {stats.resolver_name:<20} {latency_indicator} {stats.avg_latency:>6.2f} ms {success_indicator} {stats.success_rate:>5.1f}%" + status_line = ( + f" {stats.resolver_name:<20} " + f"{latency_indicator} {stats.avg_latency:>6.2f} ms " + f"{success_indicator} {stats.success_rate:>5.1f}%" + ) click.echo(status_line) if log_file: log_file.write( - f"[{check_time}] {stats.resolver_name}: {stats.avg_latency:.2f} ms, {stats.success_rate:.1f}% success\n" + f"[{check_time}] {stats.resolver_name}: " + f"{stats.avg_latency:.2f} ms, {stats.success_rate:.1f}% success\n" ) - # Display alerts if alerts: click.echo() for alert in alerts: @@ -1487,12 +1626,10 @@ def monitoring( if log_file: log_file.flush() - # Check duration if duration > 0 and (time.time() - start_time) >= duration: click.echo(success("Monitoring duration completed")) break - # Wait for next interval time.sleep(interval) except KeyboardInterrupt: @@ -1501,6 +1638,13 @@ def monitoring( click.echo(error(f"Error during monitoring: {e}")) raise finally: + # Use a fresh event loop for cleanup since the previous one may be closed + try: + loop = asyncio.new_event_loop() + loop.run_until_complete(engine.close()) + loop.close() + except Exception: + pass # best-effort cleanup — don't crash on exit if log_file: log_file.write( f"Monitoring ended: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n" diff --git a/src/dns_benchmark/core.py b/src/dns_benchmark/core.py index 01d7a4d..5b7de22 100644 --- a/src/dns_benchmark/core.py +++ b/src/dns_benchmark/core.py @@ -11,7 +11,7 @@ from dataclasses import asdict, dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Callable, Dict, List, Optional, Tuple, cast import click import dns.asyncresolver @@ -19,13 +19,11 @@ import dns.flags import dns.message import dns.name - -# import dns.query import dns.rdatatype import httpx import idna -from dns_benchmark.utils.messages import warning +from dns_benchmark.utils.messages import error, warning class QueryStatus(Enum): @@ -36,10 +34,11 @@ class QueryStatus(Enum): CONNECTION_REFUSED = "connection_refused" UNKNOWN_ERROR = "unknown_error" DNSSEC_FAILED = "dnssec_failed" + TLS_ERROR = "tls_error" class QueryProtocol(Enum): - PLAIN = "plain" + PLAIN = "plain" # traditional DNS over UDP/TCP, dnspython will handle protocol selection and fallback DOH = "doh" DOT = "dot" @@ -104,12 +103,15 @@ def __init__( self.enable_dnssec = enable_dnssec self.enforce_dnssec = enforce_dnssec - async def _ensure_async_primitives(self) -> None: - """Create asyncio primitives when running inside an event loop.""" - if self.semaphore is None: - self.semaphore = asyncio.Semaphore(self.max_concurrent_queries) - if self._lock is None: - self._lock = asyncio.Lock() + # Shared DoH clients and DoT connections, one per resolver IP. + # Reusing these avoids repeated TLS handshakes — biggest latency win + # for encrypted protocols. Cleaned up via engine.close(). + # NOT thread-safe — safe only because asyncio is single-threaded. + # Do not access from threads without adding locks. + self._doh_clients: Dict[str, httpx.AsyncClient] = {} + self._dot_connections: Dict[ + str, Tuple[asyncio.StreamReader, asyncio.StreamWriter] + ] = {} def set_progress_callback(self, callback: Callable[[int, int], None]) -> None: """Set callback for progress updates with completed/total counts.""" @@ -119,6 +121,20 @@ def _get_cache_key(self, resolver_ip: str, domain: str, record_type: str) -> str """Generate cache key for query.""" return f"{resolver_ip}:{domain}:{record_type}" + def _validate_resolver(self, resolver: Dict[str, str]) -> None: + """Validate resolver configuration.""" + if "ip" not in resolver: + raise ValueError(f"Resolver missing 'ip' key: {resolver}") + if "name" not in resolver: + raise ValueError(f"Resolver missing 'name' key: {resolver}") + + async def _ensure_async_primitives(self) -> None: + """Create asyncio primitives when running inside an event loop.""" + if self.semaphore is None: + self.semaphore = asyncio.Semaphore(self.max_concurrent_queries) + if self._lock is None: + self._lock = asyncio.Lock() + async def _update_progress(self) -> None: """Thread-safe progress update.""" await self._ensure_async_primitives() @@ -128,12 +144,63 @@ async def _update_progress(self) -> None: if self.progress_callback: self.progress_callback(self.query_counter, self.total_queries) - def _validate_resolver(self, resolver: Dict[str, str]) -> None: - """Validate resolver configuration.""" - if "ip" not in resolver: - raise ValueError(f"Resolver missing 'ip' key: {resolver}") - if "name" not in resolver: - raise ValueError(f"Resolver missing 'name' key: {resolver}") + async def _get_doh_client(self, resolver_ip: str) -> httpx.AsyncClient: + """Return cached AsyncClient for this resolver, creating if needed.""" + if resolver_ip not in self._doh_clients: + self._doh_clients[resolver_ip] = httpx.AsyncClient( + http2=True, + timeout=self.timeout, + verify=True, + ) + return self._doh_clients[resolver_ip] + + async def _get_dot_connection( + self, + resolver_ip: str, + port: int = 853, + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: + """Return a cached DoT connection for this resolver, creating if needed. + + If the cached connection is dead (writer closing), it is evicted and + a fresh connection is opened. This avoids repeated TLS handshakes + across queries to the same resolver. + """ + existing = self._dot_connections.get(resolver_ip) + if existing: + reader, writer = existing + if not writer.is_closing(): + return reader, writer + # Dead connection — evict and fall through to reconnect + del self._dot_connections[resolver_ip] + + ssl_ctx = ssl.create_default_context() + ssl_ctx.verify_mode = ssl.CERT_REQUIRED + ssl_ctx.check_hostname = True + + reader, writer = await asyncio.wait_for( + asyncio.open_connection(resolver_ip, port, ssl=ssl_ctx), + timeout=self.timeout, + ) + self._dot_connections[resolver_ip] = (reader, writer) + return reader, writer + + async def close(self) -> None: + """Close all shared DoH clients and DoT connections. + + Must be awaited after run_benchmark completes — especially important + in FastAPI where connections are reused across requests. + """ + for client in self._doh_clients.values(): + await client.aclose() + self._doh_clients.clear() + + for _reader, writer in self._dot_connections.values(): + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + self._dot_connections.clear() async def query_single( self, @@ -398,35 +465,26 @@ async def query_single_doh( assert self.semaphore is not None start_time = time.time() - + client = await self._get_doh_client(resolver_ip) for attempt in range(self.max_retries + 1): try: async with self.semaphore: start_time = time.time() - - # Build DNS wire-format query qname = dns.name.from_text(domain) rdtype = dns.rdatatype.from_text(record_type) request = dns.message.make_query(qname, rdtype) if self.enable_dnssec: request.use_edns(ednsflags=dns.flags.DO) wire = request.to_wire() - - async with httpx.AsyncClient( - http2=True, - timeout=self.timeout, - verify=True, # enforce TLS — never disable - ) as client: - response_raw = await client.post( - doh_url, - content=wire, - headers={ - "Content-Type": "application/dns-message", - "Accept": "application/dns-message", - }, - ) - response_raw.raise_for_status() - + response_raw = await client.post( + doh_url, + content=wire, + headers={ + "Content-Type": "application/dns-message", + "Accept": "application/dns-message", + }, + ) + response_raw.raise_for_status() end_time = time.time() latency_ms = (end_time - start_time) * 1000 @@ -472,8 +530,8 @@ async def query_single_doh( domain=domain, record_type=record_type, start_time=start_time, - end_time=time.time(), - latency_ms=(time.time() - start_time) * 1000, + end_time=end_time, + latency_ms=(end_time - start_time) * 1000, status=QueryStatus.TIMEOUT, answers=[], ttl=None, @@ -489,6 +547,34 @@ async def query_single_doh( self.retry_backoff_base**attempt * self.retry_backoff_multiplier ) + except httpx.HTTPStatusError as e: + if attempt == self.max_retries: + end_time = time.time() + async with self._lock: # type: ignore[union-attr] + self.failed_resolvers[resolver_ip] += 1 + result = DNSQueryResult( + resolver_ip=resolver_ip, + resolver_name=resolver_name, + domain=domain, + record_type=record_type, + start_time=start_time, + end_time=end_time, + latency_ms=(end_time - start_time) * 1000, + status=QueryStatus.SERVFAIL, + answers=[], + ttl=None, + error_message=f"HTTP {e.response.status_code}", + attempt_number=attempt + 1, + cache_hit=False, + iteration=iteration, + protocol=QueryProtocol.DOH, + ) + await self._update_progress() + return result + await asyncio.sleep( + self.retry_backoff_base**attempt * self.retry_backoff_multiplier + ) + except Exception as e: if attempt == self.max_retries: end_time = time.time() @@ -544,8 +630,12 @@ async def query_single_dot( port: int = 853, iteration: int = 1, ) -> DNSQueryResult: - """Execute a single DNS-over-TLS query.""" + """Execute a single DNS-over-TLS query. + Reuses a pooled TLS connection per resolver to avoid handshake overhead + on every query. Connection is evicted from the pool on any error so the + next query gets a fresh connection. + """ await self._ensure_async_primitives() assert self.semaphore is not None @@ -556,43 +646,30 @@ async def query_single_dot( async with self.semaphore: start_time = time.time() - # Build wire-format query with length prefix (RFC 7858) qname = dns.name.from_text(domain) rdtype = dns.rdatatype.from_text(record_type) request = dns.message.make_query(qname, rdtype) if self.enable_dnssec: request.use_edns(ednsflags=dns.flags.DO) wire = request.to_wire() - # 2-byte length prefix required by DoT spec + # 2-byte length prefix required by RFC 7858 prefixed = struct.pack("!H", len(wire)) + wire - ssl_ctx = ssl.create_default_context() - # enforce cert validation — never bypass for security tool - ssl_ctx.verify_mode = ssl.CERT_REQUIRED - ssl_ctx.check_hostname = True + # Reuse pooled connection — no TLS handshake if already open + reader, writer = await self._get_dot_connection(resolver_ip, port) - reader, writer = await asyncio.wait_for( - asyncio.open_connection(resolver_ip, port, ssl=ssl_ctx), - timeout=self.timeout, - ) - try: - writer.write(prefixed) - await writer.drain() + writer.write(prefixed) + await writer.drain() - # Read 2-byte length prefix then full message - raw_len = await asyncio.wait_for( - reader.readexactly(2), timeout=self.timeout - ) - msg_len = struct.unpack("!H", raw_len)[0] - raw_msg = await asyncio.wait_for( - reader.readexactly(msg_len), timeout=self.timeout - ) - finally: - writer.close() - try: - await writer.wait_closed() - except Exception: - pass + # Read 2-byte length prefix then full message + raw_len = await asyncio.wait_for( + reader.readexactly(2), timeout=self.timeout + ) + msg_len = struct.unpack("!H", raw_len)[0] + raw_msg = await asyncio.wait_for( + reader.readexactly(msg_len), timeout=self.timeout + ) + # Do NOT close writer — connection is pooled and reused end_time = time.time() latency_ms = (end_time - start_time) * 1000 @@ -629,7 +706,10 @@ async def query_single_dot( return result except asyncio.TimeoutError: + # Evict connection — may be in a bad state after timeout + self._dot_connections.pop(resolver_ip, None) if attempt == self.max_retries: + end_time = time.time() async with self._lock: # type: ignore[union-attr] self.failed_resolvers[resolver_ip] += 1 result = DNSQueryResult( @@ -638,8 +718,8 @@ async def query_single_dot( domain=domain, record_type=record_type, start_time=start_time, - end_time=time.time(), - latency_ms=(time.time() - start_time) * 1000, + end_time=end_time, + latency_ms=(end_time - start_time) * 1000, status=QueryStatus.TIMEOUT, answers=[], ttl=None, @@ -656,7 +736,9 @@ async def query_single_dot( ) except ssl.SSLError as e: - # SSL errors are not retryable + # SSL errors are not retryable — evict and return immediately + self._dot_connections.pop(resolver_ip, None) + end_time = time.time() async with self._lock: # type: ignore[union-attr] self.failed_resolvers[resolver_ip] += 1 result = DNSQueryResult( @@ -665,9 +747,9 @@ async def query_single_dot( domain=domain, record_type=record_type, start_time=start_time, - end_time=time.time(), - latency_ms=(time.time() - start_time) * 1000, - status=QueryStatus.CONNECTION_REFUSED, + end_time=end_time, + latency_ms=(end_time - start_time) * 1000, + status=QueryStatus.TLS_ERROR, answers=[], ttl=None, error_message=f"TLS error: {e}", @@ -680,7 +762,10 @@ async def query_single_dot( return result except Exception as e: + # Evict connection on any unknown error before retrying + self._dot_connections.pop(resolver_ip, None) if attempt == self.max_retries: + end_time = time.time() async with self._lock: # type: ignore[union-attr] self.failed_resolvers[resolver_ip] += 1 result = DNSQueryResult( @@ -689,8 +774,8 @@ async def query_single_dot( domain=domain, record_type=record_type, start_time=start_time, - end_time=time.time(), - latency_ms=(time.time() - start_time) * 1000, + end_time=end_time, + latency_ms=(end_time - start_time) * 1000, status=QueryStatus.UNKNOWN_ERROR, answers=[], ttl=None, @@ -758,11 +843,14 @@ async def run_benchmark( if not record_types: record_types = ["A"] - # Warmup handling (warmup_fast takes precedence) + # Warmup uses same protocol as benchmark so connection overhead is + # representative. warmup_fast takes precedence over warmup. if warmup_fast: - warmup_results = await self._run_fast_warmup(resolvers) + warmup_results = await self._run_fast_warmup(resolvers, protocol, doh_urls) elif warmup: - warmup_results = await self._run_warmup(resolvers, domains, record_types) + warmup_results = await self._run_warmup( + resolvers, domains, record_types, protocol, doh_urls + ) else: warmup_results = [] @@ -775,7 +863,7 @@ async def run_benchmark( ) ) - # Reset counters for actual benchmark + # Reset counters after warmup so progress tracks benchmark queries only self.query_counter = 0 self.total_queries = ( len(resolvers) * len(domains) * len(record_types) * iterations @@ -788,6 +876,12 @@ async def run_benchmark( for record_type in record_types: if protocol == QueryProtocol.DOH: url = (doh_urls or {}).get(resolver["ip"], "") + if not url: + click.echo( + error( + f"No DoH URL configured for resolver {resolver['ip']} ({resolver['name']})" + ) + ) task = self.query_single_doh( resolver_ip=resolver["ip"], resolver_name=resolver["name"], @@ -823,6 +917,8 @@ async def _run_warmup( resolvers: List[Dict[str, str]], domains: List[str], record_types: List[str], + protocol: QueryProtocol = QueryProtocol.PLAIN, + doh_urls: Optional[Dict[str, str]] = None, ) -> List[DNSQueryResult]: """Run full warmup queries (all combinations). @@ -832,39 +928,80 @@ async def _run_warmup( for resolver in resolvers: for domain in domains: for record_type in record_types: - task = self.query_single( - resolver_ip=resolver["ip"], - resolver_name=resolver["name"], - domain=domain, - record_type=record_type, - use_cache=False, - iteration=0, # Mark as warmup - ) + if protocol == QueryProtocol.DOH: + url = (doh_urls or {}).get(resolver["ip"], "") + task = self.query_single_doh( + resolver_ip=resolver["ip"], + resolver_name=resolver["name"], + domain=domain, + doh_url=url, + record_type=record_type, + iteration=0, # Mark as warmup + ) + elif protocol == QueryProtocol.DOT: + task = self.query_single_dot( + resolver_ip=resolver["ip"], + resolver_name=resolver["name"], + domain=domain, + record_type=record_type, + iteration=0, + ) + else: + task = self.query_single( + resolver_ip=resolver["ip"], + resolver_name=resolver["name"], + domain=domain, + record_type=record_type, + use_cache=False, + iteration=0, + ) tasks.append(task) return await asyncio.gather(*tasks) async def _run_fast_warmup( self, resolvers: List[Dict[str, str]], + protocol: QueryProtocol = QueryProtocol.PLAIN, + doh_urls: Optional[Dict[str, str]] = None, probe_domain: str = "example.com", record_type: str = "A", ) -> List[DNSQueryResult]: """Lightweight warmup: one query per resolver. Uses a known-good domain to verify resolver connectivity. + Respects the active protocol so warmup overhead matches benchmark overhead. Does not update progress counters or cache results. """ - tasks = [ - self.query_single( - resolver_ip=r["ip"], - resolver_name=r["name"], - domain=probe_domain, - record_type=record_type, - use_cache=False, - iteration=0, # Mark as warmup - ) - for r in resolvers - ] + tasks = [] + for r in resolvers: + if protocol == QueryProtocol.DOH: + url = (doh_urls or {}).get(r["ip"], "") + task = self.query_single_doh( + resolver_ip=r["ip"], + resolver_name=r["name"], + domain=probe_domain, + doh_url=url, + record_type=record_type, + iteration=0, # Mark as warmup + ) + elif protocol == QueryProtocol.DOT: + task = self.query_single_dot( + resolver_ip=r["ip"], + resolver_name=r["name"], + domain=probe_domain, + record_type=record_type, + iteration=0, + ) + else: + task = self.query_single( + resolver_ip=r["ip"], + resolver_name=r["name"], + domain=probe_domain, + record_type=record_type, + use_cache=False, + iteration=0, + ) + tasks.append(task) return await asyncio.gather(*tasks) def clear_cache(self) -> None: @@ -1277,7 +1414,10 @@ def get_default_resolvers() -> List[Dict[str, str]]: {"name": "Google", "ip": "8.8.8.8"}, {"name": "Quad9", "ip": "9.9.9.9"}, {"name": "OpenDNS", "ip": "208.67.222.222"}, - {"name": "Comodo", "ip": "8.26.56.26"}, + # No doh endpoints for these, so commented out for now. + # If you know valid DoH URLs, + # please open a Pull Request to add them back in. + # {"name": "Comodo", "ip": "8.26.56.26"}, ] @staticmethod diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 016554a..1c3ac55 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -276,7 +276,7 @@ async def test_query_single_dot_tls_error(engine: DNSQueryEngine) -> None: domain="google.com", ) - assert result.status == QueryStatus.CONNECTION_REFUSED + assert result.status == QueryStatus.TLS_ERROR assert result.protocol == QueryProtocol.DOT assert "TLS error" in (result.error_message or "")