# Copyright 2023, 2024 Konstantin Butenko, Jan Philipp Payonk
# Copyright 2023, 2024 Johannes Reding, Julius Zimmermann
# SPDX-License-Identifier: GPL-3.0-or-later
import json
import logging
import os
import time
from abc import ABC, abstractmethod
import ngsolve
import numpy as np
import pandas as pd
from ossdbs.fem.mesh import Mesh
from ossdbs.fem.solver import Solver
from ossdbs.model_geometry import Contacts, ModelGeometry
from ossdbs.point_analysis import Lattice, PointModel
from ossdbs.stimulation_signals import (
FrequencyDomainSignal,
get_indices_in_octave_band,
get_octave_band_indices,
get_timesteps,
reconstruct_time_signals,
)
from ossdbs.utils import have_dielectric_properties_changed
from ossdbs.utils.vtk_export import FieldSolution
from .conductivity import ConductivityCF
_logger = logging.getLogger(__name__)
[docs]class VolumeConductor(ABC):
"""Template class of a volume conductor.
Parameters
----------
geometry : ModelGeometry
Model geometry, brain with implanted electrodes
conductivity : ConductivityCF
Material information
solver : Solver
Solver (linear algebra part)
order: int
Order of solver and mesh (curved elements)
meshing_parameters: dict
Dictionary with setting for meshing
output_path: str
Path to store solution
"""
def __init__(
self,
geometry: ModelGeometry,
conductivity: ConductivityCF,
solver: Solver,
order: int,
meshing_parameters: dict,
output_path: str = "Results",
) -> None:
self._solver = solver
self._order = order
self._model_geometry = geometry
# contacts of electrode
self._contacts = Contacts(geometry.contacts)
_logger.debug(f"Assigned base contacts with properties:\n {self._contacts}")
self._conductivity_cf = conductivity
self._complex = conductivity.is_complex
# to store impedances at all frequencies
self._impedances = None
# to store the voltage (current-controlled)
# or current (voltage-controlled) stimulation
self._free_stimulation_variable = None
self._stimulation_variable = None
self._floating_potentials = None
self._surface_impedances = None
# set output path
self.output_path = output_path
# generate the mesh
self._mesh = Mesh(self._model_geometry.geometry, self._order)
if meshing_parameters["LoadMesh"]:
self.mesh.load_mesh(meshing_parameters["LoadPath"])
else:
self.mesh.generate_mesh(meshing_parameters)
# Ensure HP refinement params are stored even when loading a
# mesh (generate_mesh stores them, load_mesh does not).
hp_params = meshing_parameters.get("HPRefinement")
if hp_params is not None and hp_params.get("Active"):
self.mesh.set_hp_refinement_params(hp_params)
if meshing_parameters["SaveMesh"]:
self.mesh.save(meshing_parameters["SavePath"])
# to save previous solution and do post-processing
self._frequency = None
self._sigma = None
# frequency at which VTK shall be exported
self._export_frequency = None
# VTA volume (mm^3) computed by direct ngsolve integration
# over the FEM mesh; populated in _frequency_domain_exports
# whenever an ActivationThreshold is configured.
self._vta_volume = None
[docs] @abstractmethod
def compute_solution(self, frequency: float) -> None:
"""Compute solution at frequency.
Parameters
----------
frequency: float
Frequency at which solution is computed
"""
return
[docs] @abstractmethod
def update_space(self):
"""Update space (e.g., if mesh changes)."""
[docs] def prepare_mesh_refinements(self, material_mesh_refinement_steps: int = 0):
"""Apply material and HP mesh refinements."""
self.refine_mesh_by_material(material_mesh_refinement_steps)
# HP refinement must come after bisection-based material refinement
self.mesh.apply_hp_refinement()
self.update_space()
[docs] def apply_h_refinements(self, material_mesh_refinement_steps: int = 0):
"""Apply only h-refinements (material bisection).
Use this when the mesh will be saved for later reuse
(e.g. StimSets). HP refinement is deferred so it can be
applied on each loaded mesh instance.
"""
self.refine_mesh_by_material(material_mesh_refinement_steps)
[docs] def apply_hp_and_update_space(self):
"""Apply deferred HP refinement and rebuild the FEM space.
Call after loading an h-refined mesh to complete the
refinement pipeline.
"""
self.mesh.apply_hp_refinement()
self.update_space()
# ruff: noqa: C901
[docs] def run_full_analysis(
self,
frequency_domain_signal: FrequencyDomainSignal,
compute_impedance: bool = False,
export_vtk: bool = False,
point_models: list[PointModel] | None = None,
activation_threshold: float | None = None,
dielectric_threshold: float = 0.01,
out_of_core: bool = False,
export_frequency: float | None = None,
adaptive_mesh_refinement_settings: dict | None = None,
truncation_time: float | None = None,
estimate_currents: bool | None = False,
vtk_subdivision: int = 0,
) -> dict:
"""Run volume conductor model at all frequencies.
Parameters
----------
compute_impedance: bool
If True, the impedance will be computed at each frequency.
export_vtk: bool
VTK export for visualization in ParaView
point_models: list[PointModel]
list of PointModel to extract solution for VTA / PAM
activation_threshold: float
If VTA is estimated by threshold, provide it here.
Its unit must be V/m!
dielectric_threshold: float
Threshold for accuracy of dielectric properties
out_of_core: bool
Indicate whether point model shall be done out-of-core
export_frequency: float
Frequency at which the VTK file should be exported.
Otherwise, median frequency is used.
frequency_domain_signal: FrequencyDomainSignal
Frequency-domain representation of stimulation signal
adaptive_mesh_refinement_settings: dict
Perform adaptive mesh refinement (only at first frequency)
truncation_time: float
Time until which result will be written to hard drive
estimate_currents: bool
Get current estimate per contact by integration of normal component
vtk_subdivision: int
Element subdivision count forwarded to ``ngsolve.VTKOutput``.
Notes
-----
The volume conductor model is run at all frequencies
and the time-domain signal is computed (if relevant).
"""
self.signal = frequency_domain_signal
# set voltages for two-contact current controlled stimulation
# check if current-controlled mode is correctly prepared
if self.current_controlled:
self.prepare_current_controlled_mode()
if point_models is None:
# use empty list
point_models = []
timings = self.setup_timings_dict(export_vtk, point_models)
dtype = float
if self.is_complex:
dtype = complex
_do_AMR = self._resolve_amr_active(adaptive_mesh_refinement_settings)
multisine_mode = np.all(np.isclose(self.signal.amplitudes, 1.0))
# always compute impedance for CC with 2 contacts
if self.current_controlled and len(self.contacts.active) == 2:
_logger.info(
"Set compute_impedance to True."
"Impedance calculation is required for 2 contacts."
)
compute_impedance = True
if compute_impedance and len(self.contacts.active) != 2:
_logger.warning(
"ComputeImpedance was requested but the configuration has "
f"{len(self.contacts.active)} active contacts (needs exactly 2). "
"Disabling scalar impedance computation. For multicontact "
"configurations, enable the ImpedanceAnalysis block in the "
"input JSON."
)
compute_impedance = False
if self.signal.octave_band_approximation:
frequency_indices = get_octave_band_indices(self.signal.frequencies)
# add DC component
if not np.isclose(self.signal.amplitudes[0], 0.0):
frequency_indices = np.insert(frequency_indices, 0, 0)
else:
frequency_indices = np.arange(len(self.signal.frequencies))
if export_frequency is None:
middle_frequency_index = frequency_indices[int(len(frequency_indices) / 2)]
self._export_frequency = self.signal.frequencies[middle_frequency_index]
_logger.info(f"Set export frequency to {self._export_frequency}")
else:
self._export_frequency = export_frequency
if not multisine_mode:
self._free_stimulation_variable = np.zeros(
shape=(len(self.signal.frequencies), len(self.contacts.active)),
dtype=complex,
)
self._stimulation_variable = np.zeros(
shape=(len(self.signal.frequencies), len(self.contacts.active)),
dtype=complex,
)
if len(self.contacts.floating) > 0:
self._floating_potentials = np.zeros(
shape=(len(self.signal.frequencies), len(self.contacts.floating)),
dtype=complex,
)
if compute_impedance:
# scalar 1-D array, one entry per frequency
self._impedances = np.ndarray(
shape=(len(self.signal.frequencies),), dtype=dtype
)
if estimate_currents:
self._currents = {}
for contact in self.contacts:
self._currents[contact.name] = np.ndarray(
shape=(len(self.signal.frequencies)), dtype=dtype
)
for computing_idx, freq_idx in enumerate(frequency_indices):
frequency = self.signal.frequencies[freq_idx]
_logger.info(f"Computing at frequency: {frequency}")
time_0 = time.time()
# prepare storing at multiple frequencies
if self.signal.octave_band_approximation:
band_indices = get_indices_in_octave_band(
freq_idx, frequency_indices, len(self.signal.frequencies) - 1
)
_logger.debug(
f"""Band frequencies from {self.signal.frequencies[band_indices[0]]}
to {self.signal.frequencies[band_indices[-1]]}"""
)
else:
band_indices = [freq_idx]
# check if conductivity has changed
sigma_has_changed = self._has_sigma_changed(
computing_idx, frequency_indices, threshold=dielectric_threshold
)
if sigma_has_changed:
self.compute_solution(frequency)
if compute_impedance:
impedance = self.compute_impedance()
self._impedances[band_indices] = impedance
if estimate_currents:
estimated_currents = self.estimate_currents()
for contact in self.contacts:
self._currents[contact.name][band_indices] = estimated_currents[
contact.name
]
# refine only at first frequency
if computing_idx == 0 and _do_AMR:
# For voltage-controlled (non-CC) mode, use power as
# the AMR convergence metric — computing the full
# admittance matrix is unnecessary and very expensive.
use_power_for_amr = not self.current_controlled
if use_power_for_amr:
_logger.info("Using power instead of impedance in AMR")
amr_metric = self.compute_power()
elif not compute_impedance:
try:
amr_metric = self.compute_impedance()
except NotImplementedError:
_logger.info("Using power instead of impedance in AMR")
amr_metric = self.compute_power()
use_power_for_amr = True
else:
amr_metric = self._impedances[freq_idx]
_logger.info(
"Number of elements before refinement:"
f"{self.mesh.ngsolvemesh.ne}"
)
error = 100.0
refinements = 0
tolerance = adaptive_mesh_refinement_settings["ErrorTolerance"]
max_iterations = adaptive_mesh_refinement_settings["MaxIterations"]
while error > tolerance and refinements < max_iterations:
self.adaptive_mesh_refinement()
# solve on refined mesh
self.compute_solution(frequency)
# check convergence
if use_power_for_amr:
new_amr_metric = self.compute_power()
else:
new_amr_metric = self.compute_impedance()
# error in percent
error = (
100.0 * abs(amr_metric - new_amr_metric) / abs(amr_metric)
)
# update variables for loop
refinements += 1
amr_metric = new_amr_metric
_logger.info(
f"Adaptive refinement step {refinements}, "
f"error {error:.3f}%."
)
if compute_impedance:
# overwrite impedance values
if use_power_for_amr:
# recompute impedance on final refined mesh
self._impedances[band_indices] = self.compute_impedance()
else:
self._impedances[band_indices] = amr_metric
if estimate_currents:
# recompute currents on the final refined mesh
estimated_currents = self.estimate_currents()
for contact in self.contacts:
self._currents[contact.name][band_indices] = (
estimated_currents[contact.name]
)
_logger.info(
"Number of elements after refinement:"
f"{self.mesh.ngsolvemesh.ne}"
)
_logger.info(
"Adaptive mesh refinement converged after "
f"{refinements} refinement steps with an "
f"error of {error:.3f}%"
)
else:
_logger.info(f"Skipped computation at {frequency} Hz")
if compute_impedance:
# copy from previous frequency
impedance = self._impedances[computing_idx - 1]
self._impedances[band_indices] = impedance
if estimate_currents:
for contact in self.contacts:
self._currents[contact.name][band_indices] = self._currents[
contact.name
][computing_idx - 1]
# scale factor: is one for VC and depends on impedance for other case
self._scale_factor = self.get_scale_factor(freq_idx)
_logger.debug(f"Scale factor: {self._scale_factor}")
if not multisine_mode:
self._store_solution_at_contacts(band_indices)
else:
for contact_idx, contact in enumerate(self.contacts.floating):
self._floating_potentials[freq_idx, contact_idx] = contact.voltage
if _logger.getEffectiveLevel() == logging.DEBUG:
estimated_currents = self.estimate_currents()
_logger.debug(
f"Estimated currents through contacts: {estimated_currents}"
)
time_1 = time.time()
timings["ComputeSolution"].append(time_1 - time_0)
time_0 = time_1
_logger.info("Copy solution to point models")
# initialise point models only after possible mesh change
# AMR can change points in mesh
if computing_idx == 0:
for point_model in point_models:
point_model.output_path = self.output_path
point_model.prepare_VCM_specific_evaluation(
self.mesh, self.conductivity_cf
)
point_model.prepare_frequency_domain_data_structure(
len(self.signal.frequencies), out_of_core
)
_logger.debug(
f"Points in point model: {point_model.coordinates.shape}"
)
# copy solution to point models
self._process_frequency_domain_solution(band_indices, point_models)
time_1 = time.time()
timings["CopyValues"].append(time_1 - time_0)
time_0 = time_1
# export frequency-domain solution at one frequency
if np.isclose(frequency, self._export_frequency):
_logger.info(f"Exporting at {self._export_frequency}")
# save vtk
if export_vtk:
self.vtk_export(freq_idx, multisine_mode, vtk_subdivision)
time_1 = time.time()
timings["VTKExport"].append(time_1 - time_0)
time_0 = time_1
# continue with frequency-domain exports
self._frequency_domain_exports(
point_models, freq_idx, activation_threshold
)
time_1 = time.time()
timings["FieldExport"] = time_1 - time_0
time_0 = time_1
# reset surface impedances
self._surface_impedances = None
# save impedance at all frequencies to file!
if compute_impedance:
_logger.info("Saving impedance")
df = pd.DataFrame(
{
"freq": self.signal.frequencies,
"real": self.impedances.real,
"imag": self.impedances.imag,
}
)
df.to_csv(os.path.join(self.output_path, "impedance.csv"), index=False)
if estimate_currents:
df = pd.DataFrame(
{
"freq": self.signal.frequencies,
}
)
for contact in self.contacts:
df[f"{contact.name}_real"] = self._currents[contact.name].real
df[f"{contact.name}_imag"] = self._currents[contact.name].imag
df.to_csv(os.path.join(self.output_path, "currents.csv"), index=False)
# export floating voltages
if self._floating_potentials is not None:
df = pd.DataFrame(
{
"freq": self.signal.frequencies,
}
)
for contact_idx, contact in enumerate(self.contacts.floating):
df[f"{contact.name}_real"] = self._floating_potentials[
:, contact_idx
].real
df[f"{contact.name}_imag"] = self._floating_potentials[
:, contact_idx
].imag
df.to_csv(
os.path.join(self.output_path, "floating_potentials.csv"), index=False
)
# export time domain solution if a proper signal has been passed
_logger.info("Launching reconstruction of time domain")
if len(self.signal.frequencies) > 1 and not multisine_mode:
_logger.info("Reconstructing time-domain signal.")
timesteps = get_timesteps(
self.signal.cutoff_frequency,
self.signal.base_frequency,
self.signal.signal_length,
)
truncation_index = None
if truncation_time is not None:
timestep = timesteps[1] - timesteps[0]
truncation_index = round(truncation_time / timestep)
for point_model_idx, point_model in enumerate(point_models):
# skip point models that are not considered in time domain
if not point_model.time_domain_conversion:
continue
(
potential_in_time,
Ex_in_time,
Ey_in_time,
Ez_in_time,
) = point_model.compute_solutions_in_time_domain(
self.signal.signal_length, convert_field=point_model.export_field
)
point_model.create_time_result(
timesteps,
potential_in_time,
Ex_in_time,
Ey_in_time,
Ez_in_time,
truncation_index=truncation_index,
)
time_1 = time.time()
timings[f"ReconstructTimeSignals_PointModel_{point_model_idx}"] = (
time_1 - time_0
)
time_0 = time_1
# close output-file
# and write point model reports
for point_model in point_models:
point_model.close_output_file()
try:
point_model.export_point_model_information(
os.path.join(point_model.output_path, point_model.name + ".json")
)
except NotImplementedError:
pass
if len(self.signal.frequencies) > 1 and not multisine_mode:
self.export_solution_at_contacts()
self._save_report(timings)
return timings
@property
def output_path(self) -> str:
"""Returns the path to output."""
return self._output_path
@output_path.setter
def output_path(self, path: str) -> None:
"""Set the path to write output.
Notes
-----
Creates directory if it doesn't exist.
"""
if not os.path.exists(path):
os.makedirs(path)
self._output_path = path
@property
def conductivity_cf(self) -> ConductivityCF:
"""Returns the coefficient function of the conductivity."""
return self._conductivity_cf
@property
def signal(self) -> FrequencyDomainSignal:
"""Returns the frequency-domain representation of stimulation signal."""
return self._signal
@signal.setter
def signal(self, new_signal: FrequencyDomainSignal) -> None:
self._check_signal(new_signal)
self._signal = new_signal
def _check_signal(self, new_signal: FrequencyDomainSignal) -> None:
"""Check the provided signal."""
sum_currents = 0
# check that floating conditions have been imposed correctly
floating_with_surface_impedance = 0
for contact in self.contacts.floating:
if contact.surface_impedance_model is not None:
floating_with_surface_impedance += 1
sum_currents += contact.current
if new_signal.current_controlled:
voltages_active = np.zeros(len(self.contacts.active))
for idx, contact in enumerate(self.contacts.active):
sum_currents += contact.current
voltages_active[idx] = contact.voltage
if not np.isclose(sum_currents, 0):
raise ValueError("The sum of all currents is not zero!")
# Mixed floating (some with surface impedance, some without) is
# rejected earlier in VolumeConductorFloatingImpedance.__init__,
# so that case never reaches here. Only the plain Floating case
# (no floating contact carries a surface impedance) needs the
# multipolar-ground check.
if len(self.contacts.floating) > 0 and floating_with_surface_impedance == 0:
active_contacts_grounded = np.isclose(voltages_active, 0.0)
if len(np.where(active_contacts_grounded)[0]) > 1:
raise ValueError(
"In multipolar current-controlled mode, "
"only one active contact has to be grounded!"
)
@property
def current_controlled(self) -> bool:
"""Return if stimulation is current-controlled."""
return self.signal.current_controlled
@property
def impedances(self) -> np.ndarray:
"""Scalar impedance per frequency (1-D array, length ``n_freq``).
Populated by ``run_full_analysis`` when ``compute_impedance`` is
True. Multicontact admittance / impedance matrices are produced
by ``ossdbs.fem.analysis.ImpedanceAnalyzer``, not here.
"""
return self._impedances
@property
def conductivity(self) -> ngsolve.CoefficientFunction:
"""Return conductivity of latest solution."""
return self._sigma
@property
def is_complex(self) -> bool:
"""Return the state of the data type for spaces. True if complex,
False otherwise.
Returns
-------
bool
"""
return self._complex
@is_complex.setter
def is_complex(self, value: bool) -> None:
"""If complex mode (EQS) is used or not."""
self._complex = value
@property
def model_geometry(self) -> ModelGeometry:
"""The underlying model geometry used for mesh generation."""
return self._model_geometry
@property
def mesh(self) -> Mesh:
"""The mesh used in computations."""
return self._mesh
@property
def solver(self) -> Solver:
"""The solver used in the VCM."""
return self._solver
@property
def contacts(self) -> Contacts:
"""A list of contacts in the VCM."""
return self._contacts
@property
def potential(self) -> ngsolve.GridFunction:
"""Return solution at most recent frequency."""
return self._potential
@property
def frequency(self) -> float:
"""Most recent frequency, not equal to the frequency of the signal!."""
return self._frequency
[docs] def evaluate_potential_at_points(self, lattice: np.ndarray) -> np.ndarray:
"""Return electric potential at specifed 3-D coordinates.
Parameters
----------
lattice : np.ndarray
Nx3 numpy.ndarray of lattice points
Notes
-----
Requires that points outside of the computational domain
have been filtered!
"""
mesh = self.mesh.ngsolvemesh
x, y, z = lattice.T
pots = self.potential(mesh(x, y, z))
return pots
[docs] def evaluate_field_at_points(self, lattice: np.ndarray) -> np.ndarray:
"""Return electric field components at specifed 3-D coordinates.
Parameters
----------
lattice : np.ndarray
Nx3 numpy.ndarray of lattice points
Notes
-----
Requires that points outside of the computational domain
have been filtered!
"""
mesh = self.mesh.ngsolvemesh
x, y, z = lattice.T
fields = self.electric_field(mesh(x, y, z))
return fields
@property
def current_density(self) -> ngsolve.GridFunction:
"""Return current density in A/mm^2."""
# scale to account for mm as length unit (not yet contained in conductivity)
return self.conductivity * self.electric_field
@property
def electric_field(self) -> ngsolve.GridFunction:
"""Compute electric field from potential."""
return -ngsolve.grad(self.potential)
[docs] def compute_power(self) -> complex:
"""Compute power in domain."""
mesh = self.mesh.ngsolvemesh
# do not need to account for mm because of integration
power = ngsolve.Integrate(
ngsolve.Conj(self.electric_field) * self.current_density, mesh
)
return power
[docs] def compute_impedance(self) -> complex:
"""Compute scalar impedance at most recent solution.
For two active contacts, the scalar impedance is computed
by volume integration. This approach is superior to
integration of the normal current density. It has been
described in [Zimmermann2021a]_.
Multicontact configurations are not handled here. The full
admittance-matrix analysis is a separate analysis tool (see
``docs/impedance_analyzer_plan.md``).
References
----------
.. [Zimmermann2021a] Zimmermann, J., et al. (2021).
Frontiers in Bioengineering and Biotechnology, 9, 765516.
https://doi.org/10.3389/fbioe.2021.765516
Returns
-------
complex
Scalar impedance between the two active contacts.
"""
if len(self.contacts.active) != 2:
raise NotImplementedError(
"Scalar impedance requires exactly 2 active contacts "
f"(got {len(self.contacts.active)}). For multicontact "
"configurations, use the standalone admittance-matrix "
"analysis tool."
)
power = self.compute_power()
voltage_diff = 0
for idx, contact in enumerate(self.contacts.active):
voltage = contact.voltage
voltage_diff += (-1) ** idx * contact.voltage
if contact.surface_impedance_model is not None:
interface_admittance = ngsolve.CF(
1.0 / self._surface_impedances[contact.name]
)
diff = (
self.mesh.boundary_coefficients({contact.name: voltage})
- self.potential
)
power += ngsolve.Integrate(
interface_admittance * diff * ngsolve.Conj(diff),
mesh=self.mesh.ngsolvemesh,
definedon=self.mesh.ngsolvemesh.Boundaries(contact.name),
)
# Add surface-impedance dissipation for floating contacts.
# contact.voltage holds the computed u_k after
# _update_floating_voltages(). Without this term the scalar
# Z = V^2 / P under-counts P and over-estimates |Z|.
for contact in self.contacts.floating:
if contact.surface_impedance_model is not None:
interface_admittance = ngsolve.CF(
1.0 / self._surface_impedances[contact.name]
)
diff = (
self.mesh.boundary_coefficients({contact.name: contact.voltage})
- self.potential
)
power += ngsolve.Integrate(
interface_admittance * diff * ngsolve.Conj(diff),
mesh=self.mesh.ngsolvemesh,
definedon=self.mesh.ngsolvemesh.Boundaries(contact.name),
)
_logger.debug(f"Voltage drop for impedance: {voltage_diff}")
_logger.debug(f"Power after surface imp: {power}")
return voltage_diff * np.conj(voltage_diff) / power
[docs] def estimate_currents(self) -> dict:
"""Estimate currents by integration of normal component.
Notes
-----
Meant for debugging purposes.
If singularities are present, this method will not be accurate.
"""
normal_vector = ngsolve.specialcf.normal(3)
estimated_currents = {}
for contact in self.contacts:
# use that normal_vector always points outwards
normal_current_density = -normal_vector * ngsolve.BoundaryFromVolumeCF(
self.current_density
)
current = ngsolve.Integrate(
normal_current_density * ngsolve.ds(contact.name), self.mesh.ngsolvemesh
)
estimated_currents[contact.name] = current
return estimated_currents
[docs] def vtk_export(
self,
freq_idx: int,
multisine_mode: bool = False,
subdivision: int = 0,
) -> None:
"""Export all relevant properties to VTK.
Parameters
----------
freq_idx: int
Index of frequency
multisine_mode: bool
If rectangular pulse is used (multisine_mode = False)
subdivision: int
Element subdivision count forwarded to ngsolve.VTKOutput.
"""
self.export_solution_to_vtk(freq_idx, multisine_mode, subdivision)
self.export_conductivity_to_vtk(subdivision)
self.export_material_distribution_to_vtk(subdivision)
[docs] def export_solution_to_vtk(
self,
freq_idx: int,
multisine_mode: bool = False,
subdivision: int = 0,
) -> None:
"""Export potential and field at frequency to VTK.
Parameters
----------
freq_idx: int
Index of frequency
multisine_mode: bool
If rectangular pulse is used (multisine_mode = False)
subdivision: int
Element subdivision count forwarded to ngsolve.VTKOutput.
"""
ngmesh = self.mesh.ngsolvemesh
# use standard solution with 1V voltage drop
# unless we run multisine mode
scale_factor = 1.0
if multisine_mode:
scale_factor = self._scale_factor * self.signal.amplitudes[freq_idx]
FieldSolution(
scale_factor * self.potential, "potential", ngmesh, self.is_complex
).save(os.path.join(self.output_path, "potential"), subdivision)
FieldSolution(
scale_factor * self.electric_field, "E_field", ngmesh, self.is_complex
).save(os.path.join(self.output_path, "E-field"), subdivision)
[docs] def export_conductivity_to_vtk(self, subdivision: int = 0) -> None:
"""Write conductivity to VTK file."""
ngmesh = self.mesh.ngsolvemesh
if self.conductivity_cf.is_tensor:
# Naming convention by ParaView!
cf_list = (
self.conductivity[0], # xx
self.conductivity[4], # yy
self.conductivity[8], # zz
self.conductivity[1], # xy
self.conductivity[5], # yz
self.conductivity[2], # xz
)
conductivity_export = ngsolve.CoefficientFunction(cf_list, dims=(6,))
else:
conductivity_export = self.conductivity
FieldSolution(
conductivity_export, "conductivity", ngmesh, self.is_complex
).save(os.path.join(self.output_path, "conductivity"), subdivision)
if self.conductivity_cf.is_tensor:
dti_voxel = self.conductivity_cf.dti_voxel_distribution
# Naming convention by ParaView!
cf_list = (
dti_voxel[0], # xx
dti_voxel[4], # yy
dti_voxel[8], # zz
dti_voxel[1], # xy
dti_voxel[5], # yz
dti_voxel[2], # xz
)
dti_export = ngsolve.CoefficientFunction(cf_list, dims=(6,))
FieldSolution(dti_export, "dti", ngmesh, False).save(
os.path.join(self.output_path, "dti"), subdivision
)
[docs] def export_material_distribution_to_vtk(self, subdivision: int = 0) -> None:
"""Write material distribution to VTK file."""
ngmesh = self.mesh.ngsolvemesh
FieldSolution(
self.conductivity_cf.material_distribution(self.mesh),
"material",
ngmesh,
False,
).save(os.path.join(self.output_path, "material"), subdivision)
[docs] def floating_values(self) -> dict:
"""Read out floating potentials."""
floating_voltages = {}
for contact in self.contacts.floating:
floating_voltages[contact.name] = contact.voltage
return floating_voltages
[docs] def h1_space(self, boundaries: list[str], is_complex: bool) -> ngsolve.H1:
"""Return a h1 space on the mesh.
Parameters
----------
boundaries : list of str
list of boundary names.
is_complex: bool
Whether to use complex arithmetic
Returns
-------
ngsolve.H1
"""
dirichlet = "|".join(boundary for boundary in boundaries)
return ngsolve.H1(
mesh=self.mesh.ngsolvemesh,
order=self._order,
dirichlet=dirichlet,
complex=is_complex,
wb_withedges=False,
)
[docs] def number_space(self) -> ngsolve.comp.NumberSpace:
"""Return a number space on the mesh.
Returns
-------
ngsolve.NumberSpace
Space with only one single (global) DOF.
TODO check if needed
"""
return ngsolve.NumberSpace(
mesh=self.mesh.ngsolvemesh, order=0, complex=self.is_complex
)
[docs] def flux_space(self) -> ngsolve.comp.HDiv:
"""Return a flux space on the mesh.
Returns
-------
ngsolve.HDiv
Notes
-----
The HDiv space is returned with a minimum order of 1.
It is needed for the a-posteriori error estimator
needed for adaptive mesh refinement.
"""
return ngsolve.HDiv(
mesh=self.mesh.ngsolvemesh,
order=max(1, self._order - 1),
complex=self.is_complex,
)
[docs] def get_scale_factor(self, freq_idx: int) -> float:
"""Scale solution by signal amplitude at a frequency given by index.
Notes
-----
In voltage-controlled mode,
only the amplitude of the Fourier coefficient is used.
In current-controlled mode without using floating conductors,
the impedance is also considered.
"""
scale_factor = 1.0
if self.current_controlled:
_logger.debug("Scale solution for current_controlled mode")
if self.current_controlled and len(self.contacts.active) == 2:
impedance = self.impedances[freq_idx]
# use Ohm's law U = Z * I
# and that the Fourier coefficient for the current is known
amplitude = self.contacts.active[0].current
# use positive current by construction
if self.is_complex:
sign = np.sign(amplitude.real)
else:
sign = np.sign(amplitude)
amplitude *= sign
scale_factor *= impedance * amplitude
return scale_factor
[docs] def prepare_current_controlled_mode(self) -> None:
"""Check contacts and assign voltages if needed."""
if len(self.contacts.active) == 2:
_logger.info("Overwrite voltage for current-controlled mode")
ground_assigned = False
for contact_idx, contact in enumerate(self.contacts.active):
if np.isclose(self.contacts[contact.name].voltage, 0):
if ground_assigned:
raise ValueError(
"All active contacts have been grounded (voltage = 0V)."
"Choose only one ground."
)
ground_assigned = True
else:
contact_voltage = float(contact_idx) + 1
self.contacts[contact.name].voltage = contact_voltage
else:
if len(self.contacts.active) == 0:
# All contacts are floating (e.g. FloatingImpedance).
# The Lagrange multiplier constrains the sum of floating
# potentials to zero, providing the voltage reference.
_logger.info(
"No active contacts — all floating with surface "
"impedance. Sum-of-potentials constraint is used."
)
elif len(self.contacts.active) == 1:
for contact in self.contacts.active:
if not np.isclose(contact.voltage, 0):
raise ValueError(
"In multicontact current-controlled mode, "
"only ground voltage (0V) can be set on "
"active contacts!"
)
else:
raise ValueError(
"In multicontact current-controlled mode, "
"currently only one active contact with fixed "
"voltage can be used. "
"Its voltage has to be 0V (ground)."
)
[docs] def setup_timings_dict(
self, export_vtk: bool, point_models: list[PointModel]
) -> dict:
"""Setup dictionary to save execution times estimate."""
timings = {}
timings["ComputeSolution"] = []
# look at entire copying process only
timings["CopyValues"] = []
if export_vtk:
timings["VTKExport"] = []
for point_model_idx, _ in enumerate(point_models):
timings[f"ReconstructTimeSignals_PointModel_{point_model_idx}"] = 0.0
return timings
def _store_solution_at_contacts(self, band_indices: list | np.ndarray) -> None:
"""Save voltages / currents at given frequency band for all contacts."""
if self.current_controlled:
for contact_idx, contact in enumerate(self.contacts.active):
for freq_idx in band_indices:
scale_factor = self._scale_factor * self.signal.amplitudes[freq_idx]
self._free_stimulation_variable[freq_idx, contact_idx] = (
scale_factor * contact.voltage
)
self._stimulation_variable[freq_idx, contact_idx] = (
scale_factor * contact.current
)
else:
estimated_currents = self.estimate_currents()
for contact_idx, contact in enumerate(self.contacts.active):
for freq_idx in band_indices:
scale_factor = self._scale_factor * self.signal.amplitudes[freq_idx]
self._free_stimulation_variable[freq_idx, contact_idx] = (
scale_factor * estimated_currents[contact.name]
)
self._stimulation_variable[freq_idx, contact_idx] = (
scale_factor * contact.voltage
)
for contact_idx, contact in enumerate(self.contacts.floating):
for freq_idx in band_indices:
self._floating_potentials[freq_idx, contact_idx] = (
self._scale_factor * contact.voltage
)
def _copy_frequency_domain_solution(
self,
band_indices: list | np.ndarray,
point_model: PointModel,
potentials: np.ndarray,
fields: np.ndarray,
) -> None:
"""Copy values to time-domain vector."""
for freq_idx in band_indices:
scale_factor = self._scale_factor * self.signal.amplitudes[freq_idx]
# cast scale_factor to complex
# needed for export of h5py files in out-of-core mode
if not isinstance(scale_factor, complex):
scale_factor = complex(scale_factor)
point_model.copy_frequency_domain_solution_from_vcm(
freq_idx, scale_factor * potentials, scale_factor * fields
)
[docs] def threshold_frequency_domain_Efield(
self, scale_factor: float, activation_threshold: float
) -> float:
"""Determine volume of E-field above threshold at current frequency."""
field = scale_factor * self.electric_field
# convert to V/m (field is in V/mm because mesh is in mm)
# Use the modulus |E| (Norm) so complex frequency-domain fields are
# handled; for real fields this equals sqrt(E.E). This matches the
# complex-modulus convention of the exported field magnitude.
field_magnitude = 1e3 * ngsolve.Norm(field)
# subtract threshold from electric field,
# all positive values are 1, negative values 0
threshold_cf = ngsolve.IfPos(field_magnitude - activation_threshold, 1, 0)
mesh = self.mesh.ngsolvemesh
# Integrate to get volume
return ngsolve.Integrate(threshold_cf, mesh=mesh)
def _has_sigma_changed(
self, freq_idx: int, frequency_indices: np.ndarray, threshold: float = 0.01
) -> bool:
"""Check if conductivity has changed."""
if self._sigma is None:
return True
else:
dielectric_properties = self._conductivity_cf.dielectric_properties
old_frequency = frequency_indices[freq_idx - 1] * self.signal.base_frequency
new_frequency = frequency_indices[freq_idx] * self.signal.base_frequency
return have_dielectric_properties_changed(
dielectric_properties,
self.is_complex,
old_frequency,
new_frequency,
threshold,
)
def _frequency_domain_exports(
self,
point_models: list,
export_frequency_index: int,
activation_threshold: float,
):
"""Export solution at desired frequency."""
export_frequency = self.signal.frequencies[export_frequency_index]
_logger.info(f"Exporting results at {export_frequency} Hz.")
if activation_threshold is not None:
scale_factor = (
self._scale_factor * self.signal.amplitudes[export_frequency_index]
)
self._vta_volume = self.threshold_frequency_domain_Efield(
scale_factor, activation_threshold
)
_logger.info(f"VTA volume is: {self._vta_volume:.3f}")
for point_model in point_models:
_logger.info(f"Exporting for point model type {type(point_model)}.")
point_model.export_potential_at_frequency(
self._export_frequency, export_frequency_index
)
if point_model._export_field:
_logger.info(
f"Exporting electric field at frequency {self._export_frequency}Hz."
)
point_model.export_field_at_frequency(
self._export_frequency,
export_frequency_index,
electrode=self.model_geometry.electrodes[0],
activation_threshold=activation_threshold,
)
if isinstance(point_model, Lattice):
point_model.VTA_volume = self._vta_volume
def _process_frequency_domain_solution(
self, band_indices: list | np.ndarray, point_models: PointModel
):
"""Copy results to points."""
for point_model in point_models:
potentials = self.evaluate_potential_at_points(point_model.lattice)
fields = self.evaluate_field_at_points(point_model.lattice)
# copy values for time-domain analysis
self._copy_frequency_domain_solution(
band_indices, point_model, potentials, fields
)
[docs] def adaptive_mesh_refinement(self):
"""Refine mesh adaptively."""
flux = self.current_density
hdiv_space = self.flux_space()
flux_potential = ngsolve.GridFunction(space=hdiv_space)
flux_potential.Set(coefficient=flux)
difference = flux - flux_potential
error = difference * ngsolve.Conj(difference)
self.mesh.refine_by_error_cf(error)
self.update_space()
def _check_AMR_settings(self, adaptive_mesh_refinement_settings: dict) -> None:
if not {"ErrorTolerance", "MaxIterations"}.issubset(
adaptive_mesh_refinement_settings.keys()
):
raise ValueError(
"Need to specify ErrorTolerance and "
"MaxIterations for adaptive mesh refinement"
)
def _resolve_amr_active(
self, adaptive_mesh_refinement_settings: dict | None
) -> bool:
"""Decide whether AMR should run for this analysis.
AMR is disabled when hp-refinement has been applied to the mesh,
because NGSolve's standard refinement produces an inconsistent
mesh on top of hp-refined elements.
"""
if adaptive_mesh_refinement_settings is None:
return False
self._check_AMR_settings(adaptive_mesh_refinement_settings)
active = bool(adaptive_mesh_refinement_settings.get("Active", False))
if active and self.mesh.hp_refinement_applied:
# hp refinement introduces elements that NGSolve cannot refine
_logger.warning(
"Attention: Adaptive mesh refinement and hp-refinement "
"are mutually exclusive"
)
return False
return active
def _save_report(self, timings: dict):
"""Save simulation run report to disk."""
report = {}
report["DOF"] = self._space.ndof
report["Elements"] = self.mesh.n_elements
report["Timings"] = timings
if self._vta_volume is not None:
report["VTA_volume_mm3"] = self._vta_volume
with open(os.path.join(self.output_path, "VCM_report.json"), "w") as fp:
json.dump(report, fp)
[docs] def refine_mesh_by_material(self, material_mesh_refinement_steps: int) -> None:
"""Check materials and refine mesh if more than one material per element."""
for _ in range(material_mesh_refinement_steps):
self.mesh.refine_by_material_cf(
self.conductivity_cf.material_distribution(self.mesh)
)