From 7ec9a7ed1267a0aa0b414f41808ec9a979ee66b9 Mon Sep 17 00:00:00 2001 From: Otto Gustavson Date: Tue, 14 Oct 2025 09:33:38 +0300 Subject: [PATCH] initial commit --- README.md | 0 deconvolve_func.py | 261 +++++++++++++++++++++++++++++++++++++++++++++ plot.py | 82 ++++++++++++++ pohi.py | 95 +++++++++++++++++ zeiss2D.py | 94 ++++++++++++++++ 5 files changed, 532 insertions(+) create mode 100644 README.md create mode 100644 deconvolve_func.py create mode 100644 plot.py create mode 100644 pohi.py create mode 100644 zeiss2D.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/deconvolve_func.py b/deconvolve_func.py new file mode 100644 index 0000000..0121f98 --- /dev/null +++ b/deconvolve_func.py @@ -0,0 +1,261 @@ +import numpy as np +import pylab as plt +from scipy.signal import convolve +from scipy.fftpack import fftn, ifftn, ifftshift +from scipy.ndimage import gaussian_filter +from scipy.optimize import least_squares + +import h5py #for saving data +from sewar.full_ref import msssim + +from skimage.metrics import structural_similarity as ssim + +def _generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_diameter: float): + shape = np.array(shape) + pz, py, px = voxel_size * 1e6 + midpoint = shape // 2 + + z_shape, y_shape, x_shape = np.array(shape) + x_axis = np.arange(x_shape) + y_axis = np.arange(y_shape) + z_axis = np.arange(z_shape) + + matrix_x, matrix_y, matrix_z = np.meshgrid(z_axis, y_axis, x_axis, indexing="ij") + + bead = (px * (matrix_z - midpoint[2])) ** 2 + (py * (matrix_y - midpoint[1])) ** 2 + (pz * (matrix_x - midpoint[0])) ** 2 + return (bead <= bead_diameter**2 / 4).astype(float) + + +def generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_outer_diameter: float, bead_inner_diameter: float = 0): + core = _generate_bead(shape, voxel_size, np.abs(bead_inner_diameter)) + bead = _generate_bead(shape, voxel_size, bead_outer_diameter) + np.sign(bead_inner_diameter) * core + return bead / bead.sum() + + +class DeconvolveWithBead: + def __init__( + self, + image: np.ndarray, + voxel_size: np.ndarray, + bead: np.ndarray = None, + bead_diameter: float = None, + bead_inner_diameter: float = 0.0, + stop_thr: float = 15, + ): + self.image = image + self.voxel_size = voxel_size + self.bead = bead + self.bead_diameter = bead_diameter + self.bead_inner_diameter = bead_inner_diameter + self.stop_thr: float = stop_thr # WHAT is this varable for + self.pad_width: int = 0 + + if self.bead is None: + self.bead = generate_bead(self.image.shape, voxel_size, bead_diameter, bead_inner_diameter) + + def deconvolve(self, n_iterations: int = 100, g_sigma: float = 5, pad_width: int = 17, method: str = "rl"): + # { Preparing + bead = self.bead + self.pad_width = pad_width + + psf = np.pad(bead.copy(), pad_width=pad_width) + + im = self.image.copy() + im = np.pad(im, pad_width=pad_width) + if g_sigma > 0: + print(f"Appling gaussian filter with sigma {g_sigma} on PSF and INPUT image") + psf = gaussian_filter(psf, g_sigma) + im = gaussian_filter(im, g_sigma) + + psf /= psf.sum() + + # } + + if method.lower() == "rl": + on, mse = self._deconvolve_rl(n_iterations, bead, psf, im) + else: + raise NotImplementedError(f"{method} is not implemented") + + return on[pad_width:-pad_width, pad_width:-pad_width, pad_width:-pad_width], mse + + def _calc_mse(self, image, bead): + pd = self.pad_width + if pd > 0: + return np.sum((image - bead) ** 2) / image.size + else: + return np.sum((image - bead) ** 2) / image.size # changed for calcs + + + def deconvolve_rl(self, n_iterations, bead: np.ndarray, psf: np.ndarray, im: np.ndarray): + + otf = fftn(psf) + otf_conj = otf.conj() + on = im.copy() + + mse = [] + imsum = im.sum() + + itr = [] #added for iterations + mssim = [] # for ssim results (mssim) + win_size = 7 #default = 7 + #ms_ssim = [] + + # minimum = None #[] + # previous = None + best_mse = np.inf #float('inf') + minn = None + + print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum()) + for it in range(n_iterations): + on_f = fftn(on) + div = ifftshift(ifftn(on_f * otf).real) + res = np.divide(im, div, out=np.zeros_like(div), where=div != 0) + on *= ifftshift(ifftn(fftn(res) * otf_conj).real) + on = np.where(on < 0, 0, on) + mse.append(self._calc_mse(on[20:-20,20:-20], bead[20:-20,20:-20])) # mse between ground truth and image + #slicing mse to remove unwanted irrelevant data around the "sõõr" + onsum = on.sum() + + # # saving the previous image + # if previous is None: + # minimum = on.copy() #[on] + # else: + # minimum = previous.copy() #[previous] + # previous = on.copy() + + if mse[-1] < best_mse: + best_mse = mse[-1] + minn = on.copy() + + # if len(mse) >= 2 and mse[-2] < mse[-1]: + # minn = minimum + # print("---------------------miinimum-pilt-on-leitud------------------------------------") + + + # fig = plt.figure() + # ax1 = fig.add_subplot(141) + # ax2 = fig.add_subplot(142) + # ax3 = fig.add_subplot(143) + # ax4 = fig.add_subplot(144) + + # ax1.imshow(bead) + # ax2.imshow(im) + # ax3.imshow(on) + # ax4.imshow(bead-on) + # plt.show() + #ms_ssim.append(msssim(on,bead)) + + itr.append(it) # for iterations + _mssim = ssim(bead, on, + win_size = win_size, # okaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem + data_range = max(bead.max(),on.max()) - min(bead.min(),on.min())) + mssim.append(_mssim) + + if not (0.5 * onsum < imsum < 2 * onsum): + msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma." + print(msg) + print(" ", it, imsum, onsum, f"mse={mse[-1]}") + raise StopIteration(msg) + + if it % 10 == 0: + print(" ", it, imsum, onsum, f"mse={mse[-1]}") + + return on, mse, mssim, itr, minn #, ms_ssim + + + + def deconvolve_extra(self, + n_iterations, + bead: np.ndarray, + psf: np.ndarray, + im: np.ndarray, + hdf: h5py.File): # added hdf + + otf = fftn(psf) + otf_conj = otf.conj() + on = im.copy() + + mse = [] + imsum = im.sum() + + itr = [] # added + lok = [] + mssim = [] # for ssim + mssim_lok = [] + win_size = 7 + previous_on = None + + ms_ssim = [] + + def total_variation(image): + dx = np.diff(image, axis=1) + dy = np.diff(image, axis=0) + + dx = np.pad(dx, ((0, 0), (0, 1)), mode='constant') + dy = np.pad(dy, ((0, 1), (0, 0)), mode='constant') + + # Compute isotropic total variation + tv = np.sum(np.sqrt(dx**2 + dy**2)) + tv_aniso = np.sum(np.abs(dx) + np.abs(dy)) + return tv, tv_aniso + + print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum()) + for it in range(n_iterations): + on_f = fftn(on) + div = ifftshift(ifftn(on_f * otf).real) + res = np.divide(im, div, out=np.zeros_like(div), where=div != 0) + on *= ifftshift(ifftn(fftn(res) * otf_conj).real) + on = np.where(on < 0, 0, on) + mse.append(self._calc_mse(on, bead)) # mse between ground truth and image + onsum = on.sum() + + ms_ssim.append(msssim(bead,on)) + + itr.append(it) # for iterations + + _mssim, _grad = ssim(bead, on, + win_size = win_size, #7, lokaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem + gradient = True, # 400/1900, maatrix mis muutus kahe pildi vahel + data_range = on.max() - on.min(),) # gaussian - window has unifrom weights, + #full = True) + #with this, more weight is on the center pixels (less sigma = more centre weight, smaller window) + # S - per pixel similarity (2D-array) [-1,1] - 1 most similar + mssim.append(_mssim) + + if previous_on is not None: + lok.append(self._calc_mse(on, previous_on)) # for the mse between two pictures next to eachother + _mssim_lok, _grad_lok = ssim(previous_on, on, + win_size = win_size, + gradient = True, + data_range = on.max() - on.min() ) + mssim_lok.append(_mssim_lok) + else: + lok.append(self._calc_mse(on, on)) + _mssim_lok, _grad_lok = ssim(on, on, + win_size = win_size, + gradient = True, + data_range = on.max() - on.min() ) + mssim_lok.append(_mssim_lok) + previous_on = on.copy() + + if not (0.5 * onsum < imsum < 2 * onsum): + msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma." + print(msg) + print(" ", it, imsum, onsum, f"mse={mse[-1]}") + raise StopIteration(msg) + + if it % 10 == 0: + print(" ", it, imsum, onsum, f"mse={mse[-1]}") + + + if hdf is not None: # saving the images of on-bead + sss = on - bead + #print(type(hdf)) + with h5py.File(hdf, 'a') as hhh: + hhh.create_dataset(f'diff/{it:05d}', data = sss) + hhh.create_dataset(f'grad/{it:05d}', data = _grad) + #hhh.create_dataset(f'ssimg/{it:05d}', data = _S) + #print("----",np.min(_grad),"max",np.max(_grad), "==", np.sum(_grad)) + #tvssim, tvssim_aniso = total_variation(_grad) + #print("------",tvssim, "----", tvssim_aniso) + return on, mse, itr, lok, mssim, mssim_lok, ms_ssim \ No newline at end of file diff --git a/plot.py b/plot.py new file mode 100644 index 0000000..46978a4 --- /dev/null +++ b/plot.py @@ -0,0 +1,82 @@ +'''plotting the MSE and SSIM against the number of iterations''' + +import numpy as np +import matplotlib.pyplot as plt +import h5py +import glob +import argparse + +from scipy.signal import find_peaks + +parser = argparse.ArgumentParser() +parser.add_argument('--infile', type = str, required = True, help = 'file name of the .h5 file for names') +parser.add_argument('--data', type = str, required = True, help = "write only the part of 1 file name before the _(nr-s).npy; files for creating the plots") +parser.add_argument('--images', type = str, default = "mse_mssim", help = "naming the output images, default = mse_mssim") + +args = parser.parse_args() + +group_to_datasets = {} + +with h5py.File(args.infile, 'r') as hdf5_file: + for group_name in hdf5_file: # loop üle grupi nimede + group = hdf5_file[group_name] #salvestab grupi nime + if isinstance(group, h5py.Group): + datasets = [] #kui grupi nimi on h5py.Group nimi siis + for ds_name in group: #vaatab üle kõik datasetid grupi sees + if isinstance(group[ds_name], h5py.Dataset): # kui vastab ds nimele + datasets.append(ds_name) # appenditakse + group_to_datasets[group_name] = datasets # iga grupile apenditakse tema oma ds + + +npy_files = sorted(glob.glob(f"{args.data}*.npy")) +group_names = list(group_to_datasets.keys()) + +for group_name, filename in zip(group_names, npy_files): #üle kõigi mürade + data = np.load(filename) + iterations = data[:,0] + + cols_rem = data.shape[1] - 1 + cols = cols_rem // 2 + mses = data[:, 1:1+cols] + mssim = data[:, 1+cols:1+2*cols] + + #print(f"keskmine mses {group_name}",np.mean(mses)) + #print(f"keskmine mssim {group_name}",np.mean(mssim)) + + #np.set_printoptions(threshold=np.inf) + #print("---",mssim) + fig, ax = plt.subplots(2, 1, figsize = (15,9)) + ax0 = ax[0] + ax1 = ax[1] + + labels = group_to_datasets[group_name] + #print("######################################################################") + for i in range(mses.shape[1]): #üle kõigi sigmade + ax0.plot(iterations, mses[:, i], label = labels[i]) + #print(np.argmin(mses[:,i])) + #print(f"------{i}--------------------{i}-----------------") + + ax0.set_xlabel('iterations') + ax0.set_ylabel('mse') + #ax0.set_yscale('log') + ax0.set_title(f'mse change - {group_name}') + ax0.legend() + ax0.grid(False) + + for i in range(mssim.shape[1]): + ax1.plot(iterations, mssim[:, i], label = labels[i]) + ax1.set_xlabel('iterations') + ax1.set_ylabel('mssim') + ax1.set_title(f'mssim - {group_name}') + ax1.legend() + ax1.grid(False) + + plt.tight_layout() + plt.savefig(f"{args.images}_{group_name}.png") + +plt.show() + + + + + diff --git a/pohi.py b/pohi.py new file mode 100644 index 0000000..dae597c --- /dev/null +++ b/pohi.py @@ -0,0 +1,95 @@ +'''deconvolving 2D images''' + +import numpy as np +import h5py +import argparse +from scipy.signal import convolve +from scipy.ndimage import gaussian_filter +from deconvolve_func import DeconvolveWithBead as DWB + +parser = argparse.ArgumentParser() +parser.add_argument('bead', type = str, help = "original file") +parser.add_argument('--output', type = str, required = True, help = "file name for the output of images") +parser.add_argument('-pd', '--plot_data', type = str, required = True, help = "data storage file name, iterations added automatically") +parser.add_argument('-it', '--iterations', type = int, required = True, help = "nr of iterations") +parser.add_argument('-in', '--intensity', type = float, nargs = '+', required = True, help = "image intensity; division of the signal by intensity") +parser.add_argument('-k', '--kernel', type = float, nargs = '+', required = True, help = "kernel g sigma values") +parser.add_argument('--vox', type = float, nargs = '+', default = [1.0, 1.0, 1.0], help = "voxel values in micrometers, default [1.0, 1.0, 1.0]") + +args = parser.parse_args() + + +#--------- importing the image, that will become the GROUND TRUTH ------ +with h5py.File(args.bead, 'r') as hdf5_file: + original = hdf5_file['t0/channel0'][:] + +print("algse pildi integraal", np.sum(original)) + +#--------- creating the 2D gaussian PSF---------------------------------- +points = np.max(original) +point_sources = np.zeros(original.shape) + +center_x = original.shape[0] // 2 +center_y = original.shape[1] // 2 + +point_sources[center_x, center_y] = points +#point_sources[90, 110] = points +psf = gaussian_filter(point_sources, sigma = 2.4) + +print("psf integraal", np.sum(psf)) +if not np.isclose(np.sum(psf), 1.0, atol=1e-6): + psf /= np.sum(psf) # normaliseerin, et piksli väärtused oleksid samas suurusjärgus +print("psf-integraal-uuesti", np.sum(psf)) + +#-------------- DECONVOLUTION-------------------------------------------------------------- +with h5py.File(args.output, 'w') as hdf5_file: + 'erinevad mürad ### signal intensity #################################################' + intensity = args.intensity + + for i in intensity: + scaled_original = original / i if i!=0 else original.copy() + scaled_original = scaled_original.astype(np.float64) + image = gaussian_filter(scaled_original, sigma = 2.4)#convolve(scaled_original, psf, mode='same') + image[image <= 0] = 0 + + "rakendan pildile Poissoni müra #################################################" + if i == 0: + noisy_image = image.copy() + else: + noisy_image = np.random.poisson(lam = image, size = image.shape).astype(np.float64) + + hdf5_file.create_dataset(f"noise_level_{i:08.3f}", data = noisy_image) + hdf5_file.create_dataset(f"scaled_original_{i:08.3f}", data = scaled_original) + + '''kernelid erinevate sigmadega - pildi taastamine ##################''' + g_sigma = args.kernel + vox = np.array(args.vox) + mses = [] + mssim = [] + + for j in g_sigma: + image_kerneled = gaussian_filter(noisy_image, sigma = j) + psf_kerneled = gaussian_filter(psf, sigma = j) + print(np.sum(image_kerneled),np.sum(noisy_image), np.sum(psf_kerneled), np.sum(psf)) + 'dekonvolveerin - taastan pilti ######################################' + deconv = DWB(image = image_kerneled, voxel_size = vox, bead = scaled_original) + + dec = deconv.deconvolve_rl(n_iterations = args.iterations, + bead = scaled_original, + psf = psf_kerneled, + im = image_kerneled) + + hdf5_file.create_dataset(f'intensity_{i:08.3f}/SIGMA_{j:03.1f}', data = dec[0]) + hdf5_file.create_dataset(f"intensity_{i:08.3f}/minimum_image_{j:03.1f}", data = dec[4]) + + mses.append(dec[1]) + mssim.append(dec[2]) + + + print("---------------lõpp-pildi integraal", np.sum(dec[0])) + data = np.column_stack((np.column_stack(mses),np.column_stack(mssim))) + np.save(f"{args.plot_data}_{args.iterations}_{i:08.3f}.npy", np.column_stack((dec[3], data))) + + + + diff --git a/zeiss2D.py b/zeiss2D.py new file mode 100644 index 0000000..3334238 --- /dev/null +++ b/zeiss2D.py @@ -0,0 +1,94 @@ +import numpy as np + +import czifile as czi +import xml.etree.ElementTree as ET + + +class ZeissData: + + def __init__(self, filename): + + self.filename = filename + self._initialize() + + def _initialize(self): + czi_obj = czi.CziFile(self.filename) + axes = czi_obj.axes + tmp = czi_obj.asarray() + self.axes = "" + # print(tmp.shape) + for i, size in enumerate(tmp.shape): + if size > 1: + self.axes += axes[i] + + self.image = np.squeeze(tmp) + self.shape = self.image.shape + self.ndim = len(self.shape) + self.origin = np.array([[0, i1] for i1 in self.shape]) + + self.metadata = czi_obj.metadata() + root = ET.fromstring(self.metadata) + for d in root.findall("./Metadata/Scaling/Items/Distance"): + el = d.attrib["Id"] + if el in self.axes: + setattr(self, "p" + el.lower(), float(d[0].text) * 1e6) + + # print(self.axes) + # print(self.shape) + + # for k, v in kwargs.items(): + # if k in ImageJ.allowed_kwargs: + # setattr(self, k, v) + # else: + # print(k, "is not allowed keyword argument!") + + def get_channle_info(self, channelID): + # print(self.metadata) + root = ET.fromstring(self.metadata) + # for child in root: + # print(child.tag, child.attrib) + # print(type(root)) + + na = None + refractive_index = None + for _child in root.findall("./Metadata/Information/Instrument/Objectives/Objective"): + for _ch in _child: + if _ch.tag == "LensNA": + na = _ch.text + + if _ch.tag == "ImmersionRefractiveIndex": + refractive_index = _ch.text + # print(_child.tag, _child.attrib, _child[1].tag, type(_child.tag), type(_child.attrib)) + + em, ex = None, None + for _child in root.findall("./Metadata/DisplaySetting/Channels"): + + for _ch in _child.findall("./Channel"): + if _ch.attrib["Id"] == f"Channel:{channelID}": + for _c in _ch: + if _c.tag == "DyeMaxEmission": + em = _c.text + if _c.tag == "DyeMaxExcitation": + ex = _c.text + if ex is not None and em is not None: + break + + # print(_c.tag) + # print(_ch.tag, _ch.attrib["Id"], _ch.text) + + # print() + # print(_child.tag, _child.attrib, _child[1].tag, type(_child.tag), type(_child.attrib)) + + info = { + "MainEmissionWavelength": float(em), + "MainExcitationWavelength": float(ex), + "ObjectiveNA": float(na), + "RefractiveIndex": float(refractive_index), + } + return info + + def get_stack(self, channleID): + #spacings = dict(X=self.px * 1e-6, Y=self.py * 1e-6, Z=self.pz * 1e-6) + spacings = dict(X=self.px * 1e-6, Y=self.py * 1e-6) + + return self.image[channleID], spacings, self.get_channle_info(channleID)