Skip to content

Implementation of 6K model #8

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
255 changes: 253 additions & 2 deletions pyrho/fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,249 @@ def solveGo(tlag, Gd, Go0=1000, tol=1e-9):



def fit6Kstates(fluxSet, quickSet, run, vInd, params, method=defMethod): # , verbose=config.verbose):
"""
fluxSet := ProtocolData set (of Photocurrent objects) to fit
quickSet:= ProtocolData set (of Photocurrent objects) with short pulses to fit opsin activation rates
run := Index for the run within the ProtocolData set
vInd := Index for Voltage clamp value within the ProtocolData set
params := Parameters object of model parameters with initial values [and bounds, expressions]
method := Fitting algorithm for the optimiser to use
"""
# verbose := Text output (verbosity) level


plotResult = bool(config.verbose > 1)

nStates = '6K'

### Prepare the data
nRuns = fluxSet.nRuns
nPhis = fluxSet.nPhis
nVs = fluxSet.nVs

assert(0 < nPhis)
assert(0 <= run < nRuns)
assert(0 <= vInd < nVs)

Ions = [None for phiInd in range(nPhis)]
Ioffs = [None for phiInd in range(nPhis)]
tons = [None for phiInd in range(nPhis)]
toffs = [None for phiInd in range(nPhis)]
phis = []
Is = []
ts = []
Vs = []

Icycles = []
nfs = [] # Normalisation factors: e.g. /Ions[trial][-1] or /min(Ions[trial])


# Trim off phase data
#frac = 1
#chop = int(round(len(Ioffs[0])*frac))

for phiInd in range(nPhis):
targetPC = fluxSet.trials[run][phiInd][vInd]
#targetPC.alignToTime()
I = targetPC.I
t = targetPC.t
onInd = targetPC._idx_pulses_[0,0] ### Consider multiple pulse scenarios
offInd = targetPC._idx_pulses_[0,1]
Ions[phiInd] = I[onInd:offInd+1]
Ioffs[phiInd] = I[offInd:] #[I[offInd:] for I in Is]
tons[phiInd] = t[onInd:offInd+1]-t[onInd]
toffs[phiInd] = t[offInd:]-t[offInd] #[t[offInd:]-t[offInd] for t in ts]
#args=(Ioffs[phiInd][:chop+1],toffs[phiInd][:chop+1])
phi = targetPC.phi
phis.append(phi)

Is.append(I)
ts.append(t)
V = targetPC.V
Vs.append(V)

Icycles.append(I[onInd:])
nfs.append(I[offInd])
#nfs.append(targetPC.I_peak_)


### OFF PHASE
### 3a. OFF CURVE: Fit biexponential to off curve to find lambdas

OffKeys = ['Gd1', 'Gd2', 'Gf0', 'Ga3']

iOffPs = Parameters() # Create parameter dictionary
for k in OffKeys:
copyParam(k, params, iOffPs)

### Trim the first 10% of the off curve to allow I1 and I2 to empty?


### This is an approximation based on the 4-state model which ignores the effects of Go1 and Go2 after light off.

# lam1 + lam2 == Gd1 + Gd2 + Gf0 + Gb0
# lam1 * lam2 == Gd1*Gd2 + Gd1*Gb0 + Gd2*Gf0

#, Gd2, Gf0, Gb0: (Gd1 + Gd2 + Gf0 + Gb0)/2
#calcC = lambda b, Gd1, Gd2, Gf0, Gb0: np.sqrt(b**2 - (Gd1*Gd2 + Gd1*Gb0 + Gd2*Gf0))

def lams(p):

Gd1 = p['Gd1'].value
Gd2 = p['Gd2'].value
Ga3 = p['Ga3'].value

lam1 = Gd1
lam2 = (Gd2 + Ga3)
return lam1, lam2

# Create dummy parameters for each phi
for phiInd in range(nPhis):
Iss = Ioffs[phiInd][0]
if Iss < 0:
iOffPs.add('Islow_'+str(phiInd), value=0.2*Iss, vary=True, max=0)
iOffPs.add('Ifast_'+str(phiInd), value=0.8*Iss, vary=True, max=0, expr='{} - {}'.format(Iss, 'Islow_'+str(phiInd)))
else:
iOffPs.add('Islow_'+str(phiInd), value=0.2*Iss, vary=True, min=0)
iOffPs.add('Ifast_'+str(phiInd), value=0.8*Iss, vary=True, min=0, expr='{} - {}'.format(Iss, 'Islow_'+str(phiInd)))

def fit6Koff(p,t,trial):
Islow = p['Islow_'+str(trial)].value
Ifast = p['Ifast_'+str(trial)].value
lam1, lam2 = lams(p)
return Islow*np.exp(-lam1*t) + Ifast*np.exp(-lam2*t)

def err6Koff(p,Ioffs,toffs):
"""Normalise by the first element of the off-curve""" # [-1]
return np.r_[ [(Ioffs[i] - fit6Koff(p,toffs[i],i))/Ioffs[i][0] for i in range(len(Ioffs))] ]

#fitfunc = lambda p, t: -(p['a0'].value + p['a1'].value*np.exp(-lams(p)[0]*t) + p['a2'].value*np.exp(-lams(p)[1]*t))
##fitfunc = lambda p, t: -(p['a0'].value + p['a1'].value*np.exp(-p['lam1'].value*t) + p['a2'].value*np.exp(-p['lam2'].value*t))
#errfunc = lambda p, Ioff, toff: Ioff - fitfunc(p,toff)

offPmin = minimize(err6Koff, iOffPs, args=(Ioffs,toffs), method=method)#, fit_kws={'maxfun':100000})
pOffs = offPmin.params

reportFit(offPmin, "Off-phase fit report for the 6K-state model", method)
if config.verbose > 0:
print('Gd1 = {}; Gd2 = {}; Gf0 = {}'.format(pOffs['Gd1'].value, pOffs['Gd2'].value,
pOffs['Gf0'].value))

if plotResult:
lam1, lam2 = lams(pOffs)
plotOffPhaseFits(toffs, Ioffs, pOffs, phis, nStates, fit6Koff, lam1, lam2, Gd=None)


# Fix off-curve parameters
for k in OffKeys:
pOffs[k].vary = False


### Calculate Go (1/tau_opsin)
print('\nCalculating opsin activation rate')
# Assume that Gd1 > Gd2
# Assume that Gd = Gd1 for short pulses

def solveGo(tlag, Gd, Go0=1000, tol=1e-9):
Go, Go_m1 = Go0, 0
while abs(Go_m1 - Go) > tol:
Go_m1 = Go
Go = ((tlag*Gd) - np.log(Gd/Go_m1))/tlag
#Go_m1, Go = Go, ((tlag*Gd) - np.log(Gd/Go_m1))/tlag
return Go

#if 'shortPulse' in dataSet: # Fit Go
if quickSet.nRuns > 1:
#from scipy.optimize import curve_fit
# Fit tpeak = tpulse + tmaxatp0 * np.exp(-k*tpulse)
#dataSet['shortPulse'].getProtPeaks()
#tpeaks = dataSet['shortPulse'].IrunPeaks

#PD = dataSet['shortPulse']
PCs = [quickSet.trials[p][0][0] for p in range(quickSet.nRuns)] # Aligned to the pulse i.e. t_on = 0
#[pc.alignToTime() for pc in PCs]

#tpeaks = np.asarray([PD.trials[p][0][0].tpeak for p in range(PD.nRuns)]) # - PD.trials[p][0][0].t[0]
#tpulses = np.asarray([PD.trials[p][0][0].Dt_ons[0] for p in range(PD.nRuns)])
tpeaks = np.asarray([pc.t_peak_ for pc in PCs])
tpulses = np.asarray([pc.Dt_ons_[0] for pc in PCs])

devFunc = lambda tpulses, t0, k: tpulses + t0 * np.exp(-k*tpulses)
p0 = (0, 1)
popt, pcov = curve_fit(devFunc, tpulses, tpeaks, p0=p0)
if plotResult:
fig = plt.figure()
ax = fig.add_subplot(111, aspect='equal')
nPoints = 10*int(round(max(tpulses))+1) # 101
tsmooth = np.linspace(0, max(tpulses), nPoints)
ax.plot(tpulses, tpeaks, 'x')
ax.plot(tsmooth, devFunc(tsmooth, *popt))
ax.plot(tsmooth, tsmooth, '--')
ax.set_ylim([0, max(tpulses)]) #+5
ax.set_xlim([0, max(tpulses)]) #+5
#plt.tight_layout()
#plt.axis('equal')
plt.show()

# Solve iteratively Go = ((tlag*Gd) - np.log(Gd/Go))/tlag
Gd1 = pOffs['Gd1'].value
Go = solveGo(tlag=popt[0], Gd=Gd1, Go0=1000, tol=1e-9)
print('t_lag = {:.3g}; Gd = {:.3g} --> Go = {:.3g}'.format(popt[0], Gd1, Go))

elif quickSet.nRuns == 1: #'delta' in dataSet:
#PD = dataSet['delta']
#PCs = [PD.trials[p][0][0] for p in range(PD.nRuns)]
PC = quickSet.trials[0][0][0]
tlag = PC.Dt_lag_ # := Dt_lags_[0] ############################### Add to Photocurrent...
Go = solveGo(tlag=tlag, Gd=Gd1, Go0=1000, tol=1e-9)
print('t_lag = {:.3g}; Gd = {:.3g} --> Go = {:.3g}'.format(tlag, Gd1, Go))

else:
Go = 1 # Default
print('No data found to estimate Go: defaulting to Go = {}'.format(Go))


### ON PHASE

iOnPs = Parameters() # deepcopy(params)

# Set parameters from Off-curve optimisation
for k in OffKeys:
copyParam(k, pOffs, iOnPs)

# Set parameters from general rhodopsin analysis routines
for k in ['Go1', 'Go2', 'k1', 'k2', 'k3', 'k_f', 'k_b', 'gam', 'p', 'q', 'phi_m', 'g0', 'Gb', 'E', 'v0', 'v1']: #.extend(OffKeys):
copyParam(k, params, iOnPs)

# Set parameters from short pulse calculations
iOnPs['Go1'].value = Go; iOnPs['Go1'].vary = False
iOnPs['Go2'].value = Go; iOnPs['Go2'].vary = False

RhO = models['6K']()

### Trim down ton? Take 10% of data or one point every ms? ==> [0::5]

if config.verbose > 2:
print('Optimising ',end='')

onPmin = minimize(errOnPhase, iOnPs, args=(Ions,tons,RhO,Vs,phis), method=method)
pOns = onPmin.params

reportFit(onPmin, "On-phase fit report for the 6K-state model", method)

if config.verbose > 0:
print('k1 = {}; k2 = {}; k_f = {}; k_b = {}'.format(pOns['k1'].value, pOns['k2'].value,
pOns['k_f'].value, pOns['k_b'].value))
print('gam = {}; phi_m = {}; p = {}; q = {}'.format(pOns['gam'].value, pOns['phi_m'].value,
pOns['p'].value, pOns['q'].value))

fitParams = pOns

return fitParams, onPmin




#TODO: Tidy up and refactor getRecoveryPeaks and fitRecovery
def getRecoveryPeaks(recData, phiInd=None, vInd=None, usePeakTime=False):
Expand Down Expand Up @@ -1769,11 +2012,14 @@ def fitModel(dataSet, nStates='3', params=None, postFitOpt=True, relaxFact=2, me
"""Fit a model (with initial parameters) to a dataset of optogenetic photocurrents."""

### Define non-optimised parameters to exclude in post-fit optimisation
nonOptParams = ['Gr0', 'E', 'v0', 'v1']

if not isinstance(nStates, str):
nStates = str(nStates) # .lower()

if nStates == '3' or nStates == '4' or nStates == '6':
nonOptParams = ['Gr0', 'E', 'v0', 'v1']
elif nStates == '6K':
nonOptParams = ['E', 'v0', 'v1', 'Ga3']

if nStates not in modelParams:
print(f"Error in selecting model {nStates} - please choose from {list(modelParams)} states")
raise NotImplementedError(nStates)
Expand Down Expand Up @@ -2000,6 +2246,11 @@ def fitModel(dataSet, nStates='3', params=None, postFitOpt=True, relaxFact=2, me
constrainedParams = ['Gd1', 'Gd2', 'Gf0', 'Gb0', 'Go1', 'Go2']
#constrainedParams = ['Go1', 'Go2', 'Gf0', 'Gb0']
#nonOptParams.append(['Gd1', 'Gd2'])
elif nStates == '6K':
fittedParams, miniObj = fit6Kstates(setPC, quickSet, runInd, vIndm70, fitParams, method) # , verbose)
constrainedParams = ['Gd1', 'Gd2', 'Gf0', 'Go1', 'Go2', 'Gb']
#constrainedParams = ['Go1', 'Go2', 'Gf0', 'Gb0']
#nonOptParams.append(['Gd1', 'Gd2'])
else:
raise Exception(f'Invalid choice for nStates: {nStates}!')

Expand Down
Loading