Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ spikeinterface.preprocessing
.. autofunction:: get_motion_parameters_preset
.. autofunction:: load_motion_info
.. autofunction:: save_motion_info
.. autofunction:: decimate
.. autofunction:: depth_order
.. autofunction:: detect_bad_channels
.. autofunction:: detect_and_interpolate_bad_channels
Expand Down
175 changes: 175 additions & 0 deletions src/spikeinterface/preprocessing/_decimation_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
"""
Helpers for splitting a (potentially large) integer decimation factor into several balanced
sub-factors, so that anti-aliased decimation can be applied as multiple stable scipy.signal.decimate
passes. Shared by DecimateRecording and ResampleRecording.
"""

import math
import warnings

import numpy as np

from spikeinterface.core import get_chunk_with_margin

# scipy.signal.decimate uses an order-8 Chebyshev type I IIR filter by default, and its
# documentation recommends decimating in several balanced steps rather than a single step
# for downsampling factors larger than this value.
_MAX_SINGLE_PASS_DECIMATION = 13


def _prime_factors(n):
"""
Return the prime factors of a positive integer `n` (ascending, with multiplicity).

Examples
--------
>>> _prime_factors(60)
[2, 2, 3, 5]
>>> _prime_factors(17)
[17]
"""
factors = []
divisor = 2
while divisor * divisor <= n:
while n % divisor == 0:
factors.append(divisor)
n //= divisor
divisor += 1
if n > 1:
factors.append(n)
return factors


def _greedy_pack(primes_desc, num_bins):
"""
Greedily pack `primes_desc` (largest first) into `num_bins` bins, keeping each bin's
product <= `_MAX_SINGLE_PASS_DECIMATION` and the bins as balanced as possible.

Returns the list of bin products, or None if some prime cannot be placed (i.e. `num_bins`
is too small to keep every bin <= the single-pass limit).

Examples
--------
Pack the prime factors of 48 into two balanced bins (6 and 8):

>>> _greedy_pack([3, 2, 2, 2, 2], 2)
[6, 8]

Two bins cannot hold 2 ** 7 = 128 without a bin exceeding the single-pass limit of 13:

>>> _greedy_pack([2, 2, 2, 2, 2, 2, 2], 2) is None
True
"""
bins = [1] * num_bins
for prime in primes_desc:
fitting = [i for i in range(num_bins) if bins[i] * prime <= _MAX_SINGLE_PASS_DECIMATION]
if not fitting:
return None
# Place into the smallest fitting bin (ties broken by index, for determinism).
target = min(fitting, key=lambda i: (bins[i], i))
bins[target] *= prime
return bins


def get_balanced_decimation_factors(decimation_factor):
"""
Split `decimation_factor` into sub-factors, each <= 13, as balanced as possible (so their
products are close), for stable multi-pass anti-aliased decimation.

scipy recommends decimating in several balanced steps rather than one large step when the
factor exceeds 13 (e.g. 48 -> [8, 6] rather than [12, 4]). The product of the returned
factors always equals `decimation_factor`.

If `decimation_factor` has a prime factor greater than 13 (e.g. a large prime such as 17),
no valid split exists and `[decimation_factor]` is returned; it is the caller's
responsibility to handle this (e.g., warn that a single, potentially unstable, pass will
be used).
"""
if decimation_factor <= _MAX_SINGLE_PASS_DECIMATION:
return [decimation_factor]

primes = _prime_factors(decimation_factor)
if max(primes) > _MAX_SINGLE_PASS_DECIMATION:
# If a prime factor > 13 cannot be split into sub-13 factors...
return [decimation_factor]

primes_desc = sorted(primes, reverse=True)
# Minimum number of passes so that, ideally, each pass decimates by <= 13.
num_passes = max(1, math.ceil(math.log(decimation_factor) / math.log(_MAX_SINGLE_PASS_DECIMATION)))
while num_passes <= len(primes_desc):
bins = _greedy_pack(primes_desc, num_passes)
if bins is not None:
return sorted(bins, reverse=True)
num_passes += 1
# Fallback: one prime per pass (always valid since every prime is <= 13).
return primes_desc


def get_antialiased_decimated_traces(
parent_segment,
start_frame,
end_frame,
channel_indices,
decimation_factor,
decimation_factors,
margin,
dtype,
decimation_offset=0,
):
"""
Fetch a margined chunk from `parent_segment` and decimate it by `decimation_factor`, applied
as a cascade of the balanced `decimation_factors` passes of ``scipy.signal.decimate``.

The margin is rounded up to a multiple of the total `decimation_factor` so that
``left_margin // decimation_factor`` is exact; combined with scipy's default
``zero_phase=True`` (output sample i maps to filtered input sample i * factor), this keeps the
downsampled traces aligned across chunks (a chunked read matches a full read). Exactly
``end_frame - start_frame`` decimated samples are returned.

Parameters
----------
parent_segment : BaseRecordingSegment
The parent segment to read (full-rate) traces from.
start_frame, end_frame : int
Output (decimated) frame range to return.
channel_indices : slice | list | np.ndarray | None
Channels to read, forwarded to the parent segment.
decimation_factor : int
The total decimation factor (the product of `decimation_factors`).
decimation_factors : list[int]
The per-pass sub-factors (each <= 13), e.g. from `get_balanced_decimation_factors`.
margin : int
Margin in parent samples used to limit anti-aliasing filter edge effects. Rounded up
internally to a multiple of `decimation_factor`.
dtype : np.dtype | str
Output dtype. The decimation runs in float32 and the result is cast to `dtype`.
decimation_offset : int, default: 0
Index of the first parent frame, applied to the first output sample only.
"""
from scipy import signal

q = decimation_factor
parent_start_frame = decimation_offset + start_frame * q
parent_end_frame = parent_start_frame + (end_frame - start_frame) * q
# Round the margin up to a multiple of q so that left_margin // q is exact.
margin = int(np.ceil(margin / q) * q)
parent_traces, left_margin, right_margin = get_chunk_with_margin(
parent_segment,
parent_start_frame,
parent_end_frame,
channel_indices,
margin,
add_reflect_padding=True,
dtype=np.float32,
)
decimated_traces = parent_traces
for sub_q in decimation_factors:
decimated_traces = signal.decimate(decimated_traces, q=sub_q, axis=0)
if np.any(np.isnan(decimated_traces)):
warnings.warn(
f"`scipy.signal.decimate` produced NaNs while decimating by {q}. "
f"Consider a different decimation factor."
)
start_drop = left_margin // q
n_out = end_frame - start_frame
return decimated_traces[start_drop : start_drop + n_out].astype(dtype)
99 changes: 81 additions & 18 deletions src/spikeinterface/preprocessing/decimate.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import warnings

import numpy as np
from spikeinterface.core.core_tools import (
define_function_handling_dict_from_class,
)

from .basepreprocessor import BasePreprocessor
from .filter import fix_dtype
from ._decimation_tools import (
_MAX_SINGLE_PASS_DECIMATION,
get_balanced_decimation_factors,
get_antialiased_decimated_traces,
)
from spikeinterface.core import BaseRecordingSegment


class DecimateRecording(BasePreprocessor):
"""
Decimate the recording extractor traces using array slicing
Decimate the recording extractor traces.

Important: This uses simple array slicing for decimation rather than eg scipy.decimate.
This might introduce aliasing, or skip across signal of interest.
Consider spikeinterface.preprocessing.ResampleRecording for safe resampling.
By default this uses simple array slicing
(``<parent_traces>[<decimation_offset>::<decimation_factor>]``), which is fast but applies no
anti-aliasing filter and so might introduce aliasing, or skip across signal of interest. Set
`antialias=True` to low-pass filter before downsampling using ``scipy.signal.decimate`` (the
same anti-aliased decimation used by ``spikeinterface.preprocessing.ResampleRecording``).

Parameters
----------
Expand All @@ -29,12 +38,25 @@ class DecimateRecording(BasePreprocessor):
to ensure that the decimated recording has at least one frame. Consider combining DecimateRecording
with FrameSliceRecording for fine control on the recording start and end frames.
The same decimation offset is applied to all segments from the parent recording.
antialias : bool, default: False
If True, apply an anti-aliasing low-pass filter before downsampling, using
``scipy.signal.decimate``. When `decimation_factor` exceeds 13, the decimation is
automatically performed in several balanced sub-13 passes (e.g. a factor of 48 is applied
as 8 then 6), as scipy recommends, to keep the IIR anti-aliasing filter stable. If False
(the default), traces are downsampled by plain array slicing with no filtering, and
`margin_ms` is ignored.
margin_ms : float, default: 100.0
Margin in ms used on each side of every chunk to limit edge effects of the anti-aliasing
filter. Only used when `antialias=True`. The margin is internally rounded up to a whole
number of output samples so the filtered, downsampled traces stay aligned across chunks.
dtype : dtype or None, default: None
The dtype of the returned traces. If None, the dtype of the parent recording is used.

Returns
-------
decimate_recording: DecimateRecording
The decimated recording extractor object. The full traces of the child recording segment
correspond to the traces of the parent segment as follows:
The decimated recording extractor object. With `antialias=False` the full traces of the
child recording segment correspond to the traces of the parent segment as follows:
```<decimated_traces> = <parent_traces>[<decimation_offset>::<decimation_factor>]```

"""
Expand All @@ -44,6 +66,9 @@ def __init__(
recording,
decimation_factor,
decimation_offset=0,
antialias=False,
margin_ms=100.0,
dtype=None,
):
# Original sampling frequency
self._orig_samp_freq = recording.get_sampling_frequency()
Expand All @@ -63,7 +88,21 @@ def __init__(
self._decimation_offset = decimation_offset
decimated_sampling_frequency = self._orig_samp_freq / self._decimation_factor

BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency)
# fix_dtype doesn't always returns the str, make sure it does
dtype = fix_dtype(recording, dtype).str

antialias_factors = get_balanced_decimation_factors(decimation_factor)
if antialias and decimation_factor > _MAX_SINGLE_PASS_DECIMATION and antialias_factors == [decimation_factor]:
warnings.warn(
f"`decimation_factor`={decimation_factor} cannot be split into anti-aliasing passes of <= 13 "
f"(it has a prime factor > 13). A single `scipy.signal.decimate` pass will be used, which may be "
f"unstable. Consider a `decimation_factor` without large prime factors."
)

# Margin (in parent samples) to limit anti-aliasing filter edge effects.
margin = int(margin_ms * self._orig_samp_freq / 1000)

BasePreprocessor.__init__(self, recording, sampling_frequency=decimated_sampling_frequency, dtype=dtype)

for parent_segment in recording.segments:
self.add_recording_segment(
Expand All @@ -74,13 +113,19 @@ def __init__(
decimation_factor,
decimation_offset,
self._dtype,
antialias,
margin,
antialias_factors,
)
)

self._kwargs = dict(
recording=recording,
decimation_factor=decimation_factor,
decimation_offset=decimation_offset,
antialias=antialias,
margin_ms=margin_ms,
dtype=dtype,
)


Expand All @@ -93,6 +138,9 @@ def __init__(
decimation_factor,
decimation_offset,
dtype,
antialias=False,
margin=0,
antialias_factors=None,
):
if parent_recording_segment.time_vector is not None:
time_vector = parent_recording_segment.time_vector[decimation_offset::decimation_factor]
Expand All @@ -113,25 +161,40 @@ def __init__(
self._decimation_factor = decimation_factor
self._decimation_offset = decimation_offset
self._dtype = dtype
self._antialias = antialias
self._margin = margin
self._antialias_factors = antialias_factors if antialias_factors is not None else [decimation_factor]

def get_num_samples(self):
parent_n_samp = self._parent_segment.get_num_samples()
assert self._decimation_offset < parent_n_samp # Sanity check (already enforced). Formula changes otherwise
return int(np.ceil((parent_n_samp - self._decimation_offset) / self._decimation_factor))

def get_traces(self, start_frame, end_frame, channel_indices):
# Account for offset and end when querying parent traces
parent_start_frame = self._decimation_offset + start_frame * self._decimation_factor
parent_end_frame = parent_start_frame + (end_frame - start_frame) * self._decimation_factor

# And now we can decimate without offsetting
return self._parent_segment.get_traces(
parent_start_frame,
parent_end_frame,
if not self._antialias:
# Simple array slicing, no anti-aliasing filter.
parent_start_frame = self._decimation_offset + start_frame * self._decimation_factor
parent_end_frame = parent_start_frame + (end_frame - start_frame) * self._decimation_factor
return self._parent_segment.get_traces(
parent_start_frame,
parent_end_frame,
channel_indices,
)[
:: self._decimation_factor
].astype(self._dtype)

# Anti-aliased decimation as a cascade of balanced scipy.signal.decimate passes.
return get_antialiased_decimated_traces(
self._parent_segment,
start_frame,
end_frame,
channel_indices,
)[
:: self._decimation_factor
].astype(self._dtype)
self._decimation_factor,
self._antialias_factors,
self._margin,
self._dtype,
decimation_offset=self._decimation_offset,
)


decimate = define_function_handling_dict_from_class(source_class=DecimateRecording, name="decimate")
Loading
Loading