# Copyright 2023, 2024 Johannes Reding, Julius Zimmermann
# SPDX-License-Identifier: GPL-3.0-or-later
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
import matplotlib.pyplot as plt
import numpy as np
from scipy.fft import rfft, rfftfreq
from .utilities import adjust_cutoff_frequency, retrieve_time_domain_signal_from_fft
[docs]@dataclass
class FrequencyDomainSignal:
"""Store information for freqency domain signal."""
frequencies: np.ndarray
amplitudes: np.ndarray
current_controlled: bool
base_frequency: float
cutoff_frequency: float
signal_length: int
octave_band_approximation: bool = False
[docs]class TimeDomainSignal(ABC):
"""Template for Signals.
Parameters
----------
frequency : float
Frequency [Hz] of the signal.
pulse_width : float
Relative pulse width of one period.
counter_pulse_width: float
Relative width of counter pulse of one period.
inter_pulse_width: float
Relative width between pulse and counter pulse of one period.
Notes
-----
Amplitudes are relative: the primary pulse has amplitude 1.0 and the
counter pulse amplitude is given as a fraction (e.g. 0.5 means half
the primary amplitude). The actual voltage or current scaling is
applied externally by the volume conductor model.
"""
def __init__(
self,
frequency: float,
pulse_width: float,
inter_pulse_width: float | None = 0.0,
counter_pulse_width: float | None = 0.0,
counter_pulse_amplitude: float | None = 1.0,
) -> None:
if np.isclose(frequency, 0):
raise ValueError("Frequency must be greater than zero.")
self._frequency = frequency
self._pulse_width = pulse_width
self._inter_pulse_width = inter_pulse_width
self._counter_pulse_width = counter_pulse_width
# the values here are relative amplitudes
# e.g., if the signal is 1V and the counter_pulse_amplitude is 0.5
# its amplitude will be 0.5V (same for 1mA)
self._amplitude = 1.0
self._counter_amplitude = counter_pulse_amplitude
@property
def amplitude(self) -> float:
"""Return signal amplitude."""
return self._amplitude
@amplitude.setter
def amplitude(self, value) -> None:
"""Set amplitude value."""
self._amplitude = value
@property
def counter_amplitude(self) -> float:
"""Get amplitude of counterpulse."""
return self._counter_amplitude
@counter_amplitude.setter
def counter_amplitude(self, value) -> None:
self._counter_amplitude = value
@property
def frequency(self) -> float:
"""Return frequency of signal.
Returns
-------
float
"""
return self._frequency
@frequency.setter
def frequency(self, value):
"""Set frequency of signal."""
self._frequency = value
[docs] def get_adjusted_cutoff_frequency(self, cutoff_frequency: float) -> float:
"""Adjust cutoff frequency to signal frequency.
Double the cutoff frequency to account for FFT and actually sample until there.
"""
return adjust_cutoff_frequency(2.0 * cutoff_frequency, self.frequency)
[docs] def get_fft_spectrum(
self, cutoff_frequency: float
) -> tuple[np.ndarray, np.ndarray, int]:
"""FFT spectrum of time-domain signal.
Parameters
----------
cutoff_frequency: float
Highest considered frequency.
"""
cutoff_frequency = self.get_adjusted_cutoff_frequency(cutoff_frequency)
dt = 1.0 / cutoff_frequency
# required length for frequency
timesteps = int(cutoff_frequency / self.frequency)
time_domain_signal = self.get_time_domain_signal(dt, timesteps)
return (
rfftfreq(len(time_domain_signal), d=dt),
np.asarray(rfft(time_domain_signal)),
len(time_domain_signal),
)
[docs] def retrieve_time_domain_signal(
self, fft_signal: np.ndarray, cutoff_frequency: float, signal_length: int
) -> tuple[np.ndarray, np.ndarray]:
"""Compute time-domain signal by FFT."""
return retrieve_time_domain_signal_from_fft(
fft_signal, cutoff_frequency, self.frequency, signal_length
)
[docs] @abstractmethod
def get_time_domain_signal(self, dt: float, timesteps: int) -> np.ndarray:
"""Time-domain signal for given timestep."""
pass
[docs] def plot_time_domain_signal(self, cutoff_frequency, output_path, show=False):
"""Plot signal and export to PDF."""
cutoff_frequency = adjust_cutoff_frequency(
2.0 * cutoff_frequency, self.frequency
)
dt = 1.0 / cutoff_frequency
# required length for frequency
timesteps = int(cutoff_frequency / self.frequency)
time_domain_signal = self.get_time_domain_signal(dt, timesteps)
plt.plot(dt * np.arange(0, timesteps), time_domain_signal)
plt.xlabel("Time / s")
plt.ylabel("Signal / arb. u.")
plt.savefig(os.path.join(output_path, "time_domain_signal.pdf"))
if show:
plt.show()
else:
plt.close()
[docs] def get_active_time(self) -> float:
"""Return time during which the stimulator is active."""
return self._pulse_width + self._inter_pulse_width + self._counter_pulse_width