18 Commits

Author SHA1 Message Date
dd99a0677c Updating image converter and aading simple script to visulaize segmentation 2025-12-16 11:27:38 +02:00
9c4c39fb39 Adding image converter 2025-12-12 23:52:34 +02:00
20a87c9040 Updating config 2025-12-12 21:51:12 +02:00
9f7d2be1ac Updating the base model preset 2025-12-11 23:27:02 +02:00
dbde07c0e8 Making training tab scrollable 2025-12-11 23:12:39 +02:00
b3c5a51dbb Using QPolygonF instead of drawLine 2025-12-11 17:14:07 +02:00
9a221acb63 Making image manipulations thru one class 2025-12-11 16:59:56 +02:00
32a6a122bd Fixing circular import 2025-12-11 16:06:39 +02:00
9ba44043ef Defining image extensions only in one place 2025-12-11 15:50:14 +02:00
8eb1cc8c86 Fixing grayscale conversion 2025-12-11 15:15:38 +02:00
e4ce882a18 Grayscale RGB conversion modified 2025-12-11 15:06:59 +02:00
6b6d6fad03 2Stage training fix 2025-12-11 12:50:34 +02:00
c0684a9c14 Implementing 2 stage training 2025-12-11 12:04:08 +02:00
221c80aa8c Small image showing fix 2025-12-11 11:20:20 +02:00
833b222fad Adding result shower 2025-12-10 16:55:28 +02:00
5370d31dce Merge pull request 'Update training' (#2) from training into main
Reviewed-on: #2
2025-12-10 15:47:00 +02:00
5d196c3a4a Update training 2025-12-10 15:46:26 +02:00
f719c7ec40 Merge pull request 'segmentation' (#1) from segmentation into main
Reviewed-on: #1
2025-12-10 12:08:54 +02:00
17 changed files with 3056 additions and 122 deletions

View File

@@ -12,12 +12,28 @@ image_repository:
models:
default_base_model: yolov8s-seg.pt
models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolo11s-seg.pt
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 640
default_imgsz: 1024
default_patience: 50
default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection:
default_confidence: 0.25
default_iou: 0.45

View File

@@ -10,6 +10,14 @@ from typing import List, Dict, Optional, Tuple, Any, Union
from pathlib import Path
import csv
import hashlib
import yaml
from src.utils.logger import get_logger
from src.utils.image import Image
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
logger = get_logger(__name__)
class DatabaseManager:
@@ -443,6 +451,25 @@ class DatabaseManager:
filters["model_id"] = model_id
return self.get_detections(filters)
def delete_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> int:
"""Delete detections tied to a specific image and optional model."""
conn = self.get_connection()
try:
cursor = conn.cursor()
if model_id is not None:
cursor.execute(
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
(image_id, model_id),
)
else:
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
conn.commit()
return cursor.rowcount
finally:
conn.close()
def delete_detections_for_model(self, model_id: int) -> int:
"""Delete all detections for a specific model."""
conn = self.get_connection()
@@ -861,6 +888,187 @@ class DatabaseManager:
finally:
conn.close()
# ==================== Dataset Utilities ====================
def compose_data_yaml(
self,
dataset_root: str,
output_path: Optional[str] = None,
splits: Optional[Dict[str, str]] = None,
) -> str:
"""
Compose a YOLO data.yaml file based on dataset folders and database metadata.
Args:
dataset_root: Base directory containing the dataset structure.
output_path: Optional output path; defaults to <dataset_root>/data.yaml.
splits: Optional mapping overriding train/val/test image directories (relative
to dataset_root or absolute paths).
Returns:
Path to the generated YAML file.
"""
dataset_root_path = Path(dataset_root).expanduser()
if not dataset_root_path.exists():
raise ValueError(f"Dataset root does not exist: {dataset_root_path}")
dataset_root_path = dataset_root_path.resolve()
split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")}
if splits:
for key, value in splits.items():
if key in split_map and value:
split_map[key] = value
inferred = self._infer_split_dirs(dataset_root_path)
for key in split_map:
if not split_map[key]:
split_map[key] = inferred.get(key, "")
for required in ("train", "val"):
if not split_map[required]:
raise ValueError(
"Unable to determine %s image directory under %s. Provide it "
"explicitly via the 'splits' argument."
% (required, dataset_root_path)
)
yaml_splits: Dict[str, str] = {}
for key, value in split_map.items():
if not value:
continue
yaml_splits[key] = self._normalize_split_value(value, dataset_root_path)
class_names = self._fetch_annotation_class_names()
if not class_names:
class_names = [cls["class_name"] for cls in self.get_object_classes()]
if not class_names:
raise ValueError("No object classes available to populate data.yaml")
names_map = {idx: name for idx, name in enumerate(class_names)}
payload: Dict[str, Any] = {
"path": dataset_root_path.as_posix(),
"train": yaml_splits["train"],
"val": yaml_splits["val"],
"names": names_map,
"nc": len(class_names),
}
if yaml_splits.get("test"):
payload["test"] = yaml_splits["test"]
output_path_obj = (
Path(output_path).expanduser()
if output_path
else dataset_root_path / "data.yaml"
)
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(output_path_obj, "w", encoding="utf-8") as handle:
yaml.safe_dump(payload, handle, sort_keys=False)
logger.info(f"Generated data.yaml at {output_path_obj}")
return output_path_obj.as_posix()
def _fetch_annotation_class_names(self) -> List[str]:
"""Return class names referenced by annotations (ordered by class ID)."""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"""
SELECT DISTINCT c.id, c.class_name
FROM annotations a
JOIN object_classes c ON a.class_id = c.id
ORDER BY c.id
"""
)
rows = cursor.fetchall()
return [row["class_name"] for row in rows]
finally:
conn.close()
def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]:
"""Infer train/val/test image directories relative to dataset_root."""
patterns = {
"train": [
"train/images",
"training/images",
"images/train",
"images/training",
"train",
"training",
],
"val": [
"val/images",
"validation/images",
"images/val",
"images/validation",
"val",
"validation",
],
"test": [
"test/images",
"testing/images",
"images/test",
"images/testing",
"test",
"testing",
],
}
inferred: Dict[str, str] = {key: "" for key in patterns}
for split_name, options in patterns.items():
for relative in options:
candidate = (dataset_root / relative).resolve()
if (
candidate.exists()
and candidate.is_dir()
and self._directory_has_images(candidate)
):
try:
inferred[split_name] = candidate.relative_to(
dataset_root
).as_posix()
except ValueError:
inferred[split_name] = candidate.as_posix()
break
return inferred
def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str:
"""Validate and normalize a split directory to a YAML-friendly string."""
split_path = Path(split_value).expanduser()
if not split_path.is_absolute():
split_path = (dataset_root / split_path).resolve()
else:
split_path = split_path.resolve()
if not split_path.exists() or not split_path.is_dir():
raise ValueError(f"Split directory not found: {split_path}")
if not self._directory_has_images(split_path):
raise ValueError(f"No images found under {split_path}")
try:
return split_path.relative_to(dataset_root).as_posix()
except ValueError:
return split_path.as_posix()
@staticmethod
def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool:
"""Return True if directory tree contains at least one image file."""
checked = 0
try:
for file_path in directory.rglob("*"):
if not file_path.is_file():
continue
if file_path.suffix.lower() in IMAGE_EXTENSIONS:
return True
checked += 1
if checked >= max_checks:
break
except Exception:
return False
return False
@staticmethod
def calculate_checksum(file_path: str) -> str:
"""Calculate MD5 checksum of a file."""

View File

@@ -297,7 +297,9 @@ class MainWindow(QMainWindow):
# Save window state before closing
self._save_window_state()
# Save annotation tab state if it exists
# Persist tab state and stop background work before exit
if hasattr(self, "training_tab"):
self.training_tab.shutdown()
if hasattr(self, "annotation_tab"):
self.annotation_tab.save_state()

View File

@@ -168,7 +168,7 @@ class AnnotationTab(QWidget):
self,
"Select Image",
start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
)
if not file_path:

View File

@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
)
from PySide6.QtCore import Qt, QThread, Signal
from pathlib import Path
from typing import Optional
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.file_utils import get_image_files
from src.model.inference import InferenceEngine
from src.utils.image import Image
logger = get_logger(__name__)
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
def _load_models(self):
"""Load available models from database."""
"""Load available models from database and local storage."""
try:
models = self.db_manager.get_models()
self.model_combo.clear()
models = self.db_manager.get_models()
has_models = False
if not models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
return
known_paths = set()
# Add base model option
# Add base model option first (always available)
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
self.model_combo.addItem(
f"Base Model ({base_model})", {"id": 0, "path": base_model}
)
if base_model:
base_data = {
"id": 0,
"path": base_model,
"model_name": Path(base_model).stem or "Base Model",
"model_version": "pretrained",
"base_model": base_model,
"source": "base",
}
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
known_paths.add(self._normalize_model_path(base_model))
has_models = True
# Add trained models
# Add trained models from database
for model in models:
display_name = f"{model['model_name']} v{model['model_version']}"
self.model_combo.addItem(display_name, model)
model_data = {**model, "path": model.get("model_path")}
normalized = self._normalize_model_path(model_data.get("path"))
if normalized:
known_paths.add(normalized)
self.model_combo.addItem(display_name, model_data)
has_models = True
self._set_buttons_enabled(True)
# Discover local model files not yet in the database
local_models = self._discover_local_models()
for model_path in local_models:
normalized = self._normalize_model_path(model_path)
if normalized in known_paths:
continue
display_name = f"Local Model ({Path(model_path).stem})"
model_data = {
"id": None,
"path": str(model_path),
"model_name": Path(model_path).stem,
"model_version": "local",
"base_model": Path(model_path).stem,
"source": "local",
}
self.model_combo.addItem(display_name, model_data)
known_paths.add(normalized)
has_models = True
if not has_models:
self.model_combo.addItem("No models available", None)
self._set_buttons_enabled(False)
else:
self._set_buttons_enabled(True)
except Exception as e:
logger.error(f"Error loading models: {e}")
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
self,
"Select Image",
start_dir,
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
)
if not file_path:
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
QMessageBox.warning(self, "No Model", "Please select a model first.")
return
model_path = model_data["path"]
model_id = model_data["id"]
model_path = model_data.get("path")
if not model_path:
QMessageBox.warning(
self, "Invalid Model", "Selected model is missing a file path."
)
return
# Ensure we have a valid model ID (create entry for base model if needed)
if model_id == 0:
# Create database entry for base model
base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
model_id = self.db_manager.add_model(
model_name="Base Model",
model_version="pretrained",
model_path=base_model,
base_model=base_model,
if not Path(model_path).exists():
QMessageBox.critical(
self,
"Model Not Found",
f"The selected model file could not be found:\n{model_path}",
)
return
model_id = model_data.get("id")
# Ensure we have a database entry for the selected model
if model_id in (None, 0):
model_id = self._ensure_model_record(model_data)
if not model_id:
QMessageBox.critical(
self,
"Model Registration Failed",
"Unable to register the selected model in the database.",
)
return
normalized_model_path = self._normalize_model_path(model_path) or model_path
# Create inference engine
self.inference_engine = InferenceEngine(
model_path, self.db_manager, model_id
normalized_model_path, self.db_manager, model_id
)
# Get confidence threshold
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
self.batch_btn.setEnabled(enabled)
self.model_combo.setEnabled(enabled)
def _discover_local_models(self) -> list:
"""Scan the models directory for standalone .pt files."""
models_dir = self.config_manager.get_models_directory()
if not models_dir:
return []
models_path = Path(models_dir)
if not models_path.exists():
return []
try:
return sorted(
[p for p in models_path.rglob("*.pt") if p.is_file()],
key=lambda p: str(p).lower(),
)
except Exception as e:
logger.warning(f"Error discovering local models: {e}")
return []
def _normalize_model_path(self, path_value) -> str:
"""Return a normalized absolute path string for comparison."""
if not path_value:
return ""
try:
return str(Path(path_value).resolve())
except Exception:
return str(path_value)
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
"""Ensure a database record exists for the selected model."""
model_path = model_data.get("path")
if not model_path:
return None
normalized_target = self._normalize_model_path(model_path)
try:
existing_models = self.db_manager.get_models()
for model in existing_models:
existing_path = model.get("model_path")
if not existing_path:
continue
normalized_existing = self._normalize_model_path(existing_path)
if (
normalized_existing == normalized_target
or existing_path == model_path
):
return model["id"]
model_name = (
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
)
model_version = (
model_data.get("model_version") or model_data.get("source") or "local"
)
base_model = model_data.get(
"base_model",
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
)
return self.db_manager.add_model(
model_name=model_name,
model_version=model_version,
model_path=normalized_target,
base_model=base_model,
)
except Exception as e:
logger.error(f"Failed to ensure model record for {model_path}: {e}")
return None
def refresh(self):
"""Refresh the tab."""
self._load_models()

View File

@@ -1,15 +1,39 @@
"""
Results tab for the microscopy object detection application.
Results tab for browsing stored detections and visualizing overlays.
"""
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
from pathlib import Path
from typing import Dict, List, Optional
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QHBoxLayout,
QLabel,
QGroupBox,
QPushButton,
QSplitter,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QAbstractItemView,
QMessageBox,
QCheckBox,
)
from PySide6.QtCore import Qt
from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
from src.utils.image import Image, ImageLoadError
from src.gui.widgets import AnnotationCanvasWidget
logger = get_logger(__name__)
class ResultsTab(QWidget):
"""Results tab placeholder."""
"""Results tab showing detection history and preview overlays."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
@@ -18,29 +42,398 @@ class ResultsTab(QWidget):
self.db_manager = db_manager
self.config_manager = config_manager
self.detection_summary: List[Dict] = []
self.current_selection: Optional[Dict] = None
self.current_image: Optional[Image] = None
self.current_detections: List[Dict] = []
self._image_path_cache: Dict[str, str] = {}
self._setup_ui()
self.refresh()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
group = QGroupBox("Results")
group_layout = QVBoxLayout()
label = QLabel(
"Results viewer will be implemented here.\n\n"
"Features:\n"
"- Detection history browser\n"
"- Advanced filtering\n"
"- Statistics dashboard\n"
"- Export functionality"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
# Splitter for list + preview
splitter = QSplitter(Qt.Horizontal)
layout.addWidget(group)
layout.addStretch()
# Left pane: detection list
left_container = QWidget()
left_layout = QVBoxLayout()
left_layout.setContentsMargins(0, 0, 0, 0)
controls_layout = QHBoxLayout()
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn)
controls_layout.addStretch()
left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels(
["Image", "Model", "Detections", "Classes", "Last Updated"]
)
self.results_table.horizontalHeader().setSectionResizeMode(
0, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
left_layout.addWidget(self.results_table)
left_container.setLayout(left_layout)
# Right pane: preview canvas and controls
right_container = QWidget()
right_layout = QVBoxLayout()
right_layout.setContentsMargins(0, 0, 0, 0)
preview_group = QGroupBox("Detection Preview")
preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget()
self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas)
toggles_layout = QHBoxLayout()
self.show_masks_checkbox = QCheckBox("Show Masks")
self.show_masks_checkbox.setChecked(False)
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
self.show_bboxes_checkbox.setChecked(True)
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
toggles_layout.addStretch()
preview_layout.addLayout(toggles_layout)
self.summary_label = QLabel("Select a detection result to preview.")
self.summary_label.setWordWrap(True)
preview_layout.addWidget(self.summary_label)
preview_group.setLayout(preview_layout)
right_layout.addWidget(preview_group)
right_container.setLayout(right_layout)
splitter.addWidget(left_container)
splitter.addWidget(right_container)
splitter.setStretchFactor(0, 1)
splitter.setStretchFactor(1, 2)
layout.addWidget(splitter)
self.setLayout(layout)
def refresh(self):
"""Refresh the tab."""
pass
"""Refresh the detection list and preview."""
self._load_detection_summary()
self._populate_results_table()
self.current_selection = None
self.current_image = None
self.current_detections = []
self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.")
def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model."""
try:
detections = self.db_manager.get_detections(limit=500)
summary_map: Dict[tuple, Dict] = {}
for det in detections:
key = (det["image_id"], det["model_id"])
metadata = det.get("metadata") or {}
entry = summary_map.setdefault(
key,
{
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename")
or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
"count": 0,
"classes": set(),
"source_path": metadata.get("source_path"),
"repository_root": metadata.get("repository_root"),
},
)
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected")
or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
entry["classes"].add(det["class_name"])
if metadata.get("source_path") and not entry.get("source_path"):
entry["source_path"] = metadata.get("source_path")
if metadata.get("repository_root") and not entry.get("repository_root"):
entry["repository_root"] = metadata.get("repository_root")
self.detection_summary = sorted(
summary_map.values(),
key=lambda x: str(x.get("last_detected") or ""),
reverse=True,
)
except Exception as e:
logger.error(f"Failed to load detection summary: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detection results:\n{str(e)}",
)
self.detection_summary = []
def _populate_results_table(self):
"""Populate the table widget with detection summaries."""
self.results_table.setRowCount(len(self.detection_summary))
for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = (
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
items = [
QTableWidgetItem(entry.get("image_filename", "")),
QTableWidgetItem(model_label),
QTableWidgetItem(str(entry.get("count", 0))),
QTableWidgetItem(class_list),
QTableWidgetItem(str(entry.get("last_detected") or "")),
]
for col, item in enumerate(items):
item.setData(Qt.UserRole, row)
self.results_table.setItem(row, col, item)
self.results_table.clearSelection()
def _on_result_selected(self):
"""Handle selection changes in the detection table."""
selected_items = self.results_table.selectedItems()
if not selected_items:
return
row = selected_items[0].data(Qt.UserRole)
if row is None or row >= len(self.detection_summary):
return
entry = self.detection_summary[row]
if (
self.current_selection
and self.current_selection.get("image_id") == entry["image_id"]
and self.current_selection.get("model_id") == entry["model_id"]
):
return
self.current_selection = entry
image_path = self._resolve_image_path(entry)
if not image_path:
QMessageBox.warning(
self,
"Image Not Found",
"Unable to locate the image file for this detection.",
)
return
try:
self.current_image = Image(image_path)
self.preview_canvas.load_image(self.current_image)
except ImageLoadError as e:
logger.error(f"Failed to load image '{image_path}': {e}")
QMessageBox.critical(
self,
"Image Error",
f"Failed to load image for preview:\n{str(e)}",
)
return
self._load_detections_for_selection(entry)
self._apply_detection_overlays()
self._update_summary_label(entry)
def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair."""
self.current_detections = []
if not entry:
return
try:
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
self.current_detections = self.db_manager.get_detections(filters)
except Exception as e:
logger.error(f"Failed to load detections for preview: {e}")
QMessageBox.critical(
self,
"Error",
f"Failed to load detections for this image:\n{str(e)}",
)
self.current_detections = []
def _apply_detection_overlays(self):
"""Draw detections onto the preview canvas based on current toggles."""
self.preview_canvas.clear_annotations()
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
if not self.current_detections or not self.current_image:
return
for det in self.current_detections:
color = self._get_class_color(det.get("class_name"))
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
mask_points = self._convert_mask(det["segmentation_mask"])
if mask_points:
self.preview_canvas.draw_saved_polyline(mask_points, color)
bbox = [
det.get("x_min"),
det.get("y_min"),
det.get("x_max"),
det.get("y_max"),
]
if all(v is not None for v in bbox):
label = None
if self.show_confidence_checkbox.isChecked():
confidence = det.get("confidence")
if confidence is not None:
label = f"{confidence:.2f}"
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
converted = []
for point in mask_points:
if len(point) >= 2:
x, y = point[0], point[1]
converted.append([y, x])
return converted
def _toggle_bboxes(self):
"""Update bounding box visibility on the canvas."""
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
# Re-render to respect show/hide when toggled
self._apply_detection_overlays()
def _update_summary_label(self, entry: Dict):
"""Display textual summary for the selected detection run."""
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
summary_text = (
f"Image: {entry.get('image_filename', 'unknown')}\n"
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
f"Detections: {entry.get('count', 0)}\n"
f"Classes: {classes}\n"
f"Last Updated: {entry.get('last_detected', 'n/a')}"
)
self.summary_label.setText(summary_text)
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
"""Resolve an image path using metadata, cache, and repository hints."""
relative_path = entry.get("image_path") if entry else None
cache_key = relative_path or entry.get("source_path")
if cache_key and cache_key in self._image_path_cache:
cached = Path(self._image_path_cache[cache_key])
if cached.exists():
return self._image_path_cache[cache_key]
del self._image_path_cache[cache_key]
candidates = []
source_path = entry.get("source_path") if entry else None
if source_path:
candidates.append(Path(source_path))
repo_roots = []
if entry.get("repository_root"):
repo_roots.append(entry["repository_root"])
config_repo = self.config_manager.get_image_repository_path()
if config_repo:
repo_roots.append(config_repo)
for root in repo_roots:
if relative_path:
candidates.append(Path(root) / relative_path)
if relative_path:
candidates.append(Path(relative_path))
for candidate in candidates:
try:
if candidate and candidate.exists():
resolved = str(candidate.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
except Exception:
continue
# Fallback: search by filename in known roots
filename = Path(relative_path).name if relative_path else None
if filename:
search_roots = [Path(root) for root in repo_roots if root]
if not search_roots:
search_roots = [Path("data")]
match = self._search_in_roots(filename, search_roots)
if match:
resolved = str(match.resolve())
if cache_key:
self._image_path_cache[cache_key] = resolved
return resolved
return None
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
"""Search for a file name within a list of root directories."""
for root in roots:
try:
if not root.exists():
continue
for candidate in root.rglob(filename):
return candidate
except Exception as e:
logger.debug(f"Error searching for {filename} in {root}: {e}")
return None
def _get_class_color(self, class_name: Optional[str]) -> str:
"""Return consistent color hex for a class name."""
if not class_name:
return "#FF6B6B"
color_map = self.config_manager.get_bbox_colors()
if class_name in color_map:
return color_map[class_name]
# Deterministic fallback color based on hash
palette = [
"#FF6B6B",
"#4ECDC4",
"#FFD166",
"#1D3557",
"#F4A261",
"#E76F51",
]
return palette[hash(class_name) % len(palette)]

File diff suppressed because it is too large Load Diff

View File

@@ -16,8 +16,9 @@ from PySide6.QtGui import (
QKeyEvent,
QMouseEvent,
QPaintEvent,
QPolygonF,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError
@@ -246,10 +247,10 @@ class AnnotationCanvasWidget(QWidget):
return
try:
# Get RGB image data
if self.current_image.channels == 3:
# Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb()
height, width, channels = image_data.shape
height, width = image_data.shape[:2]
else:
image_data = self.current_image.get_grayscale()
height, width = image_data.shape
@@ -263,7 +264,7 @@ class AnnotationCanvasWidget(QWidget):
height,
bytes_per_line,
self.current_image.qtimage_format,
)
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage)
@@ -496,8 +497,10 @@ class AnnotationCanvasWidget(QWidget):
)
painter.setPen(pen)
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
painter.drawLine(int(x1), int(y1), int(x2), int(y2))
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
# drawPolygon() automatically closes the shape, ensuring proper visual closure
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
painter.drawPolygon(polygon)
# Draw bounding boxes (dashed) if enabled
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
@@ -529,6 +532,40 @@ class AnnotationCanvasWidget(QWidget):
painter.setPen(pen)
painter.drawRect(x_min, y_min, rect_width, rect_height)
label_text = meta.get("label")
if label_text:
painter.save()
font = painter.font()
font.setPointSizeF(max(10.0, width + 4))
painter.setFont(font)
metrics = painter.fontMetrics()
text_width = metrics.horizontalAdvance(label_text)
text_height = metrics.height()
padding = 4
bg_width = text_width + padding * 2
bg_height = text_height + padding * 2
canvas_width = self.original_pixmap.width()
canvas_height = self.original_pixmap.height()
bg_x = max(0, min(x_min, canvas_width - bg_width))
bg_y = y_min - bg_height
if bg_y < 0:
bg_y = min(y_min, canvas_height - bg_height)
bg_y = max(0, bg_y)
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
background_color = QColor(pen_color)
background_color.setAlpha(220)
painter.fillRect(background_rect, background_color)
text_color = QColor(0, 0, 0)
if background_color.lightness() < 128:
text_color = QColor(255, 255, 255)
painter.setPen(text_color)
painter.drawText(
background_rect.adjusted(padding, padding, -padding, -padding),
Qt.AlignLeft | Qt.AlignVCenter,
label_text,
)
painter.restore()
painter.end()
self._update_display()
@@ -787,7 +824,13 @@ class AnnotationCanvasWidget(QWidget):
f"Drew saved polyline with {len(polyline)} points in color {color}"
)
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
def draw_saved_bbox(
self,
bbox: List[float],
color: str,
width: int = 3,
label: Optional[str] = None,
):
"""
Draw a bounding box from database coordinates onto the annotation canvas.
@@ -796,6 +839,7 @@ class AnnotationCanvasWidget(QWidget):
in normalized coordinates (0-1)
color: Color hex string (e.g., '#FF0000')
width: Line width in pixels
label: Optional text label to render near the bounding box
"""
if not self.annotation_pixmap or not self.original_pixmap:
logger.warning("Cannot draw bounding box: no image loaded")
@@ -828,11 +872,11 @@ class AnnotationCanvasWidget(QWidget):
self.bboxes.append(
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
self.bbox_meta.append({"color": pen_color, "width": int(width)})
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency
self.all_strokes.append(
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
)
# Redraw overlay (polylines + all bounding boxes)

View File

@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
height,
bytes_per_line,
self.current_image.qtimage_format,
)
).copy() # Copy to ensure Qt owns its memory after this scope
# Convert to pixmap
pixmap = QPixmap.fromImage(qimage)

View File

@@ -5,12 +5,12 @@ Handles detection inference and result storage.
from typing import List, Dict, Optional, Callable
from pathlib import Path
from PIL import Image
import cv2
import numpy as np
from src.model.yolo_wrapper import YOLOWrapper
from src.database.db_manager import DatabaseManager
from src.utils.image import Image
from src.utils.logger import get_logger
from src.utils.file_utils import get_relative_path
@@ -42,6 +42,7 @@ class InferenceEngine:
relative_path: str,
conf: float = 0.25,
save_to_db: bool = True,
repository_root: Optional[str] = None,
) -> Dict:
"""
Detect objects in a single image.
@@ -51,49 +52,79 @@ class InferenceEngine:
relative_path: Relative path from repository root
conf: Confidence threshold
save_to_db: Whether to save results to database
repository_root: Base directory used to compute relative_path (if known)
Returns:
Dictionary with detection results
"""
try:
# Normalize storage path (fall back to absolute path when repo root is unknown)
stored_relative_path = relative_path
if not repository_root:
stored_relative_path = str(Path(image_path).resolve())
# Get image dimensions
img = Image.open(image_path)
width, height = img.size
img.close()
img = Image(image_path)
width = img.width
height = img.height
# Perform detection
detections = self.yolo.predict(image_path, conf=conf)
# Add/get image in database
image_id = self.db_manager.get_or_create_image(
relative_path=relative_path,
relative_path=stored_relative_path,
filename=Path(image_path).name,
width=width,
height=height,
)
# Save detections to database
if save_to_db and detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
inserted_count = 0
deleted_count = 0
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"segmentation_mask": det.get("segmentation_mask"),
"metadata": {"class_id": det["class_id"]},
}
detection_records.append(record)
# Save detections to database, replacing any previous results for this image/model
if save_to_db:
deleted_count = self.db_manager.delete_detections_for_image(
image_id, self.model_id
)
if detections:
detection_records = []
for det in detections:
# Use normalized bbox from detection
bbox_normalized = det[
"bbox_normalized"
] # [x_min, y_min, x_max, y_max]
self.db_manager.add_detections_batch(detection_records)
logger.info(f"Saved {len(detection_records)} detections to database")
metadata = {
"class_id": det["class_id"],
"source_path": str(Path(image_path).resolve()),
}
if repository_root:
metadata["repository_root"] = str(
Path(repository_root).resolve()
)
record = {
"image_id": image_id,
"model_id": self.model_id,
"class_name": det["class_name"],
"bbox": tuple(bbox_normalized),
"confidence": det["confidence"],
"segmentation_mask": det.get("segmentation_mask"),
"metadata": metadata,
}
detection_records.append(record)
inserted_count = self.db_manager.add_detections_batch(
detection_records
)
logger.info(
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
)
else:
logger.info(
f"Detection run removed {deleted_count} stale entries but produced no new detections"
)
return {
"success": True,
@@ -142,7 +173,12 @@ class InferenceEngine:
rel_path = get_relative_path(image_path, repository_root)
# Perform detection
result = self.detect_single(image_path, rel_path, conf)
result = self.detect_single(
image_path,
rel_path,
conf=conf,
repository_root=repository_root,
)
results.append(result)
# Update progress

View File

@@ -7,6 +7,9 @@ from ultralytics import YOLO
from pathlib import Path
from typing import Optional, List, Dict, Callable, Any
import torch
import tempfile
import os
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range
from src.utils.logger import get_logger
@@ -55,6 +58,7 @@ class YOLOWrapper:
save_dir: str = "data/models",
name: str = "custom_model",
resume: bool = False,
callbacks: Optional[Dict[str, Callable]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
@@ -69,13 +73,15 @@ class YOLOWrapper:
save_dir: Directory to save trained model
name: Name for the training run
resume: Resume training from last checkpoint
callbacks: Optional Ultralytics callback dictionary
**kwargs: Additional training arguments
Returns:
Dictionary with training results
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Starting training: {name}")
@@ -117,7 +123,8 @@ class YOLOWrapper:
Dictionary with validation metrics
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Starting validation on {split} split")
@@ -158,12 +165,15 @@ class YOLOWrapper:
List of detection dictionaries
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
try:
logger.info(f"Running inference on {source}")
results = self.model.predict(
source=source,
source=prepared_source,
conf=conf,
iou=iou,
save=save,
@@ -180,6 +190,14 @@ class YOLOWrapper:
except Exception as e:
logger.error(f"Error during inference: {e}")
raise
finally:
if 0: # cleanup_path:
try:
os.remove(cleanup_path)
except OSError as cleanup_error:
logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
@@ -196,7 +214,8 @@ class YOLOWrapper:
Path to exported model
"""
if self.model is None:
self.load_model()
if not self.load_model():
raise RuntimeError(f"Failed to load model from {self.model_path}")
try:
logger.info(f"Exporting model to {format} format")
@@ -208,6 +227,38 @@ class YOLOWrapper:
logger.error(f"Error exporting model: {e}")
raise
def _prepare_source(self, source):
"""Convert single-channel images to RGB temporarily for inference."""
cleanup_path = None
if isinstance(source, (str, Path)):
source_path = Path(source)
if source_path.is_file():
try:
img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name
tmp.close()
rgb_img.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
)
return source, cleanup_path
def _format_training_results(self, results) -> Dict[str, Any]:
"""Format training results into dictionary."""
try:

View File

@@ -7,6 +7,7 @@ import yaml
from pathlib import Path
from typing import Any, Dict, Optional
from src.utils.logger import get_logger
from src.utils.image import Image
logger = get_logger(__name__)
@@ -46,18 +47,15 @@ class ConfigManager:
"database": {"path": "data/detections.db"},
"image_repository": {
"base_path": "",
"allowed_extensions": [
".jpg",
".jpeg",
".png",
".tif",
".tiff",
".bmp",
],
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
},
"models": {
"default_base_model": "yolov8s-seg.pt",
"models_directory": "data/models",
"base_model_choices": [
"yolov8s-seg.pt",
"yolov11s-seg.pt",
],
},
"training": {
"default_epochs": 100,
@@ -65,6 +63,20 @@ class ConfigManager:
"default_imgsz": 640,
"default_patience": 50,
"default_lr0": 0.01,
"two_stage": {
"enabled": False,
"stage1": {
"epochs": 20,
"lr0": 0.0005,
"patience": 10,
"freeze": 10,
},
"stage2": {
"epochs": 150,
"lr0": 0.0003,
"patience": 30,
},
},
},
"detection": {
"default_confidence": 0.25,
@@ -214,5 +226,5 @@ class ConfigManager:
def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions."""
return self.get(
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
)

View File

@@ -28,7 +28,9 @@ def get_image_files(
List of absolute paths to image files
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
# Normalize extensions to lowercase
allowed_extensions = [ext.lower() for ext in allowed_extensions]
@@ -204,7 +206,9 @@ def is_image_file(
True if file is an image
"""
if allowed_extensions is None:
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
from src.utils.image import Image
allowed_extensions = Image.SUPPORTED_EXTENSIONS
extension = Path(file_path).suffix.lower()
return extension in [ext.lower() for ext in allowed_extensions]

View File

@@ -289,3 +289,40 @@ class Image:
def __str__(self) -> str:
"""String representation of the Image object."""
return self.__repr__()
def convert_grayscale_to_rgb_preserve_range(
pil_image: PILImage.Image,
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Args:
pil_image: Single-channel PIL image (e.g., 16-bit grayscale).
Returns:
PIL Image in RGB mode with intensities normalized to 0-255.
"""
if pil_image.mode == "RGB":
return pil_image
grayscale = np.array(pil_image)
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")

View File

@@ -0,0 +1,139 @@
import numpy as np
from roifile import ImagejRoi
from tifffile import TiffFile, TiffWriter
from pathlib import Path
class UT:
"""
Docstring for UT
Operetta files along with rois drawn in ImageJ
"""
def __init__(self, roifile_fn: Path):
self.roifile_fn = roifile_fn
print("is file", self.roifile_fn.is_file())
self.rois = ImagejRoi.fromfile(self.roifile_fn)
self.stem = self.roifile_fn.stem.split("Roi-")[1]
self.image, self.image_props = self._load_images()
def _load_images(self):
"""Loading sequence of tif files
array sequence is CZYX
"""
print(self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
n_p = len(set([stem.split("-")[0] for stem in stems]))
n_t = len(set([stem.split("t")[1] for stem in stems]))
with TiffFile(fns[0]) as tif:
img = tif.asarray()
w, h = img.shape
dtype = img.dtype
self.image_props = {
"channels": n_ch,
"planes": n_p,
"tiles": n_t,
"width": w,
"height": h,
"dtype": dtype,
}
print("Image props", self.image_props)
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
for fn in fns:
with TiffFile(fn) as tif:
img = tif.asarray()
stem = fn.stem.split(self.stem)[-1]
ch = int(stem.split("-ch")[-1].split("t")[0])
p = int(stem.split("-")[0].split("p")[1])
t = int(stem.split("t")[1])
print(fn.stem, "ch", ch, "p", p, "t", t)
image_stack[ch - 1, p - 1] = img
print(image_stack.shape)
return image_stack, self.image_props
@property
def width(self):
return self.image_props["width"]
@property
def height(self):
return self.image_props["height"]
@property
def nchannels(self):
return self.image_props["channels"]
@property
def nplanes(self):
return self.image_props["planes"]
def export_rois(
self,
path: Path,
subfolder: str = "labels",
class_index: int = 0,
):
"""Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
for i, roi in enumerate(self.rois):
rc = roi.subpixel_coordinates
if rc is None:
print(
f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}"
)
continue
xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
for x, y in rc:
coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n")
return
def export_image(
self,
path: Path,
subfolder: str = "images",
plane_mode: str = "max projection",
channel: int = 0,
):
"""Export image to a file"""
if plane_mode == "max projection":
self.image = np.max(self.image[channel], axis=0)
print(self.image.shape)
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
tif.write(self.image)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", nargs="*", type=Path)
parser.add_argument("-o", "--output", type=Path)
args = parser.parse_args()
for path in args.input:
print("Path:", path)
for rfn in Path(path).glob("*.zip"):
print("Roi FN:", rfn)
ut = UT(rfn)
ut.export_rois(args.output, class_index=0)
ut.export_image(args.output, plane_mode="max projection", channel=0)
print()

184
tests/show_yolo_seg.py Normal file
View File

@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
show_yolo_seg.py
Usage:
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
Supports:
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
- YOLO bbox lines as fallback: "class x_center y_center width height"
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
"""
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import random
def parse_label_line(line):
parts = line.strip().split()
if not parts:
return None
cls = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
return cls, coords
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
if not coords:
return False
return max(coords) <= 1.001
def yolo_bbox_to_xyxy(coords, img_w, img_h):
# coords: [xc, yc, w, h] normalized or absolute
xc, yc, w, h = coords[:4]
if max(coords) <= 1.001:
xc *= img_w
yc *= img_h
w *= img_w
h *= img_h
x1 = int(round(xc - w / 2))
y1 = int(round(yc - h / 2))
x2 = int(round(xc + w / 2))
y2 = int(round(yc + h / 2))
return x1, y1, x2, y2
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]):
coords = [
coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))
]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
def random_color_for_class(cls):
random.seed(cls) # deterministic per class
return tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array
overlay = img.copy()
h, w = img.shape[:2]
for cls, coords in labels:
if not coords:
continue
# polygon case (>=6 coordinates)
if len(coords) >= 6:
color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
pts = poly_to_pts(coords[4:], w, h)
# fill on overlay
cv2.fillPoly(overlay, [pts], color)
# outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=2)
# put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
cv2.putText(
img,
str(cls),
(x, max(6, y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# YOLO bbox case (4 coords)
elif len(coords) == 4:
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
color = random_color_for_class(cls)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
cv2.putText(
img,
str(cls),
(x1, max(6, y1 - 4)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
else:
# Unknown / invalid format, skip
continue
# blend overlay for filled polygons
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
return img
def load_labels_file(label_path):
labels = []
with open(label_path, "r") as f:
for raw in f:
line = raw.strip()
if not line:
continue
parsed = parse_label_line(line)
if parsed:
labels.append(parsed)
return labels
def main():
parser = argparse.ArgumentParser(
description="Show YOLO segmentation / polygon annotations"
)
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument(
"--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)"
)
parser.add_argument(
"--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons"
)
args = parser.parse_args()
img_path = Path(args.image)
lbl_path = Path(args.labels)
if not img_path.exists():
print("Image not found:", img_path)
sys.exit(1)
if not lbl_path.exists():
print("Label file not found:", lbl_path)
sys.exit(1)
img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
if img is None:
print("Could not load image:", img_path)
sys.exit(1)
labels = load_labels_file(str(lbl_path))
if not labels:
print("No labels parsed from", lbl_path)
# continue and just show image
out = draw_annotations(
img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox)
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show()
if __name__ == "__main__":
main()

View File

@@ -27,7 +27,7 @@ class TestImage:
def test_supported_extensions(self):
"""Test that supported extensions are correctly defined."""
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
expected_extensions = Image.SUPPORTED_EXTENSIONS
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
def test_image_properties(self, tmp_path):