Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 95 additions & 3 deletions v1/src/core/csi_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = No
# Processing state
self.csi_history = deque(maxlen=self.max_history_size)
self.previous_detection_confidence = 0.0
self.signal_quality_history = deque(maxlen=self.max_history_size)
self.signal_quality_threshold = config.get('signal_quality_threshold', 0.35)

# Doppler cache: pre-computed mean phase per frame for O(1) append
self._phase_cache = deque(maxlen=self.max_history_size)
Expand Down Expand Up @@ -131,6 +133,19 @@ def preprocess_csi_data(self, csi_data: CSIData) -> CSIData:
return csi_data

try:
# Validate frame integrity
self._validate_csi_data(csi_data)

# Assess signal quality metrics
quality = self.assess_signal_quality(csi_data)
self._update_signal_quality_history(quality)

if quality['valid_ratio'] < self.signal_quality_threshold:
self.logger.warning(
"Low CSI signal quality (ratio %.2f) detected",
quality['valid_ratio']
)

# Remove noise from the signal
cleaned_data = self._remove_noise(csi_data)

Expand All @@ -140,6 +155,12 @@ def preprocess_csi_data(self, csi_data: CSIData) -> CSIData:
# Normalize amplitude values
normalized_data = self._normalize_amplitude(windowed_data)

# Attach quality metadata for downstream consumers
normalized_data.metadata = {
**normalized_data.metadata,
'signal_quality': quality
}

return normalized_data

except Exception as e:
Expand Down Expand Up @@ -172,7 +193,8 @@ def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]:

# Extract Doppler and frequency features
doppler_shift, power_spectral_density = self._extract_doppler_features(csi_data)


quality = csi_data.metadata.get('signal_quality') if csi_data.metadata else None
return CSIFeatures(
amplitude_mean=amplitude_mean,
amplitude_variance=amplitude_variance,
Expand All @@ -181,7 +203,10 @@ def extract_features(self, csi_data: CSIData) -> Optional[CSIFeatures]:
doppler_shift=doppler_shift,
power_spectral_density=power_spectral_density,
timestamp=datetime.now(timezone.utc),
metadata={'processing_params': self.config}
metadata={
'processing_params': self.config,
**({'signal_quality': quality} if quality is not None else {})
}
)

except Exception as e:
Expand Down Expand Up @@ -464,4 +489,71 @@ def _apply_temporal_smoothing(self, raw_confidence: float) -> float:
(1 - self.smoothing_factor) * raw_confidence)

self.previous_detection_confidence = smoothed_confidence
return smoothed_confidence
return smoothed_confidence

def _validate_csi_data(self, csi_data: CSIData) -> None:
"""Validate CSI frame dimensions and safety."""
if csi_data.amplitude.ndim != 2 or csi_data.phase.ndim != 2:
raise CSIProcessingError("CSI data must be 2D amplitude/phase")
if np.any(np.isnan(csi_data.amplitude)) or np.any(np.isnan(csi_data.phase)):
raise CSIProcessingError("CSI data contains NaNs")
if csi_data.num_subcarriers <= 0 or csi_data.num_antennas <= 0:
raise CSIProcessingError("Invalid CSI dimensions")

def assess_signal_quality(self, csi_data: CSIData) -> Dict[str, Any]:
"""Compute signal quality metrics used by the pipeline."""
amplitude = csi_data.amplitude
with np.errstate(divide='ignore'):
amplitude_db = 20 * np.log10(np.clip(amplitude, 1e-12, None))

mean_db = float(np.nanmean(amplitude_db))
valid_ratio = float(np.nanmean(amplitude_db > (mean_db - 10)))
correlation = float(np.nanmean(np.abs(np.corrcoef(amplitude + 1e-9))))
temporal_stability = self._calculate_temporal_stability(amplitude)

return {
'snr': float(csi_data.snr),
'mean_db': mean_db,
'valid_ratio': valid_ratio,
'correlation': correlation,
'temporal_stability': temporal_stability,
'subcarrier_count': csi_data.num_subcarriers,
'antenna_count': csi_data.num_antennas
}

def _calculate_temporal_stability(self, amplitude: np.ndarray) -> float:
"""Score stability relative to the previous frame."""
if not self.csi_history:
return 1.0
last_amplitude = self.csi_history[-1].amplitude
if last_amplitude.shape != amplitude.shape:
return 0.0
diff = np.linalg.norm(amplitude - last_amplitude)
base = np.linalg.norm(last_amplitude) + 1e-9
ratio = diff / base
return float(max(0.0, 1.0 - min(1.0, ratio)))

def _update_signal_quality_history(self, quality: Dict[str, Any]) -> None:
"""Bookkeeping for recent signal quality summaries."""
self.signal_quality_history.append(quality)

def get_signal_quality_trend(self) -> Dict[str, float]:
"""Return aggregate quality metrics over the recent horizon."""
if not self.signal_quality_history:
return {}

total = len(self.signal_quality_history)
summary = {
'avg_snr': 0.0,
'avg_valid_ratio': 0.0,
'avg_correlation': 0.0,
'avg_temporal_stability': 0.0
}

for entry in self.signal_quality_history:
summary['avg_snr'] += entry.get('snr', 0.0)
summary['avg_valid_ratio'] += entry.get('valid_ratio', 0.0)
summary['avg_correlation'] += entry.get('correlation', 0.0)
summary['avg_temporal_stability'] += entry.get('temporal_stability', 0.0)

return {key: value / total for key, value in summary.items()}
45 changes: 43 additions & 2 deletions v1/src/core/phase_sanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
from typing import Dict, Any, Optional, Tuple
from datetime import datetime, timezone
from collections import deque
from scipy import signal


Expand Down Expand Up @@ -41,6 +42,9 @@ def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = No
self.enable_noise_filtering = config.get('enable_noise_filtering', False)
self.noise_threshold = config.get('noise_threshold', 0.05)
self.phase_range = config.get('phase_range', (-np.pi, np.pi))
self.temporal_window = config.get('temporal_window', 5)
self.noise_history_size = config.get('noise_history_size', 128)
self.noise_profile = deque(maxlen=self.noise_history_size)

# Statistics tracking
self._total_processed = 0
Expand Down Expand Up @@ -262,6 +266,41 @@ def _apply_low_pass_filter(self, phase_data: np.ndarray, threshold: float) -> np
filtered_data[i, :] = signal.filtfilt(b, a, phase_data[i, :])

return filtered_data

def _apply_temporal_consistency(self, phase_data: np.ndarray) -> np.ndarray:
"""Reduce sudden frame-to-frame jumps."""
if self.temporal_window <= 1 or phase_data.shape[1] < 2:
return phase_data

adjusted = phase_data.copy()
diffs = np.abs(np.diff(phase_data, axis=1))
threshold = self.outlier_threshold / 2

for idx in range(1, phase_data.shape[1]):
jump_mask = diffs[:, idx - 1] > threshold
if np.any(jump_mask):
adjusted[jump_mask, idx] = adjusted[jump_mask, idx - 1]
return adjusted

def _apply_phase_noise_filter(self, phase_data: np.ndarray) -> np.ndarray:
"""Apply conservative noise filtering and track profile."""
if not self.enable_noise_filtering:
return phase_data

filtered = signal.medfilt(phase_data, kernel_size=(1, 3))
noise_level = float(np.nanmean(np.abs(phase_data - filtered)))
self.noise_profile.append(noise_level)
return filtered

def get_noise_profile(self) -> Dict[str, Any]:
"""Return the recent noise metrics."""
if not self.noise_profile:
return {'avg_noise': 0.0}
avg_noise = float(sum(self.noise_profile) / len(self.noise_profile))
return {
'avg_noise': avg_noise,
'samples': len(self.noise_profile)
}

def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray:
"""Sanitize phase data through complete pipeline.
Expand All @@ -281,9 +320,11 @@ def sanitize_phase(self, phase_data: np.ndarray) -> np.ndarray:
# Validate input data
self.validate_phase_data(phase_data)

# Apply complete sanitization pipeline
# Apply complete sanitization pipeline with temporal consistency and noise tracking
sanitized_data = self.unwrap_phase(phase_data)
sanitized_data = self._apply_temporal_consistency(sanitized_data)
sanitized_data = self.remove_outliers(sanitized_data)
sanitized_data = self._apply_phase_noise_filter(sanitized_data)
sanitized_data = self.smooth_phase(sanitized_data)
sanitized_data = self.filter_noise(sanitized_data)

Expand Down Expand Up @@ -344,4 +385,4 @@ def reset_statistics(self) -> None:
"""Reset sanitization statistics."""
self._total_processed = 0
self._outliers_removed = 0
self._sanitization_errors = 0
self._sanitization_errors = 0
122 changes: 86 additions & 36 deletions v1/src/hardware/csi_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""CSI data extraction from WiFi hardware using Test-Driven Development approach."""

import asyncio
import contextlib
import struct
import numpy as np
from datetime import datetime, timezone
Expand Down Expand Up @@ -301,6 +302,10 @@ def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = No
# State management
self.is_connected = False
self.is_streaming = False
self._udp_transport: Optional[asyncio.BaseTransport] = None
self._udp_protocol = None
self._udp_queue: Optional[asyncio.Queue] = None
self._udp_keepalive_task: Optional[asyncio.Task] = None

# Create appropriate parser
if self.hardware_type == 'esp32':
Expand Down Expand Up @@ -446,14 +451,31 @@ def stop_streaming(self) -> None:
self.is_streaming = False

async def _establish_hardware_connection(self) -> bool:
"""Establish connection to hardware (to be implemented by subclasses)."""
# Placeholder implementation for testing
return True
"""Establish connection to the CSI source."""
if self.hardware_type == 'esp32':
host = self.config.get('aggregator_host', '0.0.0.0')
port = self.config.get('aggregator_port', 5005)
await self._setup_udp_endpoint(host, port)
return True
elif self.hardware_type == 'router':
# Router clients rely on RouterInterface; nothing to open here
return True

self.logger.warning(f"Unsupported hardware type for connection: {self.hardware_type}")
return False

async def _close_hardware_connection(self) -> None:
"""Close hardware connection (to be implemented by subclasses)."""
# Placeholder implementation for testing
pass
"""Close any open hardware transport resources."""
if self._udp_transport:
self._udp_transport.close()
self._udp_transport = None
if self._udp_keepalive_task:
self._udp_keepalive_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._udp_keepalive_task
self._udp_keepalive_task = None
self._udp_queue = None
self._udp_protocol = None

async def _read_raw_data(self) -> bytes:
"""Read raw data from hardware.
Expand All @@ -475,42 +497,70 @@ async def _read_udp_data(self) -> bytes:
Raises:
CSIExtractionError: If read times out or connection fails.
"""
host = self.config.get('aggregator_host', '0.0.0.0')
port = self.config.get('aggregator_port', 5005)
if not self._udp_queue:
raise CSIExtractionError("UDP endpoint is not initialized")

try:
data = await asyncio.wait_for(self._udp_queue.get(), timeout=self.timeout)
return data
except asyncio.TimeoutError:
host = self.config.get('aggregator_host', '0.0.0.0')
port = self.config.get('aggregator_port', 5005)
raise CSIExtractionError(
f"UDP read timed out after {self.timeout}s. "
f"Ensure the aggregator is running and sending to {host}:{port}."
)

async def _setup_udp_endpoint(self, host: str, port: int) -> None:
"""Set up a UDP listener for the aggregator stream."""
loop = asyncio.get_event_loop()

# Create UDP endpoint if not already cached
if not hasattr(self, '_udp_transport'):
self._udp_future: asyncio.Future = loop.create_future()
queue = asyncio.Queue(maxsize=max(1, self.buffer_size * 2))
self._udp_queue = queue

class _UDPProtocol(asyncio.DatagramProtocol):
def __init__(self, queue: asyncio.Queue):
self.queue = queue

def datagram_received(self, data, addr):
if self.queue.full():
try:
self.queue.get_nowait()
except asyncio.QueueEmpty:
pass
try:
self.queue.put_nowait(data)
except asyncio.QueueFull:
pass

def error_received(self, exc):
logging.getLogger(__name__).error(f"UDP error: {exc}")

transport, protocol = await loop.create_datagram_endpoint(
lambda: _UDPProtocol(queue),
local_addr=(host, port),
)

class _UdpProtocol(asyncio.DatagramProtocol):
def __init__(self, future):
self._future = future
self._udp_transport = transport
self._udp_protocol = protocol

def datagram_received(self, data, addr):
if not self._future.done():
self._future.set_result(data)
keepalive_message = self.config.get('aggregator_keepalive_message')
keepalive_interval = self.config.get('aggregator_keepalive_interval', 5.0)

def error_received(self, exc):
if not self._future.done():
self._future.set_exception(exc)
if isinstance(keepalive_message, str):
keepalive_message = keepalive_message.encode('utf-8')

transport, protocol = await loop.create_datagram_endpoint(
lambda: _UdpProtocol(self._udp_future),
local_addr=(host, port),
if keepalive_message:
self._udp_keepalive_task = asyncio.create_task(
self._udp_keepalive_loop(host, port, keepalive_message, keepalive_interval)
)
self._udp_transport = transport
self._udp_protocol = protocol

try:
data = await asyncio.wait_for(self._udp_future, timeout=self.timeout)
# Reset future for next read
self._udp_future = loop.create_future()
self._udp_protocol._future = self._udp_future
return data
except asyncio.TimeoutError:
raise CSIExtractionError(
f"UDP read timed out after {self.timeout}s. "
f"Ensure the aggregator is running and sending to {host}:{port}."
)
async def _udp_keepalive_loop(self, host: str, port: int, message: bytes, interval: float):
"""Periodically ping the aggregator to keep the socket open."""
while True:
try:
if self._udp_transport and message:
self._udp_transport.sendto(message, (host, port))
await asyncio.sleep(interval)
except asyncio.CancelledError:
break
Loading