-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathsolution.py
179 lines (162 loc) · 5.86 KB
/
solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# -*- coding: utf-8 -*-
# @Author: Theo Lemaire
# @Date: 2022-02-01 18:50:11
# @Last Modified by: Theo Lemaire
# @Last Modified time: 2022-02-11 14:30:56
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from model import Model
from logger import *
from constants import *
class Solution(pd.DataFrame):
''' Wrapper around pandas DataFrame containing a simulation solution '''
def __init__(self, data, hard_ylims=False):
super().__init__(data)
self.hard_ylims = hard_ylims
@property
def states(self):
''' Automatically extract state names from solution dataset '''
return list(set(self.columns) - set((TIME_MS, V_MV)))
def plot(self, *args, **kwargs):
''' Wrapper around pandas plot, using time as x variable '''
return super().plot(x=TIME_MS)
@staticmethod
def set_soft_ylims(ax, my_ylims):
''' Set a minimum y-axis range, but expand it if data exceeds that range '''
ax.autoscale(True)
ymin, ymax = ax.get_ylim()
ax.set_ylim(min(my_ylims[0], ymin), max(my_ylims[1], ymax))
def set_ylims(self, ax, my_ylims):
''' Set y-axis limits '''
if self.hard_ylims:
ax.set_ylim(*my_ylims)
else:
self.set_soft_ylims(ax, my_ylims)
@staticmethod
def update_axis(ax):
''' Update axis scaling '''
ax.relim()
ax.autoscale_view()
@staticmethod
def add_stim_mark(stim, ax):
''' Add stimulus marks on plots '''
t_off_on, t_on_off = stim.t_OFF_ON(), stim.t_ON_OFF()
for t1, t2 in zip(t_off_on, t_on_off):
ax.axvspan(t1, t2, color='silver', alpha=0.3)
def plot_var(self, key, ax=None, stim=None, update=False, redraw=True, ylims=None):
''' plot variable time course '''
if ax is None:
fig, ax = plt.subplots()
sns.despine(ax=ax)
ax.set_xlabel(TIME_MS)
else:
fig = ax.get_figure()
if update:
line = ax.get_lines()[0]
line.set_ydata(self[key])
self.update_axis(ax)
else:
ax.set_ylabel(key)
ax.plot(self[TIME_MS], self[key], c='k')
if stim is not None:
self.add_stim_mark(stim, ax)
if ylims is not None:
self.set_ylims(ax, ylims)
if update and redraw:
fig.canvas.draw()
return fig
def plot_voltage(self, *args, **kwargs):
''' plot solution voltage time course '''
return self.plot_var(V_MV, *args, ylims=V_LIMS, **kwargs)
def plot_states(self, ax=None, stim=None, update=False, redraw=True):
''' plot solution states time course '''
if ax is None:
fig, ax = plt.subplots()
sns.despine(ax=ax)
ax.set_xlabel(TIME_MS)
else:
fig = ax.get_figure()
if update:
for k, line in zip(self.states, ax.get_lines()):
line.set_ydata(self[k])
self.update_axis(ax)
else:
ax.set_ylabel('states')
for k in self.states:
ax.plot(self[TIME_MS], self[k], label=k)
ax.legend()
if stim is not None:
self.add_stim_mark(stim, ax)
self.set_ylims(ax, STATES_LIMS)
if update and redraw:
fig.canvas.draw()
return fig
def plot_currents(self, cfuncs, stim=None, ax=None, update=False, redraw=True):
''' plot solution currents time course '''
if ax is None:
fig, ax = plt.subplots()
sns.despine(ax=ax)
ax.set_xlabel(TIME_MS)
else:
fig = ax.get_figure()
if isinstance(cfuncs, Model):
cfuncs = cfuncs.compute_currents
currents = cfuncs(self[V_MV], self)
if currents:
i_cap = -pd.concat(currents.values(), axis=1).sum(axis=1)
else:
i_cap = 0
if stim is not None:
tstim, Istim = stim.stim_profile(tstop=self[TIME_MS].values[-1])
Istim_interp = np.interp(self[TIME_MS], tstim, Istim)
i_cap += Istim_interp
currents.update({'i_cap': i_cap})
if update:
for v, line in zip(currents.values(), ax.get_lines()):
line.set_ydata(v)
else:
ax.set_ylabel(CURRENT_DENSITY)
colors = plt.get_cmap('Dark2').colors
for (k, v), c in zip(currents.items(), colors):
ax.plot(self[TIME_MS], v, label=k, c=c)
if stim is not None:
if update:
ax.get_lines()[-1].set_ydata(Istim)
else:
ax.plot(tstim, Istim, label='i_stim', c='k')
self.add_stim_mark(stim, ax)
if update:
self.update_axis(ax)
else:
ax.legend()
self.set_ylims(ax, I_LIMS)
if update and redraw:
fig.canvas.draw()
return fig
def plot_all(self, cfuncs, fig=None, **kwargs):
''' plot all time courses (voltage, states & currents) from a solution '''
naxes = 3 if self.states else 2
if fig is None:
fig, axes = plt.subplots(naxes, 1, figsize=(7, 2 * naxes))
update = False
else:
axes = fig.axes
update = True
iax = 0
self.plot_voltage(ax=axes[iax], update=update, redraw=False, **kwargs)
iax += 1
if self.states:
self.plot_states(ax=axes[iax], update=update, redraw=False, **kwargs)
iax += 1
self.plot_currents(cfuncs, ax=axes[iax], update=update, redraw=False, **kwargs)
if not update:
for ax in axes:
sns.despine(ax=ax)
for ax in axes[:-1]:
ax.set_xticks([])
axes[-1].set_xlabel(TIME_MS)
else:
fig.canvas.draw()
return fig