6 Commits

8 changed files with 1553 additions and 135 deletions

View File

@@ -201,6 +201,28 @@ class DatabaseManager:
finally:
conn.close()
def delete_model(self, model_id: int) -> bool:
"""Delete a model from the database.
Note: detections referencing this model are deleted automatically via
the `detections.model_id` foreign key (ON DELETE CASCADE).
Args:
model_id: ID of the model to delete.
Returns:
True if a model row was deleted, False otherwise.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM models WHERE id = ?", (model_id,))
conn.commit()
return cursor.rowcount > 0
finally:
conn.close()
# ==================== Image Operations ====================
def add_image(
@@ -462,6 +484,22 @@ class DatabaseManager:
finally:
conn.close()
def delete_all_detections(self) -> int:
"""Delete all detections from the database.
Returns:
Number of rows deleted.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM detections")
conn.commit()
return cursor.rowcount
finally:
conn.close()
# ==================== Statistics Operations ====================
def get_detection_statistics(
@@ -620,6 +658,75 @@ class DatabaseManager:
# ==================== Annotation Operations ====================
def get_annotated_images_summary(
self,
name_filter: Optional[str] = None,
order_by: str = "filename",
order_dir: str = "ASC",
limit: Optional[int] = None,
offset: int = 0,
) -> List[Dict]:
"""Return images that have at least one manual annotation.
Args:
name_filter: Optional substring filter applied to filename/relative_path.
order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'.
order_dir: 'ASC' or 'DESC'.
limit: Optional max number of rows.
offset: Pagination offset.
Returns:
List of dicts: {id, relative_path, filename, added_at, annotation_count}
"""
allowed_order_by = {
"filename": "i.filename",
"relative_path": "i.relative_path",
"annotation_count": "annotation_count",
"added_at": "i.added_at",
}
order_expr = allowed_order_by.get(order_by, "i.filename")
dir_norm = str(order_dir).upper().strip()
if dir_norm not in {"ASC", "DESC"}:
dir_norm = "ASC"
conn = self.get_connection()
try:
params: List[Any] = []
where_sql = ""
if name_filter:
# Case-insensitive substring search.
token = f"%{name_filter}%"
where_sql = "WHERE (i.filename LIKE ? OR i.relative_path LIKE ?)"
params.extend([token, token])
limit_sql = ""
if limit is not None:
limit_sql = " LIMIT ? OFFSET ?"
params.extend([int(limit), int(offset)])
query = f"""
SELECT
i.id,
i.relative_path,
i.filename,
i.added_at,
COUNT(a.id) AS annotation_count
FROM images i
JOIN annotations a ON a.image_id = i.id
{where_sql}
GROUP BY i.id
HAVING annotation_count > 0
ORDER BY {order_expr} {dir_norm}
{limit_sql}
"""
cursor = conn.cursor()
cursor.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
finally:
conn.close()
def add_annotation(
self,
image_id: int,

View File

@@ -55,10 +55,7 @@ CREATE TABLE IF NOT EXISTS object_classes (
-- Insert default object classes
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
('cell', '#FF0000', 'Cell object'),
('nucleus', '#00FF00', 'Cell nucleus'),
('mitochondria', '#0000FF', 'Mitochondria'),
('vesicle', '#FFFF00', 'Vesicle');
('terminal', '#FFFF00', 'Axion terminal');
-- Annotations table: stores manual annotations
CREATE TABLE IF NOT EXISTS annotations (

View File

@@ -1,6 +1,7 @@
"""
Main window for the microscopy object detection application.
"""
"""Main window for the microscopy object detection application."""
import shutil
from pathlib import Path
from PySide6.QtWidgets import (
QMainWindow,
@@ -20,6 +21,7 @@ from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.gui.dialogs.config_dialog import ConfigDialog
from src.gui.dialogs.delete_model_dialog import DeleteModelDialog
from src.gui.tabs.detection_tab import DetectionTab
from src.gui.tabs.training_tab import TrainingTab
from src.gui.tabs.validation_tab import ValidationTab
@@ -91,6 +93,12 @@ class MainWindow(QMainWindow):
db_stats_action.triggered.connect(self._show_database_stats)
tools_menu.addAction(db_stats_action)
tools_menu.addSeparator()
delete_model_action = QAction("Delete &Model…", self)
delete_model_action.triggered.connect(self._show_delete_model_dialog)
tools_menu.addAction(delete_model_action)
# Help menu
help_menu = menubar.addMenu("&Help")
@@ -117,10 +125,10 @@ class MainWindow(QMainWindow):
# Add tabs to widget
self.tab_widget.addTab(self.detection_tab, "Detection")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation")
self.tab_widget.addTab(self.training_tab, "Training")
self.tab_widget.addTab(self.validation_tab, "Validation")
self.tab_widget.addTab(self.results_tab, "Results")
self.tab_widget.addTab(self.annotation_tab, "Annotation (Future)")
# Connect tab change signal
self.tab_widget.currentChanged.connect(self._on_tab_changed)
@@ -152,9 +160,7 @@ class MainWindow(QMainWindow):
"""Center window on screen."""
screen = self.screen().geometry()
size = self.geometry()
self.move(
(screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2
)
self.move((screen.width() - size.width()) // 2, (screen.height() - size.height()) // 2)
def _restore_window_state(self):
"""Restore window geometry from settings or center window."""
@@ -193,6 +199,10 @@ class MainWindow(QMainWindow):
self.training_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "annotation_tab"):
self.annotation_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
except Exception as e:
logger.error(f"Error applying settings: {e}")
@@ -209,6 +219,14 @@ class MainWindow(QMainWindow):
logger.debug(f"Switched to tab: {tab_name}")
self._update_status(f"Viewing: {tab_name}")
# Ensure the Annotation tab always shows up-to-date DB-backed lists.
try:
current_widget = self.tab_widget.widget(index)
if hasattr(self, "annotation_tab") and current_widget is self.annotation_tab:
self.annotation_tab.refresh()
except Exception as exc:
logger.debug(f"Failed to refresh annotation tab on selection: {exc}")
def _show_database_stats(self):
"""Show database statistics dialog."""
try:
@@ -231,10 +249,230 @@ class MainWindow(QMainWindow):
except Exception as e:
logger.error(f"Error getting database stats: {e}")
QMessageBox.warning(
self, "Error", f"Failed to get database statistics:\n{str(e)}"
QMessageBox.warning(self, "Error", f"Failed to get database statistics:\n{str(e)}")
def _show_delete_model_dialog(self) -> None:
"""Open the model deletion dialog."""
dialog = DeleteModelDialog(self.db_manager, self)
if not dialog.exec():
return
model_ids = dialog.selected_model_ids
if not model_ids:
return
self._delete_models(model_ids)
def _delete_models(self, model_ids: list[int]) -> None:
"""Delete one or more models from the database and remove artifacts from disk."""
deleted_count = 0
removed_paths: list[str] = []
remove_errors: list[str] = []
for model_id in model_ids:
model = None
try:
model = self.db_manager.get_model_by_id(int(model_id))
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
remove_errors.append(f"Model id {model_id} not found in database.")
continue
try:
deleted = self.db_manager.delete_model(int(model_id))
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
remove_errors.append(f"Failed to delete model id {model_id} from DB: {exc}")
continue
if not deleted:
remove_errors.append(f"Model id {model_id} was not deleted (already removed?).")
continue
deleted_count += 1
removed, errors = self._delete_model_artifacts_from_disk(model)
removed_paths.extend(removed)
remove_errors.extend(errors)
# Refresh tabs to reflect the deletion(s).
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details: list[str] = []
if removed_paths:
details.append("Removed from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
f"Deleted {deleted_count} model(s) from database." + ("\n\n" + "\n".join(details) if details else ""),
)
def _delete_model(self, model_id: int) -> None:
"""Delete a model from the database and remove its artifacts from disk."""
model = None
try:
model = self.db_manager.get_model_by_id(model_id)
except Exception as exc:
logger.error(f"Failed to load model {model_id} before deletion: {exc}")
if not model:
QMessageBox.warning(self, "Delete Model", "Selected model was not found in the database.")
return
model_path = str(model.get("model_path") or "")
try:
deleted = self.db_manager.delete_model(model_id)
except Exception as exc:
logger.error(f"Failed to delete model {model_id}: {exc}")
QMessageBox.critical(self, "Delete Model", f"Failed to delete model from database:\n{exc}")
return
if not deleted:
QMessageBox.warning(self, "Delete Model", "No model was deleted (it may have already been removed).")
return
removed_paths, remove_errors = self._delete_model_artifacts_from_disk(model)
# Refresh tabs to reflect the deletion.
try:
if hasattr(self, "detection_tab"):
self.detection_tab.refresh()
if hasattr(self, "results_tab"):
self.results_tab.refresh()
if hasattr(self, "validation_tab"):
self.validation_tab.refresh()
if hasattr(self, "training_tab"):
self.training_tab.refresh()
except Exception as exc:
logger.warning(f"Failed to refresh tabs after model deletion: {exc}")
details = []
if model_path:
details.append(f"Deleted model record for: {model_path}")
if removed_paths:
details.append("\nRemoved from disk:\n" + "\n".join(removed_paths))
if remove_errors:
details.append("\nDisk cleanup warnings:\n" + "\n".join(remove_errors))
QMessageBox.information(
self,
"Delete Model",
"Model deleted from database." + ("\n\n" + "\n".join(details) if details else ""),
)
def _delete_model_artifacts_from_disk(self, model: dict) -> tuple[list[str], list[str]]:
"""Best-effort removal of model artifacts on disk.
Strategy:
- Remove run directories inferred from:
- model.model_path (…/<run>/weights/*.pt => <run>)
- training_params.stage_results[].results.save_dir
but only if they are under the configured models directory.
- If the weights file itself exists and is outside the models directory, delete only the file.
Returns:
(removed_paths, errors)
"""
removed: list[str] = []
errors: list[str] = []
models_root = Path(self.config_manager.get_models_directory() or "data/models").expanduser()
try:
models_root_resolved = models_root.resolve()
except Exception:
models_root_resolved = models_root
inferred_dirs: list[Path] = []
# 1) From model_path
model_path_value = model.get("model_path")
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
p_resolved = p.resolve() if p.exists() else p
if p_resolved.is_file():
if p_resolved.parent.name == "weights" and p_resolved.parent.parent.exists():
inferred_dirs.append(p_resolved.parent.parent)
elif p_resolved.parent.exists():
inferred_dirs.append(p_resolved.parent)
except Exception:
pass
# 2) From training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if not save_dir:
continue
try:
d = Path(str(save_dir)).expanduser()
if d.exists() and d.is_dir():
inferred_dirs.append(d)
except Exception:
continue
# Deduplicate inferred_dirs
unique_dirs: list[Path] = []
seen: set[str] = set()
for d in inferred_dirs:
try:
key = str(d.resolve())
except Exception:
key = str(d)
if key in seen:
continue
seen.add(key)
unique_dirs.append(d)
# Delete directories under models_root
for d in unique_dirs:
try:
d_resolved = d.resolve()
except Exception:
d_resolved = d
try:
if d_resolved.exists() and d_resolved.is_dir() and d_resolved.is_relative_to(models_root_resolved):
shutil.rmtree(d_resolved)
removed.append(str(d_resolved))
except Exception as exc:
errors.append(f"Failed to remove directory {d_resolved}: {exc}")
# If nothing matched (e.g., model_path outside models_root), delete just the file.
if model_path_value:
try:
p = Path(str(model_path_value)).expanduser()
if p.exists() and p.is_file():
p_resolved = p.resolve()
if not p_resolved.is_relative_to(models_root_resolved):
p_resolved.unlink()
removed.append(str(p_resolved))
except Exception as exc:
errors.append(f"Failed to remove model file {model_path_value}: {exc}")
return removed, errors
def _show_about(self):
"""Show about dialog."""
about_text = """
@@ -301,6 +539,11 @@ class MainWindow(QMainWindow):
if hasattr(self, "training_tab"):
self.training_tab.shutdown()
if hasattr(self, "annotation_tab"):
# Best-effort refresh so DB-backed UI state is consistent at shutdown.
try:
self.annotation_tab.refresh()
except Exception:
pass
self.annotation_tab.save_state()
logger.info("Application closing")

View File

@@ -13,6 +13,11 @@ from PySide6.QtWidgets import (
QFileDialog,
QMessageBox,
QSplitter,
QLineEdit,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
)
from PySide6.QtCore import Qt, QSettings
from pathlib import Path
@@ -29,9 +34,7 @@ logger = get_logger(__name__)
class AnnotationTab(QWidget):
"""Annotation tab for manual image annotation."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
@@ -52,6 +55,32 @@ class AnnotationTab(QWidget):
self.main_splitter = QSplitter(Qt.Horizontal)
self.main_splitter.setHandleWidth(10)
# { Left-most pane: annotated images list
annotated_group = QGroupBox("Annotated Images")
annotated_layout = QVBoxLayout()
filter_row = QHBoxLayout()
filter_row.addWidget(QLabel("Filter:"))
self.annotated_filter_edit = QLineEdit()
self.annotated_filter_edit.setPlaceholderText("Type to filter by image name…")
self.annotated_filter_edit.textChanged.connect(self._refresh_annotated_images_list)
filter_row.addWidget(self.annotated_filter_edit, 1)
annotated_layout.addLayout(filter_row)
self.annotated_images_table = QTableWidget(0, 2)
self.annotated_images_table.setHorizontalHeaderLabels(["Image", "Annotations"])
self.annotated_images_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.annotated_images_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.annotated_images_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.annotated_images_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.annotated_images_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.annotated_images_table.setSortingEnabled(True)
self.annotated_images_table.itemSelectionChanged.connect(self._on_annotated_image_selected)
annotated_layout.addWidget(self.annotated_images_table, 1)
annotated_group.setLayout(annotated_layout)
# }
# { Left splitter for image display and zoom info
self.left_splitter = QSplitter(Qt.Vertical)
self.left_splitter.setHandleWidth(10)
@@ -62,6 +91,9 @@ class AnnotationTab(QWidget):
# Use the AnnotationCanvasWidget
self.annotation_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available canvas viewport.
# (Matches the behavior used in ResultsTab.)
self.annotation_canvas.set_auto_fit_to_view(True)
self.annotation_canvas.zoom_changed.connect(self._on_zoom_changed)
self.annotation_canvas.annotation_drawn.connect(self._on_annotation_drawn)
# Selection of existing polylines (when tool is not in drawing mode)
@@ -72,9 +104,7 @@ class AnnotationTab(QWidget):
self.left_splitter.addWidget(canvas_group)
# Controls info
controls_info = QLabel(
"Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse"
)
controls_info = QLabel("Zoom: Mouse wheel or +/- keys | Drawing: Enable pen and drag mouse")
controls_info.setStyleSheet("QLabel { color: #888; font-style: italic; }")
self.left_splitter.addWidget(controls_info)
# }
@@ -85,36 +115,20 @@ class AnnotationTab(QWidget):
# Annotation tools section
self.annotation_tools = AnnotationToolsWidget(self.db_manager)
self.annotation_tools.polyline_enabled_changed.connect(
self.annotation_canvas.set_polyline_enabled
)
self.annotation_tools.polyline_pen_color_changed.connect(
self.annotation_canvas.set_polyline_pen_color
)
self.annotation_tools.polyline_pen_width_changed.connect(
self.annotation_canvas.set_polyline_pen_width
)
self.annotation_tools.polyline_enabled_changed.connect(self.annotation_canvas.set_polyline_enabled)
self.annotation_tools.polyline_pen_color_changed.connect(self.annotation_canvas.set_polyline_pen_color)
self.annotation_tools.polyline_pen_width_changed.connect(self.annotation_canvas.set_polyline_pen_width)
# Show / hide bounding boxes
self.annotation_tools.show_bboxes_changed.connect(
self.annotation_canvas.set_show_bboxes
)
self.annotation_tools.show_bboxes_changed.connect(self.annotation_canvas.set_show_bboxes)
# RDP simplification controls
self.annotation_tools.simplify_on_finish_changed.connect(
self._on_simplify_on_finish_changed
)
self.annotation_tools.simplify_epsilon_changed.connect(
self._on_simplify_epsilon_changed
)
self.annotation_tools.simplify_on_finish_changed.connect(self._on_simplify_on_finish_changed)
self.annotation_tools.simplify_epsilon_changed.connect(self._on_simplify_epsilon_changed)
# Class selection and class-color changes
self.annotation_tools.class_selected.connect(self._on_class_selected)
self.annotation_tools.class_color_changed.connect(self._on_class_color_changed)
self.annotation_tools.clear_annotations_requested.connect(
self._on_clear_annotations
)
self.annotation_tools.clear_annotations_requested.connect(self._on_clear_annotations)
# Delete selected annotation on canvas
self.annotation_tools.delete_selected_annotation_requested.connect(
self._on_delete_selected_annotation
)
self.annotation_tools.delete_selected_annotation_requested.connect(self._on_delete_selected_annotation)
self.right_splitter.addWidget(self.annotation_tools)
# Image loading section
@@ -137,12 +151,13 @@ class AnnotationTab(QWidget):
self.right_splitter.addWidget(load_group)
# }
# Add both splitters to the main horizontal splitter
# Add list + both splitters to the main horizontal splitter
self.main_splitter.addWidget(annotated_group)
self.main_splitter.addWidget(self.left_splitter)
self.main_splitter.addWidget(self.right_splitter)
# Set initial sizes: 75% for left (image), 25% for right (controls)
self.main_splitter.setSizes([750, 250])
# Set initial sizes: list (left), canvas (middle), controls (right)
self.main_splitter.setSizes([320, 650, 280])
layout.addWidget(self.main_splitter)
self.setLayout(layout)
@@ -150,6 +165,9 @@ class AnnotationTab(QWidget):
# Restore splitter positions from settings
self._restore_state()
# Populate list on startup.
self._refresh_annotated_images_list()
def _load_image(self):
"""Load and display an image file."""
# Get last opened directory from QSettings
@@ -180,12 +198,24 @@ class AnnotationTab(QWidget):
self.current_image_path = file_path
# Store the directory for next time
settings.setValue(
"annotation_tab/last_directory", str(Path(file_path).parent)
)
settings.setValue("annotation_tab/last_directory", str(Path(file_path).parent))
# Get or create image in database
relative_path = str(Path(file_path).name) # Simplified for now
repo_root = self.config_manager.get_image_repository_path()
relative_path: str
try:
if repo_root:
repo_root_path = Path(repo_root).expanduser().resolve()
file_resolved = Path(file_path).expanduser().resolve()
if file_resolved.is_relative_to(repo_root_path):
relative_path = file_resolved.relative_to(repo_root_path).as_posix()
else:
# Fallback: store filename only to avoid leaking absolute paths.
relative_path = file_resolved.name
else:
relative_path = str(Path(file_path).name)
except Exception:
relative_path = str(Path(file_path).name)
self.current_image_id = self.db_manager.get_or_create_image(
relative_path,
Path(file_path).name,
@@ -199,6 +229,9 @@ class AnnotationTab(QWidget):
# Load and display any existing annotations for this image
self._load_annotations_for_current_image()
# Update annotated images list (newly annotated image added/selected).
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# Update info label
self._update_image_info()
@@ -206,9 +239,7 @@ class AnnotationTab(QWidget):
except ImageLoadError as e:
logger.error(f"Failed to load image: {e}")
QMessageBox.critical(
self, "Error Loading Image", f"Failed to load image:\n{str(e)}"
)
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{str(e)}")
except Exception as e:
logger.error(f"Unexpected error loading image: {e}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{str(e)}")
@@ -296,6 +327,9 @@ class AnnotationTab(QWidget):
# Reload annotations from DB and redraw (respecting current class filter)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to save annotation: {e}")
QMessageBox.critical(self, "Error", f"Failed to save annotation:\n{str(e)}")
@@ -340,9 +374,7 @@ class AnnotationTab(QWidget):
if not self.current_image_id:
return
logger.debug(
f"Class color changed; reloading annotations for image ID {self.current_image_id}"
)
logger.debug(f"Class color changed; reloading annotations for image ID {self.current_image_id}")
self._load_annotations_for_current_image()
def _on_class_selected(self, class_data):
@@ -355,9 +387,7 @@ class AnnotationTab(QWidget):
if class_data:
logger.debug(f"Object class selected: {class_data['class_name']}")
else:
logger.debug(
'No class selected ("-- Select Class --"), showing all annotations'
)
logger.debug('No class selected ("-- Select Class --"), showing all annotations')
# Changing the class filter invalidates any previous selection
self.selected_annotation_ids = []
@@ -390,9 +420,7 @@ class AnnotationTab(QWidget):
question = "Are you sure you want to delete the selected annotation?"
title = "Delete Annotation"
else:
question = (
f"Are you sure you want to delete the {count} selected annotations?"
)
question = f"Are you sure you want to delete the {count} selected annotations?"
title = "Delete Annotations"
reply = QMessageBox.question(
@@ -420,13 +448,11 @@ class AnnotationTab(QWidget):
QMessageBox.warning(
self,
"Partial Failure",
"Some annotations could not be deleted:\n"
+ ", ".join(str(a) for a in failed_ids),
"Some annotations could not be deleted:\n" + ", ".join(str(a) for a in failed_ids),
)
else:
logger.info(
f"Deleted {count} annotation(s): "
+ ", ".join(str(a) for a in self.selected_annotation_ids)
f"Deleted {count} annotation(s): " + ", ".join(str(a) for a in self.selected_annotation_ids)
)
# Clear selection and reload annotations for the current image from DB
@@ -434,6 +460,9 @@ class AnnotationTab(QWidget):
self.annotation_tools.set_has_selected_annotation(False)
self._load_annotations_for_current_image()
# Update list counts.
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
except Exception as e:
logger.error(f"Failed to delete annotations: {e}")
QMessageBox.critical(
@@ -456,17 +485,13 @@ class AnnotationTab(QWidget):
return
try:
self.current_annotations = self.db_manager.get_annotations_for_image(
self.current_image_id
)
self.current_annotations = self.db_manager.get_annotations_for_image(self.current_image_id)
# New annotations loaded; reset any selection
self.selected_annotation_ids = []
self.annotation_tools.set_has_selected_annotation(False)
self._redraw_annotations_for_current_filter()
except Exception as e:
logger.error(
f"Failed to load annotations for image {self.current_image_id}: {e}"
)
logger.error(f"Failed to load annotations for image {self.current_image_id}: {e}")
QMessageBox.critical(
self,
"Error",
@@ -490,10 +515,7 @@ class AnnotationTab(QWidget):
drawn_count = 0
for ann in self.current_annotations:
# Filter by class if one is selected
if (
selected_class_id is not None
and ann.get("class_id") != selected_class_id
):
if selected_class_id is not None and ann.get("class_id") != selected_class_id:
continue
if ann.get("segmentation_mask"):
@@ -545,22 +567,176 @@ class AnnotationTab(QWidget):
settings = QSettings("microscopy_app", "object_detection")
# Save main splitter state
settings.setValue(
"annotation_tab/main_splitter_state", self.main_splitter.saveState()
)
settings.setValue("annotation_tab/main_splitter_state", self.main_splitter.saveState())
# Save left splitter state
settings.setValue(
"annotation_tab/left_splitter_state", self.left_splitter.saveState()
)
settings.setValue("annotation_tab/left_splitter_state", self.left_splitter.saveState())
# Save right splitter state
settings.setValue(
"annotation_tab/right_splitter_state", self.right_splitter.saveState()
)
settings.setValue("annotation_tab/right_splitter_state", self.right_splitter.saveState())
logger.debug("Saved annotation tab splitter states")
def refresh(self):
"""Refresh the tab."""
self._refresh_annotated_images_list(select_image_id=self.current_image_id)
# ==================== Annotated images list ====================
def _refresh_annotated_images_list(self, select_image_id: int | None = None) -> None:
"""Reload annotated-images list from the database."""
if not hasattr(self, "annotated_images_table"):
return
# Preserve selection if possible
desired_id = select_image_id if select_image_id is not None else self.current_image_id
name_filter = ""
if hasattr(self, "annotated_filter_edit"):
name_filter = self.annotated_filter_edit.text().strip()
try:
rows = self.db_manager.get_annotated_images_summary(name_filter=name_filter)
except Exception as exc:
logger.error(f"Failed to load annotated images summary: {exc}")
rows = []
sorting_enabled = self.annotated_images_table.isSortingEnabled()
self.annotated_images_table.setSortingEnabled(False)
self.annotated_images_table.blockSignals(True)
try:
self.annotated_images_table.setRowCount(len(rows))
for r, entry in enumerate(rows):
image_name = str(entry.get("filename") or "")
count = int(entry.get("annotation_count") or 0)
rel_path = str(entry.get("relative_path") or "")
name_item = QTableWidgetItem(image_name)
# Tooltip shows full path of the image (best-effort: repository_root + relative_path)
full_path = rel_path
repo_root = self.config_manager.get_image_repository_path()
if repo_root and rel_path and not Path(rel_path).is_absolute():
try:
full_path = str((Path(repo_root) / rel_path).resolve())
except Exception:
full_path = str(Path(repo_root) / rel_path)
name_item.setToolTip(full_path)
name_item.setData(Qt.UserRole, int(entry.get("id")))
name_item.setData(Qt.UserRole + 1, rel_path)
count_item = QTableWidgetItem()
# Use EditRole to ensure numeric sorting.
count_item.setData(Qt.EditRole, count)
count_item.setData(Qt.UserRole, int(entry.get("id")))
count_item.setData(Qt.UserRole + 1, rel_path)
self.annotated_images_table.setItem(r, 0, name_item)
self.annotated_images_table.setItem(r, 1, count_item)
# Re-select desired row
if desired_id is not None:
for r in range(self.annotated_images_table.rowCount()):
item = self.annotated_images_table.item(r, 0)
if item and item.data(Qt.UserRole) == desired_id:
self.annotated_images_table.selectRow(r)
break
finally:
self.annotated_images_table.blockSignals(False)
self.annotated_images_table.setSortingEnabled(sorting_enabled)
def _on_annotated_image_selected(self) -> None:
"""When user clicks an item in the list, load that image in the annotation canvas."""
selected = self.annotated_images_table.selectedItems()
if not selected:
return
# Row selection -> take the first column item
row = self.annotated_images_table.currentRow()
item = self.annotated_images_table.item(row, 0)
if not item:
return
image_id = item.data(Qt.UserRole)
rel_path = item.data(Qt.UserRole + 1) or ""
if not image_id:
return
image_path = self._resolve_image_path_for_relative_path(rel_path)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate image on disk for:\n"
f"{rel_path}\n\n"
"Tip: set Settings → Image repository path to the folder containing your images.",
)
return
try:
self.current_image = Image(image_path)
self.current_image_path = image_path
self.current_image_id = int(image_id)
self.annotation_canvas.load_image(self.current_image)
self._load_annotations_for_current_image()
self._update_image_info()
except ImageLoadError as exc:
logger.error(f"Failed to load image '{image_path}': {exc}")
QMessageBox.critical(self, "Error Loading Image", f"Failed to load image:\n{exc}")
except Exception as exc:
logger.error(f"Unexpected error loading image '{image_path}': {exc}")
QMessageBox.critical(self, "Error", f"Unexpected error:\n{exc}")
def _resolve_image_path_for_relative_path(self, relative_path: str) -> str | None:
"""Best-effort conversion from a DB relative_path to an on-disk file path."""
rel = (relative_path or "").strip()
if not rel:
return None
candidates: list[Path] = []
# 1) Repository root + relative
repo_root = (self.config_manager.get_image_repository_path() or "").strip()
if repo_root:
candidates.append(Path(repo_root) / rel)
# 2) If the DB path is absolute, try it directly.
candidates.append(Path(rel))
# 3) Try the directory of the currently loaded image (helps when DB stores only filenames)
if self.current_image_path:
try:
candidates.append(Path(self.current_image_path).expanduser().resolve().parent / Path(rel).name)
except Exception:
pass
# 4) Try the last directory used by the annotation file picker
try:
settings = QSettings("microscopy_app", "object_detection")
last_dir = settings.value("annotation_tab/last_directory", None)
if last_dir:
candidates.append(Path(str(last_dir)) / Path(rel).name)
except Exception:
pass
for p in candidates:
try:
expanded = p.expanduser()
if expanded.exists() and expanded.is_file():
return str(expanded.resolve())
except Exception:
continue
# 5) Fallback: search by filename within repository root.
filename = Path(rel).name
if repo_root and filename:
root = Path(repo_root).expanduser()
try:
if root.exists():
for match in root.rglob(filename):
if match.is_file():
return str(match.resolve())
except Exception as exc:
logger.debug(f"Search for {filename} under {root} failed: {exc}")
return None

View File

@@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays.
"""
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from PySide6.QtWidgets import (
QWidget,
@@ -65,6 +65,22 @@ class ResultsTab(QWidget):
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
)
self.delete_all_btn.clicked.connect(self._delete_all_detections)
controls_layout.addWidget(self.delete_all_btn)
self.export_labels_btn = QPushButton("Export Labels")
self.export_labels_btn.setToolTip(
"Export YOLO .txt labels for the selected image/model run.\n"
"Output path is inferred from the image path (images/ -> labels/)."
)
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
controls_layout.addWidget(self.export_labels_btn)
controls_layout.addStretch()
left_layout.addLayout(controls_layout)
@@ -130,6 +146,41 @@ class ResultsTab(QWidget):
layout.addWidget(splitter)
self.setLayout(layout)
def _delete_all_detections(self):
"""Delete all detections from the database after user confirmation."""
confirm = QMessageBox.warning(
self,
"Delete All Detections",
"This will permanently delete ALL detections from the database.\n\n"
"This action cannot be undone.\n\n"
"Do you want to continue?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if confirm != QMessageBox.Yes:
return
try:
deleted = self.db_manager.delete_all_detections()
except Exception as exc:
logger.error(f"Failed to delete all detections: {exc}")
QMessageBox.critical(
self,
"Error",
f"Failed to delete detections:\n{exc}",
)
return
QMessageBox.information(
self,
"Delete All Detections",
f"Deleted {deleted} detection(s) from the database.",
)
# Reset UI state.
self.refresh()
def refresh(self):
"""Refresh the detection list and preview."""
self._load_detection_summary()
@@ -139,6 +190,8 @@ class ResultsTab(QWidget):
self.current_detections = []
self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.")
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False)
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
@@ -258,6 +311,231 @@ class ResultsTab(QWidget):
self._load_detections_for_selection(entry)
self._apply_detection_overlays()
self._update_summary_label(entry)
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(True)
def _export_labels_for_current_selection(self):
"""Export YOLO label file(s) for the currently selected image/model."""
if not self.current_selection:
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
return
entry = self.current_selection
image_path_str = self._resolve_image_path(entry)
if not image_path_str:
QMessageBox.warning(
self,
"Export Labels",
"Unable to locate the image file for this detection; cannot infer labels path.",
)
return
# Ensure we have the detections for the selection.
if not self.current_detections:
self._load_detections_for_selection(entry)
if not self.current_detections:
QMessageBox.information(
self,
"Export Labels",
"No detections found for this image/model selection.",
)
return
image_path = Path(image_path_str)
try:
label_path = self._infer_yolo_label_path(image_path)
except Exception as exc:
logger.error(f"Failed to infer label path for {image_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to infer export path for labels:\n{exc}",
)
return
class_map = self._build_detection_class_index_map(self.current_detections)
if not class_map:
QMessageBox.warning(
self,
"Export Labels",
"Unable to build class->index mapping (missing class names).",
)
return
lines_written = 0
skipped = 0
label_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(label_path, "w", encoding="utf-8") as handle:
print("writing to", label_path)
for det in self.current_detections:
yolo_line = self._format_detection_as_yolo_line(det, class_map)
if not yolo_line:
skipped += 1
continue
handle.write(yolo_line + "\n")
lines_written += 1
except OSError as exc:
logger.error(f"Failed to write labels file {label_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to write label file:\n{label_path}\n\n{exc}",
)
return
return
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
try:
classes_txt = label_path.parent.parent / "classes.txt"
classes_txt.parent.mkdir(parents=True, exist_ok=True)
inv = {idx: name for name, idx in class_map.items()}
with open(classes_txt, "w", encoding="utf-8") as handle:
for idx in range(len(inv)):
handle.write(f"{inv[idx]}\n")
except Exception:
# Non-fatal
pass
QMessageBox.information(
self,
"Export Labels",
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
)
def _infer_yolo_label_path(self, image_path: Path) -> Path:
"""Infer a YOLO label path from an image path.
If the image lives under an `images/` directory (anywhere in the path), we mirror the
subpath under a sibling `labels/` directory at the same level.
Example:
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
"""
resolved = image_path.expanduser().resolve()
# Find the nearest ancestor directory named 'images'
images_dir: Optional[Path] = None
for parent in [resolved.parent, *resolved.parents]:
if parent.name.lower() == "images":
images_dir = parent
break
if images_dir is not None:
rel = resolved.relative_to(images_dir)
labels_dir = images_dir.parent / "labels"
return (labels_dir / rel).with_suffix(".txt")
# Fallback: create a local sibling labels folder next to the image.
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
"""Build a stable class_name -> YOLO class index mapping.
Preference order:
1) Database object_classes table (alphabetical class_name order)
2) Fallback to class_name values present in the detections (alphabetical)
"""
names: List[str] = []
try:
db_classes = self.db_manager.get_object_classes() or []
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
except Exception:
names = []
if not names:
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
names = list(observed)
return {name: idx for idx, name in enumerate(names)}
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
"""Convert a detection row to a YOLO label line.
- If segmentation_mask is present, exports segmentation polygon format:
class x1 y1 x2 y2 ...
(normalized coordinates)
- Otherwise exports bbox format:
class x_center y_center width height
(normalized coordinates)
"""
class_name = det.get("class_name")
if not class_name or class_name not in class_map:
return None
class_idx = class_map[class_name]
mask = det.get("segmentation_mask")
polygon = self._convert_segmentation_mask_to_polygon(mask)
if polygon:
coords = " ".join(f"{value:.6f}" for value in polygon)
return f"{class_idx} {coords}".strip()
bbox = self._convert_bbox_to_yolo_xywh(det)
if bbox is None:
return None
x_center, y_center, width, height = bbox
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
x_min = det.get("x_min")
y_min = det.get("y_min")
x_max = det.get("x_max")
y_max = det.get("y_max")
if any(v is None for v in (x_min, y_min, x_max, y_max)):
return None
try:
x_min_f = self._clamp01(float(x_min))
y_min_f = self._clamp01(float(y_min))
x_max_f = self._clamp01(float(x_max))
y_max_f = self._clamp01(float(y_max))
except (TypeError, ValueError):
return None
width = max(0.0, x_max_f - x_min_f)
height = max(0.0, y_max_f - y_min_f)
if width <= 0.0 or height <= 0.0:
return None
x_center = x_min_f + width / 2.0
y_center = y_min_f + height / 2.0
return x_center, y_center, width, height
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
if not isinstance(mask_data, list):
return []
coords: List[float] = []
for point in mask_data:
if not isinstance(point, (list, tuple)) or len(point) < 2:
continue
try:
x = self._clamp01(float(point[0]))
y = self._clamp01(float(point[1]))
except (TypeError, ValueError):
continue
coords.extend([x, y])
# Need at least 3 points => 6 values.
return coords if len(coords) >= 6 else []
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return value
def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair."""

View File

@@ -2,45 +2,554 @@
Validation tab for the microscopy object detection application.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
from PySide6.QtCore import Qt, QSize
from PySide6.QtGui import QPainter, QPixmap
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QLabel,
QGroupBox,
QHBoxLayout,
QPushButton,
QComboBox,
QFormLayout,
QScrollArea,
QGridLayout,
QFrame,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QSplitter,
QListWidget,
QListWidgetItem,
QAbstractItemView,
QGraphicsView,
QGraphicsScene,
QGraphicsPixmapItem,
)
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
class _PlotItem:
label: str
path: Path
class _ZoomableImageView(QGraphicsView):
"""Zoomable image viewer.
- Mouse wheel: zoom in/out
- Left mouse drag: pan (ScrollHandDrag)
"""
def __init__(self, parent: Optional[QWidget] = None):
super().__init__(parent)
self._scene = QGraphicsScene(self)
self.setScene(self._scene)
self._pixmap_item = QGraphicsPixmapItem()
self._scene.addItem(self._pixmap_item)
# QGraphicsView render hints are QPainter.RenderHints.
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self._has_pixmap = False
def clear(self) -> None:
self._pixmap_item.setPixmap(QPixmap())
self._scene.setSceneRect(0, 0, 1, 1)
self.resetTransform()
self._has_pixmap = False
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
self._pixmap_item.setPixmap(pixmap)
self._scene.setSceneRect(pixmap.rect())
self._has_pixmap = not pixmap.isNull()
self.resetTransform()
if fit and self._has_pixmap:
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
def wheelEvent(self, event) -> None: # type: ignore[override]
if not self._has_pixmap:
return
zoom_in_factor = 1.25
zoom_out_factor = 1.0 / zoom_in_factor
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
self.scale(factor, factor)
class ValidationTab(QWidget):
"""Validation tab placeholder."""
"""Validation tab that shows stored validation metrics + plots for a selected model."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
self._models: List[Dict[str, Any]] = []
self._selected_model_id: Optional[int] = None
self._plot_widgets: List[QWidget] = []
self._plot_items: List[_PlotItem] = []
self._setup_ui()
self.refresh()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
layout = QVBoxLayout(self)
group = QGroupBox("Validation")
group_layout = QVBoxLayout()
label = QLabel(
"Validation functionality will be implemented here.\n\n"
"Features:\n"
"- Model validation\n"
"- Metrics visualization\n"
"- Confusion matrix\n"
"- Precision-Recall curves"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
# ===== Header controls =====
header = QGroupBox("Validation")
header_layout = QVBoxLayout()
header_row = QHBoxLayout()
layout.addWidget(group)
layout.addStretch()
self.setLayout(layout)
header_row.addWidget(QLabel("Select model:"))
self.model_combo = QComboBox()
self.model_combo.setMinimumWidth(420)
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
header_row.addWidget(self.model_combo, 1)
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
header_row.addWidget(self.refresh_btn)
header_row.addStretch()
header_layout.addLayout(header_row)
self.header_status = QLabel("No models loaded.")
self.header_status.setWordWrap(True)
header_layout.addWidget(self.header_status)
header.setLayout(header_layout)
layout.addWidget(header)
# ===== Metrics =====
metrics_group = QGroupBox("Validation Metrics")
metrics_layout = QVBoxLayout()
self.metrics_form = QFormLayout()
self.metric_labels: Dict[str, QLabel] = {}
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
value_label = QLabel("")
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
self.metric_labels[key] = value_label
self.metrics_form.addRow(f"{key}:", value_label)
metrics_layout.addLayout(self.metrics_form)
self.per_class_table = QTableWidget(0, 3)
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
self.per_class_table.setMinimumHeight(160)
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
metrics_layout.addWidget(self.per_class_table)
metrics_group.setLayout(metrics_layout)
layout.addWidget(metrics_group)
# ===== Plots =====
plots_group = QGroupBox("Validation Plots")
plots_layout = QVBoxLayout()
self.plots_status = QLabel("Select a model to see validation plots.")
self.plots_status.setWordWrap(True)
plots_layout.addWidget(self.plots_status)
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
# Left: selected image viewer
left_widget = QWidget()
left_layout = QVBoxLayout(left_widget)
left_layout.setContentsMargins(0, 0, 0, 0)
self.selected_plot_title = QLabel("No image selected.")
self.selected_plot_title.setWordWrap(True)
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_title)
self.plot_view = _ZoomableImageView()
self.plot_view.setMinimumHeight(360)
left_layout.addWidget(self.plot_view, 1)
self.selected_plot_path = QLabel("")
self.selected_plot_path.setWordWrap(True)
self.selected_plot_path.setStyleSheet("color: #888;")
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_path)
# Right: scrollable list
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
right_layout.setContentsMargins(0, 0, 0, 0)
right_layout.addWidget(QLabel("Images:"))
self.plots_list = QListWidget()
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
self.plots_list.setIconSize(QSize(160, 160))
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
right_layout.addWidget(self.plots_list, 1)
self.plots_splitter.addWidget(left_widget)
self.plots_splitter.addWidget(right_widget)
self.plots_splitter.setStretchFactor(0, 3)
self.plots_splitter.setStretchFactor(1, 1)
plots_layout.addWidget(self.plots_splitter, 1)
plots_group.setLayout(plots_layout)
layout.addWidget(plots_group, 1)
layout.addStretch(0)
self._clear_metrics()
self._clear_plots()
# ==================== Public API ====================
def refresh(self):
"""Refresh the tab."""
self._load_models()
self._populate_model_combo()
self._restore_or_select_default_model()
# ==================== Internal: models ====================
def _load_models(self) -> None:
try:
self._models = self.db_manager.get_models() or []
except Exception as exc:
logger.error("Failed to load models: %s", exc)
self._models = []
def _populate_model_combo(self) -> None:
self.model_combo.blockSignals(True)
self.model_combo.clear()
self.model_combo.addItem("Select a model…", None)
for model in self._models:
model_id = model.get("id")
name = (model.get("model_name") or "").strip()
version = (model.get("model_version") or "").strip()
created_at = model.get("created_at")
label = f"{name} {version}".strip()
if created_at:
label = f"{label} ({created_at})"
self.model_combo.addItem(label, model_id)
self.model_combo.blockSignals(False)
if self._models:
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
else:
self.header_status.setText("No models found. Train a model first.")
def _restore_or_select_default_model(self) -> None:
if not self._models:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
return
# Keep selection if still present.
if self._selected_model_id is not None:
for idx in range(1, self.model_combo.count()):
if self.model_combo.itemData(idx) == self._selected_model_id:
self.model_combo.setCurrentIndex(idx)
return
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
if first_model_id is not None:
self.model_combo.setCurrentIndex(1)
def _on_model_selected(self, index: int) -> None:
model_id = self.model_combo.itemData(index)
if not model_id:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Select a model to see validation plots.")
return
self._selected_model_id = int(model_id)
model = self._get_model_by_id(self._selected_model_id)
if not model:
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Selected model not found.")
return
self._render_metrics(model)
self._render_plots(model)
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
for model in self._models:
if model.get("id") == model_id:
return model
try:
return self.db_manager.get_model_by_id(model_id)
except Exception:
return None
# ==================== Internal: metrics ====================
def _clear_metrics(self) -> None:
for label in self.metric_labels.values():
label.setText("")
self.per_class_table.setRowCount(0)
def _render_metrics(self, model: Dict[str, Any]) -> None:
self._clear_metrics()
metrics: Dict[str, Any] = model.get("metrics") or {}
# Training tab stores metrics under results['metrics'] in training results payload.
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
metrics = metrics.get("metrics") or {}
def set_metric(key: str, value: Any) -> None:
if key not in self.metric_labels:
return
if value is None:
self.metric_labels[key].setText("")
return
try:
self.metric_labels[key].setText(f"{float(value):.4f}")
except Exception:
self.metric_labels[key].setText(str(value))
set_metric("mAP50", metrics.get("mAP50"))
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
set_metric("precision", metrics.get("precision"))
set_metric("recall", metrics.get("recall"))
set_metric("fitness", metrics.get("fitness"))
# Optional per-class metrics
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
if isinstance(class_metrics, dict) and class_metrics:
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
self.per_class_table.setRowCount(len(items))
for row, (cls_name, cls_stats) in enumerate(items):
ap = (cls_stats or {}).get("ap")
ap50 = (cls_stats or {}).get("ap50")
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
else:
self.per_class_table.setRowCount(0)
@staticmethod
def _format_float(value: Any) -> str:
if value is None:
return ""
try:
return f"{float(value):.4f}"
except Exception:
return str(value)
# ==================== Internal: plots ====================
def _clear_plots(self) -> None:
# Remove legacy grid widgets (from the initial implementation).
for widget in self._plot_widgets:
widget.setParent(None)
widget.deleteLater()
self._plot_widgets = []
self._plot_items = []
if hasattr(self, "plots_list"):
self.plots_list.blockSignals(True)
self.plots_list.clear()
self.plots_list.blockSignals(False)
if hasattr(self, "plot_view"):
self.plot_view.clear()
if hasattr(self, "selected_plot_title"):
self.selected_plot_title.setText("No image selected.")
if hasattr(self, "selected_plot_path"):
self.selected_plot_path.setText("")
def _render_plots(self, model: Dict[str, Any]) -> None:
self._clear_plots()
plot_dirs = self._infer_run_directories(model)
plot_items = self._discover_plot_items(plot_dirs)
if not plot_items:
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
self.plots_status.setText(
"No validation plot images found for this model.\n\n"
"Searched directories:\n" + (dirs_text or "(none)")
)
return
self._plot_items = list(plot_items)
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
self.plots_list.blockSignals(True)
self.plots_list.clear()
for idx, item in enumerate(self._plot_items):
qitem = QListWidgetItem(item.label)
qitem.setData(Qt.ItemDataRole.UserRole, idx)
pix = QPixmap(str(item.path))
if not pix.isNull():
thumb = pix.scaled(
self.plots_list.iconSize(),
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation,
)
qitem.setIcon(thumb)
self.plots_list.addItem(qitem)
self.plots_list.blockSignals(False)
if self.plots_list.count() > 0:
self.plots_list.setCurrentRow(0)
def _on_plot_item_selected(self) -> None:
if not self._plot_items:
return
selected = self.plots_list.selectedItems()
if not selected:
return
idx = selected[0].data(Qt.ItemDataRole.UserRole)
try:
idx_int = int(idx)
except Exception:
return
if idx_int < 0 or idx_int >= len(self._plot_items):
return
plot = self._plot_items[idx_int]
self.selected_plot_title.setText(plot.label)
self.selected_plot_path.setText(str(plot.path))
pix = QPixmap(str(plot.path))
if pix.isNull():
self.plot_view.clear()
return
self.plot_view.set_pixmap(pix, fit=True)
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
dirs: List[Path] = []
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
model_path = model.get("model_path")
if model_path:
try:
p = Path(str(model_path)).expanduser()
if p.name.lower().endswith(".pt"):
# If it lives under weights/, use parent.parent.
if p.parent.name == "weights" and p.parent.parent.exists():
dirs.append(p.parent.parent)
elif p.parent.exists():
dirs.append(p.parent)
except Exception:
pass
# 2) Look at training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
stage_results = None
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if save_dir:
try:
save_path = Path(str(save_dir)).expanduser()
if save_path.exists():
dirs.append(save_path)
except Exception:
continue
# Deduplicate while preserving order.
unique: List[Path] = []
seen: set[str] = set()
for d in dirs:
try:
resolved = str(d.resolve())
except Exception:
resolved = str(d)
if resolved not in seen and d.exists() and d.is_dir():
seen.add(resolved)
unique.append(d)
return unique
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
preferred_names = [
"results.png",
"results.jpg",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
"labels.jpg",
"labels.png",
"BoxPR_curve.png",
"BoxP_curve.png",
"BoxR_curve.png",
"BoxF1_curve.png",
"MaskPR_curve.png",
"MaskP_curve.png",
"MaskR_curve.png",
"MaskF1_curve.png",
"val_batch0_pred.jpg",
"val_batch0_labels.jpg",
]
found: List[_PlotItem] = []
seen: set[str] = set()
for d in directories:
# 1) Preferred
for name in preferred_names:
p = d / name
if p.exists() and p.is_file():
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
# 2) Curated globs
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
for p in sorted(d.glob(pattern)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
for p in sorted(d.glob(ext)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# Keep list bounded to avoid UI overload for huge runs.
return found[:60]

View File

@@ -0,0 +1,103 @@
import numpy as np
from pathlib import Path
from skimage.draw import polygon
from tifffile import TiffFile
from src.database.db_manager import DatabaseManager
def read_image(image_path: Path) -> np.ndarray:
metadata = {}
with TiffFile(image_path) as tif:
image = tif.asarray()
metadata = tif.imagej_metadata
return image, metadata
def main():
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
image = np.zeros((100, 100), dtype=np.uint8)
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
image[rr, cc] = 255
if __name__ == "__main__":
db = DatabaseManager()
model_name = "c17"
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
print(f"Model name {model_name}, id {model_id}")
detections = db.get_detections(filters={"model_id": model_id})
file_stems = set()
for detection in detections:
file_stems.add(detection["image_filename"].split("_")[0])
print("Files:", file_stems)
for stem in file_stems:
print(stem)
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
annotations = []
for detection in detections:
source_path = Path(detection["metadata"]["source_path"])
image, metadata = read_image(source_path)
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
# print(source_path, image, metadata, segmentation.shape)
# print(offset)
# print(scale)
# print(segmentation)
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
segmentation = (segmentation + offset) / scale
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
)
annotations.append(yolo_annotation)
# print(segmentation)
# print(yolo_annotation)
# aa
print(
" ",
detection["model_name"],
detection["image_id"],
detection["image_filename"],
source_path,
metadata["label_path"],
)
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
# print(" ", section_i_section_j)
label_path = metadata["label_path"]
print(" ", label_path)
with open(label_path, "w") as f:
f.write("\n".join(annotations))
exit()
for detection in detections:
print(detection["model_name"], detection["image_id"], detection["image_filename"])
print(detections[0])
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
# image = np.zeros((100, 100), dtype=np.uint8)
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
# image[rr, cc] = 255
# import matplotlib.pyplot as plt
# plt.imshow(image, cmap='gray')
# plt.show()

View File

@@ -189,25 +189,30 @@ def main():
# continue and just show image
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
lclass, coords = labels[0]
print(lclass, coords)
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
if 0:
plt.imshow(out_rgb.transpose(1, 0, 2))
else:
plt.imshow(out_rgb)
for label in labels:
lclass, coords = label
# print(lclass, coords)
bbox = coords[:4]
print("bbox", bbox)
# print("bbox", bbox)
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
yc, xc, h, w = bbox
print("bbox", bbox)
# print("bbox", bbox)
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
print("pl", coords[4:])
print("pl", polyline)
# print("pl", coords[4:])
# print("pl", polyline)
# Convert BGR -> RGB for matplotlib display
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# out_rgb = Image()
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb)
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
if 0:
plt.plot(