diff --git a/pynestml/codegeneration/nest_code_generator.py b/pynestml/codegeneration/nest_code_generator.py index 066109e3d..6ca0d641f 100644 --- a/pynestml/codegeneration/nest_code_generator.py +++ b/pynestml/codegeneration/nest_code_generator.py @@ -223,14 +223,23 @@ def setup_printers(self): self._gsl_variable_printer = GSLVariablePrinter(None) if self.option_exists("nest_version") and (self.get_option("nest_version").startswith("2") or self.get_option("nest_version").startswith("v2")): self._gsl_function_call_printer = NEST2GSLFunctionCallPrinter(None) + self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None) else: self._gsl_function_call_printer = NESTGSLFunctionCallPrinter(None) + self._gsl_function_call_printer_no_origin = NEST2GSLFunctionCallPrinter(None) self._gsl_printer = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer, constant_printer=self._constant_printer, function_call_printer=self._gsl_function_call_printer)) self._gsl_function_call_printer._expression_printer = self._gsl_printer + self._gsl_variable_printer_no_origin = GSLVariablePrinter(None, with_origin=False) + self._gsl_printer_no_origin = CppExpressionPrinter(simple_expression_printer=CppSimpleExpressionPrinter(variable_printer=self._gsl_variable_printer_no_origin, + constant_printer=self._constant_printer, + function_call_printer=self._gsl_function_call_printer)) + self._gsl_variable_printer_no_origin._expression_printer = self._gsl_printer_no_origin + self._gsl_function_call_printer_no_origin._expression_printer = self._gsl_printer_no_origin + # ODE-toolbox printers self._ode_toolbox_variable_printer = ODEToolboxVariablePrinter(None) self._ode_toolbox_function_call_printer = ODEToolboxFunctionCallPrinter(None) @@ -521,6 +530,7 @@ def _get_model_namespace(self, astnode: ASTModel) -> Dict: namespace["printer"] = self._nest_printer namespace["printer_no_origin"] = self._printer_no_origin namespace["gsl_printer"] = self._gsl_printer + namespace["gsl_printer_no_origin"] = self._gsl_printer_no_origin namespace["nestml_printer"] = NESTMLPrinter() namespace["type_symbol_printer"] = self._type_symbol_printer @@ -666,6 +676,9 @@ def _get_synapse_model_namespace(self, synapse: ASTModel) -> Dict: expr_ast.accept(ASTSymbolTableVisitor()) namespace["numeric_update_expressions"][sym] = expr_ast + ASTUtils.assign_numeric_non_numeric_state_variables(synapse, namespace["numeric_state_variables"], + namespace["numeric_update_expressions"] if "numeric_update_expressions" in namespace.keys() else None, namespace["update_expressions"] if "update_expressions" in namespace.keys() else None) + namespace["spike_updates"] = synapse.spike_updates # special case for NEST delay variable (state or parameter) diff --git a/pynestml/codegeneration/nest_code_generator_utils.py b/pynestml/codegeneration/nest_code_generator_utils.py index 342c2321e..4ff5c7e9a 100644 --- a/pynestml/codegeneration/nest_code_generator_utils.py +++ b/pynestml/codegeneration/nest_code_generator_utils.py @@ -58,9 +58,6 @@ def print_symbol_origin(cls, variable_symbol: VariableSymbol, variable: ASTVaria if variable_symbol.block_type == BlockType.INTERNALS: return "V_.%s" - if variable_symbol.block_type == BlockType.INPUT: - return "B_.%s" - return "" @classmethod diff --git a/pynestml/codegeneration/printers/gsl_variable_printer.py b/pynestml/codegeneration/printers/gsl_variable_printer.py index 463833a43..ff5c93c0f 100644 --- a/pynestml/codegeneration/printers/gsl_variable_printer.py +++ b/pynestml/codegeneration/printers/gsl_variable_printer.py @@ -18,8 +18,10 @@ # # You should have received a copy of the GNU General Public License # along with NEST. If not, see . +from pynestml.codegeneration.nest_code_generator_utils import NESTCodeGeneratorUtils from pynestml.codegeneration.nest_unit_converter import NESTUnitConverter from pynestml.codegeneration.printers.cpp_variable_printer import CppVariablePrinter +from pynestml.codegeneration.printers.expression_printer import ExpressionPrinter from pynestml.meta_model.ast_variable import ASTVariable from pynestml.symbols.predefined_units import PredefinedUnits from pynestml.symbols.symbol import SymbolKind @@ -33,46 +35,42 @@ class GSLVariablePrinter(CppVariablePrinter): Variable printer for C++ syntax and using the GSL (GNU Scientific Library) API from inside the ``extern "C"`` stepping function. """ - def print_variable(self, node: ASTVariable) -> str: + def __init__(self, expression_printer: ExpressionPrinter, with_origin: bool = True, ): + super().__init__(expression_printer) + self.with_origin = with_origin + + def print_variable(self, variable: ASTVariable) -> str: """ Converts a single name reference to a gsl processable format. - :param node: a single variable + :param variable: a single variable :return: a gsl processable format of the variable """ - assert isinstance(node, ASTVariable) - symbol = node.get_scope().resolve_to_symbol(node.get_complete_name(), SymbolKind.VARIABLE) + assert isinstance(variable, ASTVariable) + symbol = variable.get_scope().resolve_to_symbol(variable.get_complete_name(), SymbolKind.VARIABLE) if symbol is None: # test if variable name can be resolved to a type - if PredefinedUnits.is_unit(node.get_complete_name()): - return str(NESTUnitConverter.get_factor(PredefinedUnits.get_unit(node.get_complete_name()).get_unit())) + if PredefinedUnits.is_unit(variable.get_complete_name()): + return str( + NESTUnitConverter.get_factor(PredefinedUnits.get_unit(variable.get_complete_name()).get_unit())) - code, message = Messages.get_could_not_resolve(node.get_name()) + code, message = Messages.get_could_not_resolve(variable.get_name()) Logger.log_message(log_level=LoggingLevel.ERROR, code=code, message=message, - error_position=node.get_source_position()) + error_position=variable.get_source_position()) return "" - if node.is_delay_variable(): - return self._print_delay_variable(node) + if variable.is_delay_variable(): + return self._print_delay_variable(variable) if symbol.is_state() and not symbol.is_inline_expression: - if "_is_numeric" in dir(node) and node._is_numeric: + if "_is_numeric" in dir(variable) and variable._is_numeric: # ode_state[] here is---and must be---the state vector supplied by the integrator, not the state vector in the node, node.S_.ode_state[]. - return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(node.get_complete_name()) + "]" - - # non-ODE state symbol - return "node.S_." + CppVariablePrinter._print_cpp_name(node.get_complete_name()) - - if symbol.is_parameters(): - return "node.P_." + super().print_variable(node) - - if symbol.is_internals(): - return "node.V_." + super().print_variable(node) + return "ode_state[State_::" + CppVariablePrinter._print_cpp_name(variable.get_complete_name()) + "]" if symbol.is_input(): - return "node.B_." + self._print_buffer_value(node) + return "node.B_." + self._print_buffer_value(variable) - raise Exception("Unknown node type") + return self._print(variable, symbol, with_origin=self.with_origin) def _print_delay_variable(self, variable: ASTVariable) -> str: """ @@ -104,3 +102,11 @@ def _print_buffer_value(self, variable: ASTVariable) -> str: return "spike_inputs_grid_sum_[node." + var_name + " - node.MIN_SPIKE_RECEPTOR]" return variable_symbol.get_symbol_name() + '_grid_sum_' + + def _print(self, variable, symbol, with_origin: bool = True): + variable_name = CppVariablePrinter._print_cpp_name(variable.get_complete_name()) + + if with_origin: + return "node." + NESTCodeGeneratorUtils.print_symbol_origin(symbol, variable) % variable_name + + return "node." + variable_name diff --git a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 index b3cd2142b..8a49d7bd8 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/common/SynapseHeader.h.jinja2 @@ -69,6 +69,19 @@ along with NEST. If not, see . #include "volume_transmitter.h" {%- endif %} +{%- if uses_numeric_solver %} +{%- if numeric_solver == "rk45" %} + +#ifndef HAVE_GSL +#error "The GSL library is required for the Runge-Kutta solver." +#endif + +// External includes: +#include +#include +#include +{%- endif %} +{%- endif %} // Includes from sli: #include "dictdatum.h" @@ -101,9 +114,18 @@ namespace {{names_namespace}} const Name _{{sym.get_symbol_name()}}( "{{sym.get_symbol_name()}}" ); {%- endfor %} {%- endif %} -} +} // end namespace {{names_namespace}}; -class {{synapseName}}CommonSynapseProperties : public CommonSynapseProperties { +namespace {{ synapseName }} +{ +{%- if uses_numeric_solver %} +{%- for s in utils.create_integrate_odes_combinations(astnode) %} +extern "C" inline int {{synapseName}}_dynamics{% if s | length > 0 %}_{{ s }}{% endif %}( double, const double ode_state[], double f[], void* pnode ); +{%- endfor %} +{%- endif %} + +class {{synapseName}}CommonSynapseProperties : public CommonSynapseProperties +{ public: {{synapseName}}CommonSynapseProperties() @@ -214,43 +236,7 @@ public: } {%- endif %} -}; - -template < typename targetidentifierT > -class {{synapseName}} : public Connection< targetidentifierT > -{ -{%- if paired_neuron_name | length > 0 %} - typedef {{ paired_neuron_name }} post_neuron_t; - -{% endif %} -{%- if vt_ports is defined and vt_ports|length > 0 %} -public: -{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} - void trigger_update_weight( size_t t, - const std::vector< spikecounter >& vt_spikes, - double t_trig, - const {{synapseName}}CommonSynapseProperties& cp ); -{%- else %} - void trigger_update_weight( thread t, - const std::vector< spikecounter >& vt_spikes, - double t_trig, - const {{synapseName}}CommonSynapseProperties& cp ); -{%- endif %} -{%- endif %} -private: - double t_lastspike_; -{%- if vt_ports is defined and vt_ports|length > 0 %} - // time of last update, which is either time of last presyn. spike or time-driven update - double t_last_update_; - - // vt_spikes_idx_ refers to the vt spike that has just been processed after trigger_update_weight - // a pseudo vt spike at t_trig is stored at index 0 and vt_spikes_idx_ = 0 -{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} - size_t vt_spikes_idx_; -{%- else %} - index vt_spikes_idx_; -{%- endif %} -{%- endif %} +}; // end class {{synapseName}}CommonSynapseProperties /** * Dynamic state of the synapse. @@ -293,17 +279,19 @@ private: //! state vector, must be C-array for GSL solver double ode_state[STATE_VEC_SIZE]; - // state variables from state block +{# // state variables from state block#} {%- filter indent(4,True) %} {%- for variable_symbol in synapse.get_state_symbols() %} -{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} -{%- include "directives_cpp/MemberDeclaration.jinja2" %} +{% if variable_symbol.get_symbol_name() not in numeric_state_variables %} +{%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} +{%- include "directives_cpp/MemberDeclaration.jinja2" %} +{%- endif %} {%- endfor %} {%- endfilter %} {%- endif %} State_() {}; - }; + }; // end State_ /** * Free parameters of the synapse. @@ -338,10 +326,49 @@ private: {%- endif %} {%- endfor %} {%- endfilter %} + double __gsl_abs_error_tol; + double __gsl_rel_error_tol; /** Initialize parameters to their default values. */ Parameters_() {}; - }; + }; // end Parameters_ + + +template < typename targetidentifierT > +class {{synapseName}} : public Connection< targetidentifierT > +{ +{%- if paired_neuron_name | length > 0 %} + typedef {{ paired_neuron_name }} post_neuron_t; + +{% endif %} +{%- if vt_ports is defined and vt_ports|length > 0 %} +public: +{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} + void trigger_update_weight( size_t t, + const std::vector< spikecounter >& vt_spikes, + double t_trig, + const {{synapseName}}CommonSynapseProperties& cp ); +{%- else %} + void trigger_update_weight( thread t, + const std::vector< spikecounter >& vt_spikes, + double t_trig, + const {{synapseName}}CommonSynapseProperties& cp ); +{%- endif %} +{%- endif %} +private: + double t_lastspike_; +{%- if vt_ports is defined and vt_ports|length > 0 %} + // time of last update, which is either time of last presyn. spike or time-driven update + double t_last_update_; + + // vt_spikes_idx_ refers to the vt spike that has just been processed after trigger_update_weight + // a pseudo vt spike at t_trig is stored at index 0 and vt_spikes_idx_ = 0 +{%- if not (nest_version.startswith("v2") or nest_version.startswith("v3.0") or nest_version.startswith("v3.1") or nest_version.startswith("v3.2") or nest_version.startswith("v3.3") or nest_version.startswith("v3.4")) %} + size_t vt_spikes_idx_; +{%- else %} + index vt_spikes_idx_; +{%- endif %} +{%- endif %} /** * Internal variables of the synapse. @@ -361,10 +388,30 @@ private: {%- endfor %} }; +{%- if uses_numeric_solver %} +{%- if numeric_solver == "rk45" %} + gsl_odeiv_step* __s = nullptr; //!< stepping function + gsl_odeiv_control* __c = nullptr; //!< adaptive stepsize control function + gsl_odeiv_evolve* __e = nullptr; //!< evolution function + gsl_odeiv_system __sys; //!< struct describing system + + // __integration_step should be reset with the neuron on ResetNetwork, + // but remain unchanged during calibration. Since it is initialized with + // step_, and the resolution cannot change after nodes have been created, + // it is safe to place both here. + double __step; //!< step size in ms + double __integration_step; //!< current integration time step, updated by GSL +{%- endif %} +{%- endif %} + Parameters_ P_; //!< Free parameters. State_ S_; //!< Dynamic state. Variables_ V_; //!< Internal Variables -{%- if synapse.get_state_symbols()|length > 0 or synapse.get_parameter_symbols()|length > 0 %} + +{%- for s in utils.create_integrate_odes_combinations(astnode) %} + friend int {{synapseName}}_dynamics{% if s | length > 0 %}_{{ s }}{% endif %}( double, const double ode_state[], double f[], void* pnode ); +{%- endfor %} + // ------------------------------------------------------------------------- // Getters/setters for parameters and state variables // ------------------------------------------------------------------------- @@ -390,7 +437,6 @@ inline void set_{{ variable.get_name() }}(const {{ declarations.print_variable_t {%- endif %} {%- endfor %} {%- endfilter %} -{%- endif %} // ------------------------------------------------------------------------- // Getters/setters for inline expressions @@ -424,6 +470,8 @@ inline void set_{{ variable.get_name() }}(const {{ declarations.print_variable_t void recompute_internal_variables(); + std::string get_name() const; + public: // this line determines which common properties to use typedef {{synapseName}}CommonSynapseProperties CommonPropertiesType; @@ -676,7 +724,7 @@ void get_entry_from_continuous_variable_history(double t, runner = start; while ( runner != finish ) { - if ( fabs( t - runner->t_ ) < nest::kernel().connection_manager.get_stdp_eps() ) + if ( fabs( t - runner->t_ ) < kernel().connection_manager.get_stdp_eps() ) { histentry = *runner; return; @@ -704,7 +752,7 @@ void get_entry_from_continuous_variable_history(double t, send( Event& e, const thread tid, const {{synapseName}}CommonSynapseProperties& cp ) {%- endif %} { - const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function + const double __timestep = Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function auto get_thread = [tid]() { @@ -980,7 +1028,7 @@ void get_entry_from_continuous_variable_history(double t, {%- if nest_version.startswith("v2") %} librandom::NormalRandomDev normal_dev_; //!< random deviate generator {%- else %} - nest::normal_distribution normal_dev_; //!< random deviate generator + normal_distribution normal_dev_; //!< random deviate generator {%- endif %} {%- endif %} }; @@ -990,7 +1038,7 @@ void get_entry_from_continuous_variable_history(double t, void register_{{ synapseName }}( const std::string& name ) { - nest::register_connection_model< {{ synapseName }} >( name ); + register_connection_model< {{ synapseName }} >( name ); } {%- endif %} @@ -1070,6 +1118,14 @@ void } {%- endif %} +/* +** Synapse dynamics +*/ +{% if uses_numeric_solver %} +{%- for ast in utils.get_all_integrate_odes_calls_unique(synapse) %} +{%- include "directives_cpp/GSLDifferentiationFunction.jinja2" %} +{%- endfor %} +{%- endif %} template < typename targetidentifierT > void @@ -1087,11 +1143,11 @@ void {%- if variable.get_name() == nest_codegen_opt_delay_variable %} {#- special case for NEST special variable delay #} def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, names::delay, {{ printer.print(variable) }} ); // NEST special case for delay variable -def(__d, nest::{{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, {{ printer.print(variable) }}); +def(__d, {{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, {{ printer.print(variable) }}); {#- special case for NEST special variable weight #} {%- elif variable.get_name() == synapse_weight_variable %} def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, names::weight, {{ printer.print(variable) }} ); // NEST special case for weight variable -def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, nest::{{ synapseName }}_names::_{{ synapse_weight_variable }}, {{ printer.print(variable) }} ); // NEST special case for weight variable +def< {{ declarations.print_variable_type(variable_symbol) }} >( __d, {{ synapseName }}_names::_{{ synapse_weight_variable }}, {{ printer.print(variable) }} ); // NEST special case for weight variable {%- else %} {%- include "directives_cpp/WriteInDictionary.jinja2" %} {%- endif %} @@ -1106,14 +1162,14 @@ void ConnectorModel& cm ) { {%- if synapse_weight_variable|length > 0 and synapse_weight_variable != "weight" %} - if (__d->known(nest::{{ synapseName }}_names::_{{ synapse_weight_variable }}) and __d->known(nest::names::weight)) + if (__d->known({{ synapseName }}_names::_{{ synapse_weight_variable }}) and __d->known(names::weight)) { throw BadProperty( "To prevent inconsistencies, please set either 'weight' or '{{ synapse_weight_variable }}' variable; not both at the same time." ); } {%- endif %} {%- if nest_codegen_opt_delay_variable != "delay" %} - if (__d->known(nest::{{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}) and __d->known(nest::names::delay)) + if (__d->known({{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}) and __d->known(names::delay)) { throw BadProperty( "To prevent inconsistencies, please set either 'delay' or '{{ nest_codegen_opt_delay_variable }}' variable; not both at the same time." ); } @@ -1131,17 +1187,17 @@ void {%- if variable.get_name() == nest_codegen_opt_delay_variable %} // special treatment of NEST delay double tmp_{{ nest_codegen_opt_delay_variable }} = get_delay(); -updateValue(__d, nest::{{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, tmp_{{nest_codegen_opt_delay_variable}}); +updateValue(__d, {{ synapseName }}_names::_{{ nest_codegen_opt_delay_variable }}, tmp_{{nest_codegen_opt_delay_variable}}); {%- elif variable.get_name() == synapse_weight_variable %} // special treatment of NEST weight double tmp_{{ synapse_weight_variable }} = get_weight(); -if (__d->known(nest::{{ synapseName }}_names::_{{ synapse_weight_variable }})) +if (__d->known({{ synapseName }}_names::_{{ synapse_weight_variable }})) { - updateValue(__d, nest::{{ synapseName }}_names::_{{ synapse_weight_variable }}, tmp_{{synapse_weight_variable}}); + updateValue(__d, {{ synapseName }}_names::_{{ synapse_weight_variable }}, tmp_{{synapse_weight_variable}}); } -if (__d->known(nest::names::weight)) +if (__d->known(names::weight)) { - updateValue(__d, nest::names::weight, tmp_{{synapse_weight_variable}}); + updateValue(__d, names::weight, tmp_{{synapse_weight_variable}}); } {%- else %} {%- include "directives_cpp/ReadFromDictionaryToTmp.jinja2" %} @@ -1179,13 +1235,13 @@ set_delay(tmp_{{ nest_codegen_opt_delay_variable }}); {% for invariant in synapse.get_parameter_invariants() %} if ( !({{printer.print(invariant)}}) ) { - throw nest::BadProperty("The constraint '{{nestml_printer.print(invariant)}}' is violated!"); + throw BadProperty("The constraint '{{nestml_printer.print(invariant)}}' is violated!"); } {%- endfor %} {%- endif %} // recompute internal variables in case they are dependent on parameters or state that might have been updated in this call to set_status() - V_.__h = nest::Time::get_resolution().get_ms(); + V_.__h = Time::get_resolution().get_ms(); recompute_internal_variables(); } @@ -1195,7 +1251,7 @@ set_delay(tmp_{{ nest_codegen_opt_delay_variable }}); template < typename targetidentifierT > void {{synapseName}}< targetidentifierT >::recompute_internal_variables() { - const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function + const double __timestep = Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function {% filter indent(2) %} {%- for variable_symbol in synapse.get_internal_symbols() %} @@ -1207,14 +1263,20 @@ void {{synapseName}}< targetidentifierT >::recompute_internal_variables() {%- endfilter %} } +template < typename targetidentifierT > +std::string {{synapseName}}< targetidentifierT >::get_name() const +{ + std::string s ("{{ synapseName }}"); + return s; +} + /** * constructor **/ template < typename targetidentifierT > {{synapseName}}< targetidentifierT >::{{synapseName}}() : ConnectionBase() { - const double __timestep = nest::Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function - + const double __timestep = Time::get_resolution().get_ms(); // do not remove, this is necessary for the timestep() function // initial values for parameters {%- filter indent(2, True) %} {%- for variable_symbol in synapse.get_parameter_symbols() %} @@ -1226,6 +1288,10 @@ template < typename targetidentifierT > {%- endif %} {%- endif %} {%- endfor %} +{%- if uses_numeric_solver and numeric_solver == "rk45" %} +P_.__gsl_abs_error_tol = 1e-6; +P_.__gsl_rel_error_tol = 1e-6; +{%- endif %} {%- endfilter %} // initial values for internal variables @@ -1261,6 +1327,41 @@ template < typename targetidentifierT > {%- endif %} {%- endif %} +{%- if uses_numeric_solver and numeric_solver == "rk45" %} + if ( not __s ) + { + __s = gsl_odeiv_step_alloc( gsl_odeiv_step_rkf45, State_::STATE_VEC_SIZE ); + } + else + { + gsl_odeiv_step_reset( __s ); + } + + if ( not __c ) + { + __c = gsl_odeiv_control_y_new( P_.__gsl_abs_error_tol, P_.__gsl_rel_error_tol ); + } + else + { + gsl_odeiv_control_init( __c, P_.__gsl_abs_error_tol, P_.__gsl_rel_error_tol, 1.0, 0.0 ); + } + + if ( not __e ) + { + __e = gsl_odeiv_evolve_alloc( State_::STATE_VEC_SIZE ); + } + else + { + gsl_odeiv_evolve_reset( __e ); + } + + __sys.jacobian = nullptr; + __sys.dimension = State_::STATE_VEC_SIZE; + __sys.params = reinterpret_cast< void* >( &P_ ); + __step = Time::get_resolution().get_ms(); + __integration_step = Time::get_resolution().get_ms(); +{%- endif %} + t_lastspike_ = 0.; {%- if vt_ports is defined and vt_ports|length > 0 %} t_last_update_ = 0.; @@ -1296,7 +1397,7 @@ template < typename targetidentifierT > {%- for variable_symbol in synapse.get_state_symbols() %} {%- set variable = utils.get_state_variable_by_name(astnode, variable_symbol.get_symbol_name()) %} {%- if variable.get_name() != synapse_weight_variable and variable.get_name() != nest_codegen_opt_delay_variable %} - S_.{{ printer_no_origin.print(variable) }} = rhs.S_.{{ printer_no_origin.print(variable) }}; + {{ nest_codegen_utils.print_symbol_origin(variable_symbol, variable) % printer_no_origin.print(variable) }} = rhs.{{ nest_codegen_utils.print_symbol_origin(variable_symbol, variable) % printer_no_origin.print(variable) }}; {%- endif %} {%- endfor %} @@ -1305,6 +1406,16 @@ template < typename targetidentifierT > {%- endif %} t_lastspike_ = rhs.t_lastspike_; +{%- if uses_numeric_solver and numeric_solver == "rk45" %} + // Numeric solver variables + __s = rhs.__s; + __c = rhs.__c; + __e = rhs.__e; + __sys = rhs.__sys; + __step = rhs.__step; + __integration_step = rhs.__integration_step; +{%- endif %} + // special treatment of NEST delay set_delay(rhs.get_delay()); {%- if synapse_weight_variable | length > 0 %} @@ -1473,6 +1584,7 @@ inline void {%- endif %} -} // namespace +} // namespace {{ synapseName }}; +} // end namespace nest; #endif /* #ifndef {{synapseName.upper()}}_H */ diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 index e2495a676..52641ea04 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLDifferentiationFunction.jinja2 @@ -1,16 +1,27 @@ {# Creates GSL implementation of the differentiation step for the system of ODEs. -#} -extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}(double __time, const double ode_state[], double f[], void* pnode) +{%- if neuronName is defined %} +{%- set modelName = neuronName %} +{%- else %} +{%- set modelName = synapseName %} +{%- endif %} +extern "C" inline int {{modelName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}(double __time, const double ode_state[], double f[], void* pnode) { - typedef {{neuronName}}::State_ State_; - // get access to node so we can almost work as in a member function +{%- if neuronName is defined %} + typedef {{modelName}}::State_ State_; + // get access to node so we can almost work as in a member function assert( pnode ); const {{neuronName}}& node = *( reinterpret_cast< {{neuronName}}* >( pnode ) ); - +{%- else %} + // get access to node so we can almost work as in a member function + assert( pnode ); + const Parameters_& node = *( reinterpret_cast< Parameters_* >( pnode ) ); +{%- endif %} // ode_state[] here is---and must be---the state vector supplied by the integrator, // not the state vector in the node, node.S_.ode_state[]. +{%- if neuronName is defined %} {%- for eq_block in neuron.get_equations_blocks() %} {%- for ode in eq_block.get_declarations() %} {%- for inline_expr in utils.get_inline_expression_symbols(ode) %} @@ -22,7 +33,20 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 % {%- endfor %} {%- endfor %} -{%- if use_gap_junctions %} +{%- else %} +{%- for eq_block in synapse.get_equations_blocks() %} +{%- for ode in eq_block.get_declarations() %} +{%- for inline_expr in utils.get_inline_expression_symbols(ode) %} +{%- if not inline_expr.is_equation() %} +{%- set declaring_expr = inline_expr.get_declaring_expression() %} + double {{ printer.print(utils.get_state_variable_by_name(astnode, inline_expr)) }} = {{ gsl_printer.print(declaring_expr) }}; +{%- endif %} +{%- endfor %} +{%- endfor %} +{%- endfor %} +{%- endif %} + +{%- if use_gap_junctions and neuronName is defined %} // set I_gap depending on interpolation order double __I_gap = 0.0; @@ -51,6 +75,7 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 % } {%- endif %} +{%- if neuronName is defined %} {% set numeric_state_variables_to_be_integrated = numeric_state_variables + purely_numeric_state_variables_moved %} {%- if ast.get_args() | length > 0 %} {%- set numeric_state_variables_to_be_integrated = utils.filter_variables_list(numeric_state_variables_to_be_integrated, ast.get_args()) %} @@ -65,6 +90,14 @@ extern "C" inline int {{neuronName}}_dynamics{% if ast.get_args() | length > 0 % {%- endif %} {%- endfor %} +{%- else %} +{%- for variable_name in numeric_state_variables %} +{%- set update_expr = numeric_update_expressions[variable_name] %} +{%- set variable_symbol = variable_symbols[variable_name] %} + f[State_::{{ variable_symbol.get_symbol_name() }}] = {% if ast.get_args() | length > 0 %}{% if variable_name in utils.integrate_odes_args_strs_from_function_call(ast) + utils.all_convolution_variable_names(astnode) %}{{ gsl_printer_no_origin.print(update_expr) }}{% else %}0{% endif %}{% else %}{{ gsl_printer_no_origin.print(update_expr) }}{% endif %}; +{%- endfor %} +{%- endif %} + {%- if numeric_solver == "rk45" %} return GSL_SUCCESS; {%- else %} diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 index 4a8090537..c0617ded1 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/GSLIntegrationStep.jinja2 @@ -5,7 +5,11 @@ {%- if tracing %}/* generated by {{self._TemplateReference__context.name}} */ {% endif %} {%- if numeric_solver == "rk45" %} double __t = 0; +{%- if neuronName is defined %} B_.__sys.function = {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}; +{%- else %} +__sys.function = {{synapseName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}; +{%- endif %} // numerical integration with adaptive step size control: // ------------------------------------------------------ // gsl_odeiv_evolve_apply performs only a single numerical @@ -18,11 +22,12 @@ B_.__sys.function = {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_ // enforce setting IntegrationStep to step-t; this is of advantage // for a consistent and efficient integration across subsequent // simulation intervals +{%- if neuronName is defined %} while ( __t < B_.__step ) { -{%- if use_gap_junctions %} +{%- if use_gap_junctions %} gap_junction_step = B_.__step; -{%- endif %} +{%- endif %} const int status = gsl_odeiv_evolve_apply(B_.__e, B_.__c, @@ -38,6 +43,25 @@ while ( __t < B_.__step ) throw nest::GSLSolverFailure( get_name(), status ); } } +{%- else %} +while ( __t < timestep ) +{ + const int status = gsl_odeiv_evolve_apply(__e, + __c, + __s, + &__sys, // system of ODE + &__t, // from t + timestep, // to t <= step + &__integration_step, // integration step size + S_.ode_state); // neuronal state + + if ( status != GSL_SUCCESS ) + { + throw nest::GSLSolverFailure( get_name(), status ); + } + } +{%- endif %} + {%- elif numeric_solver == "forward-Euler" %} double f[State_::STATE_VEC_SIZE]; {{neuronName}}_dynamics{% if ast.get_args() | length > 0 %}_{{ utils.integrate_odes_args_str_from_function_call(ast) }}{% endif %}( get_t(), S_.ode_state, f, reinterpret_cast< void* >( this ) ); diff --git a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 index 65f8b218e..12f8bc0e2 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/directives_cpp/PredefinedFunction_integrate_odes.jinja2 @@ -25,7 +25,12 @@ {%- if uses_numeric_solver %} +{%- if neuronName is defined %} {% set numeric_state_variables_to_be_integrated = numeric_state_variables + purely_numeric_state_variables_moved %} +{%- else %} +{% set numeric_state_variables_to_be_integrated = numeric_state_variables %} +{%- endif %} + {%- if ast.get_args() | length > 0 %} {%- set numeric_state_variables_to_be_integrated = utils.filter_variables_list(numeric_state_variables_to_be_integrated, ast.get_args()) %} {%- endif %} diff --git a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 index 3f6646d42..f9175e2c9 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClass.jinja2 @@ -132,7 +132,7 @@ void {%- if synapses %} // register synapses {%- for synapse in synapses %} - nest::register_connection_model< nest::{{synapse.get_name()}} >( "{{synapse.get_name()}}" ); + nest::register_connection_model< nest::{{synapse.get_name()}}::{{ synapse.get_name() }} >( "{{synapse.get_name()}}" ); {%- endfor %} {%- endif %} } // {{moduleName}}::init() diff --git a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 index fe2d49582..a43c3912c 100644 --- a/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 +++ b/pynestml/codegeneration/resources_nest/point_neuron/setup/common/ModuleClassMaster.jinja2 @@ -78,7 +78,7 @@ void {{moduleName}}::initialize() {%- if synapses %} // register synapses {%- for synapse in synapses %} - nest::register_{{synapse.get_name()}}( "{{synapse.get_name()}}" ); + nest::{{synapse.get_name()}}::register_{{synapse.get_name()}}( "{{synapse.get_name()}}" ); {%- endfor %} {%- endif %} -} \ No newline at end of file +} diff --git a/pynestml/transformers/synapse_post_neuron_transformer.py b/pynestml/transformers/synapse_post_neuron_transformer.py index 95d7e087e..73ed81390 100644 --- a/pynestml/transformers/synapse_post_neuron_transformer.py +++ b/pynestml/transformers/synapse_post_neuron_transformer.py @@ -294,7 +294,7 @@ def transform_neuron_synapse_pair_(self, neuron: ASTModel, synapse: ASTModel): strictly_synaptic_vars = ["t"] # "seed" this with the predefined variable t if self.option_exists("strictly_synaptic_vars") and removesuffix(synapse.get_name(), FrontendConfiguration.suffix) in self.get_option("strictly_synaptic_vars").keys() and self.get_option("strictly_synaptic_vars")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]: - strictly_synaptic_vars.append(self.get_option("strictly_synaptic_vars")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]) + strictly_synaptic_vars.extend(self.get_option("strictly_synaptic_vars")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]) if self.option_exists("delay_variable") and removesuffix(synapse.get_name(), FrontendConfiguration.suffix) in self.get_option("delay_variable").keys() and self.get_option("delay_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]: strictly_synaptic_vars.append(self.get_option("delay_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]) diff --git a/pynestml/utils/ast_utils.py b/pynestml/utils/ast_utils.py index ef51e0812..29c4179a6 100644 --- a/pynestml/utils/ast_utils.py +++ b/pynestml/utils/ast_utils.py @@ -2555,7 +2555,7 @@ def get_spike_input_ports_in_pairs(cls, neuron: ASTModel) -> Dict[int, List[Vari return rport_to_port_map @classmethod - def assign_numeric_non_numeric_state_variables(cls, neuron, numeric_state_variable_names, numeric_update_expressions, update_expressions): + def assign_numeric_non_numeric_state_variables(cls, model, numeric_state_variable_names, numeric_update_expressions, update_expressions): r"""For each ASTVariable, set the ``node._is_numeric`` member to True or False based on whether this variable will be solved with the analytic or numeric solver. Ideally, this would not be a property of the ASTVariable as it is an implementation detail (that only emerges during code generation) and not an intrinsic part of the model itself. However, this approach is preferred over setting it as a property of the variable printers as it would have to make each printer aware of all models and variables therein.""" @@ -2574,10 +2574,10 @@ def visit_variable(self, node): visitor = ASTVariableOriginSetterVisitor() visitor._numeric_state_variables = numeric_state_variable_names - neuron.accept(visitor) + model.accept(visitor) - if "extra_on_emit_spike_stmts_from_synapse" in dir(neuron): - for expr in neuron.extra_on_emit_spike_stmts_from_synapse: + if "extra_on_emit_spike_stmts_from_synapse" in dir(model): + for expr in model.extra_on_emit_spike_stmts_from_synapse: expr.accept(visitor) if update_expressions: @@ -2588,15 +2588,17 @@ def visit_variable(self, node): for expr in numeric_update_expressions.values(): expr.accept(visitor) - for update_expr_list in neuron.spike_updates.values(): + for update_expr_list in model.spike_updates.values(): for update_expr in update_expr_list: update_expr.accept(visitor) - for update_expr in neuron.post_spike_updates.values(): - update_expr.accept(visitor) + if "post_spike_updates" in dir(model): + for update_expr in model.post_spike_updates.values(): + update_expr.accept(visitor) - for node in neuron.equations_with_delay_vars + neuron.equations_with_vector_vars: - node.accept(visitor) + if "equations_with_delay_vars" in dir(model): + for node in model.equations_with_delay_vars + model.equations_with_vector_vars: + node.accept(visitor) @classmethod def depends_only_on_vars(cls, expr, vars): diff --git a/tests/nest_tests/resources/non_linear_synapse.nestml b/tests/nest_tests/resources/non_linear_synapse.nestml new file mode 100644 index 000000000..76f4985cf --- /dev/null +++ b/tests/nest_tests/resources/non_linear_synapse.nestml @@ -0,0 +1,30 @@ +model non_linear_synapse: + state: + x real = 1. + y real = 1. + z real = 1. + w real = 0. + d ms = 1.0 ms + + equations: + x' = (sigma * (y - x)) / ms + y' = (x * (rho - z) - y) / ms + z' = (x * y - beta * z) / ms + + parameters: + sigma real = 10. + beta real = 8/3 + rho real = 28 + + input: + pre_spikes <- spike + + output: + spike(weight real, delay ms) + + onReceive(pre_spikes): + w += x * y / z + emit_spike(w, d) + + update: + integrate_odes() diff --git a/tests/nest_tests/resources/stp_synapse.nestml b/tests/nest_tests/resources/stp_synapse.nestml new file mode 100644 index 000000000..b86f055bb --- /dev/null +++ b/tests/nest_tests/resources/stp_synapse.nestml @@ -0,0 +1,68 @@ +# stp_synapse.nestml +# ################## +# +# +# Description +# +++++++++++ +# +# This model is used to test vector operations with NEST. +# +# +# Copyright statement +# +++++++++++++++++++ +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . + +model stp_synapse: + input: + pre_spikes <- spike + + output: + spike(weight real, delay ms) + + state: + w real = 1 / U_0 # synaptic (baseline) weight + x real = 1. # fraction of available resources after neurotransmitter depletion + u real = U_0 # utilization parameter: fraction of available resources ready for use (release probability) + U real = U_0 # increment of u produced by a spike + + parameters: + d ms = 1 ms # synaptic transmission delay + U_0 real = 0.25 # basal release probability + K_A real = 0.0375 # controls how fast the baseline release probability increases with the activity + tau_D ms = 300 ms # depression time constant + tau_F ms = 1500 ms # facilitation time constant + tau_A ms = 20000 ms # augmentation time constant + tau_filter ms = 50 ms # filtered spike train time constant + + equations: + x' = (1. - x) / tau_D + u' = (U - u) / tau_F + U' = (U_0 - U) / tau_A + + onReceive(pre_spikes): + x -= u * x + u += U * (1. - u) + U += K_A * (1. - U) + + w_effective real = w * x * u + + emit_spike(w_effective, d) + + update: + integrate_odes() diff --git a/tests/nest_tests/test_synapse_numeric_solver.py b/tests/nest_tests/test_synapse_numeric_solver.py new file mode 100644 index 000000000..e0597d0ea --- /dev/null +++ b/tests/nest_tests/test_synapse_numeric_solver.py @@ -0,0 +1,251 @@ +# -*- coding: utf-8 -*- +# +# test_synapse_numeric_solver.py +# +# This file is part of NEST. +# +# Copyright (C) 2004 The NEST Initiative +# +# NEST is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 2 of the License, or +# (at your option) any later version. +# +# NEST is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with NEST. If not, see . +import os +import nest +import pytest +from scipy.integrate import solve_ivp + +from pynestml.codegeneration.nest_tools import NESTTools +from pynestml.frontend.pynestml_frontend import generate_target, generate_nest_target +import numpy as np + +try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.ticker + import matplotlib.pyplot as plt + + TEST_PLOTS = True +except Exception: + TEST_PLOTS = False + + +@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"), + reason="This test does not support NEST 2") +class TestSynapseNumericSolver: + """ + Tests a synapse with non-linear dynamics requiring a numeric solver for ODEs. + """ + + def test_synapse_with_numeric_solver(self): + nest.ResetKernel() + nest.set_verbosity("M_WARNING") + dt = 0.1 + nest.resolution = dt + + files = ["models/neurons/iaf_psc_exp_neuron.nestml", "tests/nest_tests/resources/stp_synapse.nestml"] + input_paths = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join( + os.pardir, os.pardir, s))) for s in files] + target_path = "target_stp" + modulename = "stp_module" + + generate_nest_target(input_path=input_paths, + target_path=target_path, + logging_level="INFO", + suffix="_nestml", + module_name=modulename, + codegen_opts={"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron", + "synapse": "stp_synapse"}], + "delay_variable": {"stp_synapse": "d"}, + "weight_variable": {"stp_synapse": "w"}}) + nest.Install(modulename) + + # properties of the generated spike train + frequency = 50 # in Hz + spike_count = 10 + step = 1000. / frequency # in ms + duration = spike_count * step + sim_time = duration + 11_000 + + spike_times = (([i * step for i in range(1, spike_count + 1)] # 10 spikes at 50Hz + + [duration + 500]) # then 500ms after + + [duration + 10_000]) # then 10s after + + # parameters for the spike generator (spike train injector) + params_sg = { + "spike_times": spike_times + } + print(spike_times) + neuron_model = "iaf_psc_exp_neuron_nestml__with_stp_synapse_nestml" + synapse_model = "stp_synapse_nestml__with_iaf_psc_exp_neuron_nestml" + + print("Creating the neuron model") + neuron = nest.Create(neuron_model) + + print("Creating spike generator") + spike_train_injector = nest.Create("spike_train_injector", params=params_sg) + + voltmeter = nest.Create("voltmeter", params={'interval': 0.1}) + spike_recorder = nest.Create("spike_recorder") + + print("Connecting the synapse") + nest.Connect(spike_train_injector, neuron, syn_spec={"synapse_model": synapse_model}) + nest.Connect(voltmeter, neuron) + nest.Connect(spike_train_injector, spike_recorder) + connections = nest.GetConnections(source=spike_train_injector, synapse_model=synapse_model) + x = [] + u = [] + U = [] + sim_step_size = 1. + for i in np.arange(0., sim_time + 0.01, sim_step_size): + nest.Simulate(sim_step_size) + syn_stats = connections.get() # nest.GetConnections()[2].get() + x += [syn_stats["x"]] + u += [syn_stats["u"]] + U += [syn_stats["U"]] + + data_vm = voltmeter.events + data_sr = spike_recorder.events + + # TODO: add assertions + + if TEST_PLOTS: + fig, ax = plt.subplots(3, 1, sharex=True, figsize=(10, 15)) + + ax[0].vlines(data_sr["times"], 0, 1) + ax[0].set_xlim([0, sim_time]) + ax[0].set_xlabel('Time (s)') + + ax[1].set_xlim([0, sim_time]) + ax[1].set_ylim([0, 1]) + ax[1].set_xlabel('Time (s)') + + ax[1].plot(x, label='x') + ax[1].plot(u, label='u') + ax[1].plot(U, label='U') + ax[1].legend(loc='best') + + ax[2].set_xlim([0, sim_time]) + ax[2].set_xlabel('Time (ms)') + + for ax_ in ax: + ax_.set_xlim([1., sim_time]) + ax_.set_xscale('log') + + ax[2].plot(data_vm["times"], data_vm["V_m"]) + + fig.tight_layout() + fig.savefig('synaug_numsim.pdf') + + def lorenz_attractor_system(self, t, state, sigma, rho, beta): + x, y, z = state + dxdt = (sigma * (y - x)) + dydt = (x * (rho - z) - y) + dzdt = (x * y - beta * z) + return [dxdt, dydt, dzdt] + + def evaluate_odes_scipy(self, sigma, rho, beta, initial_state, spike_times, sim_time): + x_arr = [] + y_arr = [] + z_arr = [] + y0 = initial_state + + t_last_spike = 0. + spike_idx = 0 + for i in np.arange(1., sim_time + 0.01, 1.0): + if spike_idx < len(spike_times) and i == spike_times[spike_idx]: + t_spike = spike_times[spike_idx] + t_span = (t_last_spike, t_spike) + print("Integrating over the iterval: ", t_span) + # Solve using RK45 + solution = solve_ivp( + fun=self.lorenz_attractor_system, + t_span=t_span, + y0=y0, # [x_arr[-1], y_arr[-1], z_arr[-1]], + args=(sigma, rho, beta), + method='RK45', + first_step=0.1, + rtol=1e-6, # relative tolerance + atol=1e-6 # absolute tolerance + ) + y0 = solution.y[:, -1] + t_last_spike = t_spike + spike_idx += 1 + + x_arr += [y0[0]] + y_arr += [y0[1]] + z_arr += [y0[2]] + + return x_arr, y_arr, z_arr + + def test_non_linear_synapse(self): + nest.ResetKernel() + nest.set_verbosity("M_WARNING") + dt = 0.1 + nest.resolution = dt + sim_time = 8.0 + spike_times = [3.0, 5.0, 7.0] + + files = ["models/neurons/iaf_psc_exp_neuron.nestml", "tests/nest_tests/resources/non_linear_synapse.nestml"] + input_paths = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join( + os.pardir, os.pardir, s))) for s in files] + target_path = "target_nl" + modulename = "nl_syn_module" + + generate_nest_target(input_path=input_paths, + target_path=target_path, + logging_level="INFO", + suffix="_nestml", + module_name=modulename, + codegen_opts={"neuron_synapse_pairs": [{"neuron": "iaf_psc_exp_neuron", + "synapse": "non_linear_synapse"}], + "delay_variable": {"non_linear_synapse": "d"}, + "weight_variable": {"non_linear_synapse": "w"}, + "strictly_synaptic_vars": {"non_linear_synapse": ["x", "y", "z"]}}) + nest.Install(modulename) + + neuron_model = "iaf_psc_exp_neuron_nestml__with_non_linear_synapse_nestml" + synapse_model = "non_linear_synapse_nestml__with_iaf_psc_exp_neuron_nestml" + + neuron = nest.Create(neuron_model) + sg = nest.Create("spike_generator", params={"spike_times": spike_times}) + + nest.Connect(sg, neuron, syn_spec={"synapse_model": synapse_model}) + connections = nest.GetConnections(source=sg, synapse_model=synapse_model) + + # Get the parameter values + sigma = connections.get("sigma") + rho = connections.get("rho") + beta = connections.get("beta") + + # Initial values of state variables + inital_state = [connections.get("x"), connections.get("y"), connections.get("z")] + + # Scipy simulation + x_expected, y_expected, z_expected = self.evaluate_odes_scipy(sigma, rho, beta, inital_state, spike_times, sim_time) + + # NEST simulation + x = [] + y = [] + z = [] + sim_step_size = 1. + for i in np.arange(0., sim_time, sim_step_size): + nest.Simulate(sim_step_size) + syn_stats = connections.get() # nest.GetConnections()[2].get() + x += [syn_stats["x"]] + y += [syn_stats["y"]] + z += [syn_stats["z"]] + + # TODO: Adjust tolerance + np.testing.assert_allclose(x, x_expected, atol=1e-2, rtol=1e-2) + np.testing.assert_allclose(y, y_expected, atol=1e-2, rtol=1e-2) + np.testing.assert_allclose(z, z_expected, atol=1e-2, rtol=1e-2)