diff --git a/brainpy/integrators/base.py b/brainpy/integrators/base.py index fbc15c6c..bdf444f7 100644 --- a/brainpy/integrators/base.py +++ b/brainpy/integrators/base.py @@ -141,6 +141,10 @@ def state_delays(self, value): raise ValueError('Cannot set "state_delays" by users.') def _call_integral(self, *args, **kwargs): + kwargs = dict(kwargs) + t = kwargs.get('t', None) + kwargs['t'] = 0. if t is None else t + if _during_compile: jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) diff --git a/brainpy/integrators/ode/explicit_rk.py b/brainpy/integrators/ode/explicit_rk.py index 59e76994..9b425a91 100644 --- a/brainpy/integrators/ode/explicit_rk.py +++ b/brainpy/integrators/ode/explicit_rk.py @@ -178,8 +178,7 @@ def __init__(self, def build(self): # step stage - common.step(self.variables, C.DT, - self.A, self.C, self.code_lines, self.parameters) + common.step(self.variables, C.DT, self.A, self.C, self.code_lines, self.parameters) # variable update return_args = common.update(self.variables, C.DT, self.B, self.code_lines) # returns @@ -189,7 +188,8 @@ def build(self): code_scope={k: v for k, v in self.code_scope.items()}, code_lines=self.code_lines, show_code=self.show_code, - func_name=self.func_name) + func_name=self.func_name + ) class Euler(ExplicitRKIntegrator): diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index e9cb6ac9..ca258ce2 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -25,7 +25,6 @@ import brainstate.environ import jax from jax import config, numpy as jnp, devices -from jax.lib import xla_bridge from . import modes from . import scales @@ -733,8 +732,13 @@ def clear_buffer_memory( Clear name cache. Default is True. """ + if jax.__version_info__ < (0, 8, 0): + from jax.lib.xla_bridge import get_backend + else: + from jax.extend.backend import get_backend + if array: - for buf in xla_bridge.get_backend(platform).live_buffers(): + for buf in get_backend(platform).live_buffers(): buf.delete() if compilation: jax.clear_caches()