Files
deconv-test/plot.py
2025-10-14 16:16:50 +03:00

140 lines
4.7 KiB
Python

'''plotting the MSE and SSIM against the number of iterations'''
import numpy as np
import matplotlib.pyplot as plt
import h5py
import glob
import argparse
import csv
import re
from scipy.signal import find_peaks
parser = argparse.ArgumentParser()
parser.add_argument('--infile', type = str, required = True, help = 'file name of the .h5 file for names')
parser.add_argument('--data', type = str, required = True, help = "write only the part of 1 file name before the _(nr-s).npy; files for creating the plots")
parser.add_argument('--images', type = str, default = "mse_mssim", help = "naming the output images, default = mse_mssim")
parser.add_argument('-q','--squared', action = argparse.BooleanOptionalAction, help = "squares the intensity values, otherwise normal values")
parser.add_argument('-sp', '--sigma_plots', action = argparse.BooleanOptionalAction, help = 'shows the plots, where all intensity values of a sigma value are on one plot')
args = parser.parse_args()
group_to_datasets = {}
with h5py.File(args.infile, 'r') as hdf5_file:
for group_name in hdf5_file: # loop üle grupi nimede
group = hdf5_file[group_name] #salvestab grupi nime
if isinstance(group, h5py.Group):
datasets = [] #kui grupi nimi on h5py.Group nimi siis
for ds_name in group: #vaatab üle kõik datasetid grupi sees
if isinstance(group[ds_name], h5py.Dataset): # kui vastab ds nimele
datasets.append(ds_name) # appenditakse
group_to_datasets[group_name] = datasets # iga grupile apenditakse tema oma ds
group_names = list(group_to_datasets.keys()) # has all the groups which are different inensitires and
# all the datasets names are there with sigma and minimum but i only use sigma so the rest are not used at all
labels = group_to_datasets.get(group_names[0]) # gets all the groups labels in a list
npy_files = sorted(glob.glob(f"{args.data}*.npy"))
if args.sigma_plots:
column_accumulators = None
for filename in npy_files:
data = np.load(filename)
iterations = data[:,0]
cols_rem = data.shape[1] - 1
cols = cols_rem // 2
mses = data[:, 1:1+cols]
#mssim = data[:, 1+cols:1+2*cols]
if column_accumulators is None:
column_accumulators = [[] for _ in range(mses.shape[1])]
for i in range(mses.shape[1]):
column_accumulators[i].append(mses[:, i])
column_vars = [np.column_stack(col) for col in column_accumulators]
flat_numbers = np.array([float(n) for s in group_names for n in re.findall(r"-?\d+\.?\d*", s)])
flat_numbers[flat_numbers < 1] = 1
flat_numbers = np.square(flat_numbers)
for i, col in enumerate(column_vars):
plt.figure(figsize=(15,9))
for j in range(col.shape[1]):
if args.squared:
plt.plot(col[:, j] * flat_numbers[j], label = group_names[j])
else:
plt.plot(col[:, j], label = group_names[j])
plt.title(f"{labels[i]}")
plt.legend()
#plt.show()
with open('mse.csv', 'w', newline='') as csvfile, open('mssim.csv', 'w', newline='') as f2:
writer = csv.writer(csvfile)
writer2 = csv.writer(f2)
header = False
for group_name, filename in zip(group_names, npy_files): #üle kõigi mürade
labels = group_to_datasets[group_name]
if not header:
writer.writerow([''] + labels)
writer2.writerow([''] + labels)
header = True
row = [group_name]
row2 = [group_name]
data = np.load(filename)
iterations = data[:,0]
cols_rem = data.shape[1] - 1
cols = cols_rem // 2
mses = data[:, 1:1+cols]
mssim = data[:, 1+cols:1+2*cols]
#np.set_printoptions(threshold=np.inf)
fig, ax = plt.subplots(2, 1, figsize = (15,9))
ax0 = ax[0]
ax1 = ax[1]
for i in range(mses.shape[1]): #üle kõigi sigmade
ax0.plot(iterations, mses[:, i], label = labels[i])
row.append(np.min(mses[:,i]))
ax0.set_xlabel('iterations')
ax0.set_ylabel('mse')
#ax0.set_yscale('log')
ax0.set_title(f'mse change - {group_name}')
ax0.legend()
ax0.grid(False)
for i in range(mssim.shape[1]):
ax1.plot(iterations, mssim[:, i], label = labels[i])
row2.append(np.max(mssim[:,i]))
ax1.set_xlabel('iterations')
ax1.set_ylabel('mssim')
ax1.set_title(f'mssim - {group_name}')
ax1.legend()
ax1.grid(False)
writer.writerow(row)
writer2.writerow(row2)
plt.tight_layout()
plt.savefig(f"{args.images}_{group_name}.png")
plt.show()