225 lines
7.0 KiB
Python
225 lines
7.0 KiB
Python
import time
|
|
import numpy as np
|
|
import pylab as plt
|
|
import pandas as pd
|
|
import re
|
|
|
|
from scipy.optimize import least_squares
|
|
from Model import Model
|
|
from Data import Data
|
|
|
|
|
|
class Fitter:
|
|
def __init__(
|
|
self, model: Model, data: Data, current_fit_range: tuple = (107, 341)
|
|
) -> None:
|
|
"""
|
|
current_fit_range : tuple (t0, t1), t0 nad t1 are start and stop times in between current is fitted
|
|
"""
|
|
|
|
self.model = model
|
|
self.data = data
|
|
self.current_fit_range = current_fit_range
|
|
self.fit_results = {}
|
|
|
|
self.tspan = [0, 1000]
|
|
self.dt = 1 # 1.0
|
|
self.time_points = np.arange(*self.tspan, self.dt)
|
|
|
|
model = Model()
|
|
self.initial_values = model.get_initial_values()
|
|
|
|
self.iteration = 0 # least squares iteration counter
|
|
|
|
t0, t1 = self.current_fit_range = current_fit_range
|
|
self.current_time_indecies = (t0 <= self.time_points) & (self.time_points <= t1)
|
|
self.measured_current = self.data.get_current_slice(
|
|
self.time_points[self.current_time_indecies] / 1000
|
|
) # calculating in ms but data recorded in sec
|
|
|
|
def convolve_current(self, current: np.ndarray, tau=1.5):
|
|
if np.abs(tau) < 1e-8:
|
|
return current
|
|
|
|
k = np.zeros(current.size)
|
|
k[k.size // 2:] = np.exp(-np.arange(k.size // 2) / np.abs(tau))
|
|
k /= k.sum()
|
|
|
|
if tau > 0:
|
|
return np.convolve(current, k, mode="same")
|
|
else:
|
|
return np.convolve(current, k[::-1], mode="same")
|
|
|
|
def cost_func(self, parameters: np.ndarray):
|
|
|
|
model = self.model()
|
|
|
|
gGaL, ECal, K_pc_half, tau_xfer, tau_RC, offset = parameters
|
|
|
|
model.ECaL = ECal
|
|
model.gCaL = gGaL
|
|
model.K_pc_half = K_pc_half
|
|
model.tau_xfer = tau_xfer
|
|
|
|
model.solve(
|
|
times=self.time_points,
|
|
tspan=self.tspan,
|
|
dt=self.dt,
|
|
initial_values=self.initial_values
|
|
)
|
|
|
|
states = self.initial_values
|
|
|
|
V = self.model.mem_potential(t)
|
|
_calc_curr = model.calculated_current(states, t, V)
|
|
calculated_current = (
|
|
self.convolve_current(_calc_curr, tau=tau_RC)[self.current_time_indecies]
|
|
+ offset
|
|
)
|
|
|
|
res = self.measured_current - calculated_current
|
|
err = np.mean(res**2) # mean squared error
|
|
self.iteration += 1
|
|
print(self.iteration, parameters.tolist(), "err", err)
|
|
|
|
if self.iteration < -0:
|
|
t = self.time_points[self.current_time_indecies]
|
|
plt.plot(
|
|
t, _calc_curr[self.current_time_indecies], label="calculated current"
|
|
)
|
|
plt.plot(t, self.measured_current, label="measured current")
|
|
plt.plot(t, calculated_current, label="conv calculated current")
|
|
plt.plot(t, self.measured_current - calculated_current, label="error")
|
|
plt.xlabel("time, ms")
|
|
plt.ylabel("current, pA/pF")
|
|
plt.legend(frameon=False)
|
|
plt.show()
|
|
|
|
return res # , measured_fluo - calculated_fluo)
|
|
|
|
def optimize(self, init_parameters=None):
|
|
t0 = time.time()
|
|
self.iteration = 0
|
|
if init_parameters is None:
|
|
m = self.model()
|
|
K_pc_half = m.K_pc_half
|
|
tau_xfer = m.tau_xfer
|
|
tau_RC = 1.5
|
|
offset = 0
|
|
|
|
d = self.data
|
|
init_parameters = np.array(
|
|
[d.gGaL, d.ECal, K_pc_half, tau_xfer, tau_RC, offset]
|
|
)
|
|
print(init_parameters.tolist())
|
|
|
|
bounds = (
|
|
(0.01, 10, 0.1, 0.1, -5, -10),
|
|
(10, 100, 100, 100, 10, 10),
|
|
)
|
|
|
|
res = least_squares(self.cost_func, init_parameters, bounds=bounds, xtol=1e-10)
|
|
print()
|
|
print(" Parameters: [gGaL, ECal, K_pc_half, tau_xfer, tau_RC, offset]")
|
|
print(" Initial:", init_parameters.tolist())
|
|
print(" Optimized:", res.x.tolist())
|
|
print(" Optim status:", res.status)
|
|
print("Optim message:", res.message)
|
|
|
|
gGaL, ECal, K_pc_half, tau_xfer, tau_RC, offset = res.x
|
|
|
|
err = self.cost_func(res.x)
|
|
|
|
self.fit_results.update(
|
|
{
|
|
"gGaL": gGaL,
|
|
"ECal": ECal,
|
|
"K_pc_half": K_pc_half,
|
|
"tau_xfer": tau_xfer,
|
|
"tau_RC": tau_RC,
|
|
"offset": offset,
|
|
"mean_squared_error": np.mean((err) ** 2),
|
|
}
|
|
)
|
|
|
|
model = self.model()
|
|
|
|
model.ECaL = ECal
|
|
model.gCaL = gGaL
|
|
model.K_pc_half = K_pc_half
|
|
model.tau_xfer = tau_xfer
|
|
|
|
model.solve(times=self.time_points)
|
|
|
|
_calc_curr = model.calculated_current()
|
|
calculated_current = self.convolve_current(_calc_curr, tau=tau_RC) + offset
|
|
print("Elapsed time:", time.time() - t0)
|
|
|
|
fig = plt.figure(figsize=(6, 3))
|
|
ax1 = fig.add_subplot(121)
|
|
ax2 = fig.add_subplot(122)
|
|
ax1.plot(1000 * self.data.current_t, self.data.current, label="Mõõdetud")
|
|
ax1.plot(self.time_points, calculated_current, label="Arvutatud")
|
|
ax1.set_xlabel("Aeg [ms]")
|
|
ax1.set_ylabel("Vool [pA/pF]")
|
|
ax1.legend(frameon=False)
|
|
|
|
tp = self.time_points[self.current_time_indecies]
|
|
ax2.plot(tp, self.measured_current, label="Mõõdetud")
|
|
ax2.plot(tp, calculated_current[self.current_time_indecies], label="Arvutatud")
|
|
ax2.set_xlabel("Aeg [ms]")
|
|
ax2.set_ylabel("Vool [pA/pF]")
|
|
ax2.legend(frameon=False)
|
|
return res, fig
|
|
|
|
def covcor_from_lsq(res):
|
|
|
|
_, s, VT = np.linalg.svd(res.jac, full_matrices=False)
|
|
threshold = np.finfo(float).eps * max(res.jac.shape) * s[0]
|
|
s = s[s > threshold]
|
|
VT = VT[: s.size]
|
|
cov = np.dot(VT.T / s**2, VT)
|
|
|
|
std = np.sqrt(np.diag(cov))
|
|
cor = cov / np.outer(std, std)
|
|
cor[cov == 0] = 0
|
|
|
|
return cov, cor
|
|
|
|
def plot_correlation_matrix(cor):
|
|
plt.imshow(cor, cmap="viridis", interpolation="nearest")
|
|
plt.colorbar(label="Correlation")
|
|
plt.title("Correlation Matrix")
|
|
plt.xlabel("Variables")
|
|
plt.ylabel("Variables")
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
filename = "ltcc_current.h5"
|
|
eid = "1c5ca4b12ae2ddffc3960c1fe39a3cce35967ce23dbac57c010f450e796d01fd_2017.11.27 14:07:04"
|
|
|
|
data = Data(filename, eid)
|
|
|
|
fit = Fitter(Model, data)
|
|
|
|
res, fig = fit.optimize()
|
|
|
|
fit_hist = pd.DataFrame.from_dict(fit.fit_results, orient="index").T
|
|
fit_hist.index.name = "Iterations"
|
|
|
|
res_filename = f"fit_results_{eid}.csv"
|
|
res_filename = res_filename.replace(" ", "_").replace(":", "-")
|
|
fit_hist.to_csv(res_filename, index=True)
|
|
|
|
eid_cleaned = re.sub(r"[^\w.-]", "", eid) # Eemaldab eritahed
|
|
fig.savefig(f"plot_{eid_cleaned}.png")
|
|
fig.savefig(f"plot_{eid_cleaned}.pdf")
|
|
|
|
# plot_filename = "fit_plot"
|
|
# fig.savefig(f"{plot_filename}.png")
|
|
# fig.savefig(f"{plot_filename}.pdf")
|
|
fig.savefig("naidis_fit.pdf")
|
|
plt.show()
|