Compare commits
9 Commits
d998c65665
...
506c74e53a
| Author | SHA1 | Date | |
|---|---|---|---|
| 506c74e53a | |||
| eefda5b878 | |||
| 31cb6a6c8e | |||
| 0c19ea2557 | |||
| 89e47591db | |||
| 69cde09e53 | |||
| fcbd5fb16d | |||
| ca52312925 | |||
| 0a93bf797a |
@@ -60,9 +60,7 @@ class DatabaseManager:
|
|||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# Check if annotations table exists
|
# Check if annotations table exists
|
||||||
cursor.execute(
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
|
||||||
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
|
|
||||||
)
|
|
||||||
if not cursor.fetchone():
|
if not cursor.fetchone():
|
||||||
# Table doesn't exist yet, no migration needed
|
# Table doesn't exist yet, no migration needed
|
||||||
return
|
return
|
||||||
@@ -242,9 +240,7 @@ class DatabaseManager:
|
|||||||
return cursor.lastrowid
|
return cursor.lastrowid
|
||||||
except sqlite3.IntegrityError:
|
except sqlite3.IntegrityError:
|
||||||
# Image already exists, return its ID
|
# Image already exists, return its ID
|
||||||
cursor.execute(
|
cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
|
||||||
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return row["id"] if row else None
|
return row["id"] if row else None
|
||||||
finally:
|
finally:
|
||||||
@@ -255,17 +251,13 @@ class DatabaseManager:
|
|||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
|
||||||
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_or_create_image(
|
def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
|
||||||
self, relative_path: str, filename: str, width: int, height: int
|
|
||||||
) -> int:
|
|
||||||
"""Get existing image or create new one."""
|
"""Get existing image or create new one."""
|
||||||
existing = self.get_image_by_path(relative_path)
|
existing = self.get_image_by_path(relative_path)
|
||||||
if existing:
|
if existing:
|
||||||
@@ -355,16 +347,8 @@ class DatabaseManager:
|
|||||||
bbox[2],
|
bbox[2],
|
||||||
bbox[3],
|
bbox[3],
|
||||||
det["confidence"],
|
det["confidence"],
|
||||||
(
|
(json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
|
||||||
json.dumps(det.get("segmentation_mask"))
|
(json.dumps(det.get("metadata")) if det.get("metadata") else None),
|
||||||
if det.get("segmentation_mask")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
(
|
|
||||||
json.dumps(det.get("metadata"))
|
|
||||||
if det.get("metadata")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
@@ -409,15 +393,16 @@ class DatabaseManager:
|
|||||||
if filters:
|
if filters:
|
||||||
conditions = []
|
conditions = []
|
||||||
for key, value in filters.items():
|
for key, value in filters.items():
|
||||||
if (
|
if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
|
||||||
key.startswith("d.")
|
if "like" in value.lower():
|
||||||
or key.startswith("i.")
|
conditions.append(f"{key} LIKE ?")
|
||||||
or key.startswith("m.")
|
params.append(value.split(" ")[1])
|
||||||
):
|
else:
|
||||||
conditions.append(f"{key} = ?")
|
conditions.append(f"{key} = ?")
|
||||||
|
params.append(value)
|
||||||
else:
|
else:
|
||||||
conditions.append(f"d.{key} = ?")
|
conditions.append(f"d.{key} = ?")
|
||||||
params.append(value)
|
params.append(value)
|
||||||
query += " WHERE " + " AND ".join(conditions)
|
query += " WHERE " + " AND ".join(conditions)
|
||||||
|
|
||||||
query += " ORDER BY d.detected_at DESC"
|
query += " ORDER BY d.detected_at DESC"
|
||||||
@@ -442,18 +427,14 @@ class DatabaseManager:
|
|||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def get_detections_for_image(
|
def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
|
||||||
self, image_id: int, model_id: Optional[int] = None
|
|
||||||
) -> List[Dict]:
|
|
||||||
"""Get all detections for a specific image."""
|
"""Get all detections for a specific image."""
|
||||||
filters = {"image_id": image_id}
|
filters = {"image_id": image_id}
|
||||||
if model_id:
|
if model_id:
|
||||||
filters["model_id"] = model_id
|
filters["model_id"] = model_id
|
||||||
return self.get_detections(filters)
|
return self.get_detections(filters)
|
||||||
|
|
||||||
def delete_detections_for_image(
|
def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
|
||||||
self, image_id: int, model_id: Optional[int] = None
|
|
||||||
) -> int:
|
|
||||||
"""Delete detections tied to a specific image and optional model."""
|
"""Delete detections tied to a specific image and optional model."""
|
||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
@@ -524,9 +505,7 @@ class DatabaseManager:
|
|||||||
""",
|
""",
|
||||||
params,
|
params,
|
||||||
)
|
)
|
||||||
class_counts = {
|
class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
|
||||||
row["class_name"]: row["count"] for row in cursor.fetchall()
|
|
||||||
}
|
|
||||||
|
|
||||||
# Average confidence
|
# Average confidence
|
||||||
cursor.execute(
|
cursor.execute(
|
||||||
@@ -583,9 +562,7 @@ class DatabaseManager:
|
|||||||
|
|
||||||
# ==================== Export Operations ====================
|
# ==================== Export Operations ====================
|
||||||
|
|
||||||
def export_detections_to_csv(
|
def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
|
||||||
self, output_path: str, filters: Optional[Dict] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Export detections to CSV file."""
|
"""Export detections to CSV file."""
|
||||||
try:
|
try:
|
||||||
detections = self.get_detections(filters)
|
detections = self.get_detections(filters)
|
||||||
@@ -614,9 +591,7 @@ class DatabaseManager:
|
|||||||
for det in detections:
|
for det in detections:
|
||||||
row = {k: det[k] for k in fieldnames if k in det}
|
row = {k: det[k] for k in fieldnames if k in det}
|
||||||
# Convert segmentation mask list to JSON string for CSV
|
# Convert segmentation mask list to JSON string for CSV
|
||||||
if row.get("segmentation_mask") and isinstance(
|
if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
|
||||||
row["segmentation_mask"], list
|
|
||||||
):
|
|
||||||
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
|
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
|
||||||
writer.writerow(row)
|
writer.writerow(row)
|
||||||
|
|
||||||
@@ -625,9 +600,7 @@ class DatabaseManager:
|
|||||||
print(f"Error exporting to CSV: {e}")
|
print(f"Error exporting to CSV: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def export_detections_to_json(
|
def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
|
||||||
self, output_path: str, filters: Optional[Dict] = None
|
|
||||||
) -> bool:
|
|
||||||
"""Export detections to JSON file."""
|
"""Export detections to JSON file."""
|
||||||
try:
|
try:
|
||||||
detections = self.get_detections(filters)
|
detections = self.get_detections(filters)
|
||||||
@@ -785,17 +758,13 @@ class DatabaseManager:
|
|||||||
conn = self.get_connection()
|
conn = self.get_connection()
|
||||||
try:
|
try:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
cursor.execute(
|
cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
|
||||||
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
|
|
||||||
)
|
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
return dict(row) if row else None
|
return dict(row) if row else None
|
||||||
finally:
|
finally:
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
def add_object_class(
|
def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
|
||||||
self, class_name: str, color: str, description: Optional[str] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
"""
|
||||||
Add a new object class.
|
Add a new object class.
|
||||||
|
|
||||||
@@ -928,8 +897,7 @@ class DatabaseManager:
|
|||||||
if not split_map[required]:
|
if not split_map[required]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unable to determine %s image directory under %s. Provide it "
|
"Unable to determine %s image directory under %s. Provide it "
|
||||||
"explicitly via the 'splits' argument."
|
"explicitly via the 'splits' argument." % (required, dataset_root_path)
|
||||||
% (required, dataset_root_path)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yaml_splits: Dict[str, str] = {}
|
yaml_splits: Dict[str, str] = {}
|
||||||
@@ -955,11 +923,7 @@ class DatabaseManager:
|
|||||||
if yaml_splits.get("test"):
|
if yaml_splits.get("test"):
|
||||||
payload["test"] = yaml_splits["test"]
|
payload["test"] = yaml_splits["test"]
|
||||||
|
|
||||||
output_path_obj = (
|
output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
|
||||||
Path(output_path).expanduser()
|
|
||||||
if output_path
|
|
||||||
else dataset_root_path / "data.yaml"
|
|
||||||
)
|
|
||||||
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
||||||
@@ -1019,15 +983,9 @@ class DatabaseManager:
|
|||||||
for split_name, options in patterns.items():
|
for split_name, options in patterns.items():
|
||||||
for relative in options:
|
for relative in options:
|
||||||
candidate = (dataset_root / relative).resolve()
|
candidate = (dataset_root / relative).resolve()
|
||||||
if (
|
if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
|
||||||
candidate.exists()
|
|
||||||
and candidate.is_dir()
|
|
||||||
and self._directory_has_images(candidate)
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
inferred[split_name] = candidate.relative_to(
|
inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
|
||||||
dataset_root
|
|
||||||
).as_posix()
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
inferred[split_name] = candidate.as_posix()
|
inferred[split_name] = candidate.as_posix()
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -35,9 +35,7 @@ logger = get_logger(__name__)
|
|||||||
class ResultsTab(QWidget):
|
class ResultsTab(QWidget):
|
||||||
"""Results tab showing detection history and preview overlays."""
|
"""Results tab showing detection history and preview overlays."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
@@ -71,24 +69,12 @@ class ResultsTab(QWidget):
|
|||||||
left_layout.addLayout(controls_layout)
|
left_layout.addLayout(controls_layout)
|
||||||
|
|
||||||
self.results_table = QTableWidget(0, 5)
|
self.results_table = QTableWidget(0, 5)
|
||||||
self.results_table.setHorizontalHeaderLabels(
|
self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
|
||||||
["Image", "Model", "Detections", "Classes", "Last Updated"]
|
self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
|
||||||
)
|
self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
|
||||||
0, QHeaderView.Stretch
|
self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
|
||||||
)
|
self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
1, QHeaderView.Stretch
|
|
||||||
)
|
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
2, QHeaderView.ResizeToContents
|
|
||||||
)
|
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
3, QHeaderView.Stretch
|
|
||||||
)
|
|
||||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
|
||||||
4, QHeaderView.ResizeToContents
|
|
||||||
)
|
|
||||||
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||||
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||||
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||||
@@ -106,6 +92,8 @@ class ResultsTab(QWidget):
|
|||||||
preview_layout = QVBoxLayout()
|
preview_layout = QVBoxLayout()
|
||||||
|
|
||||||
self.preview_canvas = AnnotationCanvasWidget()
|
self.preview_canvas = AnnotationCanvasWidget()
|
||||||
|
# Auto-zoom so newly loaded images fill the available preview viewport.
|
||||||
|
self.preview_canvas.set_auto_fit_to_view(True)
|
||||||
self.preview_canvas.set_polyline_enabled(False)
|
self.preview_canvas.set_polyline_enabled(False)
|
||||||
self.preview_canvas.set_show_bboxes(True)
|
self.preview_canvas.set_show_bboxes(True)
|
||||||
preview_layout.addWidget(self.preview_canvas)
|
preview_layout.addWidget(self.preview_canvas)
|
||||||
@@ -119,9 +107,7 @@ class ResultsTab(QWidget):
|
|||||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||||
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||||
self.show_confidence_checkbox.setChecked(False)
|
self.show_confidence_checkbox.setChecked(False)
|
||||||
self.show_confidence_checkbox.stateChanged.connect(
|
self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||||
self._apply_detection_overlays
|
|
||||||
)
|
|
||||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||||
toggles_layout.addWidget(self.show_confidence_checkbox)
|
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||||
@@ -169,8 +155,7 @@ class ResultsTab(QWidget):
|
|||||||
"image_id": det["image_id"],
|
"image_id": det["image_id"],
|
||||||
"model_id": det["model_id"],
|
"model_id": det["model_id"],
|
||||||
"image_path": det.get("image_path"),
|
"image_path": det.get("image_path"),
|
||||||
"image_filename": det.get("image_filename")
|
"image_filename": det.get("image_filename") or det.get("image_path"),
|
||||||
or det.get("image_path"),
|
|
||||||
"model_name": det.get("model_name", ""),
|
"model_name": det.get("model_name", ""),
|
||||||
"model_version": det.get("model_version", ""),
|
"model_version": det.get("model_version", ""),
|
||||||
"last_detected": det.get("detected_at"),
|
"last_detected": det.get("detected_at"),
|
||||||
@@ -183,8 +168,7 @@ class ResultsTab(QWidget):
|
|||||||
|
|
||||||
entry["count"] += 1
|
entry["count"] += 1
|
||||||
if det.get("detected_at") and (
|
if det.get("detected_at") and (
|
||||||
not entry.get("last_detected")
|
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||||
or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
|
||||||
):
|
):
|
||||||
entry["last_detected"] = det.get("detected_at")
|
entry["last_detected"] = det.get("detected_at")
|
||||||
if det.get("class_name"):
|
if det.get("class_name"):
|
||||||
@@ -214,9 +198,7 @@ class ResultsTab(QWidget):
|
|||||||
|
|
||||||
for row, entry in enumerate(self.detection_summary):
|
for row, entry in enumerate(self.detection_summary):
|
||||||
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||||
class_list = (
|
class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||||
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
|
||||||
)
|
|
||||||
|
|
||||||
items = [
|
items = [
|
||||||
QTableWidgetItem(entry.get("image_filename", "")),
|
QTableWidgetItem(entry.get("image_filename", "")),
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
import numpy as np
|
||||||
from PySide6.QtCore import Qt, QThread, Signal
|
from PySide6.QtCore import Qt, QThread, Signal
|
||||||
from PySide6.QtWidgets import (
|
from PySide6.QtWidgets import (
|
||||||
QWidget,
|
QWidget,
|
||||||
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
computed_total = sum(
|
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
|
||||||
max(0, int((stage.get("params") or {}).get("epochs", 0)))
|
|
||||||
for stage in self.stage_plan
|
|
||||||
)
|
|
||||||
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
|
||||||
self._stop_requested = False
|
self._stop_requested = False
|
||||||
|
|
||||||
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
|
|||||||
class TrainingTab(QWidget):
|
class TrainingTab(QWidget):
|
||||||
"""Training tab for model training."""
|
"""Training tab for model training."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
|
||||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
|
||||||
):
|
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
self.db_manager = db_manager
|
self.db_manager = db_manager
|
||||||
self.config_manager = config_manager
|
self.config_manager = config_manager
|
||||||
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
|
|||||||
self.model_version_edit = QLineEdit("v1")
|
self.model_version_edit = QLineEdit("v1")
|
||||||
form_layout.addRow("Version:", self.model_version_edit)
|
form_layout.addRow("Version:", self.model_version_edit)
|
||||||
|
|
||||||
default_base_model = self.config_manager.get(
|
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
|
||||||
"models.default_base_model", "yolov8s-seg.pt"
|
|
||||||
)
|
|
||||||
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
base_model_choices = self.config_manager.get("models.base_model_choices", [])
|
||||||
|
|
||||||
self.base_model_combo = QComboBox()
|
self.base_model_combo = QComboBox()
|
||||||
self.base_model_combo.addItem("Custom path…", "")
|
self.base_model_combo.addItem("Custom path…", "")
|
||||||
for choice in base_model_choices:
|
for choice in base_model_choices:
|
||||||
self.base_model_combo.addItem(choice, choice)
|
self.base_model_combo.addItem(choice, choice)
|
||||||
self.base_model_combo.currentIndexChanged.connect(
|
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
|
||||||
self._on_base_model_preset_changed
|
|
||||||
)
|
|
||||||
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
form_layout.addRow("Base Model Preset:", self.base_model_combo)
|
||||||
|
|
||||||
base_model_layout = QHBoxLayout()
|
base_model_layout = QHBoxLayout()
|
||||||
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
|
|||||||
group_layout = QVBoxLayout()
|
group_layout = QVBoxLayout()
|
||||||
|
|
||||||
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
|
||||||
two_stage_defaults = (
|
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
|
||||||
training_defaults.get("two_stage", {}) if training_defaults else {}
|
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
|
||||||
)
|
|
||||||
self.two_stage_checkbox.setChecked(
|
|
||||||
bool(two_stage_defaults.get("enabled", False))
|
|
||||||
)
|
|
||||||
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
|
||||||
group_layout.addWidget(self.two_stage_checkbox)
|
group_layout.addWidget(self.two_stage_checkbox)
|
||||||
|
|
||||||
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
|
|||||||
stage2_group.setLayout(stage2_form)
|
stage2_group.setLayout(stage2_form)
|
||||||
controls_layout.addWidget(stage2_group)
|
controls_layout.addWidget(stage2_group)
|
||||||
|
|
||||||
helper_label = QLabel(
|
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
|
||||||
"When enabled, staged hyperparameters override the global epochs/patience/lr."
|
|
||||||
)
|
|
||||||
helper_label.setWordWrap(True)
|
helper_label.setWordWrap(True)
|
||||||
controls_layout.addWidget(helper_label)
|
controls_layout.addWidget(helper_label)
|
||||||
|
|
||||||
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
|
|||||||
if normalized == preset_value:
|
if normalized == preset_value:
|
||||||
target_index = idx
|
target_index = idx
|
||||||
break
|
break
|
||||||
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
|
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
|
||||||
f"\\{preset_value}"
|
|
||||||
):
|
|
||||||
target_index = idx
|
target_index = idx
|
||||||
break
|
break
|
||||||
self.base_model_combo.blockSignals(True)
|
self.base_model_combo.blockSignals(True)
|
||||||
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
def _browse_dataset(self):
|
def _browse_dataset(self):
|
||||||
"""Open a file dialog to manually select data.yaml."""
|
"""Open a file dialog to manually select data.yaml."""
|
||||||
start_dir = self.config_manager.get(
|
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
|
||||||
"training.last_dataset_dir", "data/datasets"
|
|
||||||
)
|
|
||||||
start_path = Path(start_dir).expanduser()
|
start_path = Path(start_dir).expanduser()
|
||||||
if not start_path.exists():
|
if not start_path.exists():
|
||||||
start_path = Path.cwd()
|
start_path = Path.cwd()
|
||||||
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
|
|||||||
return
|
return
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.exception("Unexpected error while generating data.yaml")
|
logger.exception("Unexpected error while generating data.yaml")
|
||||||
self._display_dataset_error(
|
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
|
||||||
"Unexpected error while generating data.yaml. Check logs for details."
|
|
||||||
)
|
|
||||||
QMessageBox.critical(
|
QMessageBox.critical(
|
||||||
self,
|
self,
|
||||||
"data.yaml Generation Failed",
|
"data.yaml Generation Failed",
|
||||||
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
|
|||||||
self.selected_dataset = info
|
self.selected_dataset = info
|
||||||
|
|
||||||
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
|
||||||
self.train_count_label.setText(
|
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
|
||||||
self._format_split_info(info["splits"].get("train"))
|
|
||||||
)
|
|
||||||
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
|
||||||
self.test_count_label.setText(
|
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
|
||||||
self._format_split_info(info["splits"].get("test"))
|
|
||||||
)
|
|
||||||
self.num_classes_label.setText(str(info["num_classes"]))
|
self.num_classes_label.setText(str(info["num_classes"]))
|
||||||
class_names = ", ".join(info["class_names"]) or "–"
|
class_names = ", ".join(info["class_names"]) or "–"
|
||||||
self.class_names_label.setText(class_names)
|
self.class_names_label.setText(class_names)
|
||||||
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
|
|||||||
if split_path.exists():
|
if split_path.exists():
|
||||||
split_info["count"] = self._count_images(split_path)
|
split_info["count"] = self._count_images(split_path)
|
||||||
if split_info["count"] == 0:
|
if split_info["count"] == 0:
|
||||||
warnings.append(
|
warnings.append(f"No images found for {split_name} split at {split_path}")
|
||||||
f"No images found for {split_name} split at {split_path}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
warnings.append(
|
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
|
||||||
f"{split_name.capitalize()} path does not exist: {split_path}"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if split_name in ("train", "val"):
|
if split_name in ("train", "val"):
|
||||||
warnings.append(
|
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
|
||||||
f"{split_name.capitalize()} split missing in data.yaml"
|
|
||||||
)
|
|
||||||
splits[split_name] = split_info
|
splits[split_name] = split_info
|
||||||
|
|
||||||
names_list = self._normalize_class_names(data.get("names"))
|
names_list = self._normalize_class_names(data.get("names"))
|
||||||
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
|
|||||||
if not names_list and nc_value:
|
if not names_list and nc_value:
|
||||||
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
|
||||||
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
elif nc_value and len(names_list) not in (0, int(nc_value)):
|
||||||
warnings.append(
|
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
|
||||||
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset_name = data.get("name") or base_path.name
|
dataset_name = data.get("name") or base_path.name
|
||||||
|
|
||||||
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
class_index_map = self._build_class_index_map(dataset_info)
|
class_index_map = self._build_class_index_map(dataset_info)
|
||||||
if not class_index_map:
|
if not class_index_map:
|
||||||
self._append_training_log(
|
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
|
||||||
"Skipping label export: dataset classes do not match database entries."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
dataset_root_str = dataset_info.get("root")
|
dataset_root_str = dataset_info.get("root")
|
||||||
dataset_yaml_path = dataset_info.get("yaml_path")
|
dataset_yaml_path = dataset_info.get("yaml_path")
|
||||||
dataset_yaml = (
|
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
||||||
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
|
|
||||||
)
|
|
||||||
dataset_root: Optional[Path]
|
dataset_root: Optional[Path]
|
||||||
if dataset_root_str:
|
if dataset_root_str:
|
||||||
dataset_root = Path(dataset_root_str).resolve()
|
dataset_root = Path(dataset_root_str).resolve()
|
||||||
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
|
|||||||
if stats["registered_images"]:
|
if stats["registered_images"]:
|
||||||
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
message += f" {stats['registered_images']} image(s) had database-backed annotations."
|
||||||
if stats["missing_records"]:
|
if stats["missing_records"]:
|
||||||
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
message += (
|
||||||
|
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
|
||||||
|
)
|
||||||
split_messages.append(message)
|
split_messages.append(message)
|
||||||
|
|
||||||
for msg in split_messages:
|
for msg in split_messages:
|
||||||
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
processed_images += 1
|
processed_images += 1
|
||||||
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
|
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
|
||||||
".txt"
|
|
||||||
)
|
|
||||||
label_path.parent.mkdir(parents=True, exist_ok=True)
|
label_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
found, annotation_entries = self._fetch_annotations_for_image(
|
found, annotation_entries = self._fetch_annotations_for_image(
|
||||||
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
|
|||||||
for entry in annotation_entries:
|
for entry in annotation_entries:
|
||||||
polygon = entry.get("polygon") or []
|
polygon = entry.get("polygon") or []
|
||||||
if polygon:
|
if polygon:
|
||||||
|
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
|
||||||
|
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
|
||||||
|
# coords += " "
|
||||||
coords = " ".join(f"{value:.6f}" for value in polygon)
|
coords = " ".join(f"{value:.6f}" for value in polygon)
|
||||||
handle.write(f"{entry['class_idx']} {coords}\n")
|
handle.write(f"{entry['class_idx']} {coords}\n")
|
||||||
annotations_written += 1
|
annotations_written += 1
|
||||||
elif entry.get("bbox"):
|
elif entry.get("bbox"):
|
||||||
x_center, y_center, width, height = entry["bbox"]
|
x_center, y_center, width, height = entry["bbox"]
|
||||||
handle.write(
|
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
|
||||||
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
|
|
||||||
)
|
|
||||||
annotations_written += 1
|
annotations_written += 1
|
||||||
|
|
||||||
total_annotations += annotations_written
|
total_annotations += annotations_written
|
||||||
|
|
||||||
cache_reset_root = labels_dir.parent
|
cache_reset_root = labels_dir.parent
|
||||||
self._invalidate_split_cache(cache_reset_root)
|
self._invalidate_split_cache(cache_reset_root)
|
||||||
|
|
||||||
if processed_images == 0:
|
if processed_images == 0:
|
||||||
self._append_training_log(
|
self._append_training_log(f"[{split_name}] No images found to export labels for.")
|
||||||
f"[{split_name}] No images found to export labels for."
|
|
||||||
)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
|
|||||||
xs.append(x_val)
|
xs.append(x_val)
|
||||||
ys.append(y_val)
|
ys.append(y_val)
|
||||||
|
|
||||||
|
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
|
||||||
|
print("Closing polygon")
|
||||||
|
coords.extend(coords[:2])
|
||||||
|
|
||||||
if len(coords) < 6:
|
if len(coords) < 6:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
|
|||||||
+ abs((min(ys) if ys else 0.0) - y_min)
|
+ abs((min(ys) if ys else 0.0) - y_min)
|
||||||
+ abs((max(ys) if ys else 0.0) - y_max)
|
+ abs((max(ys) if ys else 0.0) - y_max)
|
||||||
)
|
)
|
||||||
|
width = max(0.0, x_max - x_min)
|
||||||
|
height = max(0.0, y_max - y_min)
|
||||||
|
x_center = x_min + width / 2.0
|
||||||
|
y_center = y_min + height / 2.0
|
||||||
|
score = (x_center, y_center, width, height)
|
||||||
|
|
||||||
candidates.append((score, coords))
|
candidates.append((score, coords))
|
||||||
|
|
||||||
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
|
|||||||
return 1.0
|
return 1.0
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def _prepare_dataset_for_training(
|
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
|
||||||
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
|
|
||||||
) -> Path:
|
|
||||||
dataset_info = dataset_info or (
|
dataset_info = dataset_info or (
|
||||||
self.selected_dataset
|
self.selected_dataset
|
||||||
if self.selected_dataset
|
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
||||||
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
|
|
||||||
else self._parse_dataset_yaml(dataset_yaml)
|
else self._parse_dataset_yaml(dataset_yaml)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
|
|||||||
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
cache_root = self._get_rgb_cache_root(dataset_yaml)
|
||||||
rgb_yaml = cache_root / "data.yaml"
|
rgb_yaml = cache_root / "data.yaml"
|
||||||
if rgb_yaml.exists():
|
if rgb_yaml.exists():
|
||||||
self._append_training_log(
|
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
|
||||||
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
|
|
||||||
)
|
|
||||||
return rgb_yaml
|
return rgb_yaml
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
|
||||||
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
|
|
||||||
)
|
|
||||||
self._build_rgb_dataset(cache_root, dataset_info)
|
self._build_rgb_dataset(cache_root, dataset_info)
|
||||||
return rgb_yaml
|
return rgb_yaml
|
||||||
|
|
||||||
@@ -1463,15 +1427,12 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
dataset_path = Path(dataset_yaml).expanduser()
|
dataset_path = Path(dataset_yaml).expanduser()
|
||||||
if not dataset_path.exists():
|
if not dataset_path.exists():
|
||||||
QMessageBox.warning(
|
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
|
||||||
self, "Invalid Dataset", "Selected data.yaml file does not exist."
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
dataset_info = (
|
dataset_info = (
|
||||||
self.selected_dataset
|
self.selected_dataset
|
||||||
if self.selected_dataset
|
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
||||||
and self.selected_dataset.get("yaml_path") == str(dataset_path)
|
|
||||||
else self._parse_dataset_yaml(dataset_path)
|
else self._parse_dataset_yaml(dataset_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1480,16 +1441,12 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
|
||||||
if dataset_to_use != dataset_path:
|
if dataset_to_use != dataset_path:
|
||||||
self._append_training_log(
|
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
|
||||||
f"Using RGB-converted dataset at {dataset_to_use.parent}"
|
|
||||||
)
|
|
||||||
|
|
||||||
params = self._collect_training_params()
|
params = self._collect_training_params()
|
||||||
stage_plan = self._compose_stage_plan(params)
|
stage_plan = self._compose_stage_plan(params)
|
||||||
params["stage_plan"] = stage_plan
|
params["stage_plan"] = stage_plan
|
||||||
total_planned_epochs = (
|
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
||||||
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
|
|
||||||
)
|
|
||||||
params["total_planned_epochs"] = total_planned_epochs
|
params["total_planned_epochs"] = total_planned_epochs
|
||||||
self._active_training_params = params
|
self._active_training_params = params
|
||||||
self._training_cancelled = False
|
self._training_cancelled = False
|
||||||
@@ -1498,9 +1455,7 @@ class TrainingTab(QWidget):
|
|||||||
self._append_training_log("Two-stage fine-tuning schedule:")
|
self._append_training_log("Two-stage fine-tuning schedule:")
|
||||||
self._log_stage_plan(stage_plan)
|
self._log_stage_plan(stage_plan)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
|
||||||
f"Starting training run '{params['run_name']}' using {params['base_model']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.training_progress_bar.setVisible(True)
|
self.training_progress_bar.setVisible(True)
|
||||||
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
|
||||||
@@ -1528,9 +1483,7 @@ class TrainingTab(QWidget):
|
|||||||
def _stop_training(self):
|
def _stop_training(self):
|
||||||
if self.training_worker and self.training_worker.isRunning():
|
if self.training_worker and self.training_worker.isRunning():
|
||||||
self._training_cancelled = True
|
self._training_cancelled = True
|
||||||
self._append_training_log(
|
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
|
||||||
"Stop requested. Waiting for the current epoch to finish..."
|
|
||||||
)
|
|
||||||
self.training_worker.stop()
|
self.training_worker.stop()
|
||||||
self.stop_training_button.setEnabled(False)
|
self.stop_training_button.setEnabled(False)
|
||||||
|
|
||||||
@@ -1566,9 +1519,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
if worker.isRunning():
|
if worker.isRunning():
|
||||||
if not worker.wait(wait_timeout_ms):
|
if not worker.wait(wait_timeout_ms):
|
||||||
logger.warning(
|
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
|
||||||
"Training worker did not finish within %sms", wait_timeout_ms
|
|
||||||
)
|
|
||||||
|
|
||||||
worker.deleteLater()
|
worker.deleteLater()
|
||||||
|
|
||||||
@@ -1585,16 +1536,12 @@ class TrainingTab(QWidget):
|
|||||||
self._set_training_state(False)
|
self._set_training_state(False)
|
||||||
self.training_progress_bar.setVisible(False)
|
self.training_progress_bar.setVisible(False)
|
||||||
|
|
||||||
def _on_training_progress(
|
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
|
||||||
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
|
|
||||||
):
|
|
||||||
self.training_progress_bar.setMaximum(total_epochs)
|
self.training_progress_bar.setMaximum(total_epochs)
|
||||||
self.training_progress_bar.setValue(current_epoch)
|
self.training_progress_bar.setValue(current_epoch)
|
||||||
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
parts = [f"Epoch {current_epoch}/{total_epochs}"]
|
||||||
if metrics:
|
if metrics:
|
||||||
metric_text = ", ".join(
|
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
|
||||||
f"{key}: {value:.4f}" for key, value in metrics.items()
|
|
||||||
)
|
|
||||||
parts.append(metric_text)
|
parts.append(metric_text)
|
||||||
self._append_training_log(" | ".join(parts))
|
self._append_training_log(" | ".join(parts))
|
||||||
|
|
||||||
@@ -1621,9 +1568,7 @@ class TrainingTab(QWidget):
|
|||||||
f"Model trained but not registered: {exc}",
|
f"Model trained but not registered: {exc}",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
QMessageBox.information(
|
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
|
||||||
self, "Training Complete", "Training finished successfully."
|
|
||||||
)
|
|
||||||
|
|
||||||
def _on_training_error(self, message: str):
|
def _on_training_error(self, message: str):
|
||||||
self._cleanup_training_worker()
|
self._cleanup_training_worker()
|
||||||
@@ -1669,9 +1614,7 @@ class TrainingTab(QWidget):
|
|||||||
metrics=results.get("metrics"),
|
metrics=results.get("metrics"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self._append_training_log(
|
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
|
||||||
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
|
|
||||||
)
|
|
||||||
self._active_training_params = None
|
self._active_training_params = None
|
||||||
|
|
||||||
def _set_training_state(self, is_training: bool):
|
def _set_training_state(self, is_training: bool):
|
||||||
@@ -1714,9 +1657,7 @@ class TrainingTab(QWidget):
|
|||||||
|
|
||||||
def _browse_save_dir(self):
|
def _browse_save_dir(self):
|
||||||
start_path = self.save_dir_edit.text().strip() or "data/models"
|
start_path = self.save_dir_edit.text().strip() or "data/models"
|
||||||
directory = QFileDialog.getExistingDirectory(
|
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
|
||||||
self, "Select Save Directory", start_path
|
|
||||||
)
|
|
||||||
if directory:
|
if directory:
|
||||||
self.save_dir_edit.setText(directory)
|
self.save_dir_edit.setText(directory)
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from PySide6.QtGui import (
|
|||||||
QPaintEvent,
|
QPaintEvent,
|
||||||
QPolygonF,
|
QPolygonF,
|
||||||
)
|
)
|
||||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
|
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from src.utils.image import Image, ImageLoadError
|
from src.utils.image import Image, ImageLoadError
|
||||||
@@ -79,9 +79,7 @@ def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float,
|
|||||||
return [start, end]
|
return [start, end]
|
||||||
|
|
||||||
|
|
||||||
def simplify_polyline(
|
def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
|
||||||
points: List[Tuple[float, float]], epsilon: float
|
|
||||||
) -> List[Tuple[float, float]]:
|
|
||||||
"""
|
"""
|
||||||
Simplify a polyline with RDP while preserving closure semantics.
|
Simplify a polyline with RDP while preserving closure semantics.
|
||||||
|
|
||||||
@@ -145,6 +143,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self.zoom_step = 0.1
|
self.zoom_step = 0.1
|
||||||
self.zoom_wheel_step = 0.15
|
self.zoom_wheel_step = 0.15
|
||||||
|
|
||||||
|
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
|
||||||
|
# will scale to fill the available viewport while preserving aspect ratio.
|
||||||
|
self._auto_fit_to_view: bool = False
|
||||||
|
|
||||||
# Drawing / interaction state
|
# Drawing / interaction state
|
||||||
self.is_drawing = False
|
self.is_drawing = False
|
||||||
self.polyline_enabled = False
|
self.polyline_enabled = False
|
||||||
@@ -175,6 +177,35 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
self._setup_ui()
|
self._setup_ui()
|
||||||
|
|
||||||
|
def set_auto_fit_to_view(self, enabled: bool):
|
||||||
|
"""Enable/disable automatic zoom-to-fit behavior."""
|
||||||
|
self._auto_fit_to_view = bool(enabled)
|
||||||
|
if self._auto_fit_to_view and self.original_pixmap is not None:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
|
def fit_to_view(self, padding_px: int = 6):
|
||||||
|
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
|
||||||
|
if self.original_pixmap is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
viewport = self.scroll_area.viewport().size()
|
||||||
|
available_w = max(1, int(viewport.width()) - int(padding_px))
|
||||||
|
available_h = max(1, int(viewport.height()) - int(padding_px))
|
||||||
|
|
||||||
|
img_w = max(1, int(self.original_pixmap.width()))
|
||||||
|
img_h = max(1, int(self.original_pixmap.height()))
|
||||||
|
|
||||||
|
scale_w = available_w / img_w
|
||||||
|
scale_h = available_h / img_h
|
||||||
|
new_scale = min(scale_w, scale_h)
|
||||||
|
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
|
||||||
|
|
||||||
|
if abs(new_scale - self.zoom_scale) < 1e-4:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.zoom_scale = new_scale
|
||||||
|
self._apply_zoom()
|
||||||
|
|
||||||
def _setup_ui(self):
|
def _setup_ui(self):
|
||||||
"""Setup user interface."""
|
"""Setup user interface."""
|
||||||
layout = QVBoxLayout()
|
layout = QVBoxLayout()
|
||||||
@@ -187,9 +218,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
self.canvas_label = QLabel("No image loaded")
|
self.canvas_label = QLabel("No image loaded")
|
||||||
self.canvas_label.setAlignment(Qt.AlignCenter)
|
self.canvas_label.setAlignment(Qt.AlignCenter)
|
||||||
self.canvas_label.setStyleSheet(
|
self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
|
||||||
"QLabel { background-color: #2b2b2b; color: #888; }"
|
|
||||||
)
|
|
||||||
self.canvas_label.setScaledContents(False)
|
self.canvas_label.setScaledContents(False)
|
||||||
self.canvas_label.setMouseTracking(True)
|
self.canvas_label.setMouseTracking(True)
|
||||||
|
|
||||||
@@ -212,9 +241,18 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self.zoom_scale = 1.0
|
self.zoom_scale = 1.0
|
||||||
self.clear_annotations()
|
self.clear_annotations()
|
||||||
self._display_image()
|
self._display_image()
|
||||||
logger.debug(
|
|
||||||
f"Loaded image into annotation canvas: {image.width}x{image.height}"
|
# Defer fit-to-view until the widget has a valid viewport size.
|
||||||
)
|
if self._auto_fit_to_view:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
|
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
|
||||||
|
|
||||||
|
def resizeEvent(self, event):
|
||||||
|
"""Optionally keep the image fitted when the widget is resized."""
|
||||||
|
super().resizeEvent(event)
|
||||||
|
if self._auto_fit_to_view and self.original_pixmap is not None:
|
||||||
|
QTimer.singleShot(0, self.fit_to_view)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""Clear the displayed image and all annotations."""
|
"""Clear the displayed image and all annotations."""
|
||||||
@@ -289,22 +327,14 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
scaled_width,
|
scaled_width,
|
||||||
scaled_height,
|
scaled_height,
|
||||||
Qt.KeepAspectRatio,
|
Qt.KeepAspectRatio,
|
||||||
(
|
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
|
||||||
Qt.SmoothTransformation
|
|
||||||
if self.zoom_scale >= 1.0
|
|
||||||
else Qt.FastTransformation
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
scaled_annotations = self.annotation_pixmap.scaled(
|
scaled_annotations = self.annotation_pixmap.scaled(
|
||||||
scaled_width,
|
scaled_width,
|
||||||
scaled_height,
|
scaled_height,
|
||||||
Qt.KeepAspectRatio,
|
Qt.KeepAspectRatio,
|
||||||
(
|
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
|
||||||
Qt.SmoothTransformation
|
|
||||||
if self.zoom_scale >= 1.0
|
|
||||||
else Qt.FastTransformation
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Composite image and annotations
|
# Composite image and annotations
|
||||||
@@ -390,16 +420,11 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
y = (pos.y() - offset_y) / self.zoom_scale
|
y = (pos.y() - offset_y) / self.zoom_scale
|
||||||
|
|
||||||
# Check bounds
|
# Check bounds
|
||||||
if (
|
if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
|
||||||
0 <= x < self.original_pixmap.width()
|
|
||||||
and 0 <= y < self.original_pixmap.height()
|
|
||||||
):
|
|
||||||
return (int(x), int(y))
|
return (int(x), int(y))
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _find_polyline_at(
|
def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
|
||||||
self, img_x: float, img_y: float, threshold_px: float = 5.0
|
|
||||||
) -> Optional[int]:
|
|
||||||
"""
|
"""
|
||||||
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
|
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
|
||||||
Returns the index in self.polylines, or None if none is close enough.
|
Returns the index in self.polylines, or None if none is close enough.
|
||||||
@@ -421,9 +446,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
# Precise distance to all segments
|
# Precise distance to all segments
|
||||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||||
d = perpendicular_distance(
|
d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
|
||||||
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
|
|
||||||
)
|
|
||||||
if d < best_dist:
|
if d < best_dist:
|
||||||
best_dist = d
|
best_dist = d
|
||||||
best_index = idx
|
best_index = idx
|
||||||
@@ -624,11 +647,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
def mouseMoveEvent(self, event: QMouseEvent):
|
def mouseMoveEvent(self, event: QMouseEvent):
|
||||||
"""Handle mouse move events for drawing."""
|
"""Handle mouse move events for drawing."""
|
||||||
if (
|
if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
|
||||||
not self.is_drawing
|
|
||||||
or not self.polyline_enabled
|
|
||||||
or self.annotation_pixmap is None
|
|
||||||
):
|
|
||||||
super().mouseMoveEvent(event)
|
super().mouseMoveEvent(event)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -688,15 +707,10 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
if len(simplified) >= 2:
|
if len(simplified) >= 2:
|
||||||
# Store polyline and redraw all annotations
|
# Store polyline and redraw all annotations
|
||||||
self._add_polyline(
|
self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
|
||||||
simplified, self.polyline_pen_color, self.polyline_pen_width
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to normalized coordinates for metadata + signal
|
# Convert to normalized coordinates for metadata + signal
|
||||||
normalized_stroke = [
|
normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
|
||||||
self._image_to_normalized_coords(int(x), int(y))
|
|
||||||
for (x, y) in simplified
|
|
||||||
]
|
|
||||||
self.all_strokes.append(
|
self.all_strokes.append(
|
||||||
{
|
{
|
||||||
"points": normalized_stroke,
|
"points": normalized_stroke,
|
||||||
@@ -709,8 +723,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
# Emit signal with normalized coordinates
|
# Emit signal with normalized coordinates
|
||||||
self.annotation_drawn.emit(normalized_stroke)
|
self.annotation_drawn.emit(normalized_stroke)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Completed stroke with {len(simplified)} points "
|
f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
|
||||||
f"(normalized len={len(normalized_stroke)})"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.current_stroke = []
|
self.current_stroke = []
|
||||||
@@ -750,9 +763,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
# Store polyline as [y_norm, x_norm] to match DB convention and
|
# Store polyline as [y_norm, x_norm] to match DB convention and
|
||||||
# the expectations of draw_saved_polyline().
|
# the expectations of draw_saved_polyline().
|
||||||
normalized_polyline = [
|
normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
|
||||||
[y / img_height, x / img_width] for (x, y) in polyline
|
|
||||||
]
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Polyline {idx}: {len(polyline)} points, "
|
f"Polyline {idx}: {len(polyline)} points, "
|
||||||
@@ -772,7 +783,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
self,
|
self,
|
||||||
polyline: List[List[float]],
|
polyline: List[List[float]],
|
||||||
color: str,
|
color: str,
|
||||||
width: int = 3,
|
width: int = 1,
|
||||||
annotation_id: Optional[int] = None,
|
annotation_id: Optional[int] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -810,17 +821,13 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
|
|
||||||
# Store and redraw using common pipeline
|
# Store and redraw using common pipeline
|
||||||
pen_color = QColor(color)
|
pen_color = QColor(color)
|
||||||
pen_color.setAlpha(128) # Add semi-transparency
|
pen_color.setAlpha(255) # Add semi-transparency
|
||||||
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
|
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
|
||||||
|
|
||||||
# Store in all_strokes for consistency (uses normalized coordinates)
|
# Store in all_strokes for consistency (uses normalized coordinates)
|
||||||
self.all_strokes.append(
|
self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
|
||||||
{"points": polyline, "color": color, "alpha": 128, "width": width}
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
|
||||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def draw_saved_bbox(
|
def draw_saved_bbox(
|
||||||
self,
|
self,
|
||||||
@@ -844,9 +851,7 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
return
|
return
|
||||||
|
|
||||||
if len(bbox) != 4:
|
if len(bbox) != 4:
|
||||||
logger.warning(
|
logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
|
||||||
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
|
|
||||||
)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Convert normalized coordinates to image coordinates (for logging/debug)
|
# Convert normalized coordinates to image coordinates (for logging/debug)
|
||||||
@@ -867,15 +872,11 @@ class AnnotationCanvasWidget(QWidget):
|
|||||||
# in _redraw_annotations() together with all polylines.
|
# in _redraw_annotations() together with all polylines.
|
||||||
pen_color = QColor(color)
|
pen_color = QColor(color)
|
||||||
pen_color.setAlpha(128) # Add semi-transparency
|
pen_color.setAlpha(128) # Add semi-transparency
|
||||||
self.bboxes.append(
|
self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
|
||||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
|
||||||
)
|
|
||||||
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||||
|
|
||||||
# Store in all_strokes for consistency
|
# Store in all_strokes for consistency
|
||||||
self.all_strokes.append(
|
self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
|
||||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Redraw overlay (polylines + all bounding boxes)
|
# Redraw overlay (polylines + all bounding boxes)
|
||||||
self._redraw_annotations()
|
self._redraw_annotations()
|
||||||
|
|||||||
@@ -96,9 +96,7 @@ class YOLOWrapper:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting training: {name}")
|
logger.info(f"Starting training: {name}")
|
||||||
logger.info(
|
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
|
||||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
||||||
# Users can override by passing explicit kwargs.
|
# Users can override by passing explicit kwargs.
|
||||||
@@ -149,9 +147,7 @@ class YOLOWrapper:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Starting validation on {split} split")
|
logger.info(f"Starting validation on {split} split")
|
||||||
results = self.model.val(
|
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
|
||||||
data=data_yaml, split=split, device=self.device, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info("Validation completed successfully")
|
logger.info("Validation completed successfully")
|
||||||
return self._format_validation_results(results)
|
return self._format_validation_results(results)
|
||||||
@@ -190,11 +186,9 @@ class YOLOWrapper:
|
|||||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||||
|
|
||||||
prepared_source, cleanup_path = self._prepare_source(source)
|
prepared_source, cleanup_path = self._prepare_source(source)
|
||||||
|
imgsz = 1088
|
||||||
try:
|
try:
|
||||||
logger.info(
|
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
||||||
f"Running inference on {source} -> prepared_source {prepared_source}"
|
|
||||||
)
|
|
||||||
results = self.model.predict(
|
results = self.model.predict(
|
||||||
source=source,
|
source=source,
|
||||||
conf=conf,
|
conf=conf,
|
||||||
@@ -203,6 +197,7 @@ class YOLOWrapper:
|
|||||||
save_txt=save_txt,
|
save_txt=save_txt,
|
||||||
save_conf=save_conf,
|
save_conf=save_conf,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
|
imgsz=imgsz,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -218,13 +213,9 @@ class YOLOWrapper:
|
|||||||
try:
|
try:
|
||||||
os.remove(cleanup_path)
|
os.remove(cleanup_path)
|
||||||
except OSError as cleanup_error:
|
except OSError as cleanup_error:
|
||||||
logger.warning(
|
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
|
||||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def export(
|
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
|
||||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Export model to different format.
|
Export model to different format.
|
||||||
|
|
||||||
@@ -265,9 +256,7 @@ class YOLOWrapper:
|
|||||||
tmp.close()
|
tmp.close()
|
||||||
img_obj.save(tmp_path)
|
img_obj.save(tmp_path)
|
||||||
cleanup_path = tmp_path
|
cleanup_path = tmp_path
|
||||||
logger.info(
|
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
|
||||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
|
||||||
)
|
|
||||||
return tmp_path, cleanup_path
|
return tmp_path, cleanup_path
|
||||||
except Exception as convert_error:
|
except Exception as convert_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -280,9 +269,7 @@ class YOLOWrapper:
|
|||||||
"""Format training results into dictionary."""
|
"""Format training results into dictionary."""
|
||||||
try:
|
try:
|
||||||
# Get the results dict
|
# Get the results dict
|
||||||
results_dict = (
|
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
|
||||||
results.results_dict if hasattr(results, "results_dict") else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
formatted = {
|
formatted = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -315,9 +302,7 @@ class YOLOWrapper:
|
|||||||
"mAP50-95": float(box_metrics.map),
|
"mAP50-95": float(box_metrics.map),
|
||||||
"precision": float(box_metrics.mp),
|
"precision": float(box_metrics.mp),
|
||||||
"recall": float(box_metrics.mr),
|
"recall": float(box_metrics.mr),
|
||||||
"fitness": (
|
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
|
||||||
float(results.fitness) if hasattr(results, "fitness") else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Add per-class metrics if available
|
# Add per-class metrics if available
|
||||||
@@ -327,11 +312,7 @@ class YOLOWrapper:
|
|||||||
if idx < len(box_metrics.ap):
|
if idx < len(box_metrics.ap):
|
||||||
class_metrics[name] = {
|
class_metrics[name] = {
|
||||||
"ap": float(box_metrics.ap[idx]),
|
"ap": float(box_metrics.ap[idx]),
|
||||||
"ap50": (
|
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
|
||||||
float(box_metrics.ap50[idx])
|
|
||||||
if hasattr(box_metrics, "ap50")
|
|
||||||
else 0.0
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
formatted["class_metrics"] = class_metrics
|
formatted["class_metrics"] = class_metrics
|
||||||
|
|
||||||
@@ -364,21 +345,15 @@ class YOLOWrapper:
|
|||||||
"class_id": int(boxes.cls[i]),
|
"class_id": int(boxes.cls[i]),
|
||||||
"class_name": result.names[int(boxes.cls[i])],
|
"class_name": result.names[int(boxes.cls[i])],
|
||||||
"confidence": float(boxes.conf[i]),
|
"confidence": float(boxes.conf[i]),
|
||||||
"bbox_normalized": [
|
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
|
||||||
float(v) for v in xyxyn
|
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
|
||||||
], # [x_min, y_min, x_max, y_max]
|
|
||||||
"bbox_absolute": [
|
|
||||||
float(v) for v in boxes.xyxy[i].cpu().numpy()
|
|
||||||
], # Absolute pixels
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# Extract segmentation mask if available
|
# Extract segmentation mask if available
|
||||||
if has_masks:
|
if has_masks:
|
||||||
try:
|
try:
|
||||||
# Get the mask for this detection
|
# Get the mask for this detection
|
||||||
mask_data = result.masks.xy[
|
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
|
||||||
i
|
|
||||||
] # Polygon coordinates in absolute pixels
|
|
||||||
|
|
||||||
# Convert to normalized coordinates
|
# Convert to normalized coordinates
|
||||||
if len(mask_data) > 0:
|
if len(mask_data) > 0:
|
||||||
@@ -391,9 +366,7 @@ class YOLOWrapper:
|
|||||||
else:
|
else:
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
except Exception as mask_error:
|
except Exception as mask_error:
|
||||||
logger.warning(
|
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
|
||||||
f"Error extracting mask for detection {i}: {mask_error}"
|
|
||||||
)
|
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
else:
|
else:
|
||||||
detection["segmentation_mask"] = None
|
detection["segmentation_mask"] = None
|
||||||
@@ -407,9 +380,7 @@ class YOLOWrapper:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def convert_bbox_format(
|
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
|
||||||
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
|
|
||||||
) -> List[float]:
|
|
||||||
"""
|
"""
|
||||||
Convert bounding box between formats.
|
Convert bounding box between formats.
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
|||||||
a1[a1 > p999] = p999
|
a1[a1 > p999] = p999
|
||||||
a1 /= a1.max()
|
a1 /= a1.max()
|
||||||
|
|
||||||
if 0:
|
if 1:
|
||||||
a2 = a1.copy()
|
a2 = a1.copy()
|
||||||
a2 = a2**gamma
|
a2 = a2**gamma
|
||||||
a2 /= a2.max()
|
a2 /= a2.max()
|
||||||
@@ -47,9 +47,12 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
|||||||
a3[a3 > p9999] = p9999
|
a3[a3 > p9999] = p9999
|
||||||
a3 /= a3.max()
|
a3 /= a3.max()
|
||||||
|
|
||||||
return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||||
# return np.stack([a1, a2, a3], axis=0)
|
out = np.stack([a1, a2, a3], axis=0)
|
||||||
|
# print(any(np.isnan(out).flatten()))
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
class ImageLoadError(Exception):
|
class ImageLoadError(Exception):
|
||||||
@@ -122,7 +125,7 @@ class Image:
|
|||||||
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||||
self._data = imread(str(self.path))
|
self._data = imread(str(self.path))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("RGB is not implemented")
|
# raise NotImplementedError("RGB is not implemented")
|
||||||
# Load with OpenCV (returns BGR format)
|
# Load with OpenCV (returns BGR format)
|
||||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
@@ -246,27 +249,33 @@ class Image:
|
|||||||
if self.channels == 1:
|
if self.channels == 1:
|
||||||
img = get_pseudo_rgb(self.data)
|
img = get_pseudo_rgb(self.data)
|
||||||
self._dtype = img.dtype
|
self._dtype = img.dtype
|
||||||
return img
|
return img, True
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
if self._channels == 3:
|
elif self._channels == 3:
|
||||||
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
|
||||||
elif self._channels == 4:
|
elif self._channels == 4:
|
||||||
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
|
||||||
|
|
||||||
else:
|
else:
|
||||||
return self._data
|
raise NotImplementedError
|
||||||
|
|
||||||
|
# else:
|
||||||
|
# return self._data
|
||||||
|
|
||||||
def get_qt_rgb(self) -> np.ascontiguousarray:
|
def get_qt_rgb(self) -> np.ascontiguousarray:
|
||||||
# we keep data as (C, H, W)
|
# we keep data as (C, H, W)
|
||||||
_img = self.get_rgb()
|
_img, pseudo = self.get_rgb()
|
||||||
|
|
||||||
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
|
if pseudo:
|
||||||
img[..., 0] = _img[0] # R gradient
|
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
|
||||||
img[..., 1] = _img[1] # G gradient
|
img[..., 0] = _img[0] # R gradient
|
||||||
img[..., 2] = _img[2] # B constant
|
img[..., 1] = _img[1] # G gradient
|
||||||
img[..., 3] = 1.0 # A = 1.0 (opaque)
|
img[..., 2] = _img[2] # B constant
|
||||||
|
img[..., 3] = 1.0 # A = 1.0 (opaque)
|
||||||
|
|
||||||
return np.ascontiguousarray(img)
|
return np.ascontiguousarray(img)
|
||||||
|
else:
|
||||||
|
return np.ascontiguousarray(_img)
|
||||||
|
|
||||||
def get_grayscale(self) -> np.ndarray:
|
def get_grayscale(self) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -114,11 +114,12 @@ class Label:
|
|||||||
return truth_val
|
return truth_val
|
||||||
|
|
||||||
def to_string(self, bbox: list = None, polygon: list = None):
|
def to_string(self, bbox: list = None, polygon: list = None):
|
||||||
|
coords = ""
|
||||||
if bbox is None:
|
if bbox is None:
|
||||||
bbox = self.bbox
|
bbox = self.bbox
|
||||||
|
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
|
||||||
if polygon is None:
|
if polygon is None:
|
||||||
polygon = self.polygon
|
polygon = self.polygon
|
||||||
coords = " ".join([f"{x:.6f}" for x in self.bbox])
|
|
||||||
if self.polygon is not None:
|
if self.polygon is not None:
|
||||||
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
|
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
|
||||||
return f"{self.class_id} {coords}"
|
return f"{self.class_id} {coords}"
|
||||||
@@ -179,6 +180,13 @@ class ImageSplitter:
|
|||||||
|
|
||||||
for i in range(patch_size[0]):
|
for i in range(patch_size[0]):
|
||||||
for j in range(patch_size[1]):
|
for j in range(patch_size[1]):
|
||||||
|
metadata = {
|
||||||
|
"image_path": str(self.image_path),
|
||||||
|
"label_path": str(self.label_path),
|
||||||
|
"tile_section": f"{i}, {j}",
|
||||||
|
"tile_size": f"{hstep}, {wstep}",
|
||||||
|
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
|
||||||
|
}
|
||||||
tile_reference = f"i{i}j{j}"
|
tile_reference = f"i{i}j{j}"
|
||||||
hrange = (i * hstep / h, (i + 1) * hstep / h)
|
hrange = (i * hstep / h, (i + 1) * hstep / h)
|
||||||
wrange = (j * wstep / w, (j + 1) * wstep / w)
|
wrange = (j * wstep / w, (j + 1) * wstep / w)
|
||||||
@@ -199,7 +207,7 @@ class ImageSplitter:
|
|||||||
print(l.bbox)
|
print(l.bbox)
|
||||||
|
|
||||||
# print(labels)
|
# print(labels)
|
||||||
yield tile_reference, tile, labels
|
yield tile_reference, tile, labels, metadata
|
||||||
|
|
||||||
def split_respective_to_label(self, padding: int = 67):
|
def split_respective_to_label(self, padding: int = 67):
|
||||||
if self.labels is None:
|
if self.labels is None:
|
||||||
@@ -208,6 +216,7 @@ class ImageSplitter:
|
|||||||
for i, label in enumerate(self.labels):
|
for i, label in enumerate(self.labels):
|
||||||
tile_reference = f"_lbl-{i+1:02d}"
|
tile_reference = f"_lbl-{i+1:02d}"
|
||||||
# print(label.bbox)
|
# print(label.bbox)
|
||||||
|
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
|
||||||
|
|
||||||
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
|
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
|
||||||
xc, yc, h, w = [
|
xc, yc, h, w = [
|
||||||
@@ -246,17 +255,17 @@ class ImageSplitter:
|
|||||||
|
|
||||||
# print("tile shape:", tile.shape)
|
# print("tile shape:", tile.shape)
|
||||||
|
|
||||||
yolo_annotation = f"{label.class_id} {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
|
yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
|
||||||
print(yolo_annotation)
|
|
||||||
yolo_annotation += " ".join(
|
yolo_annotation += " ".join(
|
||||||
[
|
[
|
||||||
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
|
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
|
||||||
for x, y in label.polygon
|
for x, y in label.polygon
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
print(yolo_annotation)
|
||||||
new_label = Label(yolo_annotation=yolo_annotation)
|
new_label = Label(yolo_annotation=yolo_annotation)
|
||||||
|
|
||||||
yield tile_reference, tile, [new_label]
|
yield tile_reference, tile, [new_label], metadata
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
@@ -278,9 +287,9 @@ def main(args):
|
|||||||
else:
|
else:
|
||||||
data = data.split_into_tiles(patch_size=args.patch_size)
|
data = data.split_into_tiles(patch_size=args.patch_size)
|
||||||
|
|
||||||
for tile_reference, tile, labels in data:
|
for tile_reference, tile, labels, metadata in data:
|
||||||
print()
|
print()
|
||||||
print(tile_reference, tile.shape, labels) # len(labels) if labels else None)
|
print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
|
||||||
|
|
||||||
# { debug
|
# { debug
|
||||||
debug = False
|
debug = False
|
||||||
@@ -310,15 +319,21 @@ def main(args):
|
|||||||
# } debug
|
# } debug
|
||||||
|
|
||||||
if args.output:
|
if args.output:
|
||||||
imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile)
|
# imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
|
||||||
scale = 5
|
scale = 5
|
||||||
tile_zoomed = zoom(tile, zoom=scale)
|
tile_zoomed = zoom(tile, zoom=scale)
|
||||||
imwrite(args.output / "images-zoomed" / f"{image_path.stem}_{tile_reference}.tif", tile_zoomed)
|
metadata["scale"] = scale
|
||||||
|
imwrite(
|
||||||
|
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
|
||||||
|
tile_zoomed,
|
||||||
|
metadata=metadata,
|
||||||
|
imagej=True,
|
||||||
|
)
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
|
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
|
||||||
for label in labels:
|
for label in labels:
|
||||||
label.offset_label(tile.shape[1], tile.shape[0])
|
# label.offset_label(tile.shape[1], tile.shape[0])
|
||||||
f.write(label.to_string() + "\n")
|
f.write(label.to_string() + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -72,8 +72,9 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
|||||||
# logger.info(f"Loading with monkey-patched imread: {filename}")
|
# logger.info(f"Loading with monkey-patched imread: {filename}")
|
||||||
arr = arr.astype(np.float32)
|
arr = arr.astype(np.float32)
|
||||||
arr /= arr.max()
|
arr /= arr.max()
|
||||||
arr *= 2**16 - 1
|
arr *= 2**8 - 1
|
||||||
arr = arr.astype(np.uint16)
|
arr = arr.astype(np.uint8)
|
||||||
|
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
|
||||||
return np.ascontiguousarray(arr)
|
return np.ascontiguousarray(arr)
|
||||||
|
|
||||||
# logger.info(f"Loading with original imread: {filename}")
|
# logger.info(f"Loading with original imread: {filename}")
|
||||||
@@ -105,7 +106,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
|||||||
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
|
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
|
||||||
# Start from upstream behavior to keep device placement + multiscale identical,
|
# Start from upstream behavior to keep device placement + multiscale identical,
|
||||||
# but replace the 255 division with dtype-aware scaling.
|
# but replace the 255 division with dtype-aware scaling.
|
||||||
logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
|
# logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
|
||||||
for k, v in batch.items():
|
for k, v in batch.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||||
|
|||||||
@@ -196,7 +196,9 @@ def main():
|
|||||||
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||||
yc, xc, h, w = bbox
|
yc, xc, h, w = bbox
|
||||||
print("bbox", bbox)
|
print("bbox", bbox)
|
||||||
polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
|
||||||
|
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
|
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||||
print("pl", coords[4:])
|
print("pl", coords[4:])
|
||||||
print("pl", polyline)
|
print("pl", polyline)
|
||||||
|
|
||||||
@@ -207,12 +209,13 @@ def main():
|
|||||||
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||||
plt.imshow(out_rgb)
|
plt.imshow(out_rgb)
|
||||||
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
||||||
plt.plot(
|
if 0:
|
||||||
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
|
plt.plot(
|
||||||
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
|
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
|
||||||
"r",
|
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
|
||||||
linewidth=2,
|
"r",
|
||||||
)
|
linewidth=2,
|
||||||
|
)
|
||||||
|
|
||||||
# plt.axis("off")
|
# plt.axis("off")
|
||||||
plt.title(f"{img_path.name} ({lbl_path.name})")
|
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||||
|
|||||||
Reference in New Issue
Block a user