261 lines
9.5 KiB
Python
261 lines
9.5 KiB
Python
![]() |
import numpy as np
|
||
|
import pylab as plt
|
||
|
from scipy.signal import convolve
|
||
|
from scipy.fftpack import fftn, ifftn, ifftshift
|
||
|
from scipy.ndimage import gaussian_filter
|
||
|
from scipy.optimize import least_squares
|
||
|
|
||
|
import h5py #for saving data
|
||
|
from sewar.full_ref import msssim
|
||
|
|
||
|
from skimage.metrics import structural_similarity as ssim
|
||
|
|
||
|
def _generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_diameter: float):
|
||
|
shape = np.array(shape)
|
||
|
pz, py, px = voxel_size * 1e6
|
||
|
midpoint = shape // 2
|
||
|
|
||
|
z_shape, y_shape, x_shape = np.array(shape)
|
||
|
x_axis = np.arange(x_shape)
|
||
|
y_axis = np.arange(y_shape)
|
||
|
z_axis = np.arange(z_shape)
|
||
|
|
||
|
matrix_x, matrix_y, matrix_z = np.meshgrid(z_axis, y_axis, x_axis, indexing="ij")
|
||
|
|
||
|
bead = (px * (matrix_z - midpoint[2])) ** 2 + (py * (matrix_y - midpoint[1])) ** 2 + (pz * (matrix_x - midpoint[0])) ** 2
|
||
|
return (bead <= bead_diameter**2 / 4).astype(float)
|
||
|
|
||
|
|
||
|
def generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_outer_diameter: float, bead_inner_diameter: float = 0):
|
||
|
core = _generate_bead(shape, voxel_size, np.abs(bead_inner_diameter))
|
||
|
bead = _generate_bead(shape, voxel_size, bead_outer_diameter) + np.sign(bead_inner_diameter) * core
|
||
|
return bead / bead.sum()
|
||
|
|
||
|
|
||
|
class DeconvolveWithBead:
|
||
|
def __init__(
|
||
|
self,
|
||
|
image: np.ndarray,
|
||
|
voxel_size: np.ndarray,
|
||
|
bead: np.ndarray = None,
|
||
|
bead_diameter: float = None,
|
||
|
bead_inner_diameter: float = 0.0,
|
||
|
stop_thr: float = 15,
|
||
|
):
|
||
|
self.image = image
|
||
|
self.voxel_size = voxel_size
|
||
|
self.bead = bead
|
||
|
self.bead_diameter = bead_diameter
|
||
|
self.bead_inner_diameter = bead_inner_diameter
|
||
|
self.stop_thr: float = stop_thr # WHAT is this varable for
|
||
|
self.pad_width: int = 0
|
||
|
|
||
|
if self.bead is None:
|
||
|
self.bead = generate_bead(self.image.shape, voxel_size, bead_diameter, bead_inner_diameter)
|
||
|
|
||
|
def deconvolve(self, n_iterations: int = 100, g_sigma: float = 5, pad_width: int = 17, method: str = "rl"):
|
||
|
# { Preparing
|
||
|
bead = self.bead
|
||
|
self.pad_width = pad_width
|
||
|
|
||
|
psf = np.pad(bead.copy(), pad_width=pad_width)
|
||
|
|
||
|
im = self.image.copy()
|
||
|
im = np.pad(im, pad_width=pad_width)
|
||
|
if g_sigma > 0:
|
||
|
print(f"Appling gaussian filter with sigma {g_sigma} on PSF and INPUT image")
|
||
|
psf = gaussian_filter(psf, g_sigma)
|
||
|
im = gaussian_filter(im, g_sigma)
|
||
|
|
||
|
psf /= psf.sum()
|
||
|
|
||
|
# }
|
||
|
|
||
|
if method.lower() == "rl":
|
||
|
on, mse = self._deconvolve_rl(n_iterations, bead, psf, im)
|
||
|
else:
|
||
|
raise NotImplementedError(f"{method} is not implemented")
|
||
|
|
||
|
return on[pad_width:-pad_width, pad_width:-pad_width, pad_width:-pad_width], mse
|
||
|
|
||
|
def _calc_mse(self, image, bead):
|
||
|
pd = self.pad_width
|
||
|
if pd > 0:
|
||
|
return np.sum((image - bead) ** 2) / image.size
|
||
|
else:
|
||
|
return np.sum((image - bead) ** 2) / image.size # changed for calcs
|
||
|
|
||
|
|
||
|
def deconvolve_rl(self, n_iterations, bead: np.ndarray, psf: np.ndarray, im: np.ndarray):
|
||
|
|
||
|
otf = fftn(psf)
|
||
|
otf_conj = otf.conj()
|
||
|
on = im.copy()
|
||
|
|
||
|
mse = []
|
||
|
imsum = im.sum()
|
||
|
|
||
|
itr = [] #added for iterations
|
||
|
mssim = [] # for ssim results (mssim)
|
||
|
win_size = 7 #default = 7
|
||
|
#ms_ssim = []
|
||
|
|
||
|
# minimum = None #[]
|
||
|
# previous = None
|
||
|
best_mse = np.inf #float('inf')
|
||
|
minn = None
|
||
|
|
||
|
print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum())
|
||
|
for it in range(n_iterations):
|
||
|
on_f = fftn(on)
|
||
|
div = ifftshift(ifftn(on_f * otf).real)
|
||
|
res = np.divide(im, div, out=np.zeros_like(div), where=div != 0)
|
||
|
on *= ifftshift(ifftn(fftn(res) * otf_conj).real)
|
||
|
on = np.where(on < 0, 0, on)
|
||
|
mse.append(self._calc_mse(on[20:-20,20:-20], bead[20:-20,20:-20])) # mse between ground truth and image
|
||
|
#slicing mse to remove unwanted irrelevant data around the "sõõr"
|
||
|
onsum = on.sum()
|
||
|
|
||
|
# # saving the previous image
|
||
|
# if previous is None:
|
||
|
# minimum = on.copy() #[on]
|
||
|
# else:
|
||
|
# minimum = previous.copy() #[previous]
|
||
|
# previous = on.copy()
|
||
|
|
||
|
if mse[-1] < best_mse:
|
||
|
best_mse = mse[-1]
|
||
|
minn = on.copy()
|
||
|
|
||
|
# if len(mse) >= 2 and mse[-2] < mse[-1]:
|
||
|
# minn = minimum
|
||
|
# print("---------------------miinimum-pilt-on-leitud------------------------------------")
|
||
|
|
||
|
|
||
|
# fig = plt.figure()
|
||
|
# ax1 = fig.add_subplot(141)
|
||
|
# ax2 = fig.add_subplot(142)
|
||
|
# ax3 = fig.add_subplot(143)
|
||
|
# ax4 = fig.add_subplot(144)
|
||
|
|
||
|
# ax1.imshow(bead)
|
||
|
# ax2.imshow(im)
|
||
|
# ax3.imshow(on)
|
||
|
# ax4.imshow(bead-on)
|
||
|
# plt.show()
|
||
|
#ms_ssim.append(msssim(on,bead))
|
||
|
|
||
|
itr.append(it) # for iterations
|
||
|
_mssim = ssim(bead, on,
|
||
|
win_size = win_size, # okaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem
|
||
|
data_range = max(bead.max(),on.max()) - min(bead.min(),on.min()))
|
||
|
mssim.append(_mssim)
|
||
|
|
||
|
if not (0.5 * onsum < imsum < 2 * onsum):
|
||
|
msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma."
|
||
|
print(msg)
|
||
|
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
|
||
|
raise StopIteration(msg)
|
||
|
|
||
|
if it % 10 == 0:
|
||
|
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
|
||
|
|
||
|
return on, mse, mssim, itr, minn #, ms_ssim
|
||
|
|
||
|
|
||
|
|
||
|
def deconvolve_extra(self,
|
||
|
n_iterations,
|
||
|
bead: np.ndarray,
|
||
|
psf: np.ndarray,
|
||
|
im: np.ndarray,
|
||
|
hdf: h5py.File): # added hdf
|
||
|
|
||
|
otf = fftn(psf)
|
||
|
otf_conj = otf.conj()
|
||
|
on = im.copy()
|
||
|
|
||
|
mse = []
|
||
|
imsum = im.sum()
|
||
|
|
||
|
itr = [] # added
|
||
|
lok = []
|
||
|
mssim = [] # for ssim
|
||
|
mssim_lok = []
|
||
|
win_size = 7
|
||
|
previous_on = None
|
||
|
|
||
|
ms_ssim = []
|
||
|
|
||
|
def total_variation(image):
|
||
|
dx = np.diff(image, axis=1)
|
||
|
dy = np.diff(image, axis=0)
|
||
|
|
||
|
dx = np.pad(dx, ((0, 0), (0, 1)), mode='constant')
|
||
|
dy = np.pad(dy, ((0, 1), (0, 0)), mode='constant')
|
||
|
|
||
|
# Compute isotropic total variation
|
||
|
tv = np.sum(np.sqrt(dx**2 + dy**2))
|
||
|
tv_aniso = np.sum(np.abs(dx) + np.abs(dy))
|
||
|
return tv, tv_aniso
|
||
|
|
||
|
print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum())
|
||
|
for it in range(n_iterations):
|
||
|
on_f = fftn(on)
|
||
|
div = ifftshift(ifftn(on_f * otf).real)
|
||
|
res = np.divide(im, div, out=np.zeros_like(div), where=div != 0)
|
||
|
on *= ifftshift(ifftn(fftn(res) * otf_conj).real)
|
||
|
on = np.where(on < 0, 0, on)
|
||
|
mse.append(self._calc_mse(on, bead)) # mse between ground truth and image
|
||
|
onsum = on.sum()
|
||
|
|
||
|
ms_ssim.append(msssim(bead,on))
|
||
|
|
||
|
itr.append(it) # for iterations
|
||
|
|
||
|
_mssim, _grad = ssim(bead, on,
|
||
|
win_size = win_size, #7, lokaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem
|
||
|
gradient = True, # 400/1900, maatrix mis muutus kahe pildi vahel
|
||
|
data_range = on.max() - on.min(),) # gaussian - window has unifrom weights,
|
||
|
#full = True)
|
||
|
#with this, more weight is on the center pixels (less sigma = more centre weight, smaller window)
|
||
|
# S - per pixel similarity (2D-array) [-1,1] - 1 most similar
|
||
|
mssim.append(_mssim)
|
||
|
|
||
|
if previous_on is not None:
|
||
|
lok.append(self._calc_mse(on, previous_on)) # for the mse between two pictures next to eachother
|
||
|
_mssim_lok, _grad_lok = ssim(previous_on, on,
|
||
|
win_size = win_size,
|
||
|
gradient = True,
|
||
|
data_range = on.max() - on.min() )
|
||
|
mssim_lok.append(_mssim_lok)
|
||
|
else:
|
||
|
lok.append(self._calc_mse(on, on))
|
||
|
_mssim_lok, _grad_lok = ssim(on, on,
|
||
|
win_size = win_size,
|
||
|
gradient = True,
|
||
|
data_range = on.max() - on.min() )
|
||
|
mssim_lok.append(_mssim_lok)
|
||
|
previous_on = on.copy()
|
||
|
|
||
|
if not (0.5 * onsum < imsum < 2 * onsum):
|
||
|
msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma."
|
||
|
print(msg)
|
||
|
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
|
||
|
raise StopIteration(msg)
|
||
|
|
||
|
if it % 10 == 0:
|
||
|
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
|
||
|
|
||
|
|
||
|
if hdf is not None: # saving the images of on-bead
|
||
|
sss = on - bead
|
||
|
#print(type(hdf))
|
||
|
with h5py.File(hdf, 'a') as hhh:
|
||
|
hhh.create_dataset(f'diff/{it:05d}', data = sss)
|
||
|
hhh.create_dataset(f'grad/{it:05d}', data = _grad)
|
||
|
#hhh.create_dataset(f'ssimg/{it:05d}', data = _S)
|
||
|
#print("----",np.min(_grad),"max",np.max(_grad), "==", np.sum(_grad))
|
||
|
#tvssim, tvssim_aniso = total_variation(_grad)
|
||
|
#print("------",tvssim, "----", tvssim_aniso)
|
||
|
return on, mse, itr, lok, mssim, mssim_lok, ms_ssim
|