Source code for ossdbs.stimulation_signals.utilities

# Copyright 2023, 2024 Julius Zimmermann
# SPDX-License-Identifier: GPL-3.0-or-later

import logging

import numpy as np
from scipy.fft import irfft

_logger = logging.getLogger(__name__)


[docs]def adjust_cutoff_frequency(cutoff_frequency, frequency): """Function to make cutoff frequency multiple of stimulation frequency.""" return cutoff_frequency - cutoff_frequency % frequency
[docs]def get_timesteps( cutoff_frequency: float, base_frequency: float, n_frequencies: int ) -> np.ndarray: """Return list with timesteps.""" cutoff_frequency = adjust_cutoff_frequency(2.0 * cutoff_frequency, base_frequency) dt = 1.0 / cutoff_frequency timesteps = dt * np.arange(n_frequencies) return timesteps
[docs]def retrieve_time_domain_signal_from_fft( fft_signal: np.ndarray, cutoff_frequency: float, base_frequency: float, signal_length: int, ) -> tuple[np.ndarray, np.ndarray]: """Compute time-domain signal via fft. Parameters ---------- fft_signal: np.ndarray Frequency-domain signal signal_length: int Length of original time-domain signal cutoff_frequency: float Highest considered frequency base_frequency:float Frequency of time-domain signal (often 130 Hz) """ # double the cutoff_frequency to actually sample until there signal = irfft(fft_signal, n=signal_length) timesteps = get_timesteps(cutoff_frequency, base_frequency, len(signal)) return timesteps, signal
[docs]def reconstruct_time_signals( freq_domain_signal: np.ndarray, signal_length: int ) -> np.ndarray: """Compute time signals from frequency-domain data. Parameters ---------- freq_domain_signal: np.ndarray Frequency-domain signal to be transformed signal_length: int Length of initial time-domain signal """ # run irfft with maximum possible amount of workers result_in_time = irfft(freq_domain_signal, n=signal_length, axis=0, workers=-1) return result_in_time
[docs]def get_octave_band_indices(frequencies: np.ndarray) -> np.ndarray: """Return indices of octave band frequencies.""" n_octaves = int(np.log2(len(frequencies) - 1)) + 1 octave_indices = 2 ** np.arange(0, n_octaves) return octave_indices
[docs]def get_minimum_octave_band_index(freq_idx: int) -> int: """Get index of lowest frequency in octave band.""" return int(np.round(freq_idx / np.sqrt(2)))
[docs]def get_maximum_octave_band_index(freq_idx: int) -> int: """Get index of highest frequency in octave band.""" return int(np.round(freq_idx * np.sqrt(2)))
[docs]def get_indices_in_octave_band( freq_idx: int, frequency_indices: list, cutoff_frequency_index: int ) -> list | np.ndarray: """Get indices of frequencies in octave band. Notes ----- We start evaluating from the bottom. I.e., it is checked if there is an overlap with frequencies from the octave band below (already computed). The minimum frequencies are increased until there is no overlap with the previous band. """ min_freq = get_minimum_octave_band_index(freq_idx) max_freq = get_maximum_octave_band_index(freq_idx) list_index = np.argwhere(freq_idx == frequency_indices) if list_index.shape != (1, 1): raise ValueError("Wrong frequencies for band evaluation supplied") list_index = list_index[0][0] if freq_idx > 0: max_of_prev_band = get_maximum_octave_band_index( frequency_indices[list_index - 1] ) if min_freq == frequency_indices[list_index - 1]: min_freq = freq_idx else: max_of_prev_band = -1 # catch if the octave band frequency is equal to another center frequency if freq_idx < frequency_indices[-1]: if max_freq == frequency_indices[list_index + 1]: max_freq = freq_idx else: # if band exceeds cutoff max_freq = cutoff_frequency_index # catch if the octave band frequency is overlapping with the band below while min_freq <= max_of_prev_band: min_freq += 1 band_indices = np.arange(min_freq, max_freq + 1) if len(band_indices) == 0: band_indices = [freq_idx] _logger.debug(f"Band indices from {band_indices[0]} to {band_indices[-1]}") return band_indices