diff --git a/v1/src/core/csi_processor.py b/v1/src/core/csi_processor.py index c6e4fa92..7be7cca5 100644 --- a/v1/src/core/csi_processor.py +++ b/v1/src/core/csi_processor.py @@ -81,6 +81,8 @@ def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = No # Processing state self.csi_history = deque(maxlen=self.max_history_size) self.previous_detection_confidence = 0.0 + self.signal_quality_history = deque(maxlen=self.max_history_size) + self.signal_quality_threshold = config.get('signal_quality_threshold', 0.35) # Doppler cache: pre-computed mean phase per frame for O(1) append self._phase_cache = deque(maxlen=self.max_history_size) @@ -131,6 +133,19 @@ def preprocess_csi_data(self, csi_data: CSIData) -> CSIData: return csi_data try: + # Validate frame integrity + self._validate_csi_data(csi_data) + + # Assess signal quality metrics + quality = self.assess_signal_quality(csi_data) + self._update_signal_quality_history(quality) + + if quality['valid_ratio'] < self.signal_quality_threshold: + self.logger.warning( + "Low CSI signal quality (ratio %.2f) detected", + quality['valid_ratio'] + ) + # Remove noise from the signal cleaned_data = self._remove_noise(csi_data) @@ -140,6 +155,12 @@ def preprocess_csi_data(self, csi_data: CSIData) -> CSIData: # Normalize amplitude values normalized_data = self._normalize_amplitude(windowed_data) + # Attach quality metadata for downstream consumers + normalized_data.metadata = { + **normalized_data.metadata, + 'signal_quality': quality + } + return normalized_data except Exception as e: @@ -172,7 +193,8 @@ def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]: # Extract Doppler and frequency features doppler_shift, power_spectral_density = self._extract_doppler_features(csi_data) - + + quality = csi_data.metadata.get('signal_quality') if csi_data.metadata else None return CSIFeatures( amplitude_mean=amplitude_mean, amplitude_variance=amplitude_variance, @@ -181,7 +203,10 @@ def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]: doppler_shift=doppler_shift, power_spectral_density=power_spectral_density, timestamp=datetime.now(timezone.utc), - metadata={'processing_params': self.config} + metadata={ + 'processing_params': self.config, + **({'signal_quality': quality} if quality is not None else {}) + } ) except Exception as e: @@ -464,4 +489,71 @@ def _apply_temporal_smoothing(self, raw_confidence: float) -> float: (1 - self.smoothing_factor) * raw_confidence) self.previous_detection_confidence = smoothed_confidence - return smoothed_confidence \ No newline at end of file + return smoothed_confidence + + def _validate_csi_data(self, csi_data: CSIData) -> None: + """Validate CSI frame dimensions and safety.""" + if csi_data.amplitude.ndim != 2 or csi_data.phase.ndim != 2: + raise CSIProcessingError("CSI data must be 2D amplitude/phase") + if np.any(np.isnan(csi_data.amplitude)) or np.any(np.isnan(csi_data.phase)): + raise CSIProcessingError("CSI data contains NaNs") + if csi_data.num_subcarriers <= 0 or csi_data.num_antennas <= 0: + raise CSIProcessingError("Invalid CSI dimensions") + + def assess_signal_quality(self, csi_data: CSIData) -> Dict[str, Any]: + """Compute signal quality metrics used by the pipeline.""" + amplitude = csi_data.amplitude + with np.errstate(divide='ignore'): + amplitude_db = 20 * np.log10(np.clip(amplitude, 1e-12, None)) + + mean_db = float(np.nanmean(amplitude_db)) + valid_ratio = float(np.nanmean(amplitude_db > (mean_db - 10))) + correlation = float(np.nanmean(np.abs(np.corrcoef(amplitude + 1e-9)))) + temporal_stability = self._calculate_temporal_stability(amplitude) + + return { + 'snr': float(csi_data.snr), + 'mean_db': mean_db, + 'valid_ratio': valid_ratio, + 'correlation': correlation, + 'temporal_stability': temporal_stability, + 'subcarrier_count': csi_data.num_subcarriers, + 'antenna_count': csi_data.num_antennas + } + + def _calculate_temporal_stability(self, amplitude: np.ndarray) -> float: + """Score stability relative to the previous frame.""" + if not self.csi_history: + return 1.0 + last_amplitude = self.csi_history[-1].amplitude + if last_amplitude.shape != amplitude.shape: + return 0.0 + diff = np.linalg.norm(amplitude - last_amplitude) + base = np.linalg.norm(last_amplitude) + 1e-9 + ratio = diff / base + return float(max(0.0, 1.0 - min(1.0, ratio))) + + def _update_signal_quality_history(self, quality: Dict[str, Any]) -> None: + """Bookkeeping for recent signal quality summaries.""" + self.signal_quality_history.append(quality) + + def get_signal_quality_trend(self) -> Dict[str, float]: + """Return aggregate quality metrics over the recent horizon.""" + if not self.signal_quality_history: + return {} + + total = len(self.signal_quality_history) + summary = { + 'avg_snr': 0.0, + 'avg_valid_ratio': 0.0, + 'avg_correlation': 0.0, + 'avg_temporal_stability': 0.0 + } + + for entry in self.signal_quality_history: + summary['avg_snr'] += entry.get('snr', 0.0) + summary['avg_valid_ratio'] += entry.get('valid_ratio', 0.0) + summary['avg_correlation'] += entry.get('correlation', 0.0) + summary['avg_temporal_stability'] += entry.get('temporal_stability', 0.0) + + return {key: value / total for key, value in summary.items()} diff --git a/v1/src/core/phase_sanitizer.py b/v1/src/core/phase_sanitizer.py index 91482063..706df602 100644 --- a/v1/src/core/phase_sanitizer.py +++ b/v1/src/core/phase_sanitizer.py @@ -4,6 +4,7 @@ import logging from typing import Dict, Any, Optional, Tuple from datetime import datetime, timezone +from collections import deque from scipy import signal @@ -41,6 +42,9 @@ def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = No self.enable_noise_filtering = config.get('enable_noise_filtering', False) self.noise_threshold = config.get('noise_threshold', 0.05) self.phase_range = config.get('phase_range', (-np.pi, np.pi)) + self.temporal_window = config.get('temporal_window', 5) + self.noise_history_size = config.get('noise_history_size', 128) + self.noise_profile = deque(maxlen=self.noise_history_size) # Statistics tracking self._total_processed = 0 @@ -262,6 +266,41 @@ def _apply_low_pass_filter(self, phase_data: np.ndarray, threshold: float) -> np filtered_data[i, :] = signal.filtfilt(b, a, phase_data[i, :]) return filtered_data + + def _apply_temporal_consistency(self, phase_data: np.ndarray) -> np.ndarray: + """Reduce sudden frame-to-frame jumps.""" + if self.temporal_window <= 1 or phase_data.shape[1] < 2: + return phase_data + + adjusted = phase_data.copy() + diffs = np.abs(np.diff(phase_data, axis=1)) + threshold = self.outlier_threshold / 2 + + for idx in range(1, phase_data.shape[1]): + jump_mask = diffs[:, idx - 1] > threshold + if np.any(jump_mask): + adjusted[jump_mask, idx] = adjusted[jump_mask, idx - 1] + return adjusted + + def _apply_phase_noise_filter(self, phase_data: np.ndarray) -> np.ndarray: + """Apply conservative noise filtering and track profile.""" + if not self.enable_noise_filtering: + return phase_data + + filtered = signal.medfilt(phase_data, kernel_size=(1, 3)) + noise_level = float(np.nanmean(np.abs(phase_data - filtered))) + self.noise_profile.append(noise_level) + return filtered + + def get_noise_profile(self) -> Dict[str, Any]: + """Return the recent noise metrics.""" + if not self.noise_profile: + return {'avg_noise': 0.0} + avg_noise = float(sum(self.noise_profile) / len(self.noise_profile)) + return { + 'avg_noise': avg_noise, + 'samples': len(self.noise_profile) + } def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray: """Sanitize phase data through complete pipeline. @@ -281,9 +320,11 @@ def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray: # Validate input data self.validate_phase_data(phase_data) - # Apply complete sanitization pipeline + # Apply complete sanitization pipeline with temporal consistency and noise tracking sanitized_data = self.unwrap_phase(phase_data) + sanitized_data = self._apply_temporal_consistency(sanitized_data) sanitized_data = self.remove_outliers(sanitized_data) + sanitized_data = self._apply_phase_noise_filter(sanitized_data) sanitized_data = self.smooth_phase(sanitized_data) sanitized_data = self.filter_noise(sanitized_data) @@ -344,4 +385,4 @@ def reset_statistics(self) -> None: """Reset sanitization statistics.""" self._total_processed = 0 self._outliers_removed = 0 - self._sanitization_errors = 0 \ No newline at end of file + self._sanitization_errors = 0 diff --git a/v1/src/hardware/csi_extractor.py b/v1/src/hardware/csi_extractor.py index edb43325..94ab28bf 100644 --- a/v1/src/hardware/csi_extractor.py +++ b/v1/src/hardware/csi_extractor.py @@ -1,6 +1,7 @@ """CSI data extraction from WiFi hardware using Test-Driven Development approach.""" import asyncio +import contextlib import struct import numpy as np from datetime import datetime, timezone @@ -301,6 +302,10 @@ def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = No # State management self.is_connected = False self.is_streaming = False + self._udp_transport: Optional[asyncio.BaseTransport] = None + self._udp_protocol = None + self._udp_queue: Optional[asyncio.Queue] = None + self._udp_keepalive_task: Optional[asyncio.Task] = None # Create appropriate parser if self.hardware_type == 'esp32': @@ -446,14 +451,31 @@ def stop_streaming(self) -> None: self.is_streaming = False async def _establish_hardware_connection(self) -> bool: - """Establish connection to hardware (to be implemented by subclasses).""" - # Placeholder implementation for testing - return True + """Establish connection to the CSI source.""" + if self.hardware_type == 'esp32': + host = self.config.get('aggregator_host', '0.0.0.0') + port = self.config.get('aggregator_port', 5005) + await self._setup_udp_endpoint(host, port) + return True + elif self.hardware_type == 'router': + # Router clients rely on RouterInterface; nothing to open here + return True + + self.logger.warning(f"Unsupported hardware type for connection: {self.hardware_type}") + return False async def _close_hardware_connection(self) -> None: - """Close hardware connection (to be implemented by subclasses).""" - # Placeholder implementation for testing - pass + """Close any open hardware transport resources.""" + if self._udp_transport: + self._udp_transport.close() + self._udp_transport = None + if self._udp_keepalive_task: + self._udp_keepalive_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._udp_keepalive_task + self._udp_keepalive_task = None + self._udp_queue = None + self._udp_protocol = None async def _read_raw_data(self) -> bytes: """Read raw data from hardware. @@ -475,42 +497,70 @@ async def _read_udp_data(self) -> bytes: Raises: CSIExtractionError: If read times out or connection fails. """ - host = self.config.get('aggregator_host', '0.0.0.0') - port = self.config.get('aggregator_port', 5005) + if not self._udp_queue: + raise CSIExtractionError("UDP endpoint is not initialized") + + try: + data = await asyncio.wait_for(self._udp_queue.get(), timeout=self.timeout) + return data + except asyncio.TimeoutError: + host = self.config.get('aggregator_host', '0.0.0.0') + port = self.config.get('aggregator_port', 5005) + raise CSIExtractionError( + f"UDP read timed out after {self.timeout}s. " + f"Ensure the aggregator is running and sending to {host}:{port}." + ) + async def _setup_udp_endpoint(self, host: str, port: int) -> None: + """Set up a UDP listener for the aggregator stream.""" loop = asyncio.get_event_loop() - # Create UDP endpoint if not already cached - if not hasattr(self, '_udp_transport'): - self._udp_future: asyncio.Future = loop.create_future() + queue = asyncio.Queue(maxsize=max(1, self.buffer_size * 2)) + self._udp_queue = queue + + class _UDPProtocol(asyncio.DatagramProtocol): + def __init__(self, queue: asyncio.Queue): + self.queue = queue + + def datagram_received(self, data, addr): + if self.queue.full(): + try: + self.queue.get_nowait() + except asyncio.QueueEmpty: + pass + try: + self.queue.put_nowait(data) + except asyncio.QueueFull: + pass + + def error_received(self, exc): + logging.getLogger(__name__).error(f"UDP error: {exc}") + + transport, protocol = await loop.create_datagram_endpoint( + lambda: _UDPProtocol(queue), + local_addr=(host, port), + ) - class _UdpProtocol(asyncio.DatagramProtocol): - def __init__(self, future): - self._future = future + self._udp_transport = transport + self._udp_protocol = protocol - def datagram_received(self, data, addr): - if not self._future.done(): - self._future.set_result(data) + keepalive_message = self.config.get('aggregator_keepalive_message') + keepalive_interval = self.config.get('aggregator_keepalive_interval', 5.0) - def error_received(self, exc): - if not self._future.done(): - self._future.set_exception(exc) + if isinstance(keepalive_message, str): + keepalive_message = keepalive_message.encode('utf-8') - transport, protocol = await loop.create_datagram_endpoint( - lambda: _UdpProtocol(self._udp_future), - local_addr=(host, port), + if keepalive_message: + self._udp_keepalive_task = asyncio.create_task( + self._udp_keepalive_loop(host, port, keepalive_message, keepalive_interval) ) - self._udp_transport = transport - self._udp_protocol = protocol - try: - data = await asyncio.wait_for(self._udp_future, timeout=self.timeout) - # Reset future for next read - self._udp_future = loop.create_future() - self._udp_protocol._future = self._udp_future - return data - except asyncio.TimeoutError: - raise CSIExtractionError( - f"UDP read timed out after {self.timeout}s. " - f"Ensure the aggregator is running and sending to {host}:{port}." - ) \ No newline at end of file + async def _udp_keepalive_loop(self, host: str, port: int, message: bytes, interval: float): + """Periodically ping the aggregator to keep the socket open.""" + while True: + try: + if self._udp_transport and message: + self._udp_transport.sendto(message, (host, port)) + await asyncio.sleep(interval) + except asyncio.CancelledError: + break diff --git a/v1/src/hardware/router_interface.py b/v1/src/hardware/router_interface.py index 8aa25c76..31047cc9 100644 --- a/v1/src/hardware/router_interface.py +++ b/v1/src/hardware/router_interface.py @@ -1,7 +1,9 @@ """Router interface for WiFi-DensePose system using TDD approach.""" import asyncio +import base64 import logging +import struct from typing import Dict, Any, Optional import asyncssh from datetime import datetime, timezone @@ -19,6 +21,100 @@ class RouterConnectionError(Exception): pass +class RouterCSIParser: + """Parser for router CSI matrices (Atheros / Nexmon).""" + + class AtherosCSIFormat: + HEADER_SIZE = 25 + + @staticmethod + def parse_header(data: bytes) -> Dict[str, Any]: + if len(data) < RouterCSIParser.AtherosCSIFormat.HEADER_SIZE: + raise CSIParseError("Atheros header too short") + timestamp = struct.unpack(' int: + byte_offset = bit_offset // 8 + bit_shift = bit_offset % 8 + if byte_offset + 1 >= len(data): + return 0 + window = (data[byte_offset] << 8) | data[byte_offset + 1] + return (window >> (6 - bit_shift)) & 0x3FF + + @staticmethod + def parse_csi_data(data: bytes, header: Dict[str, Any]) -> np.ndarray: + start = RouterCSIParser.AtherosCSIFormat.HEADER_SIZE + length = header['csi_length'] + if len(data) < start + length: + raise CSIParseError("Atheros CSI payload truncated") + payload = data[start:start + length] + samples = [] + bit_offset = 0 + while bit_offset + 20 <= len(payload) * 8: + real = RouterCSIParser.AtherosCSIFormat._extract_10bit(payload, bit_offset) + imag = RouterCSIParser.AtherosCSIFormat._extract_10bit(payload, bit_offset + 10) + real = real - 512 if real > 511 else real + imag = imag - 512 if imag > 511 else imag + samples.append(complex(real, imag)) + bit_offset += 20 + if not samples: + raise CSIParseError("No complex samples recovered from Atheros payload") + tx = 3 if header['antenna_config'] == 0x07 else 2 + rx = 3 + num_subcarriers = len(samples) // (tx * rx) + if num_subcarriers == 0: + raise CSIParseError("Subcarrier count is zero") + matrix = np.array(samples[:tx * rx * num_subcarriers]) + matrix = matrix.reshape((tx * rx, num_subcarriers)) + return matrix + + def parse(self, raw_data: bytes) -> CSIData: + data = raw_data.strip() + if data.startswith(b'CSI_HEX:'): + data = bytes.fromhex(data.split(b':', 1)[1].strip().decode('utf-8')) + elif data.startswith(b'CSI_BASE64:'): + data = base64.b64decode(data.split(b':', 1)[1].strip()) + elif data.startswith(b'0x'): + data = bytes.fromhex(data[2:].decode('utf-8')) + + header = self.AtherosCSIFormat.parse_header(data) + matrix = self.AtherosCSIFormat.parse_csi_data(data, header) + amplitude = np.abs(matrix) + phase = np.angle(matrix) + return CSIData( + timestamp=datetime.now(tz=timezone.utc), + amplitude=amplitude, + phase=phase, + frequency=header['channel'] * 1e6, + bandwidth=20e6, + num_subcarriers=amplitude.shape[1], + num_antennas=amplitude.shape[0], + snr=float(header['rssi'] - header['noise']), + metadata={ + 'source': 'router', + 'router_channel': header['channel'], + 'mac_address': header['mac_address'] + } + ) + + class RouterInterface: """Interface for communicating with WiFi routers via SSH.""" @@ -50,6 +146,9 @@ def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = No # Connection state self.is_connected = False self.ssh_client = None + self.parser = RouterCSIParser() + self.csi_command = config.get('csi_command', 'cat /tmp/csi.bin') + self.output_encoding = config.get('output_encoding', 'binary') def _validate_config(self, config: Dict[str, Any]) -> None: """Validate configuration parameters. @@ -136,16 +235,17 @@ async def execute_command(self, command: str) -> str: async def get_csi_data(self) -> CSIData: """Retrieve CSI data from router. - + Returns: CSI data structure - + Raises: RouterConnectionError: If data retrieval fails """ try: - response = await self.execute_command("iwlist scan | grep CSI") - return self._parse_csi_response(response) + response = await self.execute_command(self.csi_command) + raw_bytes = self._decode_csi_output(response) + return self.parser.parse(raw_bytes) except Exception as e: raise RouterConnectionError(f"Failed to retrieve CSI data: {e}") @@ -198,28 +298,18 @@ async def health_check(self) -> bool: self.logger.error(f"Health check failed: {e}") return False - def _parse_csi_response(self, response: str) -> CSIData: - """Parse CSI response data. - - Args: - response: Raw response from router - - Returns: - Parsed CSI data - - Raises: - RouterConnectionError: Always in current state, because real CSI - parsing from router command output requires hardware-specific - format knowledge that must be implemented per router model. - """ - raise RouterConnectionError( - "Real CSI data parsing from router responses is not yet implemented. " - "Collecting CSI data from a router requires: " - "(1) a router with CSI-capable firmware (e.g., Atheros CSI Tool, Nexmon), " - "(2) proper hardware setup and configuration, and " - "(3) a parser for the specific binary/text format produced by the firmware. " - "See docs/hardware-setup.md for instructions on configuring your router for CSI collection." - ) + def _decode_csi_output(self, response: str) -> bytes: + payload = response.strip() + if self.output_encoding == 'hex' or payload.startswith('0x'): + normalized = payload[2:] if payload.startswith('0x') else payload + return bytes.fromhex(normalized) + if self.output_encoding == 'base64' or payload.startswith('CSI_BASE64:'): + _, encoded = payload.split(':', 1) if ':' in payload else ('', payload) + return base64.b64decode(encoded.strip()) + if payload.startswith('CSI_HEX:'): + _, encoded = payload.split(':', 1) + return bytes.fromhex(encoded.strip()) + return payload.encode('latin-1') def _parse_status_response(self, response: str) -> Dict[str, Any]: """Parse router status response. @@ -238,4 +328,4 @@ def _parse_status_response(self, response: str) -> Dict[str, Any]: 'wifi_status': 'active', 'uptime': '5 days, 3 hours', 'raw_response': response - } \ No newline at end of file + }