diff --git a/conftest.py b/conftest.py index 5bf7d74527..e326c60d71 100644 --- a/conftest.py +++ b/conftest.py @@ -8,6 +8,21 @@ def create_cache_folder(tmp_path_factory): return cache_folder +@pytest.fixture(scope="module") +def debug_plots(request): + """Return True if debug plots should be shown.""" + return request.config.getoption("--debug-plots") + + +def pytest_addoption(parser): + parser.addoption( + "--debug-plots", + action="store_true", + default=False, + help="Enable debug plots during tests", + ) + + def pytest_collection_modifyitems(config, items): """ This function marks (in the pytest sense) the tests according to their name and file_path location diff --git a/doc/api.rst b/doc/api.rst index fd71e06622..fc55017606 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -213,6 +213,9 @@ spikeinterface.preprocessing .. autofunction:: detect_bad_channels .. autofunction:: detect_and_interpolate_bad_channels .. autofunction:: detect_and_remove_bad_channels + .. autofunction:: detect_artifact_periods + .. autofunction:: detect_artifact_periods_by_envelope + .. autofunction:: detect_saturation_periods .. autofunction:: directional_derivative .. autofunction:: filter .. autofunction:: gaussian_filter diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d569293823..2c38248c1a 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -495,16 +495,17 @@ def find_parents_of_type(list_of_parents, parent_type): return parents -def check_graph(nodes): +def check_graph(nodes, check_for_peak_source=True): """ Check that node list is orderd in a good (parents are before children) """ - node0 = nodes[0] - if not isinstance(node0, PeakSource): - raise ValueError( - "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" - ) + if check_for_peak_source: + node0 = nodes[0] + if not isinstance(node0, PeakSource): + raise ValueError( + "Peak pipeline graph must have as first element a PeakSource (PeakDetector or PeakRetriever or SpikeRetriever" + ) for i, node in enumerate(nodes): assert isinstance(node, PipelineNode), f"Node {node} is not an instance of PipelineNode" @@ -532,6 +533,7 @@ def run_node_pipeline( verbose=False, skip_after_n_peaks=None, recording_slices=None, + check_for_peak_source=True, ): """ Machinery to compute in parallel operations on peaks and traces. @@ -587,6 +589,8 @@ def run_node_pipeline( Optionaly give a list of slices to run the pipeline only on some chunks of the recording. It must be a list of (segment_index, frame_start, frame_stop). If None (default), the function iterates over the entire duration of the recording. + check_for_peak_source : bool, default True + Whether to check that the first node is a PeakSource (PeakDetector or PeakRetriever or Returns ------- @@ -595,7 +599,7 @@ def run_node_pipeline( If squeeze_output=True and only one output then directly np.array. """ - check_graph(nodes) + check_graph(nodes, check_for_peak_source=check_for_peak_source) job_kwargs = fix_job_kwargs(job_kwargs) assert all(isinstance(node, PipelineNode) for node in nodes) diff --git a/src/spikeinterface/preprocessing/__init__.py b/src/spikeinterface/preprocessing/__init__.py index de25944bd2..fd8d8fd787 100644 --- a/src/spikeinterface/preprocessing/__init__.py +++ b/src/spikeinterface/preprocessing/__init__.py @@ -20,6 +20,8 @@ PreprocessingPipeline, ) +from .detect_artifacts import detect_artifact_periods, detect_artifact_periods_by_envelope, detect_saturation_periods + # for snippets from .align_snippets import AlignSnippets from warnings import warn diff --git a/src/spikeinterface/preprocessing/detect_artifacts.py b/src/spikeinterface/preprocessing/detect_artifacts.py new file mode 100644 index 0000000000..f9978a6397 --- /dev/null +++ b/src/spikeinterface/preprocessing/detect_artifacts.py @@ -0,0 +1,596 @@ +from __future__ import annotations + +from typing import Literal + +import numpy as np + +from spikeinterface.core import BaseRecording +from spikeinterface.core.base import base_period_dtype +from spikeinterface.preprocessing.rectify import RectifyRecording +from spikeinterface.preprocessing.common_reference import CommonReferenceRecording +from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording +from spikeinterface.core.job_tools import fix_job_kwargs +from spikeinterface.core.recording_tools import get_noise_levels, get_random_data_chunks +from spikeinterface.core.node_pipeline import PeakDetector, base_peak_dtype, run_node_pipeline, PipelineNode + +artifact_dtype = base_period_dtype + +# this will be extend with channel boundaries if needed +# extended_artifact_dtype = artifact_dtype + [ +# # TODO +# ] + + +def _collapse_events(events: np.ndarray) -> np.ndarray: + """ + Collapse artifact events that were split across chunk boundaries. + + When a chunk boundary falls within an artifact period the period is emitted + as two adjacent events whose ``end_sample_index`` / ``start_sample_index`` + values are equal. This function merges such pairs into a single record. + + Parameters + ---------- + events : np.ndarray + Array of artifact events with dtype ``artifact_dtype``, containing + ``"start_sample_index"``, ``"end_sample_index"``, and + ``"segment_index"`` fields. + + Returns + ------- + np.ndarray + Array of collapsed artifact events with the same dtype as ``events``. + """ + order = np.lexsort((events["start_sample_index"], events["segment_index"])) + events = events[order] + to_drop = np.zeros(events.size, dtype=bool) + + # compute if duplicate + for i in np.arange(events.size - 1): + same = events["end_sample_index"][i] == events["start_sample_index"][i + 1] + if same: + to_drop[i] = True + events["start_sample_index"][i + 1] = events["start_sample_index"][i] + collapsed_events = events[~to_drop] + return collapsed_events + + +## detect_period_artifacts_saturation zone +class _DetectSaturation(PipelineNode): + """ + A pipeline node for parallelised amplifier-saturation detection. + + When run with :func:`run_node_pipeline`, this node computes saturation + events for a given data chunk. See :func:`detect_saturation_periods` for + the full algorithm description and parameter semantics. + """ + + name = "detect_saturation" + preferred_mp_context = None + _compute_has_extended_signature = True + + def __init__( + self, + recording: BaseRecording, + saturation_threshold_uV: float, + diff_threshold_uV: float | None, + proportion: float, + ) -> None: + """ + Parameters + ---------- + recording : BaseRecording + The recording to process. + saturation_threshold_uV : float + Voltage saturation threshold in μV. + diff_threshold_uV : float | None + First-derivative threshold in μV/sample, or ``None`` to disable + derivative-based detection. + proportion : float + Fraction of channels that must exceed the threshold for a sample to + be labelled as saturated (0 < proportion < 1). + """ + PipelineNode.__init__(self, recording, return_output=True) + + num_chans = recording.get_num_channels() + + self.diff_threshold_uV = diff_threshold_uV + thresh = np.full((num_chans,), saturation_threshold_uV) + # 0.98 is empirically determined as the true saturating point is + # slightly lower than the documented saturation point of the probe + self.sampling_frequency = recording.get_sampling_frequency() + self.proportion = proportion + self._dtype = np.dtype(artifact_dtype) + self.gain = recording.get_channel_gains() + self.offset = recording.get_channel_offsets() + + self.saturation_threshold_unscaled = (thresh - self.offset) / self.gain + + # do not apply offset when dealing with the derivative + if self.diff_threshold_uV is not None: + self.diff_threshold_unscaled = diff_threshold_uV / self.gain + else: + self.diff_threshold_unscaled = None + + def get_trace_margin(self) -> int: + """Return the number of margin samples required on each side of a chunk.""" + return 0 + + def get_dtype(self) -> np.dtype: + """Return the NumPy dtype of the output array produced by :meth:`compute`.""" + return self._dtype + + def compute( + self, + traces: np.ndarray, + start_frame: int, + end_frame: int, + segment_index: int, + max_margin: int, + ) -> tuple[np.ndarray]: + """ + Detect saturation events within a single chunk of raw traces. + + A sample is labelled as *saturated by value* when the fraction of + channels whose absolute amplitude exceeds + ``saturation_threshold_unscaled`` is greater than ``proportion``. + + Optionally, a sample is also labelled as *saturated by derivative* when + the fraction of channels whose forward-difference amplitude exceeds + ``diff_threshold_unscaled`` is greater than ``proportion``. + + Consecutive saturated samples are grouped into contiguous period events. + + Parameters + ---------- + traces : np.ndarray + Raw trace data for the current chunk, shape ``(n_samples, n_channels)``. + start_frame : int + Index of the first sample of this chunk within its segment. + end_frame : int + Index one past the last sample of this chunk within its segment. + segment_index : int + Index of the segment to which this chunk belongs. + max_margin : int + Maximum trace margin (unused; kept for API compatibility). + + Returns + ------- + tuple[np.ndarray] + A one-element tuple containing an array of saturation events with + dtype ``artifact_dtype``. + """ + # cast to float32 to prevent overflow when applying thresholds in unscaled ADC units + traces = traces.astype("float32") + + saturation = np.mean(np.abs(traces) > self.saturation_threshold_unscaled, axis=1) + detected_by_value = saturation > self.proportion + + if self.diff_threshold_unscaled is not None: + # then compute the derivative of the voltage saturation + n_diff_saturated = np.mean(np.abs(np.diff(traces, axis=0)) >= self.diff_threshold_unscaled, axis=1) + + # Note this means the velocity is not checked for the last sample in the + # check because we are taking the forward derivative + n_diff_saturated = np.r_[n_diff_saturated, 0] + + # if either of those reaches more than the proportion of channels labels the sample as saturated + detected_by_diff = n_diff_saturated > self.proportion + saturation = np.logical_or(detected_by_value, detected_by_diff) + else: + saturation = detected_by_value + + intervals = np.flatnonzero(np.diff(saturation, prepend=False, append=False)) + n_events = len(intervals) // 2 # Number of saturation periods + events = np.zeros(n_events, dtype=artifact_dtype) + + for i, (start, stop) in enumerate(zip(intervals[::2], intervals[1::2])): + events[i]["start_sample_index"] = start + start_frame + events[i]["end_sample_index"] = stop + start_frame + events[i]["segment_index"] = segment_index + + return (events,) + + +def detect_saturation_periods( + recording: BaseRecording, + saturation_threshold_uV: float | None = None, + diff_threshold_uV: float | None = None, + proportion: float = 0.2, + job_kwargs: dict | None = None, +) -> np.ndarray: + """ + Detect amplifier saturation events (single- or multi-sample periods) in raw data. + + Saturation detection should be applied to the **raw** recording, before any + preprocessing. The returned periods can then be used to zero out (silence) + the corresponding samples **after** preprocessing has been performed. + + Saturation is identified in two complementary ways: + + 1. **By value**: a sample is saturated when the fraction of channels whose + absolute amplitude exceeds ``saturation_threshold_uV`` is greater than + ``proportion``. + 2. **By derivative**: a sample is saturated when the fraction of channels + whose forward-difference amplitude exceeds ``diff_threshold_uV`` is + greater than ``proportion``. + + If ``diff_threshold_uV`` is not ``None``, a sample is marked as saturated + if *either* criterion is met. + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect saturation events. + saturation_threshold_uV : float | None, default: None + Voltage saturation threshold in μV. The appropriate value depends on + the probe and amplifier gain settings; for Neuropixels 1.0 probes IBL + recommend **1200 μV**. NP2 probes are harder to saturate than NP1. + If ``None``, the value is read from the ``"saturation_threshold_uV"`` + annotation of ``recording``. + diff_threshold_uV : float | None, default: None + First-derivative threshold in μV/sample. Periods where the + sample-to-sample voltage change exceeds this value in the required + fraction of channels are flagged as saturation. Pass ``None`` to + disable derivative-based detection and rely solely on + ``saturation_threshold_uV``. IBL use **300 μV/sample** for NP1 probes. + proportion : float, default: 0.2 + Fraction of channels (0 < proportion < 1) that must exceed the + threshold for a sample to be considered saturated. + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` describing each saturation period. + Fields: ``"start_sample_index"``, ``"end_sample_index"``, + ``"segment_index"``. + """ + if job_kwargs is None: + job_kwargs = {} + + job_kwargs = fix_job_kwargs(job_kwargs) + + # The saturation threshold can be specified in the recording annotations and loaded automatically + # for some acquisition systems (e.g., Neuropixels) + if "saturation_threshold_uV" in recording.get_annotation_keys() and saturation_threshold_uV is None: + saturation_threshold_uV = recording.get_annotation("saturation_threshold_uV") + + node0 = _DetectSaturation( + recording, + saturation_threshold_uV=saturation_threshold_uV, + diff_threshold_uV=diff_threshold_uV, + proportion=proportion, + ) + + saturation_periods = run_node_pipeline( + recording, [node0], job_kwargs=job_kwargs, job_name="detect saturation artifacts", check_for_peak_source=False + ) + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + return _collapse_events(saturation_periods) + + +## detect_artifact_periods_by_envelope zone +class _DetectThresholdCrossing(PeakDetector): + """ + A pipeline node that detects threshold crossings of a channel-aggregated envelope. + + Each crossing of the global median z-score above 1 is returned as an event + with a ``"front"`` flag indicating whether the crossing is a rising edge + (``True``) or a falling edge (``False``). Used internally by + :func:`detect_artifact_periods_by_envelope`. + + Attributes + ---------- + abs_thresholds : np.ndarray + Per-channel absolute amplitude thresholds in raw ADC units. + """ + + name = "threshold_crossings" + preferred_mp_context = None + + def __init__( + self, + recording: BaseRecording, + mads: np.ndarray, + medians: np.ndarray, + detect_threshold: float = 5, + ) -> None: + """ + Parameters + ---------- + recording : BaseRecording + The (pre-processed envelope) recording to process. + detect_threshold : float, default: 5 + Detection threshold expressed as a multiple of the estimated noise + level per channel. + mads : np.ndarray + Pre-computed per-channel median absolute deviations in raw ADC units. + medians : np.ndarray + Pre-computed per-channel medians in raw ADC units. + noise_levels_kwargs : dict, default: {} + Additional keyword arguments forwarded to + :func:`~spikeinterface.core.get_noise_levels`. + """ + PeakDetector.__init__(self, recording, return_output=True) + self.abs_thresholds = (mads * detect_threshold)[np.newaxis, :] + self.medians = medians[np.newaxis, :] + # internal dtype + self._dtype = np.dtype([("sample_index", "int64"), ("segment_index", "int64"), ("front", "bool")]) + + def get_trace_margin(self) -> int: + """Return the number of margin samples required on each side of a chunk.""" + return 0 + + def get_dtype(self) -> np.dtype: + """Return the NumPy dtype of the output array produced by :meth:`compute`.""" + return self._dtype + + def compute( + self, + traces: np.ndarray, + start_frame: int, + end_frame: int, + segment_index: int, + max_margin: int, + ) -> tuple[np.ndarray]: + """ + Detect threshold crossings in a single chunk of envelope traces. + + The per-sample signal is the median z-score across channels: + ``z = median(traces / abs_thresholds, axis=1)``. Transitions of + ``z > 1`` are located and returned as crossing events. + + Parameters + ---------- + traces : np.ndarray + Envelope trace data for the current chunk, + shape ``(n_samples, n_channels)``. + start_frame : int + Index of the first sample of this chunk within its segment. + end_frame : int + Index one past the last sample of this chunk within its segment. + segment_index : int + Index of the segment to which this chunk belongs. + max_margin : int + Maximum trace margin (unused; kept for API compatibility). + + Returns + ------- + tuple[np.ndarray] + A one-element tuple containing an array of threshold-crossing + events with fields ``"sample_index"``, ``"segment_index"``, and + ``"front"`` (``True`` for rising edge, ``False`` for falling edge). + """ + z = np.median((traces - self.medians) / self.abs_thresholds, axis=1) + threshold_mask = np.diff((z > 1) != 0, axis=0) + + indices = np.flatnonzero(threshold_mask) + threshold_crossings = np.zeros(indices.size, dtype=self._dtype) + threshold_crossings["sample_index"] = indices + threshold_crossings["segment_index"] = segment_index + threshold_crossings["front"][::2] = True + threshold_crossings["front"][1::2] = False + return (threshold_crossings,) + + +def detect_artifact_periods_by_envelope( + recording: BaseRecording, + detect_threshold: float = 5, + apply_envelope_common_reference: bool = False, + freq_max: float = 20.0, + seed: int | None = None, + job_kwargs: dict | None = None, + random_slices_kwargs: dict | None = None, + return_envelope: bool = False, +) -> np.ndarray | tuple[np.ndarray, BaseRecording]: + """ + Detect putative artifact periods as threshold crossings of a global channel envelope. + + The pipeline is: + + 1. Rectify the raw recording. + 2. Low-pass filter with a Gaussian filter up to ``freq_max`` Hz to produce + a smooth per-channel amplitude envelope. + 3. Apply a common-average reference so that only signals correlated across + channels (i.e. artefacts) survive. + 4. Estimate per-channel noise levels on the envelope. + 5. Detect samples where the median channel z-score exceeds + ``detect_threshold``, and convert contiguous runs into period records. + + Parameters + ---------- + recording : BaseRecording + The recording extractor from which to detect artefact periods. + detect_threshold : float, default: 5 + Detection threshold as a multiple of the estimated per-channel noise + level of the envelope. + freq_max : float, default: 20.0 + Cut-off frequency (Hz) for the Gaussian low-pass filter applied to the + rectified signal when building the envelope. + seed : int | None, default: None + Random seed forwarded to :func:`~spikeinterface.core.get_noise_levels`. + If ``None``, ``get_noise_levels`` uses ``seed=0``. + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + random_slices_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the ``random_slices_kwargs`` + argument of :func:`~spikeinterface.core.get_noise_levels`. + return_envelope : bool, default: False + If ``True``, also return the intermediate envelope recording so that it + can be inspected or plotted. + + Returns + ------- + artifacts : np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. Fields: ``"start_sample_index"``, ``"end_sample_index"``, + ``"segment_index"``. + envelope : BaseRecording + Only returned when ``return_envelope=True``. The processed envelope + recording (rectified → Gaussian-filtered → common-average referenced). + """ + envelope = RectifyRecording(recording) + envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) + if apply_envelope_common_reference: + envelope = CommonReferenceRecording(envelope) + + job_kwargs = fix_job_kwargs(job_kwargs) + if random_slices_kwargs is None: + random_slices_kwargs = {} + else: + random_slices_kwargs = random_slices_kwargs.copy() + random_slices_kwargs["seed"] = seed + random_data = get_random_data_chunks(envelope, **random_slices_kwargs) + medians = np.median(random_data, axis=0) + mad = np.median(np.abs(random_data - medians), axis=0) + mads = mad / 0.6745 + + node0 = _DetectThresholdCrossing( + envelope, + detect_threshold=detect_threshold, + mads=mads, + medians=medians, + ) + + threshold_crossings = run_node_pipeline( + envelope, + [node0], + job_kwargs, + job_name="detect artifact on envelope", + check_for_peak_source=False, + ) + + order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) + threshold_crossings = threshold_crossings[order] + + artifacts = _transform_internal_dtype_to_artifact_dtype(threshold_crossings, recording) + + num_samples = [recording.get_num_samples(seg_index) for seg_index in range(recording.get_num_segments())] + artifacts = _collapse_events(artifacts) + + if return_envelope: + return artifacts, envelope + else: + return artifacts + + +def _transform_internal_dtype_to_artifact_dtype( + artifacts: np.ndarray, + recording: BaseRecording, +) -> np.ndarray: + """ + Convert threshold-crossing events to the standard ``artifact_dtype`` format. + + Threshold-crossing events are stored as individual rising/falling edge + records. This function pairs them up segment by segment to produce + contiguous period records. Edge cases at segment boundaries are handled: + + * If the first event in a segment is a falling edge, an implicit rising + edge at sample 0 is prepended. + * If the last event in a segment is a rising edge, an implicit falling edge + at the last sample of the segment is appended. + + Parameters + ---------- + artifacts : np.ndarray + Array of threshold-crossing events with fields ``"sample_index"``, + ``"segment_index"``, and ``"front"`` (``True`` = rising edge). + Must be sorted by ``(segment_index, sample_index)``. + recording : BaseRecording + The original recording, used to determine the number of segments and + the number of samples per segment. + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` containing the merged artifact + periods. Returns an empty array if no crossings are found. + """ + num_seg = recording.get_num_segments() + + final_artifacts = [] + for seg_index in range(num_seg): + mask = artifacts["segment_index"] == seg_index + sub_thr = artifacts[mask] + print(sub_thr) + if len(sub_thr) > 0: + if not sub_thr["front"][0]: + local_thr = np.zeros(1, dtype=np.dtype(base_period_dtype + [("front", "bool")])) + local_thr["sample_index"] = 0 + local_thr["front"] = True + sub_thr = np.hstack((local_thr, sub_thr)) + if sub_thr["front"][-1]: + local_thr = np.zeros(1, dtype=np.dtype(base_period_dtype + [("front", "bool")])) + local_thr["sample_index"] = recording.get_num_samples(seg_index) + local_thr["front"] = False + sub_thr = np.hstack((sub_thr, local_thr)) + + local_artifact = np.zeros(int(np.ceil(sub_thr.size / 2)), dtype=artifact_dtype) + local_artifact["start_sample_index"] = sub_thr["sample_index"][::2] + local_artifact["end_sample_index"] = sub_thr["sample_index"][1::2] + local_artifact["segment_index"] = seg_index + final_artifacts.append(local_artifact) + + if len(final_artifacts) > 0: + final_artifacts = np.concatenate(final_artifacts) + else: + final_artifacts = np.zeros(0, dtype=artifact_dtype) + return final_artifacts + + +_method_to_function = { + "envelope": detect_artifact_periods_by_envelope, + "saturation": detect_saturation_periods, +} + + +def detect_artifact_periods( + recording: BaseRecording, + method: Literal["envelope", "saturation"] = "envelope", + method_kwargs: dict | None = None, + job_kwargs: dict | None = None, +) -> np.ndarray: + """ + Detect artifact periods using one of several available methods. + + Available methods: + + * ``"envelope"``: detects artifacts as threshold crossings of a low-pass-filtered, rectified + channel envelope. + * ``"saturation"``: detects amplifier saturation events by a voltage threshold and/or a derivative threshold. + + See the documentation of each sub-function for a full description of their + parameters, which can be forwarded via ``method_kwargs``. + + Parameters + ---------- + recording : BaseRecording + The recording on which to detect artifact periods. + method : {"envelope", "saturation"}, default: "envelope" + Detection method to use. + method_kwargs : dict | None, default: None + Additional keyword arguments forwarded to the selected detection + function. Pass ``None`` to use that function's defaults. + job_kwargs : dict | None, default: None + Keyword arguments forwarded to :func:`run_node_pipeline` (e.g. + ``n_jobs``, ``chunk_duration``). + + Returns + ------- + np.ndarray + Array with dtype ``artifact_dtype`` describing each detected artifact + period. + """ + assert ( + method in _method_to_function + ), f"Method {method} not recognized. Valid methods are: {_method_to_function.keys()}" + if method_kwargs is None: + method_kwargs = dict() + + artifact_periods = _method_to_function[method](recording, job_kwargs=job_kwargs, **method_kwargs) + + return artifact_periods diff --git a/src/spikeinterface/preprocessing/preprocessing_classes.py b/src/spikeinterface/preprocessing/preprocessing_classes.py index ef75c595a3..47e3c0906b 100644 --- a/src/spikeinterface/preprocessing/preprocessing_classes.py +++ b/src/spikeinterface/preprocessing/preprocessing_classes.py @@ -48,7 +48,8 @@ from .depth_order import DepthOrderRecording, depth_order from .astype import AstypeRecording, astype from .unsigned_to_signed import UnsignedToSignedRecording, unsigned_to_signed -from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts + +# from .silence_artifacts import SilencedArtifactsRecording, silence_artifacts _all_preprocesser_dict = { # filter stuff @@ -88,7 +89,7 @@ DirectionalDerivativeRecording: directional_derivative, AstypeRecording: astype, UnsignedToSignedRecording: unsigned_to_signed, - SilencedArtifactsRecording: silence_artifacts, + # SilencedArtifactsRecording: silence_artifacts, } # we control import in the preprocessing init by setting an __all__ diff --git a/src/spikeinterface/preprocessing/silence_artifacts.py b/src/spikeinterface/preprocessing/silence_artifacts.py deleted file mode 100644 index cbbb61f836..0000000000 --- a/src/spikeinterface/preprocessing/silence_artifacts.py +++ /dev/null @@ -1,222 +0,0 @@ -import numpy as np - -from spikeinterface.core.base import base_peak_dtype -from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from spikeinterface.core.job_tools import split_job_kwargs, fix_job_kwargs -from spikeinterface.core.recording_tools import get_noise_levels -from spikeinterface.core.node_pipeline import PeakDetector -from spikeinterface.preprocessing.silence_periods import SilencedPeriodsRecording -from spikeinterface.preprocessing.rectify import RectifyRecording -from spikeinterface.preprocessing.common_reference import CommonReferenceRecording -from spikeinterface.preprocessing.filter_gaussian import GaussianFilterRecording - - -class DetectThresholdCrossing(PeakDetector): - - name = "threshold_crossings" - preferred_mp_context = None - - def __init__( - self, - recording, - detect_threshold=5, - noise_levels=None, - seed=None, - noise_levels_kwargs=dict(), - ): - PeakDetector.__init__(self, recording, return_output=True) - if noise_levels is None: - random_slices_kwargs = noise_levels_kwargs.pop("random_slices_kwargs", {}).copy() - random_slices_kwargs["seed"] = seed - noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=random_slices_kwargs) - self.abs_thresholds = noise_levels * detect_threshold - self._dtype = np.dtype(base_peak_dtype + [("front", "bool")]) - - def get_trace_margin(self): - return 0 - - def get_dtype(self): - return self._dtype - - def compute(self, traces, start_frame, end_frame, segment_index, max_margin): - z = np.median(traces / self.abs_thresholds, 1) - threshold_mask = np.diff((z > 1) != 0, axis=0) - indices = np.flatnonzero(threshold_mask) - threshold_crossings = np.zeros(indices.size, dtype=self._dtype) - threshold_crossings["sample_index"] = indices - threshold_crossings["front"][::2] = True - threshold_crossings["front"][1::2] = False - return (threshold_crossings,) - - -def detect_period_artifacts_by_envelope( - recording, - detect_threshold=5, - min_duration_ms=50, - freq_max=20.0, - seed=None, - noise_levels=None, - **noise_levels_kwargs, -): - """ - Docstring for detect_period_artifacts. Function to detect putative artifact periods as threshold crossings of - a global envelope of the channels. - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to detect putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels`. - If none, `get_noise_levels` uses `seed=0`. - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - """ - - envelope = RectifyRecording(recording) - envelope = GaussianFilterRecording(envelope, freq_min=None, freq_max=freq_max) - envelope = CommonReferenceRecording(envelope) - - from spikeinterface.core.node_pipeline import ( - run_node_pipeline, - ) - - _, job_kwargs = split_job_kwargs(noise_levels_kwargs) - job_kwargs = fix_job_kwargs(job_kwargs) - - node0 = DetectThresholdCrossing( - recording, detect_threshold=detect_threshold, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - threshold_crossings = run_node_pipeline( - recording, - [node0], - job_kwargs, - job_name="detect threshold crossings", - ) - - order = np.lexsort((threshold_crossings["sample_index"], threshold_crossings["segment_index"])) - threshold_crossings = threshold_crossings[order] - - periods = [] - fs = recording.sampling_frequency - max_duration_samples = int(min_duration_ms * fs / 1000) - num_seg = recording.get_num_segments() - - for seg_index in range(num_seg): - sub_periods = [] - mask = threshold_crossings["segment_index"] == seg_index - sub_thr = threshold_crossings[mask] - if len(sub_thr) > 0: - local_thr = np.zeros(1, dtype=np.dtype(base_peak_dtype + [("front", "bool")])) - if not sub_thr["front"][0]: - local_thr["sample_index"] = 0 - local_thr["front"] = True - sub_thr = np.hstack((local_thr, sub_thr)) - if sub_thr["front"][-1]: - local_thr["sample_index"] = recording.get_num_samples(seg_index) - local_thr["front"] = False - sub_thr = np.hstack((sub_thr, local_thr)) - - indices = np.flatnonzero(np.diff(sub_thr["front"])) - for i, j in zip(indices[:-1], indices[1:]): - if sub_thr["front"][i]: - start = sub_thr["sample_index"][i] - end = sub_thr["sample_index"][j] - if end - start > max_duration_samples: - sub_periods.append((start, end)) - - periods.append(sub_periods) - - return periods, envelope - - -class SilencedArtifactsRecording(SilencedPeriodsRecording): - """ - Silence user-defined periods from recording extractor traces. The code will construct - an enveloppe of the recording (as a low pass filtered version of the traces) and detect - threshold crossings to identify the periods to silence. The periods are then silenced either - on a per channel basis or across all channels by replacing the values by zeros or by - adding gaussian noise with the same variance as the one in the recordings - - Parameters - ---------- - recording : RecordingExtractor - The recording extractor to silence putative artifacts - detect_threshold : float, default: 5 - The threshold to detect artifacts. The threshold is computed as `detect_threshold * noise_level` - freq_max : float, default: 20 - The maximum frequency for the low pass filter used - min_duration_ms : float, default: 50 - The minimum duration for a threshold crossing to be considered as an artefact. - noise_levels : array - Noise levels if already computed - seed : int | None, default: None - Random seed for `get_noise_levels` and `NoiseGeneratorRecording`. - If none, `get_noise_levels` uses `seed=0` and `NoiseGeneratorRecording` generates a random seed using `numpy.random.default_rng`. - mode : "zeros" | "noise", default: "zeros" - Determines what periods are replaced by. Can be one of the following: - - - "zeros": Artifacts are replaced by zeros. - - - "noise": The periods are filled with a gaussion noise that has the - same variance that the one in the recordings, on a per channel - basis - **noise_levels_kwargs : Keyword arguments for `spikeinterface.core.get_noise_levels()` function - - Returns - ------- - silenced_recording : SilencedArtifactsRecording - The recording extractor after silencing detected artifacts - """ - - _precomputable_kwarg_names = ["list_periods"] - - def __init__( - self, - recording, - detect_threshold=5, - verbose=False, - freq_max=20.0, - min_duration_ms=50, - mode="zeros", - noise_levels=None, - seed=None, - list_periods=None, - **noise_levels_kwargs, - ): - - if list_periods is None: - list_periods, _ = detect_period_artifacts_by_envelope( - recording, - detect_threshold=detect_threshold, - min_duration_ms=min_duration_ms, - freq_max=freq_max, - seed=seed, - noise_levels=noise_levels, - **noise_levels_kwargs, - ) - - if verbose: - for i, periods in enumerate(list_periods): - total_time = np.sum([end - start for start, end in periods]) - percentage = 100 * total_time / recording.get_num_samples(i) - print(f"{percentage}% of segment {i} has been flagged as artifactual") - - SilencedPeriodsRecording.__init__( - self, recording, list_periods, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs - ) - - -# function for API -silence_artifacts = define_function_handling_dict_from_class( - source_class=SilencedArtifactsRecording, name="silence_artifacts" -) diff --git a/src/spikeinterface/preprocessing/silence_periods.py b/src/spikeinterface/preprocessing/silence_periods.py index 817df7031e..189b97ec87 100644 --- a/src/spikeinterface/preprocessing/silence_periods.py +++ b/src/spikeinterface/preprocessing/silence_periods.py @@ -6,6 +6,7 @@ from spikeinterface.core import get_noise_levels from spikeinterface.core.generate import NoiseGeneratorRecording from spikeinterface.core.job_tools import split_job_kwargs +from spikeinterface.core.base import base_period_dtype class SilencedPeriodsRecording(BasePreprocessor): @@ -46,7 +47,9 @@ class SilencedPeriodsRecording(BasePreprocessor): def __init__( self, recording, - list_periods, + periods=None, + # this is keep for backward compatibility + list_periods=None, mode="zeros", noise_levels=None, seed=None, @@ -54,25 +57,27 @@ def __init__( ): available_modes = ("zeros", "noise") num_seg = recording.get_num_segments() - if num_seg == 1: - if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: - # when unique segment accept list instead of list of list/arrays - list_periods = [list_periods] - # some checks - assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" - assert isinstance(list_periods, list), "'list_periods' must be a list (one per segment)" - assert len(list_periods) == num_seg, "'list_periods' must have the same length as the number of segments" - assert all( - isinstance(list_periods[i], (list, np.ndarray)) for i in range(num_seg) - ), "Each element of 'list_periods' must be array-like" + # handle backward compatibility with previous version + if list_periods is not None: + assert periods is None + periods = _all_period_list_to_periods_vec(list_periods, num_seg) + else: + assert list_periods is None + if not isinstance(periods, np.ndarray): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + if periods.dtype.fields is None: + # this is the old format : list[list[int]] + periods = _all_period_list_to_periods_vec(periods, num_seg) + + # force order + order = np.lexsort((periods["start_sample_index"], periods["segment_index"])) + periods = periods[order] + _check_periods(periods, num_seg) - for periods in list_periods: - if len(periods) > 0: - assert np.all(np.diff(np.array(periods), axis=1) > 0), "t_stops should be larger than t_starts" - assert np.all( - periods[i][1] < periods[i + 1][0] for i in np.arange(len(periods) - 1) - ), "Intervals should not overlap" + # some checks + assert mode in available_modes, f"mode {mode} is not an available mode: {available_modes}" if mode in ["noise"]: if noise_levels is None: @@ -96,16 +101,56 @@ def __init__( noise_generator = None BasePreprocessor.__init__(self, recording) + + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) for seg_index, parent_segment in enumerate(recording._recording_segments): - periods = list_periods[seg_index] - periods = np.asarray(periods, dtype="int64") - periods = np.sort(periods, axis=0) - rec_segment = SilencedPeriodsRecordingSegment(parent_segment, periods, mode, noise_generator, seg_index) + i0 = seg_limits[seg_index] + i1 = seg_limits[seg_index + 1] + periods_in_seg = periods[i0:i1] + rec_segment = SilencedPeriodsRecordingSegment( + parent_segment, periods_in_seg, mode, noise_generator, seg_index + ) self.add_recording_segment(rec_segment) - self._kwargs = dict( - recording=recording, list_periods=list_periods, mode=mode, seed=seed, noise_levels=noise_levels - ) + self._kwargs = dict(recording=recording, periods=periods, mode=mode, seed=seed, noise_levels=noise_levels) + + +def _all_period_list_to_periods_vec(list_periods, num_seg): + if num_seg == 1: + if isinstance(list_periods, (list, np.ndarray)) and np.array(list_periods).ndim == 2: + # when unique segment accept list instead of list of list/arrays + list_periods = [list_periods] + size = sum(len(p) for p in list_periods) + periods = np.zeros(size, dtype=base_period_dtype) + start = 0 + for i in range(num_seg): + periods_in_seg = np.array(list_periods[i]) + stop = start + periods_in_seg.shape[0] + periods[start:stop]["segment_index"] = i + periods[start:stop]["start_sample_index"] = periods_in_seg[:, 0] + periods[start:stop]["end_sample_index"] = periods_in_seg[:, 1] + start = stop + return periods + + +def _check_periods(periods, num_seg): + # check dtype + if any(col not in np.dtype(base_period_dtype).fields for col in periods.dtype.fields): + raise ValueError(f"periods must be a np.array with dtype {base_period_dtype}") + + # check non overlap and non negative + seg_limits = np.searchsorted(periods["segment_index"], np.arange(num_seg + 1)) + for i in range(num_seg): + i0 = seg_limits[i] + i1 = seg_limits[i + 1] + periods_in_seg = periods[i0:i1] + if periods_in_seg.size == 0: + continue + if len(periods) > 0: + if np.any(periods_in_seg["start_sample_index"] > periods_in_seg["end_sample_index"]): + raise ValueError("end_sample_index should be larger than start_sample_index") + if np.any(periods_in_seg["start_sample_index"][1:] < periods_in_seg["end_sample_index"][:-1]): + raise ValueError("Intervals should not overlap") class SilencedPeriodsRecordingSegment(BasePreprocessorSegment): @@ -118,18 +163,20 @@ def __init__(self, parent_recording_segment, periods, mode, noise_generator, seg def get_traces(self, start_frame, end_frame, channel_indices): traces = self.parent_recording_segment.get_traces(start_frame, end_frame, channel_indices) - traces = traces.copy() + if self.periods.size > 0: new_interval = np.array([start_frame, end_frame]) - lower_index = np.searchsorted(self.periods[:, 1], new_interval[0]) - upper_index = np.searchsorted(self.periods[:, 0], new_interval[1]) + + lower_index = np.searchsorted(self.periods["end_sample_index"], new_interval[0]) + upper_index = np.searchsorted(self.periods["start_sample_index"], new_interval[1]) if upper_index > lower_index: - periods_in_interval = self.periods[lower_index:upper_index] + traces = traces.copy() + periods_in_interval = self.periods[lower_index:upper_index] for period in periods_in_interval: - onset = max(0, period[0] - start_frame) - offset = min(period[1] - start_frame, end_frame) + onset = max(0, period["start_sample_index"] - start_frame) + offset = min(period["end_sample_index"] - start_frame, end_frame) if self.mode == "zeros": traces[onset:offset, :] = 0 @@ -146,3 +193,46 @@ def get_traces(self, start_frame, end_frame, channel_indices): silence_periods = define_function_handling_dict_from_class( source_class=SilencedPeriodsRecording, name="silence_periods" ) + + +class DetectArtifactAndSilentPeriodsRecording(SilencedPeriodsRecording): + """ + Class doing artifact detection and lient at the same time. + + See SilencedPeriodsRecording and detect_artifact_periods for details. + """ + + _precomputable_kwarg_names = ["artifacts"] + + def __init__( + self, + recording, + detect_artifact_method="envelope", + detect_artifact_kwargs=dict(), + periods=None, + mode="zeros", + noise_levels=None, + seed=None, + **noise_levels_kwargs, + ): + + if artifacts is None: + from spikeinterface.preprocessing import detect_artifact_periods + + artifacts = detect_artifact_periods( + recording, + method=detect_artifact_method, + method_kwargs=detect_artifact_kwargs, + job_kwargs=None, + ) + + SilencedPeriodsRecording.__init__( + self, recording, periods=artifacts, mode=mode, noise_levels=noise_levels, seed=seed, **noise_levels_kwargs + ) + # note self._kwargs["periods"] is done by SilencedPeriodsRecording and so the computaion is done once + + +# function for API +detect_artifacts_and_silent_periods = define_function_handling_dict_from_class( + source_class=DetectArtifactAndSilentPeriodsRecording, name="silence_artifacts" +) diff --git a/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py new file mode 100644 index 0000000000..e0db644d87 --- /dev/null +++ b/src/spikeinterface/preprocessing/tests/test_detect_artifacts.py @@ -0,0 +1,243 @@ +import numpy as np + +from spikeinterface.core import generate_recording, NumpyRecording +from spikeinterface.preprocessing import ( + detect_artifact_periods, + detect_saturation_periods, + detect_artifact_periods_by_envelope, +) + + +def test_detect_artifact_by_envelope(debug_plots): + # one segment only + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below + + # Generate some data in uV + sat_value = 1200 + noise_level = 10 + rng = np.random.default_rng(42) + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 + + artifact_starts = rng.choice(np.arange(0, data.shape[0] - 1000), size=10, replace=False) + artifact_stops = artifact_starts + 100 + for start, stop in zip(artifact_starts, artifact_stops): + data[start:stop, :] = sat_value + + recording = NumpyRecording(data, sampling_frequency) + + artifacts, envelope = detect_artifact_periods_by_envelope( + recording, apply_envelope_common_reference=False, return_envelope=True, random_slices_kwargs={"seed": 2308} + ) + + if debug_plots: + import matplotlib + import matplotlib.pyplot as plt + + plt.plot(envelope.get_traces(), color="r", lw=3) + plt.title("data float") + plt.show() + + # it finds some artifacts + assert len(artifacts) > 0 + + +def test_detect_saturation_periods(debug_plots): + """ + This tests the saturation detection method. First a mock recording is created with + saturation events. Events may be single-sample or a multi-sample period. We create a multi-segment + recording with the stop-sample of each event offset by one, so the segments are distinguishable. + + Saturation detection is performed on chunked data (we set to 30k sample chunks) and so injected + events are hard-coded in order to cross a chunk boundary to test this case. + + The saturation detection function tests both a) saturation threshold exceeded + and b) first derivative (velocity) threshold exceeded. Because the forward + derivative is taken, the sample before the first saturated sample is also flagged. + Also, because of the way the mask is computed in the function, the sample after the + last saturated sample is flagged. + """ + import scipy.signal + + num_chans = 32 + sampling_frequency = 30000 + chunk_size = 30000 # This value is critical to ensure hard-coded start / stops below + job_kwargs = {"chunk_size": chunk_size} + + # Generate some data in uV + sat_value = 1200 + diff_threshold_uV = 200 # 200 uV/sample + noise_level = 10 + rng = np.random.default_rng() + data = noise_level * rng.uniform(low=-0.5, high=0.5, size=(150000, num_chans)) * 10 + + # Design the Butterworth filter + sos = scipy.signal.butter(N=3, Wn=8000 / (sampling_frequency / 2), btype="low", output="sos") + + # Apply the filter to the data + data_seg_1 = scipy.signal.sosfiltfilt(sos, data, axis=0) + data_seg_2 = data_seg_1.copy() + + # Add test saturation at the start, end of recording + # as well as across and within chunks (30k samples). + # Two cases which are not tested are a single event + # exactly on the border, as it makes testing complex + # This was checked manually and any future breaking change + # on this function would be extremely unlikely only to break this case. + all_starts = np.array([0, 29950, 45123, 90005, 149500]) + all_stops = np.array([1001, 30011, 45126, 90006, 149999]) + + second_seg_stop_offset = 10 + for start, stop in zip(all_starts, all_stops): + data_seg_1[start:stop, :] = sat_value + # differentiate the second segment for testing purposes + data_seg_2[start : stop + second_seg_stop_offset, :] = sat_value + + # Add slow artifact + start_slow_artifact = 6100 + stop_slow_artifact = 6300 + accepted_slope = diff_threshold_uV * 0.9 + start_rising_sample = int(np.floor(start_slow_artifact - sat_value / accepted_slope)) + stop_falling_sample = int(np.ceil(stop_slow_artifact + sat_value / accepted_slope)) + + offsets = [0, second_seg_stop_offset] + data_segs = [data_seg_1, data_seg_2] + for offset, data_seg in zip(offsets, data_segs): + start_rising = start_rising_sample + stop_rising = start_slow_artifact + start_falling = stop_slow_artifact + offset + stop_falling = stop_falling_sample + offset + data_seg[stop_rising:start_falling, :] = sat_value + data_seg[start_rising:stop_rising, :] = np.tile( + (accepted_slope * np.arange(stop_rising - start_rising))[:, None], (1, num_chans) + ) + data_seg[start_falling:stop_falling, :] = np.tile( + (sat_value - accepted_slope * np.arange(stop_falling - start_falling))[:, None], (1, num_chans) + ) + + # Add start and stop of slow artifact to start/stops + all_starts = np.sort(np.append(all_starts, start_slow_artifact)) + all_stops = np.clip(np.sort(np.append(all_stops, stop_slow_artifact)), a_min=0, a_max=data_seg_1.shape[0] - 1) + + gain = 2.34 # mimic NP1.0 + offset = 0 + + if debug_plots: + import matplotlib + import matplotlib.pyplot as plt + + plt.plot(data_seg_1) + plt.title("data float") + plt.show() + plt.plot(np.diff(data_seg_1, axis=0)) + plt.title("diff float") + plt.show() + + seg_1_int16 = np.clip(np.rint((data_seg_1 - offset) / gain), -32768, 32767).astype(np.int16) + seg_2_int16 = np.clip(np.rint((data_seg_2 - offset) / gain), -32768, 32767).astype(np.int16) + + if debug_plots: + plt.plot(seg_1_int16) + plt.title("data int") + plt.show() + plt.plot(np.diff(seg_1_int16, axis=0)) + plt.title("diff int") + plt.show() + + recording = NumpyRecording([seg_1_int16, seg_2_int16], sampling_frequency) + recording.set_channel_gains(gain) + recording.set_channel_offsets([offset] * num_chans) + + periods = detect_saturation_periods( + recording, + saturation_threshold_uV=sat_value * 0.98, + diff_threshold_uV=diff_threshold_uV, + job_kwargs=job_kwargs, + ) + + seg_1_periods = periods[np.where(periods["segment_index"] == 0)] + seg_2_periods = periods[np.where(periods["segment_index"] == 1)] + + # For the start times, all are one sample before the actual saturated + # period starts because the derivative threshold is exceeded at one + # sample before the saturation starts. Therefore this one-sample-offset + # on the start times is an implicit test that the derivative + # threshold is working properly. + tolerance_samples = 1 + offsets = np.array([0, second_seg_stop_offset]) + for seg_periods, offset in zip([seg_1_periods, seg_2_periods], offsets): + starts = seg_periods["start_sample_index"] + stops = seg_periods["end_sample_index"] + start_diffs = np.abs(starts - all_starts) + assert np.all(start_diffs <= tolerance_samples) + stop_diffs = np.abs(stops - np.clip(all_stops + offset, a_min=0, a_max=data_seg_1.shape[0] - 1)) + assert np.all(stop_diffs <= tolerance_samples) + + # Check that slow rising and falling phases are not in periods + # The ramp slope is 90% of diff_threshold_uV, so they should not be detected. + for seg_periods, seg_offset in zip([seg_1_periods, seg_2_periods], offsets): + slow_period_idx = np.argmin(np.abs(seg_periods["start_sample_index"] - start_slow_artifact)) + slow_period = seg_periods[slow_period_idx] + assert ( + slow_period["start_sample_index"] >= start_rising_sample + tolerance_samples + ), "Slow artifact period starts in the rising phase" + assert ( + slow_period["end_sample_index"] <= stop_falling_sample + seg_offset - tolerance_samples + ), "Slow artifact period ends in the falling phase" + + # Just do a quick test that a threshold slightly over the sat value is not detected. + # In this case we only see the derivative threshold detection. We do not play around with this + # threshold because the derivative threshold is not easy to predict (the baseline sample is random). + periods_only_diff = detect_saturation_periods( + recording, + saturation_threshold_uV=sat_value * 1.02, + diff_threshold_uV=diff_threshold_uV, + job_kwargs=job_kwargs, + ) + assert abs(periods_only_diff["start_sample_index"][0] - 1000) <= tolerance_samples + assert abs(periods_only_diff["end_sample_index"][0] - 1001) <= tolerance_samples + + # Test that the same result is obtained with the detect_artifact_periods function with method="saturation" and the + # same parameters. + periods_entry_function = detect_artifact_periods( + recording, + method="saturation", + method_kwargs=dict( + saturation_threshold_uV=sat_value * 0.98, + diff_threshold_uV=diff_threshold_uV, + ), + job_kwargs=job_kwargs, + ) + assert np.array_equal(periods, periods_entry_function) + + # Test that the same result is obtained with multiple jobs + job_kwargs = {"chunk_size": chunk_size, "n_jobs": 2} + periods_entry_function_parallel = detect_artifact_periods( + recording, + method="saturation", + method_kwargs=dict( + saturation_threshold_uV=sat_value * 0.98, + diff_threshold_uV=diff_threshold_uV, + ), + job_kwargs=job_kwargs, + ) + assert np.array_equal(periods, periods_entry_function_parallel) + + # Test that the same result is obtained with saturation_threshold_uV annotation + recording.annotate(saturation_threshold_uV=sat_value * 0.98) + periods_entry_with_annotation = detect_artifact_periods( + recording, + method="saturation", + method_kwargs=dict( + saturation_threshold_uV=None, + diff_threshold_uV=diff_threshold_uV, + ), + job_kwargs=job_kwargs, + ) + assert np.array_equal(periods, periods_entry_with_annotation) + + +if __name__ == "__main__": + # test_detect_artifact_by_envelope(True) + test_detect_saturation_periods(True) diff --git a/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py index c2cdbeb3db..861fa23f6f 100644 --- a/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py +++ b/src/spikeinterface/preprocessing/tests/test_grouped_preprocessing.py @@ -40,7 +40,9 @@ def test_grouped_preprocessing(): sp_recording_1 = silence_periods(recording_1, list_periods=list_periods, mode=mode, seed=seed) sp_recording_2 = silence_periods(recording_2, list_periods=list_periods, mode=mode, seed=seed) - dict_of_silence_period_recordings = silence_periods(dict_of_recordings, list_periods, mode=mode, seed=seed) + dict_of_silence_period_recordings = silence_periods( + dict_of_recordings, list_periods=list_periods, mode=mode, seed=seed + ) check_recordings_equal(dict_of_silence_period_recordings["one"], sp_recording_1) check_recordings_equal(dict_of_silence_period_recordings["two"], sp_recording_2) diff --git a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py b/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py deleted file mode 100644 index 2baa4bf1b3..0000000000 --- a/src/spikeinterface/preprocessing/tests/test_silence_artifacts.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -import numpy as np - -from spikeinterface.core import generate_recording -from spikeinterface.preprocessing import silence_artifacts - - -def test_silence_artifacts(): - # one segment only - rec = generate_recording(durations=[10.0, 10]) - new_rec = silence_artifacts(rec, detect_threshold=5, freq_max=5.0, min_duration_ms=50) - - -if __name__ == "__main__": - test_silence_artifacts() diff --git a/src/spikeinterface/preprocessing/tests/test_silence.py b/src/spikeinterface/preprocessing/tests/test_silence_periods.py similarity index 76% rename from src/spikeinterface/preprocessing/tests/test_silence.py rename to src/spikeinterface/preprocessing/tests/test_silence_periods.py index e7aee1a84d..44bd205f1b 100644 --- a/src/spikeinterface/preprocessing/tests/test_silence.py +++ b/src/spikeinterface/preprocessing/tests/test_silence_periods.py @@ -1,12 +1,11 @@ import pytest from spikeinterface.core import generate_recording - +from spikeinterface.core import get_noise_levels +from spikeinterface.core.base import base_period_dtype from spikeinterface.preprocessing import silence_periods -from spikeinterface.core import get_noise_levels - import numpy as np from pathlib import Path @@ -18,17 +17,20 @@ def test_silence(create_cache_folder): rec = generate_recording() - rec0 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="zeros", seed=2308) - rec0.save(verbose=False) + periods = np.array([(0, 0, 1000), (0, 5000, 6000)], dtype=base_period_dtype) + rec0 = silence_periods(rec, periods=periods, mode="zeros", seed=2308) + rec0.save(format="memory", verbose=False) traces_in0 = rec0.get_traces(segment_index=0, start_frame=0, end_frame=1000) - traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) - traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert np.all(traces_in0 == 0) + traces_half0 = rec0.get_traces(segment_index=0, start_frame=900, end_frame=1100) + assert np.all(traces_half0[:100] == 0) + traces_in1 = rec0.get_traces(segment_index=0, start_frame=5000, end_frame=6000) assert np.all(traces_in1 == 0) + traces_out0 = rec0.get_traces(segment_index=0, start_frame=2000, end_frame=3000) assert not np.all(traces_out0 == 0) - rec1 = silence_periods(rec, list_periods=[[[0, 1000], [5000, 6000]], []], mode="noise", seed=2308) - rec1 = rec1.save(folder=cache_folder / "rec_w_noise", verbose=False, overwrite=True) + rec1 = silence_periods(rec, periods=periods, mode="noise", seed=2308) + rec1 = rec1.save(format="memory", verbose=False, overwrite=True) noise_levels = get_noise_levels(rec, return_in_uV=False) traces_in0 = rec1.get_traces(segment_index=0, start_frame=0, end_frame=1000) traces_in1 = rec1.get_traces(segment_index=0, start_frame=5000, end_frame=6000)