Calcium_Model/fitter.py

205 lines
7.0 KiB
Python
Raw Normal View History

import time
import numpy as np
import pylab as plt
import pandas as pd
import re
from scipy.optimize import least_squares
from model import Model
from Data import Data
class Fitter:
def __init__(self, model: Model, data: Data, current_fit_range: tuple = (107, 341)) -> None:
"""
current_fit_range : tuple (t0, t1), t0 nad t1 are start and stop times in between current is fitted
"""
self.model = model
self.data = data
self.current_fit_range = current_fit_range
self.fit_results = {}
self.tspan = [0, 1000]
self.dt = 1 # 1.0
self.time_points = np.arange(*self.tspan, self.dt)
self.iteration = 0 # least squares iteration counter
t0, t1 = self.current_fit_range = current_fit_range
self.current_time_indecies = (t0 <= self.time_points) & (self.time_points <= t1)
self.measured_current = self.data.get_current_slice(
self.time_points[self.current_time_indecies] / 1000
) # calculating in ms but data recorded in sec
def convolve_current(self, current: np.ndarray, tau=1.5):
if np.abs(tau) < 1e-8:
return current
k = np.zeros(current.size)
k[k.size // 2 :] = np.exp(-np.arange(k.size // 2) / np.abs(tau))
k /= k.sum()
if tau > 0:
return np.convolve(current, k, mode="same")
else:
return np.convolve(current, k[::-1], mode="same")
def cost_func(self, parameters: np.ndarray):
model = self.model()
gGaL, ECal, K_pc_half, tau_xfer, tau_RC, offset = parameters
model.ECaL = ECal
model.gCaL = gGaL
model.K_pc_half = K_pc_half
model.tau_xfer = tau_xfer
model.solve(times=self.time_points)
_calc_curr = model.calculated_current()
calculated_current = self.convolve_current(_calc_curr, tau=tau_RC)[self.current_time_indecies] + offset
res = self.measured_current - calculated_current
err = np.mean(res**2) # mean squared error
self.iteration += 1
print(self.iteration, parameters.tolist(), "err", err)
# measured_fluo = self.data.fluo
# fluo_interplolator = interp1d(self.time_domain, model.calculated_fluo)
# calculated_fluo = fluo_interplolator(self.data.fluo_time)
if self.iteration < -0:
t = self.time_points[self.current_time_indecies]
plt.plot(t, _calc_curr[self.current_time_indecies], label="calculated current")
plt.plot(t, self.measured_current, label="measured current")
plt.plot(t, calculated_current, label="conv calculated current")
plt.plot(t, self.measured_current - calculated_current, label="error")
plt.xlabel("time, ms")
plt.ylabel("current, pA/pF")
plt.legend(frameon=False)
plt.show()
# exit()
return res # , measured_fluo - calculated_fluo)
def optimize(self, init_parameters=None):
t0 = time.time()
self.iteration = 0
if init_parameters is None:
m = self.model()
K_pc_half = m.K_pc_half
tau_xfer = m.tau_xfer
tau_RC = 1.5
offset = 0
d = self.data
init_parameters = np.array([d.gGaL, d.ECal, K_pc_half, tau_xfer, tau_RC, offset])
print(init_parameters.tolist())
bounds = (
(0.01, 10, 0.1, 0.1, 0.1, -5, -10),
(10, 100, 100, 1, 100, 10, 10),
)
res = least_squares(self.cost_func, init_parameters, bounds=bounds, xtol=1e-10)
print()
print(" Parameters: [gGaL, ECal, K_pc_half, tau_xfer, tau_RC, offset]")
print(" Initial:", init_parameters.tolist())
print(" Optimized:", res.x.tolist())
print(" Optim status:", res.status)
print("Optim message:", res.message)
gGaL, ECal, K_pc_half, tau_xfer, tau_RC, offset = res.x
self.fit_results.update({
'gGaL': gGaL,
'ECal': ECal,
'K_pc_half': K_pc_half,
'tau_xfer': tau_xfer,
'tau_RC': tau_RC,
'offset': offset,
'mean_squared_error': err })
model = self.model()
model.ECaL = ECal
model.gCaL = gGaL
model.K_pc_half = K_pc_half
model.tau_xfer = tau_xfer
model.solve(times=self.time_points)
_calc_curr = model.calculated_current()
calculated_current = self.convolve_current(_calc_curr, tau=tau_RC) + offset
print("Elapsed time:", time.time() - t0)
fig = plt.figure(figsize=(24, 12))
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.plot(1000 * self.data.current_t, self.data.current, label="Measured")
ax1.plot(self.time_points, calculated_current, label="Calculated")
ax1.set_xlabel("time, ms")
ax1.set_ylabel("current, pA/pF")
ax1.legend(frameon=False)
tp = self.time_points[self.current_time_indecies]
ax2.plot(tp, self.measured_current, label="Measured")
ax2.plot(tp, calculated_current[self.current_time_indecies], label="Calculated")
ax2.set_xlabel("time, ms")
ax2.set_ylabel("current, pA/pF")
ax2.legend(frameon=False)
return res, fig
def covcor_from_lsq(res):
_, s, VT = svd(res.jac, full_matrices=False)
threshold = np.finfo(float).eps * max(res.jac.shape) * s[0]
s = s[s > threshold]
VT = VT[: s.size]
cov = np.dot(VT.T / s**2, VT)
std = np.sqrt(np.diag(cov))
cor = cov / np.outer(std, std)
cor[cov == 0] = 0
return cov, cor
def plot_correlation_matrix(cor):
plt.imshow(cor, cmap='viridis', interpolation='nearest')
plt.colorbar(label='Correlation')
plt.title('Correlation Matrix')
plt.xlabel('Variables')
plt.ylabel('Variables')
plt.show()
if __name__ == "__main__":
filename = "ltcc_current.h5"
eid = "0033635a51b096dc449eb9964e70443a67fc16b9587ae3ff6564eea1fa0e3437_2018.06.18 14:48:40"
data = Data(filename, eid)
fit = Fitter(Model, data)
fit_hist = pd.DataFrame.from_dict(fit.fit_results, orient='index').T
fit_hist.index.name = 'Iterations'
res_filename = f"fit_results_{eid}.csv"
res_filename = res_filename.replace(" ", "_").replace(":", "-")
fit_hist.to_csv(res_filename, index=True)
eid_cleaned = re.sub(r'[^\w.-]', '', eid) # Eemalda kõik eritähed ja jääb alles alphanumbrilised tähed, sidekriipsud ja punktid
fig.savefig(f"plot_{eid_cleaned}.png")
fig.savefig(f"plot_{eid_cleaned}.pdf")
# plot_filename = "fit_plot"
# fig.savefig(f"{plot_filename}.png")
# fig.savefig(f"{plot_filename}.pdf")
fig.savefig("naidis_fit.pdf")
plt.show()