83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
![]() |
'''plotting the MSE and SSIM against the number of iterations'''
|
||
|
|
||
|
import numpy as np
|
||
|
import matplotlib.pyplot as plt
|
||
|
import h5py
|
||
|
import glob
|
||
|
import argparse
|
||
|
|
||
|
from scipy.signal import find_peaks
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument('--infile', type = str, required = True, help = 'file name of the .h5 file for names')
|
||
|
parser.add_argument('--data', type = str, required = True, help = "write only the part of 1 file name before the _(nr-s).npy; files for creating the plots")
|
||
|
parser.add_argument('--images', type = str, default = "mse_mssim", help = "naming the output images, default = mse_mssim")
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
group_to_datasets = {}
|
||
|
|
||
|
with h5py.File(args.infile, 'r') as hdf5_file:
|
||
|
for group_name in hdf5_file: # loop üle grupi nimede
|
||
|
group = hdf5_file[group_name] #salvestab grupi nime
|
||
|
if isinstance(group, h5py.Group):
|
||
|
datasets = [] #kui grupi nimi on h5py.Group nimi siis
|
||
|
for ds_name in group: #vaatab üle kõik datasetid grupi sees
|
||
|
if isinstance(group[ds_name], h5py.Dataset): # kui vastab ds nimele
|
||
|
datasets.append(ds_name) # appenditakse
|
||
|
group_to_datasets[group_name] = datasets # iga grupile apenditakse tema oma ds
|
||
|
|
||
|
|
||
|
npy_files = sorted(glob.glob(f"{args.data}*.npy"))
|
||
|
group_names = list(group_to_datasets.keys())
|
||
|
|
||
|
for group_name, filename in zip(group_names, npy_files): #üle kõigi mürade
|
||
|
data = np.load(filename)
|
||
|
iterations = data[:,0]
|
||
|
|
||
|
cols_rem = data.shape[1] - 1
|
||
|
cols = cols_rem // 2
|
||
|
mses = data[:, 1:1+cols]
|
||
|
mssim = data[:, 1+cols:1+2*cols]
|
||
|
|
||
|
#print(f"keskmine mses {group_name}",np.mean(mses))
|
||
|
#print(f"keskmine mssim {group_name}",np.mean(mssim))
|
||
|
|
||
|
#np.set_printoptions(threshold=np.inf)
|
||
|
#print("---",mssim)
|
||
|
fig, ax = plt.subplots(2, 1, figsize = (15,9))
|
||
|
ax0 = ax[0]
|
||
|
ax1 = ax[1]
|
||
|
|
||
|
labels = group_to_datasets[group_name]
|
||
|
#print("######################################################################")
|
||
|
for i in range(mses.shape[1]): #üle kõigi sigmade
|
||
|
ax0.plot(iterations, mses[:, i], label = labels[i])
|
||
|
#print(np.argmin(mses[:,i]))
|
||
|
#print(f"------{i}--------------------{i}-----------------")
|
||
|
|
||
|
ax0.set_xlabel('iterations')
|
||
|
ax0.set_ylabel('mse')
|
||
|
#ax0.set_yscale('log')
|
||
|
ax0.set_title(f'mse change - {group_name}')
|
||
|
ax0.legend()
|
||
|
ax0.grid(False)
|
||
|
|
||
|
for i in range(mssim.shape[1]):
|
||
|
ax1.plot(iterations, mssim[:, i], label = labels[i])
|
||
|
ax1.set_xlabel('iterations')
|
||
|
ax1.set_ylabel('mssim')
|
||
|
ax1.set_title(f'mssim - {group_name}')
|
||
|
ax1.legend()
|
||
|
ax1.grid(False)
|
||
|
|
||
|
plt.tight_layout()
|
||
|
plt.savefig(f"{args.images}_{group_name}.png")
|
||
|
|
||
|
plt.show()
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
|