Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e5036c10cf | |||
| c7e388d9ae | |||
| 6b995e7325 | |||
| 0e0741d323 | |||
| dd99a0677c | |||
| 9c4c39fb39 | |||
| 20a87c9040 | |||
| 9f7d2be1ac | |||
| dbde07c0e8 | |||
| b3c5a51dbb | |||
| 9a221acb63 | |||
| 32a6a122bd | |||
| 9ba44043ef | |||
| 8eb1cc8c86 | |||
| e4ce882a18 | |||
| 6b6d6fad03 | |||
| c0684a9c14 | |||
| 221c80aa8c | |||
| 833b222fad | |||
| 5370d31dce |
@@ -12,12 +12,26 @@ image_repository:
|
|||||||
models:
|
models:
|
||||||
default_base_model: yolov8s-seg.pt
|
default_base_model: yolov8s-seg.pt
|
||||||
models_directory: data/models
|
models_directory: data/models
|
||||||
|
base_model_choices:
|
||||||
|
- yolov8s-seg.pt
|
||||||
|
- yolo11s-seg.pt
|
||||||
training:
|
training:
|
||||||
default_epochs: 100
|
default_epochs: 100
|
||||||
default_batch_size: 16
|
default_batch_size: 16
|
||||||
default_imgsz: 640
|
default_imgsz: 1024
|
||||||
default_patience: 50
|
default_patience: 50
|
||||||
default_lr0: 0.01
|
default_lr0: 0.01
|
||||||
|
two_stage:
|
||||||
|
enabled: false
|
||||||
|
stage1:
|
||||||
|
epochs: 20
|
||||||
|
lr0: 0.0005
|
||||||
|
patience: 10
|
||||||
|
freeze: 10
|
||||||
|
stage2:
|
||||||
|
epochs: 150
|
||||||
|
lr0: 0.0003
|
||||||
|
patience: 30
|
||||||
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
|
||||||
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
last_dataset_dir: /home/martin/code/object_detection/data/datasets
|
||||||
detection:
|
detection:
|
||||||
|
|||||||
@@ -13,8 +13,9 @@ import hashlib
|
|||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp")
|
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -450,6 +451,25 @@ class DatabaseManager:
|
|||||||
filters["model_id"] = model_id
|
filters["model_id"] = model_id
|
||||||
return self.get_detections(filters)
|
return self.get_detections(filters)
|
||||||
|
|
||||||
|
def delete_detections_for_image(
|
||||||
|
self, image_id: int, model_id: Optional[int] = None
|
||||||
|
) -> int:
|
||||||
|
"""Delete detections tied to a specific image and optional model."""
|
||||||
|
conn = self.get_connection()
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
if model_id is not None:
|
||||||
|
cursor.execute(
|
||||||
|
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
|
||||||
|
(image_id, model_id),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
|
||||||
|
conn.commit()
|
||||||
|
return cursor.rowcount
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
def delete_detections_for_model(self, model_id: int) -> int:
|
def delete_detections_for_model(self, model_id: int) -> int:
|
||||||
"""Delete all detections for a specific model."""
|
"""Delete all detections for a specific model."""
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ class AnnotationTab(QWidget):
|
|||||||
self,
|
self,
|
||||||
"Select Image",
|
"Select Image",
|
||||||
start_dir,
|
start_dir,
|
||||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not file_path:
|
if not file_path:
|
||||||
|
|||||||
@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
|
|||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import get_image_files
|
from src.utils.file_utils import get_image_files
|
||||||
from src.model.inference import InferenceEngine
|
from src.model.inference import InferenceEngine
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
|
|||||||
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||||||
|
|
||||||
def _load_models(self):
|
def _load_models(self):
|
||||||
"""Load available models from database."""
|
"""Load available models from database and local storage."""
|
||||||
try:
|
try:
|
||||||
models = self.db_manager.get_models()
|
|
||||||
self.model_combo.clear()
|
self.model_combo.clear()
|
||||||
|
models = self.db_manager.get_models()
|
||||||
|
has_models = False
|
||||||
|
|
||||||
if not models:
|
known_paths = set()
|
||||||
self.model_combo.addItem("No models available", None)
|
|
||||||
self._set_buttons_enabled(False)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add base model option
|
# Add base model option first (always available)
|
||||||
base_model = self.config_manager.get(
|
base_model = self.config_manager.get(
|
||||||
"models.default_base_model", "yolov8s-seg.pt"
|
"models.default_base_model", "yolov8s-seg.pt"
|
||||||
)
|
)
|
||||||
self.model_combo.addItem(
|
if base_model:
|
||||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
base_data = {
|
||||||
)
|
"id": 0,
|
||||||
|
"path": base_model,
|
||||||
|
"model_name": Path(base_model).stem or "Base Model",
|
||||||
|
"model_version": "pretrained",
|
||||||
|
"base_model": base_model,
|
||||||
|
"source": "base",
|
||||||
|
}
|
||||||
|
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
|
||||||
|
known_paths.add(self._normalize_model_path(base_model))
|
||||||
|
has_models = True
|
||||||
|
|
||||||
# Add trained models
|
# Add trained models from database
|
||||||
for model in models:
|
for model in models:
|
||||||
display_name = f"{model['model_name']} v{model['model_version']}"
|
display_name = f"{model['model_name']} v{model['model_version']}"
|
||||||
self.model_combo.addItem(display_name, model)
|
model_data = {**model, "path": model.get("model_path")}
|
||||||
|
normalized = self._normalize_model_path(model_data.get("path"))
|
||||||
|
if normalized:
|
||||||
|
known_paths.add(normalized)
|
||||||
|
self.model_combo.addItem(display_name, model_data)
|
||||||
|
has_models = True
|
||||||
|
|
||||||
self._set_buttons_enabled(True)
|
# Discover local model files not yet in the database
|
||||||
|
local_models = self._discover_local_models()
|
||||||
|
for model_path in local_models:
|
||||||
|
normalized = self._normalize_model_path(model_path)
|
||||||
|
if normalized in known_paths:
|
||||||
|
continue
|
||||||
|
|
||||||
|
display_name = f"Local Model ({Path(model_path).stem})"
|
||||||
|
model_data = {
|
||||||
|
"id": None,
|
||||||
|
"path": str(model_path),
|
||||||
|
"model_name": Path(model_path).stem,
|
||||||
|
"model_version": "local",
|
||||||
|
"base_model": Path(model_path).stem,
|
||||||
|
"source": "local",
|
||||||
|
}
|
||||||
|
self.model_combo.addItem(display_name, model_data)
|
||||||
|
known_paths.add(normalized)
|
||||||
|
has_models = True
|
||||||
|
|
||||||
|
if not has_models:
|
||||||
|
self.model_combo.addItem("No models available", None)
|
||||||
|
self._set_buttons_enabled(False)
|
||||||
|
else:
|
||||||
|
self._set_buttons_enabled(True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error loading models: {e}")
|
logger.error(f"Error loading models: {e}")
|
||||||
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
|
|||||||
self,
|
self,
|
||||||
"Select Image",
|
"Select Image",
|
||||||
start_dir,
|
start_dir,
|
||||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||||
)
|
)
|
||||||
|
|
||||||
if not file_path:
|
if not file_path:
|
||||||
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
|
|||||||
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||||||
return
|
return
|
||||||
|
|
||||||
model_path = model_data["path"]
|
model_path = model_data.get("path")
|
||||||
model_id = model_data["id"]
|
if not model_path:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self, "Invalid Model", "Selected model is missing a file path."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
# Ensure we have a valid model ID (create entry for base model if needed)
|
if not Path(model_path).exists():
|
||||||
if model_id == 0:
|
QMessageBox.critical(
|
||||||
# Create database entry for base model
|
self,
|
||||||
base_model = self.config_manager.get(
|
"Model Not Found",
|
||||||
"models.default_base_model", "yolov8s-seg.pt"
|
f"The selected model file could not be found:\n{model_path}",
|
||||||
)
|
|
||||||
model_id = self.db_manager.add_model(
|
|
||||||
model_name="Base Model",
|
|
||||||
model_version="pretrained",
|
|
||||||
model_path=base_model,
|
|
||||||
base_model=base_model,
|
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
model_id = model_data.get("id")
|
||||||
|
|
||||||
|
# Ensure we have a database entry for the selected model
|
||||||
|
if model_id in (None, 0):
|
||||||
|
model_id = self._ensure_model_record(model_data)
|
||||||
|
if not model_id:
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Model Registration Failed",
|
||||||
|
"Unable to register the selected model in the database.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
normalized_model_path = self._normalize_model_path(model_path) or model_path
|
||||||
|
|
||||||
# Create inference engine
|
# Create inference engine
|
||||||
self.inference_engine = InferenceEngine(
|
self.inference_engine = InferenceEngine(
|
||||||
model_path, self.db_manager, model_id
|
normalized_model_path, self.db_manager, model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get confidence threshold
|
# Get confidence threshold
|
||||||
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
|
|||||||
self.batch_btn.setEnabled(enabled)
|
self.batch_btn.setEnabled(enabled)
|
||||||
self.model_combo.setEnabled(enabled)
|
self.model_combo.setEnabled(enabled)
|
||||||
|
|
||||||
|
def _discover_local_models(self) -> list:
|
||||||
|
"""Scan the models directory for standalone .pt files."""
|
||||||
|
models_dir = self.config_manager.get_models_directory()
|
||||||
|
if not models_dir:
|
||||||
|
return []
|
||||||
|
|
||||||
|
models_path = Path(models_dir)
|
||||||
|
if not models_path.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
return sorted(
|
||||||
|
[p for p in models_path.rglob("*.pt") if p.is_file()],
|
||||||
|
key=lambda p: str(p).lower(),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error discovering local models: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _normalize_model_path(self, path_value) -> str:
|
||||||
|
"""Return a normalized absolute path string for comparison."""
|
||||||
|
if not path_value:
|
||||||
|
return ""
|
||||||
|
try:
|
||||||
|
return str(Path(path_value).resolve())
|
||||||
|
except Exception:
|
||||||
|
return str(path_value)
|
||||||
|
|
||||||
|
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
|
||||||
|
"""Ensure a database record exists for the selected model."""
|
||||||
|
model_path = model_data.get("path")
|
||||||
|
if not model_path:
|
||||||
|
return None
|
||||||
|
|
||||||
|
normalized_target = self._normalize_model_path(model_path)
|
||||||
|
|
||||||
|
try:
|
||||||
|
existing_models = self.db_manager.get_models()
|
||||||
|
for model in existing_models:
|
||||||
|
existing_path = model.get("model_path")
|
||||||
|
if not existing_path:
|
||||||
|
continue
|
||||||
|
normalized_existing = self._normalize_model_path(existing_path)
|
||||||
|
if (
|
||||||
|
normalized_existing == normalized_target
|
||||||
|
or existing_path == model_path
|
||||||
|
):
|
||||||
|
return model["id"]
|
||||||
|
|
||||||
|
model_name = (
|
||||||
|
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
|
||||||
|
)
|
||||||
|
model_version = (
|
||||||
|
model_data.get("model_version") or model_data.get("source") or "local"
|
||||||
|
)
|
||||||
|
base_model = model_data.get(
|
||||||
|
"base_model",
|
||||||
|
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.db_manager.add_model(
|
||||||
|
model_name=model_name,
|
||||||
|
model_version=model_version,
|
||||||
|
model_path=normalized_target,
|
||||||
|
base_model=base_model,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to ensure model record for {model_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the tab."""
|
||||||
self._load_models()
|
self._load_models()
|
||||||
|
|||||||
@@ -1,15 +1,39 @@
|
|||||||
"""
|
"""
|
||||||
Results tab for the microscopy object detection application.
|
Results tab for browsing stored detections and visualizing overlays.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from PySide6.QtWidgets import (
|
||||||
|
QWidget,
|
||||||
|
QVBoxLayout,
|
||||||
|
QHBoxLayout,
|
||||||
|
QLabel,
|
||||||
|
QGroupBox,
|
||||||
|
QPushButton,
|
||||||
|
QSplitter,
|
||||||
|
QTableWidget,
|
||||||
|
QTableWidgetItem,
|
||||||
|
QHeaderView,
|
||||||
|
QAbstractItemView,
|
||||||
|
QMessageBox,
|
||||||
|
QCheckBox,
|
||||||
|
)
|
||||||
|
from PySide6.QtCore import Qt
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.image import Image, ImageLoadError
|
||||||
|
from src.gui.widgets import AnnotationCanvasWidget
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ResultsTab(QWidget):
|
class ResultsTab(QWidget):
|
||||||
"""Results tab placeholder."""
|
"""Results tab showing detection history and preview overlays."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||||
@@ -18,29 +42,398 @@ class ResultsTab(QWidget):
|
|||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
|
|
||||||
|
self.detection_summary: List[Dict] = []
|
||||||
|
self.current_selection: Optional[Dict] = None
|
||||||
|
self.current_image: Optional[Image] = None
|
||||||
|
self.current_detections: List[Dict] = []
|
||||||
|
self._image_path_cache: Dict[str, str] = {}
|
||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
self.refresh()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
|
|
||||||
group = QGroupBox("Results")
|
# Splitter for list + preview
|
||||||
group_layout = QVBoxLayout()
|
splitter = QSplitter(Qt.Horizontal)
|
||||||
label = QLabel(
|
|
||||||
"Results viewer will be implemented here.\n\n"
|
|
||||||
"Features:\n"
|
|
||||||
"- Detection history browser\n"
|
|
||||||
"- Advanced filtering\n"
|
|
||||||
"- Statistics dashboard\n"
|
|
||||||
"- Export functionality"
|
|
||||||
)
|
|
||||||
group_layout.addWidget(label)
|
|
||||||
group.setLayout(group_layout)
|
|
||||||
|
|
||||||
layout.addWidget(group)
|
# Left pane: detection list
|
||||||
layout.addStretch()
|
left_container = QWidget()
|
||||||
|
left_layout = QVBoxLayout()
|
||||||
|
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
controls_layout = QHBoxLayout()
|
||||||
|
self.refresh_btn = QPushButton("Refresh")
|
||||||
|
self.refresh_btn.clicked.connect(self.refresh)
|
||||||
|
controls_layout.addWidget(self.refresh_btn)
|
||||||
|
controls_layout.addStretch()
|
||||||
|
left_layout.addLayout(controls_layout)
|
||||||
|
|
||||||
|
self.results_table = QTableWidget(0, 5)
|
||||||
|
self.results_table.setHorizontalHeaderLabels(
|
||||||
|
["Image", "Model", "Detections", "Classes", "Last Updated"]
|
||||||
|
)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||||
|
0, QHeaderView.Stretch
|
||||||
|
)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||||
|
1, QHeaderView.Stretch
|
||||||
|
)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||||
|
2, QHeaderView.ResizeToContents
|
||||||
|
)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||||
|
3, QHeaderView.Stretch
|
||||||
|
)
|
||||||
|
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||||
|
4, QHeaderView.ResizeToContents
|
||||||
|
)
|
||||||
|
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||||
|
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||||
|
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||||
|
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
|
||||||
|
|
||||||
|
left_layout.addWidget(self.results_table)
|
||||||
|
left_container.setLayout(left_layout)
|
||||||
|
|
||||||
|
# Right pane: preview canvas and controls
|
||||||
|
right_container = QWidget()
|
||||||
|
right_layout = QVBoxLayout()
|
||||||
|
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
|
||||||
|
preview_group = QGroupBox("Detection Preview")
|
||||||
|
preview_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.preview_canvas = AnnotationCanvasWidget()
|
||||||
|
self.preview_canvas.set_polyline_enabled(False)
|
||||||
|
self.preview_canvas.set_show_bboxes(True)
|
||||||
|
preview_layout.addWidget(self.preview_canvas)
|
||||||
|
|
||||||
|
toggles_layout = QHBoxLayout()
|
||||||
|
self.show_masks_checkbox = QCheckBox("Show Masks")
|
||||||
|
self.show_masks_checkbox.setChecked(False)
|
||||||
|
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||||
|
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||||
|
self.show_bboxes_checkbox.setChecked(True)
|
||||||
|
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||||
|
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||||
|
self.show_confidence_checkbox.setChecked(False)
|
||||||
|
self.show_confidence_checkbox.stateChanged.connect(
|
||||||
|
self._apply_detection_overlays
|
||||||
|
)
|
||||||
|
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||||
|
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||||
|
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||||
|
toggles_layout.addStretch()
|
||||||
|
preview_layout.addLayout(toggles_layout)
|
||||||
|
|
||||||
|
self.summary_label = QLabel("Select a detection result to preview.")
|
||||||
|
self.summary_label.setWordWrap(True)
|
||||||
|
preview_layout.addWidget(self.summary_label)
|
||||||
|
|
||||||
|
preview_group.setLayout(preview_layout)
|
||||||
|
right_layout.addWidget(preview_group)
|
||||||
|
right_container.setLayout(right_layout)
|
||||||
|
|
||||||
|
splitter.addWidget(left_container)
|
||||||
|
splitter.addWidget(right_container)
|
||||||
|
splitter.setStretchFactor(0, 1)
|
||||||
|
splitter.setStretchFactor(1, 2)
|
||||||
|
|
||||||
|
layout.addWidget(splitter)
|
||||||
self.setLayout(layout)
|
self.setLayout(layout)
|
||||||
|
|
||||||
def refresh(self):
|
def refresh(self):
|
||||||
"""Refresh the tab."""
|
"""Refresh the detection list and preview."""
|
||||||
pass
|
self._load_detection_summary()
|
||||||
|
self._populate_results_table()
|
||||||
|
self.current_selection = None
|
||||||
|
self.current_image = None
|
||||||
|
self.current_detections = []
|
||||||
|
self.preview_canvas.clear()
|
||||||
|
self.summary_label.setText("Select a detection result to preview.")
|
||||||
|
|
||||||
|
def _load_detection_summary(self):
|
||||||
|
"""Load latest detection summaries grouped by image + model."""
|
||||||
|
try:
|
||||||
|
detections = self.db_manager.get_detections(limit=500)
|
||||||
|
summary_map: Dict[tuple, Dict] = {}
|
||||||
|
|
||||||
|
for det in detections:
|
||||||
|
key = (det["image_id"], det["model_id"])
|
||||||
|
metadata = det.get("metadata") or {}
|
||||||
|
entry = summary_map.setdefault(
|
||||||
|
key,
|
||||||
|
{
|
||||||
|
"image_id": det["image_id"],
|
||||||
|
"model_id": det["model_id"],
|
||||||
|
"image_path": det.get("image_path"),
|
||||||
|
"image_filename": det.get("image_filename")
|
||||||
|
or det.get("image_path"),
|
||||||
|
"model_name": det.get("model_name", ""),
|
||||||
|
"model_version": det.get("model_version", ""),
|
||||||
|
"last_detected": det.get("detected_at"),
|
||||||
|
"count": 0,
|
||||||
|
"classes": set(),
|
||||||
|
"source_path": metadata.get("source_path"),
|
||||||
|
"repository_root": metadata.get("repository_root"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
entry["count"] += 1
|
||||||
|
if det.get("detected_at") and (
|
||||||
|
not entry.get("last_detected")
|
||||||
|
or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||||
|
):
|
||||||
|
entry["last_detected"] = det.get("detected_at")
|
||||||
|
if det.get("class_name"):
|
||||||
|
entry["classes"].add(det["class_name"])
|
||||||
|
if metadata.get("source_path") and not entry.get("source_path"):
|
||||||
|
entry["source_path"] = metadata.get("source_path")
|
||||||
|
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||||
|
entry["repository_root"] = metadata.get("repository_root")
|
||||||
|
|
||||||
|
self.detection_summary = sorted(
|
||||||
|
summary_map.values(),
|
||||||
|
key=lambda x: str(x.get("last_detected") or ""),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load detection summary: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to load detection results:\n{str(e)}",
|
||||||
|
)
|
||||||
|
self.detection_summary = []
|
||||||
|
|
||||||
|
def _populate_results_table(self):
|
||||||
|
"""Populate the table widget with detection summaries."""
|
||||||
|
self.results_table.setRowCount(len(self.detection_summary))
|
||||||
|
|
||||||
|
for row, entry in enumerate(self.detection_summary):
|
||||||
|
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||||
|
class_list = (
|
||||||
|
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||||
|
)
|
||||||
|
|
||||||
|
items = [
|
||||||
|
QTableWidgetItem(entry.get("image_filename", "")),
|
||||||
|
QTableWidgetItem(model_label),
|
||||||
|
QTableWidgetItem(str(entry.get("count", 0))),
|
||||||
|
QTableWidgetItem(class_list),
|
||||||
|
QTableWidgetItem(str(entry.get("last_detected") or "")),
|
||||||
|
]
|
||||||
|
|
||||||
|
for col, item in enumerate(items):
|
||||||
|
item.setData(Qt.UserRole, row)
|
||||||
|
self.results_table.setItem(row, col, item)
|
||||||
|
|
||||||
|
self.results_table.clearSelection()
|
||||||
|
|
||||||
|
def _on_result_selected(self):
|
||||||
|
"""Handle selection changes in the detection table."""
|
||||||
|
selected_items = self.results_table.selectedItems()
|
||||||
|
if not selected_items:
|
||||||
|
return
|
||||||
|
|
||||||
|
row = selected_items[0].data(Qt.UserRole)
|
||||||
|
if row is None or row >= len(self.detection_summary):
|
||||||
|
return
|
||||||
|
|
||||||
|
entry = self.detection_summary[row]
|
||||||
|
if (
|
||||||
|
self.current_selection
|
||||||
|
and self.current_selection.get("image_id") == entry["image_id"]
|
||||||
|
and self.current_selection.get("model_id") == entry["model_id"]
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
self.current_selection = entry
|
||||||
|
|
||||||
|
image_path = self._resolve_image_path(entry)
|
||||||
|
if not image_path:
|
||||||
|
QMessageBox.warning(
|
||||||
|
self,
|
||||||
|
"Image Not Found",
|
||||||
|
"Unable to locate the image file for this detection.",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.current_image = Image(image_path)
|
||||||
|
self.preview_canvas.load_image(self.current_image)
|
||||||
|
except ImageLoadError as e:
|
||||||
|
logger.error(f"Failed to load image '{image_path}': {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Image Error",
|
||||||
|
f"Failed to load image for preview:\n{str(e)}",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._load_detections_for_selection(entry)
|
||||||
|
self._apply_detection_overlays()
|
||||||
|
self._update_summary_label(entry)
|
||||||
|
|
||||||
|
def _load_detections_for_selection(self, entry: Dict):
|
||||||
|
"""Load detection records for the selected image/model pair."""
|
||||||
|
self.current_detections = []
|
||||||
|
if not entry:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
|
||||||
|
self.current_detections = self.db_manager.get_detections(filters)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load detections for preview: {e}")
|
||||||
|
QMessageBox.critical(
|
||||||
|
self,
|
||||||
|
"Error",
|
||||||
|
f"Failed to load detections for this image:\n{str(e)}",
|
||||||
|
)
|
||||||
|
self.current_detections = []
|
||||||
|
|
||||||
|
def _apply_detection_overlays(self):
|
||||||
|
"""Draw detections onto the preview canvas based on current toggles."""
|
||||||
|
self.preview_canvas.clear_annotations()
|
||||||
|
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||||
|
|
||||||
|
if not self.current_detections or not self.current_image:
|
||||||
|
return
|
||||||
|
|
||||||
|
for det in self.current_detections:
|
||||||
|
color = self._get_class_color(det.get("class_name"))
|
||||||
|
|
||||||
|
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
|
||||||
|
mask_points = self._convert_mask(det["segmentation_mask"])
|
||||||
|
if mask_points:
|
||||||
|
self.preview_canvas.draw_saved_polyline(mask_points, color)
|
||||||
|
|
||||||
|
bbox = [
|
||||||
|
det.get("x_min"),
|
||||||
|
det.get("y_min"),
|
||||||
|
det.get("x_max"),
|
||||||
|
det.get("y_max"),
|
||||||
|
]
|
||||||
|
if all(v is not None for v in bbox):
|
||||||
|
label = None
|
||||||
|
if self.show_confidence_checkbox.isChecked():
|
||||||
|
confidence = det.get("confidence")
|
||||||
|
if confidence is not None:
|
||||||
|
label = f"{confidence:.2f}"
|
||||||
|
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
|
||||||
|
|
||||||
|
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||||
|
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||||
|
converted = []
|
||||||
|
for point in mask_points:
|
||||||
|
if len(point) >= 2:
|
||||||
|
x, y = point[0], point[1]
|
||||||
|
converted.append([y, x])
|
||||||
|
return converted
|
||||||
|
|
||||||
|
def _toggle_bboxes(self):
|
||||||
|
"""Update bounding box visibility on the canvas."""
|
||||||
|
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||||
|
# Re-render to respect show/hide when toggled
|
||||||
|
self._apply_detection_overlays()
|
||||||
|
|
||||||
|
def _update_summary_label(self, entry: Dict):
|
||||||
|
"""Display textual summary for the selected detection run."""
|
||||||
|
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
|
||||||
|
summary_text = (
|
||||||
|
f"Image: {entry.get('image_filename', 'unknown')}\n"
|
||||||
|
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
|
||||||
|
f"Detections: {entry.get('count', 0)}\n"
|
||||||
|
f"Classes: {classes}\n"
|
||||||
|
f"Last Updated: {entry.get('last_detected', 'n/a')}"
|
||||||
|
)
|
||||||
|
self.summary_label.setText(summary_text)
|
||||||
|
|
||||||
|
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
|
||||||
|
"""Resolve an image path using metadata, cache, and repository hints."""
|
||||||
|
relative_path = entry.get("image_path") if entry else None
|
||||||
|
cache_key = relative_path or entry.get("source_path")
|
||||||
|
if cache_key and cache_key in self._image_path_cache:
|
||||||
|
cached = Path(self._image_path_cache[cache_key])
|
||||||
|
if cached.exists():
|
||||||
|
return self._image_path_cache[cache_key]
|
||||||
|
del self._image_path_cache[cache_key]
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
source_path = entry.get("source_path") if entry else None
|
||||||
|
if source_path:
|
||||||
|
candidates.append(Path(source_path))
|
||||||
|
|
||||||
|
repo_roots = []
|
||||||
|
if entry.get("repository_root"):
|
||||||
|
repo_roots.append(entry["repository_root"])
|
||||||
|
config_repo = self.config_manager.get_image_repository_path()
|
||||||
|
if config_repo:
|
||||||
|
repo_roots.append(config_repo)
|
||||||
|
|
||||||
|
for root in repo_roots:
|
||||||
|
if relative_path:
|
||||||
|
candidates.append(Path(root) / relative_path)
|
||||||
|
|
||||||
|
if relative_path:
|
||||||
|
candidates.append(Path(relative_path))
|
||||||
|
|
||||||
|
for candidate in candidates:
|
||||||
|
try:
|
||||||
|
if candidate and candidate.exists():
|
||||||
|
resolved = str(candidate.resolve())
|
||||||
|
if cache_key:
|
||||||
|
self._image_path_cache[cache_key] = resolved
|
||||||
|
return resolved
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Fallback: search by filename in known roots
|
||||||
|
filename = Path(relative_path).name if relative_path else None
|
||||||
|
if filename:
|
||||||
|
search_roots = [Path(root) for root in repo_roots if root]
|
||||||
|
if not search_roots:
|
||||||
|
search_roots = [Path("data")]
|
||||||
|
match = self._search_in_roots(filename, search_roots)
|
||||||
|
if match:
|
||||||
|
resolved = str(match.resolve())
|
||||||
|
if cache_key:
|
||||||
|
self._image_path_cache[cache_key] = resolved
|
||||||
|
return resolved
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
|
||||||
|
"""Search for a file name within a list of root directories."""
|
||||||
|
for root in roots:
|
||||||
|
try:
|
||||||
|
if not root.exists():
|
||||||
|
continue
|
||||||
|
for candidate in root.rglob(filename):
|
||||||
|
return candidate
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Error searching for {filename} in {root}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_class_color(self, class_name: Optional[str]) -> str:
|
||||||
|
"""Return consistent color hex for a class name."""
|
||||||
|
if not class_name:
|
||||||
|
return "#FF6B6B"
|
||||||
|
|
||||||
|
color_map = self.config_manager.get_bbox_colors()
|
||||||
|
if class_name in color_map:
|
||||||
|
return color_map[class_name]
|
||||||
|
|
||||||
|
# Deterministic fallback color based on hash
|
||||||
|
palette = [
|
||||||
|
"#FF6B6B",
|
||||||
|
"#4ECDC4",
|
||||||
|
"#FFD166",
|
||||||
|
"#1D3557",
|
||||||
|
"#F4A261",
|
||||||
|
"#E76F51",
|
||||||
|
]
|
||||||
|
return palette[hash(class_name) % len(palette)]
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from PIL import Image as PILImage
|
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -28,24 +27,20 @@ from PySide6.QtWidgets import (
|
|||||||
QProgressBar,
|
QProgressBar,
|
||||||
QSpinBox,
|
QSpinBox,
|
||||||
QDoubleSpinBox,
|
QDoubleSpinBox,
|
||||||
|
QCheckBox,
|
||||||
|
QScrollArea,
|
||||||
)
|
)
|
||||||
|
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
from src.model.yolo_wrapper import YOLOWrapper
|
||||||
from src.utils.config_manager import ConfigManager
|
from src.utils.config_manager import ConfigManager
|
||||||
|
from src.utils.image import Image
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
DEFAULT_IMAGE_EXTENSIONS = {
|
DEFAULT_IMAGE_EXTENSIONS = set(Image.SUPPORTED_EXTENSIONS)
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
".png",
|
|
||||||
".tif",
|
|
||||||
".tiff",
|
|
||||||
".bmp",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingWorker(QThread):
|
class TrainingWorker(QThread):
|
||||||
@@ -67,6 +62,8 @@ class TrainingWorker(QThread):
|
|||||||
save_dir: str,
|
save_dir: str,
|
||||||
run_name: str,
|
run_name: str,
|
||||||
parent: Optional[QThread] = None,
|
parent: Optional[QThread] = None,
|
||||||
|
stage_plan: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
total_epochs: Optional[int] = None,
|
||||||
):
|
):
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.data_yaml = data_yaml
|
self.data_yaml = data_yaml
|
||||||
@@ -78,6 +75,27 @@ class TrainingWorker(QThread):
|
|||||||
self.lr0 = lr0
|
self.lr0 = lr0
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.run_name = run_name
|
self.run_name = run_name
|
||||||
|
self.stage_plan = stage_plan or [
|
||||||
|
{
|
||||||
|
"label": "Single Stage",
|
||||||
|
"model_path": base_model,
|
||||||
|
"use_previous_best": False,
|
||||||
|
"params": {
|
||||||
|
"epochs": epochs,
|
||||||
|
"batch": batch,
|
||||||
|
"imgsz": imgsz,
|
||||||
|
"patience": patience,
|
||||||
|
"lr0": lr0,
|
||||||
|
"freeze": 0,
|
||||||
|
"name": run_name,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
computed_total = sum(
|
||||||
|
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
||||||
|
for stage in self.stage_plan
|
||||||
|
)
|
||||||
|
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||||
self._stop_requested = False
|
self._stop_requested = False
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@@ -86,36 +104,98 @@ class TrainingWorker(QThread):
|
|||||||
self.requestInterruption()
|
self.requestInterruption()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
"""Execute YOLO training and emit progress/finished signals."""
|
"""Execute YOLO training over one or more stages and emit progress/finished signals."""
|
||||||
wrapper = YOLOWrapper(self.base_model)
|
|
||||||
|
|
||||||
def on_epoch_end(trainer):
|
completed_epochs = 0
|
||||||
current_epoch = getattr(trainer, "epoch", 0) + 1
|
stage_history: List[Dict[str, Any]] = []
|
||||||
metrics: Dict[str, float] = {}
|
last_stage_results: Optional[Dict[str, Any]] = None
|
||||||
loss_items = getattr(trainer, "loss_items", None)
|
|
||||||
if loss_items:
|
|
||||||
metrics["loss"] = float(loss_items[-1])
|
|
||||||
self.progress.emit(current_epoch, self.epochs, metrics)
|
|
||||||
if self.isInterruptionRequested() or self._stop_requested:
|
|
||||||
setattr(trainer, "stop_training", True)
|
|
||||||
|
|
||||||
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
for stage_index, stage in enumerate(self.stage_plan, start=1):
|
||||||
|
if self._stop_requested or self.isInterruptionRequested():
|
||||||
|
break
|
||||||
|
|
||||||
try:
|
stage_label = stage.get("label") or f"Stage {stage_index}"
|
||||||
results = wrapper.train(
|
stage_params = dict(stage.get("params") or {})
|
||||||
data_yaml=self.data_yaml,
|
stage_epochs = int(stage_params.get("epochs", self.epochs))
|
||||||
epochs=self.epochs,
|
if stage_epochs <= 0:
|
||||||
imgsz=self.imgsz,
|
stage_epochs = 1
|
||||||
batch=self.batch,
|
batch = int(stage_params.get("batch", self.batch))
|
||||||
patience=self.patience,
|
imgsz = int(stage_params.get("imgsz", self.imgsz))
|
||||||
save_dir=self.save_dir,
|
patience = int(stage_params.get("patience", self.patience))
|
||||||
name=self.run_name,
|
lr0 = float(stage_params.get("lr0", self.lr0))
|
||||||
lr0=self.lr0,
|
freeze = int(stage_params.get("freeze", 0))
|
||||||
callbacks=callbacks,
|
run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}"
|
||||||
|
|
||||||
|
weights_path = stage.get("model_path") or self.base_model
|
||||||
|
if stage.get("use_previous_best") and last_stage_results:
|
||||||
|
weights_path = (
|
||||||
|
last_stage_results.get("best_model_path")
|
||||||
|
or last_stage_results.get("last_model_path")
|
||||||
|
or weights_path
|
||||||
|
)
|
||||||
|
|
||||||
|
wrapper = YOLOWrapper(weights_path)
|
||||||
|
stage_offset = completed_epochs
|
||||||
|
|
||||||
|
def on_epoch_end(trainer, offset=stage_offset):
|
||||||
|
current_epoch = getattr(trainer, "epoch", 0) + 1
|
||||||
|
metrics: Dict[str, float] = {}
|
||||||
|
loss_items = getattr(trainer, "loss_items", None)
|
||||||
|
if loss_items:
|
||||||
|
metrics["loss"] = float(loss_items[-1])
|
||||||
|
absolute_epoch = min(
|
||||||
|
max(1, offset + current_epoch),
|
||||||
|
max(1, self.total_epochs),
|
||||||
|
)
|
||||||
|
self.progress.emit(absolute_epoch, self.total_epochs, metrics)
|
||||||
|
if self.isInterruptionRequested() or self._stop_requested:
|
||||||
|
setattr(trainer, "stop_training", True)
|
||||||
|
|
||||||
|
callbacks = {"on_fit_epoch_end": on_epoch_end}
|
||||||
|
|
||||||
|
try:
|
||||||
|
stage_result = wrapper.train(
|
||||||
|
data_yaml=self.data_yaml,
|
||||||
|
epochs=stage_epochs,
|
||||||
|
imgsz=imgsz,
|
||||||
|
batch=batch,
|
||||||
|
patience=patience,
|
||||||
|
save_dir=self.save_dir,
|
||||||
|
name=run_name,
|
||||||
|
lr0=lr0,
|
||||||
|
callbacks=callbacks,
|
||||||
|
freeze=freeze,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self.error.emit(str(exc))
|
||||||
|
return
|
||||||
|
|
||||||
|
stage_history.append(
|
||||||
|
{
|
||||||
|
"label": stage_label,
|
||||||
|
"params": stage_params,
|
||||||
|
"weights_used": weights_path,
|
||||||
|
"results": stage_result,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
self.finished.emit(results)
|
last_stage_results = stage_result
|
||||||
except Exception as exc:
|
completed_epochs += stage_epochs
|
||||||
self.error.emit(str(exc))
|
|
||||||
|
final_payload: Dict[str, Any]
|
||||||
|
if last_stage_results:
|
||||||
|
final_payload = dict(last_stage_results)
|
||||||
|
else:
|
||||||
|
final_payload = {
|
||||||
|
"success": False,
|
||||||
|
"message": "Training stopped before any stage completed.",
|
||||||
|
}
|
||||||
|
|
||||||
|
final_payload["stage_results"] = stage_history
|
||||||
|
final_payload["total_epochs_completed"] = completed_epochs
|
||||||
|
final_payload["total_epochs_planned"] = self.total_epochs
|
||||||
|
final_payload["stages_completed"] = len(stage_history)
|
||||||
|
|
||||||
|
self.finished.emit(final_payload)
|
||||||
|
|
||||||
|
|
||||||
class TrainingTab(QWidget):
|
class TrainingTab(QWidget):
|
||||||
@@ -146,12 +226,23 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
# Create a container widget for all content
|
||||||
|
container = QWidget()
|
||||||
|
container_layout = QVBoxLayout(container)
|
||||||
|
|
||||||
layout.addWidget(self._create_dataset_group())
|
container_layout.addWidget(self._create_dataset_group())
|
||||||
layout.addWidget(self._create_training_controls_group())
|
container_layout.addWidget(self._create_training_controls_group())
|
||||||
layout.addStretch()
|
container_layout.addStretch()
|
||||||
self.setLayout(layout)
|
|
||||||
|
# Create scroll area and set the container as its widget
|
||||||
|
scroll_area = QScrollArea()
|
||||||
|
scroll_area.setWidget(container)
|
||||||
|
scroll_area.setWidgetResizable(True)
|
||||||
|
|
||||||
|
# Set main layout with scroll area
|
||||||
|
main_layout = QVBoxLayout(self)
|
||||||
|
main_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
main_layout.addWidget(scroll_area)
|
||||||
|
|
||||||
self._discover_datasets()
|
self._discover_datasets()
|
||||||
self._load_saved_dataset()
|
self._load_saved_dataset()
|
||||||
@@ -249,13 +340,26 @@ class TrainingTab(QWidget):
|
|||||||
default_base_model = self.config_manager.get(
|
default_base_model = self.config_manager.get(
|
||||||
"models.default_base_model", "yolov8s-seg.pt"
|
"models.default_base_model", "yolov8s-seg.pt"
|
||||||
)
|
)
|
||||||
|
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||||
|
|
||||||
|
self.base_model_combo = QComboBox()
|
||||||
|
self.base_model_combo.addItem("Custom path…", "")
|
||||||
|
for choice in base_model_choices:
|
||||||
|
self.base_model_combo.addItem(choice, choice)
|
||||||
|
self.base_model_combo.currentIndexChanged.connect(
|
||||||
|
self._on_base_model_preset_changed
|
||||||
|
)
|
||||||
|
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||||
|
|
||||||
base_model_layout = QHBoxLayout()
|
base_model_layout = QHBoxLayout()
|
||||||
self.base_model_edit = QLineEdit(default_base_model)
|
self.base_model_edit = QLineEdit(default_base_model)
|
||||||
|
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
|
||||||
base_model_layout.addWidget(self.base_model_edit)
|
base_model_layout.addWidget(self.base_model_edit)
|
||||||
self.base_model_browse_button = QPushButton("Browse…")
|
self.base_model_browse_button = QPushButton("Browse…")
|
||||||
self.base_model_browse_button.clicked.connect(self._browse_base_model)
|
self.base_model_browse_button.clicked.connect(self._browse_base_model)
|
||||||
base_model_layout.addWidget(self.base_model_browse_button)
|
base_model_layout.addWidget(self.base_model_browse_button)
|
||||||
form_layout.addRow("Base Model (.pt):", base_model_layout)
|
form_layout.addRow("Base Model (.pt):", base_model_layout)
|
||||||
|
self._sync_base_model_preset_selection(default_base_model)
|
||||||
|
|
||||||
models_dir = self.config_manager.get("models.models_directory", "data/models")
|
models_dir = self.config_manager.get("models.models_directory", "data/models")
|
||||||
save_dir_layout = QHBoxLayout()
|
save_dir_layout = QHBoxLayout()
|
||||||
@@ -298,6 +402,9 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
group_layout.addLayout(form_layout)
|
group_layout.addLayout(form_layout)
|
||||||
|
|
||||||
|
self.two_stage_group = self._create_two_stage_group(training_defaults)
|
||||||
|
group_layout.addWidget(self.two_stage_group)
|
||||||
|
|
||||||
button_layout = QHBoxLayout()
|
button_layout = QHBoxLayout()
|
||||||
self.start_training_button = QPushButton("Start Training")
|
self.start_training_button = QPushButton("Start Training")
|
||||||
self.start_training_button.clicked.connect(self._start_training)
|
self.start_training_button.clicked.connect(self._start_training)
|
||||||
@@ -322,6 +429,134 @@ class TrainingTab(QWidget):
|
|||||||
group.setLayout(group_layout)
|
group.setLayout(group_layout)
|
||||||
return group
|
return group
|
||||||
|
|
||||||
|
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
|
||||||
|
group = QGroupBox("Two-Stage Fine-Tuning")
|
||||||
|
group_layout = QVBoxLayout()
|
||||||
|
|
||||||
|
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||||
|
two_stage_defaults = (
|
||||||
|
training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||||
|
)
|
||||||
|
self.two_stage_checkbox.setChecked(
|
||||||
|
bool(two_stage_defaults.get("enabled", False))
|
||||||
|
)
|
||||||
|
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||||
|
group_layout.addWidget(self.two_stage_checkbox)
|
||||||
|
|
||||||
|
self.two_stage_controls_widget = QWidget()
|
||||||
|
controls_layout = QVBoxLayout()
|
||||||
|
controls_layout.setContentsMargins(0, 0, 0, 0)
|
||||||
|
controls_layout.setSpacing(8)
|
||||||
|
|
||||||
|
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
|
||||||
|
stage1_form = QFormLayout()
|
||||||
|
stage1_defaults = two_stage_defaults.get("stage1", {})
|
||||||
|
|
||||||
|
self.stage1_epochs_spin = QSpinBox()
|
||||||
|
self.stage1_epochs_spin.setRange(1, 500)
|
||||||
|
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
|
||||||
|
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
|
||||||
|
|
||||||
|
self.stage1_lr_spin = QDoubleSpinBox()
|
||||||
|
self.stage1_lr_spin.setDecimals(5)
|
||||||
|
self.stage1_lr_spin.setRange(0.00001, 0.1)
|
||||||
|
self.stage1_lr_spin.setSingleStep(0.0005)
|
||||||
|
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
|
||||||
|
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
|
||||||
|
|
||||||
|
self.stage1_patience_spin = QSpinBox()
|
||||||
|
self.stage1_patience_spin.setRange(1, 200)
|
||||||
|
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
|
||||||
|
stage1_form.addRow("Patience:", self.stage1_patience_spin)
|
||||||
|
|
||||||
|
self.stage1_freeze_spin = QSpinBox()
|
||||||
|
self.stage1_freeze_spin.setRange(0, 24)
|
||||||
|
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
|
||||||
|
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
|
||||||
|
|
||||||
|
stage1_group.setLayout(stage1_form)
|
||||||
|
controls_layout.addWidget(stage1_group)
|
||||||
|
|
||||||
|
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
|
||||||
|
stage2_form = QFormLayout()
|
||||||
|
stage2_defaults = two_stage_defaults.get("stage2", {})
|
||||||
|
|
||||||
|
self.stage2_epochs_spin = QSpinBox()
|
||||||
|
self.stage2_epochs_spin.setRange(1, 2000)
|
||||||
|
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
|
||||||
|
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
|
||||||
|
|
||||||
|
self.stage2_lr_spin = QDoubleSpinBox()
|
||||||
|
self.stage2_lr_spin.setDecimals(5)
|
||||||
|
self.stage2_lr_spin.setRange(0.00001, 0.1)
|
||||||
|
self.stage2_lr_spin.setSingleStep(0.0005)
|
||||||
|
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
|
||||||
|
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
|
||||||
|
|
||||||
|
self.stage2_patience_spin = QSpinBox()
|
||||||
|
self.stage2_patience_spin.setRange(1, 200)
|
||||||
|
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
|
||||||
|
stage2_form.addRow("Patience:", self.stage2_patience_spin)
|
||||||
|
|
||||||
|
stage2_group.setLayout(stage2_form)
|
||||||
|
controls_layout.addWidget(stage2_group)
|
||||||
|
|
||||||
|
helper_label = QLabel(
|
||||||
|
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
||||||
|
)
|
||||||
|
helper_label.setWordWrap(True)
|
||||||
|
controls_layout.addWidget(helper_label)
|
||||||
|
|
||||||
|
self.two_stage_controls_widget.setLayout(controls_layout)
|
||||||
|
group_layout.addWidget(self.two_stage_controls_widget)
|
||||||
|
|
||||||
|
group.setLayout(group_layout)
|
||||||
|
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
|
||||||
|
return group
|
||||||
|
|
||||||
|
def _on_two_stage_toggled(self, checked: bool):
|
||||||
|
self._refresh_two_stage_controls_enabled(checked)
|
||||||
|
|
||||||
|
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
|
||||||
|
if not hasattr(self, "two_stage_controls_widget"):
|
||||||
|
return
|
||||||
|
desired_state = checked
|
||||||
|
if desired_state is None:
|
||||||
|
desired_state = self.two_stage_checkbox.isChecked()
|
||||||
|
can_edit = self.two_stage_checkbox.isEnabled()
|
||||||
|
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
|
||||||
|
|
||||||
|
def _on_base_model_preset_changed(self, index: int):
|
||||||
|
preset_value = self.base_model_combo.itemData(index)
|
||||||
|
if preset_value:
|
||||||
|
self.base_model_edit.setText(str(preset_value))
|
||||||
|
elif index == 0:
|
||||||
|
self.base_model_edit.setFocus()
|
||||||
|
|
||||||
|
def _on_base_model_path_edited(self):
|
||||||
|
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
|
||||||
|
|
||||||
|
def _sync_base_model_preset_selection(self, model_path: str):
|
||||||
|
if not hasattr(self, "base_model_combo"):
|
||||||
|
return
|
||||||
|
normalized = (model_path or "").strip()
|
||||||
|
target_index = 0
|
||||||
|
for idx in range(1, self.base_model_combo.count()):
|
||||||
|
preset_value = self.base_model_combo.itemData(idx)
|
||||||
|
if not preset_value:
|
||||||
|
continue
|
||||||
|
if normalized == preset_value:
|
||||||
|
target_index = idx
|
||||||
|
break
|
||||||
|
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
||||||
|
f"\\{preset_value}"
|
||||||
|
):
|
||||||
|
target_index = idx
|
||||||
|
break
|
||||||
|
self.base_model_combo.blockSignals(True)
|
||||||
|
self.base_model_combo.setCurrentIndex(target_index)
|
||||||
|
self.base_model_combo.blockSignals(False)
|
||||||
|
|
||||||
def _get_dataset_search_roots(self) -> List[Path]:
|
def _get_dataset_search_roots(self) -> List[Path]:
|
||||||
roots: List[Path] = []
|
roots: List[Path] = []
|
||||||
default_root = Path("data/datasets").expanduser()
|
default_root = Path("data/datasets").expanduser()
|
||||||
@@ -346,6 +581,7 @@ class TrainingTab(QWidget):
|
|||||||
for yaml_path in root.rglob("*.yaml"):
|
for yaml_path in root.rglob("*.yaml"):
|
||||||
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
|
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
discovered.append(yaml_path.resolve())
|
discovered.append(yaml_path.resolve())
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Unable to scan {root}: {exc}")
|
logger.warning(f"Unable to scan {root}: {exc}")
|
||||||
@@ -964,6 +1200,90 @@ class TrainingTab(QWidget):
|
|||||||
self._build_rgb_dataset(cache_root, dataset_info)
|
self._build_rgb_dataset(cache_root, dataset_info)
|
||||||
return rgb_yaml
|
return rgb_yaml
|
||||||
|
|
||||||
|
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
two_stage = params.get("two_stage") or {}
|
||||||
|
base_stage = {
|
||||||
|
"label": "Single Stage",
|
||||||
|
"model_path": params["base_model"],
|
||||||
|
"use_previous_best": False,
|
||||||
|
"params": {
|
||||||
|
"epochs": params["epochs"],
|
||||||
|
"batch": params["batch"],
|
||||||
|
"imgsz": params["imgsz"],
|
||||||
|
"patience": params["patience"],
|
||||||
|
"lr0": params["lr0"],
|
||||||
|
"freeze": 0,
|
||||||
|
"name": params["run_name"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if not two_stage.get("enabled"):
|
||||||
|
return [base_stage]
|
||||||
|
|
||||||
|
stage_plan: List[Dict[str, Any]] = []
|
||||||
|
stage1 = two_stage.get("stage1", {})
|
||||||
|
stage2 = two_stage.get("stage2", {})
|
||||||
|
|
||||||
|
stage_plan.append(
|
||||||
|
{
|
||||||
|
"label": "Stage 1 — Head-only",
|
||||||
|
"model_path": params["base_model"],
|
||||||
|
"use_previous_best": False,
|
||||||
|
"params": {
|
||||||
|
"epochs": stage1.get("epochs", params["epochs"]),
|
||||||
|
"batch": params["batch"],
|
||||||
|
"imgsz": params["imgsz"],
|
||||||
|
"patience": stage1.get("patience", params["patience"]),
|
||||||
|
"lr0": stage1.get("lr0", params["lr0"]),
|
||||||
|
"freeze": stage1.get("freeze", 0),
|
||||||
|
"name": f"{params['run_name']}_head_ft",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
stage_plan.append(
|
||||||
|
{
|
||||||
|
"label": "Stage 2 — Full",
|
||||||
|
"model_path": params["base_model"],
|
||||||
|
"use_previous_best": True,
|
||||||
|
"params": {
|
||||||
|
"epochs": stage2.get("epochs", params["epochs"]),
|
||||||
|
"batch": params["batch"],
|
||||||
|
"imgsz": params["imgsz"],
|
||||||
|
"patience": stage2.get("patience", params["patience"]),
|
||||||
|
"lr0": stage2.get("lr0", params["lr0"]),
|
||||||
|
"freeze": stage2.get("freeze", 0),
|
||||||
|
"name": f"{params['run_name']}_full_ft",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return stage_plan
|
||||||
|
|
||||||
|
def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int:
|
||||||
|
total = 0
|
||||||
|
for stage in stage_plan:
|
||||||
|
params = stage.get("params") or {}
|
||||||
|
try:
|
||||||
|
stage_epochs = int(params.get("epochs", 0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
stage_epochs = 0
|
||||||
|
if stage_epochs > 0:
|
||||||
|
total += stage_epochs
|
||||||
|
return total
|
||||||
|
|
||||||
|
def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]):
|
||||||
|
for index, stage in enumerate(stage_plan, start=1):
|
||||||
|
stage_label = stage.get("label") or f"Stage {index}"
|
||||||
|
params = stage.get("params") or {}
|
||||||
|
epochs = params.get("epochs", "?")
|
||||||
|
lr0 = params.get("lr0", "?")
|
||||||
|
patience = params.get("patience", "?")
|
||||||
|
freeze = params.get("freeze", 0)
|
||||||
|
self._append_training_log(
|
||||||
|
f" • {stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
|
||||||
cache_base = Path("data/datasets/_rgb_cache")
|
cache_base = Path("data/datasets/_rgb_cache")
|
||||||
cache_base.mkdir(parents=True, exist_ok=True)
|
cache_base.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -984,8 +1304,8 @@ class TrainingTab(QWidget):
|
|||||||
if not sample_image:
|
if not sample_image:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
with PILImage.open(sample_image) as img:
|
img = Image(sample_image)
|
||||||
return img.mode.upper() != "RGB"
|
return img.pil_image.mode.upper() != "RGB"
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
|
||||||
return False
|
return False
|
||||||
@@ -1045,9 +1365,13 @@ class TrainingTab(QWidget):
|
|||||||
dst = dst_dir / relative
|
dst = dst_dir / relative
|
||||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||||
try:
|
try:
|
||||||
with PILImage.open(src) as img:
|
img_obj = Image(src)
|
||||||
rgb_img = img.convert("RGB")
|
pil_img = img_obj.pil_image
|
||||||
rgb_img.save(dst)
|
if len(pil_img.getbands()) == 1:
|
||||||
|
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
|
||||||
|
else:
|
||||||
|
rgb_img = pil_img.convert("RGB")
|
||||||
|
rgb_img.save(dst)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
logger.warning(f"Failed to convert {src} to RGB: {exc}")
|
||||||
|
|
||||||
@@ -1085,6 +1409,21 @@ class TrainingTab(QWidget):
|
|||||||
save_dir_path.mkdir(parents=True, exist_ok=True)
|
save_dir_path.mkdir(parents=True, exist_ok=True)
|
||||||
run_name = f"{model_name}_{model_version}".replace(" ", "_")
|
run_name = f"{model_name}_{model_version}".replace(" ", "_")
|
||||||
|
|
||||||
|
two_stage_config = {
|
||||||
|
"enabled": self.two_stage_checkbox.isChecked(),
|
||||||
|
"stage1": {
|
||||||
|
"epochs": self.stage1_epochs_spin.value(),
|
||||||
|
"lr0": self.stage1_lr_spin.value(),
|
||||||
|
"patience": self.stage1_patience_spin.value(),
|
||||||
|
"freeze": self.stage1_freeze_spin.value(),
|
||||||
|
},
|
||||||
|
"stage2": {
|
||||||
|
"epochs": self.stage2_epochs_spin.value(),
|
||||||
|
"lr0": self.stage2_lr_spin.value(),
|
||||||
|
"patience": self.stage2_patience_spin.value(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model_name": model_name,
|
"model_name": model_name,
|
||||||
"model_version": model_version,
|
"model_version": model_version,
|
||||||
@@ -1096,6 +1435,7 @@ class TrainingTab(QWidget):
|
|||||||
"imgsz": self.imgsz_spin.value(),
|
"imgsz": self.imgsz_spin.value(),
|
||||||
"patience": self.patience_spin.value(),
|
"patience": self.patience_spin.value(),
|
||||||
"lr0": self.lr_spin.value(),
|
"lr0": self.lr_spin.value(),
|
||||||
|
"two_stage": two_stage_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _start_training(self):
|
def _start_training(self):
|
||||||
@@ -1137,15 +1477,25 @@ class TrainingTab(QWidget):
|
|||||||
)
|
)
|
||||||
|
|
||||||
params = self._collect_training_params()
|
params = self._collect_training_params()
|
||||||
|
stage_plan = self._compose_stage_plan(params)
|
||||||
|
params["stage_plan"] = stage_plan
|
||||||
|
total_planned_epochs = (
|
||||||
|
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||||
|
)
|
||||||
|
params["total_planned_epochs"] = total_planned_epochs
|
||||||
self._active_training_params = params
|
self._active_training_params = params
|
||||||
self._training_cancelled = False
|
self._training_cancelled = False
|
||||||
|
|
||||||
|
if len(stage_plan) > 1:
|
||||||
|
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||||
|
self._log_stage_plan(stage_plan)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(
|
||||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.training_progress_bar.setVisible(True)
|
self.training_progress_bar.setVisible(True)
|
||||||
self.training_progress_bar.setMaximum(params["epochs"])
|
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||||
self.training_progress_bar.setValue(0)
|
self.training_progress_bar.setValue(0)
|
||||||
self._set_training_state(True)
|
self._set_training_state(True)
|
||||||
|
|
||||||
@@ -1159,6 +1509,8 @@ class TrainingTab(QWidget):
|
|||||||
lr0=params["lr0"],
|
lr0=params["lr0"],
|
||||||
save_dir=params["save_dir"],
|
save_dir=params["save_dir"],
|
||||||
run_name=params["run_name"],
|
run_name=params["run_name"],
|
||||||
|
stage_plan=stage_plan,
|
||||||
|
total_epochs=total_planned_epochs,
|
||||||
)
|
)
|
||||||
self.training_worker.progress.connect(self._on_training_progress)
|
self.training_worker.progress.connect(self._on_training_progress)
|
||||||
self.training_worker.finished.connect(self._on_training_finished)
|
self.training_worker.finished.connect(self._on_training_finished)
|
||||||
@@ -1283,14 +1635,22 @@ class TrainingTab(QWidget):
|
|||||||
if not model_path:
|
if not model_path:
|
||||||
raise ValueError("Training results did not include a model path.")
|
raise ValueError("Training results did not include a model path.")
|
||||||
|
|
||||||
|
effective_epochs = params.get("total_planned_epochs", params["epochs"])
|
||||||
training_params = {
|
training_params = {
|
||||||
"epochs": params["epochs"],
|
"epochs": effective_epochs,
|
||||||
"batch": params["batch"],
|
"batch": params["batch"],
|
||||||
"imgsz": params["imgsz"],
|
"imgsz": params["imgsz"],
|
||||||
"patience": params["patience"],
|
"patience": params["patience"],
|
||||||
"lr0": params["lr0"],
|
"lr0": params["lr0"],
|
||||||
"run_name": params["run_name"],
|
"run_name": params["run_name"],
|
||||||
|
"two_stage": params.get("two_stage"),
|
||||||
}
|
}
|
||||||
|
if params.get("stage_plan"):
|
||||||
|
training_params["stage_plan"] = params["stage_plan"]
|
||||||
|
if results.get("stage_results"):
|
||||||
|
training_params["stage_results"] = results["stage_results"]
|
||||||
|
if results.get("total_epochs_completed") is not None:
|
||||||
|
training_params["epochs_completed"] = results["total_epochs_completed"]
|
||||||
|
|
||||||
model_id = self.db_manager.add_model(
|
model_id = self.db_manager.add_model(
|
||||||
model_name=params["model_name"],
|
model_name=params["model_name"],
|
||||||
@@ -1315,6 +1675,7 @@ class TrainingTab(QWidget):
|
|||||||
self.rescan_button.setEnabled(not is_training)
|
self.rescan_button.setEnabled(not is_training)
|
||||||
self.model_name_edit.setEnabled(not is_training)
|
self.model_name_edit.setEnabled(not is_training)
|
||||||
self.model_version_edit.setEnabled(not is_training)
|
self.model_version_edit.setEnabled(not is_training)
|
||||||
|
self.base_model_combo.setEnabled(not is_training)
|
||||||
self.base_model_edit.setEnabled(not is_training)
|
self.base_model_edit.setEnabled(not is_training)
|
||||||
self.base_model_browse_button.setEnabled(not is_training)
|
self.base_model_browse_button.setEnabled(not is_training)
|
||||||
self.save_dir_edit.setEnabled(not is_training)
|
self.save_dir_edit.setEnabled(not is_training)
|
||||||
@@ -1324,6 +1685,8 @@ class TrainingTab(QWidget):
|
|||||||
self.imgsz_spin.setEnabled(not is_training)
|
self.imgsz_spin.setEnabled(not is_training)
|
||||||
self.patience_spin.setEnabled(not is_training)
|
self.patience_spin.setEnabled(not is_training)
|
||||||
self.lr_spin.setEnabled(not is_training)
|
self.lr_spin.setEnabled(not is_training)
|
||||||
|
self.two_stage_checkbox.setEnabled(not is_training)
|
||||||
|
self._refresh_two_stage_controls_enabled()
|
||||||
|
|
||||||
def _append_training_log(self, message: str):
|
def _append_training_log(self, message: str):
|
||||||
timestamp = datetime.now().strftime("%H:%M:%S")
|
timestamp = datetime.now().strftime("%H:%M:%S")
|
||||||
@@ -1339,6 +1702,7 @@ class TrainingTab(QWidget):
|
|||||||
)
|
)
|
||||||
if file_path:
|
if file_path:
|
||||||
self.base_model_edit.setText(file_path)
|
self.base_model_edit.setText(file_path)
|
||||||
|
self._sync_base_model_preset_selection(file_path)
|
||||||
|
|
||||||
def _browse_save_dir(self):
|
def _browse_save_dir(self):
|
||||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||||
|
|||||||
@@ -16,8 +16,9 @@ from PySide6.QtGui import (
|
|||||||
QKeyEvent,
|
QKeyEvent,
|
||||||
QMouseEvent,
|
QMouseEvent,
|
||||||
QPaintEvent,
|
QPaintEvent,
|
||||||
|
QPolygonF,
|
||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.utils.image import Image, ImageLoadError
|
from src.utils.image import Image, ImageLoadError
|
||||||
@@ -246,10 +247,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get RGB image data
|
# Get image data in a format compatible with Qt
|
||||||
if self.current_image.channels == 3:
|
if self.current_image.channels in (3, 4):
|
||||||
image_data = self.current_image.get_rgb()
|
image_data = self.current_image.get_rgb()
|
||||||
height, width, channels = image_data.shape
|
height, width = image_data.shape[:2]
|
||||||
else:
|
else:
|
||||||
image_data = self.current_image.get_grayscale()
|
image_data = self.current_image.get_grayscale()
|
||||||
height, width = image_data.shape
|
height, width = image_data.shape
|
||||||
@@ -263,7 +264,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
height,
|
height,
|
||||||
bytes_per_line,
|
bytes_per_line,
|
||||||
self.current_image.qtimage_format,
|
self.current_image.qtimage_format,
|
||||||
)
|
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||||
|
|
||||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||||
|
|
||||||
@@ -496,8 +497,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
)
|
)
|
||||||
|
|
||||||
painter.setPen(pen)
|
painter.setPen(pen)
|
||||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
|
||||||
painter.drawLine(int(x1), int(y1), int(x2), int(y2))
|
# drawPolygon() automatically closes the shape, ensuring proper visual closure
|
||||||
|
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
|
||||||
|
painter.drawPolygon(polygon)
|
||||||
|
|
||||||
# Draw bounding boxes (dashed) if enabled
|
# Draw bounding boxes (dashed) if enabled
|
||||||
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
|
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
|
||||||
@@ -529,6 +532,40 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
painter.setPen(pen)
|
painter.setPen(pen)
|
||||||
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||||
|
|
||||||
|
label_text = meta.get("label")
|
||||||
|
if label_text:
|
||||||
|
painter.save()
|
||||||
|
font = painter.font()
|
||||||
|
font.setPointSizeF(max(10.0, width + 4))
|
||||||
|
painter.setFont(font)
|
||||||
|
metrics = painter.fontMetrics()
|
||||||
|
text_width = metrics.horizontalAdvance(label_text)
|
||||||
|
text_height = metrics.height()
|
||||||
|
padding = 4
|
||||||
|
bg_width = text_width + padding * 2
|
||||||
|
bg_height = text_height + padding * 2
|
||||||
|
canvas_width = self.original_pixmap.width()
|
||||||
|
canvas_height = self.original_pixmap.height()
|
||||||
|
bg_x = max(0, min(x_min, canvas_width - bg_width))
|
||||||
|
bg_y = y_min - bg_height
|
||||||
|
if bg_y < 0:
|
||||||
|
bg_y = min(y_min, canvas_height - bg_height)
|
||||||
|
bg_y = max(0, bg_y)
|
||||||
|
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
|
||||||
|
background_color = QColor(pen_color)
|
||||||
|
background_color.setAlpha(220)
|
||||||
|
painter.fillRect(background_rect, background_color)
|
||||||
|
text_color = QColor(0, 0, 0)
|
||||||
|
if background_color.lightness() < 128:
|
||||||
|
text_color = QColor(255, 255, 255)
|
||||||
|
painter.setPen(text_color)
|
||||||
|
painter.drawText(
|
||||||
|
background_rect.adjusted(padding, padding, -padding, -padding),
|
||||||
|
Qt.AlignLeft | Qt.AlignVCenter,
|
||||||
|
label_text,
|
||||||
|
)
|
||||||
|
painter.restore()
|
||||||
|
|
||||||
painter.end()
|
painter.end()
|
||||||
|
|
||||||
self._update_display()
|
self._update_display()
|
||||||
@@ -787,7 +824,13 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
|
def draw_saved_bbox(
|
||||||
|
self,
|
||||||
|
bbox: List[float],
|
||||||
|
color: str,
|
||||||
|
width: int = 3,
|
||||||
|
label: Optional[str] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Draw a bounding box from database coordinates onto the annotation canvas.
|
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||||
|
|
||||||
@@ -796,6 +839,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
in normalized coordinates (0-1)
|
in normalized coordinates (0-1)
|
||||||
color: Color hex string (e.g., '#FF0000')
|
color: Color hex string (e.g., '#FF0000')
|
||||||
width: Line width in pixels
|
width: Line width in pixels
|
||||||
|
label: Optional text label to render near the bounding box
|
||||||
"""
|
"""
|
||||||
if not self.annotation_pixmap or not self.original_pixmap:
|
if not self.annotation_pixmap or not self.original_pixmap:
|
||||||
logger.warning("Cannot draw bounding box: no image loaded")
|
logger.warning("Cannot draw bounding box: no image loaded")
|
||||||
@@ -828,11 +872,11 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self.bboxes.append(
|
self.bboxes.append(
|
||||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
||||||
)
|
)
|
||||||
self.bbox_meta.append({"color": pen_color, "width": int(width)})
|
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||||
|
|
||||||
# Store in all_strokes for consistency
|
# Store in all_strokes for consistency
|
||||||
self.all_strokes.append(
|
self.all_strokes.append(
|
||||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
|
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Redraw overlay (polylines + all bounding boxes)
|
# Redraw overlay (polylines + all bounding boxes)
|
||||||
|
|||||||
@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
|
|||||||
height,
|
height,
|
||||||
bytes_per_line,
|
bytes_per_line,
|
||||||
self.current_image.qtimage_format,
|
self.current_image.qtimage_format,
|
||||||
)
|
).copy() # Copy to ensure Qt owns its memory after this scope
|
||||||
|
|
||||||
# Convert to pixmap
|
# Convert to pixmap
|
||||||
pixmap = QPixmap.fromImage(qimage)
|
pixmap = QPixmap.fromImage(qimage)
|
||||||
|
|||||||
@@ -5,12 +5,12 @@ Handles detection inference and result storage.
|
|||||||
|
|
||||||
from typing import List, Dict, Optional, Callable
|
from typing import List, Dict, Optional, Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from PIL import Image
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from src.model.yolo_wrapper import YOLOWrapper
|
from src.model.yolo_wrapper import YOLOWrapper
|
||||||
from src.database.db_manager import DatabaseManager
|
from src.database.db_manager import DatabaseManager
|
||||||
|
from src.utils.image import Image
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
from src.utils.file_utils import get_relative_path
|
from src.utils.file_utils import get_relative_path
|
||||||
|
|
||||||
@@ -42,6 +42,7 @@ class InferenceEngine:
|
|||||||
relative_path: str,
|
relative_path: str,
|
||||||
conf: float = 0.25,
|
conf: float = 0.25,
|
||||||
save_to_db: bool = True,
|
save_to_db: bool = True,
|
||||||
|
repository_root: Optional[str] = None,
|
||||||
) -> Dict:
|
) -> Dict:
|
||||||
"""
|
"""
|
||||||
Detect objects in a single image.
|
Detect objects in a single image.
|
||||||
@@ -51,49 +52,79 @@ class InferenceEngine:
|
|||||||
relative_path: Relative path from repository root
|
relative_path: Relative path from repository root
|
||||||
conf: Confidence threshold
|
conf: Confidence threshold
|
||||||
save_to_db: Whether to save results to database
|
save_to_db: Whether to save results to database
|
||||||
|
repository_root: Base directory used to compute relative_path (if known)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with detection results
|
Dictionary with detection results
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Normalize storage path (fall back to absolute path when repo root is unknown)
|
||||||
|
stored_relative_path = relative_path
|
||||||
|
if not repository_root:
|
||||||
|
stored_relative_path = str(Path(image_path).resolve())
|
||||||
|
|
||||||
# Get image dimensions
|
# Get image dimensions
|
||||||
img = Image.open(image_path)
|
img = Image(image_path)
|
||||||
width, height = img.size
|
width = img.width
|
||||||
img.close()
|
height = img.height
|
||||||
|
|
||||||
# Perform detection
|
# Perform detection
|
||||||
detections = self.yolo.predict(image_path, conf=conf)
|
detections = self.yolo.predict(image_path, conf=conf)
|
||||||
|
|
||||||
# Add/get image in database
|
# Add/get image in database
|
||||||
image_id = self.db_manager.get_or_create_image(
|
image_id = self.db_manager.get_or_create_image(
|
||||||
relative_path=relative_path,
|
relative_path=stored_relative_path,
|
||||||
filename=Path(image_path).name,
|
filename=Path(image_path).name,
|
||||||
width=width,
|
width=width,
|
||||||
height=height,
|
height=height,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save detections to database
|
inserted_count = 0
|
||||||
if save_to_db and detections:
|
deleted_count = 0
|
||||||
detection_records = []
|
|
||||||
for det in detections:
|
|
||||||
# Use normalized bbox from detection
|
|
||||||
bbox_normalized = det[
|
|
||||||
"bbox_normalized"
|
|
||||||
] # [x_min, y_min, x_max, y_max]
|
|
||||||
|
|
||||||
record = {
|
# Save detections to database, replacing any previous results for this image/model
|
||||||
"image_id": image_id,
|
if save_to_db:
|
||||||
"model_id": self.model_id,
|
deleted_count = self.db_manager.delete_detections_for_image(
|
||||||
"class_name": det["class_name"],
|
image_id, self.model_id
|
||||||
"bbox": tuple(bbox_normalized),
|
)
|
||||||
"confidence": det["confidence"],
|
if detections:
|
||||||
"segmentation_mask": det.get("segmentation_mask"),
|
detection_records = []
|
||||||
"metadata": {"class_id": det["class_id"]},
|
for det in detections:
|
||||||
}
|
# Use normalized bbox from detection
|
||||||
detection_records.append(record)
|
bbox_normalized = det[
|
||||||
|
"bbox_normalized"
|
||||||
|
] # [x_min, y_min, x_max, y_max]
|
||||||
|
|
||||||
self.db_manager.add_detections_batch(detection_records)
|
metadata = {
|
||||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
"class_id": det["class_id"],
|
||||||
|
"source_path": str(Path(image_path).resolve()),
|
||||||
|
}
|
||||||
|
if repository_root:
|
||||||
|
metadata["repository_root"] = str(
|
||||||
|
Path(repository_root).resolve()
|
||||||
|
)
|
||||||
|
|
||||||
|
record = {
|
||||||
|
"image_id": image_id,
|
||||||
|
"model_id": self.model_id,
|
||||||
|
"class_name": det["class_name"],
|
||||||
|
"bbox": tuple(bbox_normalized),
|
||||||
|
"confidence": det["confidence"],
|
||||||
|
"segmentation_mask": det.get("segmentation_mask"),
|
||||||
|
"metadata": metadata,
|
||||||
|
}
|
||||||
|
detection_records.append(record)
|
||||||
|
|
||||||
|
inserted_count = self.db_manager.add_detections_batch(
|
||||||
|
detection_records
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
||||||
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -142,7 +173,12 @@ class InferenceEngine:
|
|||||||
rel_path = get_relative_path(image_path, repository_root)
|
rel_path = get_relative_path(image_path, repository_root)
|
||||||
|
|
||||||
# Perform detection
|
# Perform detection
|
||||||
result = self.detect_single(image_path, rel_path, conf)
|
result = self.detect_single(
|
||||||
|
image_path,
|
||||||
|
rel_path,
|
||||||
|
conf=conf,
|
||||||
|
repository_root=repository_root,
|
||||||
|
)
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
# Update progress
|
# Update progress
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ from ultralytics import YOLO
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List, Dict, Callable, Any
|
from typing import Optional, List, Dict, Callable, Any
|
||||||
import torch
|
import torch
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
from src.utils.image import Image
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -77,7 +80,8 @@ class YOLOWrapper:
|
|||||||
Dictionary with training results
|
Dictionary with training results
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting training: {name}")
|
logger.info(f"Starting training: {name}")
|
||||||
@@ -119,7 +123,8 @@ class YOLOWrapper:
|
|||||||
Dictionary with validation metrics
|
Dictionary with validation metrics
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting validation on {split} split")
|
logger.info(f"Starting validation on {split} split")
|
||||||
@@ -160,12 +165,15 @@ class YOLOWrapper:
|
|||||||
List of detection dictionaries
|
List of detection dictionaries
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
|
prepared_source, cleanup_path = self._prepare_source(source)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Running inference on {source}")
|
logger.info(f"Running inference on {source}")
|
||||||
results = self.model.predict(
|
results = self.model.predict(
|
||||||
source=source,
|
source=prepared_source,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
iou=iou,
|
iou=iou,
|
||||||
save=save,
|
save=save,
|
||||||
@@ -182,6 +190,14 @@ class YOLOWrapper:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during inference: {e}")
|
logger.error(f"Error during inference: {e}")
|
||||||
raise
|
raise
|
||||||
|
finally:
|
||||||
|
if 0: # cleanup_path:
|
||||||
|
try:
|
||||||
|
os.remove(cleanup_path)
|
||||||
|
except OSError as cleanup_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
||||||
|
)
|
||||||
|
|
||||||
def export(
|
def export(
|
||||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||||
@@ -198,7 +214,8 @@ class YOLOWrapper:
|
|||||||
Path to exported model
|
Path to exported model
|
||||||
"""
|
"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.load_model()
|
if not self.load_model():
|
||||||
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Exporting model to {format} format")
|
logger.info(f"Exporting model to {format} format")
|
||||||
@@ -210,6 +227,38 @@ class YOLOWrapper:
|
|||||||
logger.error(f"Error exporting model: {e}")
|
logger.error(f"Error exporting model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def _prepare_source(self, source):
|
||||||
|
"""Convert single-channel images to RGB temporarily for inference."""
|
||||||
|
cleanup_path = None
|
||||||
|
|
||||||
|
if isinstance(source, (str, Path)):
|
||||||
|
source_path = Path(source)
|
||||||
|
if source_path.is_file():
|
||||||
|
try:
|
||||||
|
img_obj = Image(source_path)
|
||||||
|
pil_img = img_obj.pil_image
|
||||||
|
if len(pil_img.getbands()) == 1:
|
||||||
|
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
|
||||||
|
else:
|
||||||
|
rgb_img = pil_img.convert("RGB")
|
||||||
|
|
||||||
|
suffix = source_path.suffix or ".png"
|
||||||
|
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||||
|
tmp_path = tmp.name
|
||||||
|
tmp.close()
|
||||||
|
rgb_img.save(tmp_path)
|
||||||
|
cleanup_path = tmp_path
|
||||||
|
logger.info(
|
||||||
|
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||||
|
)
|
||||||
|
return tmp_path, cleanup_path
|
||||||
|
except Exception as convert_error:
|
||||||
|
logger.warning(
|
||||||
|
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return source, cleanup_path
|
||||||
|
|
||||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||||
"""Format training results into dictionary."""
|
"""Format training results into dictionary."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import yaml
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
from src.utils.logger import get_logger
|
from src.utils.logger import get_logger
|
||||||
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -46,18 +47,15 @@ class ConfigManager:
|
|||||||
"database": {"path": "data/detections.db"},
|
"database": {"path": "data/detections.db"},
|
||||||
"image_repository": {
|
"image_repository": {
|
||||||
"base_path": "",
|
"base_path": "",
|
||||||
"allowed_extensions": [
|
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
|
||||||
".jpg",
|
|
||||||
".jpeg",
|
|
||||||
".png",
|
|
||||||
".tif",
|
|
||||||
".tiff",
|
|
||||||
".bmp",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
"models": {
|
"models": {
|
||||||
"default_base_model": "yolov8s-seg.pt",
|
"default_base_model": "yolov8s-seg.pt",
|
||||||
"models_directory": "data/models",
|
"models_directory": "data/models",
|
||||||
|
"base_model_choices": [
|
||||||
|
"yolov8s-seg.pt",
|
||||||
|
"yolov11s-seg.pt",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
"training": {
|
"training": {
|
||||||
"default_epochs": 100,
|
"default_epochs": 100,
|
||||||
@@ -65,6 +63,20 @@ class ConfigManager:
|
|||||||
"default_imgsz": 640,
|
"default_imgsz": 640,
|
||||||
"default_patience": 50,
|
"default_patience": 50,
|
||||||
"default_lr0": 0.01,
|
"default_lr0": 0.01,
|
||||||
|
"two_stage": {
|
||||||
|
"enabled": False,
|
||||||
|
"stage1": {
|
||||||
|
"epochs": 20,
|
||||||
|
"lr0": 0.0005,
|
||||||
|
"patience": 10,
|
||||||
|
"freeze": 10,
|
||||||
|
},
|
||||||
|
"stage2": {
|
||||||
|
"epochs": 150,
|
||||||
|
"lr0": 0.0003,
|
||||||
|
"patience": 30,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"detection": {
|
"detection": {
|
||||||
"default_confidence": 0.25,
|
"default_confidence": 0.25,
|
||||||
@@ -214,5 +226,5 @@ class ConfigManager:
|
|||||||
def get_allowed_extensions(self) -> list:
|
def get_allowed_extensions(self) -> list:
|
||||||
"""Get list of allowed image file extensions."""
|
"""Get list of allowed image file extensions."""
|
||||||
return self.get(
|
return self.get(
|
||||||
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
|
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -28,7 +28,9 @@ def get_image_files(
|
|||||||
List of absolute paths to image files
|
List of absolute paths to image files
|
||||||
"""
|
"""
|
||||||
if allowed_extensions is None:
|
if allowed_extensions is None:
|
||||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
# Normalize extensions to lowercase
|
# Normalize extensions to lowercase
|
||||||
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
||||||
@@ -204,7 +206,9 @@ def is_image_file(
|
|||||||
True if file is an image
|
True if file is an image
|
||||||
"""
|
"""
|
||||||
if allowed_extensions is None:
|
if allowed_extensions is None:
|
||||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
from src.utils.image import Image
|
||||||
|
|
||||||
|
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||||
|
|
||||||
extension = Path(file_path).suffix.lower()
|
extension = Path(file_path).suffix.lower()
|
||||||
return extension in [ext.lower() for ext in allowed_extensions]
|
return extension in [ext.lower() for ext in allowed_extensions]
|
||||||
|
|||||||
@@ -277,6 +277,38 @@ class Image:
|
|||||||
"""
|
"""
|
||||||
return self._channels >= 3
|
return self._channels >= 3
|
||||||
|
|
||||||
|
def convert_grayscale_to_rgb_preserve_range(
|
||||||
|
self,
|
||||||
|
) -> PILImage.Image:
|
||||||
|
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image in RGB mode with intensities normalized to 0-255.
|
||||||
|
"""
|
||||||
|
if self._channels == 3:
|
||||||
|
return self.pil_image
|
||||||
|
|
||||||
|
grayscale = self.data
|
||||||
|
if grayscale.ndim == 3:
|
||||||
|
grayscale = grayscale[:, :, 0]
|
||||||
|
|
||||||
|
original_dtype = grayscale.dtype
|
||||||
|
grayscale = grayscale.astype(np.float32)
|
||||||
|
|
||||||
|
if grayscale.size == 0:
|
||||||
|
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
|
||||||
|
|
||||||
|
if np.issubdtype(original_dtype, np.integer):
|
||||||
|
denom = float(max(np.iinfo(original_dtype).max, 1))
|
||||||
|
else:
|
||||||
|
max_val = float(grayscale.max())
|
||||||
|
denom = max(max_val, 1.0)
|
||||||
|
|
||||||
|
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
|
||||||
|
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
|
||||||
|
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
|
||||||
|
return PILImage.fromarray(rgb_arr, mode="RGB")
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
"""String representation of the Image object."""
|
"""String representation of the Image object."""
|
||||||
return (
|
return (
|
||||||
|
|||||||
160
src/utils/image_converters.py
Normal file
160
src/utils/image_converters.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from roifile import ImagejRoi
|
||||||
|
from tifffile import TiffFile, TiffWriter
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class UT:
|
||||||
|
"""
|
||||||
|
Docstring for UT
|
||||||
|
|
||||||
|
Operetta files along with rois drawn in ImageJ
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, roifile_fn: Path, no_labels: bool):
|
||||||
|
self.roifile_fn = roifile_fn
|
||||||
|
print("is file", self.roifile_fn.is_file())
|
||||||
|
self.rois = None
|
||||||
|
if no_labels:
|
||||||
|
self.rois = ImagejRoi.fromfile(self.roifile_fn)
|
||||||
|
self.stem = self.roifile_fn.stem.split("Roi-")[1]
|
||||||
|
else:
|
||||||
|
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
|
||||||
|
self.stem = self.roifile_fn.stem
|
||||||
|
|
||||||
|
print(self.roifile_fn)
|
||||||
|
|
||||||
|
print(self.stem)
|
||||||
|
self.image, self.image_props = self._load_images()
|
||||||
|
|
||||||
|
def _load_images(self):
|
||||||
|
"""Loading sequence of tif files
|
||||||
|
array sequence is CZYX
|
||||||
|
"""
|
||||||
|
print("Loading images:", self.roifile_fn.parent, self.stem)
|
||||||
|
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
|
||||||
|
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
|
||||||
|
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
|
||||||
|
n_p = len(set([stem.split("-")[0] for stem in stems]))
|
||||||
|
n_t = len(set([stem.split("t")[1] for stem in stems]))
|
||||||
|
|
||||||
|
with TiffFile(fns[0]) as tif:
|
||||||
|
img = tif.asarray()
|
||||||
|
w, h = img.shape
|
||||||
|
dtype = img.dtype
|
||||||
|
self.image_props = {
|
||||||
|
"channels": n_ch,
|
||||||
|
"planes": n_p,
|
||||||
|
"tiles": n_t,
|
||||||
|
"width": w,
|
||||||
|
"height": h,
|
||||||
|
"dtype": dtype,
|
||||||
|
}
|
||||||
|
print("Image props", self.image_props)
|
||||||
|
|
||||||
|
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
|
||||||
|
for fn in fns:
|
||||||
|
with TiffFile(fn) as tif:
|
||||||
|
img = tif.asarray()
|
||||||
|
stem = fn.stem.split(self.stem)[-1]
|
||||||
|
ch = int(stem.split("-ch")[-1].split("t")[0])
|
||||||
|
p = int(stem.split("-")[0].split("p")[1])
|
||||||
|
t = int(stem.split("t")[1])
|
||||||
|
print(fn.stem, "ch", ch, "p", p, "t", t)
|
||||||
|
image_stack[ch - 1, p - 1] = img
|
||||||
|
|
||||||
|
print(image_stack.shape)
|
||||||
|
|
||||||
|
return image_stack, self.image_props
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self):
|
||||||
|
return self.image_props["width"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self):
|
||||||
|
return self.image_props["height"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nchannels(self):
|
||||||
|
return self.image_props["channels"]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nplanes(self):
|
||||||
|
return self.image_props["planes"]
|
||||||
|
|
||||||
|
def export_rois(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
subfolder: str = "labels",
|
||||||
|
class_index: int = 0,
|
||||||
|
):
|
||||||
|
"""Export rois to a file"""
|
||||||
|
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
||||||
|
for i, roi in enumerate(self.rois):
|
||||||
|
rc = roi.subpixel_coordinates
|
||||||
|
if rc is None:
|
||||||
|
print(
|
||||||
|
f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
xmn, ymn = rc.min(axis=0)
|
||||||
|
xmx, ymx = rc.max(axis=0)
|
||||||
|
xc = (xmn + xmx) / 2
|
||||||
|
yc = (ymn + ymx) / 2
|
||||||
|
bw = xmx - xmn
|
||||||
|
bh = ymx - ymn
|
||||||
|
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
|
||||||
|
for x, y in rc:
|
||||||
|
coords += f"{x/self.width} {y/self.height} "
|
||||||
|
f.write(f"{class_index} {coords}\n")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
def export_image(
|
||||||
|
self,
|
||||||
|
path: Path,
|
||||||
|
subfolder: str = "images",
|
||||||
|
plane_mode: str = "max projection",
|
||||||
|
channel: int = 0,
|
||||||
|
):
|
||||||
|
"""Export image to a file"""
|
||||||
|
|
||||||
|
if plane_mode == "max projection":
|
||||||
|
self.image = np.max(self.image[channel], axis=0)
|
||||||
|
print(self.image.shape)
|
||||||
|
|
||||||
|
print(path / subfolder / f"{self.stem}.tif")
|
||||||
|
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
||||||
|
tif.write(self.image)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-i", "--input", nargs="*", type=Path)
|
||||||
|
parser.add_argument("-o", "--output", type=Path)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-labels",
|
||||||
|
action="store_false",
|
||||||
|
help="Source does not have labels, export only images",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
for path in args.input:
|
||||||
|
print("Path:", path)
|
||||||
|
if not args.no_labels:
|
||||||
|
print("No labels")
|
||||||
|
ut = UT(path, args.no_labels)
|
||||||
|
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||||
|
|
||||||
|
else:
|
||||||
|
for rfn in Path(path).glob("*.zip"):
|
||||||
|
print("Roi FN:", rfn)
|
||||||
|
ut = UT(rfn, args.no_labels)
|
||||||
|
ut.export_rois(args.output, class_index=0)
|
||||||
|
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||||
|
|
||||||
|
print()
|
||||||
184
tests/show_yolo_seg.py
Normal file
184
tests/show_yolo_seg.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
show_yolo_seg.py
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
|
||||||
|
- YOLO bbox lines as fallback: "class x_center y_center width height"
|
||||||
|
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
def parse_label_line(line):
|
||||||
|
parts = line.strip().split()
|
||||||
|
if not parts:
|
||||||
|
return None
|
||||||
|
cls = int(float(parts[0]))
|
||||||
|
coords = [float(x) for x in parts[1:]]
|
||||||
|
return cls, coords
|
||||||
|
|
||||||
|
|
||||||
|
def coords_are_normalized(coords):
|
||||||
|
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||||
|
if not coords:
|
||||||
|
return False
|
||||||
|
return max(coords) <= 1.001
|
||||||
|
|
||||||
|
|
||||||
|
def yolo_bbox_to_xyxy(coords, img_w, img_h):
|
||||||
|
# coords: [xc, yc, w, h] normalized or absolute
|
||||||
|
xc, yc, w, h = coords[:4]
|
||||||
|
if max(coords) <= 1.001:
|
||||||
|
xc *= img_w
|
||||||
|
yc *= img_h
|
||||||
|
w *= img_w
|
||||||
|
h *= img_h
|
||||||
|
x1 = int(round(xc - w / 2))
|
||||||
|
y1 = int(round(yc - h / 2))
|
||||||
|
x2 = int(round(xc + w / 2))
|
||||||
|
y2 = int(round(yc + h / 2))
|
||||||
|
return x1, y1, x2, y2
|
||||||
|
|
||||||
|
|
||||||
|
def poly_to_pts(coords, img_w, img_h):
|
||||||
|
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||||
|
if coords_are_normalized(coords[4:]):
|
||||||
|
coords = [
|
||||||
|
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
|
||||||
|
]
|
||||||
|
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||||
|
return pts
|
||||||
|
|
||||||
|
|
||||||
|
def random_color_for_class(cls):
|
||||||
|
random.seed(cls) # deterministic per class
|
||||||
|
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
|
||||||
|
|
||||||
|
|
||||||
|
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
||||||
|
# img: BGR numpy array
|
||||||
|
overlay = img.copy()
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
for cls, coords in labels:
|
||||||
|
if not coords:
|
||||||
|
continue
|
||||||
|
# polygon case (>=6 coordinates)
|
||||||
|
if len(coords) >= 6:
|
||||||
|
color = random_color_for_class(cls)
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||||
|
|
||||||
|
pts = poly_to_pts(coords[4:], w, h)
|
||||||
|
# fill on overlay
|
||||||
|
cv2.fillPoly(overlay, [pts], color)
|
||||||
|
# outline on base image
|
||||||
|
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
|
||||||
|
# put class text at first point
|
||||||
|
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
|
||||||
|
cv2.putText(
|
||||||
|
img,
|
||||||
|
str(cls),
|
||||||
|
(x, max(6, y)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
(255, 255, 255),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
|
||||||
|
# YOLO bbox case (4 coords)
|
||||||
|
elif len(coords) == 4:
|
||||||
|
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
|
||||||
|
color = random_color_for_class(cls)
|
||||||
|
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||||
|
cv2.putText(
|
||||||
|
img,
|
||||||
|
str(cls),
|
||||||
|
(x1, max(6, y1 - 4)),
|
||||||
|
cv2.FONT_HERSHEY_SIMPLEX,
|
||||||
|
0.6,
|
||||||
|
(255, 255, 255),
|
||||||
|
2,
|
||||||
|
cv2.LINE_AA,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Unknown / invalid format, skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
# blend overlay for filled polygons
|
||||||
|
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def load_labels_file(label_path):
|
||||||
|
labels = []
|
||||||
|
with open(label_path, "r") as f:
|
||||||
|
for raw in f:
|
||||||
|
line = raw.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
parsed = parse_label_line(line)
|
||||||
|
if parsed:
|
||||||
|
labels.append(parsed)
|
||||||
|
return labels
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Show YOLO segmentation / polygon annotations"
|
||||||
|
)
|
||||||
|
parser.add_argument("image", type=str, help="Path to image file")
|
||||||
|
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
|
||||||
|
parser.add_argument(
|
||||||
|
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
img_path = Path(args.image)
|
||||||
|
lbl_path = Path(args.labels)
|
||||||
|
|
||||||
|
if not img_path.exists():
|
||||||
|
print("Image not found:", img_path)
|
||||||
|
sys.exit(1)
|
||||||
|
if not lbl_path.exists():
|
||||||
|
print("Label file not found:", lbl_path)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
||||||
|
if img is None:
|
||||||
|
print("Could not load image:", img_path)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
labels = load_labels_file(str(lbl_path))
|
||||||
|
if not labels:
|
||||||
|
print("No labels parsed from", lbl_path)
|
||||||
|
# continue and just show image
|
||||||
|
out = draw_annotations(
|
||||||
|
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert BGR -> RGB for matplotlib display
|
||||||
|
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||||
|
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||||
|
plt.imshow(out_rgb)
|
||||||
|
plt.axis("off")
|
||||||
|
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -27,7 +27,7 @@ class TestImage:
|
|||||||
|
|
||||||
def test_supported_extensions(self):
|
def test_supported_extensions(self):
|
||||||
"""Test that supported extensions are correctly defined."""
|
"""Test that supported extensions are correctly defined."""
|
||||||
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
expected_extensions = Image.SUPPORTED_EXTENSIONS
|
||||||
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
||||||
|
|
||||||
def test_image_properties(self, tmp_path):
|
def test_image_properties(self, tmp_path):
|
||||||
|
|||||||
Reference in New Issue
Block a user