Files
deconv-test/deconvolve_func.py

261 lines
9.5 KiB
Python
Raw Normal View History

2025-10-14 09:33:38 +03:00
import numpy as np
import pylab as plt
from scipy.signal import convolve
from scipy.fftpack import fftn, ifftn, ifftshift
from scipy.ndimage import gaussian_filter
from scipy.optimize import least_squares
import h5py #for saving data
from sewar.full_ref import msssim
from skimage.metrics import structural_similarity as ssim
def _generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_diameter: float):
shape = np.array(shape)
pz, py, px = voxel_size * 1e6
midpoint = shape // 2
z_shape, y_shape, x_shape = np.array(shape)
x_axis = np.arange(x_shape)
y_axis = np.arange(y_shape)
z_axis = np.arange(z_shape)
matrix_x, matrix_y, matrix_z = np.meshgrid(z_axis, y_axis, x_axis, indexing="ij")
bead = (px * (matrix_z - midpoint[2])) ** 2 + (py * (matrix_y - midpoint[1])) ** 2 + (pz * (matrix_x - midpoint[0])) ** 2
return (bead <= bead_diameter**2 / 4).astype(float)
def generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_outer_diameter: float, bead_inner_diameter: float = 0):
core = _generate_bead(shape, voxel_size, np.abs(bead_inner_diameter))
bead = _generate_bead(shape, voxel_size, bead_outer_diameter) + np.sign(bead_inner_diameter) * core
return bead / bead.sum()
class DeconvolveWithBead:
def __init__(
self,
image: np.ndarray,
voxel_size: np.ndarray,
bead: np.ndarray = None,
bead_diameter: float = None,
bead_inner_diameter: float = 0.0,
stop_thr: float = 15,
):
self.image = image
self.voxel_size = voxel_size
self.bead = bead
self.bead_diameter = bead_diameter
self.bead_inner_diameter = bead_inner_diameter
self.stop_thr: float = stop_thr # WHAT is this varable for
self.pad_width: int = 0
if self.bead is None:
self.bead = generate_bead(self.image.shape, voxel_size, bead_diameter, bead_inner_diameter)
def deconvolve(self, n_iterations: int = 100, g_sigma: float = 5, pad_width: int = 17, method: str = "rl"):
# { Preparing
bead = self.bead
self.pad_width = pad_width
psf = np.pad(bead.copy(), pad_width=pad_width)
im = self.image.copy()
im = np.pad(im, pad_width=pad_width)
if g_sigma > 0:
print(f"Appling gaussian filter with sigma {g_sigma} on PSF and INPUT image")
psf = gaussian_filter(psf, g_sigma)
im = gaussian_filter(im, g_sigma)
psf /= psf.sum()
# }
if method.lower() == "rl":
on, mse = self._deconvolve_rl(n_iterations, bead, psf, im)
else:
raise NotImplementedError(f"{method} is not implemented")
return on[pad_width:-pad_width, pad_width:-pad_width, pad_width:-pad_width], mse
def _calc_mse(self, image, bead):
pd = self.pad_width
if pd > 0:
return np.sum((image - bead) ** 2) / image.size
else:
return np.sum((image - bead) ** 2) / image.size # changed for calcs
def deconvolve_rl(self, n_iterations, bead: np.ndarray, psf: np.ndarray, im: np.ndarray):
otf = fftn(psf)
otf_conj = otf.conj()
on = im.copy()
mse = []
imsum = im.sum()
itr = [] #added for iterations
mssim = [] # for ssim results (mssim)
win_size = 7 #default = 7
#ms_ssim = []
# minimum = None #[]
# previous = None
best_mse = np.inf #float('inf')
minn = None
print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum())
for it in range(n_iterations):
on_f = fftn(on)
div = ifftshift(ifftn(on_f * otf).real)
res = np.divide(im, div, out=np.zeros_like(div), where=div != 0)
on *= ifftshift(ifftn(fftn(res) * otf_conj).real)
on = np.where(on < 0, 0, on)
mse.append(self._calc_mse(on[20:-20,20:-20], bead[20:-20,20:-20])) # mse between ground truth and image
#slicing mse to remove unwanted irrelevant data around the "sõõr"
onsum = on.sum()
# # saving the previous image
# if previous is None:
# minimum = on.copy() #[on]
# else:
# minimum = previous.copy() #[previous]
# previous = on.copy()
if mse[-1] < best_mse:
best_mse = mse[-1]
minn = on.copy()
# if len(mse) >= 2 and mse[-2] < mse[-1]:
# minn = minimum
# print("---------------------miinimum-pilt-on-leitud------------------------------------")
# fig = plt.figure()
# ax1 = fig.add_subplot(141)
# ax2 = fig.add_subplot(142)
# ax3 = fig.add_subplot(143)
# ax4 = fig.add_subplot(144)
# ax1.imshow(bead)
# ax2.imshow(im)
# ax3.imshow(on)
# ax4.imshow(bead-on)
# plt.show()
#ms_ssim.append(msssim(on,bead))
itr.append(it) # for iterations
_mssim = ssim(bead, on,
win_size = win_size, # okaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem
data_range = max(bead.max(),on.max()) - min(bead.min(),on.min()))
mssim.append(_mssim)
if not (0.5 * onsum < imsum < 2 * onsum):
msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma."
print(msg)
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
raise StopIteration(msg)
if it % 10 == 0:
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
return on, mse, mssim, itr, minn #, ms_ssim
def deconvolve_extra(self,
n_iterations,
bead: np.ndarray,
psf: np.ndarray,
im: np.ndarray,
hdf: h5py.File): # added hdf
otf = fftn(psf)
otf_conj = otf.conj()
on = im.copy()
mse = []
imsum = im.sum()
itr = [] # added
lok = []
mssim = [] # for ssim
mssim_lok = []
win_size = 7
previous_on = None
ms_ssim = []
def total_variation(image):
dx = np.diff(image, axis=1)
dy = np.diff(image, axis=0)
dx = np.pad(dx, ((0, 0), (0, 1)), mode='constant')
dy = np.pad(dy, ((0, 1), (0, 0)), mode='constant')
# Compute isotropic total variation
tv = np.sum(np.sqrt(dx**2 + dy**2))
tv_aniso = np.sum(np.abs(dx) + np.abs(dy))
return tv, tv_aniso
print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum())
for it in range(n_iterations):
on_f = fftn(on)
div = ifftshift(ifftn(on_f * otf).real)
res = np.divide(im, div, out=np.zeros_like(div), where=div != 0)
on *= ifftshift(ifftn(fftn(res) * otf_conj).real)
on = np.where(on < 0, 0, on)
mse.append(self._calc_mse(on, bead)) # mse between ground truth and image
onsum = on.sum()
ms_ssim.append(msssim(bead,on))
itr.append(it) # for iterations
_mssim, _grad = ssim(bead, on,
win_size = win_size, #7, lokaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem
gradient = True, # 400/1900, maatrix mis muutus kahe pildi vahel
data_range = on.max() - on.min(),) # gaussian - window has unifrom weights,
#full = True)
#with this, more weight is on the center pixels (less sigma = more centre weight, smaller window)
# S - per pixel similarity (2D-array) [-1,1] - 1 most similar
mssim.append(_mssim)
if previous_on is not None:
lok.append(self._calc_mse(on, previous_on)) # for the mse between two pictures next to eachother
_mssim_lok, _grad_lok = ssim(previous_on, on,
win_size = win_size,
gradient = True,
data_range = on.max() - on.min() )
mssim_lok.append(_mssim_lok)
else:
lok.append(self._calc_mse(on, on))
_mssim_lok, _grad_lok = ssim(on, on,
win_size = win_size,
gradient = True,
data_range = on.max() - on.min() )
mssim_lok.append(_mssim_lok)
previous_on = on.copy()
if not (0.5 * onsum < imsum < 2 * onsum):
msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma."
print(msg)
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
raise StopIteration(msg)
if it % 10 == 0:
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
if hdf is not None: # saving the images of on-bead
sss = on - bead
#print(type(hdf))
with h5py.File(hdf, 'a') as hhh:
hhh.create_dataset(f'diff/{it:05d}', data = sss)
hhh.create_dataset(f'grad/{it:05d}', data = _grad)
#hhh.create_dataset(f'ssimg/{it:05d}', data = _S)
#print("----",np.min(_grad),"max",np.max(_grad), "==", np.sum(_grad))
#tvssim, tvssim_aniso = total_variation(_grad)
#print("------",tvssim, "----", tvssim_aniso)
return on, mse, itr, lok, mssim, mssim_lok, ms_ssim