import numpy as np
from scipy.integrate import odeint
import matplotlib.pyplot as plt

# Define the system of ODEs with 3 state variables
def nmda_model_3var(y, t, params):
    """
    Defines the differential equations for the NMDA receptor model
    using separate fast and slow components.

    Arguments:
        y :  vector of the state variables:
                  y[0] = x (auxiliary activation variable)
                  y[1] = g_fast (fast decaying conductance component)
                  y[2] = g_slow (slow decaying conductance component)
        t :  time
        params : tuple of parameters
                 (Tau_rise, Tau_decay1, Tau_decay2, A1, A2)
                 A1, A2 are amplitudes/weights for fast/slow components
    """
    x, g_fast, g_slow = y
    Tau_rise, Tau_decay1, Tau_decay2, A1, A2 = params

    # --- Calculate derivatives ---

    # dx/dt: Auxiliary variable decays based on Tau_rise
    # Avoid division by zero
    dx_dt = -x / Tau_rise if Tau_rise > 0 else 0

    # dg_fast/dt: Driven by x (scaled by A1), decays with Tau_decay1
    # Avoid division by zero
    dg_fast_dt = (A1 * x - g_fast / Tau_decay1) if Tau_decay1 > 0 else (A1 * x)

    # dg_slow/dt: Driven by x (scaled by A2), decays with Tau_decay2
    # Avoid division by zero
    dg_slow_dt = (A2 * x - g_slow / Tau_decay2) if Tau_decay2 > 0 else (A2 * x)

    return [dx_dt, dg_fast_dt, dg_slow_dt]

# --- Simulation Parameters ---

# Time vector (e.g., 0 to 500 ms)
t_start = 0.0
t_end = 500.0 # ms
dt = 0.1 # time step for output
t = np.arange(t_start, t_end + dt, dt)

# --- Model Parameters ---
# Example time constants within the 10 ms -- 100 ms range (and one faster)
# Using milliseconds (ms) as the time unit
Tau_rise = 5.0      # Rise time constant for auxiliary variable (ms)
Tau_decay1 = 10.0   # FAST decay time constant (ms) - e.g., 20ms
Tau_decay2 = 1000.0   # SLOW decay time constant (ms) - e.g., 80ms

# Amplitude/Weighting factors for each component
# These determine how much the initial activation 'x' drives each pathway
# and thus the relative peak contributions of g_fast and g_slow.
# Adjust these to change the shape.
A1 = 0.5  # Weight for the fast component
A2 = 0.5  # Weight for the slow component

# Initial conditions
x_init = 1.0       # Start auxiliary variable at 1 (simulates pulse/activation)
g_fast_init = 0.0  # Initial fast conductance component is zero
g_slow_init = 0.0  # Initial slow conductance component is zero
y0 = [x_init, g_fast_init, g_slow_init]

# Pack parameters into a tuple
params = (Tau_rise, Tau_decay1, Tau_decay2, A1, A2)

# --- Run Integration ---
sol = odeint(nmda_model_3var, y0, t, args=(params,))

# Extract results for each state variable
x_t = sol[:, 0]
g_fast_t = sol[:, 1]
g_slow_t = sol[:, 2]

# Calculate the total conductance
g_NMDA_total_t = g_fast_t + g_slow_t

# --- Plotting ---
plt.figure(figsize=(12, 7))

# Plot total g_NMDA
plt.plot(t, g_NMDA_total_t, label='$g_{NMDA}(t) = g_{fast} + g_{slow}$', linewidth=2.5, color='black')

# Plot individual components
plt.plot(t, g_fast_t, label=f'$g_{{fast}}(t)$ ($\\tau = {Tau_decay1}$ ms, A={A1})', linestyle='--', color='red')
plt.plot(t, g_slow_t, label=f'$g_{{slow}}(t)$ ($\\tau = {Tau_decay2}$ ms, A={A2})', linestyle='--', color='blue')

# Optionally plot the auxiliary variable x(t)
# plt.plot(t, x_t, label=f'$x(t)$ ($\\tau = {Tau_rise}$ ms)', linestyle=':', color='green')

# Add labels and title
plt.title('NMDA Receptor Model Simulation (3-Variable)')
plt.xlabel('Time (ms)')
plt.ylabel('Conductance / Activation (arbitrary units)')
plt.legend(loc='upper right')
plt.grid(True)
plt.ylim(bottom=-0.05) # Ensure baseline is visible
plt.xlim(left=t_start, right=t_end)

# Add text indicating the parameters used
param_text = (
    f'Parameters:\n'
    f'  $\\tau_{{rise}} = {Tau_rise}$ ms\n'
    f'  $\\tau_{{decay1}} = {Tau_decay1}$ ms, $A1 = {A1}$\n'
    f'  $\\tau_{{decay2}} = {Tau_decay2}$ ms, $A2 = {A2}$\n'
    f'Initial Conditions:\n'
    f'  $x(0) = {x_init}$\n'
    f'  $g_{{fast}}(0) = {g_fast_init}$\n'
    f'  $g_{{slow}}(0) = {g_slow_init}$'
)
# Place text box
plt.text(0.95, 0.65, param_text, transform=plt.gca().transAxes, fontsize=9,
         verticalalignment='top', horizontalalignment='right',
         bbox=dict(boxstyle='round,pad=0.5', fc='wheat', alpha=0.5))


plt.tight_layout()
plt.show()
