# Copyright 2023, 2024 Konstantin Butenko, Shruthi Chakravarthy
# Copyright 2023, 2024 Jan Philipp Payonk, Julius Zimmermann
# SPDX-License-Identifier: GPL-3.0-or-later
import importlib
import json
import logging
import os
import numpy as np
from ossdbs.dielectric_model import (
default_dielectric_parameters,
dielectric_model_parameters,
dielectric_models,
)
from ossdbs.electrodes import ELECTRODE_MODELS, ELECTRODE_PARAMETERS, ELECTRODES
from ossdbs.fem import (
PRECONDITIONERS,
SOLVERS,
Mesh,
VolumeConductor,
VolumeConductorFloating,
VolumeConductorFloatingImpedance,
VolumeConductorNonFloating,
)
from ossdbs.model_geometry import BoundingBox, BrainGeometry, ModelGeometry
from ossdbs.point_analysis import Lattice, Pathway, VoxelLattice
from ossdbs.stimulation_signals import (
FrequencyDomainSignal,
RectangleSignal,
TimeDomainSignal,
TrapezoidSignal,
TriangleSignal,
)
from ossdbs.utils.nifti1image import DiffusionTensorImage, MagneticResonanceImage
_logger = logging.getLogger(__name__)
PAM_AVAILABLE = importlib.util.find_spec("neuron") is not None
if not PAM_AVAILABLE:
_logger.warning("NEURON is not installed, disabling PAM analysis!")
[docs]def create_bounding_box(box_parameters: dict) -> BoundingBox:
"""Create a bounding box around a domain in space.
Notes
-----
`box_parameters` need to contain the center (three coordinates,
all following the style `x[mm]` (and similar for y and z-component))
and the outer dimensions.
"""
input_s = box_parameters["Dimension"]
input_c = box_parameters["Center"]
shape = (input_s["x[mm]"], input_s["y[mm]"], input_s["z[mm]"])
center = (input_c["x[mm]"], input_c["y[mm]"], input_c["z[mm]"])
start = center - np.divide(shape, 2)
end = start + shape
return BoundingBox(tuple(start), tuple(end))
[docs]def generate_electrodes(settings: dict):
"""Generate an OCC electrode model from the settings dict."""
_logger.info("Generate electrode geometries")
hp_refinement = False
if settings["Mesh"].get("HPRefinement", False):
hp_refinement = settings["Mesh"]["HPRefinement"]["Active"]
electrodes = []
for electrode_parameters in settings["Electrodes"]:
name = electrode_parameters["Name"]
direction = (
electrode_parameters["Direction"]["x[mm]"],
electrode_parameters["Direction"]["y[mm]"],
electrode_parameters["Direction"]["z[mm]"],
)
rotation = electrode_parameters["Rotation[Degrees]"]
position = (
electrode_parameters["TipPosition"]["x[mm]"],
electrode_parameters["TipPosition"]["y[mm]"],
electrode_parameters["TipPosition"]["z[mm]"],
)
# Implemented custom electrodes without using custom_electrodes.py
if "Custom" in name:
electrode_model = ELECTRODE_MODELS[name]
parameter_class = ELECTRODE_PARAMETERS[electrode_model.__name__]
custom_list = electrode_parameters["CustomParameters"]
electrode = electrode_model(
parameters=parameter_class(**custom_list),
direction=direction,
position=position,
rotation=rotation,
)
else:
electrode_type = ELECTRODES[name]
electrode = electrode_type(
direction=direction,
position=position,
rotation=rotation,
)
if hp_refinement:
electrode.set_hp_flag(electrode_parameters=electrode_parameters)
if "EncapsulationLayer" in electrode_parameters:
electrode.encapsulation_thickness = electrode_parameters[
"EncapsulationLayer"
]["Thickness[mm]"]
electrodes.append(electrode)
if settings.get("ExportElectrode", False):
n_electrode = 0
for electrode in electrodes:
n_electrode = n_electrode + 1
electrode.export_electrode(
settings["OutputPath"], settings["BrainRegion"], n_electrode
)
return electrodes
[docs]def prepare_dielectric_properties(settings: dict) -> dict:
"""Return dictionary with dielectric properties for each tissue."""
_logger.info("Prepare dielectric model")
dielectric_settings = settings["DielectricModel"]
model_type = dielectric_settings["Type"]
custom_parameters = dielectric_settings["CustomParameters"]
# create empty dict for collection of dielectric models
dielectric_properties = {}
dielectric_model = dielectric_models[model_type]
parameter_template = dielectric_model_parameters[model_type]
default_parameters = default_dielectric_parameters[model_type]
for material in settings["MaterialDistribution"]["MRIMapping"]:
if custom_parameters is not None:
model_parameters = parameter_template(**custom_parameters[material])
else:
model_parameters = default_parameters[material]
dielectric_properties[material] = dielectric_model(model_parameters)
return dielectric_properties
[docs]def generate_brain_model(settings, rotate_initial_geo: bool = False):
"""Generate OCC brain model."""
brain_region_parameters = settings["BrainRegion"]
brain_shape = brain_region_parameters["Shape"]
brain_region = create_bounding_box(brain_region_parameters)
brain_model = BrainGeometry(
brain_shape, brain_region, rotate_initial_geo=rotate_initial_geo
)
return brain_model
[docs]def generate_model_geometry(settings):
"""Generate a full geometry comprising brain and electrodes."""
brain = generate_brain_model(settings)
electrodes = generate_electrodes(settings)
try:
model_geometry = ModelGeometry(brain, electrodes)
except RuntimeError:
_logger.warning(
"Could not build geometry, trying again after rotation of initial geometry"
)
brain = generate_brain_model(settings, rotate_initial_geo=True)
model_geometry = ModelGeometry(brain, electrodes)
return model_geometry
[docs]def build_brain_model(
settings,
mri_image: MagneticResonanceImage | None = None,
rotate_initial_geo: bool = False,
) -> BrainGeometry:
"""Build geometry model of brain."""
# MRI image is default choice for brain construction
if "BrainRegion" in settings:
_logger.debug("Generating model geometry for fixed brain region")
region_parameters = settings["BrainRegion"]
brain_region = create_bounding_box(region_parameters)
shape = settings["BrainRegion"]["Shape"]
return BrainGeometry(shape, brain_region, rotate_initial_geo=rotate_initial_geo)
else:
_logger.debug("Generating model geometry from MRI image")
if mri_image is None:
raise ValueError("Need to provide MRI image to build geo.")
# attention: bounding box is given in voxel space!
brain_region = mri_image.bounding_box
shape = "Ellipsoid"
# transformation to real space in geometry creation
_logger.debug(
"Generate OCC model, passing transformation matrix from MRI image"
)
return BrainGeometry(
shape,
brain_region,
trafo_matrix=mri_image.trafo_matrix,
translation=mri_image.translation,
rotate_intial_geo=rotate_initial_geo,
)
[docs]def set_custom_mesh_sizes(settings, model_geometry):
"""Update the mesh sizes."""
model_geometry.set_mesh_sizes(settings["Mesh"]["MeshSize"])
[docs]def generate_mesh(settings):
"""Generate a mesh from settings.
Notes
-----
Attention! This mesh is not yet curved!
"""
model_geometry = generate_model_geometry(settings)
set_contact_and_encapsulation_layer_properties(settings, model_geometry)
if "MeshSize" in settings["Mesh"]:
set_custom_mesh_sizes(settings, model_geometry)
mesh_settings = settings["Mesh"]
mesh_order = 1
mesh = Mesh(model_geometry.geometry, mesh_order)
if mesh_settings["LoadMesh"]:
mesh.load_mesh(mesh_settings["LoadPath"])
return mesh
mesh_settings.setdefault("MeshingHypothesis", {"Type": "Default"})
mesh.generate_mesh(mesh_settings)
# Apply HP refinement immediately for standalone mesh generation
# (no subsequent bisection-based refinement expected)
mesh.apply_hp_refinement()
if mesh_settings["SaveMesh"]:
mesh.save(mesh_settings["SavePath"])
return mesh
[docs]def validate_solver_settings(settings: dict, model_geometry: ModelGeometry) -> None:
"""Validate and adjust solver settings based on model configuration.
Notes
-----
BDDC preconditioner does not work well with the FloatingImpedance
formulation. This function enforces the use of 'local' preconditioner
in such cases.
"""
floating_mode = model_geometry.get_floating_mode()
preconditioner = settings["Solver"].get("Preconditioner", "bddc")
if floating_mode == "FloatingImpedance":
if preconditioner == "bddc":
_logger.warning(
"BDDC preconditioner is not compatible with FloatingImpedance "
"formulation. Switching to 'local' preconditioner."
)
settings["Solver"]["Preconditioner"] = "local"
settings["Solver"]["PreconditionerKwargs"] = {}
[docs]def prepare_solver(settings):
"""Set up solver and preconditioner."""
_logger.info("Preparing solver")
parameters = settings["Solver"]
solver_type = parameters["Type"]
solver = SOLVERS[solver_type]
preconditioner_kwargs = parameters["PreconditionerKwargs"]
preconditioner = PRECONDITIONERS[parameters["Preconditioner"]](
**preconditioner_kwargs
)
return solver(
precond_par=preconditioner,
maxsteps=parameters["MaximumSteps"],
precision=parameters["Precision"],
)
[docs]def generate_point_models(settings: dict):
"""Generate a list of point models."""
point_models = []
if settings["PointModel"]["Pathway"]["Active"]:
file_name = settings["PointModel"]["Pathway"]["FileName"]
_logger.info(f"Import neuron geometries stored in {file_name}")
export_field = settings["PointModel"]["Pathway"]["ExportField"]
point_models.append(Pathway(file_name, export_field=export_field))
if settings["PointModel"]["Lattice"]["Active"]:
shape_par = settings["PointModel"]["Lattice"]["Shape"]
shape = shape_par["x"], shape_par["y"], shape_par["z"]
center_par = settings["PointModel"]["Lattice"]["Center"]
center = center_par["x[mm]"], center_par["y[mm]"], center_par["z[mm]"]
dir_par = settings["PointModel"]["Lattice"]["Direction"]
direction = dir_par["x[mm]"], dir_par["y[mm]"], dir_par["z[mm]"]
distance = settings["PointModel"]["Lattice"]["PointDistance[mm]"]
collapse_vta = settings["PointModel"]["Lattice"]["CollapseVTA"]
export_field = settings["PointModel"]["Lattice"]["ExportField"]
point_models.append(
Lattice(
shape=shape,
center=center,
distance=distance,
direction=direction,
collapse_vta=collapse_vta,
export_field=export_field,
)
)
if settings["PointModel"]["VoxelLattice"]["Active"]:
_logger.info("from voxel lattice")
center_par = settings["PointModel"]["Lattice"]["Center"]
center = center_par["x[mm]"], center_par["y[mm]"], center_par["z[mm]"]
mri_image = MagneticResonanceImage(settings["MaterialDistribution"]["MRIPath"])
affine = mri_image.affine
header = mri_image.header
shape_par = settings["PointModel"]["VoxelLattice"]["Shape"]
shape = np.array([shape_par["x"], shape_par["y"], shape_par["z"]])
export_field = settings["PointModel"]["VoxelLattice"]["ExportField"]
point_models.append(
VoxelLattice(center, affine, shape, header, export_field=export_field)
)
return point_models
[docs]def generate_signal(settings) -> TimeDomainSignal:
"""Generate a time-domain signal (waveform)."""
signal_settings = settings["StimulationSignal"]
signal_type = signal_settings["Type"]
if signal_type == "Rectangle":
signal = RectangleSignal(
signal_settings["Frequency[Hz]"],
1e-6 * signal_settings["PulseWidth[us]"],
1e-6 * signal_settings["InterPulseWidth[us]"],
1e-6 * signal_settings["CounterPulseWidth[us]"],
signal_settings["CounterAmplitude"],
)
elif signal_type == "Triangle":
signal = TriangleSignal(
signal_settings["Frequency[Hz]"],
1e-6 * signal_settings["PulseWidth[us]"],
1e-6 * signal_settings["InterPulseWidth[us]"],
1e-6 * signal_settings["CounterPulseWidth[us]"],
signal_settings["CounterAmplitude"],
)
elif signal_type == "Trapezoid":
signal = TrapezoidSignal(
signal_settings["Frequency[Hz]"],
1e-6 * signal_settings["PulseWidth[us]"],
1e-6 * signal_settings["InterPulseWidth[us]"],
1e-6 * signal_settings["CounterPulseWidth[us]"],
1e-6 * signal_settings["PulseTopWidth[us]"],
signal_settings["CounterAmplitude"],
)
signal.plot_time_domain_signal(
signal_settings["CutoffFrequency"], settings["OutputPath"]
)
return signal
[docs]def prepare_volume_conductor_model(
settings, model_geometry, conductivity, solver
) -> VolumeConductor:
"""Prepare the volume conductor model."""
_logger.info("Generate volume conductor model")
order = settings["FEMOrder"]
mesh_parameters = settings["Mesh"]
floating_mode = model_geometry.get_floating_mode()
output_path = settings["OutputPath"]
_logger.info(f"Output path set to: {output_path}")
if floating_mode == "Floating":
_logger.debug("Floating mode selected")
return VolumeConductorFloating(
model_geometry, conductivity, solver, order, mesh_parameters, output_path
)
elif floating_mode == "FloatingImpedance":
_logger.debug("FloatingImpedance mode selected")
return VolumeConductorFloatingImpedance(
model_geometry, conductivity, solver, order, mesh_parameters, output_path
)
_logger.debug("Non floating mode selected")
return VolumeConductorNonFloating(
model_geometry, conductivity, solver, order, mesh_parameters, output_path
)
[docs]def prepare_stimulation_signal(settings) -> FrequencyDomainSignal:
"""Prepare the frequency-domain representation of stimulation signal."""
signal_settings = settings["StimulationSignal"]
signal_type = signal_settings["Type"]
current_controlled = signal_settings["CurrentControlled"]
octave_band_approximation = False
if signal_type == "Multisine":
fft_frequencies = signal_settings["ListOfFrequencies"]
fft_coefficients = np.ones(len(fft_frequencies))
base_frequency = fft_frequencies[0]
cutoff_frequency = fft_frequencies[0]
signal_length = len(fft_frequencies)
else:
spectrum_mode = signal_settings["SpectrumMode"]
if spectrum_mode == "OctaveBand":
octave_band_approximation = True
signal = generate_signal(settings)
cutoff_frequency = signal_settings["CutoffFrequency"]
base_frequency = signal.frequency
fft_frequencies, fft_coefficients, signal_length = signal.get_fft_spectrum(
cutoff_frequency
)
frequency_domain_signal = FrequencyDomainSignal(
frequencies=fft_frequencies,
amplitudes=fft_coefficients,
current_controlled=current_controlled,
base_frequency=base_frequency,
cutoff_frequency=cutoff_frequency,
signal_length=signal_length,
octave_band_approximation=octave_band_approximation,
)
return frequency_domain_signal
[docs]def run_volume_conductor_model(
settings, volume_conductor, frequency_domain_signal, truncation_time=None
):
"""Run the volume conductor model at all required frequencies.
Solves the FEM system at each frequency in the signal spectrum,
optionally computes impedance and currents, exports results, and
reconstructs the time-domain solution.
Parameters
----------
settings : dict
Complete simulation settings dictionary.
volume_conductor : VolumeConductor
Prepared volume conductor model with mesh and FEM space.
frequency_domain_signal : FrequencyDomainSignal
Signal defining the frequencies and amplitudes to solve.
truncation_time : float, optional
If set, truncate the time-domain reconstruction to this duration.
"""
_logger.info("Run volume conductor model")
out_of_core = settings["OutOfCore"]
compute_impedance = False
if "ComputeImpedance" in settings:
if settings["ComputeImpedance"]:
_logger.info("Will compute impedance at each frequency")
compute_impedance = True
compute_currents = False
if "ComputeCurrents" in settings:
if settings["ComputeCurrents"]:
_logger.info("Will estimate currents at each frequency")
compute_currents = True
if "ExportVTK" in settings:
export_vtk = settings["ExportVTK"]
if export_vtk:
_logger.info("Will export solution to VTK")
else:
export_vtk = False
vtk_subdivision = int(settings.get("ExportVTKSubdivision", 0))
if "ExportFrequency" in settings:
export_frequency = settings["ExportFrequency"]
if export_frequency is not None:
_logger.info(f"Set custom export frequency to {export_frequency}.")
point_models = generate_point_models(settings)
vcm_timings = volume_conductor.run_full_analysis(
frequency_domain_signal,
compute_impedance,
export_vtk,
point_models=point_models,
activation_threshold=settings["ActivationThresholdVTA[V-per-m]"],
dielectric_threshold=settings.get("DielectricAccuracy", 0.01),
out_of_core=out_of_core,
export_frequency=export_frequency,
adaptive_mesh_refinement_settings=settings["Mesh"]["AdaptiveMeshRefinement"],
truncation_time=truncation_time,
estimate_currents=compute_currents,
vtk_subdivision=vtk_subdivision,
)
_run_impedance_analysis(settings, volume_conductor, frequency_domain_signal)
return vcm_timings
def _run_impedance_analysis(
settings, volume_conductor, frequency_domain_signal
) -> None:
"""Run the optional multicontact admittance / impedance analysis.
Triggered by the ``ImpedanceAnalysis`` block in the input JSON.
Produces ``admittance_matrix.csv`` and ``impedance_matrix.csv`` in
the output directory. Decoupled from ``run_full_analysis`` — no
effect on the stimulation output.
"""
block = settings.get("ImpedanceAnalysis") or {}
if not block.get("Enabled"):
return
from ossdbs.fem.analysis import ImpedanceAnalyzer
frequencies = block.get("Frequencies")
if not frequencies:
frequencies = frequency_domain_signal.frequencies
include_floating = block.get("IncludeFloating", True)
_logger.info(
f"Running impedance analysis at {len(frequencies)} frequency "
f"point(s); include_floating={include_floating}"
)
analyzer = ImpedanceAnalyzer(volume_conductor, include_floating=include_floating)
analyzer.compute(frequencies)
analyzer.export(settings["OutputPath"])
[docs]def run_stim_sets(settings, geometry, conductivity, solver, frequency_domain_signal):
"""Run StimSets batch workflow: compute unit solutions per contact.
For each non-ground contact, sets up a unit-current solve (1 A on
that contact, all others floating, ground at -1 A), loads the
pre-saved mesh, applies HP refinement, and runs the full FEM
analysis. Results are stored in per-contact output directories.
Parameters
----------
settings : dict
Complete simulation settings dictionary.
geometry : ModelGeometry
Geometry with electrode and contact definitions.
conductivity : ConductivityCF
Conductivity coefficient function.
solver : Solver
Configured FEM solver.
frequency_domain_signal : FrequencyDomainSignal
Signal defining the frequencies and amplitudes to solve.
"""
_logger.info("Run StimSets volume conductor model")
out_of_core = settings["OutOfCore"]
if not frequency_domain_signal.current_controlled:
_logger.warning(
"StimSets requires current-controlled stimulation"
", thus the setting was switched on"
)
# no vtk export
export_vtk = settings["ExportVTK"]
vtk_subdivision = int(settings.get("ExportVTKSubdivision", 0))
# no intermediate exports
export_frequency = None
# no VTA analysis
activation_threshold = settings["ActivationThresholdVTA[V-per-m]"]
# prepare point model
point_models = generate_point_models(settings)
ground_contact = None
for contact in geometry.contacts:
if np.isclose(contact.current, -1) and contact.active:
ground_contact = contact.name
_logger.info(f"Will skip ground contact {contact.name}")
if ground_contact is None:
raise ValueError(
"No ground contact set. Choose one active contact with current -1."
)
for contact in geometry.contacts:
if contact.name == ground_contact:
continue
# set current contact active, all other passive
for upd_contact in geometry.contacts:
# reset all voltages
contact_idx = geometry.get_contact_index(upd_contact.name)
geometry.update_contact(contact_idx, {"Voltage[V]": 0.0})
# don't change ground
if upd_contact.name == ground_contact:
continue
active = False
floating = True
current = 0.0
voltage = False
if contact.name == upd_contact.name:
active = True
floating = False
current = 1.0
voltage = 1.0
# write new contact settings
geometry.update_contact(
contact_idx,
{
"Floating": floating,
"Active": active,
"Current[A]": current,
"Voltage[V]": voltage,
},
)
volume_conductor = prepare_volume_conductor_model(
settings, geometry, conductivity, solver
)
# The loaded mesh already has h-refinement (material bisection)
# baked in. Apply HP refinement and rebuild the FEM space.
volume_conductor.apply_hp_and_update_space()
_logger.info(f"Running with contacts:\n{volume_conductor.contacts}")
volume_conductor.output_path = settings["OutputPath"] + contact.name
vcm_timings = volume_conductor.run_full_analysis(
frequency_domain_signal,
export_vtk=export_vtk,
point_models=point_models,
activation_threshold=activation_threshold,
out_of_core=out_of_core,
export_frequency=export_frequency,
adaptive_mesh_refinement_settings=settings["Mesh"][
"AdaptiveMeshRefinement"
],
vtk_subdivision=vtk_subdivision,
)
_logger.info(f"Timing for contact {contact.name}: {vcm_timings}")
[docs]def load_images(settings):
"""Load MRI and DTI images."""
_logger.info("Load MRI image")
mri_path = settings["MaterialDistribution"]["MRIPath"]
_logger.debug(f"Input path: {mri_path}")
mri_image = MagneticResonanceImage(mri_path)
dti_image = None
if settings["MaterialDistribution"]["DiffusionTensorActive"]:
_logger.info("Load DTI image")
dti_image = DiffusionTensorImage(settings["MaterialDistribution"]["DTIPath"])
return mri_image, dti_image
[docs]def run_PAM(settings):
"""Run pathway activation analysis."""
if not PAM_AVAILABLE:
raise RuntimeError("PAM not available! Please install NEURON!")
from ossdbs.axon_processing import get_neuron_model
_logger.info("Running PAM")
pathway_file = settings["PathwayFile"]
pathway_solution_dir = settings["OutputPath"]
time_domain_solution = os.path.join(
settings["OutputPath"], "oss_time_result_PAM.h5"
)
with open(pathway_file) as fp:
pathways_dict = json.load(fp)
model_type = pathways_dict["Axon_Model_Type"]
neuron_model = get_neuron_model(model_type, pathways_dict, pathway_solution_dir)
if settings["StimSets"]["Active"]:
settings.setdefault("CurrentVector", None)
# files to load individual solutions from
time_domain_solution_files = []
if settings["StimSets"]["StimSetsFile"] is not None:
_logger.info("Load current vectors form file.")
stim_protocols = np.genfromtxt(
settings["StimSets"]["StimSetsFile"],
dtype=float,
delimiter=",",
names=True,
)
n_stim_protocols = stim_protocols.shape[0]
n_contacts = len(list(stim_protocols[0]))
else:
if settings["CurrentVector"] is None:
raise ValueError("Provide either a StimSetsFile or a CurrentVector")
n_stim_protocols = 1
# load current from input file
stim_protocols = [settings["CurrentVector"]]
# assign contacts
n_contacts = len(stim_protocols[0])
# load unit solutions once
_logger.info("Load unit solutions")
for contact_i in range(n_contacts):
time_domain_solution_files.append(
os.path.join(
settings["OutputPath"] + f"E1C{contact_i + 1}",
"oss_time_result_PAM.h5",
)
)
td_unit_solutions = neuron_model.load_unit_solutions(time_domain_solution_files)
# go through stimulation protocols
_logger.info("Running stimulation protocols")
for protocol_i in range(n_stim_protocols):
# get the scaling vector for the current
scaling_vector = list(stim_protocols[protocol_i])
# swap NaNs to zero current and convert to A (StimSets in mA)
scaling_vector = [0 if np.isnan(x) else 1e-3 * x for x in scaling_vector]
td_solution = neuron_model.superimpose_unit_solutions(
td_unit_solutions, scaling_vector
)
# when using optimizer, scaling_index is not used
if (
settings["CurrentVector"] is not None
and settings["StimSets"]["StimSetsFile"] is None
):
neuron_model.process_pathways(
td_solution, scaling=settings["Scaling"], scaling_index=None
)
else:
neuron_model.process_pathways(
td_solution, scaling=settings["Scaling"], scaling_index=protocol_i
)
else:
td_solution = neuron_model.load_solution(time_domain_solution)
neuron_model.process_pathways(
td_solution,
scaling=settings["Scaling"],
scaling_index=settings["ScalingIndex"],
)