diff --git a/doc/api.rst b/doc/api.rst index fc55017606..7996c0ba92 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -209,6 +209,7 @@ spikeinterface.preprocessing .. autofunction:: get_motion_parameters_preset .. autofunction:: load_motion_info .. autofunction:: save_motion_info + .. autofunction:: decimate .. autofunction:: depth_order .. autofunction:: detect_bad_channels .. autofunction:: detect_and_interpolate_bad_channels diff --git a/src/spikeinterface/preprocessing/_decimation_tools.py b/src/spikeinterface/preprocessing/_decimation_tools.py new file mode 100644 index 0000000000..fc5147c55e --- /dev/null +++ b/src/spikeinterface/preprocessing/_decimation_tools.py @@ -0,0 +1,175 @@ +""" +Helpers for splitting a (potentially large) integer decimation factor into several balanced +sub-factors, so that anti-aliased decimation can be applied as multiple stable scipy.signal.decimate +passes. Shared by DecimateRecording and ResampleRecording. +""" + +import math +import warnings + +import numpy as np + +from spikeinterface.core import get_chunk_with_margin + +# scipy.signal.decimate uses an order-8 Chebyshev type I IIR filter by default, and its +# documentation recommends decimating in several balanced steps rather than a single step +# for downsampling factors larger than this value. +_MAX_SINGLE_PASS_DECIMATION = 13 + + +def _prime_factors(n): + """ + Return the prime factors of a positive integer `n` (ascending, with multiplicity). + + Examples + -------- + >>> _prime_factors(60) + [2, 2, 3, 5] + >>> _prime_factors(17) + [17] + """ + factors = [] + divisor = 2 + while divisor * divisor <= n: + while n % divisor == 0: + factors.append(divisor) + n //= divisor + divisor += 1 + if n > 1: + factors.append(n) + return factors + + +def _greedy_pack(primes_desc, num_bins): + """ + Greedily pack `primes_desc` (largest first) into `num_bins` bins, keeping each bin's + product <= `_MAX_SINGLE_PASS_DECIMATION` and the bins as balanced as possible. + + Returns the list of bin products, or None if some prime cannot be placed (i.e. `num_bins` + is too small to keep every bin <= the single-pass limit). + + Examples + -------- + Pack the prime factors of 48 into two balanced bins (6 and 8): + + >>> _greedy_pack([3, 2, 2, 2, 2], 2) + [6, 8] + + Two bins cannot hold 2 ** 7 = 128 without a bin exceeding the single-pass limit of 13: + + >>> _greedy_pack([2, 2, 2, 2, 2, 2, 2], 2) is None + True + """ + bins = [1] * num_bins + for prime in primes_desc: + fitting = [i for i in range(num_bins) if bins[i] * prime <= _MAX_SINGLE_PASS_DECIMATION] + if not fitting: + return None + # Place into the smallest fitting bin (ties broken by index, for determinism). + target = min(fitting, key=lambda i: (bins[i], i)) + bins[target] *= prime + return bins + + +def get_balanced_decimation_factors(decimation_factor): + """ + Split `decimation_factor` into sub-factors, each <= 13, as balanced as possible (so their + products are close), for stable multi-pass anti-aliased decimation. + + scipy recommends decimating in several balanced steps rather than one large step when the + factor exceeds 13 (e.g. 48 -> [8, 6] rather than [12, 4]). The product of the returned + factors always equals `decimation_factor`. + + If `decimation_factor` has a prime factor greater than 13 (e.g. a large prime such as 17), + no valid split exists and `[decimation_factor]` is returned; it is the caller's + responsibility to handle this (e.g., warn that a single, potentially unstable, pass will + be used). + """ + if decimation_factor <= _MAX_SINGLE_PASS_DECIMATION: + return [decimation_factor] + + primes = _prime_factors(decimation_factor) + if max(primes) > _MAX_SINGLE_PASS_DECIMATION: + # If a prime factor > 13 cannot be split into sub-13 factors... + return [decimation_factor] + + primes_desc = sorted(primes, reverse=True) + # Minimum number of passes so that, ideally, each pass decimates by <= 13. + num_passes = max(1, math.ceil(math.log(decimation_factor) / math.log(_MAX_SINGLE_PASS_DECIMATION))) + while num_passes <= len(primes_desc): + bins = _greedy_pack(primes_desc, num_passes) + if bins is not None: + return sorted(bins, reverse=True) + num_passes += 1 + # Fallback: one prime per pass (always valid since every prime is <= 13). + return primes_desc + + +def get_antialiased_decimated_traces( + parent_segment, + start_frame, + end_frame, + channel_indices, + decimation_factor, + decimation_factors, + margin, + dtype, + decimation_offset=0, +): + """ + Fetch a margined chunk from `parent_segment` and decimate it by `decimation_factor`, applied + as a cascade of the balanced `decimation_factors` passes of ``scipy.signal.decimate``. + + The margin is rounded up to a multiple of the total `decimation_factor` so that + ``left_margin // decimation_factor`` is exact; combined with scipy's default + ``zero_phase=True`` (output sample i maps to filtered input sample i * factor), this keeps the + downsampled traces aligned across chunks (a chunked read matches a full read). Exactly + ``end_frame - start_frame`` decimated samples are returned. + + Parameters + ---------- + parent_segment : BaseRecordingSegment + The parent segment to read (full-rate) traces from. + start_frame, end_frame : int + Output (decimated) frame range to return. + channel_indices : slice | list | np.ndarray | None + Channels to read, forwarded to the parent segment. + decimation_factor : int + The total decimation factor (the product of `decimation_factors`). + decimation_factors : list[int] + The per-pass sub-factors (each <= 13), e.g. from `get_balanced_decimation_factors`. + margin : int + Margin in parent samples used to limit anti-aliasing filter edge effects. Rounded up + internally to a multiple of `decimation_factor`. + dtype : np.dtype | str + Output dtype. The decimation runs in float32 and the result is cast to `dtype`. + decimation_offset : int, default: 0 + Index of the first parent frame, applied to the first output sample only. + """ + from scipy import signal + + q = decimation_factor + parent_start_frame = decimation_offset + start_frame * q + parent_end_frame = parent_start_frame + (end_frame - start_frame) * q + # Round the margin up to a multiple of q so that left_margin // q is exact. + margin = int(np.ceil(margin / q) * q) + parent_traces, left_margin, right_margin = get_chunk_with_margin( + parent_segment, + parent_start_frame, + parent_end_frame, + channel_indices, + margin, + add_reflect_padding=True, + dtype=np.float32, + ) + decimated_traces = parent_traces + for sub_q in decimation_factors: + decimated_traces = signal.decimate(decimated_traces, q=sub_q, axis=0) + if np.any(np.isnan(decimated_traces)): + warnings.warn( + f"`scipy.signal.decimate` produced NaNs while decimating by {q}. " + f"Consider a different decimation factor." + ) + start_drop = left_margin // q + n_out = end_frame - start_frame + return decimated_traces[start_drop : start_drop + n_out].astype(dtype) diff --git a/src/spikeinterface/preprocessing/decimate.py b/src/spikeinterface/preprocessing/decimate.py index da66cd9c3f..2987712533 100644 --- a/src/spikeinterface/preprocessing/decimate.py +++ b/src/spikeinterface/preprocessing/decimate.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np from spikeinterface.core.core_tools import ( define_function_handling_dict_from_class, @@ -5,16 +7,23 @@ from .basepreprocessor import BasePreprocessor from .filter import fix_dtype +from ._decimation_tools import ( + _MAX_SINGLE_PASS_DECIMATION, + get_balanced_decimation_factors, + get_antialiased_decimated_traces, +) from spikeinterface.core import BaseRecordingSegment class DecimateRecording(BasePreprocessor): """ - Decimate the recording extractor traces using array slicing + Decimate the recording extractor traces. - Important: This uses simple array slicing for decimation rather than eg scipy.decimate. - This might introduce aliasing, or skip across signal of interest. - Consider spikeinterface.preprocessing.ResampleRecording for safe resampling. + By default this uses simple array slicing + (``[::]``), which is fast but applies no + anti-aliasing filter and so might introduce aliasing, or skip across signal of interest. Set + `antialias=True` to low-pass filter before downsampling using ``scipy.signal.decimate`` (the + same anti-aliased decimation used by ``spikeinterface.preprocessing.ResampleRecording``). Parameters ---------- @@ -29,12 +38,25 @@ class DecimateRecording(BasePreprocessor): to ensure that the decimated recording has at least one frame. Consider combining DecimateRecording with FrameSliceRecording for fine control on the recording start and end frames. The same decimation offset is applied to all segments from the parent recording. + antialias : bool, default: False + If True, apply an anti-aliasing low-pass filter before downsampling, using + ``scipy.signal.decimate``. When `decimation_factor` exceeds 13, the decimation is + automatically performed in several balanced sub-13 passes (e.g. a factor of 48 is applied + as 8 then 6), as scipy recommends, to keep the IIR anti-aliasing filter stable. If False + (the default), traces are downsampled by plain array slicing with no filtering, and + `margin_ms` is ignored. + margin_ms : float, default: 100.0 + Margin in ms used on each side of every chunk to limit edge effects of the anti-aliasing + filter. Only used when `antialias=True`. The margin is internally rounded up to a whole + number of output samples so the filtered, downsampled traces stay aligned across chunks. + dtype : dtype or None, default: None + The dtype of the returned traces. If None, the dtype of the parent recording is used. Returns ------- decimate_recording: DecimateRecording - The decimated recording extractor object. The full traces of the child recording segment - correspond to the traces of the parent segment as follows: + The decimated recording extractor object. With `antialias=False` the full traces of the + child recording segment correspond to the traces of the parent segment as follows: ``` = [::]``` """ @@ -44,6 +66,9 @@ def __init__( recording, decimation_factor, decimation_offset=0, + antialias=False, + margin_ms=100.0, + dtype=None, ): # Original sampling frequency self._orig_samp_freq = recording.get_sampling_frequency() @@ -63,7 +88,21 @@ def __init__( self._decimation_offset = decimation_offset decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor - BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency) + # fix_dtype doesn't always returns the str, make sure it does + dtype = fix_dtype(recording, dtype).str + + antialias_factors = get_balanced_decimation_factors(decimation_factor) + if antialias and decimation_factor > _MAX_SINGLE_PASS_DECIMATION and antialias_factors == [decimation_factor]: + warnings.warn( + f"`decimation_factor`={decimation_factor} cannot be split into anti-aliasing passes of <= 13 " + f"(it has a prime factor > 13). A single `scipy.signal.decimate` pass will be used, which may be " + f"unstable. Consider a `decimation_factor` without large prime factors." + ) + + # Margin (in parent samples) to limit anti-aliasing filter edge effects. + margin = int(margin_ms * self._orig_samp_freq / 1000) + + BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency, dtype=dtype) for parent_segment in recording.segments: self.add_recording_segment( @@ -74,6 +113,9 @@ def __init__( decimation_factor, decimation_offset, self._dtype, + antialias, + margin, + antialias_factors, ) ) @@ -81,6 +123,9 @@ def __init__( recording=recording, decimation_factor=decimation_factor, decimation_offset=decimation_offset, + antialias=antialias, + margin_ms=margin_ms, + dtype=dtype, ) @@ -93,6 +138,9 @@ def __init__( decimation_factor, decimation_offset, dtype, + antialias=False, + margin=0, + antialias_factors=None, ): if parent_recording_segment.time_vector is not None: time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor] @@ -113,6 +161,9 @@ def __init__( self._decimation_factor = decimation_factor self._decimation_offset = decimation_offset self._dtype = dtype + self._antialias = antialias + self._margin = margin + self._antialias_factors = antialias_factors if antialias_factors is not None else [decimation_factor] def get_num_samples(self): parent_n_samp = self._parent_segment.get_num_samples() @@ -120,18 +171,30 @@ def get_num_samples(self): return int(np.ceil((parent_n_samp - self._decimation_offset) / self._decimation_factor)) def get_traces(self, start_frame, end_frame, channel_indices): - # Account for offset and end when querying parent traces - parent_start_frame = self._decimation_offset + start_frame * self._decimation_factor - parent_end_frame = parent_start_frame + (end_frame - start_frame) * self._decimation_factor - - # And now we can decimate without offsetting - return self._parent_segment.get_traces( - parent_start_frame, - parent_end_frame, + if not self._antialias: + # Simple array slicing, no anti-aliasing filter. + parent_start_frame = self._decimation_offset + start_frame * self._decimation_factor + parent_end_frame = parent_start_frame + (end_frame - start_frame) * self._decimation_factor + return self._parent_segment.get_traces( + parent_start_frame, + parent_end_frame, + channel_indices, + )[ + :: self._decimation_factor + ].astype(self._dtype) + + # Anti-aliased decimation as a cascade of balanced scipy.signal.decimate passes. + return get_antialiased_decimated_traces( + self._parent_segment, + start_frame, + end_frame, channel_indices, - )[ - :: self._decimation_factor - ].astype(self._dtype) + self._decimation_factor, + self._antialias_factors, + self._margin, + self._dtype, + decimation_offset=self._decimation_offset, + ) decimate = define_function_handling_dict_from_class(source_class=DecimateRecording, name="decimate") diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 902bd6d176..8dbe0a549b 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -8,6 +8,11 @@ from .basepreprocessor import BasePreprocessor from .filter import fix_dtype +from ._decimation_tools import ( + _MAX_SINGLE_PASS_DECIMATION, + get_balanced_decimation_factors, + get_antialiased_decimated_traces, +) from spikeinterface.core import get_chunk_with_margin, BaseRecordingSegment @@ -15,17 +20,19 @@ class ResampleRecording(BasePreprocessor): """ Resample the recording extractor traces. - If the original sampling rate is multiple of the resample_rate, it will use - the signal.decimate method from scipy. In other cases, it uses signal.resample. In the - later case, the resulting signal can have issues on the edges, mainly on the - rightmost. + If the parent sampling rate is an exact integer multiple of `resample_rate`, the + ``signal.decimate`` method from scipy is used (anti-aliased decimation). In other cases + ``signal.resample`` is used, in which case the resulting signal can have issues on the edges, + mainly on the rightmost. See Notes for a caveat on how the integer multiple is detected. Parameters ---------- recording : Recording The recording extractor to be re-referenced - resample_rate : int - The resampling frequency + resample_rate : int | float + The resampling frequency. Integer ratios (parent_rate / resample_rate) use + ``scipy.signal.decimate``; non-integer ratios use ``scipy.signal.resample`` (FFT-based), + which can have edge effects, mainly on the rightmost samples. margin_ms : float, default: 100.0 Margin in ms for computations, will be used to decrease edge effects. dtype : dtype or None, default: None @@ -38,6 +45,19 @@ class ResampleRecording(BasePreprocessor): resample_recording : ResampleRecording The resampled recording extractor object. + Notes + ----- + The (anti-aliased) decimation path is selected by an exact check, + ``parent_rate % resample_rate == 0``. This only detects an integer downsampling factor when + both rates make that modulo exactly zero. If either the parent rate or `resample_rate` is a + non-integer float (i.e. ``float(int(x)) != float(x)``), a conceptually integer ratio can go + undetected and silently fall back to the FFT-based ``scipy.signal.resample`` path. For example, + decimating a 625 Hz recording by a factor of 6 means a target of 104.1666... Hz, and + ``625 % 104.1666... != 0``, so the integer-decimation path is not taken (whereas a factor of 5, + i.e. a 125 Hz target, is detected since ``625 % 125 == 0``). To force anti-aliased integer + decimation by a known factor regardless of the rates, use + ``spikeinterface.preprocessing.DecimateRecording`` with ``antialias=True``. + """ def __init__( @@ -48,13 +68,24 @@ def __init__( dtype=None, skip_checks=False, ): - # Floating point resampling rates can lead to unexpected results, avoid actively - msg = "Non integer resampling rates can lead to unexpected results." - assert isinstance(resample_rate, (int, np.integer)), msg - # Original sampling frequency self._orig_samp_freq = recording.get_sampling_frequency() self._resample_rate = resample_rate self._sampling_frequency = resample_rate + # When the parent rate is an exact integer multiple of resample_rate, get_traces uses + # anti-aliased decimation. Large factors are split into several balanced sub-13 passes + # (see _decimation_tools); only an unsplittable factor (a prime > 13) falls back to a + # single, potentially unstable, pass and warns. + if self._orig_samp_freq % resample_rate == 0: + decimation_factor = int(self._orig_samp_freq / resample_rate) + decimation_factors = get_balanced_decimation_factors(decimation_factor) + if decimation_factors == [decimation_factor] and decimation_factor > _MAX_SINGLE_PASS_DECIMATION: + warnings.warn( + f"Resampling by an integer factor of {decimation_factor} cannot be split into " + f"anti-aliasing passes of <= 13 (it has a prime factor > 13); a single " + f"`scipy.signal.decimate` pass will be used, which may be unstable." + ) + else: + decimation_factors = None # fix_dtype not always returns the str, make sure it does dtype = fix_dtype(recording, dtype).str # Ensure that the requested resample rate is doable: @@ -73,6 +104,7 @@ def __init__( recording.get_sampling_frequency(), margin, dtype, + decimation_factors, ) ) @@ -93,12 +125,16 @@ def __init__( parent_rate, margin, dtype, + decimation_factors=None, ): self._resample_rate = resample_rate self._parent_segment = parent_recording_segment self._parent_rate = parent_rate self._margin = margin self._dtype = dtype + # Per-pass integer decimation factors when the ratio is an exact integer, else None + # (non-integer ratio -> FFT-based scipy.signal.resample). + self._decimation_factors = decimation_factors # Compute time_vector or t_start, following the pattern from DecimateRecordingSegment. # Do not use BasePreprocessorSegment because we have to reset the sampling rate! @@ -129,7 +165,22 @@ def get_num_samples(self): return int(self._parent_segment.get_num_samples() / self._parent_rate * self._resample_rate) def get_traces(self, start_frame, end_frame, channel_indices): - # get parent traces with margin + if self._decimation_factors is not None: + decimation_factor = int(self._parent_rate / self._resample_rate) + return get_antialiased_decimated_traces( + self._parent_segment, + start_frame, + end_frame, + channel_indices, + decimation_factor, + self._decimation_factors, + self._margin, + self._dtype, + ) + + # Non-integer ratio: FFT-based resampling with proportional margins. + from scipy import signal + parent_start_frame, parent_end_frame = [ int((frame / self._resample_rate) * self._parent_rate) for frame in [start_frame, end_frame] ] @@ -147,23 +198,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): int((margin / self._parent_rate) * self._resample_rate) for margin in [left_margin, right_margin] ] - # get the size for the resampled traces in case of resample: + # get the size for the resampled traces num = int((end_frame + right_margin_rs) - (start_frame - left_margin_rs)) - - # Decimate can misbehave on some cases, while resample always looks nice enough. - # Check which method to use: - from scipy import signal - - if np.mod(self._parent_rate, self._resample_rate) == 0: - # Ratio between sampling frequencies - q = int(self._parent_rate / self._resample_rate) - # Decimate can have issues for some cases, returning NaNs - resampled_traces = signal.decimate(parent_traces, q=q, axis=0) - # If that's the case, use signal.resample - if np.any(np.isnan(resampled_traces)): - resampled_traces = signal.resample(parent_traces, num, axis=0) - else: - resampled_traces = signal.resample(parent_traces, num, axis=0) + resampled_traces = signal.resample(parent_traces, num, axis=0) # now take care of the edges resampled_traces = resampled_traces[left_margin_rs : num - right_margin_rs] diff --git a/src/spikeinterface/preprocessing/tests/test_decimate.py b/src/spikeinterface/preprocessing/tests/test_decimate.py index 141345ca46..21b040cdb1 100644 --- a/src/spikeinterface/preprocessing/tests/test_decimate.py +++ b/src/spikeinterface/preprocessing/tests/test_decimate.py @@ -2,8 +2,9 @@ from spikeinterface import NumpyRecording -from spikeinterface.core import generate_recording -from spikeinterface.preprocessing.decimate import DecimateRecording +from spikeinterface.core import generate_recording, load +from spikeinterface.preprocessing.decimate import DecimateRecording, decimate, get_balanced_decimation_factors +from spikeinterface.preprocessing.tests.test_resample import create_sinusoidal_traces import numpy as np @@ -45,7 +46,8 @@ def test_decimate(num_segments, decimation_offset, decimation_factor): ) -def test_decimate_with_times(): +@pytest.mark.parametrize("antialias", [False, True]) +def test_decimate_with_times(antialias): rec = generate_recording(durations=[5, 10]) # test with times @@ -55,7 +57,7 @@ def test_decimate_with_times(): decimation_factor = 2 decimation_offset = 1 - decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset, antialias=antialias) for segment_index in range(rec.get_num_segments()): assert np.allclose( @@ -68,7 +70,7 @@ def test_decimate_with_times(): t_starts = [10, 20] for t_start, rec_segment in zip(t_starts, rec.segments): rec_segment.t_start = t_start - decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset) + decimated_rec = DecimateRecording(rec, decimation_factor, decimation_offset=decimation_offset, antialias=antialias) for segment_index in range(rec.get_num_segments()): assert np.allclose( decimated_rec.get_times(segment_index), @@ -76,5 +78,115 @@ def test_decimate_with_times(): ) +@pytest.mark.parametrize( + "decimation_factor, expected", + [ + (1, [1]), + (7, [7]), + (13, [13]), + (48, [8, 6]), + (50, [10, 5]), + (60, [10, 6]), + (100, [10, 10]), + (17, [17]), # prime > 13: cannot be split + (23, [23]), # prime > 13: cannot be split + ], +) +def test_balanced_decimation_factors(decimation_factor, expected): + factors = get_balanced_decimation_factors(decimation_factor) + assert factors == expected + # The product of the sub-factors always reconstructs the requested factor. + assert int(np.prod(factors)) == decimation_factor + # Every pass is <= 13 unless the factor is an unsplittable prime > 13. + if len(factors) > 1: + assert all(f <= 13 for f in factors) + + +@pytest.mark.parametrize("decimation_factor", [6, 10, 48]) +def test_decimate_antialias_by_chunks(decimation_factor): + # Mirror test_resample_by_chunks: chunked reads must match a full read once the + # anti-aliasing margins are accounted for. Factor 48 exercises the internal multi-pass. + sampling_frequency = int(3e4) + duration = 30 + traces, _ = create_sinusoidal_traces(sampling_frequency, duration, freqs_n=10, max_freq=1000, dtype=np.float32) + parent_rec = NumpyRecording(traces, sampling_frequency) + rms = np.sqrt(np.mean(parent_rec.get_traces() ** 2)) + decimated_rate = sampling_frequency / decimation_factor + + for margin_ms in [100, 1000]: + rec2 = DecimateRecording(parent_rec, decimation_factor, antialias=True, margin_ms=margin_ms) + chunk_size = int(decimated_rate * 2) # ~2 seconds of the decimated signal + rec3 = rec2.save(format="memory", chunk_size=chunk_size, n_jobs=1, progress_bar=False) + + traces2 = rec2.get_traces() + traces3 = rec3.get_traces() + + # Drop the first and last chunk before comparing (as in test_resample_by_chunks). + sl = slice(chunk_size, -chunk_size) + error_mean = np.sqrt(np.mean((traces2[sl] - traces3[sl]) ** 2)) + error_max = np.sqrt(np.max((traces2[sl] - traces3[sl]) ** 2)) + + assert error_mean / rms < 0.01 + assert error_max / rms < 0.05 + + +@pytest.mark.parametrize("decimation_factor", [6, 10]) +@pytest.mark.parametrize("decimation_offset", [0, 1, 5]) +def test_decimate_antialias_with_offset(decimation_factor, decimation_offset): + sampling_frequency = 30000 + # max_freq below every tested Nyquist, so anti-aliasing barely changes the signal. + traces, _ = create_sinusoidal_traces(sampling_frequency, duration=5, freqs_n=6, max_freq=500, dtype=np.float32) + parent_rec = NumpyRecording(traces, sampling_frequency) + + dec_aa = DecimateRecording( + parent_rec, decimation_factor, decimation_offset=decimation_offset, antialias=True, dtype="float32" + ) + dec_plain = DecimateRecording( + parent_rec, decimation_factor, decimation_offset=decimation_offset, antialias=False, dtype="float32" + ) + + # The anti-aliasing path returns the same number of samples as plain slicing. + parent_n = parent_rec.get_num_samples() + expected_n = int(np.ceil((parent_n - decimation_offset) / decimation_factor)) + assert dec_aa.get_num_samples() == expected_n + assert dec_aa.get_num_samples() == dec_plain.get_num_samples() + + # With only sub-Nyquist content, anti-aliased and plain-sliced traces stay aligned. + corr = np.corrcoef(dec_aa.get_traces().ravel(), dec_plain.get_traces().ravel())[0, 1] + assert corr > 0.95 + + +def test_decimate_antialias_multipass(): + sampling_frequency = 30000 + decimation_factor = 48 + traces, _ = create_sinusoidal_traces(sampling_frequency, duration=10, freqs_n=8, max_freq=200, dtype=np.float32) + parent_rec = NumpyRecording(traces, sampling_frequency) + + dec = decimate(parent_rec, decimation_factor, antialias=True) + + # Multi-pass happens internally: a single DecimateRecording carries the full factor. + assert isinstance(dec, DecimateRecording) + assert dec._kwargs["decimation_factor"] == decimation_factor + + segment = dec.segments[0] + assert int(np.prod(segment._antialias_factors)) == decimation_factor + assert all(f <= 13 for f in segment._antialias_factors) + + parent_n = parent_rec.get_num_samples() + assert dec.get_num_samples() == int(np.ceil(parent_n / decimation_factor)) + + # Provenance round-trips and reproduces the traces. + dec_loaded = load(dec.to_dict()) + np.testing.assert_allclose(dec_loaded.get_traces(), dec.get_traces()) + + +def test_decimate_antialias_large_prime_warns(): + rec = generate_recording(durations=[2.0], num_channels=2) + with pytest.warns(UserWarning, match="prime factor > 13"): + dec = DecimateRecording(rec, 17, antialias=True) + # The unsplittable factor falls back to a single pass. + assert dec.segments[0]._antialias_factors == [17] + + if __name__ == "__main__": test_decimate()