Source code for ossdbs.main

# Copyright 2023, 2024 Konstantin Butenko, Jan Philipp Payonk
# Copyright 2023, 2024 Johannes Reding, Julius Zimmermann
# SPDX-License-Identifier: GPL-3.0-or-later

import argparse
import json
import logging
import multiprocessing
import os
import pprint
import time

import ngsolve

from ossdbs import log_to_file, set_logger
from ossdbs.api import (
    build_brain_model,
    generate_electrodes,
    generate_signal,
    load_images,
    prepare_dielectric_properties,
    prepare_solver,
    prepare_stimulation_signal,
    prepare_volume_conductor_model,
    run_stim_sets,
    run_volume_conductor_model,
    set_contact_and_encapsulation_layer_properties,
    validate_solver_settings,
)
from ossdbs.fem import ConductivityCF
from ossdbs.model_geometry import ModelGeometry
from ossdbs.utils.settings import Settings
from ossdbs.utils.type_check import TypeChecker

_logger = logging.getLogger(__name__)


[docs]def main_run(input_settings: dict): """Run OSS-DBS from input dictionary. Parameters ---------- input_settings: dict Input dictionary run_path: str Path where to run OSS-DBS """ timings = {} time_0 = time.time() settings = Settings(input_settings).complete_settings() TypeChecker.check(settings) _logger.debug(f"Final settings:\\ {settings}") # create output path if not os.path.isdir(settings["OutputPath"]): os.mkdir(settings["OutputPath"]) log_to_file( output_file=os.path.join(settings["OutputPath"], "ossdbs.log"), level=_logger.getEffectiveLevel(), ) # create fail flag open( os.path.join( settings["StimulationFolder"], "fail_" + settings["FailFlag"] + ".txt" ), "w", ).close() time_1 = time.time() timings["Settings"] = time_1 - time_0 time_0 = time_1 mri_image, dti_image = load_images(settings) time_1 = time.time() timings["MRI"] = time_1 - time_0 time_0 = time_1 electrodes = generate_electrodes(settings) time_1 = time.time() timings["Electrodes"] = time_1 - time_0 time_0 = time_1 _logger.info("Generate full model geometry") brain_model = build_brain_model(settings, mri_image) try: geometry = ModelGeometry(brain_model, electrodes) except RuntimeError: _logger.warning( "Initial geometry failed, now building with rotated geometry." "If this fails, too, change the shape of the brain geometry." ) brain_model = build_brain_model(settings, mri_image, rotate_initial_geo=True) geometry = ModelGeometry(brain_model, electrodes) time_1 = time.time() timings["ModelGeometry"] = time_1 - time_0 time_0 = time_1 set_contact_and_encapsulation_layer_properties(settings, geometry) time_1 = time.time() timings["ContactProperties"] = time_1 - time_0 time_0 = time_1 # Validate solver settings for FloatingImpedance + EQS mode validate_solver_settings(settings, geometry) dielectric_properties = prepare_dielectric_properties(settings) time_1 = time.time() timings["DielectricModel"] = time_1 - time_0 time_0 = time_1 _logger.info("Prepare conductivity coefficient function") materials = settings["MaterialDistribution"]["MRIMapping"] conductivity = ConductivityCF( mri_image, brain_model.brain_region, dielectric_properties, materials, geometry.encapsulation_layers, complex_data=settings["EQSMode"], dti_image=dti_image, wm_masking=settings["MaterialDistribution"]["WMMasking"], ) time_1 = time.time() timings["ConductivityCF"] = time_1 - time_0 time_0 = time_1 # decide on truncation truncation_time = None if "TruncateAfterActivePartRatio" in settings: truncation_ratio = settings["TruncateAfterActivePartRatio"] if truncation_ratio is not None: if not isinstance(truncation_ratio, float): raise ValueError( "Please provide the ratio to truncate the signal " "as a floating-point number. " "Set e.g. to 20 for 20 times pulse + counterpulse width." ) if truncation_ratio < 1.0: raise ValueError( "The truncation ratio is a multiple of the " "active signal part." "Values smaller than 1.0 are not permitted." ) time_domain_signal = generate_signal(settings) truncation_time = truncation_ratio * time_domain_signal.get_active_time() # save Mesh for StimSets if settings["StimSets"]["Active"]: settings["Mesh"]["SavePath"] = os.path.join(settings["OutputPath"], "tmp_mesh") settings["Mesh"]["LoadPath"] = os.path.join( settings["OutputPath"], "tmp_mesh.vol.gz" ) settings["Mesh"]["SaveMesh"] = False settings["Mesh"]["LoadMesh"] = False # because of floating settings["Solver"]["Preconditioner"] = "local" settings["Solver"]["PreconditionerKwargs"] = {} # run in parallel with ngsolve.TaskManager(): solver = prepare_solver(settings) volume_conductor = prepare_volume_conductor_model( settings, geometry, conductivity, solver ) frequency_domain_signal = prepare_stimulation_signal(settings) if not settings["StimSets"]["Active"]: volume_conductor.prepare_mesh_refinements( settings["Mesh"]["MaterialRefinementSteps"] ) vcm_timings = run_volume_conductor_model( settings, volume_conductor, frequency_domain_signal, truncation_time=truncation_time, ) _logger.info(f"Volume conductor timings:\n{pprint.pformat(vcm_timings)}") else: # Apply h-refinement (material bisection) and save the # h-refined mesh. HP refinement is deferred: it will be # applied on each per-contact VCM after loading the mesh. volume_conductor.apply_h_refinements( settings["Mesh"]["MaterialRefinementSteps"] ) volume_conductor.mesh.save(settings["Mesh"]["SavePath"]) settings["Mesh"]["LoadMesh"] = True run_stim_sets( settings, geometry, conductivity, solver, frequency_domain_signal ) time_1 = time.time() timings["VolumeConductor"] = time_1 - time_0 time_0 = time_1 # run PAM if settings["PathwayFile"] is not None: _logger.info("Please compute the pathway activation separately.") # commented because of interaction with Lead-DBS """ if settings["StimSets"]["Active"]: _logger.info( "No PAM run because you specified StimSets." "Compute the pathway activation separately." ) elif settings["CalcAxonActivation"] is False: _logger.info("Axon activation is not computed.") else: run_PAM(settings) time_1 = time.time() timings["PAM"] = time_1 - time_0 """ _logger.info(f"Timings:\n {pprint.pformat(timings)}") # write success file open( os.path.join( settings["StimulationFolder"], "success_" + settings["FailFlag"] + ".txt" ), "w", ).close() os.remove( os.path.join( settings["StimulationFolder"], "fail_" + settings["FailFlag"] + ".txt" ) ) _logger.info("Process Completed")
[docs]def main() -> None: """Main function to run OSS-DBS in CLI mode.""" parser = argparse.ArgumentParser( prog="OSS-DBS", description="Welcome to OSS-DBS v2.", epilog="Please report bugs and errors on GitHub", ) parser.add_argument( "--loglevel", type=int, help="specify verbosity of logger", default=logging.INFO ) parser.add_argument( "input_dictionary", type=str, help="input dictionary in JSON format" ) args = parser.parse_args() set_logger(level=args.loglevel) _logger.info("Loading settings from input file") _logger.debug(f"Input file: {args.input_dictionary}") with open(args.input_dictionary) as json_file: input_settings = json.load(json_file) # add the stimulation folder (where input dict.json is stored, needed for Lead-DBS) input_settings["StimulationFolder"] = os.path.dirname( os.path.abspath(args.input_dictionary) ) try: main_run(input_settings) finally: for handler in logging.getLogger("ossdbs").handlers: try: handler.flush() except Exception: pass logging.shutdown()
if __name__ == "__main__": multiprocessing.freeze_support() main()