diff --git a/.github/workflows/build_and_publish.yml b/.github/workflows/build_and_publish.yml index c9486ad..5320018 100644 --- a/.github/workflows/build_and_publish.yml +++ b/.github/workflows/build_and_publish.yml @@ -24,4 +24,4 @@ jobs: run: uv build - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.gitignore b/.gitignore index 11e6330..b57bb18 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,8 @@ .eggs/* .tox/* -*/__pycache__/* +__pycache__/ *.egg-info/ build/* dist/* */*/AGENTS.md -**/.DS_Store \ No newline at end of file +**/.DS_Store diff --git a/LICENSE b/LICENSE index 506e313..7f2ce26 100644 --- a/LICENSE +++ b/LICENSE @@ -19,4 +19,3 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - diff --git a/README.md b/README.md index 7526934..fd87053 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,10 @@ uv pip install sdcpy ## Usage +The main class is `SDCAnalysis`, which takes two time series as input and performs the SDC analysis. If only one time series is provided, it will be compared with itself. + +In this example we generate two synthetic time series with a transient pattern between indices 63-169 to showcase the detection of transient correlations. + ```python import numpy as np import pandas as pd @@ -46,14 +50,87 @@ ts1 = pd.Series([tc_signal(i) for i in range(250)]) ts2 = pd.Series([tc_signal(i) for i in range(250)]) # Run SDC analysis -sdc = SDCAnalysis(ts1, ts2, fragment_size=50, n_permutations=99) +sdc = SDCAnalysis( + ts1=ts1, # First time series (pd.Series or np.ndarray) + ts2=ts2, # Second time series (pd.Series or np.ndarray) + fragment_size=50, # Size of the sliding window fragment + n_permutations=99, # Number of permutations for significance testing + method="pearson", # Correlation method ("pearson", "spearman" or Callable) + two_tailed=True, # Whether to use two-tailed test + permutations=True, # Whether to compute p-values via permutation + min_lag=-np.inf, # Minimum lag to compute + max_lag=np.inf, # Maximum lag to compute + max_memory_gb=2.0, # Max memory usage for vectorized ops before chunking +) # Generate 2-way SDC combi plot -fig = sdc.combi_plot(xlabel="TS1", ylabel="TS2") -fig.savefig("sdc_plot.png", dpi=150, bbox_inches="tight") +fig = sdc.combi_plot( + xlabel="$TS_1$", # Label for top axis (TS1) + ylabel="$TS_2$", # Label for left axis (TS2) + title=None, # Plot title (None = auto-generated) + max_r=None, # Max correlation for color scale (None = auto) + date_fmt="%Y-%m-%d", # Date format for axes + align="center", # Alignment of heatmap cells ("center", "left", "right") + min_lag=-np.inf, # Start of lag range to plot + max_lag=np.inf, # End of lag range to plot + fontsize=9, # Base font size + figsize=(7, 7), # Figure size (width, height) + show_colorbar=True, # Whether to show the colorbar + show_ts2=True, # Whether to show TS2 time series panel + dpi=250, # Resolution for saving/displaying +) + +``` + + + +### Access Detailed Results + +You can access the detailed SDC results DataFrame via `sdc.sdc_df`. This contains the coordinates of each fragment, lag, correlation, and p-value. + +```python +sdc.sdc_df.head() ``` - +```text + start_1 stop_1 start_2 stop_2 lag r p_value + 0.0 50.0 0.0 50.0 0.0 0.2267 0.0495 + 0.0 50.0 1.0 51.0 -1.0 -0.0047 1.0000 + 0.0 50.0 2.0 52.0 -2.0 0.0579 0.6238 + 0.0 50.0 3.0 53.0 -3.0 0.0602 0.6337 + 0.0 50.0 4.0 54.0 -4.0 -0.1660 0.2475 +``` + +### Correlation by Value Range + +To check if synchronies occur in specific value ranges of the time series, use `get_ranges_df()`. This computes statistics of Positive/Non-significant/Negative correlations binned by the value of the time series and returns a `pandas.DataFrame` with the results. + +```python +sdc.get_ranges_df( + ts=1, # Which TS to bin by (1 or 2) + bin_size=0.5, # Size of value bins + agg_func="mean", # Aggregation function for fragment values + alpha=0.05, # Significance threshold + min_bin=None, # Manual lower bound for bins (None = auto) + max_bin=None, # Manual upper bound for bins (None = auto) + min_lag=0, # Minimum lag to include in stats + max_lag=10, # Maximum lag to include in stats +) +``` + +```text + cat_value direction counts n freq label +(-0.6, 0.0] Positive 504 1232 0.4091 40.9 % +(-0.6, 0.0] Negative 0 1232 0.0000 0.0 % +(-0.6, 0.0] NS 728 1232 0.5909 59.1 % + (0.0, 0.5] Positive 218 883 0.2469 24.7 % + (0.0, 0.5] Negative 10 883 0.0111 1.1 % + (0.0, 0.5] NS 655 883 0.7418 74.2 % + (0.5, 1.0] Positive 1 19 0.0526 5.3 % + (0.5, 1.0] Negative 0 19 0.0000 0.0 % + (0.5, 1.0] NS 18 19 0.9474 94.7 % + (1.0, 1.5] Positive 0 0 0.0000 0.0 % +``` See [examples/basic_usage.py](examples/basic_usage.py) for a complete example with synthetic data showing transient correlations. diff --git a/examples/basic_usage.py b/examples/basic_usage.py index 190c915..84b18bf 100644 --- a/examples/basic_usage.py +++ b/examples/basic_usage.py @@ -31,5 +31,5 @@ def tc_signal(i): # Generate combination plot fig = sdc.combi_plot(xlabel="$TS_1$", ylabel="$TS_2$") - fig.savefig("sdc_example.png", dpi=300, bbox_inches="tight") + fig.savefig("sdc_example.png", bbox_inches="tight") print("Saved: sdc_example.png") diff --git a/pyproject.toml b/pyproject.toml index a080197..342f2ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "sdcpy" -version = "0.5.2" +version = "0.7.0" description = "Scale Dependent Correlation in Python" readme = "README.md" license = "MIT" @@ -88,7 +88,7 @@ ignore = [ known-first-party = ["sdcpy"] [tool.bumpversion] -current_version = "0.5.0" +current_version = "0.7.0" commit = true tag = true diff --git a/sdc_example.png b/sdc_example.png index e9b436c..558efa8 100644 Binary files a/sdc_example.png and b/sdc_example.png differ diff --git a/sdcpy/core.py b/sdcpy/core.py index 0edbac6..b04c809 100644 --- a/sdcpy/core.py +++ b/sdcpy/core.py @@ -14,7 +14,51 @@ "spearman": lambda x, y: stats.spearmanr(x, y), } -CONSTANT_WARNING = {"pearson": stats.ConstantInputWarning, "spearman": stats.ConstantInputWarning} + +# Default maximum memory threshold (in GB) for full vectorized computation. +# Above this, chunked processing is used automatically. +DEFAULT_MAX_MEMORY_GB = 2.0 + + +def _estimate_vectorized_memory( + n1: int, n2: int, n_permutations: int, dtype: np.dtype = np.float64 +) -> float: + """ + Estimate peak memory usage (in GB) for the fully vectorized SDC computation. + + The dominant memory consumers are: + - Permuted correlation matrices: (n_root^2, n1, n2) where n_root = sqrt(n_permutations) + - Fragment matrices: (n1, fragment_size) + (n2, fragment_size) + - Correlation matrix: (n1, n2) + - Grid/lag matrices: 3 * (n1, n2) + + Parameters + ---------- + n1 + Number of fragments from first time series + n2 + Number of fragments from second time series + n_permutations + Number of permutations for the randomization test + dtype + Data type (default float64 = 8 bytes) + + Returns + ------- + float + Estimated peak memory usage in gigabytes + """ + bytes_per_element = np.dtype(dtype).itemsize + n_root = int(np.sqrt(n_permutations).round()) + n_actual_perms = n_root * n_root + + # Main memory consumers + perm_matrices = n_actual_perms * n1 * n2 * bytes_per_element + corr_matrix = n1 * n2 * bytes_per_element + grid_matrices = 3 * n1 * n2 * bytes_per_element # start_1_grid, start_2_grid, lag_matrix + + total_bytes = perm_matrices + corr_matrix + grid_matrices + return total_bytes / (1024**3) # Convert to GB def generate_correlation_map(x: np.ndarray, y: np.ndarray, method: str = "pearson") -> np.ndarray: @@ -94,6 +138,7 @@ def compute_sdc( permutations: bool = True, min_lag: int = -np.inf, max_lag: int = np.inf, + max_memory_gb: float = DEFAULT_MAX_MEMORY_GB, ) -> pd.DataFrame: """ Computes scale dependent correlation (https://doi.org/10.1007/s00382-005-0106-4) matrix among two time series @@ -124,6 +169,10 @@ def compute_sdc( Lower limit of the lags between ts1 and ts2 that will be computed. max_lag Upper limit of the lags between ts1 and ts2 that will be computed. + max_memory_gb + Maximum memory (in GB) to use for full vectorized computation. If estimated memory exceeds + this limit, the computation falls back to chunked processing which uses constant memory but + may be slightly slower. Default is 2.0 GB. Returns ------- @@ -147,6 +196,7 @@ def compute_sdc( permutations, min_lag, max_lag, + max_memory_gb, ) else: # Fall back to original loop-based implementation for custom callables @@ -173,8 +223,16 @@ def _compute_sdc_vectorized( permutations: bool, min_lag: int, max_lag: int, + max_memory_gb: float = DEFAULT_MAX_MEMORY_GB, ) -> pd.DataFrame: - """Vectorized implementation for built-in correlation methods.""" + """Vectorized implementation for built-in correlation methods. + + Parameters + ---------- + max_memory_gb + Maximum memory (in GB) to use for full vectorization. If estimated + memory exceeds this, chunked processing is used automatically. + """ n1 = len(ts1) - fragment_size n2 = len(ts2) - fragment_size @@ -210,44 +268,86 @@ def _compute_sdc_vectorized( n_root = int(np.sqrt(n_permutations).round()) n_actual_perms = n_root * n_root - # OPTIMIZED: Pre-shuffle all fragments and compute full permuted correlation matrices - # This is much faster than shuffling per-pair because we compute n_root^2 full - # correlation matrices instead of n_valid individual permutation tests - - # Pre-compute shuffled versions of all fragments - # Shape: (n_root, n_fragments, fragment_size) - shuffled_frags1 = np.array( - [shuffle_along_axis(frags1.copy(), axis=1) for _ in range(n_root)] - ) - shuffled_frags2 = np.array( - [shuffle_along_axis(frags2.copy(), axis=1) for _ in range(n_root)] - ) - - # Compute permuted correlation matrices for all combinations of shuffled fragments - # Shape: (n_root, n_root, n1, n2) - perm_corr_matrices = np.zeros((n_root, n_root, n1, n2)) - for i in tqdm(range(n_root), desc="Computing permutations", leave=False): - for j in range(n_root): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - perm_corr_matrices[i, j] = generate_correlation_map( - shuffled_frags1[i], shuffled_frags2[j], method=method - ) - - # Reshape to (n_actual_perms, n1, n2) - perm_corrs_flat = perm_corr_matrices.reshape(n_actual_perms, n1, n2) - - # Compute p-values vectorized - if two_tailed: - # Count how many abs(perm) >= abs(observed) for each position - abs_observed = np.abs(corr_matrix) - abs_perms = np.abs(perm_corrs_flat) - counts = (abs_perms >= abs_observed[np.newaxis, :, :]).sum(axis=0) + # Estimate memory and decide strategy + estimated_memory = _estimate_vectorized_memory(n1, n2, n_permutations) + use_chunked = estimated_memory > max_memory_gb + + if use_chunked: + warnings.warn( + f"Estimated memory ({estimated_memory:.2f} GB) exceeds limit " + f"({max_memory_gb:.2f} GB). Using chunked processing.", + UserWarning, + stacklevel=3, + ) + # Chunked approach: accumulate counts without storing all permutation matrices + counts = np.zeros((n1, n2), dtype=np.int32) + + # Pre-compute shuffled versions of all fragments + shuffled_frags1 = np.array( + [shuffle_along_axis(frags1.copy(), axis=1) for _ in range(n_root)] + ) + shuffled_frags2 = np.array( + [shuffle_along_axis(frags2.copy(), axis=1) for _ in range(n_root)] + ) + + # Process one permutation pair at a time, accumulating counts + abs_observed = np.abs(corr_matrix) if two_tailed else None + with tqdm( + total=n_actual_perms, desc="Computing permutations (chunked)", leave=False + ) as pbar: + for i in range(n_root): + for j in range(n_root): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + perm_corr = generate_correlation_map( + shuffled_frags1[i], shuffled_frags2[j], method=method + ) + if two_tailed: + counts += (np.abs(perm_corr) >= abs_observed).astype(np.int32) + else: + counts += (perm_corr >= corr_matrix).astype(np.int32) + pbar.update(n_root) + + # P-value: (count + 1) / (n_perms + 1) for proper permutation test + p_value_matrix = (counts + 1) / (n_actual_perms + 1) else: - counts = (perm_corrs_flat >= corr_matrix[np.newaxis, :, :]).sum(axis=0) - - # P-value: (count + 1) / (n_perms + 1) for proper permutation test - p_value_matrix = (counts + 1) / (n_actual_perms + 1) + # Full vectorized approach: store all permutation matrices + # Pre-compute shuffled versions of all fragments + # Shape: (n_root, n_fragments, fragment_size) + shuffled_frags1 = np.array( + [shuffle_along_axis(frags1.copy(), axis=1) for _ in range(n_root)] + ) + shuffled_frags2 = np.array( + [shuffle_along_axis(frags2.copy(), axis=1) for _ in range(n_root)] + ) + + # Compute permuted correlation matrices for all combinations of shuffled fragments + # Shape: (n_root, n_root, n1, n2) + perm_corr_matrices = np.zeros((n_root, n_root, n1, n2)) + with tqdm(total=n_actual_perms, desc="Computing permutations", leave=False) as pbar: + for i in range(n_root): + for j in range(n_root): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + perm_corr_matrices[i, j] = generate_correlation_map( + shuffled_frags1[i], shuffled_frags2[j], method=method + ) + pbar.update(n_root) + + # Reshape to (n_actual_perms, n1, n2) + perm_corrs_flat = perm_corr_matrices.reshape(n_actual_perms, n1, n2) + + # Compute p-values vectorized + if two_tailed: + # Count how many abs(perm) >= abs(observed) for each position + abs_observed = np.abs(corr_matrix) + abs_perms = np.abs(perm_corrs_flat) + counts = (abs_perms >= abs_observed[np.newaxis, :, :]).sum(axis=0) + else: + counts = (perm_corrs_flat >= corr_matrix[np.newaxis, :, :]).sum(axis=0) + + # P-value: (count + 1) / (n_perms + 1) for proper permutation test + p_value_matrix = (counts + 1) / (n_actual_perms + 1) # Extract p-values for valid entries p_values = p_value_matrix[valid_mask] @@ -291,7 +391,7 @@ def _compute_sdc_loop( """Original loop-based implementation for custom callable methods.""" method_fun = method n_iterations = (len(ts1) - fragment_size) * (len(ts2) - fragment_size) - int(np.sqrt(n_permutations).round()) + sdc_array = np.empty(shape=(n_iterations, 7)) sdc_array[:] = np.nan i = 0 diff --git a/sdcpy/plotting.py b/sdcpy/plotting.py index d1c43c4..21ae0ee 100644 --- a/sdcpy/plotting.py +++ b/sdcpy/plotting.py @@ -6,11 +6,13 @@ import matplotlib.pyplot as plt import numpy as np import pandas as pd +import plotnine as p9 import seaborn as sns from matplotlib.ticker import MaxNLocator if TYPE_CHECKING: from matplotlib.figure import Figure as MplFigure + from plotnine.ggplot import ggplot def _determine_frequency_info(index: pd.Index) -> tuple[str, float, str]: @@ -89,6 +91,47 @@ def _determine_frequency_info(index: pd.Index) -> tuple[str, float, str]: return "periods", max(1, freq_mult), "D" +def _prepare_plot_index(index: pd.Index) -> tuple[pd.Index, bool]: + """ + Prepare an index for plotting, handling different index types appropriately. + + Parameters + ---------- + index : pd.Index + The index to prepare for plotting. + + Returns + ------- + tuple[pd.Index, bool] + A tuple of (prepared_index, is_datetime) where: + - prepared_index is the index ready for matplotlib plotting + - is_datetime indicates if the index should be treated as datetime + + Raises + ------ + ValueError + If the index is string-based but cannot be converted to datetime. + """ + # Check if already datetime + if pd.api.types.is_datetime64_any_dtype(index): + return index, True + + # Check if numeric (integer or float) + if pd.api.types.is_numeric_dtype(index): + return index, False + + # Assume string-like, try to convert to datetime + try: + converted = pd.to_datetime(index) + return converted, True + except (ValueError, TypeError) as e: + raise ValueError( + f"Could not convert index to datetime. " + f"Index should be either numeric (integer), datetime, or datetime-parseable strings. " + f"Original error: {e}" + ) from e + + def combi_plot( ts1: pd.Series, ts2: pd.Series, @@ -114,8 +157,8 @@ def combi_plot( show_ts2: bool = True, metric_label: str = None, n_ticks: int = 6, - figsize: tuple = (6, 6), - dpi: int = 150, + figsize: tuple = (7, 7), + dpi: int = 250, **kwargs, ) -> "MplFigure": """ @@ -171,9 +214,9 @@ def combi_plot( Label for the correlation metric. Defaults to method name. n_ticks : int, default 6 Number of ticks to show on axes. - figsize : tuple, default (6, 6) + figsize : tuple, default (7, 7) Figure size. - dpi : int, default 150 + dpi : int, default 250 Figure resolution. **kwargs Additional keyword arguments passed to plt.figure(). @@ -226,17 +269,18 @@ def combi_plot( left_offset = 0 if align == "left" else offset right_offset = 0 if align == "right" else offset - # Check if index is datetime for formatting - is_datetime_index = pd.api.types.is_datetime64_any_dtype(ts1.index) + # Prepare indexes for plotting (handles integer, datetime, and string indexes) + ts1_plot_index, ts1_is_datetime = _prepare_plot_index(ts1.index) + ts2_plot_index, ts2_is_datetime = _prepare_plot_index(ts2.index) - # Calculate offset using timedelta for datetime indexes, integer for others - if is_datetime_index: + # Create offset based on index type + if ts1_is_datetime: timedelta_offset = pd.to_timedelta(left_offset * freq_mult, unit=freq_unit) else: - timedelta_offset = left_offset # Use integer offset for non-datetime + timedelta_offset = left_offset # Integer offset for numeric indexes - is_datetime_index = pd.api.types.is_datetime64_any_dtype(ts1.index) - if is_datetime_index and date_fmt: + # Create date formatter if date_fmt is provided and index is datetime-like + if date_fmt and ts1_is_datetime: date_format = mdates.DateFormatter(date_fmt) else: date_format = None @@ -276,12 +320,12 @@ def combi_plot( # Time series 1 (top) ts1_ax = fig.add_subplot(gs[1, hm_cols]) - ts1_ax.plot(ts1, color="black", linewidth=1) + ts1_ax.plot(ts1_plot_index, ts1.values, color="black", linewidth=1) # Time series 2 (left) if show_ts2: ts2_ax = fig.add_subplot(gs[2:4, 0]) - ts2_ax.plot(ts2.values, ts2.index, color="black", linewidth=1) + ts2_ax.plot(ts2.values, ts2_plot_index, color="black", linewidth=1) # Heatmap hm = fig.add_subplot(gs[2:4, hm_cols]) @@ -355,8 +399,8 @@ def combi_plot( # Format TS1 axis ts1_ax.xaxis.set_label_position("top") - ts1_ax.set_xlim(ts1.index[0], ts1.index[-1]) - ts1_ax.grid(True, which="major", axis="x", linestyle="--", alpha=0.5) + ts1_ax.set_xlim(ts1_plot_index[0], ts1_plot_index[-1]) + ts1_ax.grid(True, which="major", axis="both", linestyle="--", alpha=0.5) ts1_ax.set_xlabel(xlabel, fontsize=label_fontsize) ts1_ax.xaxis.set_major_locator(MaxNLocator(nbins=n_ticks, prune="both")) if date_format: @@ -376,8 +420,8 @@ def combi_plot( # Format TS2 axis if show_ts2: - ts2_ax.set_ylim(ts2.index[0], ts2.index[-1]) - ts2_ax.grid(True, which="major", axis="y", linestyle="--", alpha=0.5) + ts2_ax.set_ylim(ts2_plot_index[0], ts2_plot_index[-1]) + ts2_ax.grid(True, which="major", axis="both", linestyle="--", alpha=0.5) ts2_ax.invert_xaxis() ts2_ax.invert_yaxis() ts2_ax.set_ylabel(ylabel, fontsize=label_fontsize) @@ -405,12 +449,12 @@ def combi_plot( gs.update(wspace=wspace, hspace=hspace) - # Max correlations scatter plots + # Max correlations line plots colors = {"Max $r$": "#A81529", "Min $r$ (abs)": "#144E8A"} if min_lag < 0: mc1 = fig.add_subplot(gs[-1, hm_cols]) - mc1_data = ( + mc1_base = ( sdc_df.query("p_value < @alpha") .query("(lag <= @max_lag) & (lag >= @min_lag)") .groupby("date_1") @@ -420,35 +464,42 @@ def combi_plot( ) .rename(columns={"r_max": "Max $r$", "r_min": "Min $r$ (abs)"}) .reset_index() - .melt("date_1") - .assign(date_1=lambda dd: dd.date_1 + timedelta_offset) - .assign(color=lambda dd: dd.variable.map(colors)) - .dropna(subset=["value"]) ) - if len(mc1_data) > 0: - mc1_data.plot.scatter( - x="date_1", - y="value", - c="color", - ax=mc1, - alpha=0.7, - colorbar=False, - linewidths=0, + # Apply offset based on index type + if ts1_is_datetime: + mc1_data = ( + mc1_base.assign(date_1=lambda dd: pd.to_datetime(dd.date_1) + timedelta_offset) + .set_index("date_1") + .sort_index() ) + else: + mc1_data = ( + mc1_base.assign(date_1=lambda dd: dd.date_1 + timedelta_offset) + .set_index("date_1") + .sort_index() + ) + if len(mc1_data) > 0: + # Plot each correlation type as a separate line + for col_name, color in colors.items(): + if col_name in mc1_data.columns: + data = mc1_data[col_name].dropna() + if len(data) > 0: + mc1.plot(data.index, data.values, color=color, alpha=1, linewidth=1.3) plt.setp(mc1.get_xticklabels(), visible=False) mc1.set_xlabel("") mc1.set_ylabel("Max |corr|", fontsize=label_fontsize) mc1.yaxis.set_label_position("right") - mc1.set_xlim(ts1.index[0], ts1.index[-1]) + mc1.set_xlim(ts1_plot_index[0], ts1_plot_index[-1]) mc1.set_ylim(0, 1.05) - mc1.grid(True, which="major") + mc1.xaxis.set_major_locator(MaxNLocator(nbins=n_ticks, prune="both")) # Match ts1 ticks + mc1.grid(True, which="major", axis="both", linestyle="--", alpha=0.5) mc1.set_yticks([0, 0.5, 1]) mc1.tick_params(axis="y", labelsize=tick_fontsize) mc1.tick_params(axis="x", bottom=False, top=False, labelbottom=False, labeltop=False) if max_lag > 0 and mc2_col is not None: mc2 = fig.add_subplot(gs[2:4, mc2_col]) - mc2_data = ( + mc2_base = ( sdc_df.query("p_value < @alpha") .query("(lag <= @max_lag) & (lag >= @min_lag)") .groupby("date_2") @@ -458,29 +509,36 @@ def combi_plot( ) .rename(columns={"r_max": "Max $r$", "r_min": "Min $r$ (abs)"}) .reset_index() - .melt("date_2") - .assign(date_2=lambda dd: dd.date_2 + timedelta_offset) - .assign(color=lambda dd: dd.variable.map(colors)) - .dropna(subset=["value"]) ) - if len(mc2_data) > 0: - mc2_data.plot.scatter( - x="value", - y="date_2", - c="color", - ax=mc2, - alpha=0.7, - colorbar=False, - linewidths=0, + # Apply offset based on index type + if ts2_is_datetime: + mc2_data = ( + mc2_base.assign(date_2=lambda dd: pd.to_datetime(dd.date_2) + timedelta_offset) + .set_index("date_2") + .sort_index() + ) + else: + mc2_data = ( + mc2_base.assign(date_2=lambda dd: dd.date_2 + timedelta_offset) + .set_index("date_2") + .sort_index() ) + if len(mc2_data) > 0: + # Plot each correlation type as a separate line (x=value, y=date for vertical orientation) + for col_name, color in colors.items(): + if col_name in mc2_data.columns: + data = mc2_data[col_name].dropna() + if len(data) > 0: + mc2.plot(data.values, data.index, color=color, alpha=1, linewidth=1.3) plt.setp(mc2.get_yticklabels(), visible=False) mc2.set_xlabel("Max |corr|", fontsize=label_fontsize) mc2.xaxis.set_label_position("top") mc2.set_ylabel("") - mc2.grid(True, which="major") mc2.set_xlim(1.05, 0) mc2.set_xticks([0, 0.5, 1]) # Match mc1's y-axis breaks - mc2.set_ylim(ts2.index[-1], ts2.index[0]) + mc2.set_ylim(ts2_plot_index[-1], ts2_plot_index[0]) + mc2.yaxis.set_major_locator(MaxNLocator(nbins=n_ticks, prune="both")) # Match ts2 ticks + mc2.grid(True, which="major", axis="both", linestyle="--", alpha=0.5) # Move x-axis ticks to top mc2.tick_params( axis="x", @@ -506,3 +564,61 @@ def combi_plot( fig.suptitle(title) return fig + + +def plot_range_comparison( + ranges_df: pd.DataFrame, + xlabel: str = "", + figsize: tuple[int, int] = (7, 3), + add_text_label: bool = True, +) -> "ggplot": + """ + Create a bar chart showing correlation directions by value ranges. + + Parameters + ---------- + ranges_df : pd.DataFrame + DataFrame from SDCAnalysis.get_ranges_df() with columns: + cat_value, direction, counts, n, freq, label. + xlabel : str, default="" + Label for the x-axis. + figsize : tuple[int, int], default=(7, 3) + Figure size as (width, height). + add_text_label : bool, default=True + Whether to add percentage labels on the bars. + + Returns + ------- + ggplot + A plotnine ggplot object. + """ + fig = ( + p9.ggplot(ranges_df) + + p9.aes("cat_value", "counts", fill="direction") + + p9.geom_col(alpha=0.8) + + p9.theme(figure_size=figsize, axis_text_x=p9.element_text(rotation=45)) + + p9.scale_fill_manual(["#3f7f93", "#da3b46", "#4d4a4a"]) + + p9.labs(x=xlabel, y="Number of Comparisons", fill="R") + ) + + if add_text_label: + positive_data = ranges_df.loc[(ranges_df.direction == "Positive") & (ranges_df.counts > 0)] + if len(positive_data) > 0: + fig += p9.geom_text( + p9.aes(label="label", x="cat_value", y="n + max(n) * .15"), + inherit_aes=False, + size=9, + data=positive_data, + color="#3f7f93", + ) + negative_data = ranges_df.loc[(ranges_df.direction == "Negative") & (ranges_df.counts > 0)] + if len(negative_data) > 0: + fig += p9.geom_text( + p9.aes(label="label", x="cat_value", y="n + max(n) * .05"), + inherit_aes=False, + size=9, + data=negative_data, + color="#da3b46", + ) + + return fig diff --git a/sdcpy/scale_dependent_correlation.py b/sdcpy/scale_dependent_correlation.py index 90e5c01..4ce46e3 100755 --- a/sdcpy/scale_dependent_correlation.py +++ b/sdcpy/scale_dependent_correlation.py @@ -2,15 +2,13 @@ from typing import TYPE_CHECKING, Callable, Optional, Union -import matplotlib.pyplot as plt import numpy as np import pandas as pd -import plotnine as p9 -import seaborn as sns from sdcpy.core import compute_sdc from sdcpy.io import load_from_excel, save_to_excel from sdcpy.plotting import combi_plot as _combi_plot +from sdcpy.plotting import plot_range_comparison as _plot_range_comparison if TYPE_CHECKING: from matplotlib.figure import Figure as MplFigure @@ -36,20 +34,16 @@ def __init__( sdc_df: Optional[pd.DataFrame] = None, min_lag: int = -np.inf, max_lag: int = np.inf, + max_memory_gb: float = 2.0, ): self.way = ( "one-way" if ts2 is None else "two-way" ) # One-way SDC inferred if no ts2 is provided ts2 = ts1.copy() if self.way == "one-way" else ts2 - # TODO: As mentioned in (#4), we should make if not isinstance(ts1, pd.Series): - ts1 = pd.Series( - ts1, index=pd.date_range(start="2000-01-01", periods=len(ts1), freq="D") - ) + ts1 = pd.Series(ts1) if not isinstance(ts2, pd.Series): - ts2 = pd.Series( - ts2, index=pd.date_range(start="2000-01-01", periods=len(ts2), freq="D") - ) + ts2 = pd.Series(ts2) min_date = max(ts1.index.min(), ts2.index.min()) max_date = min(ts1.index.max(), ts2.index.max()) self.ts1 = ts1[min_date:max_date] @@ -71,6 +65,7 @@ def __init__( permutations, min_lag, max_lag, + max_memory_gb, ).assign( date_1=lambda dd: dd.start_1.map(self.ts1.reset_index().to_dict()["date_1"]), date_2=lambda dd: dd.start_2.map(self.ts2.reset_index().to_dict()["date_2"]), @@ -82,32 +77,6 @@ def __init__( ) self.method = method - def two_way_plot(self, alpha: float = 0.05, **kwargs) -> "ggplot": - """Plot two-way SDC heatmap using plotnine.""" - fragment_size = int(self.sdc_df.iloc[0]["stop_1"] - self.sdc_df.iloc[0]["start_1"]) - f = ( - self.sdc_df.loc[lambda dd: dd.p_value < alpha] - .assign(r_str=lambda dd: dd["r"].apply(lambda x: "$r > 0$" if x > 0 else "$r < 0$")) - .pipe( - lambda dd: p9.ggplot(dd) - + p9.aes("start_1", "start_2", fill="r_str", alpha="abs(r)") - + p9.geom_tile() - + p9.scale_fill_manual(["#da2421", "black"]) - + p9.scale_y_reverse() - + p9.theme(**kwargs) - + p9.guides(alpha=False) - + p9.labs( - x="$X_i$", - y="$Y_j$", - fill="$r$", - title=f"Two-Way SDC plot for $S = {fragment_size}$" - + r" and $\alpha =$" - + f"{alpha}", - ) - ) - ) - return f - def to_excel(self, filename: str): save_to_excel( self.sdc_df, @@ -133,67 +102,125 @@ def from_excel(cls, filename: str): def get_ranges_df( self, - bin_size: int = 3, + bin_size: Union[int, float] = 1, alpha: float = 0.05, - min_bin=None, - max_bin=None, - threshold: float = 0.0, + min_bin: Optional[Union[int, float]] = None, + max_bin: Optional[Union[int, float]] = None, ts: int = 1, - ): + agg_func: str = "mean", + min_lag: int = -np.inf, + max_lag: int = np.inf, + ) -> pd.DataFrame: + """ + Compute correlation direction statistics by value ranges. + + For each SDC comparison, computes the aggregate value (mean by default) of ts1 or ts2 + during that fragment, bins those values, then counts how many correlations in each + bin were positive, negative, or not significant. + + Parameters + ---------- + bin_size : Union[int, float], default=1 + Width of each value bin. + alpha : float, default=0.05 + Significance level for classifying correlations. + min_bin : Optional[Union[int, float]], optional + Lower bound for binning. Defaults to floor(min(ts)) aligned to bin_size. + max_bin : Optional[Union[int, float]], optional + Upper bound for binning. Defaults to ceil(max(ts)) aligned to bin_size. + ts : int, default=1 + Which time series to analyze (1 for ts1, 2 for ts2). + agg_func : str, default="mean" + Aggregation function to summarize values in each fragment. + Options: "mean", "median", "min", "max". + min_lag : int, default=-np.inf + Minimum lag to consider. + max_lag : int, default=np.inf + Maximum lag to consider. + + Returns + ------- + pd.DataFrame + DataFrame with columns: + - cat_value: categorical bin (e.g., "(0, 3]") + - direction: "Positive", "Negative", or "NS" (not significant) + - counts: number of comparisons in this bin with this direction + - n: total comparisons in this bin + - freq: proportion (counts / n) + - label: formatted percentage string + """ ts_series = self.ts1 if ts == 1 else self.ts2 - min_bin = int(np.floor(ts_series.min())) if min_bin is None else min_bin - max_bin = int(np.ceil(ts_series.max())) if max_bin is None else max_bin - name = ts_series.name + + # Compute rolling aggregate for fragments + # This gives the aggregate value for each fragment starting at each index + if agg_func == "mean": + fragment_values = ts_series.rolling(window=self.fragment_size, min_periods=1).mean() + elif agg_func == "median": + fragment_values = ts_series.rolling(window=self.fragment_size, min_periods=1).median() + elif agg_func == "min": + fragment_values = ts_series.rolling(window=self.fragment_size, min_periods=1).min() + elif agg_func == "max": + fragment_values = ts_series.rolling(window=self.fragment_size, min_periods=1).max() + else: + raise ValueError( + f"Unknown agg_func: {agg_func}. Use 'mean', 'median', 'min', or 'max'." + ) + + # Create lookup from date to fragment aggregate value + fragment_values_df = fragment_values.reset_index() + fragment_values_df.columns = [f"date_{ts}", "fragment_value"] + + # Join sdc_df with fragment values first df = ( self.sdc_df.dropna() - .assign( - date_range=lambda dd: dd[f"date_{ts}"].apply( - lambda x: pd.date_range(x, x + pd.to_timedelta(self.fragment_size, unit="days")) - ) - )[["r", "p_value", "date_range"]] - .explode("date_range") - .rename(columns={"date_range": "date"}) - .reset_index() - .rename(columns={"index": "comparison_id"}) - .merge(ts_series.reset_index().rename(columns={f"date_{ts}": "date", name: "value"})) - .assign( + .query("lag >= @min_lag & lag <= @max_lag") + .merge(fragment_values_df, on=f"date_{ts}", how="left") + ) + + # Compute bin bounds from the *filtered* fragment aggregates + # This ensures we don't create empty bins for ranges that were filtered out + current_values = df["fragment_value"] + + # Snap min/max to grid defined by bin_size + if min_bin is None: + min_val = current_values.min() + min_bin = np.floor(min_val / bin_size) * bin_size + + if max_bin is None: + max_val = current_values.max() + max_bin = np.ceil(max_val / bin_size) * bin_size + + # Assign categories using data-dependent bins + # Use np.arange to support float bin_size + # Add small epsilon to max range to ensure inclusion due to floating point precision + df = ( + df.assign( cat_value=lambda dd: pd.cut( - dd.value, bins=list(range(min_bin, max_bin + bin_size, bin_size)), precision=0 + dd.fragment_value, + bins=np.arange(min_bin, max_bin + bin_size + 1e-10, bin_size), + precision=1, # Improved precision for float bins + include_lowest=True, ) ) - .groupby(["comparison_id"]) - .apply(lambda dd: dd.cat_value.value_counts(True), include_groups=False) - .loc[lambda x: x > threshold] - .reset_index() - .rename(columns={"level_1": "cat_value"}, errors="ignore") - .drop(columns=["proportion"], errors="ignore") - .merge( - self.sdc_df.reset_index().rename(columns={"index": "comparison_id"})[ - ["r", "p_value", "comparison_id"] - ] - ) .assign(significant=lambda dd: dd.p_value < alpha) .assign( - direction=lambda dd: ( - dd.significant.astype(int) * ((dd.r > 0).astype(int) + 1) - ).replace({0: "NS", 1: "Negative", 2: "Positive"}) + direction=lambda dd: np.where( + ~dd.significant, + "NS", + np.where(dd.r > 0, "Positive", "Negative"), + ) ) .assign( direction=lambda dd: pd.Categorical( dd.direction, categories=["Positive", "Negative", "NS"], ordered=True ) ) - .groupby("cat_value") - .apply( - lambda dd: dd["direction"].value_counts().rename("counts").reset_index(), - include_groups=False, - ) - .reset_index() - .drop(columns="level_1") - .rename(columns={"index": "direction"}) + .groupby(["cat_value", "direction"], observed=False) + .size() + .reset_index(name="counts") .pipe( lambda dd: dd.merge( - dd.groupby("cat_value", as_index=False)["counts"] + dd.groupby("cat_value", as_index=False, observed=False)["counts"] .sum() .rename(columns={"counts": "n"}), on="cat_value", @@ -211,63 +238,20 @@ def plot_range_comparison( figsize: tuple[int, int] = (7, 3), add_text_label: bool = True, **kwargs, - ): - df = self.get_ranges_df(**kwargs) - fig = ( - p9.ggplot(df) - + p9.aes("cat_value", "counts", fill="direction") - + p9.geom_col(alpha=0.8) - + p9.theme(figure_size=figsize, axis_text_x=p9.element_text(rotation=45)) - + p9.scale_fill_manual(["#3f7f93", "#da3b46", "#4d4a4a"]) - + p9.labs(x=xlabel, y="Number of Comparisons", fill="R") - ) - - if add_text_label: - if df.loc[df.direction == "Positive"].loc[df.counts > 0].size > 0: - fig += p9.geom_text( - p9.aes(label="label", x="cat_value", y="n + max(n) * .15"), - inherit_aes=False, - size=9, - data=df.loc[df.direction == "Positive"].loc[df.counts > 0], - color="#3f7f93", - ) - if df.loc[df.direction == "Negative"].loc[df.counts > 0].size > 0: - fig += p9.geom_text( - p9.aes(label="label", x="cat_value", y="n + max(n) * .05"), - inherit_aes=False, - size=9, - data=df.loc[df.direction == "Negative"].loc[df.counts > 0], - color="#da3b46", - ) - - return fig + ) -> "ggplot": + """ + Create a bar chart showing correlation directions by value ranges. - def plot_consecutive(self, alpha: float = 0.05, **kwargs) -> "ggplot": - f = ( - self.sdc_df.loc[lambda dd: dd.p_value < alpha] - # Here I make groups of consecutive significant values and report the longest for each lag. - .groupby("lag", as_index=True) - .apply( - lambda gdf: gdf.sort_values("start_1") - .assign(group=lambda dd: (dd.start_1 != dd.start_1.shift(1) + 1).cumsum()) - .groupby(["group"]) - .size() - .max(), - include_groups=False, - ) - .rename("Max Consecutive steps") - .reset_index() - .pipe( - lambda dd: p9.ggplot(dd) - + p9.aes("lag", "Max Consecutive steps") - + p9.geom_col() - + p9.theme(**kwargs) - + p9.labs(x="Lag [days]") - ) + See `sdcpy.plotting.plot_range_comparison` for full parameter documentation. + """ + df = self.get_ranges_df(**kwargs) + return _plot_range_comparison( + ranges_df=df, + xlabel=xlabel, + figsize=figsize, + add_text_label=add_text_label, ) - return f - def combi_plot( self, alpha: float = 0.05, @@ -294,7 +278,7 @@ def combi_plot( **kwargs, ) -> "MplFigure": """ - Create a combination plot showing SDC analysis results. + Create a combined plot showing two-way SDC analysis results. See `sdcpy.plotting.combi_plot` for full parameter documentation. """ @@ -327,61 +311,3 @@ def combi_plot( dpi=dpi, **kwargs, ) - - def dominant_lags_plot(self, alpha: float = 0.05, ylabel: str = "", **kwargs) -> "MplFigure": - fig, ax = plt.subplots(**kwargs) - df = ( - self.sdc_df.loc[lambda dd: dd.p_value < alpha] - .groupby("date_1") - .apply( - lambda dd: dd.loc[ - lambda ddd: ((ddd.r == ddd.r.max()) & (ddd.r > 0)) - | ((ddd.r == ddd.r.min()) & (ddd.r < 0)) - ], - include_groups=False, - ) - .reset_index(level=0) - .groupby(["date_1"]) - .apply( - lambda dd: dd.loc[dd["lag"].abs() == dd["lag"].abs().min()], include_groups=False - ) - .reset_index(level=0) - .assign( - date_1=lambda dd: dd.date_1 + pd.to_timedelta(self.fragment_size // 2, unit="days") - ) - .assign(lag=lambda dd: dd.lag.abs()) - ) - self.ts1.plot(ax=ax, color="black") - ax2 = ax.twinx() - sns.scatterplot( - data=df, - x="date_1", - y="r", - hue="lag", - legend="full", - alpha=0.7, - ax=ax2, - palette="inferno_r", - ) - handles, labels = ax2.get_legend_handles_labels() - ax2.legend( - bbox_to_anchor=(1.3, 1.05), - ncol=1, - frameon=True, - columnspacing=0.2, - handles=[h for i, h in enumerate(handles[1:]) if i % 3 == 1], - labels=[label for i, label in enumerate(labels[1:]) if i % 3 == 1], - title="Lag", - ) - - ax2.set_yticks([-1, -0.5, 0, 0.5, 1]) - ax2.grid(which="major") - ax2.set_xlabel("") - ax.set_xlabel("") - ax.set_ylabel(ylabel if ylabel else "Value") - ax2.set_ylabel("Max/Min r") - - return fig - - def single_shift_plot(self, shift: int) -> "MplFigure": - raise NotImplementedError("single_shift_plot is not yet implemented") diff --git a/tests/conftest.py b/tests/conftest.py index ff902b2..4c48257 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,16 @@ import pytest +def _convert_to_string_index(ts: pd.Series) -> pd.Series: + """Convert a Series with DatetimeIndex to one with string datetime index.""" + result = ts.copy() + result.index = ts.index.strftime("%Y-%m-%d") + return result + + @pytest.fixture def random_ts_pair(): - """Two random time series with daily frequency.""" + """Two random time series with daily frequency (DatetimeIndex).""" np.random.seed(42) dates = pd.date_range("2020-01-01", periods=100, freq="D") ts1 = pd.Series(np.random.randn(100), index=dates, name="ts1") @@ -17,6 +24,28 @@ def random_ts_pair(): return ts1, ts2 +@pytest.fixture +def string_datetime_ts_pair(): + """Two random time series with string-based datetime index (object dtype).""" + np.random.seed(42) + dates = pd.date_range("2020-01-01", periods=100, freq="D") + string_dates = dates.strftime("%Y-%m-%d") + ts1 = pd.Series(np.random.randn(100), index=string_dates, name="ts1") + ts1.index.name = "date_1" + ts2 = pd.Series(np.random.randn(100), index=string_dates, name="ts2") + ts2.index.name = "date_2" + return ts1, ts2 + + +@pytest.fixture(params=["datetime", "string"]) +def ts_pair_any_index(request, random_ts_pair, string_datetime_ts_pair): + """Time series pair with either datetime or string index (parameterized).""" + if request.param == "datetime": + return random_ts_pair + else: + return string_datetime_ts_pair + + @pytest.fixture def correlated_ts_pair(): """ts2 is a noisy lagged copy of ts1 (lag=5 days).""" diff --git a/tests/test_core.py b/tests/test_core.py index 24fae9b..7fb46fd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -194,3 +194,81 @@ def test_lag_detection(self, correlated_ts_pair): best_lag = mean_by_lag.idxmax() # The synthetic data has lag of 5, but we relaxed to ±10 due to noise assert abs(best_lag - 5) <= 10 + + +class TestMemoryManagement: + """Tests for memory-aware chunking functionality.""" + + def test_memory_estimation_returns_positive(self): + """Memory estimation should return a positive value.""" + from sdcpy.core import _estimate_vectorized_memory + + mem = _estimate_vectorized_memory(n1=100, n2=100, n_permutations=99) + assert mem > 0 + assert isinstance(mem, float) + + def test_memory_estimation_scales_with_size(self): + """Larger inputs should require more memory.""" + from sdcpy.core import _estimate_vectorized_memory + + small_mem = _estimate_vectorized_memory(n1=100, n2=100, n_permutations=99) + large_mem = _estimate_vectorized_memory(n1=1000, n2=1000, n_permutations=99) + assert large_mem > small_mem * 50 # Should scale with n^2 + + def test_memory_estimation_scales_with_permutations(self): + """More permutations should require more memory.""" + from sdcpy.core import _estimate_vectorized_memory + + few_perms = _estimate_vectorized_memory(n1=100, n2=100, n_permutations=9) + many_perms = _estimate_vectorized_memory(n1=100, n2=100, n_permutations=99) + assert many_perms > few_perms + + def test_chunked_mode_triggers_with_tiny_memory_limit(self, numpy_ts_pair): + """Setting very low max_memory_gb should trigger chunked mode.""" + ts1, ts2 = numpy_ts_pair + + # This should trigger a warning about using chunked processing + with pytest.warns(UserWarning, match="Using chunked processing"): + result = compute_sdc( + ts1, ts2, fragment_size=10, n_permutations=9, max_memory_gb=0.0000001 + ) + + # Should still produce valid results + assert len(result) > 0 + assert result["r"].between(-1, 1).all() + assert result["p_value"].between(0, 1).all() + + def test_chunked_and_full_produce_same_r_values(self, numpy_ts_pair): + """Chunked and full modes should produce identical correlation values. + + Note: When permutations=False, both modes use the same code path for + correlation computation, so they produce identical results. + """ + ts1, ts2 = numpy_ts_pair + + # Use permutations=False since we're testing correlation values, not p-values + # Both should produce identical r values regardless of memory mode + result_full = compute_sdc( + ts1, ts2, fragment_size=10, n_permutations=9, permutations=False, max_memory_gb=100.0 + ) + + result_chunked = compute_sdc( + ts1, + ts2, + fragment_size=10, + n_permutations=9, + permutations=False, + max_memory_gb=0.0000001, + ) + + # Correlation values should be identical (no permutations involved) + np.testing.assert_array_almost_equal(result_full["r"].values, result_chunked["r"].values) + + def test_max_memory_gb_parameter_exists(self, numpy_ts_pair): + """max_memory_gb parameter should be accepted without error.""" + ts1, ts2 = numpy_ts_pair + # Just verify the parameter is accepted + result = compute_sdc( + ts1, ts2, fragment_size=10, n_permutations=9, max_memory_gb=4.0, permutations=False + ) + assert len(result) > 0 diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 10505e6..52bca8b 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -2,7 +2,6 @@ import matplotlib import matplotlib.pyplot as plt -import pytest from sdcpy import SDCAnalysis @@ -13,26 +12,17 @@ class TestSDCAnalysisPlotting: """Smoke tests for SDCAnalysis plotting methods.""" - def test_two_way_plot(self, random_ts_pair): - """two_way_plot should not crash.""" - from plotnine.ggplot import ggplot - - ts1, ts2 = random_ts_pair - sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=9) - result = sdc.two_way_plot() - assert isinstance(result, ggplot) - - def test_combi_plot(self, random_ts_pair): - """combi_plot should return a matplotlib Figure.""" - ts1, ts2 = random_ts_pair + def test_combi_plot(self, ts_pair_any_index): + """combi_plot should return a matplotlib Figure (both index types).""" + ts1, ts2 = ts_pair_any_index sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=9) - result = sdc.combi_plot() + result = sdc.combi_plot(date_fmt="%Y-%m") assert isinstance(result, plt.Figure) plt.close(result) - def test_combi_plot_with_params(self, random_ts_pair): - """combi_plot should accept various parameters.""" - ts1, ts2 = random_ts_pair + def test_combi_plot_with_params(self, ts_pair_any_index): + """combi_plot should accept various parameters (both index types).""" + ts1, ts2 = ts_pair_any_index sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=9) result = sdc.combi_plot( alpha=0.1, @@ -42,6 +32,7 @@ def test_combi_plot_with_params(self, random_ts_pair): max_r=0.5, min_lag=-20, max_lag=20, + date_fmt="%Y-%m", ) assert isinstance(result, plt.Figure) plt.close(result) @@ -56,23 +47,6 @@ def test_combi_plot_alignment(self, random_ts_pair): assert isinstance(result, plt.Figure) plt.close(result) - def test_plot_consecutive(self, correlated_ts_pair): - """plot_consecutive should return a plotnine ggplot.""" - from plotnine.ggplot import ggplot - - ts1, ts2 = correlated_ts_pair - sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=49) - result = sdc.plot_consecutive(alpha=0.5) # Higher alpha for random data - assert isinstance(result, ggplot) - - def test_dominant_lags_plot(self, binned_value_ts_pair): - """dominant_lags_plot should return a matplotlib Figure.""" - ts1, ts2 = binned_value_ts_pair - sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=49) - result = sdc.dominant_lags_plot(alpha=0.5) - assert isinstance(result, plt.Figure) - plt.close(result) - def test_get_ranges_df(self, binned_value_ts_pair): """get_ranges_df should return a DataFrame.""" import pandas as pd @@ -170,13 +144,32 @@ def test_combi_plot_figsize(self, random_ts_pair): assert result.get_size_inches()[1] == 10 plt.close(result) + def test_combi_plot_string_datetime_index(self): + """combi_plot should work with string-based datetime indexes.""" + import numpy as np + import pandas as pd -class TestSingleShiftPlot: - """Tests for the unimplemented single_shift_plot method.""" + # Create time series with string-based (object dtype) datetime index + np.random.seed(42) + dates = pd.date_range("2005-01-01", periods=100, freq="D") + # Convert to strings to simulate user data with object dtype index + string_dates = dates.strftime("%Y-%m-%d") + + ts1 = pd.Series(np.random.randn(100), index=string_dates, name="ts1") + ts1.index.name = "date" + ts2 = pd.Series(np.random.randn(100), index=string_dates, name="ts2") + ts2.index.name = "date" - def test_raises_not_implemented(self, random_ts_pair): - """single_shift_plot should raise NotImplementedError.""" - ts1, ts2 = random_ts_pair sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=9) - with pytest.raises(NotImplementedError): - sdc.single_shift_plot(shift=5) + result = sdc.combi_plot(date_fmt="%Y-%m") + assert isinstance(result, plt.Figure) + plt.close(result) + + def test_combi_plot_integer_index(self, numpy_ts_pair): + """combi_plot should work with integer indexes (no datetime conversion).""" + ts1, ts2 = numpy_ts_pair + # Numpy arrays will be converted to Series with integer index by SDCAnalysis + sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=9) + result = sdc.combi_plot() + assert isinstance(result, plt.Figure) + plt.close(result) diff --git a/tests/test_ranges_df.py b/tests/test_ranges_df.py new file mode 100644 index 0000000..25f64ac --- /dev/null +++ b/tests/test_ranges_df.py @@ -0,0 +1,29 @@ +import numpy as np +import pandas as pd + +from sdcpy import SDCAnalysis + + +def test_get_ranges_df_lag_filtering(): + """Should filter ranges_df based on min_lag and max_lag.""" + # Create simple predictable data + ts1 = pd.Series(np.arange(20)) + ts2 = pd.Series(np.arange(20)) + + sdc = SDCAnalysis(ts1, ts2, fragment_size=5, n_permutations=9) + + # Get ranges with lag filtering + ranges_df = sdc.get_ranges_df(min_lag=-2, max_lag=2) + + # Verify we got some data + assert len(ranges_df) > 0 + + # Manually check the logic + # The sdc_df should be filtered before merging + expected_count = sdc.sdc_df.query("lag >= -2 & lag <= 2").shape[0] + # ranges_df is aggregated by cat_value and direction, so sum of counts should match total number of valid comparisons + assert ranges_df["counts"].sum() == expected_count + + # Test extreme filtering + ranges_df_strict = sdc.get_ranges_df(min_lag=0, max_lag=0) + assert ranges_df_strict["counts"].sum() == sdc.sdc_df.query("lag == 0").shape[0] diff --git a/tests/test_sdc_analysis.py b/tests/test_sdc_analysis.py index 0be5f26..2165083 100644 --- a/tests/test_sdc_analysis.py +++ b/tests/test_sdc_analysis.py @@ -22,8 +22,9 @@ def test_numpy_array_input(self, numpy_ts_pair): """Should accept numpy arrays and auto-generate dates.""" ts1, ts2 = numpy_ts_pair sdc = SDCAnalysis(ts1, ts2, fragment_size=10, n_permutations=9) - # Should have created date indices starting from 2000-01-01 - assert pd.Timestamp("2000-01-01") in sdc.ts1.index + # Should have created integer indices + assert 0 in sdc.ts1.index + assert pd.api.types.is_integer_dtype(sdc.ts1.index) assert len(sdc.sdc_df) > 0 def test_one_way_sdc(self, random_ts_pair): diff --git a/uv.lock b/uv.lock index 97efef9..86e6cdb 100644 --- a/uv.lock +++ b/uv.lock @@ -2264,7 +2264,7 @@ wheels = [ [[package]] name = "sdcpy" -version = "0.5.0" +version = "0.6.0" source = { editable = "." } dependencies = [ { name = "matplotlib", version = "3.9.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },