Source code for ossdbs.point_analysis.pathway

# Copyright 2023, 2024 Konstantin Butenko, Jan Philipp Payonk, Julius Zimmermann
# SPDX-License-Identifier: GPL-3.0-or-later

import logging
import os
from dataclasses import dataclass

import h5py
import numpy as np
import pandas as pd

from ossdbs.fem import Mesh
from ossdbs.utils.collapse_vta import get_collapsed_VTA
from ossdbs.utils.field_computation import (
    compute_field_magnitude_from_components,
)

from .lattice import PointModel
from .time_results import TimeResult

_logger = logging.getLogger(__name__)


[docs]class Pathway(PointModel): """Pathways comprise populations of axons."""
[docs] @dataclass class Axon: """ Attributes ---------- name: str Naming of axons needs to be axon0, axon1, axon2, ... to be processed in the correct order. points: np.ndarray Contains 3D coordinates of each point within one axon. """ name: str points: np.ndarray status: int # 0 - normal, -1 - outside domain/encap, -2 - csf orig_inx: int # "original" indices of streamlines
[docs] @dataclass class Population: """ Attributes ---------- name: str Name of neuronal population, e.g. a pathway. axons: list["Pathway.Axon"] List that contains all axons within one population. """ name: str axons: list["Pathway.Axon"]
def __init__(self, input_path: str, export_field: bool = False) -> None: # identifiers self._name = "PAM" self._export_field = export_field # path from where to read model self._path = input_path # never collapse VTA self.collapse_VTA = False # always compute time-domain signal self.time_domain_conversion = True with h5py.File(self._path, "r") as file: populations = [ self.Population(group, self._create_axons(file, group)) for group in file.keys() ] self._populations = populations n_points = sum( [ len(axon.points) for population in self._populations for axon in sorted(population.axons, key=lambda x: int(x.name[4:])) ] ) self._location = np.full(n_points, "") self._coordinates = self._initialize_coordinates() # will be set later self._lattice = None def _create_axons(self, file: h5py.File, group: str) -> list: """Create axons based on the input from the .h5 file. Parameters ---------- file: h5py.File Loaded .h5 file, which contains structural information. group: str Name of the group, which contains the axons. Returns ------- axons: list Returns list of all axons within one group. """ axons = [] for sub_group in file[group].keys(): dataset = file[group][sub_group] if "inx" in dataset.attrs: orig_inx = dataset.attrs["inx"] else: _logger.debug( "Dataset %s/%s has no 'inx' attribute; falling back to the " "axon index encoded in the dataset name.", group, sub_group, ) orig_inx = int(sub_group[4:]) if sub_group.startswith("axon") else 0 axons.append(self.Axon(sub_group, np.array(dataset), 0, orig_inx)) return axons def _initialize_coordinates(self) -> np.ndarray: return np.concatenate( [ axon.points for population in self._populations for axon in sorted(population.axons, key=lambda x: int(x.name[4:])) ] ) def _write_file(self, data: TimeResult, file: h5py.File): """Create datasets in HDF5 file. Parameters ---------- data: TimeResult Time-domain result to be exported. file: h5py.File HDF5 file that shall contain data. Notes ----- Creates groups for each population in .h5 file. TODO Rename to 'create_data_export'or alike. """ file.create_dataset("TimeSteps[s]", data=data.time_steps) start = 0 idx = 0 for population in self._populations: group = file.create_group(population.name) start, idx, status_list = self._create_datasets( data, start, idx, population, group ) group.create_dataset("Status", data=status_list) def _create_datasets(self, data, start, idx, population, group): """Create datasets for each axon within the corresponding population. Axons are sorted numerically. """ status_list = [] for axon in sorted(population.axons, key=lambda x: int(x.name[4:])): sub_group = group.create_group(axon.name) sub_group.attrs["inx"] = axon.orig_inx sub_group.create_dataset("Points[mm]", data=axon.points) location = self._location[ idx * len(axon.points) : (idx + 1) * len(axon.points) ] sub_group.create_dataset("Location", data=location.astype("S")) status_list.append(axon.status) if axon.status != -1: end = start + len(axon.points) # export potential potential = data.potential[start:end] sub_group.create_dataset("Potential[V]", data=potential) if data.electric_field_magnitude is not None: # export field magnitude electric_field_magnitude = data.electric_field_magnitude[start:end] sub_group.create_dataset( "Electric field magnitude[Vm^(-1)]", data=electric_field_magnitude, ) if not ( data.electric_field_vector_x is None and data.electric_field_vector_y is None and data.electric_field_vector_z is None ): # export field vector component-wise electric_field_vector_x = data.electric_field_vector_x[start:end] sub_group.create_dataset( "Electric field vector x[Vm^(-1)]", data=electric_field_vector_x ) electric_field_vector_y = data.electric_field_vector_y[start:end] sub_group.create_dataset( "Electric field vector y[Vm^(-1)]", data=electric_field_vector_y ) electric_field_vector_z = data.electric_field_vector_z[start:end] sub_group.create_dataset( "Electric field vector z[Vm^(-1)]", data=electric_field_vector_z ) start = end idx = idx + 1 return start, idx, status_list
[docs] def filter_for_geometry(self, grid_pts: np.ma.MaskedArray) -> np.ndarray: """Check if any point of an axon is outside the geometry. If this is the case, the entire axon will be marked, and its points will be removed from further processing. Parameters ---------- grid_pts: np.ma.MaskedArray Array containing points inside the mesh. Returns ------- filtered_points: np.ndarray Returns filtered_points after removing axons that are (partially) outside the geometry. """ x, y, z = grid_pts.T lattice_mask = np.invert(grid_pts.mask)[:, 0] idx_axon = 0 total_points = sum( axon.points.shape[0] for population in self._populations for axon in sorted(population.axons, key=lambda x: int(x.name[4:])) ) # Create an array of NaNs to filter the axons outside the domain all_points = np.full((total_points, 3), np.nan) point_idx = 0 pop_axons_stats = [] # List of (population_name, n_axons, n_axons_inside) for population in self._populations: n_axons = len(population.axons) n_axons_inside = 0 for axon in sorted(population.axons, key=lambda x: int(x.name[4:])): axon_length = axon.points.shape[0] axon_outside = False for idx in range(axon_length): if not lattice_mask[idx_axon + idx]: axon_outside = True break if axon_outside: axon.status = -1 else: # Fill all_points for this axon all_points[point_idx : point_idx + axon_length, 0] = x.data[ idx_axon : idx_axon + axon_length ] all_points[point_idx : point_idx + axon_length, 1] = y.data[ idx_axon : idx_axon + axon_length ] all_points[point_idx : point_idx + axon_length, 2] = z.data[ idx_axon : idx_axon + axon_length ] point_idx += axon_length n_axons_inside += 1 idx_axon += axon_length pop_axons_stats.append((population.name, n_axons, n_axons_inside)) filtered_points = all_points[~np.isnan(all_points).any(axis=1)] if np.isnan(filtered_points).any(): raise RuntimeError( "NaN entries remain in filtered_points after filtering indicating a " "possible NumPy or logic bug." ) if filtered_points.shape[0] == 0: raise ValueError("No points inside the computational domain.") for name, n_axons, n_axons_inside in pop_axons_stats: _logger.info(f"Total axons in {name}: {n_axons}") _logger.info(f"Outside the domain: {n_axons - n_axons_inside}") return filtered_points
[docs] def filter_csf_encap( self, inside_csf: np.ndarray, inside_encap: np.ndarray ) -> None: """Change axon status if a single point of the axon is within the CSF or encapsulation layer. Parameters ---------- inside_csf: np.ndarray The array contains 1 if the corresponding point is inside the CSF, 0 otherwise. inside_encap: np.ndarray The array contains 1 if the corresponding point is inside the encapsulation layer, 0 otherwise. """ idx_axon = 0 for population in self._populations: for axon in sorted(population.axons, key=lambda x: int(x.name[4:])): if axon.status != -1: axon_length = axon.points.shape[0] for idx in range(axon_length): if inside_encap[idx_axon + idx]: axon.status = -1 # set status -1 for inside encap break if inside_csf[idx_axon + idx]: axon.status = -2 # set status -2 for inside csf break idx_axon += axon_length _logger.info("Marked axons inside CSF and encapsulation layer") return
[docs] def create_index(self, lattice: np.ndarray) -> np.ndarray: """Create index for each point to the matching axon. Returns ------- index: np.ndarray. """ index = np.zeros(shape=len(lattice), dtype=int) axon_length = self.get_axon_length() for i in range(int(len(lattice) / axon_length)): index[i * axon_length : (i + 1) * axon_length] = int(i) return np.reshape(index, (len(index), 1))
[docs] def get_axon_length(self) -> int: """ Returns ------- axon_length: int Number of points per axon Notes ----- Assume the same length for all axons. """ return self._populations[0].axons[0].points.shape[0]
[docs] def get_population_names(self) -> list: """ Returns ------- population_names: list[str] Names of all populations defined """ return [self._populations[idx].name for idx in range(len(self._populations))]
[docs] def get_axon_names(self) -> list: """ Returns ------- axon_names: list[list,list,...] Names of axons in each population """ axon_names = [] for population in range(len(self._populations)): axon_names_in_population = [] for axon in range(len(self._populations[population].axons)): axon_names_in_population.append( self._populations[population].axons[axon].name ) axon_names.append(axon_names_in_population) return axon_names
[docs] def get_axon_numbers(self) -> list: """Get list of number of axons per population. Returns ------- axon_number: list[int] Number of axons per population """ return [ len(self._populations[idx].axons) for idx in range(len(self._populations)) ]
[docs] def save_as_nifti( self, scalar_field, filename, binarize=False, activation_threshold=None ): """Save scalar field in abstract orthogonal space in nifti format. Parameters ---------- scalar_field : numpy.ndarray Nx1 array of scalar values on the lattice filename: str Name for the nifti file that should contain full path binarize: bool Choose to threshold the scalar field and save the binarized result activation_threshold: float Activation threshold for VTA estimate """ raise NotImplementedError("Pathway results can not be stored in Nifti format.")
[docs] def prepare_VCM_specific_evaluation(self, mesh: Mesh, conductivity_cf): """Prepare data structure according to mesh. Parameters ---------- mesh: Mesh Mesh object on which VCM is defined conductivity_cf: ConductivityCF Conductivity function that holds material info Notes ----- Mask all points outside domain, filter CSF and encapsulation layer etc. Prepares data storage for all frequencies at all points. """ grid_pts = self.points_in_mesh(mesh) self._lattice_mask = np.invert(grid_pts.mask) self._lattice = self.filter_for_geometry(grid_pts) self._inside_csf = self.get_points_in_csf(mesh, conductivity_cf) self._inside_encap = self.get_points_in_encapsulation_layer(mesh) # mark complete axons and log how many axons were finally seeded self.filter_csf_encap(self.inside_csf, self.inside_encap) total_axons = sum(len(pop.axons) for pop in self._populations) seeded_axons = sum( sum(axon.status == 0 for axon in pop.axons) for pop in self._populations ) _logger.info(f"Axons finally seeded: {seeded_axons} / {total_axons}") # create index for axons self._axon_index = self.create_index(self.lattice)
[docs] def export_field_at_frequency( self, frequency: float, frequency_index: int, electrode=None, activation_threshold: float | None = None, ): """Write field values to CSV. Parameters ---------- frequency: float Frequency of exported solution frequency_index: int Index at which frequency is stored activation_threshold: float Threshold to define VTA electrode: ElectrodeModel electrode model that holds geometry information Notes ----- No Nifti file is exported for a Pathway model. """ if self.lattice is None: raise RuntimeError( "Please call first prepare_VCM_specific_evaluation " "to classify pathway points." ) Ex = self.tmp_Ex_freq_domain[:, frequency_index] Ey = self.tmp_Ey_freq_domain[:, frequency_index] Ez = self.tmp_Ez_freq_domain[:, frequency_index] field_mags = compute_field_magnitude_from_components(Ex, Ey, Ez) df_field = pd.DataFrame( np.concatenate( [ self.axon_index, self.lattice, Ex.reshape((Ex.shape[0], 1)).real, Ey.reshape((Ey.shape[0], 1)).real, Ez.reshape((Ez.shape[0], 1)).real, field_mags.reshape(field_mags.shape[0], 1), self.inside_csf, self.inside_encap, ], axis=1, ), columns=[ "index", "x-pt", "y-pt", "z-pt", "x-field", "y-field", "z-field", "magnitude", "inside_csf", "inside_encap", ], ) # save frequency df_field["frequency"] = frequency if self.collapse_VTA: _logger.info("Collapse VTA by virtually removing the electrode") field_on_probed_points = np.concatenate( [ self.lattice, Ex.reshape((Ex.shape[0], 1)).real, Ey.reshape((Ey.shape[0], 1)).real, Ez.reshape((Ey.shape[0], 1)).real, field_mags.reshape((field_mags.shape[0], 1)).real, ], axis=1, ) if electrode is None: raise ValueError( "Electrode for exporting the collapsed VTA is missing." ) implantation_coordinate = electrode._position lead_direction = electrode._direction lead_diam = electrode._parameters.lead_diameter field_on_probed_points_collapsed = get_collapsed_VTA( field_on_probed_points, implantation_coordinate, lead_direction, lead_diam, ) df_collapsed_field = pd.DataFrame( np.concatenate( [ self.axon_index, field_on_probed_points_collapsed, self.inside_csf, self.inside_encap, ], axis=1, ), columns=[ "index", "x-pt", "y-pt", "z-pt", "x-field", "y-field", "z-field", "magnitude", "inside_csf", "inside_encap", ], ) df_collapsed_field.to_csv( os.path.join(self.output_path, f"E_field_{self.name}.csv"), index=False, ) else: df_field.to_csv( os.path.join(self.output_path, f"E_field_{self.name}.csv"), index=False, ) return