initial commit

This commit is contained in:
Otto Gustavson
2025-10-14 09:33:38 +03:00
commit 7ec9a7ed12
5 changed files with 532 additions and 0 deletions

0
README.md Normal file
View File

261
deconvolve_func.py Normal file
View File

@@ -0,0 +1,261 @@
import numpy as np
import pylab as plt
from scipy.signal import convolve
from scipy.fftpack import fftn, ifftn, ifftshift
from scipy.ndimage import gaussian_filter
from scipy.optimize import least_squares
import h5py #for saving data
from sewar.full_ref import msssim
from skimage.metrics import structural_similarity as ssim
def _generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_diameter: float):
shape = np.array(shape)
pz, py, px = voxel_size * 1e6
midpoint = shape // 2
z_shape, y_shape, x_shape = np.array(shape)
x_axis = np.arange(x_shape)
y_axis = np.arange(y_shape)
z_axis = np.arange(z_shape)
matrix_x, matrix_y, matrix_z = np.meshgrid(z_axis, y_axis, x_axis, indexing="ij")
bead = (px * (matrix_z - midpoint[2])) ** 2 + (py * (matrix_y - midpoint[1])) ** 2 + (pz * (matrix_x - midpoint[0])) ** 2
return (bead <= bead_diameter**2 / 4).astype(float)
def generate_bead(shape: np.ndarray, voxel_size: np.ndarray, bead_outer_diameter: float, bead_inner_diameter: float = 0):
core = _generate_bead(shape, voxel_size, np.abs(bead_inner_diameter))
bead = _generate_bead(shape, voxel_size, bead_outer_diameter) + np.sign(bead_inner_diameter) * core
return bead / bead.sum()
class DeconvolveWithBead:
def __init__(
self,
image: np.ndarray,
voxel_size: np.ndarray,
bead: np.ndarray = None,
bead_diameter: float = None,
bead_inner_diameter: float = 0.0,
stop_thr: float = 15,
):
self.image = image
self.voxel_size = voxel_size
self.bead = bead
self.bead_diameter = bead_diameter
self.bead_inner_diameter = bead_inner_diameter
self.stop_thr: float = stop_thr # WHAT is this varable for
self.pad_width: int = 0
if self.bead is None:
self.bead = generate_bead(self.image.shape, voxel_size, bead_diameter, bead_inner_diameter)
def deconvolve(self, n_iterations: int = 100, g_sigma: float = 5, pad_width: int = 17, method: str = "rl"):
# { Preparing
bead = self.bead
self.pad_width = pad_width
psf = np.pad(bead.copy(), pad_width=pad_width)
im = self.image.copy()
im = np.pad(im, pad_width=pad_width)
if g_sigma > 0:
print(f"Appling gaussian filter with sigma {g_sigma} on PSF and INPUT image")
psf = gaussian_filter(psf, g_sigma)
im = gaussian_filter(im, g_sigma)
psf /= psf.sum()
# }
if method.lower() == "rl":
on, mse = self._deconvolve_rl(n_iterations, bead, psf, im)
else:
raise NotImplementedError(f"{method} is not implemented")
return on[pad_width:-pad_width, pad_width:-pad_width, pad_width:-pad_width], mse
def _calc_mse(self, image, bead):
pd = self.pad_width
if pd > 0:
return np.sum((image - bead) ** 2) / image.size
else:
return np.sum((image - bead) ** 2) / image.size # changed for calcs
def deconvolve_rl(self, n_iterations, bead: np.ndarray, psf: np.ndarray, im: np.ndarray):
otf = fftn(psf)
otf_conj = otf.conj()
on = im.copy()
mse = []
imsum = im.sum()
itr = [] #added for iterations
mssim = [] # for ssim results (mssim)
win_size = 7 #default = 7
#ms_ssim = []
# minimum = None #[]
# previous = None
best_mse = np.inf #float('inf')
minn = None
print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum())
for it in range(n_iterations):
on_f = fftn(on)
div = ifftshift(ifftn(on_f * otf).real)
res = np.divide(im, div, out=np.zeros_like(div), where=div != 0)
on *= ifftshift(ifftn(fftn(res) * otf_conj).real)
on = np.where(on < 0, 0, on)
mse.append(self._calc_mse(on[20:-20,20:-20], bead[20:-20,20:-20])) # mse between ground truth and image
#slicing mse to remove unwanted irrelevant data around the "sõõr"
onsum = on.sum()
# # saving the previous image
# if previous is None:
# minimum = on.copy() #[on]
# else:
# minimum = previous.copy() #[previous]
# previous = on.copy()
if mse[-1] < best_mse:
best_mse = mse[-1]
minn = on.copy()
# if len(mse) >= 2 and mse[-2] < mse[-1]:
# minn = minimum
# print("---------------------miinimum-pilt-on-leitud------------------------------------")
# fig = plt.figure()
# ax1 = fig.add_subplot(141)
# ax2 = fig.add_subplot(142)
# ax3 = fig.add_subplot(143)
# ax4 = fig.add_subplot(144)
# ax1.imshow(bead)
# ax2.imshow(im)
# ax3.imshow(on)
# ax4.imshow(bead-on)
# plt.show()
#ms_ssim.append(msssim(on,bead))
itr.append(it) # for iterations
_mssim = ssim(bead, on,
win_size = win_size, # okaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem
data_range = max(bead.max(),on.max()) - min(bead.min(),on.min()))
mssim.append(_mssim)
if not (0.5 * onsum < imsum < 2 * onsum):
msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma."
print(msg)
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
raise StopIteration(msg)
if it % 10 == 0:
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
return on, mse, mssim, itr, minn #, ms_ssim
def deconvolve_extra(self,
n_iterations,
bead: np.ndarray,
psf: np.ndarray,
im: np.ndarray,
hdf: h5py.File): # added hdf
otf = fftn(psf)
otf_conj = otf.conj()
on = im.copy()
mse = []
imsum = im.sum()
itr = [] # added
lok = []
mssim = [] # for ssim
mssim_lok = []
win_size = 7
previous_on = None
ms_ssim = []
def total_variation(image):
dx = np.diff(image, axis=1)
dy = np.diff(image, axis=0)
dx = np.pad(dx, ((0, 0), (0, 1)), mode='constant')
dy = np.pad(dy, ((0, 1), (0, 0)), mode='constant')
# Compute isotropic total variation
tv = np.sum(np.sqrt(dx**2 + dy**2))
tv_aniso = np.sum(np.abs(dx) + np.abs(dy))
return tv, tv_aniso
print("Deconvolving:", psf.min(), psf.max(), psf.sum(), otf.sum(), on.sum())
for it in range(n_iterations):
on_f = fftn(on)
div = ifftshift(ifftn(on_f * otf).real)
res = np.divide(im, div, out=np.zeros_like(div), where=div != 0)
on *= ifftshift(ifftn(fftn(res) * otf_conj).real)
on = np.where(on < 0, 0, on)
mse.append(self._calc_mse(on, bead)) # mse between ground truth and image
onsum = on.sum()
ms_ssim.append(msssim(bead,on))
itr.append(it) # for iterations
_mssim, _grad = ssim(bead, on,
win_size = win_size, #7, lokaalsem siis 3,5,7, kui globaalsem võrdlus, siis suurem
gradient = True, # 400/1900, maatrix mis muutus kahe pildi vahel
data_range = on.max() - on.min(),) # gaussian - window has unifrom weights,
#full = True)
#with this, more weight is on the center pixels (less sigma = more centre weight, smaller window)
# S - per pixel similarity (2D-array) [-1,1] - 1 most similar
mssim.append(_mssim)
if previous_on is not None:
lok.append(self._calc_mse(on, previous_on)) # for the mse between two pictures next to eachother
_mssim_lok, _grad_lok = ssim(previous_on, on,
win_size = win_size,
gradient = True,
data_range = on.max() - on.min() )
mssim_lok.append(_mssim_lok)
else:
lok.append(self._calc_mse(on, on))
_mssim_lok, _grad_lok = ssim(on, on,
win_size = win_size,
gradient = True,
data_range = on.max() - on.min() )
mssim_lok.append(_mssim_lok)
previous_on = on.copy()
if not (0.5 * onsum < imsum < 2 * onsum):
msg = f"On iteration {it} iterations were stopped because of instability. Probably caused noise on imput image. Try using --smoothing-sigma."
print(msg)
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
raise StopIteration(msg)
if it % 10 == 0:
print(" ", it, imsum, onsum, f"mse={mse[-1]}")
if hdf is not None: # saving the images of on-bead
sss = on - bead
#print(type(hdf))
with h5py.File(hdf, 'a') as hhh:
hhh.create_dataset(f'diff/{it:05d}', data = sss)
hhh.create_dataset(f'grad/{it:05d}', data = _grad)
#hhh.create_dataset(f'ssimg/{it:05d}', data = _S)
#print("----",np.min(_grad),"max",np.max(_grad), "==", np.sum(_grad))
#tvssim, tvssim_aniso = total_variation(_grad)
#print("------",tvssim, "----", tvssim_aniso)
return on, mse, itr, lok, mssim, mssim_lok, ms_ssim

82
plot.py Normal file
View File

@@ -0,0 +1,82 @@
'''plotting the MSE and SSIM against the number of iterations'''
import numpy as np
import matplotlib.pyplot as plt
import h5py
import glob
import argparse
from scipy.signal import find_peaks
parser = argparse.ArgumentParser()
parser.add_argument('--infile', type = str, required = True, help = 'file name of the .h5 file for names')
parser.add_argument('--data', type = str, required = True, help = "write only the part of 1 file name before the _(nr-s).npy; files for creating the plots")
parser.add_argument('--images', type = str, default = "mse_mssim", help = "naming the output images, default = mse_mssim")
args = parser.parse_args()
group_to_datasets = {}
with h5py.File(args.infile, 'r') as hdf5_file:
for group_name in hdf5_file: # loop üle grupi nimede
group = hdf5_file[group_name] #salvestab grupi nime
if isinstance(group, h5py.Group):
datasets = [] #kui grupi nimi on h5py.Group nimi siis
for ds_name in group: #vaatab üle kõik datasetid grupi sees
if isinstance(group[ds_name], h5py.Dataset): # kui vastab ds nimele
datasets.append(ds_name) # appenditakse
group_to_datasets[group_name] = datasets # iga grupile apenditakse tema oma ds
npy_files = sorted(glob.glob(f"{args.data}*.npy"))
group_names = list(group_to_datasets.keys())
for group_name, filename in zip(group_names, npy_files): #üle kõigi mürade
data = np.load(filename)
iterations = data[:,0]
cols_rem = data.shape[1] - 1
cols = cols_rem // 2
mses = data[:, 1:1+cols]
mssim = data[:, 1+cols:1+2*cols]
#print(f"keskmine mses {group_name}",np.mean(mses))
#print(f"keskmine mssim {group_name}",np.mean(mssim))
#np.set_printoptions(threshold=np.inf)
#print("---",mssim)
fig, ax = plt.subplots(2, 1, figsize = (15,9))
ax0 = ax[0]
ax1 = ax[1]
labels = group_to_datasets[group_name]
#print("######################################################################")
for i in range(mses.shape[1]): #üle kõigi sigmade
ax0.plot(iterations, mses[:, i], label = labels[i])
#print(np.argmin(mses[:,i]))
#print(f"------{i}--------------------{i}-----------------")
ax0.set_xlabel('iterations')
ax0.set_ylabel('mse')
#ax0.set_yscale('log')
ax0.set_title(f'mse change - {group_name}')
ax0.legend()
ax0.grid(False)
for i in range(mssim.shape[1]):
ax1.plot(iterations, mssim[:, i], label = labels[i])
ax1.set_xlabel('iterations')
ax1.set_ylabel('mssim')
ax1.set_title(f'mssim - {group_name}')
ax1.legend()
ax1.grid(False)
plt.tight_layout()
plt.savefig(f"{args.images}_{group_name}.png")
plt.show()

95
pohi.py Normal file
View File

@@ -0,0 +1,95 @@
'''deconvolving 2D images'''
import numpy as np
import h5py
import argparse
from scipy.signal import convolve
from scipy.ndimage import gaussian_filter
from deconvolve_func import DeconvolveWithBead as DWB
parser = argparse.ArgumentParser()
parser.add_argument('bead', type = str, help = "original file")
parser.add_argument('--output', type = str, required = True, help = "file name for the output of images")
parser.add_argument('-pd', '--plot_data', type = str, required = True, help = "data storage file name, iterations added automatically")
parser.add_argument('-it', '--iterations', type = int, required = True, help = "nr of iterations")
parser.add_argument('-in', '--intensity', type = float, nargs = '+', required = True, help = "image intensity; division of the signal by intensity")
parser.add_argument('-k', '--kernel', type = float, nargs = '+', required = True, help = "kernel g sigma values")
parser.add_argument('--vox', type = float, nargs = '+', default = [1.0, 1.0, 1.0], help = "voxel values in micrometers, default [1.0, 1.0, 1.0]")
args = parser.parse_args()
#--------- importing the image, that will become the GROUND TRUTH ------
with h5py.File(args.bead, 'r') as hdf5_file:
original = hdf5_file['t0/channel0'][:]
print("algse pildi integraal", np.sum(original))
#--------- creating the 2D gaussian PSF----------------------------------
points = np.max(original)
point_sources = np.zeros(original.shape)
center_x = original.shape[0] // 2
center_y = original.shape[1] // 2
point_sources[center_x, center_y] = points
#point_sources[90, 110] = points
psf = gaussian_filter(point_sources, sigma = 2.4)
print("psf integraal", np.sum(psf))
if not np.isclose(np.sum(psf), 1.0, atol=1e-6):
psf /= np.sum(psf) # normaliseerin, et piksli väärtused oleksid samas suurusjärgus
print("psf-integraal-uuesti", np.sum(psf))
#-------------- DECONVOLUTION--------------------------------------------------------------
with h5py.File(args.output, 'w') as hdf5_file:
'erinevad mürad ### signal intensity #################################################'
intensity = args.intensity
for i in intensity:
scaled_original = original / i if i!=0 else original.copy()
scaled_original = scaled_original.astype(np.float64)
image = gaussian_filter(scaled_original, sigma = 2.4)#convolve(scaled_original, psf, mode='same')
image[image <= 0] = 0
"rakendan pildile Poissoni müra #################################################"
if i == 0:
noisy_image = image.copy()
else:
noisy_image = np.random.poisson(lam = image, size = image.shape).astype(np.float64)
hdf5_file.create_dataset(f"noise_level_{i:08.3f}", data = noisy_image)
hdf5_file.create_dataset(f"scaled_original_{i:08.3f}", data = scaled_original)
'''kernelid erinevate sigmadega - pildi taastamine ##################'''
g_sigma = args.kernel
vox = np.array(args.vox)
mses = []
mssim = []
for j in g_sigma:
image_kerneled = gaussian_filter(noisy_image, sigma = j)
psf_kerneled = gaussian_filter(psf, sigma = j)
print(np.sum(image_kerneled),np.sum(noisy_image), np.sum(psf_kerneled), np.sum(psf))
'dekonvolveerin - taastan pilti ######################################'
deconv = DWB(image = image_kerneled, voxel_size = vox, bead = scaled_original)
dec = deconv.deconvolve_rl(n_iterations = args.iterations,
bead = scaled_original,
psf = psf_kerneled,
im = image_kerneled)
hdf5_file.create_dataset(f'intensity_{i:08.3f}/SIGMA_{j:03.1f}', data = dec[0])
hdf5_file.create_dataset(f"intensity_{i:08.3f}/minimum_image_{j:03.1f}", data = dec[4])
mses.append(dec[1])
mssim.append(dec[2])
print("---------------lõpp-pildi integraal", np.sum(dec[0]))
data = np.column_stack((np.column_stack(mses),np.column_stack(mssim)))
np.save(f"{args.plot_data}_{args.iterations}_{i:08.3f}.npy", np.column_stack((dec[3], data)))

94
zeiss2D.py Normal file
View File

@@ -0,0 +1,94 @@
import numpy as np
import czifile as czi
import xml.etree.ElementTree as ET
class ZeissData:
def __init__(self, filename):
self.filename = filename
self._initialize()
def _initialize(self):
czi_obj = czi.CziFile(self.filename)
axes = czi_obj.axes
tmp = czi_obj.asarray()
self.axes = ""
# print(tmp.shape)
for i, size in enumerate(tmp.shape):
if size > 1:
self.axes += axes[i]
self.image = np.squeeze(tmp)
self.shape = self.image.shape
self.ndim = len(self.shape)
self.origin = np.array([[0, i1] for i1 in self.shape])
self.metadata = czi_obj.metadata()
root = ET.fromstring(self.metadata)
for d in root.findall("./Metadata/Scaling/Items/Distance"):
el = d.attrib["Id"]
if el in self.axes:
setattr(self, "p" + el.lower(), float(d[0].text) * 1e6)
# print(self.axes)
# print(self.shape)
# for k, v in kwargs.items():
# if k in ImageJ.allowed_kwargs:
# setattr(self, k, v)
# else:
# print(k, "is not allowed keyword argument!")
def get_channle_info(self, channelID):
# print(self.metadata)
root = ET.fromstring(self.metadata)
# for child in root:
# print(child.tag, child.attrib)
# print(type(root))
na = None
refractive_index = None
for _child in root.findall("./Metadata/Information/Instrument/Objectives/Objective"):
for _ch in _child:
if _ch.tag == "LensNA":
na = _ch.text
if _ch.tag == "ImmersionRefractiveIndex":
refractive_index = _ch.text
# print(_child.tag, _child.attrib, _child[1].tag, type(_child.tag), type(_child.attrib))
em, ex = None, None
for _child in root.findall("./Metadata/DisplaySetting/Channels"):
for _ch in _child.findall("./Channel"):
if _ch.attrib["Id"] == f"Channel:{channelID}":
for _c in _ch:
if _c.tag == "DyeMaxEmission":
em = _c.text
if _c.tag == "DyeMaxExcitation":
ex = _c.text
if ex is not None and em is not None:
break
# print(_c.tag)
# print(_ch.tag, _ch.attrib["Id"], _ch.text)
# print()
# print(_child.tag, _child.attrib, _child[1].tag, type(_child.tag), type(_child.attrib))
info = {
"MainEmissionWavelength": float(em),
"MainExcitationWavelength": float(ex),
"ObjectiveNA": float(na),
"RefractiveIndex": float(refractive_index),
}
return info
def get_stack(self, channleID):
#spacings = dict(X=self.px * 1e-6, Y=self.py * 1e-6, Z=self.pz * 1e-6)
spacings = dict(X=self.px * 1e-6, Y=self.py * 1e-6)
return self.image[channleID], spacings, self.get_channle_info(channleID)