2025-10-14 09:33:38 +03:00
''' plotting the MSE and SSIM against the number of iterations '''
import numpy as np
import matplotlib . pyplot as plt
import h5py
import glob
import argparse
2025-10-14 16:16:50 +03:00
import csv
import re
2025-10-14 09:33:38 +03:00
from scipy . signal import find_peaks
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --infile ' , type = str , required = True , help = ' file name of the .h5 file for names ' )
parser . add_argument ( ' --data ' , type = str , required = True , help = " write only the part of 1 file name before the _(nr-s).npy; files for creating the plots " )
parser . add_argument ( ' --images ' , type = str , default = " mse_mssim " , help = " naming the output images, default = mse_mssim " )
2025-10-14 16:16:50 +03:00
parser . add_argument ( ' -q ' , ' --squared ' , action = argparse . BooleanOptionalAction , help = " squares the intensity values, otherwise normal values " )
parser . add_argument ( ' -sp ' , ' --sigma_plots ' , action = argparse . BooleanOptionalAction , help = ' shows the plots, where all intensity values of a sigma value are on one plot ' )
2025-10-14 09:33:38 +03:00
args = parser . parse_args ( )
group_to_datasets = { }
with h5py . File ( args . infile , ' r ' ) as hdf5_file :
for group_name in hdf5_file : # loop üle grupi nimede
group = hdf5_file [ group_name ] #salvestab grupi nime
if isinstance ( group , h5py . Group ) :
datasets = [ ] #kui grupi nimi on h5py.Group nimi siis
for ds_name in group : #vaatab üle kõik datasetid grupi sees
if isinstance ( group [ ds_name ] , h5py . Dataset ) : # kui vastab ds nimele
datasets . append ( ds_name ) # appenditakse
group_to_datasets [ group_name ] = datasets # iga grupile apenditakse tema oma ds
2025-10-14 16:16:50 +03:00
group_names = list ( group_to_datasets . keys ( ) ) # has all the groups which are different inensitires and
# all the datasets names are there with sigma and minimum but i only use sigma so the rest are not used at all
labels = group_to_datasets . get ( group_names [ 0 ] ) # gets all the groups labels in a list
2025-10-14 09:33:38 +03:00
npy_files = sorted ( glob . glob ( f " { args . data } *.npy " ) )
2025-10-14 16:16:50 +03:00
if args . sigma_plots :
column_accumulators = None
for filename in npy_files :
data = np . load ( filename )
iterations = data [ : , 0 ]
cols_rem = data . shape [ 1 ] - 1
cols = cols_rem / / 2
mses = data [ : , 1 : 1 + cols ]
#mssim = data[:, 1+cols:1+2*cols]
if column_accumulators is None :
column_accumulators = [ [ ] for _ in range ( mses . shape [ 1 ] ) ]
for i in range ( mses . shape [ 1 ] ) :
column_accumulators [ i ] . append ( mses [ : , i ] )
column_vars = [ np . column_stack ( col ) for col in column_accumulators ]
flat_numbers = np . array ( [ float ( n ) for s in group_names for n in re . findall ( r " -? \ d+ \ .? \ d* " , s ) ] )
flat_numbers [ flat_numbers < 1 ] = 1
flat_numbers = np . square ( flat_numbers )
for i , col in enumerate ( column_vars ) :
plt . figure ( figsize = ( 15 , 9 ) )
for j in range ( col . shape [ 1 ] ) :
if args . squared :
plt . plot ( col [ : , j ] * flat_numbers [ j ] , label = group_names [ j ] )
else :
plt . plot ( col [ : , j ] , label = group_names [ j ] )
plt . title ( f " { labels [ i ] } " )
plt . legend ( )
#plt.show()
with open ( ' mse.csv ' , ' w ' , newline = ' ' ) as csvfile , open ( ' mssim.csv ' , ' w ' , newline = ' ' ) as f2 :
writer = csv . writer ( csvfile )
writer2 = csv . writer ( f2 )
header = False
2025-10-14 09:33:38 +03:00
2025-10-14 16:16:50 +03:00
for group_name , filename in zip ( group_names , npy_files ) : #üle kõigi mürade
labels = group_to_datasets [ group_name ]
if not header :
writer . writerow ( [ ' ' ] + labels )
writer2 . writerow ( [ ' ' ] + labels )
header = True
row = [ group_name ]
row2 = [ group_name ]
data = np . load ( filename )
iterations = data [ : , 0 ]
cols_rem = data . shape [ 1 ] - 1
cols = cols_rem / / 2
mses = data [ : , 1 : 1 + cols ]
mssim = data [ : , 1 + cols : 1 + 2 * cols ]
#np.set_printoptions(threshold=np.inf)
fig , ax = plt . subplots ( 2 , 1 , figsize = ( 15 , 9 ) )
ax0 = ax [ 0 ]
ax1 = ax [ 1 ]
for i in range ( mses . shape [ 1 ] ) : #üle kõigi sigmade
ax0 . plot ( iterations , mses [ : , i ] , label = labels [ i ] )
row . append ( np . min ( mses [ : , i ] ) )
ax0 . set_xlabel ( ' iterations ' )
ax0 . set_ylabel ( ' mse ' )
#ax0.set_yscale('log')
ax0 . set_title ( f ' mse change - { group_name } ' )
ax0 . legend ( )
ax0 . grid ( False )
for i in range ( mssim . shape [ 1 ] ) :
ax1 . plot ( iterations , mssim [ : , i ] , label = labels [ i ] )
row2 . append ( np . max ( mssim [ : , i ] ) )
ax1 . set_xlabel ( ' iterations ' )
ax1 . set_ylabel ( ' mssim ' )
ax1 . set_title ( f ' mssim - { group_name } ' )
ax1 . legend ( )
ax1 . grid ( False )
writer . writerow ( row )
writer2 . writerow ( row2 )
plt . tight_layout ( )
plt . savefig ( f " { args . images } _ { group_name } .png " )
2025-10-14 09:33:38 +03:00
plt . show ( )