37 Commits

Author SHA1 Message Date
9c8931e6f3 Finish validation tab 2026-01-16 13:58:02 +02:00
20578c1fdf Adding a file and feature to delete all detections from database 2026-01-16 13:43:05 +02:00
2c494dac49 Adding export for labels in results 2026-01-16 11:15:12 +02:00
506c74e53a Small update 2026-01-16 10:39:46 +02:00
eefda5b878 Adding metdata to tiled images 2026-01-16 10:39:14 +02:00
31cb6a6c8e Using 8bit images 2026-01-16 10:38:34 +02:00
0c19ea2557 Updating 2026-01-16 10:30:13 +02:00
89e47591db Formatting 2026-01-16 10:27:15 +02:00
69cde09e53 Changing alpha value 2026-01-16 10:26:25 +02:00
fcbd5fb16d correcting label writing and formatting code 2026-01-16 10:24:19 +02:00
ca52312925 Adding LIKE option for filtering queries 2026-01-16 10:18:48 +02:00
0a93bf797a Adding auto zoom when result is loaded 2026-01-12 14:15:02 +02:00
d998c65665 Updating image splitter 2026-01-12 13:28:00 +02:00
510eabfa94 Adding splitter method 2026-01-05 13:56:57 +02:00
395d263900 Update 2026-01-05 08:59:36 +02:00
e98d287b8a Updating tiff image patch 2026-01-02 12:44:06 +02:00
d25101de2d adding files 2026-01-02 12:40:44 +02:00
f88beef188 Another test 2025-12-19 13:50:49 +02:00
2fd9a2acf4 RGB 2025-12-19 13:31:24 +02:00
2bcd18cc75 Bug fix 2025-12-19 13:13:12 +02:00
5d25378c46 Testing with uint conversion 2025-12-19 13:10:36 +02:00
2b0b48921e Testing more grayscale 2025-12-19 12:02:11 +02:00
b0c05f0225 testing grayscale 2025-12-19 11:55:38 +02:00
97badaa390 Samll update 2025-12-19 11:31:12 +02:00
8f8132ce61 Testing detect 2025-12-19 10:44:11 +02:00
6ae7481e25 Adding debug messages 2025-12-19 10:15:53 +02:00
061f8b3ca2 Fixing pseudo rgb 2025-12-19 09:56:43 +02:00
a8e5db3135 Small change 2025-12-18 13:03:12 +02:00
268ed5175e Appling pseudo channels for RGB 2025-12-18 12:52:13 +02:00
5e9d3b1dc4 Adding logger 2025-12-18 12:04:41 +02:00
7d83e9b9b1 Adding important file 2025-12-17 00:45:56 +02:00
e364d06217 Implementing uint16 reading with tifffile 2025-12-16 23:02:45 +02:00
e5036c10cf Small fix 2025-12-16 18:03:56 +02:00
c7e388d9ae Updating progressbar 2025-12-16 17:20:25 +02:00
6b995e7325 upate 2025-12-16 13:24:20 +02:00
0e0741d323 Update on convert_grayscale_to_rgb_preserve_range, making it class method 2025-12-16 12:37:34 +02:00
dd99a0677c Updating image converter and aading simple script to visulaize segmentation 2025-12-16 11:27:38 +02:00
17 changed files with 2086 additions and 521 deletions

View File

@@ -1,57 +0,0 @@
database:
path: data/detections.db
image_repository:
base_path: ''
allowed_extensions:
- .jpg
- .jpeg
- .png
- .tif
- .tiff
- .bmp
models:
default_base_model: yolov8s-seg.pt
models_directory: data/models
base_model_choices:
- yolov8s-seg.pt
- yolo11s-seg.pt
training:
default_epochs: 100
default_batch_size: 16
default_imgsz: 1024
default_patience: 50
default_lr0: 0.01
two_stage:
enabled: false
stage1:
epochs: 20
lr0: 0.0005
patience: 10
freeze: 10
stage2:
epochs: 150
lr0: 0.0003
patience: 30
last_dataset_yaml: /home/martin/code/object_detection/data/datasets/data.yaml
last_dataset_dir: /home/martin/code/object_detection/data/datasets
detection:
default_confidence: 0.25
default_iou: 0.45
max_batch_size: 100
visualization:
bbox_colors:
organelle: '#FF6B6B'
membrane_branch: '#4ECDC4'
default: '#00FF00'
bbox_thickness: 2
font_size: 12
export:
formats:
- csv
- json
- excel
default_format: csv
logging:
level: INFO
file: logs/app.log
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'

View File

@@ -82,12 +82,12 @@ include-package-data = true
"src.database" = ["*.sql"] "src.database" = ["*.sql"]
[tool.black] [tool.black]
line-length = 88 line-length = 120
target-version = ['py38', 'py39', 'py310', 'py311'] target-version = ['py38', 'py39', 'py310', 'py311']
include = '\.pyi?$' include = '\.pyi?$'
[tool.pylint.messages_control] [tool.pylint.messages_control]
max-line-length = 88 max-line-length = 120
[tool.mypy] [tool.mypy]
python_version = "3.8" python_version = "3.8"

View File

@@ -60,9 +60,7 @@ class DatabaseManager:
cursor = conn.cursor() cursor = conn.cursor()
# Check if annotations table exists # Check if annotations table exists
cursor.execute( cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
)
if not cursor.fetchone(): if not cursor.fetchone():
# Table doesn't exist yet, no migration needed # Table doesn't exist yet, no migration needed
return return
@@ -242,9 +240,7 @@ class DatabaseManager:
return cursor.lastrowid return cursor.lastrowid
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
# Image already exists, return its ID # Image already exists, return its ID
cursor.execute( cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone() row = cursor.fetchone()
return row["id"] if row else None return row["id"] if row else None
finally: finally:
@@ -255,17 +251,13 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
)
row = cursor.fetchone() row = cursor.fetchone()
return dict(row) if row else None return dict(row) if row else None
finally: finally:
conn.close() conn.close()
def get_or_create_image( def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
self, relative_path: str, filename: str, width: int, height: int
) -> int:
"""Get existing image or create new one.""" """Get existing image or create new one."""
existing = self.get_image_by_path(relative_path) existing = self.get_image_by_path(relative_path)
if existing: if existing:
@@ -355,16 +347,8 @@ class DatabaseManager:
bbox[2], bbox[2],
bbox[3], bbox[3],
det["confidence"], det["confidence"],
( (json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
json.dumps(det.get("segmentation_mask")) (json.dumps(det.get("metadata")) if det.get("metadata") else None),
if det.get("segmentation_mask")
else None
),
(
json.dumps(det.get("metadata"))
if det.get("metadata")
else None
),
), ),
) )
conn.commit() conn.commit()
@@ -409,12 +393,13 @@ class DatabaseManager:
if filters: if filters:
conditions = [] conditions = []
for key, value in filters.items(): for key, value in filters.items():
if ( if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
key.startswith("d.") if "like" in value.lower():
or key.startswith("i.") conditions.append(f"{key} LIKE ?")
or key.startswith("m.") params.append(value.split(" ")[1])
): else:
conditions.append(f"{key} = ?") conditions.append(f"{key} = ?")
params.append(value)
else: else:
conditions.append(f"d.{key} = ?") conditions.append(f"d.{key} = ?")
params.append(value) params.append(value)
@@ -442,18 +427,14 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def get_detections_for_image( def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
self, image_id: int, model_id: Optional[int] = None
) -> List[Dict]:
"""Get all detections for a specific image.""" """Get all detections for a specific image."""
filters = {"image_id": image_id} filters = {"image_id": image_id}
if model_id: if model_id:
filters["model_id"] = model_id filters["model_id"] = model_id
return self.get_detections(filters) return self.get_detections(filters)
def delete_detections_for_image( def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
self, image_id: int, model_id: Optional[int] = None
) -> int:
"""Delete detections tied to a specific image and optional model.""" """Delete detections tied to a specific image and optional model."""
conn = self.get_connection() conn = self.get_connection()
try: try:
@@ -481,6 +462,22 @@ class DatabaseManager:
finally: finally:
conn.close() conn.close()
def delete_all_detections(self) -> int:
"""Delete all detections from the database.
Returns:
Number of rows deleted.
"""
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute("DELETE FROM detections")
conn.commit()
return cursor.rowcount
finally:
conn.close()
# ==================== Statistics Operations ==================== # ==================== Statistics Operations ====================
def get_detection_statistics( def get_detection_statistics(
@@ -524,9 +521,7 @@ class DatabaseManager:
""", """,
params, params,
) )
class_counts = { class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
row["class_name"]: row["count"] for row in cursor.fetchall()
}
# Average confidence # Average confidence
cursor.execute( cursor.execute(
@@ -583,9 +578,7 @@ class DatabaseManager:
# ==================== Export Operations ==================== # ==================== Export Operations ====================
def export_detections_to_csv( def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to CSV file.""" """Export detections to CSV file."""
try: try:
detections = self.get_detections(filters) detections = self.get_detections(filters)
@@ -614,9 +607,7 @@ class DatabaseManager:
for det in detections: for det in detections:
row = {k: det[k] for k in fieldnames if k in det} row = {k: det[k] for k in fieldnames if k in det}
# Convert segmentation mask list to JSON string for CSV # Convert segmentation mask list to JSON string for CSV
if row.get("segmentation_mask") and isinstance( if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
row["segmentation_mask"], list
):
row["segmentation_mask"] = json.dumps(row["segmentation_mask"]) row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
writer.writerow(row) writer.writerow(row)
@@ -625,9 +616,7 @@ class DatabaseManager:
print(f"Error exporting to CSV: {e}") print(f"Error exporting to CSV: {e}")
return False return False
def export_detections_to_json( def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
"""Export detections to JSON file.""" """Export detections to JSON file."""
try: try:
detections = self.get_detections(filters) detections = self.get_detections(filters)
@@ -785,17 +774,13 @@ class DatabaseManager:
conn = self.get_connection() conn = self.get_connection()
try: try:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute( cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
)
row = cursor.fetchone() row = cursor.fetchone()
return dict(row) if row else None return dict(row) if row else None
finally: finally:
conn.close() conn.close()
def add_object_class( def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
self, class_name: str, color: str, description: Optional[str] = None
) -> int:
""" """
Add a new object class. Add a new object class.
@@ -928,8 +913,7 @@ class DatabaseManager:
if not split_map[required]: if not split_map[required]:
raise ValueError( raise ValueError(
"Unable to determine %s image directory under %s. Provide it " "Unable to determine %s image directory under %s. Provide it "
"explicitly via the 'splits' argument." "explicitly via the 'splits' argument." % (required, dataset_root_path)
% (required, dataset_root_path)
) )
yaml_splits: Dict[str, str] = {} yaml_splits: Dict[str, str] = {}
@@ -955,11 +939,7 @@ class DatabaseManager:
if yaml_splits.get("test"): if yaml_splits.get("test"):
payload["test"] = yaml_splits["test"] payload["test"] = yaml_splits["test"]
output_path_obj = ( output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
Path(output_path).expanduser()
if output_path
else dataset_root_path / "data.yaml"
)
output_path_obj.parent.mkdir(parents=True, exist_ok=True) output_path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(output_path_obj, "w", encoding="utf-8") as handle: with open(output_path_obj, "w", encoding="utf-8") as handle:
@@ -1019,15 +999,9 @@ class DatabaseManager:
for split_name, options in patterns.items(): for split_name, options in patterns.items():
for relative in options: for relative in options:
candidate = (dataset_root / relative).resolve() candidate = (dataset_root / relative).resolve()
if ( if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
candidate.exists()
and candidate.is_dir()
and self._directory_has_images(candidate)
):
try: try:
inferred[split_name] = candidate.relative_to( inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
dataset_root
).as_posix()
except ValueError: except ValueError:
inferred[split_name] = candidate.as_posix() inferred[split_name] = candidate.as_posix()
break break

View File

@@ -55,10 +55,7 @@ CREATE TABLE IF NOT EXISTS object_classes (
-- Insert default object classes -- Insert default object classes
INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES INSERT OR IGNORE INTO object_classes (class_name, color, description) VALUES
('cell', '#FF0000', 'Cell object'), ('terminal', '#FFFF00', 'Axion terminal');
('nucleus', '#00FF00', 'Cell nucleus'),
('mitochondria', '#0000FF', 'Mitochondria'),
('vesicle', '#FFFF00', 'Vesicle');
-- Annotations table: stores manual annotations -- Annotations table: stores manual annotations
CREATE TABLE IF NOT EXISTS annotations ( CREATE TABLE IF NOT EXISTS annotations (

View File

@@ -3,7 +3,7 @@ Results tab for browsing stored detections and visualizing overlays.
""" """
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Dict, List, Optional, Tuple
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QWidget, QWidget,
@@ -35,9 +35,7 @@ logger = get_logger(__name__)
class ResultsTab(QWidget): class ResultsTab(QWidget):
"""Results tab showing detection history and preview overlays.""" """Results tab showing detection history and preview overlays."""
def __init__( def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
@@ -67,28 +65,32 @@ class ResultsTab(QWidget):
self.refresh_btn = QPushButton("Refresh") self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh) self.refresh_btn.clicked.connect(self.refresh)
controls_layout.addWidget(self.refresh_btn) controls_layout.addWidget(self.refresh_btn)
self.delete_all_btn = QPushButton("Delete All Detections")
self.delete_all_btn.setToolTip(
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
)
self.delete_all_btn.clicked.connect(self._delete_all_detections)
controls_layout.addWidget(self.delete_all_btn)
self.export_labels_btn = QPushButton("Export Labels")
self.export_labels_btn.setToolTip(
"Export YOLO .txt labels for the selected image/model run.\n"
"Output path is inferred from the image path (images/ -> labels/)."
)
self.export_labels_btn.clicked.connect(self._export_labels_for_current_selection)
controls_layout.addWidget(self.export_labels_btn)
controls_layout.addStretch() controls_layout.addStretch()
left_layout.addLayout(controls_layout) left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5) self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels( self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
["Image", "Model", "Detections", "Classes", "Last Updated"] self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
) self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode( self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
0, QHeaderView.Stretch self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
) self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows) self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection) self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers) self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
@@ -106,6 +108,8 @@ class ResultsTab(QWidget):
preview_layout = QVBoxLayout() preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget() self.preview_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available preview viewport.
self.preview_canvas.set_auto_fit_to_view(True)
self.preview_canvas.set_polyline_enabled(False) self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True) self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas) preview_layout.addWidget(self.preview_canvas)
@@ -119,9 +123,7 @@ class ResultsTab(QWidget):
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes) self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence") self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False) self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect( self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
self._apply_detection_overlays
)
toggles_layout.addWidget(self.show_masks_checkbox) toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox) toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox) toggles_layout.addWidget(self.show_confidence_checkbox)
@@ -144,6 +146,41 @@ class ResultsTab(QWidget):
layout.addWidget(splitter) layout.addWidget(splitter)
self.setLayout(layout) self.setLayout(layout)
def _delete_all_detections(self):
"""Delete all detections from the database after user confirmation."""
confirm = QMessageBox.warning(
self,
"Delete All Detections",
"This will permanently delete ALL detections from the database.\n\n"
"This action cannot be undone.\n\n"
"Do you want to continue?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if confirm != QMessageBox.Yes:
return
try:
deleted = self.db_manager.delete_all_detections()
except Exception as exc:
logger.error(f"Failed to delete all detections: {exc}")
QMessageBox.critical(
self,
"Error",
f"Failed to delete detections:\n{exc}",
)
return
QMessageBox.information(
self,
"Delete All Detections",
f"Deleted {deleted} detection(s) from the database.",
)
# Reset UI state.
self.refresh()
def refresh(self): def refresh(self):
"""Refresh the detection list and preview.""" """Refresh the detection list and preview."""
self._load_detection_summary() self._load_detection_summary()
@@ -153,6 +190,8 @@ class ResultsTab(QWidget):
self.current_detections = [] self.current_detections = []
self.preview_canvas.clear() self.preview_canvas.clear()
self.summary_label.setText("Select a detection result to preview.") self.summary_label.setText("Select a detection result to preview.")
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(False)
def _load_detection_summary(self): def _load_detection_summary(self):
"""Load latest detection summaries grouped by image + model.""" """Load latest detection summaries grouped by image + model."""
@@ -169,8 +208,7 @@ class ResultsTab(QWidget):
"image_id": det["image_id"], "image_id": det["image_id"],
"model_id": det["model_id"], "model_id": det["model_id"],
"image_path": det.get("image_path"), "image_path": det.get("image_path"),
"image_filename": det.get("image_filename") "image_filename": det.get("image_filename") or det.get("image_path"),
or det.get("image_path"),
"model_name": det.get("model_name", ""), "model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""), "model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"), "last_detected": det.get("detected_at"),
@@ -183,8 +221,7 @@ class ResultsTab(QWidget):
entry["count"] += 1 entry["count"] += 1
if det.get("detected_at") and ( if det.get("detected_at") and (
not entry.get("last_detected") not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
or str(det.get("detected_at")) > str(entry.get("last_detected"))
): ):
entry["last_detected"] = det.get("detected_at") entry["last_detected"] = det.get("detected_at")
if det.get("class_name"): if det.get("class_name"):
@@ -214,9 +251,7 @@ class ResultsTab(QWidget):
for row, entry in enumerate(self.detection_summary): for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip() model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = ( class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
items = [ items = [
QTableWidgetItem(entry.get("image_filename", "")), QTableWidgetItem(entry.get("image_filename", "")),
@@ -276,6 +311,231 @@ class ResultsTab(QWidget):
self._load_detections_for_selection(entry) self._load_detections_for_selection(entry)
self._apply_detection_overlays() self._apply_detection_overlays()
self._update_summary_label(entry) self._update_summary_label(entry)
if hasattr(self, "export_labels_btn"):
self.export_labels_btn.setEnabled(True)
def _export_labels_for_current_selection(self):
"""Export YOLO label file(s) for the currently selected image/model."""
if not self.current_selection:
QMessageBox.information(self, "Export Labels", "Select a detection result first.")
return
entry = self.current_selection
image_path_str = self._resolve_image_path(entry)
if not image_path_str:
QMessageBox.warning(
self,
"Export Labels",
"Unable to locate the image file for this detection; cannot infer labels path.",
)
return
# Ensure we have the detections for the selection.
if not self.current_detections:
self._load_detections_for_selection(entry)
if not self.current_detections:
QMessageBox.information(
self,
"Export Labels",
"No detections found for this image/model selection.",
)
return
image_path = Path(image_path_str)
try:
label_path = self._infer_yolo_label_path(image_path)
except Exception as exc:
logger.error(f"Failed to infer label path for {image_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to infer export path for labels:\n{exc}",
)
return
class_map = self._build_detection_class_index_map(self.current_detections)
if not class_map:
QMessageBox.warning(
self,
"Export Labels",
"Unable to build class->index mapping (missing class names).",
)
return
lines_written = 0
skipped = 0
label_path.parent.mkdir(parents=True, exist_ok=True)
try:
with open(label_path, "w", encoding="utf-8") as handle:
print("writing to", label_path)
for det in self.current_detections:
yolo_line = self._format_detection_as_yolo_line(det, class_map)
if not yolo_line:
skipped += 1
continue
handle.write(yolo_line + "\n")
lines_written += 1
except OSError as exc:
logger.error(f"Failed to write labels file {label_path}: {exc}")
QMessageBox.critical(
self,
"Export Labels",
f"Failed to write label file:\n{label_path}\n\n{exc}",
)
return
return
# Optional: write a classes.txt next to the labels root to make the mapping discoverable.
# This is not required by Ultralytics (data.yaml usually holds class names), but helps reuse.
try:
classes_txt = label_path.parent.parent / "classes.txt"
classes_txt.parent.mkdir(parents=True, exist_ok=True)
inv = {idx: name for name, idx in class_map.items()}
with open(classes_txt, "w", encoding="utf-8") as handle:
for idx in range(len(inv)):
handle.write(f"{inv[idx]}\n")
except Exception:
# Non-fatal
pass
QMessageBox.information(
self,
"Export Labels",
f"Exported {lines_written} label line(s) to:\n{label_path}\n\nSkipped {skipped} invalid detection(s).",
)
def _infer_yolo_label_path(self, image_path: Path) -> Path:
"""Infer a YOLO label path from an image path.
If the image lives under an `images/` directory (anywhere in the path), we mirror the
subpath under a sibling `labels/` directory at the same level.
Example:
/dataset/train/images/sub/img.jpg -> /dataset/train/labels/sub/img.txt
"""
resolved = image_path.expanduser().resolve()
# Find the nearest ancestor directory named 'images'
images_dir: Optional[Path] = None
for parent in [resolved.parent, *resolved.parents]:
if parent.name.lower() == "images":
images_dir = parent
break
if images_dir is not None:
rel = resolved.relative_to(images_dir)
labels_dir = images_dir.parent / "labels"
return (labels_dir / rel).with_suffix(".txt")
# Fallback: create a local sibling labels folder next to the image.
return (resolved.parent / "labels" / resolved.name).with_suffix(".txt")
def _build_detection_class_index_map(self, detections: List[Dict]) -> Dict[str, int]:
"""Build a stable class_name -> YOLO class index mapping.
Preference order:
1) Database object_classes table (alphabetical class_name order)
2) Fallback to class_name values present in the detections (alphabetical)
"""
names: List[str] = []
try:
db_classes = self.db_manager.get_object_classes() or []
names = [str(row.get("class_name")) for row in db_classes if row.get("class_name")]
except Exception:
names = []
if not names:
observed = sorted({str(det.get("class_name")) for det in detections if det.get("class_name")})
names = list(observed)
return {name: idx for idx, name in enumerate(names)}
def _format_detection_as_yolo_line(self, det: Dict, class_map: Dict[str, int]) -> Optional[str]:
"""Convert a detection row to a YOLO label line.
- If segmentation_mask is present, exports segmentation polygon format:
class x1 y1 x2 y2 ...
(normalized coordinates)
- Otherwise exports bbox format:
class x_center y_center width height
(normalized coordinates)
"""
class_name = det.get("class_name")
if not class_name or class_name not in class_map:
return None
class_idx = class_map[class_name]
mask = det.get("segmentation_mask")
polygon = self._convert_segmentation_mask_to_polygon(mask)
if polygon:
coords = " ".join(f"{value:.6f}" for value in polygon)
return f"{class_idx} {coords}".strip()
bbox = self._convert_bbox_to_yolo_xywh(det)
if bbox is None:
return None
x_center, y_center, width, height = bbox
return f"{class_idx} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}"
def _convert_bbox_to_yolo_xywh(self, det: Dict) -> Optional[Tuple[float, float, float, float]]:
"""Convert stored xyxy (normalized) bbox to YOLO xywh (normalized)."""
x_min = det.get("x_min")
y_min = det.get("y_min")
x_max = det.get("x_max")
y_max = det.get("y_max")
if any(v is None for v in (x_min, y_min, x_max, y_max)):
return None
try:
x_min_f = self._clamp01(float(x_min))
y_min_f = self._clamp01(float(y_min))
x_max_f = self._clamp01(float(x_max))
y_max_f = self._clamp01(float(y_max))
except (TypeError, ValueError):
return None
width = max(0.0, x_max_f - x_min_f)
height = max(0.0, y_max_f - y_min_f)
if width <= 0.0 or height <= 0.0:
return None
x_center = x_min_f + width / 2.0
y_center = y_min_f + height / 2.0
return x_center, y_center, width, height
def _convert_segmentation_mask_to_polygon(self, mask_data) -> List[float]:
"""Convert stored segmentation_mask [[x,y], ...] to YOLO polygon coords [x1,y1,...]."""
if not isinstance(mask_data, list):
return []
coords: List[float] = []
for point in mask_data:
if not isinstance(point, (list, tuple)) or len(point) < 2:
continue
try:
x = self._clamp01(float(point[0]))
y = self._clamp01(float(point[1]))
except (TypeError, ValueError):
continue
coords.extend([x, y])
# Need at least 3 points => 6 values.
return coords if len(coords) >= 6 else []
@staticmethod
def _clamp01(value: float) -> float:
if value < 0.0:
return 0.0
if value > 1.0:
return 1.0
return value
def _load_detections_for_selection(self, entry: Dict): def _load_detections_for_selection(self, entry: Dict):
"""Load detection records for the selected image/model pair.""" """Load detection records for the selected image/model pair."""

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import yaml import yaml
import numpy as np
from PySide6.QtCore import Qt, QThread, Signal from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import ( from PySide6.QtWidgets import (
QWidget, QWidget,
@@ -34,7 +35,7 @@ from PySide6.QtWidgets import (
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.model.yolo_wrapper import YOLOWrapper from src.model.yolo_wrapper import YOLOWrapper
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
}, },
} }
] ]
computed_total = sum( computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False self._stop_requested = False
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
class TrainingTab(QWidget): class TrainingTab(QWidget):
"""Training tab for model training.""" """Training tab for model training."""
def __init__( def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
self.model_version_edit = QLineEdit("v1") self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit) form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get( default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
"models.default_base_model", "yolov8s-seg.pt"
)
base_model_choices = self.config_manager.get("models.base_model_choices", []) base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox() self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "") self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices: for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice) self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect( self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
self._on_base_model_preset_changed
)
form_layout.addRow("Base Model Preset:", self.base_model_combo) form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout() base_model_layout = QHBoxLayout()
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
group_layout = QVBoxLayout() group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune") self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = ( two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
training_defaults.get("two_stage", {}) if training_defaults else {} self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled) self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox) group_layout.addWidget(self.two_stage_checkbox)
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
stage2_group.setLayout(stage2_form) stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group) controls_layout.addWidget(stage2_group)
helper_label = QLabel( helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label.setWordWrap(True) helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label) controls_layout.addWidget(helper_label)
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
if normalized == preset_value: if normalized == preset_value:
target_index = idx target_index = idx
break break
if normalized.endswith(f"/{preset_value}") or normalized.endswith( if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
f"\\{preset_value}"
):
target_index = idx target_index = idx
break break
self.base_model_combo.blockSignals(True) self.base_model_combo.blockSignals(True)
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
def _browse_dataset(self): def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml.""" """Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get( start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
"training.last_dataset_dir", "data/datasets"
)
start_path = Path(start_dir).expanduser() start_path = Path(start_dir).expanduser()
if not start_path.exists(): if not start_path.exists():
start_path = Path.cwd() start_path = Path.cwd()
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
return return
except Exception as exc: except Exception as exc:
logger.exception("Unexpected error while generating data.yaml") logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error( self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
"Unexpected error while generating data.yaml. Check logs for details."
)
QMessageBox.critical( QMessageBox.critical(
self, self,
"data.yaml Generation Failed", "data.yaml Generation Failed",
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
self.selected_dataset = info self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type] self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText( self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
self._format_split_info(info["splits"].get("train"))
)
self.val_count_label.setText(self._format_split_info(info["splits"].get("val"))) self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText( self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
self._format_split_info(info["splits"].get("test"))
)
self.num_classes_label.setText(str(info["num_classes"])) self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or "" class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names) self.class_names_label.setText(class_names)
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
if split_path.exists(): if split_path.exists():
split_info["count"] = self._count_images(split_path) split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0: if split_info["count"] == 0:
warnings.append( warnings.append(f"No images found for {split_name} split at {split_path}")
f"No images found for {split_name} split at {split_path}"
)
else: else:
warnings.append( warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
f"{split_name.capitalize()} path does not exist: {split_path}"
)
else: else:
if split_name in ("train", "val"): if split_name in ("train", "val"):
warnings.append( warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
f"{split_name.capitalize()} split missing in data.yaml"
)
splits[split_name] = split_info splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names")) names_list = self._normalize_class_names(data.get("names"))
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
if not names_list and nc_value: if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))] names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)): elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append( warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
)
dataset_name = data.get("name") or base_path.name dataset_name = data.get("name") or base_path.name
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
class_index_map = self._build_class_index_map(dataset_info) class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map: if not class_index_map:
self._append_training_log( self._append_training_log("Skipping label export: dataset classes do not match database entries.")
"Skipping label export: dataset classes do not match database entries."
)
return return
dataset_root_str = dataset_info.get("root") dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path") dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = ( dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
)
dataset_root: Optional[Path] dataset_root: Optional[Path]
if dataset_root_str: if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve() dataset_root = Path(dataset_root_str).resolve()
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
if stats["registered_images"]: if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations." message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]: if stats["missing_records"]:
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written." message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
split_messages.append(message) split_messages.append(message)
for msg in split_messages: for msg in split_messages:
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
continue continue
processed_images += 1 processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix( label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
".txt"
)
label_path.parent.mkdir(parents=True, exist_ok=True) label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image( found, annotation_entries = self._fetch_annotations_for_image(
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
for entry in annotation_entries: for entry in annotation_entries:
polygon = entry.get("polygon") or [] polygon = entry.get("polygon") or []
if polygon: if polygon:
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
# coords += " "
coords = " ".join(f"{value:.6f}" for value in polygon) coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n") handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1 annotations_written += 1
elif entry.get("bbox"): elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"] x_center, y_center, width, height = entry["bbox"]
handle.write( handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
)
annotations_written += 1 annotations_written += 1
total_annotations += annotations_written total_annotations += annotations_written
cache_reset_root = labels_dir.parent cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root) self._invalidate_split_cache(cache_reset_root)
if processed_images == 0: if processed_images == 0:
self._append_training_log( self._append_training_log(f"[{split_name}] No images found to export labels for.")
f"[{split_name}] No images found to export labels for."
)
return None return None
return { return {
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
xs.append(x_val) xs.append(x_val)
ys.append(y_val) ys.append(y_val)
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
print("Closing polygon")
coords.extend(coords[:2])
if len(coords) < 6: if len(coords) < 6:
continue continue
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
+ abs((min(ys) if ys else 0.0) - y_min) + abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max) + abs((max(ys) if ys else 0.0) - y_max)
) )
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
score = (x_center, y_center, width, height)
candidates.append((score, coords)) candidates.append((score, coords))
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
return 1.0 return 1.0
return value return value
def _prepare_dataset_for_training( def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
dataset_info = dataset_info or ( dataset_info = dataset_info or (
self.selected_dataset self.selected_dataset
if self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml) else self._parse_dataset_yaml(dataset_yaml)
) )
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
cache_root = self._get_rgb_cache_root(dataset_yaml) cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml" rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists(): if rgb_yaml.exists():
self._append_training_log( self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
return rgb_yaml return rgb_yaml
self._append_training_log( self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._build_rgb_dataset(cache_root, dataset_info) self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml return rgb_yaml
@@ -1303,6 +1267,14 @@ class TrainingTab(QWidget):
sample_image = self._find_first_image(images_dir) sample_image = self._find_first_image(images_dir)
if not sample_image: if not sample_image:
return False return False
# Do not force an RGB cache for TIFF datasets.
# We handle grayscale/16-bit TIFFs via runtime Ultralytics patches that:
# - load TIFFs with `tifffile`
# - replicate grayscale to 3 channels without quantization
# - normalize uint16 correctly during training
if sample_image.suffix.lower() in {".tif", ".tiff"}:
return False
try: try:
img = Image(sample_image) img = Image(sample_image)
return img.pil_image.mode.upper() != "RGB" return img.pil_image.mode.upper() != "RGB"
@@ -1368,7 +1340,7 @@ class TrainingTab(QWidget):
img_obj = Image(src) img_obj = Image(src)
pil_img = img_obj.pil_image pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1: if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img) rgb_img = img_obj.convert_grayscale_to_rgb_preserve_range()
else: else:
rgb_img = pil_img.convert("RGB") rgb_img = pil_img.convert("RGB")
rgb_img.save(dst) rgb_img.save(dst)
@@ -1455,15 +1427,12 @@ class TrainingTab(QWidget):
dataset_path = Path(dataset_yaml).expanduser() dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists(): if not dataset_path.exists():
QMessageBox.warning( QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
self, "Invalid Dataset", "Selected data.yaml file does not exist."
)
return return
dataset_info = ( dataset_info = (
self.selected_dataset self.selected_dataset
if self.selected_dataset if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
and self.selected_dataset.get("yaml_path") == str(dataset_path)
else self._parse_dataset_yaml(dataset_path) else self._parse_dataset_yaml(dataset_path)
) )
@@ -1472,16 +1441,12 @@ class TrainingTab(QWidget):
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info) dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path: if dataset_to_use != dataset_path:
self._append_training_log( self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
params = self._collect_training_params() params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params) stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan params["stage_plan"] = stage_plan
total_planned_epochs = ( total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
params["total_planned_epochs"] = total_planned_epochs params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params self._active_training_params = params
self._training_cancelled = False self._training_cancelled = False
@@ -1490,9 +1455,7 @@ class TrainingTab(QWidget):
self._append_training_log("Two-stage fine-tuning schedule:") self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan) self._log_stage_plan(stage_plan)
self._append_training_log( self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self.training_progress_bar.setVisible(True) self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(max(1, total_planned_epochs)) self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
@@ -1520,9 +1483,7 @@ class TrainingTab(QWidget):
def _stop_training(self): def _stop_training(self):
if self.training_worker and self.training_worker.isRunning(): if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True self._training_cancelled = True
self._append_training_log( self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
"Stop requested. Waiting for the current epoch to finish..."
)
self.training_worker.stop() self.training_worker.stop()
self.stop_training_button.setEnabled(False) self.stop_training_button.setEnabled(False)
@@ -1558,9 +1519,7 @@ class TrainingTab(QWidget):
if worker.isRunning(): if worker.isRunning():
if not worker.wait(wait_timeout_ms): if not worker.wait(wait_timeout_ms):
logger.warning( logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
"Training worker did not finish within %sms", wait_timeout_ms
)
worker.deleteLater() worker.deleteLater()
@@ -1577,16 +1536,12 @@ class TrainingTab(QWidget):
self._set_training_state(False) self._set_training_state(False)
self.training_progress_bar.setVisible(False) self.training_progress_bar.setVisible(False)
def _on_training_progress( def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
):
self.training_progress_bar.setMaximum(total_epochs) self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch) self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"] parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics: if metrics:
metric_text = ", ".join( metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
f"{key}: {value:.4f}" for key, value in metrics.items()
)
parts.append(metric_text) parts.append(metric_text)
self._append_training_log(" | ".join(parts)) self._append_training_log(" | ".join(parts))
@@ -1613,9 +1568,7 @@ class TrainingTab(QWidget):
f"Model trained but not registered: {exc}", f"Model trained but not registered: {exc}",
) )
else: else:
QMessageBox.information( QMessageBox.information(self, "Training Complete", "Training finished successfully.")
self, "Training Complete", "Training finished successfully."
)
def _on_training_error(self, message: str): def _on_training_error(self, message: str):
self._cleanup_training_worker() self._cleanup_training_worker()
@@ -1661,9 +1614,7 @@ class TrainingTab(QWidget):
metrics=results.get("metrics"), metrics=results.get("metrics"),
) )
self._append_training_log( self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
)
self._active_training_params = None self._active_training_params = None
def _set_training_state(self, is_training: bool): def _set_training_state(self, is_training: bool):
@@ -1706,9 +1657,7 @@ class TrainingTab(QWidget):
def _browse_save_dir(self): def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models" start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory( directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
self, "Select Save Directory", start_path
)
if directory: if directory:
self.save_dir_edit.setText(directory) self.save_dir_edit.setText(directory)

View File

@@ -2,45 +2,554 @@
Validation tab for the microscopy object detection application. Validation tab for the microscopy object detection application.
""" """
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple
from PySide6.QtCore import Qt, QSize
from PySide6.QtGui import QPainter, QPixmap
from PySide6.QtWidgets import (
QWidget,
QVBoxLayout,
QLabel,
QGroupBox,
QHBoxLayout,
QPushButton,
QComboBox,
QFormLayout,
QScrollArea,
QGridLayout,
QFrame,
QTableWidget,
QTableWidgetItem,
QHeaderView,
QSplitter,
QListWidget,
QListWidgetItem,
QAbstractItemView,
QGraphicsView,
QGraphicsScene,
QGraphicsPixmapItem,
)
from src.database.db_manager import DatabaseManager from src.database.db_manager import DatabaseManager
from src.utils.config_manager import ConfigManager from src.utils.config_manager import ConfigManager
from src.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass(frozen=True)
class _PlotItem:
label: str
path: Path
class _ZoomableImageView(QGraphicsView):
"""Zoomable image viewer.
- Mouse wheel: zoom in/out
- Left mouse drag: pan (ScrollHandDrag)
"""
def __init__(self, parent: Optional[QWidget] = None):
super().__init__(parent)
self._scene = QGraphicsScene(self)
self.setScene(self._scene)
self._pixmap_item = QGraphicsPixmapItem()
self._scene.addItem(self._pixmap_item)
# QGraphicsView render hints are QPainter.RenderHints.
self.setRenderHints(self.renderHints() | QPainter.RenderHint.SmoothPixmapTransform)
self.setDragMode(QGraphicsView.DragMode.ScrollHandDrag)
self.setTransformationAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self.setResizeAnchor(QGraphicsView.ViewportAnchor.AnchorUnderMouse)
self._has_pixmap = False
def clear(self) -> None:
self._pixmap_item.setPixmap(QPixmap())
self._scene.setSceneRect(0, 0, 1, 1)
self.resetTransform()
self._has_pixmap = False
def set_pixmap(self, pixmap: QPixmap, *, fit: bool = True) -> None:
self._pixmap_item.setPixmap(pixmap)
self._scene.setSceneRect(pixmap.rect())
self._has_pixmap = not pixmap.isNull()
self.resetTransform()
if fit and self._has_pixmap:
self.fitInView(self._pixmap_item, Qt.AspectRatioMode.KeepAspectRatio)
def wheelEvent(self, event) -> None: # type: ignore[override]
if not self._has_pixmap:
return
zoom_in_factor = 1.25
zoom_out_factor = 1.0 / zoom_in_factor
factor = zoom_in_factor if event.angleDelta().y() > 0 else zoom_out_factor
self.scale(factor, factor)
class ValidationTab(QWidget): class ValidationTab(QWidget):
"""Validation tab placeholder.""" """Validation tab that shows stored validation metrics + plots for a selected model."""
def __init__( def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
super().__init__(parent) super().__init__(parent)
self.db_manager = db_manager self.db_manager = db_manager
self.config_manager = config_manager self.config_manager = config_manager
self._models: List[Dict[str, Any]] = []
self._selected_model_id: Optional[int] = None
self._plot_widgets: List[QWidget] = []
self._plot_items: List[_PlotItem] = []
self._setup_ui() self._setup_ui()
self.refresh()
def _setup_ui(self): def _setup_ui(self):
"""Setup user interface.""" """Setup user interface."""
layout = QVBoxLayout() layout = QVBoxLayout(self)
group = QGroupBox("Validation") # ===== Header controls =====
group_layout = QVBoxLayout() header = QGroupBox("Validation")
label = QLabel( header_layout = QVBoxLayout()
"Validation functionality will be implemented here.\n\n" header_row = QHBoxLayout()
"Features:\n"
"- Model validation\n"
"- Metrics visualization\n"
"- Confusion matrix\n"
"- Precision-Recall curves"
)
group_layout.addWidget(label)
group.setLayout(group_layout)
layout.addWidget(group) header_row.addWidget(QLabel("Select model:"))
layout.addStretch()
self.setLayout(layout) self.model_combo = QComboBox()
self.model_combo.setMinimumWidth(420)
self.model_combo.currentIndexChanged.connect(self._on_model_selected)
header_row.addWidget(self.model_combo, 1)
self.refresh_btn = QPushButton("Refresh")
self.refresh_btn.clicked.connect(self.refresh)
header_row.addWidget(self.refresh_btn)
header_row.addStretch()
header_layout.addLayout(header_row)
self.header_status = QLabel("No models loaded.")
self.header_status.setWordWrap(True)
header_layout.addWidget(self.header_status)
header.setLayout(header_layout)
layout.addWidget(header)
# ===== Metrics =====
metrics_group = QGroupBox("Validation Metrics")
metrics_layout = QVBoxLayout()
self.metrics_form = QFormLayout()
self.metric_labels: Dict[str, QLabel] = {}
for key in ("mAP50", "mAP50-95", "precision", "recall", "fitness"):
value_label = QLabel("")
value_label.setTextInteractionFlags(Qt.TextSelectableByMouse)
self.metric_labels[key] = value_label
self.metrics_form.addRow(f"{key}:", value_label)
metrics_layout.addLayout(self.metrics_form)
self.per_class_table = QTableWidget(0, 3)
self.per_class_table.setHorizontalHeaderLabels(["Class", "AP", "AP50"])
self.per_class_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.per_class_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeToContents)
self.per_class_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
self.per_class_table.setEditTriggers(QTableWidget.NoEditTriggers)
self.per_class_table.setMinimumHeight(160)
metrics_layout.addWidget(QLabel("Per-class metrics (if available):"))
metrics_layout.addWidget(self.per_class_table)
metrics_group.setLayout(metrics_layout)
layout.addWidget(metrics_group)
# ===== Plots =====
plots_group = QGroupBox("Validation Plots")
plots_layout = QVBoxLayout()
self.plots_status = QLabel("Select a model to see validation plots.")
self.plots_status.setWordWrap(True)
plots_layout.addWidget(self.plots_status)
self.plots_splitter = QSplitter(Qt.Orientation.Horizontal)
# Left: selected image viewer
left_widget = QWidget()
left_layout = QVBoxLayout(left_widget)
left_layout.setContentsMargins(0, 0, 0, 0)
self.selected_plot_title = QLabel("No image selected.")
self.selected_plot_title.setWordWrap(True)
self.selected_plot_title.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_title)
self.plot_view = _ZoomableImageView()
self.plot_view.setMinimumHeight(360)
left_layout.addWidget(self.plot_view, 1)
self.selected_plot_path = QLabel("")
self.selected_plot_path.setWordWrap(True)
self.selected_plot_path.setStyleSheet("color: #888;")
self.selected_plot_path.setTextInteractionFlags(Qt.TextSelectableByMouse)
left_layout.addWidget(self.selected_plot_path)
# Right: scrollable list
right_widget = QWidget()
right_layout = QVBoxLayout(right_widget)
right_layout.setContentsMargins(0, 0, 0, 0)
right_layout.addWidget(QLabel("Images:"))
self.plots_list = QListWidget()
self.plots_list.setSelectionMode(QAbstractItemView.SelectionMode.SingleSelection)
self.plots_list.setIconSize(QSize(160, 160))
self.plots_list.itemSelectionChanged.connect(self._on_plot_item_selected)
right_layout.addWidget(self.plots_list, 1)
self.plots_splitter.addWidget(left_widget)
self.plots_splitter.addWidget(right_widget)
self.plots_splitter.setStretchFactor(0, 3)
self.plots_splitter.setStretchFactor(1, 1)
plots_layout.addWidget(self.plots_splitter, 1)
plots_group.setLayout(plots_layout)
layout.addWidget(plots_group, 1)
layout.addStretch(0)
self._clear_metrics()
self._clear_plots()
# ==================== Public API ====================
def refresh(self): def refresh(self):
"""Refresh the tab.""" """Refresh the tab."""
self._load_models()
self._populate_model_combo()
self._restore_or_select_default_model()
# ==================== Internal: models ====================
def _load_models(self) -> None:
try:
self._models = self.db_manager.get_models() or []
except Exception as exc:
logger.error("Failed to load models: %s", exc)
self._models = []
def _populate_model_combo(self) -> None:
self.model_combo.blockSignals(True)
self.model_combo.clear()
self.model_combo.addItem("Select a model…", None)
for model in self._models:
model_id = model.get("id")
name = (model.get("model_name") or "").strip()
version = (model.get("model_version") or "").strip()
created_at = model.get("created_at")
label = f"{name} {version}".strip()
if created_at:
label = f"{label} ({created_at})"
self.model_combo.addItem(label, model_id)
self.model_combo.blockSignals(False)
if self._models:
self.header_status.setText(f"Loaded {len(self._models)} model(s).")
else:
self.header_status.setText("No models found. Train a model first.")
def _restore_or_select_default_model(self) -> None:
if not self._models:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
return
# Keep selection if still present.
if self._selected_model_id is not None:
for idx in range(1, self.model_combo.count()):
if self.model_combo.itemData(idx) == self._selected_model_id:
self.model_combo.setCurrentIndex(idx)
return
# Otherwise select the newest model (top of get_models ORDER BY created_at DESC).
first_model_id = self.model_combo.itemData(1) if self.model_combo.count() > 1 else None
if first_model_id is not None:
self.model_combo.setCurrentIndex(1)
def _on_model_selected(self, index: int) -> None:
model_id = self.model_combo.itemData(index)
if not model_id:
self._selected_model_id = None
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Select a model to see validation plots.")
return
self._selected_model_id = int(model_id)
model = self._get_model_by_id(self._selected_model_id)
if not model:
self._clear_metrics()
self._clear_plots()
self.plots_status.setText("Selected model not found.")
return
self._render_metrics(model)
self._render_plots(model)
def _get_model_by_id(self, model_id: int) -> Optional[Dict[str, Any]]:
for model in self._models:
if model.get("id") == model_id:
return model
try:
return self.db_manager.get_model_by_id(model_id)
except Exception:
return None
# ==================== Internal: metrics ====================
def _clear_metrics(self) -> None:
for label in self.metric_labels.values():
label.setText("")
self.per_class_table.setRowCount(0)
def _render_metrics(self, model: Dict[str, Any]) -> None:
self._clear_metrics()
metrics: Dict[str, Any] = model.get("metrics") or {}
# Training tab stores metrics under results['metrics'] in training results payload.
if isinstance(metrics, dict) and "metrics" in metrics and isinstance(metrics.get("metrics"), dict):
metrics = metrics.get("metrics") or {}
def set_metric(key: str, value: Any) -> None:
if key not in self.metric_labels:
return
if value is None:
self.metric_labels[key].setText("")
return
try:
self.metric_labels[key].setText(f"{float(value):.4f}")
except Exception:
self.metric_labels[key].setText(str(value))
set_metric("mAP50", metrics.get("mAP50"))
set_metric("mAP50-95", metrics.get("mAP50-95") or metrics.get("mAP50_95") or metrics.get("mAP50-95"))
set_metric("precision", metrics.get("precision"))
set_metric("recall", metrics.get("recall"))
set_metric("fitness", metrics.get("fitness"))
# Optional per-class metrics
class_metrics = metrics.get("class_metrics") if isinstance(metrics, dict) else None
if isinstance(class_metrics, dict) and class_metrics:
items = sorted(class_metrics.items(), key=lambda kv: str(kv[0]))
self.per_class_table.setRowCount(len(items))
for row, (cls_name, cls_stats) in enumerate(items):
ap = (cls_stats or {}).get("ap")
ap50 = (cls_stats or {}).get("ap50")
self.per_class_table.setItem(row, 0, QTableWidgetItem(str(cls_name)))
self.per_class_table.setItem(row, 1, QTableWidgetItem(self._format_float(ap)))
self.per_class_table.setItem(row, 2, QTableWidgetItem(self._format_float(ap50)))
else:
self.per_class_table.setRowCount(0)
@staticmethod
def _format_float(value: Any) -> str:
if value is None:
return ""
try:
return f"{float(value):.4f}"
except Exception:
return str(value)
# ==================== Internal: plots ====================
def _clear_plots(self) -> None:
# Remove legacy grid widgets (from the initial implementation).
for widget in self._plot_widgets:
widget.setParent(None)
widget.deleteLater()
self._plot_widgets = []
self._plot_items = []
if hasattr(self, "plots_list"):
self.plots_list.blockSignals(True)
self.plots_list.clear()
self.plots_list.blockSignals(False)
if hasattr(self, "plot_view"):
self.plot_view.clear()
if hasattr(self, "selected_plot_title"):
self.selected_plot_title.setText("No image selected.")
if hasattr(self, "selected_plot_path"):
self.selected_plot_path.setText("")
def _render_plots(self, model: Dict[str, Any]) -> None:
self._clear_plots()
plot_dirs = self._infer_run_directories(model)
plot_items = self._discover_plot_items(plot_dirs)
if not plot_items:
dirs_text = "\n".join(str(p) for p in plot_dirs if p)
self.plots_status.setText(
"No validation plot images found for this model.\n\n"
"Searched directories:\n" + (dirs_text or "(none)")
)
return
self._plot_items = list(plot_items)
self.plots_status.setText(f"Found {len(plot_items)} plot image(s). Select one to view/zoom.")
self.plots_list.blockSignals(True)
self.plots_list.clear()
for idx, item in enumerate(self._plot_items):
qitem = QListWidgetItem(item.label)
qitem.setData(Qt.ItemDataRole.UserRole, idx)
pix = QPixmap(str(item.path))
if not pix.isNull():
thumb = pix.scaled(
self.plots_list.iconSize(),
Qt.AspectRatioMode.KeepAspectRatio,
Qt.TransformationMode.SmoothTransformation,
)
qitem.setIcon(thumb)
self.plots_list.addItem(qitem)
self.plots_list.blockSignals(False)
if self.plots_list.count() > 0:
self.plots_list.setCurrentRow(0)
def _on_plot_item_selected(self) -> None:
if not self._plot_items:
return
selected = self.plots_list.selectedItems()
if not selected:
return
idx = selected[0].data(Qt.ItemDataRole.UserRole)
try:
idx_int = int(idx)
except Exception:
return
if idx_int < 0 or idx_int >= len(self._plot_items):
return
plot = self._plot_items[idx_int]
self.selected_plot_title.setText(plot.label)
self.selected_plot_path.setText(str(plot.path))
pix = QPixmap(str(plot.path))
if pix.isNull():
self.plot_view.clear()
return
self.plot_view.set_pixmap(pix, fit=True)
def _infer_run_directories(self, model: Dict[str, Any]) -> List[Path]:
dirs: List[Path] = []
# 1) Infer from model_path: .../<run>/weights/best.pt -> <run>
model_path = model.get("model_path")
if model_path:
try:
p = Path(str(model_path)).expanduser()
if p.name.lower().endswith(".pt"):
# If it lives under weights/, use parent.parent.
if p.parent.name == "weights" and p.parent.parent.exists():
dirs.append(p.parent.parent)
elif p.parent.exists():
dirs.append(p.parent)
except Exception:
pass pass
# 2) Look at training_params.stage_results[].results.save_dir
training_params = model.get("training_params") or {}
stage_results = None
if isinstance(training_params, dict):
stage_results = training_params.get("stage_results")
if isinstance(stage_results, list):
for stage in stage_results:
results = (stage or {}).get("results")
save_dir = (results or {}).get("save_dir") if isinstance(results, dict) else None
if save_dir:
try:
save_path = Path(str(save_dir)).expanduser()
if save_path.exists():
dirs.append(save_path)
except Exception:
continue
# Deduplicate while preserving order.
unique: List[Path] = []
seen: set[str] = set()
for d in dirs:
try:
resolved = str(d.resolve())
except Exception:
resolved = str(d)
if resolved not in seen and d.exists() and d.is_dir():
seen.add(resolved)
unique.append(d)
return unique
def _discover_plot_items(self, directories: Sequence[Path]) -> List[_PlotItem]:
# Prefer canonical Ultralytics filenames first, then fall back to any png/jpg.
preferred_names = [
"results.png",
"results.jpg",
"confusion_matrix.png",
"confusion_matrix_normalized.png",
"labels.jpg",
"labels.png",
"BoxPR_curve.png",
"BoxP_curve.png",
"BoxR_curve.png",
"BoxF1_curve.png",
"MaskPR_curve.png",
"MaskP_curve.png",
"MaskR_curve.png",
"MaskF1_curve.png",
"val_batch0_pred.jpg",
"val_batch0_labels.jpg",
]
found: List[_PlotItem] = []
seen: set[str] = set()
for d in directories:
# 1) Preferred
for name in preferred_names:
p = d / name
if p.exists() and p.is_file():
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{name} (from {d.name})", path=p))
# 2) Curated globs
for pattern in ("train_batch*.jpg", "val_batch*.jpg", "*curve*.png"):
for p in sorted(d.glob(pattern)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# 3) Fallback: any top-level png/jpg (excluding weights dir contents)
for ext in ("*.png", "*.jpg", "*.jpeg", "*.webp"):
for p in sorted(d.glob(ext)):
if not p.is_file():
continue
key = str(p)
if key in seen:
continue
seen.add(key)
found.append(_PlotItem(label=f"{p.name} (from {d.name})", path=p))
# Keep list bounded to avoid UI overload for huge runs.
return found[:60]

View File

@@ -18,7 +18,7 @@ from PySide6.QtGui import (
QPaintEvent, QPaintEvent,
QPolygonF, QPolygonF,
) )
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError from src.utils.image import Image, ImageLoadError
@@ -79,9 +79,7 @@ def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float,
return [start, end] return [start, end]
def simplify_polyline( def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
points: List[Tuple[float, float]], epsilon: float
) -> List[Tuple[float, float]]:
""" """
Simplify a polyline with RDP while preserving closure semantics. Simplify a polyline with RDP while preserving closure semantics.
@@ -145,6 +143,10 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_step = 0.1 self.zoom_step = 0.1
self.zoom_wheel_step = 0.15 self.zoom_wheel_step = 0.15
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
# will scale to fill the available viewport while preserving aspect ratio.
self._auto_fit_to_view: bool = False
# Drawing / interaction state # Drawing / interaction state
self.is_drawing = False self.is_drawing = False
self.polyline_enabled = False self.polyline_enabled = False
@@ -175,6 +177,35 @@ class AnnotationCanvasWidget(QWidget):
self._setup_ui() self._setup_ui()
def set_auto_fit_to_view(self, enabled: bool):
"""Enable/disable automatic zoom-to-fit behavior."""
self._auto_fit_to_view = bool(enabled)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def fit_to_view(self, padding_px: int = 6):
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
if self.original_pixmap is None:
return
viewport = self.scroll_area.viewport().size()
available_w = max(1, int(viewport.width()) - int(padding_px))
available_h = max(1, int(viewport.height()) - int(padding_px))
img_w = max(1, int(self.original_pixmap.width()))
img_h = max(1, int(self.original_pixmap.height()))
scale_w = available_w / img_w
scale_h = available_h / img_h
new_scale = min(scale_w, scale_h)
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
if abs(new_scale - self.zoom_scale) < 1e-4:
return
self.zoom_scale = new_scale
self._apply_zoom()
def _setup_ui(self): def _setup_ui(self):
"""Setup user interface.""" """Setup user interface."""
layout = QVBoxLayout() layout = QVBoxLayout()
@@ -187,9 +218,7 @@ class AnnotationCanvasWidget(QWidget):
self.canvas_label = QLabel("No image loaded") self.canvas_label = QLabel("No image loaded")
self.canvas_label.setAlignment(Qt.AlignCenter) self.canvas_label.setAlignment(Qt.AlignCenter)
self.canvas_label.setStyleSheet( self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
"QLabel { background-color: #2b2b2b; color: #888; }"
)
self.canvas_label.setScaledContents(False) self.canvas_label.setScaledContents(False)
self.canvas_label.setMouseTracking(True) self.canvas_label.setMouseTracking(True)
@@ -212,9 +241,18 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_scale = 1.0 self.zoom_scale = 1.0
self.clear_annotations() self.clear_annotations()
self._display_image() self._display_image()
logger.debug(
f"Loaded image into annotation canvas: {image.width}x{image.height}" # Defer fit-to-view until the widget has a valid viewport size.
) if self._auto_fit_to_view:
QTimer.singleShot(0, self.fit_to_view)
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
def resizeEvent(self, event):
"""Optionally keep the image fitted when the widget is resized."""
super().resizeEvent(event)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def clear(self): def clear(self):
"""Clear the displayed image and all annotations.""" """Clear the displayed image and all annotations."""
@@ -250,12 +288,10 @@ class AnnotationCanvasWidget(QWidget):
# Get image data in a format compatible with Qt # Get image data in a format compatible with Qt
if self.current_image.channels in (3, 4): if self.current_image.channels in (3, 4):
image_data = self.current_image.get_rgb() image_data = self.current_image.get_rgb()
height, width = image_data.shape[:2]
else: else:
image_data = self.current_image.get_grayscale() image_data = self.current_image.get_qt_rgb()
height, width = image_data.shape
image_data = np.ascontiguousarray(image_data) height, width = image_data.shape[:2]
bytes_per_line = image_data.strides[0] bytes_per_line = image_data.strides[0]
qimage = QImage( qimage = QImage(
@@ -263,7 +299,7 @@ class AnnotationCanvasWidget(QWidget):
width, width,
height, height,
bytes_per_line, bytes_per_line,
self.current_image.qtimage_format, QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope ).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
self.original_pixmap = QPixmap.fromImage(qimage) self.original_pixmap = QPixmap.fromImage(qimage)
@@ -291,22 +327,14 @@ class AnnotationCanvasWidget(QWidget):
scaled_width, scaled_width,
scaled_height, scaled_height,
Qt.KeepAspectRatio, Qt.KeepAspectRatio,
( (Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
) )
scaled_annotations = self.annotation_pixmap.scaled( scaled_annotations = self.annotation_pixmap.scaled(
scaled_width, scaled_width,
scaled_height, scaled_height,
Qt.KeepAspectRatio, Qt.KeepAspectRatio,
( (Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
) )
# Composite image and annotations # Composite image and annotations
@@ -392,16 +420,11 @@ class AnnotationCanvasWidget(QWidget):
y = (pos.y() - offset_y) / self.zoom_scale y = (pos.y() - offset_y) / self.zoom_scale
# Check bounds # Check bounds
if ( if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
0 <= x < self.original_pixmap.width()
and 0 <= y < self.original_pixmap.height()
):
return (int(x), int(y)) return (int(x), int(y))
return None return None
def _find_polyline_at( def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
self, img_x: float, img_y: float, threshold_px: float = 5.0
) -> Optional[int]:
""" """
Find index of polyline whose geometry is within threshold_px of (img_x, img_y). Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
Returns the index in self.polylines, or None if none is close enough. Returns the index in self.polylines, or None if none is close enough.
@@ -423,9 +446,7 @@ class AnnotationCanvasWidget(QWidget):
# Precise distance to all segments # Precise distance to all segments
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]): for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
d = perpendicular_distance( d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
)
if d < best_dist: if d < best_dist:
best_dist = d best_dist = d
best_index = idx best_index = idx
@@ -626,11 +647,7 @@ class AnnotationCanvasWidget(QWidget):
def mouseMoveEvent(self, event: QMouseEvent): def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing.""" """Handle mouse move events for drawing."""
if ( if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
not self.is_drawing
or not self.polyline_enabled
or self.annotation_pixmap is None
):
super().mouseMoveEvent(event) super().mouseMoveEvent(event)
return return
@@ -690,15 +707,10 @@ class AnnotationCanvasWidget(QWidget):
if len(simplified) >= 2: if len(simplified) >= 2:
# Store polyline and redraw all annotations # Store polyline and redraw all annotations
self._add_polyline( self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
simplified, self.polyline_pen_color, self.polyline_pen_width
)
# Convert to normalized coordinates for metadata + signal # Convert to normalized coordinates for metadata + signal
normalized_stroke = [ normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
self._image_to_normalized_coords(int(x), int(y))
for (x, y) in simplified
]
self.all_strokes.append( self.all_strokes.append(
{ {
"points": normalized_stroke, "points": normalized_stroke,
@@ -711,8 +723,7 @@ class AnnotationCanvasWidget(QWidget):
# Emit signal with normalized coordinates # Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke) self.annotation_drawn.emit(normalized_stroke)
logger.debug( logger.debug(
f"Completed stroke with {len(simplified)} points " f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
f"(normalized len={len(normalized_stroke)})"
) )
self.current_stroke = [] self.current_stroke = []
@@ -752,9 +763,7 @@ class AnnotationCanvasWidget(QWidget):
# Store polyline as [y_norm, x_norm] to match DB convention and # Store polyline as [y_norm, x_norm] to match DB convention and
# the expectations of draw_saved_polyline(). # the expectations of draw_saved_polyline().
normalized_polyline = [ normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
[y / img_height, x / img_width] for (x, y) in polyline
]
logger.debug( logger.debug(
f"Polyline {idx}: {len(polyline)} points, " f"Polyline {idx}: {len(polyline)} points, "
@@ -774,7 +783,7 @@ class AnnotationCanvasWidget(QWidget):
self, self,
polyline: List[List[float]], polyline: List[List[float]],
color: str, color: str,
width: int = 3, width: int = 1,
annotation_id: Optional[int] = None, annotation_id: Optional[int] = None,
): ):
""" """
@@ -812,17 +821,13 @@ class AnnotationCanvasWidget(QWidget):
# Store and redraw using common pipeline # Store and redraw using common pipeline
pen_color = QColor(color) pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency pen_color.setAlpha(255) # Add semi-transparency
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id) self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
# Store in all_strokes for consistency (uses normalized coordinates) # Store in all_strokes for consistency (uses normalized coordinates)
self.all_strokes.append( self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
{"points": polyline, "color": color, "alpha": 128, "width": width}
)
logger.debug( logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
f"Drew saved polyline with {len(polyline)} points in color {color}"
)
def draw_saved_bbox( def draw_saved_bbox(
self, self,
@@ -846,9 +851,7 @@ class AnnotationCanvasWidget(QWidget):
return return
if len(bbox) != 4: if len(bbox) != 4:
logger.warning( logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
)
return return
# Convert normalized coordinates to image coordinates (for logging/debug) # Convert normalized coordinates to image coordinates (for logging/debug)
@@ -869,15 +872,11 @@ class AnnotationCanvasWidget(QWidget):
# in _redraw_annotations() together with all polylines. # in _redraw_annotations() together with all polylines.
pen_color = QColor(color) pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency pen_color.setAlpha(128) # Add semi-transparency
self.bboxes.append( self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label}) self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency # Store in all_strokes for consistency
self.all_strokes.append( self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
)
# Redraw overlay (polylines + all bounding boxes) # Redraw overlay (polylines + all bounding boxes)
self._redraw_annotations() self._redraw_annotations()

View File

@@ -1,16 +1,21 @@
""" """YOLO model wrapper for the microscopy object detection application.
YOLO model wrapper for the microscopy object detection application.
Provides a clean interface to YOLOv8 for training, validation, and inference. Notes on 16-bit TIFF support:
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
normalize `uint16` correctly.
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
""" """
from ultralytics import YOLO
from pathlib import Path from pathlib import Path
from typing import Optional, List, Dict, Callable, Any from typing import Optional, List, Dict, Callable, Any
import torch import torch
import tempfile import tempfile
import os import os
from src.utils.image import Image, convert_grayscale_to_rgb_preserve_range from src.utils.image import Image
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -31,6 +36,9 @@ class YOLOWrapper:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"YOLOWrapper initialized with device: {self.device}") logger.info(f"YOLOWrapper initialized with device: {self.device}")
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
apply_ultralytics_16bit_tiff_patches()
def load_model(self) -> bool: def load_model(self) -> bool:
""" """
Load YOLO model from path. Load YOLO model from path.
@@ -40,6 +48,9 @@ class YOLOWrapper:
""" """
try: try:
logger.info(f"Loading YOLO model from {self.model_path}") logger.info(f"Loading YOLO model from {self.model_path}")
# Import YOLO lazily to ensure runtime patches are applied first.
from ultralytics import YOLO
self.model = YOLO(self.model_path) self.model = YOLO(self.model_path)
self.model.to(self.device) self.model.to(self.device)
logger.info("Model loaded successfully") logger.info("Model loaded successfully")
@@ -85,9 +96,17 @@ class YOLOWrapper:
try: try:
logger.info(f"Starting training: {name}") logger.info(f"Starting training: {name}")
logger.info( logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
) # Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs.
kwargs.setdefault("mosaic", 0.0)
kwargs.setdefault("mixup", 0.0)
kwargs.setdefault("cutmix", 0.0)
kwargs.setdefault("copy_paste", 0.0)
kwargs.setdefault("hsv_h", 0.0)
kwargs.setdefault("hsv_s", 0.0)
kwargs.setdefault("hsv_v", 0.0)
# Train the model # Train the model
results = self.model.train( results = self.model.train(
@@ -128,9 +147,7 @@ class YOLOWrapper:
try: try:
logger.info(f"Starting validation on {split} split") logger.info(f"Starting validation on {split} split")
results = self.model.val( results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
data=data_yaml, split=split, device=self.device, **kwargs
)
logger.info("Validation completed successfully") logger.info("Validation completed successfully")
return self._format_validation_results(results) return self._format_validation_results(results)
@@ -169,17 +186,18 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}") raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source) prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
try: try:
logger.info(f"Running inference on {source}") logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict( results = self.model.predict(
source=prepared_source, source=source,
conf=conf, conf=conf,
iou=iou, iou=iou,
save=save, save=save,
save_txt=save_txt, save_txt=save_txt,
save_conf=save_conf, save_conf=save_conf,
device=self.device, device=self.device,
imgsz=imgsz,
**kwargs, **kwargs,
) )
@@ -195,13 +213,9 @@ class YOLOWrapper:
try: try:
os.remove(cleanup_path) os.remove(cleanup_path)
except OSError as cleanup_error: except OSError as cleanup_error:
logger.warning( logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
def export( def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
""" """
Export model to different format. Export model to different format.
@@ -236,21 +250,13 @@ class YOLOWrapper:
if source_path.is_file(): if source_path.is_file():
try: try:
img_obj = Image(source_path) img_obj = Image(source_path)
pil_img = img_obj.pil_image
if len(pil_img.getbands()) == 1:
rgb_img = convert_grayscale_to_rgb_preserve_range(pil_img)
else:
rgb_img = pil_img.convert("RGB")
suffix = source_path.suffix or ".png" suffix = source_path.suffix or ".png"
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False) tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp_path = tmp.name tmp_path = tmp.name
tmp.close() tmp.close()
rgb_img.save(tmp_path) img_obj.save(tmp_path)
cleanup_path = tmp_path cleanup_path = tmp_path
logger.info( logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
return tmp_path, cleanup_path return tmp_path, cleanup_path
except Exception as convert_error: except Exception as convert_error:
logger.warning( logger.warning(
@@ -263,9 +269,7 @@ class YOLOWrapper:
"""Format training results into dictionary.""" """Format training results into dictionary."""
try: try:
# Get the results dict # Get the results dict
results_dict = ( results_dict = results.results_dict if hasattr(results, "results_dict") else {}
results.results_dict if hasattr(results, "results_dict") else {}
)
formatted = { formatted = {
"success": True, "success": True,
@@ -298,9 +302,7 @@ class YOLOWrapper:
"mAP50-95": float(box_metrics.map), "mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp), "precision": float(box_metrics.mp),
"recall": float(box_metrics.mr), "recall": float(box_metrics.mr),
"fitness": ( "fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
} }
# Add per-class metrics if available # Add per-class metrics if available
@@ -310,11 +312,7 @@ class YOLOWrapper:
if idx < len(box_metrics.ap): if idx < len(box_metrics.ap):
class_metrics[name] = { class_metrics[name] = {
"ap": float(box_metrics.ap[idx]), "ap": float(box_metrics.ap[idx]),
"ap50": ( "ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
} }
formatted["class_metrics"] = class_metrics formatted["class_metrics"] = class_metrics
@@ -347,21 +345,15 @@ class YOLOWrapper:
"class_id": int(boxes.cls[i]), "class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])], "class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]), "confidence": float(boxes.conf[i]),
"bbox_normalized": [ "bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
float(v) for v in xyxyn "bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
} }
# Extract segmentation mask if available # Extract segmentation mask if available
if has_masks: if has_masks:
try: try:
# Get the mask for this detection # Get the mask for this detection
mask_data = result.masks.xy[ mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
i
] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates # Convert to normalized coordinates
if len(mask_data) > 0: if len(mask_data) > 0:
@@ -374,9 +366,7 @@ class YOLOWrapper:
else: else:
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
except Exception as mask_error: except Exception as mask_error:
logger.warning( logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
f"Error extracting mask for detection {i}: {mask_error}"
)
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
else: else:
detection["segmentation_mask"] = None detection["segmentation_mask"] = None
@@ -390,9 +380,7 @@ class YOLOWrapper:
return [] return []
@staticmethod @staticmethod
def convert_bbox_format( def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
""" """
Convert bounding box between formats. Convert bounding box between formats.

View File

@@ -54,7 +54,7 @@ class ConfigManager:
"models_directory": "data/models", "models_directory": "data/models",
"base_model_choices": [ "base_model_choices": [
"yolov8s-seg.pt", "yolov8s-seg.pt",
"yolov11s-seg.pt", "yolo11s-seg.pt",
], ],
}, },
"training": { "training": {
@@ -225,6 +225,4 @@ class ConfigManager:
def get_allowed_extensions(self) -> list: def get_allowed_extensions(self) -> list:
"""Get list of allowed image file extensions.""" """Get list of allowed image file extensions."""
return self.get( return self.get("image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS)
"image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS
)

View File

@@ -0,0 +1,103 @@
import numpy as np
from pathlib import Path
from skimage.draw import polygon
from tifffile import TiffFile
from src.database.db_manager import DatabaseManager
def read_image(image_path: Path) -> np.ndarray:
metadata = {}
with TiffFile(image_path) as tif:
image = tif.asarray()
metadata = tif.imagej_metadata
return image, metadata
def main():
polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
image = np.zeros((100, 100), dtype=np.uint8)
rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
image[rr, cc] = 255
if __name__ == "__main__":
db = DatabaseManager()
model_name = "c17"
model_id = db.get_models(filters={"model_name": model_name})[0]["id"]
print(f"Model name {model_name}, id {model_id}")
detections = db.get_detections(filters={"model_id": model_id})
file_stems = set()
for detection in detections:
file_stems.add(detection["image_filename"].split("_")[0])
print("Files:", file_stems)
for stem in file_stems:
print(stem)
detections = db.get_detections(filters={"model_id": model_id, "i.filename": f"LIKE %{stem}%"})
annotations = []
for detection in detections:
source_path = Path(detection["metadata"]["source_path"])
image, metadata = read_image(source_path)
offset = np.array(list(map(int, metadata["tile_section"].split(","))))[::-1]
scale = np.array(list(map(int, metadata["patch_size"].split(","))))[::-1]
# tile_size = np.array(list(map(int, metadata["tile_size"].split(","))))
segmentation = np.array(detection["segmentation_mask"]) # * tile_size
# print(source_path, image, metadata, segmentation.shape)
# print(offset)
# print(scale)
# print(segmentation)
# segmentation = (segmentation + offset * tile_size) / (tile_size * scale)
segmentation = (segmentation + offset) / scale
yolo_annotation = f"{detection['metadata']['class_id']} " + " ".join(
[f"{x:.6f} {y:.6f}" for x, y in segmentation]
)
annotations.append(yolo_annotation)
# print(segmentation)
# print(yolo_annotation)
# aa
print(
" ",
detection["model_name"],
detection["image_id"],
detection["image_filename"],
source_path,
metadata["label_path"],
)
# section_i_section_j = detection["image_filename"].split("_")[1].split(".")[0]
# print(" ", section_i_section_j)
label_path = metadata["label_path"]
print(" ", label_path)
with open(label_path, "w") as f:
f.write("\n".join(annotations))
exit()
for detection in detections:
print(detection["model_name"], detection["image_id"], detection["image_filename"])
print(detections[0])
# polygon_vertices = np.array([[10, 10], [50, 10], [50, 50], [10, 50]])
# image = np.zeros((100, 100), dtype=np.uint8)
# rr, cc = polygon(polygon_vertices[:, 0], polygon_vertices[:, 1])
# image[rr, cc] = 255
# import matplotlib.pyplot as plt
# plt.imshow(image, cmap='gray')
# plt.show()

View File

@@ -6,16 +6,55 @@ import cv2
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
from PIL import Image as PILImage
from src.utils.logger import get_logger from src.utils.logger import get_logger
from src.utils.file_utils import validate_file_path, is_image_file from src.utils.file_utils import validate_file_path, is_image_file
from PySide6.QtGui import QImage from PySide6.QtGui import QImage
from tifffile import imread, imwrite
logger = get_logger(__name__) logger = get_logger(__name__)
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
"""
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
Args:
arr: Input grayscale image as numpy array
Returns:
Pseudo-RGB image as numpy array
"""
if arr.ndim != 2:
raise ValueError("Input array must be a grayscale image with shape (H, W)")
a1 = arr.copy().astype(np.float32)
a1 -= np.percentile(a1, 2)
a1[a1 < 0] = 0
p999 = np.percentile(a1, 99.9)
a1[a1 > p999] = p999
a1 /= a1.max()
if 1:
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
a3 = a1.copy()
p9999 = np.percentile(a3, 99.99)
a3[a3 > p9999] = p9999
a3 /= a3.max()
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten()))
return out
class ImageLoadError(Exception): class ImageLoadError(Exception):
"""Exception raised when an image cannot be loaded.""" """Exception raised when an image cannot be loaded."""
@@ -54,7 +93,6 @@ class Image:
""" """
self.path = Path(image_path) self.path = Path(image_path)
self._data: Optional[np.ndarray] = None self._data: Optional[np.ndarray] = None
self._pil_image: Optional[PILImage.Image] = None
self._width: int = 0 self._width: int = 0
self._height: int = 0 self._height: int = 0
self._channels: int = 0 self._channels: int = 0
@@ -80,11 +118,14 @@ class Image:
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS): if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
ext = self.path.suffix.lower() ext = self.path.suffix.lower()
raise ImageLoadError( raise ImageLoadError(
f"Unsupported image format: {ext}. " f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
) )
try: try:
if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = imread(str(self.path))
else:
# raise NotImplementedError("RGB is not implemented")
# Load with OpenCV (returns BGR format) # Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED) self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
@@ -92,23 +133,19 @@ class Image:
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}") raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
# Extract metadata # Extract metadata
# print(self._data.shape)
if len(self._data.shape) == 2:
self._height, self._width = self._data.shape[:2] self._height, self._width = self._data.shape[:2]
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1 self._channels = 1
else:
self._height, self._width = self._data.shape[1:]
self._channels = self._data.shape[0]
# self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
self._format = self.path.suffix.lower().lstrip(".") self._format = self.path.suffix.lower().lstrip(".")
self._size_bytes = self.path.stat().st_size self._size_bytes = self.path.stat().st_size
self._dtype = self._data.dtype self._dtype = self._data.dtype
# Load PIL version for compatibility (convert BGR to RGB) if 0:
if self._channels == 3:
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
self._pil_image = PILImage.fromarray(rgb_data)
elif self._channels == 4:
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
self._pil_image = PILImage.fromarray(rgba_data)
else:
# Grayscale
self._pil_image = PILImage.fromarray(self._data)
logger.info( logger.info(
f"Successfully loaded image: {self.path.name} " f"Successfully loaded image: {self.path.name} "
f"({self._width}x{self._height}, {self._channels} channels, " f"({self._width}x{self._height}, {self._channels} channels, "
@@ -131,18 +168,6 @@ class Image:
raise ImageLoadError("Image data not available") raise ImageLoadError("Image data not available")
return self._data return self._data
@property
def pil_image(self) -> PILImage.Image:
"""
Get image data as PIL Image (RGB or grayscale).
Returns:
PIL Image object
"""
if self._pil_image is None:
raise ImageLoadError("PIL image not available")
return self._pil_image
@property @property
def width(self) -> int: def width(self) -> int:
"""Get image width in pixels.""" """Get image width in pixels."""
@@ -187,6 +212,7 @@ class Image:
@property @property
def dtype(self) -> np.dtype: def dtype(self) -> np.dtype:
"""Get the data type of the image array.""" """Get the data type of the image array."""
if self._dtype is None: if self._dtype is None:
raise ImageLoadError("Image dtype not available") raise ImageLoadError("Image dtype not available")
return self._dtype return self._dtype
@@ -206,8 +232,10 @@ class Image:
elif self._channels == 1: elif self._channels == 1:
if self._dtype == np.uint16: if self._dtype == np.uint16:
return QImage.Format_Grayscale16 return QImage.Format_Grayscale16
else: elif self._dtype == np.uint8:
return QImage.Format_Grayscale8 return QImage.Format_Grayscale8
elif self._dtype == np.float32:
return QImage.Format_BGR30
else: else:
raise ImageLoadError(f"Unsupported number of channels: {self._channels}") raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
@@ -218,12 +246,36 @@ class Image:
Returns: Returns:
Image data in RGB format as numpy array Image data in RGB format as numpy array
""" """
if self._channels == 3: if self.channels == 1:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB) img = get_pseudo_rgb(self.data)
self._dtype = img.dtype
return img, True
elif self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
elif self._channels == 4: elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA) return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
else: else:
return self._data raise NotImplementedError
# else:
# return self._data
def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W)
_img, pseudo = self.get_rgb()
if pseudo:
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 0] = _img[0] # R gradient
img[..., 1] = _img[1] # G gradient
img[..., 2] = _img[2] # B constant
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img)
else:
return np.ascontiguousarray(_img)
def get_grayscale(self) -> np.ndarray: def get_grayscale(self) -> np.ndarray:
""" """
@@ -277,11 +329,26 @@ class Image:
""" """
return self._channels >= 3 return self._channels >= 3
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
if self.channels == 1:
if pseudo_rgb:
img = get_pseudo_rgb(self.data)
print("Image.save", img.shape)
else:
img = np.repeat(self.data, 3, axis=2)
else:
raise NotImplementedError("Only grayscale images are supported for now.")
imwrite(path, data=img)
def __repr__(self) -> str: def __repr__(self) -> str:
"""String representation of the Image object.""" """String representation of the Image object."""
return ( return (
f"Image(path='{self.path.name}', " f"Image(path='{self.path.name}', "
f"shape=({self._width}x{self._height}x{self._channels}), " # Display as HxWxC to match the conventional NumPy shape semantics.
f"shape=({self._height}x{self._width}x{self._channels}), "
f"format={self._format}, " f"format={self._format}, "
f"size={self.size_mb:.2f}MB)" f"size={self.size_mb:.2f}MB)"
) )
@@ -291,38 +358,13 @@ class Image:
return self.__repr__() return self.__repr__()
def convert_grayscale_to_rgb_preserve_range( if __name__ == "__main__":
pil_image: PILImage.Image, import argparse
) -> PILImage.Image:
"""Convert a single-channel PIL image to RGB while preserving dynamic range.
Args: parser = argparse.ArgumentParser()
pil_image: Single-channel PIL image (e.g., 16-bit grayscale). parser.add_argument("--path", type=str, required=True)
args = parser.parse_args()
Returns: img = Image(args.path)
PIL Image in RGB mode with intensities normalized to 0-255. img.save(args.path + "test.tif")
""" print(img)
if pil_image.mode == "RGB":
return pil_image
grayscale = np.array(pil_image)
if grayscale.ndim == 3:
grayscale = grayscale[:, :, 0]
original_dtype = grayscale.dtype
grayscale = grayscale.astype(np.float32)
if grayscale.size == 0:
return PILImage.new("RGB", pil_image.size, color=(0, 0, 0))
if np.issubdtype(original_dtype, np.integer):
denom = float(max(np.iinfo(original_dtype).max, 1))
else:
max_val = float(grayscale.max())
denom = max(max_val, 1.0)
grayscale = np.clip(grayscale / denom, 0.0, 1.0)
grayscale_u8 = (grayscale * 255.0).round().astype(np.uint8)
rgb_arr = np.repeat(grayscale_u8[:, :, None], 3, axis=2)
return PILImage.fromarray(rgb_arr, mode="RGB")

View File

@@ -12,23 +12,38 @@ class UT:
Operetta files along with rois drawn in ImageJ Operetta files along with rois drawn in ImageJ
""" """
def __init__(self, roifile_fn: Path): def __init__(self, roifile_fn: Path, no_labels: bool):
self.roifile_fn = roifile_fn self.roifile_fn = roifile_fn
print("is file", self.roifile_fn.is_file())
self.rois = None
if no_labels:
self.rois = ImagejRoi.fromfile(self.roifile_fn) self.rois = ImagejRoi.fromfile(self.roifile_fn)
self.stem = self.roifile_fn.stem.strip("-RoiSet") print(self.roifile_fn.stem)
print(self.roifile_fn.parent.parts[-1])
if "Roi-" in self.roifile_fn.stem:
self.stem = self.roifile_fn.stem.split("Roi-")[1]
else:
self.stem = self.roifile_fn.parent.parts[-1]
else:
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
self.stem = self.roifile_fn.stem
print(self.roifile_fn)
print(self.stem)
self.image, self.image_props = self._load_images() self.image, self.image_props = self._load_images()
def _load_images(self): def _load_images(self):
"""Loading sequence of tif files """Loading sequence of tif files
array sequence is CZYX array sequence is CZYX
""" """
print(self.roifile_fn.parent, self.stem) print("Loading images:", self.roifile_fn.parent, self.stem)
fns = list(self.roifile_fn.parent.glob(f"{self.stem}*.tif*")) fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
stems = [fn.stem.split(self.stem)[-1] for fn in fns] stems = [fn.stem.split(self.stem)[-1] for fn in fns]
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems])) n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
n_p = len(set([stem.split("-")[0] for stem in stems])) n_p = len(set([stem.split("-")[0] for stem in stems]))
n_t = len(set([stem.split("t")[1] for stem in stems])) n_t = len(set([stem.split("t")[1] for stem in stems]))
print(n_ch, n_p, n_t)
with TiffFile(fns[0]) as tif: with TiffFile(fns[0]) as tif:
img = tif.asarray() img = tif.asarray()
@@ -42,6 +57,7 @@ class UT:
"height": h, "height": h,
"dtype": dtype, "dtype": dtype,
} }
print("Image props", self.image_props)
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype) image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
for fn in fns: for fn in fns:
@@ -49,7 +65,7 @@ class UT:
img = tif.asarray() img = tif.asarray()
stem = fn.stem.split(self.stem)[-1] stem = fn.stem.split(self.stem)[-1]
ch = int(stem.split("-ch")[-1].split("t")[0]) ch = int(stem.split("-ch")[-1].split("t")[0])
p = int(stem.split("-")[0].lstrip("p")) p = int(stem.split("-")[0].split("p")[1])
t = int(stem.split("t")[1]) t = int(stem.split("t")[1])
print(fn.stem, "ch", ch, "p", p, "t", t) print(fn.stem, "ch", ch, "p", p, "t", t)
image_stack[ch - 1, p - 1] = img image_stack[ch - 1, p - 1] = img
@@ -82,10 +98,19 @@ class UT:
): ):
"""Export rois to a file""" """Export rois to a file"""
with open(path / subfolder / f"{self.stem}.txt", "w") as f: with open(path / subfolder / f"{self.stem}.txt", "w") as f:
for roi in self.rois: for i, roi in enumerate(self.rois):
# TODO add image coordinates normalization rc = roi.subpixel_coordinates
coords = "" if rc is None:
for x, y in roi.subpixel_coordinates: print(f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}")
continue
xmn, ymn = rc.min(axis=0)
xmx, ymx = rc.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
for x, y in rc:
coords += f"{x/self.width} {y/self.height} " coords += f"{x/self.width} {y/self.height} "
f.write(f"{class_index} {coords}\n") f.write(f"{class_index} {coords}\n")
@@ -104,6 +129,7 @@ class UT:
self.image = np.max(self.image[channel], axis=0) self.image = np.max(self.image[channel], axis=0)
print(self.image.shape) print(self.image.shape)
print(path / subfolder / f"{self.stem}.tif")
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif: with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
tif.write(self.image) tif.write(self.image)
@@ -112,11 +138,31 @@ if __name__ == "__main__":
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("input", type=Path) parser.add_argument("-i", "--input", nargs="*", type=Path)
parser.add_argument("output", type=Path) parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"--no-labels",
action="store_false",
help="Source does not have labels, export only images",
)
args = parser.parse_args() args = parser.parse_args()
for rfn in args.input.glob("*.zip"): # print(args)
ut = UT(rfn) # aa
for path in args.input:
print("Path:", path)
if not args.no_labels:
print("No labels")
ut = UT(path, args.no_labels)
ut.export_image(args.output, plane_mode="max projection", channel=0)
else:
for rfn in Path(path).glob("*.zip"):
# if Path(path).suffix == ".zip":
print("Roi FN:", rfn)
ut = UT(rfn, args.no_labels)
ut.export_rois(args.output, class_index=0) ut.export_rois(args.output, class_index=0)
ut.export_image(args.output, plane_mode="max projection", channel=0) ut.export_image(args.output, plane_mode="max projection", channel=0)
print()

368
src/utils/image_splitter.py Normal file
View File

@@ -0,0 +1,368 @@
import numpy as np
from pathlib import Path
from tifffile import imread, imwrite
from shapely.geometry import LineString
from copy import deepcopy
from scipy.ndimage import zoom
# debug
from src.utils.image import Image
from show_yolo_seg import draw_annotations
import pylab as plt
import cv2
class Label:
def __init__(self, yolo_annotation: str):
class_id, bbox, polygon = self.parse_yolo_annotation(yolo_annotation)
self.class_id = class_id
self.bbox = bbox
self.polygon = polygon
def parse_yolo_annotation(self, yolo_annotation: str):
class_id, *coords = yolo_annotation.split()
class_id = int(class_id)
bbox = np.array(coords[:4], dtype=np.float32)
polygon = np.array(coords[4:], dtype=np.float32).reshape(-1, 2) if len(coords) > 4 else None
if not any(np.isclose(polygon[0], polygon[-1])):
polygon = np.vstack([polygon, polygon[0]])
return class_id, bbox, polygon
def offset_label(
self,
img_w,
img_h,
distance: float = 1.0,
cap_style: int = 2,
join_style: int = 2,
):
if self.polygon is None:
self.bbox = np.array(
[
self.bbox[0] - distance if self.bbox[0] - distance > 0 else 0,
self.bbox[1] - distance if self.bbox[1] - distance > 0 else 0,
self.bbox[2] + distance if self.bbox[2] + distance < 1 else 1,
self.bbox[3] + distance if self.bbox[3] + distance < 1 else 1,
],
dtype=np.float32,
)
return self.bbox
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
print(coords)
# if not coords:
# return False
return all(max(coords.flatten)) <= 1.001
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
# if coords_are_normalized(coords):
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
pts = poly_to_pts(self.polygon, img_w, img_h)
line = LineString(pts)
# Buffer distance in pixels
buffered = line.buffer(distance=distance, cap_style=cap_style, join_style=join_style)
self.polygon = np.array(buffered.exterior.coords, dtype=np.float32) / (img_w, img_h)
xmn, ymn = self.polygon.min(axis=0)
xmx, ymx = self.polygon.max(axis=0)
xc = (xmn + xmx) / 2
yc = (ymn + ymx) / 2
bw = xmx - xmn
bh = ymx - ymn
self.bbox = np.array([xc, yc, bw, bh], dtype=np.float32)
return self.bbox, self.polygon
def translate(self, x, y, scale_x, scale_y):
self.bbox[0] -= x
self.bbox[0] *= scale_x
self.bbox[1] -= y
self.bbox[1] *= scale_y
self.bbox[2] *= scale_x
self.bbox[3] *= scale_y
if self.polygon is not None:
self.polygon[:, 0] -= x
self.polygon[:, 0] *= scale_x
self.polygon[:, 1] -= y
self.polygon[:, 1] *= scale_y
def in_range(self, hrange, wrange):
xc, yc, h, w = self.bbox
x1 = xc - w / 2
y1 = yc - h / 2
x2 = xc + w / 2
y2 = yc + h / 2
truth_val = (
xc >= wrange[0]
and x1 <= wrange[1]
and x2 >= wrange[0]
and x2 <= wrange[1]
and y1 >= hrange[0]
and y1 <= hrange[1]
and y2 >= hrange[0]
and y2 <= hrange[1]
)
print(x1, x2, wrange, y1, y2, hrange, truth_val)
return truth_val
def to_string(self, bbox: list = None, polygon: list = None):
coords = ""
if bbox is None:
bbox = self.bbox
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
if polygon is None:
polygon = self.polygon
if self.polygon is not None:
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
return f"{self.class_id} {coords}"
def __str__(self):
return f"Class: {self.class_id}, BBox: {self.bbox}, Polygon: {self.polygon}"
class YoloLabelReader:
def __init__(self, label_path: Path):
self.label_path = label_path
self.labels = self._read_labels()
def _read_labels(self):
with open(self.label_path, "r") as f:
labels = [Label(line) for line in f.readlines()]
return labels
def get_labels(self, hrange, wrange):
"""hrange and wrange are tuples of (start, end) normalized to [0, 1]"""
labels = []
# print(hrange, wrange)
for lbl in self.labels:
# print(lbl)
if lbl.in_range(hrange, wrange):
labels.append(lbl)
return labels if len(labels) > 0 else None
def __get_item__(self, index):
return self.labels[index]
def __len__(self):
return len(self.labels)
def __iter__(self):
return iter(self.labels)
class ImageSplitter:
def __init__(self, image_path: Path, label_path: Path):
self.image = imread(image_path)
self.image_path = image_path
self.label_path = label_path
if not label_path.exists():
print(f"Label file {label_path} not found")
self.labels = None
else:
self.labels = YoloLabelReader(label_path)
def split_into_tiles(self, patch_size: tuple = (2, 2)):
"""Split image into patches of size patch_size"""
hstep, wstep = (
self.image.shape[0] // patch_size[0],
self.image.shape[1] // patch_size[1],
)
h, w = self.image.shape[:2]
for i in range(patch_size[0]):
for j in range(patch_size[1]):
metadata = {
"image_path": str(self.image_path),
"label_path": str(self.label_path),
"tile_section": f"{i}, {j}",
"tile_size": f"{hstep}, {wstep}",
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
}
tile_reference = f"i{i}j{j}"
hrange = (i * hstep / h, (i + 1) * hstep / h)
wrange = (j * wstep / w, (j + 1) * wstep / w)
tile = self.image[i * hstep : (i + 1) * hstep, j * wstep : (j + 1) * wstep]
labels = None
if self.labels is not None:
labels = deepcopy(self.labels.get_labels(hrange, wrange))
print(id(labels))
if labels is not None:
print(hrange[0], wrange[0])
for l in labels:
print(l.bbox)
[l.translate(wrange[0], hrange[0], 2, 2) for l in labels]
print("translated")
for l in labels:
print(l.bbox)
# print(labels)
yield tile_reference, tile, labels, metadata
def split_respective_to_label(self, padding: int = 67):
if self.labels is None:
raise ValueError("No labels found. Only images having labels can be split.")
for i, label in enumerate(self.labels):
tile_reference = f"_lbl-{i+1:02d}"
# print(label.bbox)
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
xc, yc, h, w = [
int(np.round(f))
for f in [
xc_norm * self.image.shape[1],
yc_norm * self.image.shape[0],
h_norm * self.image.shape[0],
w_norm * self.image.shape[1],
]
] # image coords
# print("img coords:", xc, yc, h, w)
pad_xneg = padding + 1 # int(w / 2) + padding
pad_xpos = padding # int(w / 2) + padding
pad_yneg = padding + 1 # int(h / 2) + padding
pad_ypos = padding # int(h / 2) + padding
if xc - pad_xneg < 0:
pad_xneg = xc
if pad_xpos + xc > self.image.shape[1]:
pad_xpos = self.image.shape[1] - xc
if yc - pad_yneg < 0:
pad_yneg = yc
if pad_ypos + yc > self.image.shape[0]:
pad_ypos = self.image.shape[0] - yc
# print("pads:", pad_xneg, pad_xpos, pad_yneg, pad_ypos)
tile = self.image[
yc - pad_yneg : yc + pad_ypos,
xc - pad_xneg : xc + pad_xpos,
]
ny, nx = tile.shape
x_offset = pad_xneg
y_offset = pad_yneg
# print("tile shape:", tile.shape)
yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
yolo_annotation += " ".join(
[
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
for x, y in label.polygon
]
)
print(yolo_annotation)
new_label = Label(yolo_annotation=yolo_annotation)
yield tile_reference, tile, [new_label], metadata
def main(args):
if args.output:
args.output.mkdir(exist_ok=True, parents=True)
(args.output / "images").mkdir(exist_ok=True)
(args.output / "images-zoomed").mkdir(exist_ok=True)
(args.output / "labels").mkdir(exist_ok=True)
for image_path in (args.input / "images").glob("*.tif"):
data = ImageSplitter(
image_path=image_path,
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
)
if args.split_around_label:
data = data.split_respective_to_label(padding=args.padding)
else:
data = data.split_into_tiles(patch_size=args.patch_size)
for tile_reference, tile, labels, metadata in data:
print()
print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
# { debug
debug = False
if debug:
plt.figure(figsize=(10, 10 * tile.shape[0] / tile.shape[1]))
if labels is None:
plt.imshow(tile, cmap="gray")
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
continue
print(labels[0].bbox)
# Draw annotations
out = draw_annotations(
cv2.cvtColor((tile / tile.max() * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
[l.to_string() for l in labels],
alpha=0.1,
)
# Convert BGR -> RGB for matplotlib display
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
plt.imshow(out_rgb)
plt.axis("off")
plt.title(f"{image_path.name} ({tile_reference})")
plt.show()
# } debug
if args.output:
# imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
scale = 5
tile_zoomed = zoom(tile, zoom=scale)
metadata["scale"] = scale
imwrite(
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
tile_zoomed,
metadata=metadata,
imagej=True,
)
if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
for label in labels:
# label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=Path)
parser.add_argument("-o", "--output", type=Path)
parser.add_argument(
"-p",
"--patch-size",
nargs=2,
type=int,
default=[2, 2],
help="Number of patches along height and width, rows and columns, respectively",
)
parser.add_argument(
"-sal",
"--split-around-label",
action="store_true",
help="If enabled, the image will be split around the label and for each label, a separate image will be created.",
)
parser.add_argument(
"--padding",
type=int,
default=67,
help="Padding around the label when splitting around the label.",
)
args = parser.parse_args()
main(args)

1
src/utils/show_yolo_seg.py Symbolic link
View File

@@ -0,0 +1 @@
../../tests/show_yolo_seg.py

View File

@@ -0,0 +1,157 @@
"""Ultralytics runtime patches for 16-bit TIFF training.
Goals:
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
- Preserve 16-bit data through the dataloader as `uint16` tensors.
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
This module is intended to be imported/called **before** instantiating/using YOLO.
"""
from __future__ import annotations
from typing import Optional
from src.utils.logger import get_logger
logger = get_logger(__name__)
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
This function is safe to call multiple times.
Args:
force: If True, re-apply patches even if already applied.
"""
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
import os
import cv2
import numpy as np
# import tifffile
import torch
from src.utils.image import Image
from ultralytics.utils import patches as ul_patches
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
if already_patched and not force:
return
_original_imread = ul_patches.imread
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
- For other formats, falls back to Ultralytics' original implementation.
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
"""
# print("here")
# return _original_imread(filename, flags)
ext = os.path.splitext(filename)[1].lower()
if ext in (".tif", ".tiff"):
arr = Image(filename).get_qt_rgb()[:, :, :3]
# Normalize common shapes:
# - (H, W) -> (H, W, 1)
# - (C, H, W) -> (H, W, C) (heuristic)
if arr is None:
return None
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 2:
arr = arr[..., None]
# Ensure contiguous array for downstream OpenCV ops.
# logger.info(f"Loading with monkey-patched imread: {filename}")
arr = arr.astype(np.float32)
arr /= arr.max()
arr *= 2**8 - 1
arr = arr.astype(np.uint8)
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
return np.ascontiguousarray(arr)
# logger.info(f"Loading with original imread: {filename}")
return _original_imread(filename, flags)
# Patch the canonical reference.
ul_patches.imread = tifffile_imread
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
# Importing these modules is safe and helps ensure the patched function is used.
try:
import ultralytics.data.base as _ul_base
_ul_base.imread = tifffile_imread
except Exception:
pass
try:
import ultralytics.data.loaders as _ul_loaders
_ul_loaders.imread = tifffile_imread
except Exception:
pass
# Patch trainer normalization: default divides by 255 regardless of input dtype.
from ultralytics.models.yolo.detect import train as detect_train
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling.
# logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
img = batch.get("img")
if isinstance(img, torch.Tensor):
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
if img.dtype == torch.uint8:
denom = 255.0
elif img.dtype == torch.uint16:
denom = 65535.0
elif img.dtype.is_floating_point:
# Assume already in 0-1 range if float.
denom = 1.0
else:
# Generic integer fallback.
try:
denom = float(torch.iinfo(img.dtype).max)
except Exception:
denom = 255.0
batch["img"] = img.float() / denom
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
if getattr(self.args, "multi_scale", False):
import math
import random
import torch.nn as nn
imgs = batch["img"]
sz = (
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
// self.stride
* self.stride
)
sf = sz / max(imgs.shape[2:])
if sf != 1:
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
batch["img"] = imgs
return batch
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
# Tag function to make it easier to detect patch state.
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)

231
tests/show_yolo_seg.py Normal file
View File

@@ -0,0 +1,231 @@
#!/usr/bin/env python3
"""
show_yolo_seg.py
Usage:
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
Supports:
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
- YOLO bbox lines as fallback: "class x_center y_center width height"
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
"""
import sys
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import random
from shapely.geometry import LineString
from src.utils.image import Image
def parse_label_line(line):
parts = line.strip().split()
if not parts:
return None
cls = int(float(parts[0]))
coords = [float(x) for x in parts[1:]]
return cls, coords
def coords_are_normalized(coords):
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
if not coords:
return False
return max(coords) <= 1.001
def yolo_bbox_to_xyxy(coords, img_w, img_h):
# coords: [xc, yc, w, h] normalized or absolute
xc, yc, w, h = coords[:4]
if max(coords) <= 1.001:
xc *= img_w
yc *= img_h
w *= img_w
h *= img_h
x1 = int(round(xc - w / 2))
y1 = int(round(yc - h / 2))
x2 = int(round(xc + w / 2))
y2 = int(round(yc + h / 2))
return x1, y1, x2, y2
def poly_to_pts(coords, img_w, img_h):
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
if coords_are_normalized(coords[4:]):
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
return pts
def random_color_for_class(cls):
random.seed(cls) # deterministic per class
return (
0,
0,
255,
) # tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
# img: BGR numpy array
overlay = img.copy()
h, w = img.shape[:2]
for line in labels:
if isinstance(line, str):
cls, coords = parse_label_line(line)
if isinstance(line, tuple):
cls, coords = line
if not coords:
continue
# polygon case (>=6 coordinates)
if len(coords) >= 6:
color = random_color_for_class(cls)
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
print(x1, y1, x2, y2)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
pts = poly_to_pts(coords[4:], w, h)
# line = LineString(pts)
# # Buffer distance in pixels
# buffered = line.buffer(3, cap_style=2, join_style=2)
# coords = np.array(buffered.exterior.coords, dtype=np.int32)
# cv2.fillPoly(overlay, [coords], color=(255, 255, 255))
# fill on overlay
cv2.fillPoly(overlay, [pts], color)
# outline on base image
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=1)
# put class text at first point
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
if 0:
cv2.putText(
img,
str(cls),
(x, max(6, y)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
# YOLO bbox case (4 coords)
elif len(coords) == 4:
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
color = random_color_for_class(cls)
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
cv2.putText(
img,
str(cls),
(x1, max(6, y1 - 4)),
cv2.FONT_HERSHEY_SIMPLEX,
0.6,
(255, 255, 255),
2,
cv2.LINE_AA,
)
else:
# Unknown / invalid format, skip
continue
# blend overlay for filled polygons
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
return img
def load_labels_file(label_path):
labels = []
with open(label_path, "r") as f:
for raw in f:
line = raw.strip()
if not line:
continue
parsed = parse_label_line(line)
if parsed:
labels.append(parsed)
return labels
def main():
parser = argparse.ArgumentParser(description="Show YOLO segmentation / polygon annotations")
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("--labels", type=str, help="Path to YOLO label file (polygons)")
parser.add_argument("--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)")
parser.add_argument("--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons")
args = parser.parse_args()
print(args)
img_path = Path(args.image)
if args.labels:
lbl_path = Path(args.labels)
else:
lbl_path = img_path.with_suffix(".txt")
lbl_path = Path(str(lbl_path).replace("images", "labels"))
if not img_path.exists():
print("Image not found:", img_path)
sys.exit(1)
if not lbl_path.exists():
print("Label file not found:", lbl_path)
sys.exit(1)
# img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
img = (Image(img_path).get_qt_rgb() * 255).astype(np.uint8)
if img is None:
print("Could not load image:", img_path)
sys.exit(1)
labels = load_labels_file(str(lbl_path))
if not labels:
print("No labels parsed from", lbl_path)
# continue and just show image
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
if 0:
plt.imshow(out_rgb.transpose(1, 0, 2))
else:
plt.imshow(out_rgb)
for label in labels:
lclass, coords = label
# print(lclass, coords)
bbox = coords[:4]
# print("bbox", bbox)
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
yc, xc, h, w = bbox
# print("bbox", bbox)
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
# print("pl", coords[4:])
# print("pl", polyline)
# Convert BGR -> RGB for matplotlib display
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
# out_rgb = Image()
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
if 0:
plt.plot(
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
"r",
linewidth=2,
)
# plt.axis("off")
plt.title(f"{img_path.name} ({lbl_path.name})")
plt.show()
if __name__ == "__main__":
main()