#!/usr/bin/env python3

# integrates the FitzHugh-Nagumo model, making phase space and time evolution plots
# we use the Scipy ode package: "dop853" refers to an eighth-order Runge-Kutta method (basically a fancy Euler method)

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import ode
plt.rcParams.update({'font.size':14})

# DEFINE FUNCTIONS
# the model itself
def model(t,x,params):
    a,b,c = params
    xdot = np.zeros(2)
    xdot[0] = -x[1]+x[0]*(a-x[0])*(x[0]-1)
    xdot[1] = b*x[0]-c*x[1] 

    return xdot

# DEFINE PARAMETERS
# you will have to change this
a = 0.1 

# we recommend not changing this
b = 0.01 
c = 0.02

# set up start and end times of the integration
tstart = 0
tend = 1000

# timestep 
dt = 0.1

# initial condition 
x0 = [0.2,0.0]

# INTEGRATE AND PLOT
t = np.linspace(tstart,tend,int((tend-tstart)/dt))
x = np.zeros((len(t),2))
x[0,:] = x0 

model_instance = ode(model).set_integrator("dop853")
model_instance.set_initial_value(x[0,:]).set_f_params([a,b,c])

for i in range(1,len(t)):
    model_instance.integrate(model_instance.t+dt)
    x[i,:] = model_instance.y


# plot phase space evolution
plt.plot(x[:,0],x[:,1],label='trajectory')
plt.xlabel('x')
plt.ylabel('y')

# plot nullclines
x_null = np.linspace(-0.25,1,100)
plt.plot(x_null,x_null*(a-x_null)*(x_null-1),label='dx/dt = 0')
y_null = np.linspace(-0.05,0.3,100)
plt.plot(y_null*c/b,y_null,label='dy/dt=0')
plt.legend()

# plot time evolution plot
plt.figure()
plt.plot(t,x[:,0])
plt.xlabel('Time')
plt.ylabel('x')


plt.show()


