20 Commits

Author SHA1 Message Date
e5036c10cf Small fix 2025-12-16 18:03:56 +02:00
c7e388d9ae Updating progressbar 2025-12-16 17:20:25 +02:00
6b995e7325 upate 2025-12-16 13:24:20 +02:00
0e0741d323 Update on convert_grayscale_to_rgb_preserve_range, making it class method 2025-12-16 12:37:34 +02:00
dd99a0677c Updating image converter and aading simple script to visulaize segmentation 2025-12-16 11:27:38 +02:00
9c4c39fb39 Adding image converter 2025-12-12 23:52:34 +02:00
20a87c9040 Updating config 2025-12-12 21:51:12 +02:00
9f7d2be1ac Updating the base model preset 2025-12-11 23:27:02 +02:00
dbde07c0e8 Making training tab scrollable 2025-12-11 23:12:39 +02:00
b3c5a51dbb Using QPolygonF instead of drawLine 2025-12-11 17:14:07 +02:00
9a221acb63 Making image manipulations thru one class 2025-12-11 16:59:56 +02:00
32a6a122bd Fixing circular import 2025-12-11 16:06:39 +02:00
9ba44043ef Defining image extensions only in one place 2025-12-11 15:50:14 +02:00
8eb1cc8c86 Fixing grayscale conversion 2025-12-11 15:15:38 +02:00
e4ce882a18 Grayscale RGB conversion modified 2025-12-11 15:06:59 +02:00
6b6d6fad03 2Stage training fix 2025-12-11 12:50:34 +02:00
c0684a9c14 Implementing 2 stage training 2025-12-11 12:04:08 +02:00
221c80aa8c Small image showing fix 2025-12-11 11:20:20 +02:00
833b222fad Adding result shower 2025-12-10 16:55:28 +02:00
5370d31dce Merge pull request 'Update training' (#2) from training into main
Reviewed-on: #2
2025-12-10 15:47:00 +02:00
16 changed files with 1585 additions and 151 deletions

View File

@@ -12,12 +12,26 @@ image_repository:
models:
default_base_model: yolov8s-seg.pt
models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolo11s-seg.pt
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 640
default_imgsz: 1024
default_patience: 50
default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection:

View File

@@ -13,8 +13,9 @@ import hashlib
import yaml
from src.utils.logger import get_logger
from src.utils.image import Image
IMAGE_EXTENSIONS = (".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp")
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
logger = get_logger(__name__)
@@ -450,6 +451,25 @@ class DatabaseManager:
filters["model_id"] = model_id
return self.get_detections(filters)
def delete_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> int:
"""Delete detections tied to a specific image and optional model."""
conn = self.get_connection()
try:
cursor = conn.cursor()
if model_id is not None:
cursor.execute(
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
(image_id, model_id),
)
else:
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
conn.commit()
return cursor.rowcount
finally:
conn.close()
def delete_detections_for_model(self, model_id: int) -> int:
"""Delete all detections for a specific model."""
conn = self.get_connection()

View File

@@ -168,7 +168,7 @@ class AnnotationTab(QWidget):
self,
"Select Image",
start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
)
if not file_path:

View File

@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
)
from PySide6.QtCore import Qt, QThread, Signal
from pathlib import Path
from typing import Optional
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.file_utils import get_image_files
from src.model.inference import InferenceEngine
from src.utils.image import Image
logger = get_logger(__name__)
@@ -147,29 +149,65 @@ class DetectionTab(QWidget):
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
def _load_models(self):
"""Load available models from database."""
"""Load available models from database and local storage."""
try:
models = self.db_manager.get_models()
self.model_combo.clear()
models = self.db_manager.get_models()
has_models = False
if not models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
return
known_paths = set()
# Add base model option
# Add base model option first (always available)
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
self.model_combo.addItem(
f"Base Model ({base_model})", {"id": 0, "path": base_model}
)
if base_model:
base_data = {
"id": 0,
"path": base_model,
"model_name": Path(base_model).stem or "Base Model",
"model_version": "pretrained",
"base_model": base_model,
"source": "base",
}
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
known_paths.add(self._normalize_model_path(base_model))
has_models = True
# Add trained models
# Add trained models from database
for model in models:
display_name = f"{model['model_name']} v{model['model_version']}"
self.model_combo.addItem(display_name, model)
model_data = {**model, "path": model.get("model_path")}
normalized = self._normalize_model_path(model_data.get("path"))
if normalized:
known_paths.add(normalized)
self.model_combo.addItem(display_name, model_data)
has_models = True
# Discover local model files not yet in the database
local_models = self._discover_local_models()
for model_path in local_models:
normalized = self._normalize_model_path(model_path)
if normalized in known_paths:
continue
display_name = f"Local Model ({Path(model_path).stem})"
model_data = {
"id": None,
"path": str(model_path),
"model_name": Path(model_path).stem,
"model_version": "local",
"base_model": Path(model_path).stem,
"source": "local",
}
self.model_combo.addItem(display_name, model_data)
known_paths.add(normalized)
has_models = True
if not has_models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
else:
self._set_buttons_enabled(True)
except Exception as e:
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
self,
"Select Image",
start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
)
if not file_path:
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
QMessageBox.warning(self, "No Model", "Please select a model first.")
return
model_path = model_data["path"]
model_id = model_data["id"]
model_path = model_data.get("path")
if not model_path:
QMessageBox.warning(
self, "Invalid Model", "Selected model is missing a file path."
)
return
# Ensure we have a valid model ID (create entry for base model if needed)
if model_id == 0:
# Create database entry for base model
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
if not Path(model_path).exists():
QMessageBox.critical(
self,
"Model Not Found",
f"The selected model file could not be found:\n{model_path}",
)
model_id = self.db_manager.add_model(
model_name="Base Model",
model_version="pretrained",
model_path=base_model,
base_model=base_model,
return
model_id = model_data.get("id")
# Ensure we have a database entry for the selected model
if model_id in (None, 0):
model_id = self._ensure_model_record(model_data)
if not model_id:
QMessageBox.critical(
self,
"Model Registration Failed",
"Unable to register the selected model in the database.",
)
return
normalized_model_path = self._normalize_model_path(model_path) or model_path
# Create inference engine
self.inference_engine = InferenceEngine(
model_path, self.db_manager, model_id
normalized_model_path, self.db_manager, model_id
)
# Get confidence threshold
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
self.batch_btn.setEnabled(enabled)
self.model_combo.setEnabled(enabled)
def _discover_local_models(self) -> list:
"""Scan the models directory for standalone .pt files."""
models_dir = self.config_manager.get_models_directory()
if not models_dir:
return []
models_path = Path(models_dir)
if not models_path.exists():
return []
try:
return sorted(
[p for p in models_path.rglob("*.pt") if p.is_file()],
key=lambda p: str(p).lower(),
)
except Exception as e:
logger.warning(f"Error discovering local models: {e}")
return []
def _normalize_model_path(self, path_value) -> str:
"""Return a normalized absolute path string for comparison."""
if not path_value:
return ""
try:
return str(Path(path_value).resolve())
except Exception:
return str(path_value)
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
"""Ensure a database record exists for the selected model."""
model_path = model_data.get("path")
if not model_path:
return None
normalized_target = self._normalize_model_path(model_path)
try:
existing_models = self.db_manager.get_models()
for model in existing_models:
existing_path = model.get("model_path")
if not existing_path:
continue
normalized_existing = self._normalize_model_path(existing_path)
if (
normalized_existing == normalized_target
or existing_path == model_path
):
return model["id"]
model_name = (
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
)
model_version = (
model_data.get("model_version") or model_data.get("source") or "local"
)
base_model = model_data.get(
"base_model",
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
)
return self.db_manager.add_model(
model_name=model_name,
model_version=model_version,
model_path=normalized_target,
base_model=base_model,
)
except Exception as e:
logger.error(f"Failed to ensure model record for {model_path}: {e}")
return None
def refresh(self):
"""Refresh the tab."""
self._load_models()

View File

@@ -1,15 +1,39 @@
"""
Results tab for the microscopy object detection application.
Results tab for browsing stored detections and visualizing overlays.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from pathlib import Path
from typing import Dict, List, Optional
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QSplitter,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
QMessageBox,
QCheckBox,
)
from PySide6.QtCore import Qt
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.image import Image, ImageLoadError
from src.gui.widgets import AnnotationCanvasWidget
logger = get_logger(__name__)
class ResultsTab(QWidget):
"""Results tab placeholder."""
"""Results tab showing detection history and preview overlays."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
@@ -18,29 +42,398 @@ class ResultsTab(QWidget):
self.db_manager = db_manager
self.config_manager = config_manager
self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None
self.current_detections: List[Dict] = []
self._image_path_cache: Dict[str, str] = {}
self._setup_ui()
self.refresh()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
group = QGroupBox("Results")
group_layout = QVBoxLayout()
label = QLabel(
"Results viewer will be implemented here.\n\n"
"Features:\n"
"- Detection history browser\n"
"- Advanced filtering\n"
"- Statistics dashboard\n"
"- Export functionality"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
# Splitter for list + preview
splitter = QSplitter(Qt.Horizontal)
layout.addWidget(group)
layout.addStretch()
# Left pane: detection list
left_container = QWidget()
left_layout = QVBoxLayout()
left_layout.setContentsMargins(0, 0, 0, 0)
controls_layout = QHBoxLayout()
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
controls_layout.addStretch()
left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels(
["Image", "Model", "Detections", "Classes", "Last Updated"]
)
self.results_table.horizontalHeader().setSectionResizeMode(
0, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
left_layout.addWidget(self.results_table)
left_container.setLayout(left_layout)
# Right pane: preview canvas and controls
right_container = QWidget()
right_layout = QVBoxLayout()
right_layout.setContentsMargins(0, 0, 0, 0)
preview_group = QGroupBox("Detection Preview")
preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget()
self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas)
toggles_layout = QHBoxLayout()
self.show_masks_checkbox = QCheckBox("Show Masks")
self.show_masks_checkbox.setChecked(False)
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
toggles_layout.addStretch()
preview_layout.addLayout(toggles_layout)
self.summary_label = QLabel("Select a detection result to preview.")
self.summary_label.setWordWrap(True)
preview_layout.addWidget(self.summary_label)
preview_group.setLayout(preview_layout)
right_layout.addWidget(preview_group)
right_container.setLayout(right_layout)
splitter.addWidget(left_container)
splitter.addWidget(right_container)
splitter.setStretchFactor(0, 1)
splitter.setStretchFactor(1, 2)
layout.addWidget(splitter)
self.setLayout(layout)
def refresh(self):
"""Refresh the tab."""
pass
"""Refresh the detection list and preview."""
self._load_detection_summary()
self._populate_results_table()
self.current_selection = None
self.current_image = None
self.current_detections = []
self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.")
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
try:
detections = self.db_manager.get_detections(limit=500)
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename")
or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected")
or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as e:
logger.error(f"Failed to load detection summary: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(e)}",
)
self.detection_summary = []
def _populate_results_table(self):
"""Populate the table widget with detection summaries."""
self.results_table.setRowCount(len(self.detection_summary))
for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = (
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
items = [
QTableWidgetItem(entry.get("image_filename", "")),
QTableWidgetItem(model_label),
QTableWidgetItem(str(entry.get("count", 0))),
QTableWidgetItem(class_list),
QTableWidgetItem(str(entry.get("last_detected") or "")),
]
for col, item in enumerate(items):
item.setData(Qt.UserRole, row)
self.results_table.setItem(row, col, item)
self.results_table.clearSelection()
def _on_result_selected(self):
"""Handle selection changes in the detection table."""
selected_items = self.results_table.selectedItems()
if not selected_items:
return
row = selected_items[0].data(Qt.UserRole)
if row is None or row >= len(self.detection_summary):
return
entry = self.detection_summary[row]
if (
self.current_selection
and self.current_selection.get("image_id") == entry["image_id"]
and self.current_selection.get("model_id") == entry["model_id"]
):
return
self.current_selection = entry
image_path = self._resolve_image_path(entry)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate the image file for this detection.",
)
return
try:
self.current_image = Image(image_path)
self.preview_canvas.load_image(self.current_image)
except ImageLoadError as e:
logger.error(f"Failed to load image '{image_path}': {e}")
QMessageBox.critical(
self,
"Image Error",
f"Failed to load image for preview:\n{str(e)}",
)
return
self._load_detections_for_selection(entry)
self._apply_detection_overlays()
self._update_summary_label(entry)
def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair."""
self.current_detections = []
if not entry:
return
try:
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
self.current_detections = self.db_manager.get_detections(filters)
except Exception as e:
logger.error(f"Failed to load detections for preview: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detections for this image:\n{str(e)}",
)
self.current_detections = []
def _apply_detection_overlays(self):
"""Draw detections onto the preview canvas based on current toggles."""
self.preview_canvas.clear_annotations()
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
if not self.current_detections or not self.current_image:
return
for det in self.current_detections:
color = self._get_class_color(det.get("class_name"))
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
mask_points = self._convert_mask(det["segmentation_mask"])
if mask_points:
self.preview_canvas.draw_saved_polyline(mask_points, color)
bbox = [
det.get("x_min"),
det.get("y_min"),
det.get("x_max"),
det.get("y_max"),
]
if all(v is not None for v in bbox):
label = None
if self.show_confidence_checkbox.isChecked():
confidence = det.get("confidence")
if confidence is not None:
label = f"{confidence:.2f}"
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
converted = []
for point in mask_points:
if len(point) >= 2:
x, y = point[0], point[1]
converted.append([y, x])
return converted
def _toggle_bboxes(self):
"""Update bounding box visibility on the canvas."""
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
# Re-render to respect show/hide when toggled
self._apply_detection_overlays()
def _update_summary_label(self, entry: Dict):
"""Display textual summary for the selected detection run."""
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
summary_text = (
f"Image: {entry.get('image_filename', 'unknown')}\n"
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
f"Detections: {entry.get('count', 0)}\n"
f"Classes: {classes}\n"
f"Last Updated: {entry.get('last_detected', 'n/a')}"
)
self.summary_label.setText(summary_text)
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
"""Resolve an image path using metadata, cache, and repository hints."""
relative_path = entry.get("image_path") if entry else None
cache_key = relative_path or entry.get("source_path")
if cache_key and cache_key in self._image_path_cache:
cached = Path(self._image_path_cache[cache_key])
if cached.exists():
return self._image_path_cache[cache_key]
del self._image_path_cache[cache_key]
candidates = []
source_path = entry.get("source_path") if entry else None
if source_path:
candidates.append(Path(source_path))
repo_roots = []
if entry.get("repository_root"):
repo_roots.append(entry["repository_root"])
config_repo = self.config_manager.get_image_repository_path()
if config_repo:
repo_roots.append(config_repo)
for root in repo_roots:
if relative_path:
candidates.append(Path(root) / relative_path)
if relative_path:
candidates.append(Path(relative_path))
for candidate in candidates:
try:
if candidate and candidate.exists():
resolved = str(candidate.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
except Exception:
continue
# Fallback: search by filename in known roots
filename = Path(relative_path).name if relative_path else None
if filename:
search_roots = [Path(root) for root in repo_roots if root]
if not search_roots:
search_roots = [Path("data")]
match = self._search_in_roots(filename, search_roots)
if match:
resolved = str(match.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
return None
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
"""Search for a file name within a list of root directories."""
for root in roots:
try:
if not root.exists():
continue
for candidate in root.rglob(filename):
return candidate
except Exception as e:
logger.debug(f"Error searching for {filename} in {root}: {e}")
return None
def _get_class_color(self, class_name: Optional[str]) -> str:
"""Return consistent color hex for a class name."""
if not class_name:
return "#FF6B6B"
color_map = self.config_manager.get_bbox_colors()
if class_name in color_map:
return color_map[class_name]
# Deterministic fallback color based on hash
palette = [
"#FF6B6B",
"#4ECDC4",
"#FFD166",
"#1D3557",
"#F4A261",
"#E76F51",
]
return palette[hash(class_name) % len(palette)]

View File

@@ -10,7 +10,6 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import yaml
from PIL import Image as PILImage
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import (
QWidget,
@@ -28,24 +27,20 @@ from PySide6.QtWidgets import (
QProgressBar,
QSpinBox,
QDoubleSpinBox,
QCheckBox,
QScrollArea,
)
from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager
from src.utils.image import Image
from src.utils.logger import get_logger
logger = get_logger(__name__)
DEFAULT_IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
}
DEFAULT_IMAGE_EXTENSIONS = set(Image.SUPPORTED_EXTENSIONS)
class TrainingWorker(QThread):
@@ -67,6 +62,8 @@ class TrainingWorker(QThread):
save_dir: str,
run_name: str,
parent: Optional[QThread] = None,
stage_plan: Optional[List[Dict[str, Any]]] = None,
total_epochs: Optional[int] = None,
):
super().__init__(parent)
self.data_yaml = data_yaml
@@ -78,6 +75,27 @@ class TrainingWorker(QThread):
self.lr0 = lr0
self.save_dir = save_dir
self.run_name = run_name
self.stage_plan = stage_plan or [
{
"label": "Single Stage",
"model_path": base_model,
"use_previous_best": False,
"params": {
"epochs": epochs,
"batch": batch,
"imgsz": imgsz,
"patience": patience,
"lr0": lr0,
"freeze": 0,
"name": run_name,
},
}
]
computed_total = sum(
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False
def stop(self):
@@ -86,36 +104,98 @@ class TrainingWorker(QThread):
self.requestInterruption()
def run(self):
"""Execute YOLO training and emit progress/finished signals."""
wrapper = YOLOWrapper(self.base_model)
"""Execute YOLO training over one or more stages and emit progress/finished signals."""
def on_epoch_end(trainer):
completed_epochs = 0
stage_history: List[Dict[str, Any]] = []
last_stage_results: Optional[Dict[str, Any]] = None
for stage_index, stage in enumerate(self.stage_plan, start=1):
if self._stop_requested or self.isInterruptionRequested():
break
stage_label = stage.get("label") or f"Stage {stage_index}"
stage_params = dict(stage.get("params") or {})
stage_epochs = int(stage_params.get("epochs", self.epochs))
if stage_epochs <= 0:
stage_epochs = 1
batch = int(stage_params.get("batch", self.batch))
imgsz = int(stage_params.get("imgsz", self.imgsz))
patience = int(stage_params.get("patience", self.patience))
lr0 = float(stage_params.get("lr0", self.lr0))
freeze = int(stage_params.get("freeze", 0))
run_name = stage_params.get("name") or f"{self.run_name}_stage{stage_index}"
weights_path = stage.get("model_path") or self.base_model
if stage.get("use_previous_best") and last_stage_results:
weights_path = (
last_stage_results.get("best_model_path")
or last_stage_results.get("last_model_path")
or weights_path
)
wrapper = YOLOWrapper(weights_path)
stage_offset = completed_epochs
def on_epoch_end(trainer, offset=stage_offset):
current_epoch = getattr(trainer, "epoch", 0) + 1
metrics: Dict[str, float] = {}
loss_items = getattr(trainer, "loss_items", None)
if loss_items:
metrics["loss"] = float(loss_items[-1])
self.progress.emit(current_epoch, self.epochs, metrics)
absolute_epoch = min(
max(1, offset + current_epoch),
max(1, self.total_epochs),
)
self.progress.emit(absolute_epoch, self.total_epochs, metrics)
if self.isInterruptionRequested() or self._stop_requested:
setattr(trainer, "stop_training", True)
callbacks = {"on_fit_epoch_end": on_epoch_end}
try:
results = wrapper.train(
stage_result = wrapper.train(
data_yaml=self.data_yaml,
epochs=self.epochs,
imgsz=self.imgsz,
batch=self.batch,
patience=self.patience,
epochs=stage_epochs,
imgsz=imgsz,
batch=batch,
patience=patience,
save_dir=self.save_dir,
name=self.run_name,
lr0=self.lr0,
name=run_name,
lr0=lr0,
callbacks=callbacks,
freeze=freeze,
)
self.finished.emit(results)
except Exception as exc:
self.error.emit(str(exc))
return
stage_history.append(
{
"label": stage_label,
"params": stage_params,
"weights_used": weights_path,
"results": stage_result,
}
)
last_stage_results = stage_result
completed_epochs += stage_epochs
final_payload: Dict[str, Any]
if last_stage_results:
final_payload = dict(last_stage_results)
else:
final_payload = {
"success": False,
"message": "Training stopped before any stage completed.",
}
final_payload["stage_results"] = stage_history
final_payload["total_epochs_completed"] = completed_epochs
final_payload["total_epochs_planned"] = self.total_epochs
final_payload["stages_completed"] = len(stage_history)
self.finished.emit(final_payload)
class TrainingTab(QWidget):
@@ -146,12 +226,23 @@ class TrainingTab(QWidget):
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
# Create a container widget for all content
container = QWidget()
container_layout = QVBoxLayout(container)
layout.addWidget(self._create_dataset_group())
layout.addWidget(self._create_training_controls_group())
layout.addStretch()
self.setLayout(layout)
container_layout.addWidget(self._create_dataset_group())
container_layout.addWidget(self._create_training_controls_group())
container_layout.addStretch()
# Create scroll area and set the container as its widget
scroll_area = QScrollArea()
scroll_area.setWidget(container)
scroll_area.setWidgetResizable(True)
# Set main layout with scroll area
main_layout = QVBoxLayout(self)
main_layout.setContentsMargins(0, 0, 0, 0)
main_layout.addWidget(scroll_area)
self._discover_datasets()
self._load_saved_dataset()
@@ -249,13 +340,26 @@ class TrainingTab(QWidget):
default_base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect(
self._on_base_model_preset_changed
)
form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout()
self.base_model_edit = QLineEdit(default_base_model)
self.base_model_edit.editingFinished.connect(self._on_base_model_path_edited)
base_model_layout.addWidget(self.base_model_edit)
self.base_model_browse_button = QPushButton("Browse…")
self.base_model_browse_button.clicked.connect(self._browse_base_model)
base_model_layout.addWidget(self.base_model_browse_button)
form_layout.addRow("Base Model (.pt):", base_model_layout)
self._sync_base_model_preset_selection(default_base_model)
models_dir = self.config_manager.get("models.models_directory", "data/models")
save_dir_layout = QHBoxLayout()
@@ -298,6 +402,9 @@ class TrainingTab(QWidget):
group_layout.addLayout(form_layout)
self.two_stage_group = self._create_two_stage_group(training_defaults)
group_layout.addWidget(self.two_stage_group)
button_layout = QHBoxLayout()
self.start_training_button = QPushButton("Start Training")
self.start_training_button.clicked.connect(self._start_training)
@@ -322,6 +429,134 @@ class TrainingTab(QWidget):
group.setLayout(group_layout)
return group
def _create_two_stage_group(self, training_defaults: Dict[str, Any]) -> QGroupBox:
group = QGroupBox("Two-Stage Fine-Tuning")
group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = (
training_defaults.get("two_stage", {}) if training_defaults else {}
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox)
self.two_stage_controls_widget = QWidget()
controls_layout = QVBoxLayout()
controls_layout.setContentsMargins(0, 0, 0, 0)
controls_layout.setSpacing(8)
stage1_group = QGroupBox("Stage 1 — Head-only stabilization")
stage1_form = QFormLayout()
stage1_defaults = two_stage_defaults.get("stage1", {})
self.stage1_epochs_spin = QSpinBox()
self.stage1_epochs_spin.setRange(1, 500)
self.stage1_epochs_spin.setValue(int(stage1_defaults.get("epochs", 20)))
stage1_form.addRow("Epochs:", self.stage1_epochs_spin)
self.stage1_lr_spin = QDoubleSpinBox()
self.stage1_lr_spin.setDecimals(5)
self.stage1_lr_spin.setRange(0.00001, 0.1)
self.stage1_lr_spin.setSingleStep(0.0005)
self.stage1_lr_spin.setValue(float(stage1_defaults.get("lr0", 0.0005)))
stage1_form.addRow("Learning Rate:", self.stage1_lr_spin)
self.stage1_patience_spin = QSpinBox()
self.stage1_patience_spin.setRange(1, 200)
self.stage1_patience_spin.setValue(int(stage1_defaults.get("patience", 10)))
stage1_form.addRow("Patience:", self.stage1_patience_spin)
self.stage1_freeze_spin = QSpinBox()
self.stage1_freeze_spin.setRange(0, 24)
self.stage1_freeze_spin.setValue(int(stage1_defaults.get("freeze", 10)))
stage1_form.addRow("Freeze layers:", self.stage1_freeze_spin)
stage1_group.setLayout(stage1_form)
controls_layout.addWidget(stage1_group)
stage2_group = QGroupBox("Stage 2 — Full fine-tuning")
stage2_form = QFormLayout()
stage2_defaults = two_stage_defaults.get("stage2", {})
self.stage2_epochs_spin = QSpinBox()
self.stage2_epochs_spin.setRange(1, 2000)
self.stage2_epochs_spin.setValue(int(stage2_defaults.get("epochs", 150)))
stage2_form.addRow("Epochs:", self.stage2_epochs_spin)
self.stage2_lr_spin = QDoubleSpinBox()
self.stage2_lr_spin.setDecimals(5)
self.stage2_lr_spin.setRange(0.00001, 0.1)
self.stage2_lr_spin.setSingleStep(0.0005)
self.stage2_lr_spin.setValue(float(stage2_defaults.get("lr0", 0.0003)))
stage2_form.addRow("Learning Rate:", self.stage2_lr_spin)
self.stage2_patience_spin = QSpinBox()
self.stage2_patience_spin.setRange(1, 200)
self.stage2_patience_spin.setValue(int(stage2_defaults.get("patience", 30)))
stage2_form.addRow("Patience:", self.stage2_patience_spin)
stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group)
helper_label = QLabel(
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label)
self.two_stage_controls_widget.setLayout(controls_layout)
group_layout.addWidget(self.two_stage_controls_widget)
group.setLayout(group_layout)
self._on_two_stage_toggled(self.two_stage_checkbox.isChecked())
return group
def _on_two_stage_toggled(self, checked: bool):
self._refresh_two_stage_controls_enabled(checked)
def _refresh_two_stage_controls_enabled(self, checked: Optional[bool] = None):
if not hasattr(self, "two_stage_controls_widget"):
return
desired_state = checked
if desired_state is None:
desired_state = self.two_stage_checkbox.isChecked()
can_edit = self.two_stage_checkbox.isEnabled()
self.two_stage_controls_widget.setEnabled(bool(desired_state and can_edit))
def _on_base_model_preset_changed(self, index: int):
preset_value = self.base_model_combo.itemData(index)
if preset_value:
self.base_model_edit.setText(str(preset_value))
elif index == 0:
self.base_model_edit.setFocus()
def _on_base_model_path_edited(self):
self._sync_base_model_preset_selection(self.base_model_edit.text().strip())
def _sync_base_model_preset_selection(self, model_path: str):
if not hasattr(self, "base_model_combo"):
return
normalized = (model_path or "").strip()
target_index = 0
for idx in range(1, self.base_model_combo.count()):
preset_value = self.base_model_combo.itemData(idx)
if not preset_value:
continue
if normalized == preset_value:
target_index = idx
break
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
f"\\{preset_value}"
):
target_index = idx
break
self.base_model_combo.blockSignals(True)
self.base_model_combo.setCurrentIndex(target_index)
self.base_model_combo.blockSignals(False)
def _get_dataset_search_roots(self) -> List[Path]:
roots: List[Path] = []
default_root = Path("data/datasets").expanduser()
@@ -346,6 +581,7 @@ class TrainingTab(QWidget):
for yaml_path in root.rglob("*.yaml"):
if yaml_path.name.lower() not in {"data.yaml", "dataset.yaml"}:
continue
discovered.append(yaml_path.resolve())
except Exception as exc:
logger.warning(f"Unable to scan {root}: {exc}")
@@ -964,6 +1200,90 @@ class TrainingTab(QWidget):
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
def _compose_stage_plan(self, params: Dict[str, Any]) -> List[Dict[str, Any]]:
two_stage = params.get("two_stage") or {}
base_stage = {
"label": "Single Stage",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": params["epochs"],
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"freeze": 0,
"name": params["run_name"],
},
}
if not two_stage.get("enabled"):
return [base_stage]
stage_plan: List[Dict[str, Any]] = []
stage1 = two_stage.get("stage1", {})
stage2 = two_stage.get("stage2", {})
stage_plan.append(
{
"label": "Stage 1 — Head-only",
"model_path": params["base_model"],
"use_previous_best": False,
"params": {
"epochs": stage1.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage1.get("patience", params["patience"]),
"lr0": stage1.get("lr0", params["lr0"]),
"freeze": stage1.get("freeze", 0),
"name": f"{params['run_name']}_head_ft",
},
}
)
stage_plan.append(
{
"label": "Stage 2 — Full",
"model_path": params["base_model"],
"use_previous_best": True,
"params": {
"epochs": stage2.get("epochs", params["epochs"]),
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": stage2.get("patience", params["patience"]),
"lr0": stage2.get("lr0", params["lr0"]),
"freeze": stage2.get("freeze", 0),
"name": f"{params['run_name']}_full_ft",
},
}
)
return stage_plan
def _calculate_total_stage_epochs(self, stage_plan: List[Dict[str, Any]]) -> int:
total = 0
for stage in stage_plan:
params = stage.get("params") or {}
try:
stage_epochs = int(params.get("epochs", 0))
except (TypeError, ValueError):
stage_epochs = 0
if stage_epochs > 0:
total += stage_epochs
return total
def _log_stage_plan(self, stage_plan: List[Dict[str, Any]]):
for index, stage in enumerate(stage_plan, start=1):
stage_label = stage.get("label") or f"Stage {index}"
params = stage.get("params") or {}
epochs = params.get("epochs", "?")
lr0 = params.get("lr0", "?")
patience = params.get("patience", "?")
freeze = params.get("freeze", 0)
self._append_training_log(
f"{stage_label}: epochs={epochs}, lr0={lr0}, patience={patience}, freeze={freeze}"
)
def _get_rgb_cache_root(self, dataset_yaml: Path) -> Path:
cache_base = Path("data/datasets/_rgb_cache")
cache_base.mkdir(parents=True, exist_ok=True)
@@ -984,8 +1304,8 @@ class TrainingTab(QWidget):
if not sample_image:
return False
try:
with PILImage.open(sample_image) as img:
return img.mode.upper() != "RGB"
img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB"
except Exception as exc:
logger.warning(f"Failed to inspect image {sample_image}: {exc}")
return False
@@ -1045,8 +1365,12 @@ class TrainingTab(QWidget):
dst = dst_dir / relative
dst.parent.mkdir(parents=True, exist_ok=True)
try:
with PILImage.open(src) as img:
rgb_img = img.convert("RGB")
img_obj = Image(src)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
rgb_img.save(dst)
except Exception as exc:
logger.warning(f"Failed to convert {src} to RGB: {exc}")
@@ -1085,6 +1409,21 @@ class TrainingTab(QWidget):
save_dir_path.mkdir(parents=True, exist_ok=True)
run_name = f"{model_name}_{model_version}".replace(" ", "_")
two_stage_config = {
"enabled": self.two_stage_checkbox.isChecked(),
"stage1": {
"epochs": self.stage1_epochs_spin.value(),
"lr0": self.stage1_lr_spin.value(),
"patience": self.stage1_patience_spin.value(),
"freeze": self.stage1_freeze_spin.value(),
},
"stage2": {
"epochs": self.stage2_epochs_spin.value(),
"lr0": self.stage2_lr_spin.value(),
"patience": self.stage2_patience_spin.value(),
},
}
return {
"model_name": model_name,
"model_version": model_version,
@@ -1096,6 +1435,7 @@ class TrainingTab(QWidget):
"imgsz": self.imgsz_spin.value(),
"patience": self.patience_spin.value(),
"lr0": self.lr_spin.value(),
"two_stage": two_stage_config,
}
def _start_training(self):
@@ -1137,15 +1477,25 @@ class TrainingTab(QWidget):
)
params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan
total_planned_epochs = (
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params
self._training_cancelled = False
if len(stage_plan) > 1:
self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan)
self._append_training_log(
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(params["epochs"])
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
self.training_progress_bar.setValue(0)
self._set_training_state(True)
@@ -1159,6 +1509,8 @@ class TrainingTab(QWidget):
lr0=params["lr0"],
save_dir=params["save_dir"],
run_name=params["run_name"],
stage_plan=stage_plan,
total_epochs=total_planned_epochs,
)
self.training_worker.progress.connect(self._on_training_progress)
self.training_worker.finished.connect(self._on_training_finished)
@@ -1283,14 +1635,22 @@ class TrainingTab(QWidget):
if not model_path:
raise ValueError("Training results did not include a model path.")
effective_epochs = params.get("total_planned_epochs", params["epochs"])
training_params = {
"epochs": params["epochs"],
"epochs": effective_epochs,
"batch": params["batch"],
"imgsz": params["imgsz"],
"patience": params["patience"],
"lr0": params["lr0"],
"run_name": params["run_name"],
"two_stage": params.get("two_stage"),
}
if params.get("stage_plan"):
training_params["stage_plan"] = params["stage_plan"]
if results.get("stage_results"):
training_params["stage_results"] = results["stage_results"]
if results.get("total_epochs_completed") is not None:
training_params["epochs_completed"] = results["total_epochs_completed"]
model_id = self.db_manager.add_model(
model_name=params["model_name"],
@@ -1315,6 +1675,7 @@ class TrainingTab(QWidget):
self.rescan_button.setEnabled(not is_training)
self.model_name_edit.setEnabled(not is_training)
self.model_version_edit.setEnabled(not is_training)
self.base_model_combo.setEnabled(not is_training)
self.base_model_edit.setEnabled(not is_training)
self.base_model_browse_button.setEnabled(not is_training)
self.save_dir_edit.setEnabled(not is_training)
@@ -1324,6 +1685,8 @@ class TrainingTab(QWidget):
self.imgsz_spin.setEnabled(not is_training)
self.patience_spin.setEnabled(not is_training)
self.lr_spin.setEnabled(not is_training)
self.two_stage_checkbox.setEnabled(not is_training)
self._refresh_two_stage_controls_enabled()
def _append_training_log(self, message: str):
timestamp = datetime.now().strftime("%H:%M:%S")
@@ -1339,6 +1702,7 @@ class TrainingTab(QWidget):
)
if file_path:
self.base_model_edit.setText(file_path)
self._sync_base_model_preset_selection(file_path)
def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models"

View File

@@ -16,8 +16,9 @@ from PySide6.QtGui import (
QKeyEvent,
QMouseEvent,
QPaintEvent,
QPolygonF,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError
@@ -246,10 +247,10 @@ class AnnotationCanvasWidget(QWidget):
return
try:
# Get RGB image data
if self.current_image.channels == 3:
# Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb()
height, width, channels = image_data.shape
height, width = image_data.shape[:2]
else:
image_data = self.current_image.get_grayscale()
height, width = image_data.shape
@@ -263,7 +264,7 @@ class AnnotationCanvasWidget(QWidget):
height,
bytes_per_line,
self.current_image.qtimage_format,
)
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage)
@@ -496,8 +497,10 @@ class AnnotationCanvasWidget(QWidget):
)
painter.setPen(pen)
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
painter.drawLine(int(x1), int(y1), int(x2), int(y2))
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
# drawPolygon() automatically closes the shape, ensuring proper visual closure
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
painter.drawPolygon(polygon)
# Draw bounding boxes (dashed) if enabled
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
@@ -529,6 +532,40 @@ class AnnotationCanvasWidget(QWidget):
painter.setPen(pen)
painter.drawRect(x_min, y_min, rect_width, rect_height)
label_text = meta.get("label")
if label_text:
painter.save()
font = painter.font()
font.setPointSizeF(max(10.0, width + 4))
painter.setFont(font)
metrics = painter.fontMetrics()
text_width = metrics.horizontalAdvance(label_text)
text_height = metrics.height()
padding = 4
bg_width = text_width + padding * 2
bg_height = text_height + padding * 2
canvas_width = self.original_pixmap.width()
canvas_height = self.original_pixmap.height()
bg_x = max(0, min(x_min, canvas_width - bg_width))
bg_y = y_min - bg_height
if bg_y < 0:
bg_y = min(y_min, canvas_height - bg_height)
bg_y = max(0, bg_y)
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
background_color = QColor(pen_color)
background_color.setAlpha(220)
painter.fillRect(background_rect, background_color)
text_color = QColor(0, 0, 0)
if background_color.lightness() < 128:
text_color = QColor(255, 255, 255)
painter.setPen(text_color)
painter.drawText(
background_rect.adjusted(padding, padding, -padding, -padding),
Qt.AlignLeft | Qt.AlignVCenter,
label_text,
)
painter.restore()
painter.end()
self._update_display()
@@ -787,7 +824,13 @@ class AnnotationCanvasWidget(QWidget):
f"Drew saved polyline with {len(polyline)} points in color {color}"
)
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
def draw_saved_bbox(
self,
bbox: List[float],
color: str,
width: int = 3,
label: Optional[str] = None,
):
"""
Draw a bounding box from database coordinates onto the annotation canvas.
@@ -796,6 +839,7 @@ class AnnotationCanvasWidget(QWidget):
in normalized coordinates (0-1)
color: Color hex string (e.g., '#FF0000')
width: Line width in pixels
label: Optional text label to render near the bounding box
"""
if not self.annotation_pixmap or not self.original_pixmap:
logger.warning("Cannot draw bounding box: no image loaded")
@@ -828,11 +872,11 @@ class AnnotationCanvasWidget(QWidget):
self.bboxes.append(
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
self.bbox_meta.append({"color": pen_color, "width": int(width)})
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency
self.all_strokes.append(
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
)
# Redraw overlay (polylines + all bounding boxes)

View File

@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
height,
bytes_per_line,
self.current_image.qtimage_format,
)
).copy() # Copy to ensure Qt owns its memory after this scope
# Convert to pixmap
pixmap = QPixmap.fromImage(qimage)

View File

@@ -5,12 +5,12 @@ Handles detection inference and result storage.
from typing import List, Dict, Optional, Callable
from pathlib import Path
from PIL import Image
import cv2
import numpy as np
from src.model.yolo_wrapper import YOLOWrapper
from src.database.db_manager import DatabaseManager
from src.utils.image import Image
from src.utils.logger import get_logger
from src.utils.file_utils import get_relative_path
@@ -42,6 +42,7 @@ class InferenceEngine:
relative_path: str,
conf: float = 0.25,
save_to_db: bool = True,
repository_root: Optional[str] = None,
) -> Dict:
"""
Detect objects in a single image.
@@ -51,29 +52,42 @@ class InferenceEngine:
relative_path: Relative path from repository root
conf: Confidence threshold
save_to_db: Whether to save results to database
repository_root: Base directory used to compute relative_path (if known)
Returns:
Dictionary with detection results
"""
try:
# Normalize storage path (fall back to absolute path when repo root is unknown)
stored_relative_path = relative_path
if not repository_root:
stored_relative_path = str(Path(image_path).resolve())
# Get image dimensions
img = Image.open(image_path)
width, height = img.size
img.close()
img = Image(image_path)
width = img.width
height = img.height
# Perform detection
detections = self.yolo.predict(image_path, conf=conf)
# Add/get image in database
image_id = self.db_manager.get_or_create_image(
relative_path=relative_path,
relative_path=stored_relative_path,
filename=Path(image_path).name,
width=width,
height=height,
)
# Save detections to database
if save_to_db and detections:
inserted_count = 0
deleted_count = 0
# Save detections to database, replacing any previous results for this image/model
if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image(
image_id, self.model_id
)
if detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
@@ -81,6 +95,15 @@ class InferenceEngine:
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
metadata = {
"class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
record = {
"image_id": image_id,
"model_id": self.model_id,
@@ -88,12 +111,20 @@ class InferenceEngine:
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"segmentation_mask": det.get("segmentation_mask"),
"metadata": {"class_id": det["class_id"]},
"metadata": metadata,
}
detection_records.append(record)
self.db_manager.add_detections_batch(detection_records)
logger.info(f"Saved {len(detection_records)} detections to database")
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
)
return {
"success": True,
@@ -142,7 +173,12 @@ class InferenceEngine:
rel_path = get_relative_path(image_path, repository_root)
# Perform detection
result = self.detect_single(image_path, rel_path, conf)
result = self.detect_single(
image_path,
rel_path,
conf=conf,
repository_root=repository_root,
)
results.append(result)
# Update progress

View File

@@ -7,6 +7,9 @@ from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
import tempfile
import os
from src.utils.image import Image
from src.utils.logger import get_logger
@@ -77,7 +80,8 @@ class YOLOWrapper:
Dictionary with training results
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Starting training: {name}")
@@ -119,7 +123,8 @@ class YOLOWrapper:
Dictionary with validation metrics
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Starting validation on {split} split")
@@ -160,12 +165,15 @@ class YOLOWrapper:
List of detection dictionaries
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
try:
logger.info(f"Running inference on {source}")
results = self.model.predict(
source=source,
source=prepared_source,
conf=conf,
iou=iou,
save=save,
@@ -182,6 +190,14 @@ class YOLOWrapper:
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
finally:
if 0: # cleanup_path:
try:
os.remove(cleanup_path)
except OSError as cleanup_error:
logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
@@ -198,7 +214,8 @@ class YOLOWrapper:
Path to exported model
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Exporting model to {format} format")
@@ -210,6 +227,38 @@ class YOLOWrapper:
logger.error(f"Error exporting model: {e}")
raise
def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference."""
cleanup_path = None
if isinstance(source, (str, Path)):
source_path = Path(source)
if source_path.is_file():
try:
img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
)
return source, cleanup_path
def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary."""
try:

View File

@@ -7,6 +7,7 @@ import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from src.utils.logger import get_logger
from src.utils.image import Image
logger = get_logger(__name__)
@@ -46,18 +47,15 @@ class ConfigManager:
"database": {"path": "data/detections.db"},
"image_repository": {
"base_path": "",
"allowed_extensions": [
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
],
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
},
"models": {
"default_base_model": "yolov8s-seg.pt",
"models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolov11s-seg.pt",
],
},
"training": {
"default_epochs": 100,
@@ -65,6 +63,20 @@ class ConfigManager:
"default_imgsz": 640,
"default_patience": 50,
"default_lr0": 0.01,
"two_stage": {
"enabled": False,
"stage1": {
"epochs": 20,
"lr0": 0.0005,
"patience": 10,
"freeze": 10,
},
"stage2": {
"epochs": 150,
"lr0": 0.0003,
"patience": 30,
},
},
},
"detection": {
"default_confidence": 0.25,
@@ -214,5 +226,5 @@ class ConfigManager:
def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions."""
return self.get(
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
)

View File

@@ -28,7 +28,9 @@ def get_image_files(
List of absolute paths to image files
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
# Normalize extensions to lowercase
allowed_extensions = [ext.lower() for ext in allowed_extensions]
@@ -204,7 +206,9 @@ def is_image_file(
True if file is an image
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
extension = Path(file_path).suffix.lower()
return extension in [ext.lower() for ext in allowed_extensions]

View File

@@ -277,6 +277,38 @@ class Image:
"""
return self._channels >= 3
def convert_grayscale_to_rgb_preserve_range(
self,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Returns:
PIL Image in RGB mode with intensities normalized to 0-255.
"""
if self._channels == 3:
return self.pil_image
grayscale = self.data
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", self.shape, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")
def __repr__(self) -> str:
"""String representation of the Image object."""
return (

View File

@@ -0,0 +1,160 @@
import numpy as np
from roifile import ImagejRoi
from tifffile import TiffFile, TiffWriter
from pathlib import Path
class UT:
"""
Docstring for UT
Operetta files along with rois drawn in ImageJ
"""
def __init__(self, roifile_fn: Path, no_labels: bool):
self.roifile_fn = roifile_fn
print("is file", self.roifile_fn.is_file())
self.rois = None
if no_labels:
self.rois = ImagejRoi.fromfile(self.roifile_fn)
self.stem = self.roifile_fn.stem.split("Roi-")[1]
else:
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
self.stem = self.roifile_fn.stem
print(self.roifile_fn)
print(self.stem)
self.image, self.image_props = self._load_images()
def _load_images(self):
"""Loading sequence of tif files
array sequence is CZYX
"""
print("Loading images:", self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
n_p = len(set([stem.split("-")[0] for stem in stems]))
n_t = len(set([stem.split("t")[1] for stem in stems]))
with TiffFile(fns[0]) as tif:
img = tif.asarray()
w, h = img.shape
dtype = img.dtype
self.image_props = {
"channels": n_ch,
"planes": n_p,
"tiles": n_t,
"width": w,
"height": h,
"dtype": dtype,
}
print("Image props", self.image_props)
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
for fn in fns:
with TiffFile(fn) as tif:
img = tif.asarray()
stem = fn.stem.split(self.stem)[-1]
ch = int(stem.split("-ch")[-1].split("t")[0])
p = int(stem.split("-")[0].split("p")[1])
t = int(stem.split("t")[1])
print(fn.stem, "ch", ch, "p", p, "t", t)
image_stack[ch - 1, p - 1] = img
print(image_stack.shape)
return image_stack, self.image_props
@property
def width(self):
return self.image_props["width"]
@property
def height(self):
return self.image_props["height"]
@property
def nchannels(self):
return self.image_props["channels"]
@property
def nplanes(self):
return self.image_props["planes"]
def export_rois(
self,
path: Path,
subfolder: str = "labels",
class_index: int = 0,
):
"""Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates
if rc is None:
print(
f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}"
)
continue
xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
for x, y in rc:
coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n")
return
def export_image(
self,
path: Path,
subfolder: str = "images",
plane_mode: str = "max projection",
channel: int = 0,
):
"""Export image to a file"""
if plane_mode == "max projection":
self.image = np.max(self.image[channel], axis=0)
print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
tif.write(self.image)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", nargs="*", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"--no-labels",
action="store_false",
help="Source does not have labels, export only images",
)
args = parser.parse_args()
for path in args.input:
print("Path:", path)
if not args.no_labels:
print("No labels")
ut = UT(path, args.no_labels)
ut.export_image(args.output, plane_mode="max projection", channel=0)
else:
for rfn in Path(path).glob("*.zip"):
print("Roi FN:", rfn)
ut = UT(rfn, args.no_labels)
ut.export_rois(args.output, class_index=0)
ut.export_image(args.output, plane_mode="max projection", channel=0)
print()

184
tests/show_yolo_seg.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
show_yolo_seg.py
Usage:
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
Supports:
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
- YOLO bbox lines as fallback: "class x_center y_center width height"
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
"""
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import random
def parse_label_line(line):
parts = line.strip().split()
if not parts:
return None
cls = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
return cls, coords
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
if not coords:
return False
return max(coords) <= 1.001
def yolo_bbox_to_xyxy(coords, img_w, img_h):
# coords: [xc, yc, w, h] normalized or absolute
xc, yc, w, h = coords[:4]
if max(coords) <= 1.001:
xc *= img_w
yc *= img_h
w *= img_w
h *= img_h
x1 = int(round(xc - w / 2))
y1 = int(round(yc - h / 2))
x2 = int(round(xc + w / 2))
y2 = int(round(yc + h / 2))
return x1, y1, x2, y2
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]):
coords = [
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
def random_color_for_class(cls):
random.seed(cls) # deterministic per class
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array
overlay = img.copy()
h, w = img.shape[:2]
for cls, coords in labels:
if not coords:
continue
# polygon case (>=6 coordinates)
if len(coords) >= 6:
color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
pts = poly_to_pts(coords[4:], w, h)
# fill on overlay
cv2.fillPoly(overlay, [pts], color)
# outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
# put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
cv2.putText(
img,
str(cls),
(x, max(6, y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# YOLO bbox case (4 coords)
elif len(coords) == 4:
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
color = random_color_for_class(cls)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
cv2.putText(
img,
str(cls),
(x1, max(6, y1 - 4)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
else:
# Unknown / invalid format, skip
continue
# blend overlay for filled polygons
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
return img
def load_labels_file(label_path):
labels = []
with open(label_path, "r") as f:
for raw in f:
line = raw.strip()
if not line:
continue
parsed = parse_label_line(line)
if parsed:
labels.append(parsed)
return labels
def main():
parser = argparse.ArgumentParser(
description="Show YOLO segmentation / polygon annotations"
)
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument(
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
)
parser.add_argument(
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
)
args = parser.parse_args()
img_path = Path(args.image)
lbl_path = Path(args.labels)
if not img_path.exists():
print("Image not found:", img_path)
sys.exit(1)
if not lbl_path.exists():
print("Label file not found:", lbl_path)
sys.exit(1)
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
if img is None:
print("Could not load image:", img_path)
sys.exit(1)
labels = load_labels_file(str(lbl_path))
if not labels:
print("No labels parsed from", lbl_path)
# continue and just show image
out = draw_annotations(
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show()
if __name__ == "__main__":
main()

View File

@@ -27,7 +27,7 @@ class TestImage:
def test_supported_extensions(self):
"""Test that supported extensions are correctly defined."""
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
expected_extensions = Image.SUPPORTED_EXTENSIONS
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
def test_image_properties(self, tmp_path):