Skip to content

Motoneuron

This example demonstrates multi-objective optimization of a biophysical motoneuron model using the joint surrogate model. The motoneuron model is derived from electrophysiological recordings of mouse embryonic stem cell-derived motoneurons (Bhatt et al., J. Neurosci 2004).

The optimization targets five electrophysiological objectives (input resistance, membrane time constant, frequency-current relationship, spike amplitude, and ISI adaptation) subject to eight feasibility constraints. It uses the JointFTTransformer custom training surrogate which trains a single multi-task model over all objectives and constraints simultaneously.

Requirements

  • NEURON simulator
  • Keras 3 with a backend of your choice (see joint model docs)
  • click matplotlib mpi4py numpy pyyaml scipy
  • Compile the NMODL mechanisms: cd examples/motoneuron/mechanisms && nrnivmodl

Running the example

bash
export KERAS_BACKEND=torch
mpirun -n 8 python example_dmosopt_motoneuron.py

Source

py
import os
import sys
import logging

import click
import numpy as np
import yaml
from functools import partial
from numpy.random import default_rng
from neuron import h
from neuron import load_mechanisms
from mpi4py import MPI
from scipy import optimize, signal

from dmosopt import dmosopt

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


load_mechanisms(os.path.join(os.path.dirname(__file__), "mechanisms"))
h.load_file("stdrun.hoc")

pc = h.ParallelContext()
if hasattr(pc, "mpiabort_on_error"):
    pc.mpiabort_on_error(0)


class BRK:
    def __init__(self, params=None):
        if params is not None:
            params = params["BoothRinzelKiehn"]

        # Create sections
        self.soma = h.Section(name="soma", cell=self)
        self.dend = h.Section(name="dend", cell=self)

        # hack: use default attribute to trigger detection
        self.hillock = self.dend

        # Initialize position coordinates
        self.x = self.y = self.z = 0

        # Create section lists
        self.sections = h.SectionList()
        self.all = h.SectionList()

        if params is not None:
            self.set_parameters(params)
        else:
            self.set_default_parameters()

        self.init_topology()
        self.geometry()
        self.biophys()

        # Add sections to lists
        for sec in [self.soma, self.dend]:
            self.all.append(sec)
        self.sections = list(self.all)

    def set_default_parameters(self):
        self.pp = 0.5  # proportion of area taken up by soma
        self.Ltotal = 400 / np.pi  # total length of compartments
        self.gc = 10.5  # mS/cm2; Ra in ohm-cm

        self.global_e_pas = -60
        self.soma_g_pas = 0.0001
        self.soma_gmax_Na = 0.00030
        self.soma_gmax_K = 0.00010
        self.soma_gmax_KCa = 0.0005
        self.soma_gmax_CaN = 0.00010

        self.soma_f_Caconc = 0.004
        self.soma_alpha_Caconc = 1
        self.soma_kCa_Caconc = 8

        self.dend_g_pas = 0.0001
        self.dend_gmax_CaN = 0.00010
        self.dend_gmax_CaL = 0.00010
        self.dend_gmax_KCa = 0.00015

        self.dend_f_Caconc = 0.004
        self.dend_alpha_Caconc = 1
        self.dend_kCa_Caconc = 8

        self.global_cm = 3
        self.global_diam = 10  # Default value
        self.cm_ratio = 1

    def set_parameters(self, params):
        self.pp = params.get("pp")
        self.Ltotal = params.get("Ltotal")
        self.gc = params.get("gc")

        self.global_diam = params.get("global_diam")
        self.global_cm = params.get("global_cm")
        self.cm_ratio = params.get("cm_ratio", 1.0)

        self.global_e_pas = params.get("e_pas", -60)
        self.soma_g_pas = params.get("soma_g_pas")
        self.soma_gmax_Na = params.get("soma_gmax_Na")
        self.soma_gmax_K = params.get("soma_gmax_K")
        self.soma_gmax_KCa = params.get("soma_gmax_KCa")
        self.soma_gmax_CaN = params.get("soma_gmax_CaN")

        self.soma_f_Caconc = params.get("soma_f_Caconc")
        self.soma_alpha_Caconc = params.get("soma_alpha_Caconc")
        self.soma_kCa_Caconc = params.get("soma_kCa_Caconc")

        self.dend_g_pas = params.get("dend_g_pas")
        self.dend_gmax_CaN = params.get("dend_gmax_CaN")
        self.dend_gmax_CaL = params.get("dend_gmax_CaL")
        self.dend_gmax_KCa = params.get("dend_gmax_KCa")

        self.dend_f_Caconc = params.get("dend_f_Caconc")
        self.dend_alpha_Caconc = params.get("dend_alpha_Caconc")
        self.dend_kCa_Caconc = params.get("dend_kCa_Caconc")

    def lambda_f(self, section, freq):
        if section.n3d() < 2:
            return 1e5 * np.sqrt(
                section.diam / (4 * np.pi * freq * section.Ra * section.cm)
            )

        x1 = section.arc3d(0)
        d1 = section.diam3d(0)
        lam = 0

        for i in range(1, section.n3d()):
            x2 = section.arc3d(i)
            d2 = section.diam3d(i)
            lam += (x2 - x1) / np.sqrt(d1 + d2)
            x1, d1 = x2, d2

        lam *= np.sqrt(2) * 1e-5 * np.sqrt(4 * np.pi * freq * section.Ra * section.cm)
        return section.L / lam

    def init_topology(self):
        self.dend.connect(self.soma(1), 0)

    def geometry(self):
        self.init_dx()
        self.init_diam()
        self.init_nseg()

    def init_dx(self):
        self.soma.L = self.pp * self.Ltotal
        self.dend.L = (1 - self.pp) * self.Ltotal

    def init_diam(self):
        self.soma.diam = self.global_diam
        self.dend.diam = self.global_diam

    def init_nseg(self, freq=100, d_lambda=0.1):
        for sec in [self.soma, self.dend]:
            nseg = (
                int((sec.L / (d_lambda * self.lambda_f(sec, freq)) + 0.9) / 2) * 2 + 1
            )
            sec.nseg = nseg

    def init_ic(self, v_init):
        h.finitialize(v_init)
        seg = self.soma(0.5)
        self.soma.ic_constant = -(seg.ina + seg.ik + seg.ica + seg.i_pas)

    def biophys(self):
        # Set global parameters
        for sec in [self.soma, self.dend]:
            sec.Ra = 1
            sec.Ra = (
                1e-6 / (self.gc / self.pp * (h.area(0.5, sec=sec) * 1e-8) * 1e-3)
            ) / (2 * h.ri(0.5, sec=sec))
            sec.cm = self.global_cm

        # Soma-specific parameters
        self.soma.cm = self.global_cm * self.cm_ratio

        self.soma.insert("pas")
        self.soma.insert("constant")
        self.soma.insert("Na_conc")
        self.soma.insert("K_conc")
        self.soma.insert("Ca_conc")
        self.soma.insert("Kdr")
        self.soma.insert("Nas")
        self.soma.insert("CaN")
        self.soma.insert("KCa")
        self.soma.insert("extracellular")  # For stimulation

        self.soma.gmax_Nas = self.soma_gmax_Na
        self.soma.gmax_Kdr = self.soma_gmax_K
        self.soma.gmax_CaN = self.soma_gmax_CaN
        self.soma.gmax_KCa = self.soma_gmax_KCa

        self.soma.f_Ca_conc = self.soma_f_Caconc
        self.soma.alpha_Ca_conc = self.soma_alpha_Caconc
        self.soma.kCa_Ca_conc = self.soma_kCa_Caconc

        self.soma.g_pas = self.soma_g_pas
        self.soma.e_pas = self.global_e_pas

        # Dendrite-specific parameters
        self.dend.insert("pas")
        self.dend.insert("CaN")
        self.dend.insert("CaL")
        self.dend.insert("KCa")
        self.dend.insert("Ca_conc")
        self.dend.insert("K_conc")
        self.dend.insert("extracellular")  # For stimulation

        self.dend.f_Ca_conc = self.dend_f_Caconc
        self.dend.alpha_Ca_conc = self.dend_alpha_Caconc
        self.dend.kCa_Ca_conc = self.dend_kCa_Caconc

        self.dend.g_pas = self.dend_g_pas
        self.dend.e_pas = self.global_e_pas

        self.dend.gmax_CaN = self.dend_gmax_CaN
        self.dend.gmax_CaL = self.dend_gmax_CaL
        self.dend.gmax_KCa = self.dend_gmax_KCa

    def position(self, x, y, z):
        for sec in [self.soma, self.dend]:
            for i in range(sec.n3d()):
                h.pt3dchange(
                    i,
                    x - self.x + sec.x3d(i),
                    y - self.y + sec.y3d(i),
                    z - self.z + sec.z3d(i),
                    sec.diam3d(i),
                    sec=sec,
                )
        self.x, self.y, self.z = x, y, z

    def is_art(self):
        return False

    def is_reduced(self):
        return True

    def __repr__(self):
        return "BRK"


def ic_constant_f(
    x,
    template_class,
    param_dict,
    ic_constant,
    v_hold=-60,
    tstop=1000.0,
    dt=0.01,
    record_dt=0.01,
    celsius=36.0,
    use_cvode=False,
    use_coreneuron=True,
):
    h.cvode.use_fast_imem(1)
    h.cvode.cache_efficient(1)
    h.secondorder = 2
    h.dt = dt
    if record_dt < dt:
        record_dt = dt
    if use_cvode:
        h.cvode.active(1)
    if use_coreneuron:
        from neuron import coreneuron

        coreneuron.enable = True
    h.celsius = celsius

    cell = template_class({"BoothRinzelKiehn": param_dict})

    vec_t = h.Vector()
    vec_v = h.Vector()
    vec_t.record(h._ref_t, record_dt)
    vec_v.record(cell.soma(0.5)._ref_v, record_dt)

    h.tstop = tstop
    h.v_init = v_hold
    h.init()
    cell.soma.ic_constant = ic_constant + round(x, 6)
    h.finitialize(h.v_init)
    h.finitialize(h.v_init)

    h.run()

    t = vec_t.as_numpy()  # noqa: F841
    v = vec_v.as_numpy()
    mean_v = np.mean(v) if np.max(v) < 0.0 else 0.0
    return mean_v - v_hold


def run_iclamp(
    cell,
    amp,
    t0,
    t1,
    v_init=-65,
    tstop=1000.0,
    dt=0.01,
    record_dt=0.01,
    celsius=36.0,
    use_cvode=False,
    use_coreneuron=True,
):
    h.cvode.use_fast_imem(1)
    h.cvode.cache_efficient(1)
    h.secondorder = 2
    h.dt = dt
    if record_dt < dt:
        record_dt = dt
    if use_cvode:
        h.cvode.active(1)
    if use_coreneuron:
        from neuron import coreneuron

        coreneuron.enable = True
    h.celsius = celsius

    vec_t = h.Vector()
    vec_v = h.Vector()
    vec_t.record(h._ref_t, record_dt)
    vec_v.record(cell.soma(0.5)._ref_v, record_dt)

    stim = h.IClamp(0.5, sec=cell.soma)
    stim.delay = t0
    stim.dur = t1 - t0
    stim.amp = amp

    if tstop < t1 + 1.0:
        tstop = t1 + 1.0

    h.tstop = tstop
    h.v_init = v_init
    h.init()
    h.finitialize(h.v_init)
    h.run()

    return np.array(vec_t), np.array(vec_v)


def run_iclamp_steps(cell, Isteps, **kwargs):
    results = []
    for amp, t0, t1 in Isteps:
        try:
            t, v = run_iclamp(cell, amp, t0, t1, **kwargs)
        except Exception:
            results.append(None)
        else:
            results.append({"t": t, "v": v})
    return results


fi_value_dtype = np.dtype([("frequency", float)])
isi_value_dtype = np.dtype(
    [
        ("first", float),
        ("last", float),
        ("ratio", float),
        ("mean", float),
        ("std", float),
        ("N", int),
    ]
)


def measure_deflection(t, v, t0, t1, stim_amp=None):
    start_index = int(np.argwhere(t >= t0 * 0.999)[0, 0])
    end_index = int(np.argwhere(t >= t1 * 0.999)[0, 0])
    deflect_fn = np.argmax if (stim_amp is not None and stim_amp > 0) else np.argmin
    v_window = v[start_index:end_index]
    peak_index = deflect_fn(v_window) + start_index
    return {
        "t_peak": t[peak_index],
        "v_peak": v[peak_index],
        "peak_index": peak_index,
        "t_baseline": t[start_index],
        "v_baseline": v[start_index],
        "baseline_index": start_index,
        "stim_amp": stim_amp,
    }


def fit_membrane_time_constant(t, v, t0, t1, rmse_max_tol=1.0):
    def exp_curve(x, a, inv_tau, y0):
        return y0 + a * np.exp(-inv_tau * x)

    start_index = int(np.argwhere(t >= t0 * 0.999)[0, 0])
    end_index = int(np.argwhere(t >= t1 * 0.999)[0, 0])
    p0 = (v[start_index] - v[end_index], 0.1, v[end_index])
    t_window = (t[start_index:end_index] - t[start_index]).astype(np.float64)
    v_window = v[start_index:end_index].astype(np.float64)
    try:
        popt, pcov = optimize.curve_fit(exp_curve, t_window, v_window, p0=p0)
    except (TypeError, RuntimeError):
        return np.nan, np.nan, np.nan

    pred = exp_curve(t_window, *popt)
    rmse = np.sqrt(np.mean((pred - v_window) ** 2))
    if rmse > rmse_max_tol:
        return np.nan, np.nan, np.nan
    return popt


def measure_time_constant(
    t, v, t0, t1, stim_amp, frac=0.1, baseline_interval=100.0, min_snr=20.0
):
    if np.max(t) < t0 or np.max(t) < t1:
        return np.nan

    deflection_results = measure_deflection(t, v, t0, t1, stim_amp)
    v_peak = deflection_results["v_peak"]
    peak_index = deflection_results["peak_index"]
    v_baseline = deflection_results["v_baseline"]
    start_index = deflection_results["baseline_index"]

    signal_val = np.abs(v_baseline - v_peak)
    noise_interval_start_index = int(
        np.argwhere(t >= (t0 - baseline_interval) * 0.999)[0, 0]
    )
    noise = np.std(v[noise_interval_start_index:start_index])

    snr = np.inf if noise == 0 else signal_val / noise
    if snr < min_snr:
        return np.nan

    search_result = np.flatnonzero(
        v[start_index:] <= frac * (v_peak - v_baseline) + v_baseline
    )
    if not search_result.size:
        return np.nan

    fit_start_index = search_result[0] + start_index
    fit_end_index = peak_index
    fit_start = t[fit_start_index]
    fit_end = t[fit_end_index]

    if not (fit_start < fit_end):
        return np.nan

    a, inv_tau, y0 = fit_membrane_time_constant(t, v, fit_start, fit_end)
    return 1.0 / inv_tau


def measure_passive(t, v, t0, t1, stim_amp):
    if np.max(t) < t0 or np.max(t) < t1:
        return {"Rinp": np.nan, "tau": np.nan}
    deflection_results = measure_deflection(t, v, t0, t1, stim_amp=stim_amp)
    v_peak = deflection_results["v_peak"]
    v_baseline = deflection_results["v_baseline"]
    Rinp = (v_peak - v_baseline) / stim_amp
    tau = measure_time_constant(t, v, t0, t1, stim_amp)
    return {"Rinp": Rinp, "tau": tau}


def detect_spikes(T, Y, t0, t1, before_peak=50.0):
    spk_info_dtype = np.dtype(
        [
            ("Vpeak", float),
            ("Tpeak", float),
            ("amplitude", float),
            ("T0", float),
            ("T1", float),
        ]
    )

    dt = np.mean(np.diff(T))
    pre_period_idxs = np.argwhere(T < t0 - before_peak).flat
    pre_peak_info = signal.find_peaks(
        Y[pre_period_idxs], height=-20.0, width=(None, int(before_peak / dt))
    )
    N_peaks_pre = len(pre_peak_info[0])

    spk_period_idxs = np.argwhere(np.logical_and(T >= t0 - before_peak, T <= t1)).flat
    T_spk = T[spk_period_idxs]
    Y_spk = Y[spk_period_idxs]
    peak_info = signal.find_peaks(
        Y_spk, height=-20.0, width=(None, int(before_peak / dt))
    )
    peak_idxs = peak_info[0]
    if len(peak_idxs) == 0:
        return N_peaks_pre, 0, None, None

    dydt = np.gradient(Y_spk, T_spk)
    mean_dydt = np.mean(dydt)
    peak_idx = peak_idxs[0]
    T_peak = T_spk[peak_idx]
    T_before_idxs = np.argwhere(
        np.isclose(T_spk, T_peak - before_peak, rtol=1e-4, atol=1e-4)
    )
    if len(T_before_idxs) == 0:
        return N_peaks_pre, 0, None, None

    T_before_idx = T_before_idxs[0][0]
    sd_dydt = np.std(dydt[T_before_idx:peak_idx])
    dydt_threshold = mean_dydt + 2 * sd_dydt
    try:
        threshold_idx = np.argwhere(dydt[T_before_idx:peak_idx] >= dydt_threshold)[0]
    except Exception:
        return N_peaks_pre, 0, None, None

    threshold = Y_spk[T_before_idx:peak_idx][threshold_idx][0]

    period_idxs = np.argwhere(np.logical_and(T >= t0, T <= t1)).flat
    T = T[period_idxs]
    Y = Y[period_idxs]

    threshold_crossings = np.diff(Y > threshold, prepend=False)
    up_crossing_idx = np.argwhere(threshold_crossings)[::2, 0]
    crossing_idx = np.argwhere(threshold_crossings)[:, 0]
    N_peaks = len(up_crossing_idx)

    spk_info = None
    if N_peaks > 0:
        Y_intervals = np.split(Y, crossing_idx[1::2])[:-1]
        T_intervals = np.split(T, crossing_idx[1::2])[:-1]
        peak_idxs = [np.argmax(Yi) for Yi in Y_intervals]
        N_peaks = len(peak_idxs)
        Y_peaks, T_peaks, peak_amps = [], [], []
        for j, (T_interval, pidx) in enumerate(zip(T_intervals, peak_idxs)):
            if len(T_interval) < 2:
                N_peaks -= 1
                continue
            peak_amps.append(np.max(Y_intervals[j]) - threshold)
            Y_peaks.append(Y_intervals[j][pidx])
            T_peaks.append(
                T_interval[pidx] if pidx < len(T_interval) else T_interval[-1]
            )

        if N_peaks > 0:
            spk_info = np.zeros(shape=(N_peaks,), dtype=spk_info_dtype)
            for p in range(N_peaks):
                spk_info[p]["Vpeak"] = Y_peaks[p]
                spk_info[p]["Tpeak"] = T_peaks[p]
                spk_info[p]["T0"] = T_intervals[p][0]
                spk_info[p]["T1"] = T_intervals[p][-1]
                spk_info[p]["amplitude"] = peak_amps[p]

    return N_peaks_pre, N_peaks, threshold, spk_info


def measure_spike_features(iclamp_results, t0, t1):
    N_sweeps = len(iclamp_results)
    pre_spk_cnt = np.zeros(shape=(N_sweeps, 1), dtype=int)
    spk_cnt = np.zeros(shape=(N_sweeps, 1), dtype=int)
    spk_infos = []
    thresholds = np.zeros(shape=(N_sweeps,), dtype=np.float32)
    mean_spike_amplitudes = np.zeros(shape=(N_sweeps,), dtype=np.float32)
    for i in range(N_sweeps):
        if iclamp_results[i] is None:
            continue
        T = iclamp_results[i]["t"]
        Y = iclamp_results[i]["v"]
        N_peaks_pre, N_peaks, threshold, spk_info = detect_spikes(T, Y, t0, t1)
        spk_infos.append(spk_info)
        pre_spk_cnt[i, 0] = N_peaks_pre
        spk_cnt[i, 0] = N_peaks
        if threshold is not None:
            thresholds[i] = threshold
        if spk_info is not None:
            mean_spike_amplitudes[i] = np.mean(spk_info["amplitude"])
    return pre_spk_cnt, spk_cnt, spk_infos, thresholds, mean_spike_amplitudes


def measure_fI(spk_cnt, t0, t1, cur_steps):
    N_inj = len(cur_steps)
    fI_array = np.zeros(shape=(N_inj,), dtype=fi_value_dtype)
    for i in range(N_inj):
        if i >= len(spk_cnt):
            break
        fI_array[i]["frequency"] = spk_cnt[i, 0] * 1000 / (t1 - t0)
    return fI_array


def measure_ISI(cur_steps_amp, spk_infos):
    N = len(cur_steps_amp)
    ISI_array = np.zeros(shape=(N,), dtype=isi_value_dtype)
    for field in ["first", "last", "ratio", "mean", "std"]:
        ISI_array[field] = np.nan
    for i in range(N):
        if i >= len(spk_infos):
            break
        spk_info = spk_infos[i]
        N_peaks = spk_info.shape[0] if spk_info is not None else 0
        ISI_array[i]["N"] = N_peaks
        if N_peaks >= 2:
            ISI = np.diff(spk_info["Tpeak"])
            ISI_array[i]["first"] = ISI[0]
            if ISI.shape[0] > 1:
                ISI_array[i]["last"] = ISI[-1]
                ISI_array[i]["ratio"] = ISI[-1] / ISI[0]
                ISI_array[i]["mean"] = np.mean(ISI)
                ISI_array[i]["std"] = np.std(ISI)
    return ISI_array


class ExperimentalProtocol:
    def __init__(self, params_dict, target_namespace=None):
        self.params_dict = params_dict
        self.target_namespace = target_namespace
        self._init_params(params_dict)

    def _init_params(self, params):
        num_config = params["Numerics"]
        self.t0 = num_config["t0"]
        self.tstop = num_config["tstop"]
        self.v_init = num_config["v_init"]
        self.use_cvode = num_config.get("adaptive", False)
        self.use_coreneuron = num_config.get("use_coreneuron", False)
        self.dt = num_config.get("dt", 0.01)
        self.record_dt = num_config.get("record_dt", self.dt)

        target_config = params["Targets"]
        if self.target_namespace is not None:
            target_config = params["Target namespaces"][self.target_namespace]

        self.v_hold = target_config["V_hold"]["val"]
        self.v_rest = target_config["V_rest"]["val"]

        Rin_config = target_config["Rin"]
        self.target_rn = (Rin_config["lower"][0], Rin_config["upper"][0])
        self.rn_exp_type = "vclamp" if "V" in Rin_config else "iclamp"

        tau0_config = target_config["tau0"]
        self.target_tau = (tau0_config["lower"][0], tau0_config["upper"][0])

        f_I_config = target_config["f_I"]
        N_exp = len(f_I_config["I"])
        self.exp_i_inj_amp_f_I = np.asarray(f_I_config["I"]) * f_I_config.get(
            "I_factor", 1.0
        )
        self.exp_i_inj_t0_f_I = f_I_config["t"][0]
        self.exp_i_inj_t1_f_I = f_I_config["t"][1]
        self.exp_i_mean_rate_f_I = (
            np.asarray(f_I_config["mean"]) if "mean" in f_I_config else None
        )
        self.exp_i_ub_rate_f_I = np.asarray(
            f_I_config.get("upper", f_I_config.get("mean"))
        )
        self.exp_i_lb_rate_f_I = np.asarray(
            f_I_config.get("lower", f_I_config.get("mean"))
        )

        spike_amp_config = target_config["spike_amp"]
        self.exp_i_ub_spk_amp = (
            np.asarray(spike_amp_config["upper"])
            if "upper" in spike_amp_config
            else None
        )
        self.exp_i_lb_spk_amp = (
            np.asarray(spike_amp_config["lower"])
            if "lower" in spike_amp_config
            else None
        )

        spike_adaptation_config = target_config["spike_adaptation"]
        self.exp_i_mean_spk_adaptation = (
            np.asarray(spike_adaptation_config["mean"])
            if "mean" in spike_adaptation_config
            else None
        )
        self.exp_i_ub_spk_adaptation = np.asarray(
            spike_adaptation_config.get(
                "upper", spike_adaptation_config.get("mean", np.zeros(N_exp))
            )
        )
        self.exp_i_lb_spk_adaptation = np.asarray(
            spike_adaptation_config.get(
                "lower", spike_adaptation_config.get("mean", np.zeros(N_exp))
            )
        )

        self.target_threshold = target_config["threshold"]

        constraint_config = target_config.get("Constraints", {})
        first_ISI_constraints = constraint_config.get("first_ISI", {})
        self.exp_first_ISI_mean = np.asarray(
            first_ISI_constraints.get("mean", np.zeros((N_exp,)))
        )
        self.exp_first_ISI_lower = np.asarray(
            first_ISI_constraints.get("lower", 0.5 * self.exp_first_ISI_mean)
        )
        self.exp_first_ISI_upper = np.asarray(
            first_ISI_constraints.get("upper", 1.5 * self.exp_first_ISI_mean)
        )

    def run_iclamp(self, cell, target, tstop=10000.0):
        target_config = self.params_dict["Targets"]
        if self.target_namespace is not None:
            target_config = self.params_dict["Target namespaces"][self.target_namespace]
        this_target_config = target_config[target]

        this_target_amp = this_target_config["I"][0] * this_target_config.get(
            "I_factor", 1.0
        )
        tstim = this_target_config.get("t", None)
        if tstim is None:
            t0 = tstop / 2.0
            t1 = t0 + 1000.0
        else:
            t0, t1 = tstim

        t, v = run_iclamp(
            cell, t0=t0, t1=t1, amp=this_target_amp, tstop=tstop, v_init=self.v_hold
        )
        return {"t": t, "v": v, "t0": t0, "t1": t1, "stim_amp": this_target_amp}

    def run_iclamp_steps(self, cell):
        Isteps = np.asarray(
            [
                (amp, self.exp_i_inj_t0_f_I, self.exp_i_inj_t1_f_I)
                for amp in self.exp_i_inj_amp_f_I
            ]
        )
        return run_iclamp_steps(
            cell,
            Isteps=Isteps,
            v_init=self.v_hold,
            record_dt=self.record_dt,
            tstop=self.tstop,
            use_cvode=self.use_cvode,
            use_coreneuron=self.use_coreneuron,
        )


def range_distance(x, lb, ub):
    """Return 0 if x is within [lb, ub], otherwise the distance to the nearest bound."""
    return 0.0 if (x >= lb) and (x <= ub) else min(abs(x - lb), abs(x - ub))


def init_cell(pp, v_hold=-60, celsius=36.0, ic_constant_val=None):
    h.cvode.use_fast_imem(1)
    h.cvode.cache_efficient(1)
    h.secondorder = 2
    h.celsius = celsius

    cell = BRK({"BoothRinzelKiehn": pp})

    h.v_init = v_hold
    h.init()

    if ic_constant_val is None:
        cell.init_ic(h.v_init)
        ic_constant_0 = cell.soma.ic_constant

        x0 = 0.0
        try:
            x0, res = optimize.brentq(
                ic_constant_f,
                -0.5,
                0.5,
                args=(BRK, pp, ic_constant_0, h.v_init),
                xtol=1e-6,
                maxiter=200,
                disp=False,
                full_output=True,
            )
        except ValueError:
            x0 = 0.0
        else:
            if not res.converged:
                x0 = 0.0

        ic_constant_val = ic_constant_0 + x0

    cell.soma.ic_constant = ic_constant_val
    h.finitialize(h.v_init)
    h.finitialize(h.v_init)
    return cell


def make_obj_fun(protocol_config_dict, feature_dtypes, target_namespace, worker):
    exp_protocol = ExperimentalProtocol(
        protocol_config_dict, target_namespace=target_namespace
    )
    return partial(obj_fun, exp_protocol, feature_dtypes)


def obj_fun(exp_protocol, feature_dtypes, pp):
    cell = init_cell(pp, v_hold=exp_protocol.v_hold)
    ic_constant_hold = cell.soma.ic_constant

    initial_v_error_hold = float(
        ic_constant_f(0.0, BRK, pp, ic_constant_hold, v_hold=exp_protocol.v_hold)
    )
    initial_v_constr = 1 if abs(initial_v_error_hold) < 1.0 else -1

    # Measure passive properties (input resistance, time constant)
    cell = init_cell(pp, v_hold=exp_protocol.v_hold, ic_constant_val=ic_constant_hold)

    rn, tau = np.nan, np.nan
    if initial_v_constr > 0:
        try:
            iclamp_results = exp_protocol.run_iclamp(cell, target="Rin", tstop=3000.0)
        except Exception:
            pass
        else:
            passive_results = measure_passive(**iclamp_results)
            rn = passive_results["Rinp"]
            tau = passive_results["tau"]

    target_rn = exp_protocol.target_rn
    target_tau = exp_protocol.target_tau
    rn_obj_value = range_distance(rn, target_rn[0], target_rn[1]) ** 2
    tau_obj_value = range_distance(tau, target_tau[0], target_tau[1]) ** 2

    tau_constr = 1 if (tau > 0.0 and tau < 1000.0) else -1
    rn_constr = 1 if (rn > 0.0 and rn < 1000.0) else -1

    # Run f-I current steps
    cell = init_cell(pp, v_hold=exp_protocol.v_hold, ic_constant_val=ic_constant_hold)
    iclamp_results = exp_protocol.run_iclamp_steps(cell)

    # Measure spike features
    pre_spk_cnt, spk_cnt, spk_infos, thresholds, mean_spike_amplitudes = (
        measure_spike_features(
            iclamp_results,
            exp_protocol.exp_i_inj_t0_f_I,
            exp_protocol.exp_i_inj_t1_f_I + 2.0,
        )
    )
    pre_spk_count_constr = -1 if np.sum(pre_spk_cnt) > 0 else 1

    ISI_values = measure_ISI(exp_protocol.exp_i_inj_amp_f_I, spk_infos)

    ISI_adaptation_dists = list(
        map(
            lambda ratio, target_range: range_distance(
                ratio * 100.0, target_range[0] * 100.0, target_range[1] * 100.0
            ),
            ISI_values["ratio"],
            zip(
                exp_protocol.exp_i_lb_spk_adaptation,
                exp_protocol.exp_i_ub_spk_adaptation,
            ),
        )
    )
    ISI_adaptation_obj_value = np.mean([dist**2 for dist in ISI_adaptation_dists])
    ISI_adaptation_constr = -1 if np.isnan(ISI_adaptation_obj_value) else 1

    first_ISI_constr = (
        1 if np.all(ISI_values["first"] > exp_protocol.exp_first_ISI_lower) else -1
    )

    fI_values = measure_fI(
        spk_cnt,
        exp_protocol.exp_i_inj_t0_f_I,
        exp_protocol.exp_i_inj_t1_f_I,
        exp_protocol.exp_i_inj_amp_f_I,
    )

    fI_mean_target_rate_diff = np.mean(
        [
            (target_rate - rate) ** 2
            for rate, target_rate in zip(
                fI_values["frequency"], exp_protocol.exp_i_mean_rate_f_I
            )
        ]
    )

    fI_range_dists = list(
        map(
            lambda rate, target_range: range_distance(
                rate, target_range[0], target_range[1]
            ),
            fI_values["frequency"],
            zip(exp_protocol.exp_i_lb_rate_f_I, exp_protocol.exp_i_ub_rate_f_I),
        )
    )
    fI_obj_value = np.mean([dist**2 for dist in fI_range_dists])

    mean_spike_amplitude_range_dists = list(
        map(
            lambda amp, target_amp: None
            if np.isnan(target_amp[0])
            else range_distance(amp, target_amp[0], target_amp[1]),
            mean_spike_amplitudes,
            zip(exp_protocol.exp_i_lb_spk_amp, exp_protocol.exp_i_ub_spk_amp),
        )
    )
    mean_spike_amplitude_obj_value = np.mean(
        [
            dist**2
            for dist in filter(
                lambda x: x is not None, mean_spike_amplitude_range_dists
            )
        ]
    )
    spike_amplitude_constr = -1 if np.isnan(mean_spike_amplitude_obj_value) else 1

    fI_rate_diff = np.diff(fI_values["frequency"][:-1])
    monotonic_fI_constr = 1 if np.all(fI_rate_diff > 0) else -1

    # Obtain ic_constant for v_rest target
    cell = init_cell(pp, v_hold=exp_protocol.v_rest)
    ic_constant_rest = cell.soma.ic_constant

    # Assemble results
    feature_values = np.asarray(
        [
            (
                ic_constant_hold,
                ic_constant_rest,
                initial_v_error_hold,
                rn,
                tau,
                fI_values,
                fI_mean_target_rate_diff,
                ISI_values,
                thresholds,
                mean_spike_amplitudes,
            )
        ],
        dtype=np.dtype(feature_dtypes),
    )

    obj_values = np.array(
        [
            rn_obj_value,
            tau_obj_value,
            fI_obj_value,
            mean_spike_amplitude_obj_value,
            # ISI_adaptation_obj_value,
        ],
        dtype=np.float32,
    )

    constr_values = np.array(
        [
            monotonic_fI_constr,
            rn_constr,
            tau_constr,
            spike_amplitude_constr,
            first_ISI_constr,
            ISI_adaptation_constr,
            pre_spk_count_constr,
            initial_v_constr,
        ],
        dtype=np.float32,
    )

    return obj_values, feature_values, constr_values


OBJECTIVE_NAMES = [
    "rn_error",
    "tau_error",
    "fI_error",
    "spike_amplitude_error",
    # "ISI_adaptation_error",
]
CONSTRAINT_NAMES = [
    "monotonic_fI",
    "rn_constr",
    "tau_constr",
    "spike_amplitude_constr",
    "first_ISI_constr",
    "ISI_adaptation_constr",
    "pre_spk_count",
    "initial_v_constr",
]

script_name = os.path.basename(__file__)


@click.command()
@click.option(
    "--config-path",
    default="config/motoneuron.yaml",
    type=click.Path(exists=True, file_okay=True, dir_okay=False),
)
@click.option(
    "--results-path",
    "-p",
    default="results",
    type=click.Path(file_okay=False, dir_okay=True),
)
@click.option("--num-epochs", default=5, type=int)
@click.option("--num-initial", default=10, type=int)
@click.option("--population-size", default=100, type=int)
@click.option("--num-generations", default=10, type=int)
@click.option("--seed", default=None, type=int)
@click.option("--optimizer", default="nsga2", type=str)
@click.option("--verbose", "-v", is_flag=True)
def main(
    config_path,
    results_path,
    num_epochs,
    num_initial,
    population_size,
    num_generations,
    seed,
    optimizer,
    verbose,
):
    comm = MPI.COMM_WORLD
    protocol_config_dict = None
    if comm.rank == 0:
        with open(config_path) as f:
            protocol_config_dict = yaml.load(f, Loader=yaml.FullLoader)

    local_random = default_rng(seed=seed) if seed is not None else None
    protocol_config_dict = comm.bcast(protocol_config_dict)

    celltype = protocol_config_dict["Celltype"]

    N_exp = len(protocol_config_dict["Targets"]["f_I"]["I"])
    feature_dtypes = [
        ("ic_constant_hold", np.float32),
        ("ic_constant_rest", np.float32),
        ("initial_v_error_hold", np.float32),
        ("rn", np.float32),
        ("tau", np.float32),
        ("fI", fi_value_dtype, N_exp),
        ("mean_fI_diff", np.float32),
        ("ISI", isi_value_dtype, N_exp),
        ("threshold", np.float32, N_exp),
        ("spike_amplitude", np.float32, N_exp),
    ]

    problem_parameters = protocol_config_dict["Parameters"]
    space = protocol_config_dict["Space"]

    dmosopt_params = {
        "opt_id": f"dmosopt_{celltype}_neuron",
        "obj_fun_init_name": "make_obj_fun",
        "obj_fun_init_module": "example_dmosopt_motoneuron",
        "obj_fun_init_args": {
            "protocol_config_dict": protocol_config_dict,
            "feature_dtypes": feature_dtypes,
            "target_namespace": None,
        },
        "problem_parameters": problem_parameters,
        "space": space,
        "objective_names": OBJECTIVE_NAMES,
        "constraint_names": CONSTRAINT_NAMES,
        "feature_dtypes": feature_dtypes,
        # Optimizer
        "optimizer": optimizer,
        "population_size": population_size,
        "num_generations": num_generations,
        # Surrogate: joint model via custom training
        "surrogate_custom_training": "dmosopt.model_transformer.joint",
        "surrogate_custom_training_kwargs": {
            "mode": "c+o",
            "epochs": "auto",
        },
        # Sampling
        "n_initial": num_initial,
        "n_epochs": num_epochs,
        "initial_maxiter": 10,
        "initial_method": "slh",
        "termination_conditions": True,
        # Output
        "file_path": f"{results_path}/dmosopt_{celltype}.h5",
        "save": True,
        "local_random": local_random,
    }

    best = dmosopt.run(dmosopt_params, verbose=True)

    if best is not None:
        import matplotlib.pyplot as plt

        bestx, besty = best
        besty_dict = dict(besty)

        fig, axes = plt.subplots(1, 2, figsize=(10, 4))

        axes[0].scatter(
            besty_dict["rn_error"],
            besty_dict["fI_error"],
            c="tab:red",
            s=10,
            label="Pareto front",
        )
        axes[0].set_xlabel("Input resistance error")
        axes[0].set_ylabel("f-I error")
        axes[0].set_title("Objective trade-off")
        axes[0].legend()

        obj_values = np.column_stack([besty_dict[n] for n in OBJECTIVE_NAMES])
        axes[1].boxplot(obj_values, tick_labels=OBJECTIVE_NAMES)
        axes[1].set_ylabel("Error")
        axes[1].set_title("Objective distributions")
        plt.setp(axes[1].get_xticklabels(), rotation=30, ha="right")

        plt.tight_layout()
        plt.savefig(f"{celltype}_results.svg")
        print(f"Results saved to {celltype}_results.svg")


if __name__ == "__main__":
    main(
        args=sys.argv[
            (
                next(
                    (
                        i
                        for i, x in enumerate(sys.argv)
                        if os.path.basename(x) == script_name
                    ),
                    0,
                )
                + 1
            ) :
        ]
    )