import numpy as np import pylab as plt from scipy.signal import convolve from scipy.fftpack import fftn, ifftn, ifftshift from scipy.ndimage import gaussian_filter from scipy.optimize import least_squares import h5py #for saving data from sewar.full_ref import msssim from skimage.metrics import structural_similarity as ssim def _generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_diameter: float): shape = np.array(shape) pz, py, px = voxel_size * 1e6 midpoint = shape // 2 z_shape, y_shape, x_shape = np.array(shape) x_axis = np.arange(x_shape) y_axis = np.arange(y_shape) z_axis = np.arange(z_shape) matrix_x, matrix_y, matrix_z = np.meshgrid(z_axis, y_axis, x_axis, indexing="ij") bead = (px * (matrix_z - midpoint[2])) ** 2 + (py * (matrix_y - midpoint[1])) ** 2 + (pz * (matrix_x - midpoint[0])) ** 2 return (bead <= bead_diameter**2 / 4).astype(float) def generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_outer_diameter: float, bead_inner_diameter: float = 0): core = _generate_bead(shape, voxel_size, np.abs(bead_inner_diameter)) bead = _generate_bead(shape, voxel_size, bead_outer_diameter) + np.sign(bead_inner_diameter) * core return bead / bead.sum() class DeconvolveWithBead: def __init__( self, image: np.ndarray, voxel_size: np.ndarray, bead: np.ndarray = None, bead_diameter: float = None, bead_inner_diameter: float = 0.0, stop_thr: float = 15, ): self.image = image self.voxel_size = voxel_size self.bead = bead self.bead_diameter = bead_diameter self.bead_inner_diameter = bead_inner_diameter self.stop_thr: float = stop_thr # WHAT is this varable for self.pad_width: int = 0 if self.bead is None: self.bead = generate_bead(self.image.shape, voxel_size, bead_diameter, bead_inner_diameter) def deconvolve(self, n_iterations: int = 100, g_sigma: float = 5, pad_width: int = 17, method: str = "rl"): # { Preparing bead = self.bead self.pad_width = pad_width psf = np.pad(bead.copy(), pad_width=pad_width) im = self.image.copy() im = np.pad(im, pad_width=pad_width) if g_sigma > 0: print(f"Appling gaussian filter with sigma {g_sigma} on PSF and INPUT image") psf = gaussian_filter(psf, g_sigma) im = gaussian_filter(im, g_sigma) psf /= psf.sum() # } if method.lower() == "rl": on, mse = self._deconvolve_rl(n_iterations, bead, psf, im) else: raise NotImplementedError(f"{method} is not implemented") return on[pad_width:-pad_width, pad_width:-pad_width, pad_width:-pad_width], mse def _calc_mse(self, image, bead): pd = self.pad_width if pd > 0: return np.sum((image - bead) ** 2) / image.size else: return np.sum((image - bead) ** 2) / image.size # changed for calcs def deconvolve_rl(self, n_iterations, bead: np.ndarray, psf: np.ndarray, im: np.ndarray): otf = fftn(psf) otf_conj = otf.conj() on = im.copy() mse = [] imsum = im.sum() itr = [] #added for iterations mssim = [] # for ssim results (mssim) win_size = 7 #default = 7 #ms_ssim = [] # minimum = None #[] # previous = None best_mse = np.inf #float('inf') minn = None print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum()) for it in range(n_iterations): on_f = fftn(on) div = ifftshift(ifftn(on_f * otf).real) res = np.divide(im, div, out=np.zeros_like(div), where=div != 0) on *= ifftshift(ifftn(fftn(res) * otf_conj).real) on = np.where(on < 0, 0, on) mse.append(self._calc_mse(on[20:-20,20:-20], bead[20:-20,20:-20])) # mse between ground truth and image #slicing mse to remove unwanted irrelevant data around the "sõõr" onsum = on.sum() # # saving the previous image # if previous is None: # minimum = on.copy() #[on] # else: # minimum = previous.copy() #[previous] # previous = on.copy() if mse[-1] < best_mse: best_mse = mse[-1] minn = on.copy() # if len(mse) >= 2 and mse[-2] < mse[-1]: # minn = minimum # print("---------------------miinimum-pilt-on-leitud------------------------------------") # fig = plt.figure() # ax1 = fig.add_subplot(141) # ax2 = fig.add_subplot(142) # ax3 = fig.add_subplot(143) # ax4 = fig.add_subplot(144) # ax1.imshow(bead) # ax2.imshow(im) # ax3.imshow(on) # ax4.imshow(bead-on) # plt.show() #ms_ssim.append(msssim(on,bead)) itr.append(it) # for iterations _mssim = ssim(bead, on, win_size = win_size, # okaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem data_range = max(bead.max(),on.max()) - min(bead.min(),on.min())) mssim.append(_mssim) if not (0.5 * onsum < imsum < 2 * onsum): msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma." print(msg) print(" ", it, imsum, onsum, f"mse={mse[-1]}") raise StopIteration(msg) if it % 10 == 0: print(" ", it, imsum, onsum, f"mse={mse[-1]}") return on, mse, mssim, itr, minn #, ms_ssim def deconvolve_extra(self, n_iterations, bead: np.ndarray, psf: np.ndarray, im: np.ndarray, hdf: h5py.File): # added hdf otf = fftn(psf) otf_conj = otf.conj() on = im.copy() mse = [] imsum = im.sum() itr = [] # added lok = [] mssim = [] # for ssim mssim_lok = [] win_size = 7 previous_on = None ms_ssim = [] def total_variation(image): dx = np.diff(image, axis=1) dy = np.diff(image, axis=0) dx = np.pad(dx, ((0, 0), (0, 1)), mode='constant') dy = np.pad(dy, ((0, 1), (0, 0)), mode='constant') # Compute isotropic total variation tv = np.sum(np.sqrt(dx**2 + dy**2)) tv_aniso = np.sum(np.abs(dx) + np.abs(dy)) return tv, tv_aniso print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum()) for it in range(n_iterations): on_f = fftn(on) div = ifftshift(ifftn(on_f * otf).real) res = np.divide(im, div, out=np.zeros_like(div), where=div != 0) on *= ifftshift(ifftn(fftn(res) * otf_conj).real) on = np.where(on < 0, 0, on) mse.append(self._calc_mse(on, bead)) # mse between ground truth and image onsum = on.sum() ms_ssim.append(msssim(bead,on)) itr.append(it) # for iterations _mssim, _grad = ssim(bead, on, win_size = win_size, #7, lokaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem gradient = True, # 400/1900, maatrix mis muutus kahe pildi vahel data_range = on.max() - on.min(),) # gaussian - window has unifrom weights, #full = True) #with this, more weight is on the center pixels (less sigma = more centre weight, smaller window) # S - per pixel similarity (2D-array) [-1,1] - 1 most similar mssim.append(_mssim) if previous_on is not None: lok.append(self._calc_mse(on, previous_on)) # for the mse between two pictures next to eachother _mssim_lok, _grad_lok = ssim(previous_on, on, win_size = win_size, gradient = True, data_range = on.max() - on.min() ) mssim_lok.append(_mssim_lok) else: lok.append(self._calc_mse(on, on)) _mssim_lok, _grad_lok = ssim(on, on, win_size = win_size, gradient = True, data_range = on.max() - on.min() ) mssim_lok.append(_mssim_lok) previous_on = on.copy() if not (0.5 * onsum < imsum < 2 * onsum): msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma." print(msg) print(" ", it, imsum, onsum, f"mse={mse[-1]}") raise StopIteration(msg) if it % 10 == 0: print(" ", it, imsum, onsum, f"mse={mse[-1]}") if hdf is not None: # saving the images of on-bead sss = on - bead #print(type(hdf)) with h5py.File(hdf, 'a') as hhh: hhh.create_dataset(f'diff/{it:05d}', data = sss) hhh.create_dataset(f'grad/{it:05d}', data = _grad) #hhh.create_dataset(f'ssimg/{it:05d}', data = _S) #print("----",np.min(_grad),"max",np.max(_grad), "==", np.sum(_grad)) #tvssim, tvssim_aniso = total_variation(_grad) #print("------",tvssim, "----", tvssim_aniso) return on, mse, itr, lok, mssim, mssim_lok, ms_ssim