Compare commits
42 Commits
segmentati
...
d998c65665
| Author | SHA1 | Date | |
|---|---|---|---|
| d998c65665 | |||
| 510eabfa94 | |||
| 395d263900 | |||
| e98d287b8a | |||
| d25101de2d | |||
| f88beef188 | |||
| 2fd9a2acf4 | |||
| 2bcd18cc75 | |||
| 5d25378c46 | |||
| 2b0b48921e | |||
| b0c05f0225 | |||
| 97badaa390 | |||
| 8f8132ce61 | |||
| 6ae7481e25 | |||
| 061f8b3ca2 | |||
| a8e5db3135 | |||
| 268ed5175e | |||
| 5e9d3b1dc4 | |||
| 7d83e9b9b1 | |||
| e364d06217 | |||
| e5036c10cf | |||
| c7e388d9ae | |||
| 6b995e7325 | |||
| 0e0741d323 | |||
| dd99a0677c | |||
| 9c4c39fb39 | |||
| 20a87c9040 | |||
| 9f7d2be1ac | |||
| dbde07c0e8 | |||
| b3c5a51dbb | |||
| 9a221acb63 | |||
| 32a6a122bd | |||
| 9ba44043ef | |||
| 8eb1cc8c86 | |||
| e4ce882a18 | |||
| 6b6d6fad03 | |||
| c0684a9c14 | |||
| 221c80aa8c | |||
| 833b222fad | |||
| 5370d31dce | |||
| 5d196c3a4a | |||
| f719c7ec40 |
@@ -1,41 +0,0 @@
|
||||
database:
|
||||
path: data/detections.db
|
||||
image_repository:
|
||||
base_path: ''
|
||||
allowed_extensions:
|
||||
- .jpg
|
||||
- .jpeg
|
||||
- .png
|
||||
- .tif
|
||||
- .tiff
|
||||
- .bmp
|
||||
models:
|
||||
default_base_model: yolov8s-seg.pt
|
||||
models_directory: data/models
|
||||
training:
|
||||
default_epochs: 100
|
||||
default_batch_size: 16
|
||||
default_imgsz: 640
|
||||
default_patience: 50
|
||||
default_lr0: 0.01
|
||||
detection:
|
||||
default_confidence: 0.25
|
||||
default_iou: 0.45
|
||||
max_batch_size: 100
|
||||
visualization:
|
||||
bbox_colors:
|
||||
organelle: '#FF6B6B'
|
||||
membrane_branch: '#4ECDC4'
|
||||
default: '#00FF00'
|
||||
bbox_thickness: 2
|
||||
font_size: 12
|
||||
export:
|
||||
formats:
|
||||
- csv
|
||||
- json
|
||||
- excel
|
||||
default_format: csv
|
||||
logging:
|
||||
level: INFO
|
||||
file: logs/app.log
|
||||
format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
@@ -82,12 +82,12 @@ include-package-data = true
|
||||
"src.database" = ["*.sql"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
line-length = 120
|
||||
target-version = ['py38', 'py39', 'py310', 'py311']
|
||||
include = '\.pyi?$'
|
||||
|
||||
[tool.pylint.messages_control]
|
||||
max-line-length = 88
|
||||
max-line-length = 120
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.8"
|
||||
|
||||
@@ -10,6 +10,14 @@ from typing import List, Dict, Optional, Tuple, Any, Union
|
||||
from pathlib import Path
|
||||
import csv
|
||||
import hashlib
|
||||
import yaml
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
IMAGE_EXTENSIONS = tuple(Image.SUPPORTED_EXTENSIONS)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseManager:
|
||||
@@ -443,6 +451,25 @@ class DatabaseManager:
|
||||
filters["model_id"] = model_id
|
||||
return self.get_detections(filters)
|
||||
|
||||
def delete_detections_for_image(
|
||||
self, image_id: int, model_id: Optional[int] = None
|
||||
) -> int:
|
||||
"""Delete detections tied to a specific image and optional model."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
if model_id is not None:
|
||||
cursor.execute(
|
||||
"DELETE FROM detections WHERE image_id = ? AND model_id = ?",
|
||||
(image_id, model_id),
|
||||
)
|
||||
else:
|
||||
cursor.execute("DELETE FROM detections WHERE image_id = ?", (image_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def delete_detections_for_model(self, model_id: int) -> int:
|
||||
"""Delete all detections for a specific model."""
|
||||
conn = self.get_connection()
|
||||
@@ -861,6 +888,187 @@ class DatabaseManager:
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Dataset Utilities ====================
|
||||
|
||||
def compose_data_yaml(
|
||||
self,
|
||||
dataset_root: str,
|
||||
output_path: Optional[str] = None,
|
||||
splits: Optional[Dict[str, str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compose a YOLO data.yaml file based on dataset folders and database metadata.
|
||||
|
||||
Args:
|
||||
dataset_root: Base directory containing the dataset structure.
|
||||
output_path: Optional output path; defaults to <dataset_root>/data.yaml.
|
||||
splits: Optional mapping overriding train/val/test image directories (relative
|
||||
to dataset_root or absolute paths).
|
||||
|
||||
Returns:
|
||||
Path to the generated YAML file.
|
||||
"""
|
||||
dataset_root_path = Path(dataset_root).expanduser()
|
||||
if not dataset_root_path.exists():
|
||||
raise ValueError(f"Dataset root does not exist: {dataset_root_path}")
|
||||
dataset_root_path = dataset_root_path.resolve()
|
||||
|
||||
split_map: Dict[str, str] = {key: "" for key in ("train", "val", "test")}
|
||||
if splits:
|
||||
for key, value in splits.items():
|
||||
if key in split_map and value:
|
||||
split_map[key] = value
|
||||
|
||||
inferred = self._infer_split_dirs(dataset_root_path)
|
||||
for key in split_map:
|
||||
if not split_map[key]:
|
||||
split_map[key] = inferred.get(key, "")
|
||||
|
||||
for required in ("train", "val"):
|
||||
if not split_map[required]:
|
||||
raise ValueError(
|
||||
"Unable to determine %s image directory under %s. Provide it "
|
||||
"explicitly via the 'splits' argument."
|
||||
% (required, dataset_root_path)
|
||||
)
|
||||
|
||||
yaml_splits: Dict[str, str] = {}
|
||||
for key, value in split_map.items():
|
||||
if not value:
|
||||
continue
|
||||
yaml_splits[key] = self._normalize_split_value(value, dataset_root_path)
|
||||
|
||||
class_names = self._fetch_annotation_class_names()
|
||||
if not class_names:
|
||||
class_names = [cls["class_name"] for cls in self.get_object_classes()]
|
||||
if not class_names:
|
||||
raise ValueError("No object classes available to populate data.yaml")
|
||||
|
||||
names_map = {idx: name for idx, name in enumerate(class_names)}
|
||||
payload: Dict[str, Any] = {
|
||||
"path": dataset_root_path.as_posix(),
|
||||
"train": yaml_splits["train"],
|
||||
"val": yaml_splits["val"],
|
||||
"names": names_map,
|
||||
"nc": len(class_names),
|
||||
}
|
||||
if yaml_splits.get("test"):
|
||||
payload["test"] = yaml_splits["test"]
|
||||
|
||||
output_path_obj = (
|
||||
Path(output_path).expanduser()
|
||||
if output_path
|
||||
else dataset_root_path / "data.yaml"
|
||||
)
|
||||
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path_obj, "w", encoding="utf-8") as handle:
|
||||
yaml.safe_dump(payload, handle, sort_keys=False)
|
||||
|
||||
logger.info(f"Generated data.yaml at {output_path_obj}")
|
||||
return output_path_obj.as_posix()
|
||||
|
||||
def _fetch_annotation_class_names(self) -> List[str]:
|
||||
"""Return class names referenced by annotations (ordered by class ID)."""
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT DISTINCT c.id, c.class_name
|
||||
FROM annotations a
|
||||
JOIN object_classes c ON a.class_id = c.id
|
||||
ORDER BY c.id
|
||||
"""
|
||||
)
|
||||
rows = cursor.fetchall()
|
||||
return [row["class_name"] for row in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _infer_split_dirs(self, dataset_root: Path) -> Dict[str, str]:
|
||||
"""Infer train/val/test image directories relative to dataset_root."""
|
||||
patterns = {
|
||||
"train": [
|
||||
"train/images",
|
||||
"training/images",
|
||||
"images/train",
|
||||
"images/training",
|
||||
"train",
|
||||
"training",
|
||||
],
|
||||
"val": [
|
||||
"val/images",
|
||||
"validation/images",
|
||||
"images/val",
|
||||
"images/validation",
|
||||
"val",
|
||||
"validation",
|
||||
],
|
||||
"test": [
|
||||
"test/images",
|
||||
"testing/images",
|
||||
"images/test",
|
||||
"images/testing",
|
||||
"test",
|
||||
"testing",
|
||||
],
|
||||
}
|
||||
|
||||
inferred: Dict[str, str] = {key: "" for key in patterns}
|
||||
for split_name, options in patterns.items():
|
||||
for relative in options:
|
||||
candidate = (dataset_root / relative).resolve()
|
||||
if (
|
||||
candidate.exists()
|
||||
and candidate.is_dir()
|
||||
and self._directory_has_images(candidate)
|
||||
):
|
||||
try:
|
||||
inferred[split_name] = candidate.relative_to(
|
||||
dataset_root
|
||||
).as_posix()
|
||||
except ValueError:
|
||||
inferred[split_name] = candidate.as_posix()
|
||||
break
|
||||
return inferred
|
||||
|
||||
def _normalize_split_value(self, split_value: str, dataset_root: Path) -> str:
|
||||
"""Validate and normalize a split directory to a YAML-friendly string."""
|
||||
split_path = Path(split_value).expanduser()
|
||||
if not split_path.is_absolute():
|
||||
split_path = (dataset_root / split_path).resolve()
|
||||
else:
|
||||
split_path = split_path.resolve()
|
||||
|
||||
if not split_path.exists() or not split_path.is_dir():
|
||||
raise ValueError(f"Split directory not found: {split_path}")
|
||||
|
||||
if not self._directory_has_images(split_path):
|
||||
raise ValueError(f"No images found under {split_path}")
|
||||
|
||||
try:
|
||||
return split_path.relative_to(dataset_root).as_posix()
|
||||
except ValueError:
|
||||
return split_path.as_posix()
|
||||
|
||||
@staticmethod
|
||||
def _directory_has_images(directory: Path, max_checks: int = 2000) -> bool:
|
||||
"""Return True if directory tree contains at least one image file."""
|
||||
checked = 0
|
||||
try:
|
||||
for file_path in directory.rglob("*"):
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
if file_path.suffix.lower() in IMAGE_EXTENSIONS:
|
||||
return True
|
||||
checked += 1
|
||||
if checked >= max_checks:
|
||||
break
|
||||
except Exception:
|
||||
return False
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def calculate_checksum(file_path: str) -> str:
|
||||
"""Calculate MD5 checksum of a file."""
|
||||
|
||||
@@ -297,7 +297,9 @@ class MainWindow(QMainWindow):
|
||||
# Save window state before closing
|
||||
self._save_window_state()
|
||||
|
||||
# Save annotation tab state if it exists
|
||||
# Persist tab state and stop background work before exit
|
||||
if hasattr(self, "training_tab"):
|
||||
self.training_tab.shutdown()
|
||||
if hasattr(self, "annotation_tab"):
|
||||
self.annotation_tab.save_state()
|
||||
|
||||
|
||||
@@ -168,7 +168,7 @@ class AnnotationTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
|
||||
@@ -20,12 +20,14 @@ from PySide6.QtWidgets import (
|
||||
)
|
||||
from PySide6.QtCore import Qt, QThread, Signal
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_image_files
|
||||
from src.model.inference import InferenceEngine
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -147,30 +149,66 @@ class DetectionTab(QWidget):
|
||||
self.model_combo.currentIndexChanged.connect(self._on_model_changed)
|
||||
|
||||
def _load_models(self):
|
||||
"""Load available models from database."""
|
||||
"""Load available models from database and local storage."""
|
||||
try:
|
||||
models = self.db_manager.get_models()
|
||||
self.model_combo.clear()
|
||||
models = self.db_manager.get_models()
|
||||
has_models = False
|
||||
|
||||
if not models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
return
|
||||
known_paths = set()
|
||||
|
||||
# Add base model option
|
||||
# Add base model option first (always available)
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
self.model_combo.addItem(
|
||||
f"Base Model ({base_model})", {"id": 0, "path": base_model}
|
||||
)
|
||||
if base_model:
|
||||
base_data = {
|
||||
"id": 0,
|
||||
"path": base_model,
|
||||
"model_name": Path(base_model).stem or "Base Model",
|
||||
"model_version": "pretrained",
|
||||
"base_model": base_model,
|
||||
"source": "base",
|
||||
}
|
||||
self.model_combo.addItem(f"Base Model ({base_model})", base_data)
|
||||
known_paths.add(self._normalize_model_path(base_model))
|
||||
has_models = True
|
||||
|
||||
# Add trained models
|
||||
# Add trained models from database
|
||||
for model in models:
|
||||
display_name = f"{model['model_name']} v{model['model_version']}"
|
||||
self.model_combo.addItem(display_name, model)
|
||||
model_data = {**model, "path": model.get("model_path")}
|
||||
normalized = self._normalize_model_path(model_data.get("path"))
|
||||
if normalized:
|
||||
known_paths.add(normalized)
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
has_models = True
|
||||
|
||||
self._set_buttons_enabled(True)
|
||||
# Discover local model files not yet in the database
|
||||
local_models = self._discover_local_models()
|
||||
for model_path in local_models:
|
||||
normalized = self._normalize_model_path(model_path)
|
||||
if normalized in known_paths:
|
||||
continue
|
||||
|
||||
display_name = f"Local Model ({Path(model_path).stem})"
|
||||
model_data = {
|
||||
"id": None,
|
||||
"path": str(model_path),
|
||||
"model_name": Path(model_path).stem,
|
||||
"model_version": "local",
|
||||
"base_model": Path(model_path).stem,
|
||||
"source": "local",
|
||||
}
|
||||
self.model_combo.addItem(display_name, model_data)
|
||||
known_paths.add(normalized)
|
||||
has_models = True
|
||||
|
||||
if not has_models:
|
||||
self.model_combo.addItem("No models available", None)
|
||||
self._set_buttons_enabled(False)
|
||||
else:
|
||||
self._set_buttons_enabled(True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
@@ -199,7 +237,7 @@ class DetectionTab(QWidget):
|
||||
self,
|
||||
"Select Image",
|
||||
start_dir,
|
||||
"Images (*.jpg *.jpeg *.png *.tif *.tiff *.bmp)",
|
||||
"Images (*" + " *".join(Image.SUPPORTED_EXTENSIONS) + ")",
|
||||
)
|
||||
|
||||
if not file_path:
|
||||
@@ -249,25 +287,39 @@ class DetectionTab(QWidget):
|
||||
QMessageBox.warning(self, "No Model", "Please select a model first.")
|
||||
return
|
||||
|
||||
model_path = model_data["path"]
|
||||
model_id = model_data["id"]
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
QMessageBox.warning(
|
||||
self, "Invalid Model", "Selected model is missing a file path."
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure we have a valid model ID (create entry for base model if needed)
|
||||
if model_id == 0:
|
||||
# Create database entry for base model
|
||||
base_model = self.config_manager.get(
|
||||
"models.default_base_model", "yolov8s-seg.pt"
|
||||
)
|
||||
model_id = self.db_manager.add_model(
|
||||
model_name="Base Model",
|
||||
model_version="pretrained",
|
||||
model_path=base_model,
|
||||
base_model=base_model,
|
||||
if not Path(model_path).exists():
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Not Found",
|
||||
f"The selected model file could not be found:\n{model_path}",
|
||||
)
|
||||
return
|
||||
|
||||
model_id = model_data.get("id")
|
||||
|
||||
# Ensure we have a database entry for the selected model
|
||||
if model_id in (None, 0):
|
||||
model_id = self._ensure_model_record(model_data)
|
||||
if not model_id:
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Model Registration Failed",
|
||||
"Unable to register the selected model in the database.",
|
||||
)
|
||||
return
|
||||
|
||||
normalized_model_path = self._normalize_model_path(model_path) or model_path
|
||||
|
||||
# Create inference engine
|
||||
self.inference_engine = InferenceEngine(
|
||||
model_path, self.db_manager, model_id
|
||||
normalized_model_path, self.db_manager, model_id
|
||||
)
|
||||
|
||||
# Get confidence threshold
|
||||
@@ -338,6 +390,76 @@ class DetectionTab(QWidget):
|
||||
self.batch_btn.setEnabled(enabled)
|
||||
self.model_combo.setEnabled(enabled)
|
||||
|
||||
def _discover_local_models(self) -> list:
|
||||
"""Scan the models directory for standalone .pt files."""
|
||||
models_dir = self.config_manager.get_models_directory()
|
||||
if not models_dir:
|
||||
return []
|
||||
|
||||
models_path = Path(models_dir)
|
||||
if not models_path.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
return sorted(
|
||||
[p for p in models_path.rglob("*.pt") if p.is_file()],
|
||||
key=lambda p: str(p).lower(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error discovering local models: {e}")
|
||||
return []
|
||||
|
||||
def _normalize_model_path(self, path_value) -> str:
|
||||
"""Return a normalized absolute path string for comparison."""
|
||||
if not path_value:
|
||||
return ""
|
||||
try:
|
||||
return str(Path(path_value).resolve())
|
||||
except Exception:
|
||||
return str(path_value)
|
||||
|
||||
def _ensure_model_record(self, model_data: dict) -> Optional[int]:
|
||||
"""Ensure a database record exists for the selected model."""
|
||||
model_path = model_data.get("path")
|
||||
if not model_path:
|
||||
return None
|
||||
|
||||
normalized_target = self._normalize_model_path(model_path)
|
||||
|
||||
try:
|
||||
existing_models = self.db_manager.get_models()
|
||||
for model in existing_models:
|
||||
existing_path = model.get("model_path")
|
||||
if not existing_path:
|
||||
continue
|
||||
normalized_existing = self._normalize_model_path(existing_path)
|
||||
if (
|
||||
normalized_existing == normalized_target
|
||||
or existing_path == model_path
|
||||
):
|
||||
return model["id"]
|
||||
|
||||
model_name = (
|
||||
model_data.get("model_name") or Path(model_path).stem or "Custom Model"
|
||||
)
|
||||
model_version = (
|
||||
model_data.get("model_version") or model_data.get("source") or "local"
|
||||
)
|
||||
base_model = model_data.get(
|
||||
"base_model",
|
||||
self.config_manager.get("models.default_base_model", "yolov8s-seg.pt"),
|
||||
)
|
||||
|
||||
return self.db_manager.add_model(
|
||||
model_name=model_name,
|
||||
model_version=model_version,
|
||||
model_path=normalized_target,
|
||||
base_model=base_model,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to ensure model record for {model_path}: {e}")
|
||||
return None
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
self._load_models()
|
||||
|
||||
@@ -1,15 +1,39 @@
|
||||
"""
|
||||
Results tab for the microscopy object detection application.
|
||||
Results tab for browsing stored detections and visualizing overlays.
|
||||
"""
|
||||
|
||||
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QGroupBox
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from PySide6.QtWidgets import (
|
||||
QWidget,
|
||||
QVBoxLayout,
|
||||
QHBoxLayout,
|
||||
QLabel,
|
||||
QGroupBox,
|
||||
QPushButton,
|
||||
QSplitter,
|
||||
QTableWidget,
|
||||
QTableWidgetItem,
|
||||
QHeaderView,
|
||||
QAbstractItemView,
|
||||
QMessageBox,
|
||||
QCheckBox,
|
||||
)
|
||||
from PySide6.QtCore import Qt
|
||||
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.config_manager import ConfigManager
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
from src.gui.widgets import AnnotationCanvasWidget
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ResultsTab(QWidget):
|
||||
"""Results tab placeholder."""
|
||||
"""Results tab showing detection history and preview overlays."""
|
||||
|
||||
def __init__(
|
||||
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
|
||||
@@ -18,29 +42,398 @@ class ResultsTab(QWidget):
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
self.detection_summary: List[Dict] = []
|
||||
self.current_selection: Optional[Dict] = None
|
||||
self.current_image: Optional[Image] = None
|
||||
self.current_detections: List[Dict] = []
|
||||
self._image_path_cache: Dict[str, str] = {}
|
||||
|
||||
self._setup_ui()
|
||||
self.refresh()
|
||||
|
||||
def _setup_ui(self):
|
||||
"""Setup user interface."""
|
||||
layout = QVBoxLayout()
|
||||
|
||||
group = QGroupBox("Results")
|
||||
group_layout = QVBoxLayout()
|
||||
label = QLabel(
|
||||
"Results viewer will be implemented here.\n\n"
|
||||
"Features:\n"
|
||||
"- Detection history browser\n"
|
||||
"- Advanced filtering\n"
|
||||
"- Statistics dashboard\n"
|
||||
"- Export functionality"
|
||||
)
|
||||
group_layout.addWidget(label)
|
||||
group.setLayout(group_layout)
|
||||
# Splitter for list + preview
|
||||
splitter = QSplitter(Qt.Horizontal)
|
||||
|
||||
layout.addWidget(group)
|
||||
layout.addStretch()
|
||||
# Left pane: detection list
|
||||
left_container = QWidget()
|
||||
left_layout = QVBoxLayout()
|
||||
left_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
controls_layout = QHBoxLayout()
|
||||
self.refresh_btn = QPushButton("Refresh")
|
||||
self.refresh_btn.clicked.connect(self.refresh)
|
||||
controls_layout.addWidget(self.refresh_btn)
|
||||
controls_layout.addStretch()
|
||||
left_layout.addLayout(controls_layout)
|
||||
|
||||
self.results_table = QTableWidget(0, 5)
|
||||
self.results_table.setHorizontalHeaderLabels(
|
||||
["Image", "Model", "Detections", "Classes", "Last Updated"]
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
0, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
1, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
2, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
3, QHeaderView.Stretch
|
||||
)
|
||||
self.results_table.horizontalHeader().setSectionResizeMode(
|
||||
4, QHeaderView.ResizeToContents
|
||||
)
|
||||
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
|
||||
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
|
||||
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
|
||||
self.results_table.itemSelectionChanged.connect(self._on_result_selected)
|
||||
|
||||
left_layout.addWidget(self.results_table)
|
||||
left_container.setLayout(left_layout)
|
||||
|
||||
# Right pane: preview canvas and controls
|
||||
right_container = QWidget()
|
||||
right_layout = QVBoxLayout()
|
||||
right_layout.setContentsMargins(0, 0, 0, 0)
|
||||
|
||||
preview_group = QGroupBox("Detection Preview")
|
||||
preview_layout = QVBoxLayout()
|
||||
|
||||
self.preview_canvas = AnnotationCanvasWidget()
|
||||
self.preview_canvas.set_polyline_enabled(False)
|
||||
self.preview_canvas.set_show_bboxes(True)
|
||||
preview_layout.addWidget(self.preview_canvas)
|
||||
|
||||
toggles_layout = QHBoxLayout()
|
||||
self.show_masks_checkbox = QCheckBox("Show Masks")
|
||||
self.show_masks_checkbox.setChecked(False)
|
||||
self.show_masks_checkbox.stateChanged.connect(self._apply_detection_overlays)
|
||||
self.show_bboxes_checkbox = QCheckBox("Show Bounding Boxes")
|
||||
self.show_bboxes_checkbox.setChecked(True)
|
||||
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
|
||||
self.show_confidence_checkbox = QCheckBox("Show Confidence")
|
||||
self.show_confidence_checkbox.setChecked(False)
|
||||
self.show_confidence_checkbox.stateChanged.connect(
|
||||
self._apply_detection_overlays
|
||||
)
|
||||
toggles_layout.addWidget(self.show_masks_checkbox)
|
||||
toggles_layout.addWidget(self.show_bboxes_checkbox)
|
||||
toggles_layout.addWidget(self.show_confidence_checkbox)
|
||||
toggles_layout.addStretch()
|
||||
preview_layout.addLayout(toggles_layout)
|
||||
|
||||
self.summary_label = QLabel("Select a detection result to preview.")
|
||||
self.summary_label.setWordWrap(True)
|
||||
preview_layout.addWidget(self.summary_label)
|
||||
|
||||
preview_group.setLayout(preview_layout)
|
||||
right_layout.addWidget(preview_group)
|
||||
right_container.setLayout(right_layout)
|
||||
|
||||
splitter.addWidget(left_container)
|
||||
splitter.addWidget(right_container)
|
||||
splitter.setStretchFactor(0, 1)
|
||||
splitter.setStretchFactor(1, 2)
|
||||
|
||||
layout.addWidget(splitter)
|
||||
self.setLayout(layout)
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the tab."""
|
||||
pass
|
||||
"""Refresh the detection list and preview."""
|
||||
self._load_detection_summary()
|
||||
self._populate_results_table()
|
||||
self.current_selection = None
|
||||
self.current_image = None
|
||||
self.current_detections = []
|
||||
self.preview_canvas.clear()
|
||||
self.summary_label.setText("Select a detection result to preview.")
|
||||
|
||||
def _load_detection_summary(self):
|
||||
"""Load latest detection summaries grouped by image + model."""
|
||||
try:
|
||||
detections = self.db_manager.get_detections(limit=500)
|
||||
summary_map: Dict[tuple, Dict] = {}
|
||||
|
||||
for det in detections:
|
||||
key = (det["image_id"], det["model_id"])
|
||||
metadata = det.get("metadata") or {}
|
||||
entry = summary_map.setdefault(
|
||||
key,
|
||||
{
|
||||
"image_id": det["image_id"],
|
||||
"model_id": det["model_id"],
|
||||
"image_path": det.get("image_path"),
|
||||
"image_filename": det.get("image_filename")
|
||||
or det.get("image_path"),
|
||||
"model_name": det.get("model_name", ""),
|
||||
"model_version": det.get("model_version", ""),
|
||||
"last_detected": det.get("detected_at"),
|
||||
"count": 0,
|
||||
"classes": set(),
|
||||
"source_path": metadata.get("source_path"),
|
||||
"repository_root": metadata.get("repository_root"),
|
||||
},
|
||||
)
|
||||
|
||||
entry["count"] += 1
|
||||
if det.get("detected_at") and (
|
||||
not entry.get("last_detected")
|
||||
or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||
):
|
||||
entry["last_detected"] = det.get("detected_at")
|
||||
if det.get("class_name"):
|
||||
entry["classes"].add(det["class_name"])
|
||||
if metadata.get("source_path") and not entry.get("source_path"):
|
||||
entry["source_path"] = metadata.get("source_path")
|
||||
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||
entry["repository_root"] = metadata.get("repository_root")
|
||||
|
||||
self.detection_summary = sorted(
|
||||
summary_map.values(),
|
||||
key=lambda x: str(x.get("last_detected") or ""),
|
||||
reverse=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detection summary: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detection results:\n{str(e)}",
|
||||
)
|
||||
self.detection_summary = []
|
||||
|
||||
def _populate_results_table(self):
|
||||
"""Populate the table widget with detection summaries."""
|
||||
self.results_table.setRowCount(len(self.detection_summary))
|
||||
|
||||
for row, entry in enumerate(self.detection_summary):
|
||||
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
|
||||
class_list = (
|
||||
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
|
||||
)
|
||||
|
||||
items = [
|
||||
QTableWidgetItem(entry.get("image_filename", "")),
|
||||
QTableWidgetItem(model_label),
|
||||
QTableWidgetItem(str(entry.get("count", 0))),
|
||||
QTableWidgetItem(class_list),
|
||||
QTableWidgetItem(str(entry.get("last_detected") or "")),
|
||||
]
|
||||
|
||||
for col, item in enumerate(items):
|
||||
item.setData(Qt.UserRole, row)
|
||||
self.results_table.setItem(row, col, item)
|
||||
|
||||
self.results_table.clearSelection()
|
||||
|
||||
def _on_result_selected(self):
|
||||
"""Handle selection changes in the detection table."""
|
||||
selected_items = self.results_table.selectedItems()
|
||||
if not selected_items:
|
||||
return
|
||||
|
||||
row = selected_items[0].data(Qt.UserRole)
|
||||
if row is None or row >= len(self.detection_summary):
|
||||
return
|
||||
|
||||
entry = self.detection_summary[row]
|
||||
if (
|
||||
self.current_selection
|
||||
and self.current_selection.get("image_id") == entry["image_id"]
|
||||
and self.current_selection.get("model_id") == entry["model_id"]
|
||||
):
|
||||
return
|
||||
|
||||
self.current_selection = entry
|
||||
|
||||
image_path = self._resolve_image_path(entry)
|
||||
if not image_path:
|
||||
QMessageBox.warning(
|
||||
self,
|
||||
"Image Not Found",
|
||||
"Unable to locate the image file for this detection.",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
self.current_image = Image(image_path)
|
||||
self.preview_canvas.load_image(self.current_image)
|
||||
except ImageLoadError as e:
|
||||
logger.error(f"Failed to load image '{image_path}': {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Image Error",
|
||||
f"Failed to load image for preview:\n{str(e)}",
|
||||
)
|
||||
return
|
||||
|
||||
self._load_detections_for_selection(entry)
|
||||
self._apply_detection_overlays()
|
||||
self._update_summary_label(entry)
|
||||
|
||||
def _load_detections_for_selection(self, entry: Dict):
|
||||
"""Load detection records for the selected image/model pair."""
|
||||
self.current_detections = []
|
||||
if not entry:
|
||||
return
|
||||
|
||||
try:
|
||||
filters = {"image_id": entry["image_id"], "model_id": entry["model_id"]}
|
||||
self.current_detections = self.db_manager.get_detections(filters)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detections for preview: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detections for this image:\n{str(e)}",
|
||||
)
|
||||
self.current_detections = []
|
||||
|
||||
def _apply_detection_overlays(self):
|
||||
"""Draw detections onto the preview canvas based on current toggles."""
|
||||
self.preview_canvas.clear_annotations()
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
|
||||
if not self.current_detections or not self.current_image:
|
||||
return
|
||||
|
||||
for det in self.current_detections:
|
||||
color = self._get_class_color(det.get("class_name"))
|
||||
|
||||
if self.show_masks_checkbox.isChecked() and det.get("segmentation_mask"):
|
||||
mask_points = self._convert_mask(det["segmentation_mask"])
|
||||
if mask_points:
|
||||
self.preview_canvas.draw_saved_polyline(mask_points, color)
|
||||
|
||||
bbox = [
|
||||
det.get("x_min"),
|
||||
det.get("y_min"),
|
||||
det.get("x_max"),
|
||||
det.get("y_max"),
|
||||
]
|
||||
if all(v is not None for v in bbox):
|
||||
label = None
|
||||
if self.show_confidence_checkbox.isChecked():
|
||||
confidence = det.get("confidence")
|
||||
if confidence is not None:
|
||||
label = f"{confidence:.2f}"
|
||||
self.preview_canvas.draw_saved_bbox(bbox, color, label=label)
|
||||
|
||||
def _convert_mask(self, mask_points: List[List[float]]) -> List[List[float]]:
|
||||
"""Convert stored [x, y] masks to [y, x] format for the canvas."""
|
||||
converted = []
|
||||
for point in mask_points:
|
||||
if len(point) >= 2:
|
||||
x, y = point[0], point[1]
|
||||
converted.append([y, x])
|
||||
return converted
|
||||
|
||||
def _toggle_bboxes(self):
|
||||
"""Update bounding box visibility on the canvas."""
|
||||
self.preview_canvas.set_show_bboxes(self.show_bboxes_checkbox.isChecked())
|
||||
# Re-render to respect show/hide when toggled
|
||||
self._apply_detection_overlays()
|
||||
|
||||
def _update_summary_label(self, entry: Dict):
|
||||
"""Display textual summary for the selected detection run."""
|
||||
classes = ", ".join(sorted(entry.get("classes", []))) or "-"
|
||||
summary_text = (
|
||||
f"Image: {entry.get('image_filename', 'unknown')}\n"
|
||||
f"Model: {entry.get('model_name', '')} {entry.get('model_version', '')}\n"
|
||||
f"Detections: {entry.get('count', 0)}\n"
|
||||
f"Classes: {classes}\n"
|
||||
f"Last Updated: {entry.get('last_detected', 'n/a')}"
|
||||
)
|
||||
self.summary_label.setText(summary_text)
|
||||
|
||||
def _resolve_image_path(self, entry: Dict) -> Optional[str]:
|
||||
"""Resolve an image path using metadata, cache, and repository hints."""
|
||||
relative_path = entry.get("image_path") if entry else None
|
||||
cache_key = relative_path or entry.get("source_path")
|
||||
if cache_key and cache_key in self._image_path_cache:
|
||||
cached = Path(self._image_path_cache[cache_key])
|
||||
if cached.exists():
|
||||
return self._image_path_cache[cache_key]
|
||||
del self._image_path_cache[cache_key]
|
||||
|
||||
candidates = []
|
||||
source_path = entry.get("source_path") if entry else None
|
||||
if source_path:
|
||||
candidates.append(Path(source_path))
|
||||
|
||||
repo_roots = []
|
||||
if entry.get("repository_root"):
|
||||
repo_roots.append(entry["repository_root"])
|
||||
config_repo = self.config_manager.get_image_repository_path()
|
||||
if config_repo:
|
||||
repo_roots.append(config_repo)
|
||||
|
||||
for root in repo_roots:
|
||||
if relative_path:
|
||||
candidates.append(Path(root) / relative_path)
|
||||
|
||||
if relative_path:
|
||||
candidates.append(Path(relative_path))
|
||||
|
||||
for candidate in candidates:
|
||||
try:
|
||||
if candidate and candidate.exists():
|
||||
resolved = str(candidate.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Fallback: search by filename in known roots
|
||||
filename = Path(relative_path).name if relative_path else None
|
||||
if filename:
|
||||
search_roots = [Path(root) for root in repo_roots if root]
|
||||
if not search_roots:
|
||||
search_roots = [Path("data")]
|
||||
match = self._search_in_roots(filename, search_roots)
|
||||
if match:
|
||||
resolved = str(match.resolve())
|
||||
if cache_key:
|
||||
self._image_path_cache[cache_key] = resolved
|
||||
return resolved
|
||||
|
||||
return None
|
||||
|
||||
def _search_in_roots(self, filename: str, roots: List[Path]) -> Optional[Path]:
|
||||
"""Search for a file name within a list of root directories."""
|
||||
for root in roots:
|
||||
try:
|
||||
if not root.exists():
|
||||
continue
|
||||
for candidate in root.rglob(filename):
|
||||
return candidate
|
||||
except Exception as e:
|
||||
logger.debug(f"Error searching for {filename} in {root}: {e}")
|
||||
return None
|
||||
|
||||
def _get_class_color(self, class_name: Optional[str]) -> str:
|
||||
"""Return consistent color hex for a class name."""
|
||||
if not class_name:
|
||||
return "#FF6B6B"
|
||||
|
||||
color_map = self.config_manager.get_bbox_colors()
|
||||
if class_name in color_map:
|
||||
return color_map[class_name]
|
||||
|
||||
# Deterministic fallback color based on hash
|
||||
palette = [
|
||||
"#FF6B6B",
|
||||
"#4ECDC4",
|
||||
"#FFD166",
|
||||
"#1D3557",
|
||||
"#F4A261",
|
||||
"#E76F51",
|
||||
]
|
||||
return palette[hash(class_name) % len(palette)]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -16,8 +16,9 @@ from PySide6.QtGui import (
|
||||
QKeyEvent,
|
||||
QMouseEvent,
|
||||
QPaintEvent,
|
||||
QPolygonF,
|
||||
)
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint
|
||||
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from src.utils.image import Image, ImageLoadError
|
||||
@@ -246,15 +247,13 @@ class AnnotationCanvasWidget(QWidget):
|
||||
return
|
||||
|
||||
try:
|
||||
# Get RGB image data
|
||||
if self.current_image.channels == 3:
|
||||
# Get image data in a format compatible with Qt
|
||||
if self.current_image.channels in (3, 4):
|
||||
image_data = self.current_image.get_rgb()
|
||||
height, width, channels = image_data.shape
|
||||
else:
|
||||
image_data = self.current_image.get_grayscale()
|
||||
height, width = image_data.shape
|
||||
image_data = self.current_image.get_qt_rgb()
|
||||
|
||||
image_data = np.ascontiguousarray(image_data)
|
||||
height, width = image_data.shape[:2]
|
||||
bytes_per_line = image_data.strides[0]
|
||||
|
||||
qimage = QImage(
|
||||
@@ -262,8 +261,8 @@ class AnnotationCanvasWidget(QWidget):
|
||||
width,
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
QImage.Format_RGBX32FPx4, # self.current_image.qtimage_format,
|
||||
).copy() # Copy so Qt owns the buffer even after numpy array goes out of scope
|
||||
|
||||
self.original_pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
@@ -496,8 +495,10 @@ class AnnotationCanvasWidget(QWidget):
|
||||
)
|
||||
|
||||
painter.setPen(pen)
|
||||
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
|
||||
painter.drawLine(int(x1), int(y1), int(x2), int(y2))
|
||||
# Use QPolygonF for efficient polygon rendering (single call vs N-1 calls)
|
||||
# drawPolygon() automatically closes the shape, ensuring proper visual closure
|
||||
polygon = QPolygonF([QPointF(x, y) for x, y in polyline])
|
||||
painter.drawPolygon(polygon)
|
||||
|
||||
# Draw bounding boxes (dashed) if enabled
|
||||
if self.show_bboxes and self.original_pixmap is not None and self.bboxes:
|
||||
@@ -529,6 +530,40 @@ class AnnotationCanvasWidget(QWidget):
|
||||
painter.setPen(pen)
|
||||
painter.drawRect(x_min, y_min, rect_width, rect_height)
|
||||
|
||||
label_text = meta.get("label")
|
||||
if label_text:
|
||||
painter.save()
|
||||
font = painter.font()
|
||||
font.setPointSizeF(max(10.0, width + 4))
|
||||
painter.setFont(font)
|
||||
metrics = painter.fontMetrics()
|
||||
text_width = metrics.horizontalAdvance(label_text)
|
||||
text_height = metrics.height()
|
||||
padding = 4
|
||||
bg_width = text_width + padding * 2
|
||||
bg_height = text_height + padding * 2
|
||||
canvas_width = self.original_pixmap.width()
|
||||
canvas_height = self.original_pixmap.height()
|
||||
bg_x = max(0, min(x_min, canvas_width - bg_width))
|
||||
bg_y = y_min - bg_height
|
||||
if bg_y < 0:
|
||||
bg_y = min(y_min, canvas_height - bg_height)
|
||||
bg_y = max(0, bg_y)
|
||||
background_rect = QRect(bg_x, bg_y, bg_width, bg_height)
|
||||
background_color = QColor(pen_color)
|
||||
background_color.setAlpha(220)
|
||||
painter.fillRect(background_rect, background_color)
|
||||
text_color = QColor(0, 0, 0)
|
||||
if background_color.lightness() < 128:
|
||||
text_color = QColor(255, 255, 255)
|
||||
painter.setPen(text_color)
|
||||
painter.drawText(
|
||||
background_rect.adjusted(padding, padding, -padding, -padding),
|
||||
Qt.AlignLeft | Qt.AlignVCenter,
|
||||
label_text,
|
||||
)
|
||||
painter.restore()
|
||||
|
||||
painter.end()
|
||||
|
||||
self._update_display()
|
||||
@@ -787,7 +822,13 @@ class AnnotationCanvasWidget(QWidget):
|
||||
f"Drew saved polyline with {len(polyline)} points in color {color}"
|
||||
)
|
||||
|
||||
def draw_saved_bbox(self, bbox: List[float], color: str, width: int = 3):
|
||||
def draw_saved_bbox(
|
||||
self,
|
||||
bbox: List[float],
|
||||
color: str,
|
||||
width: int = 3,
|
||||
label: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Draw a bounding box from database coordinates onto the annotation canvas.
|
||||
|
||||
@@ -796,6 +837,7 @@ class AnnotationCanvasWidget(QWidget):
|
||||
in normalized coordinates (0-1)
|
||||
color: Color hex string (e.g., '#FF0000')
|
||||
width: Line width in pixels
|
||||
label: Optional text label to render near the bounding box
|
||||
"""
|
||||
if not self.annotation_pixmap or not self.original_pixmap:
|
||||
logger.warning("Cannot draw bounding box: no image loaded")
|
||||
@@ -828,11 +870,11 @@ class AnnotationCanvasWidget(QWidget):
|
||||
self.bboxes.append(
|
||||
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
|
||||
)
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width)})
|
||||
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
|
||||
|
||||
# Store in all_strokes for consistency
|
||||
self.all_strokes.append(
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width}
|
||||
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
|
||||
)
|
||||
|
||||
# Redraw overlay (polylines + all bounding boxes)
|
||||
|
||||
@@ -137,7 +137,7 @@ class ImageDisplayWidget(QWidget):
|
||||
height,
|
||||
bytes_per_line,
|
||||
self.current_image.qtimage_format,
|
||||
)
|
||||
).copy() # Copy to ensure Qt owns its memory after this scope
|
||||
|
||||
# Convert to pixmap
|
||||
pixmap = QPixmap.fromImage(qimage)
|
||||
|
||||
@@ -5,12 +5,12 @@ Handles detection inference and result storage.
|
||||
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from src.model.yolo_wrapper import YOLOWrapper
|
||||
from src.database.db_manager import DatabaseManager
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import get_relative_path
|
||||
|
||||
@@ -42,6 +42,7 @@ class InferenceEngine:
|
||||
relative_path: str,
|
||||
conf: float = 0.25,
|
||||
save_to_db: bool = True,
|
||||
repository_root: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Detect objects in a single image.
|
||||
@@ -51,49 +52,79 @@ class InferenceEngine:
|
||||
relative_path: Relative path from repository root
|
||||
conf: Confidence threshold
|
||||
save_to_db: Whether to save results to database
|
||||
repository_root: Base directory used to compute relative_path (if known)
|
||||
|
||||
Returns:
|
||||
Dictionary with detection results
|
||||
"""
|
||||
try:
|
||||
# Normalize storage path (fall back to absolute path when repo root is unknown)
|
||||
stored_relative_path = relative_path
|
||||
if not repository_root:
|
||||
stored_relative_path = str(Path(image_path).resolve())
|
||||
|
||||
# Get image dimensions
|
||||
img = Image.open(image_path)
|
||||
width, height = img.size
|
||||
img.close()
|
||||
img = Image(image_path)
|
||||
width = img.width
|
||||
height = img.height
|
||||
|
||||
# Perform detection
|
||||
detections = self.yolo.predict(image_path, conf=conf)
|
||||
|
||||
# Add/get image in database
|
||||
image_id = self.db_manager.get_or_create_image(
|
||||
relative_path=relative_path,
|
||||
relative_path=stored_relative_path,
|
||||
filename=Path(image_path).name,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
# Save detections to database
|
||||
if save_to_db and detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
inserted_count = 0
|
||||
deleted_count = 0
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": {"class_id": det["class_id"]},
|
||||
}
|
||||
detection_records.append(record)
|
||||
# Save detections to database, replacing any previous results for this image/model
|
||||
if save_to_db:
|
||||
deleted_count = self.db_manager.delete_detections_for_image(
|
||||
image_id, self.model_id
|
||||
)
|
||||
if detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {len(detection_records)} detections to database")
|
||||
metadata = {
|
||||
"class_id": det["class_id"],
|
||||
"source_path": str(Path(image_path).resolve()),
|
||||
}
|
||||
if repository_root:
|
||||
metadata["repository_root"] = str(
|
||||
Path(repository_root).resolve()
|
||||
)
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
"model_id": self.model_id,
|
||||
"class_name": det["class_name"],
|
||||
"bbox": tuple(bbox_normalized),
|
||||
"confidence": det["confidence"],
|
||||
"segmentation_mask": det.get("segmentation_mask"),
|
||||
"metadata": metadata,
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
inserted_count = self.db_manager.add_detections_batch(
|
||||
detection_records
|
||||
)
|
||||
logger.info(
|
||||
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -142,7 +173,12 @@ class InferenceEngine:
|
||||
rel_path = get_relative_path(image_path, repository_root)
|
||||
|
||||
# Perform detection
|
||||
result = self.detect_single(image_path, rel_path, conf)
|
||||
result = self.detect_single(
|
||||
image_path,
|
||||
rel_path,
|
||||
conf=conf,
|
||||
repository_root=repository_root,
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Update progress
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
"""
|
||||
YOLO model wrapper for the microscopy object detection application.
|
||||
Provides a clean interface to YOLOv8 for training, validation, and inference.
|
||||
"""YOLO model wrapper for the microscopy object detection application.
|
||||
|
||||
Notes on 16-bit TIFF support:
|
||||
- Ultralytics training defaults assume 8-bit images and normalize by dividing by 255.
|
||||
- This project can patch Ultralytics at runtime to decode TIFFs via `tifffile` and
|
||||
normalize `uint16` correctly.
|
||||
|
||||
See [`apply_ultralytics_16bit_tiff_patches()`](src/utils/ultralytics_16bit_patch.py:1).
|
||||
"""
|
||||
|
||||
from ultralytics import YOLO
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Callable, Any
|
||||
import torch
|
||||
import tempfile
|
||||
import os
|
||||
from src.utils.image import Image
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.ultralytics_16bit_patch import apply_ultralytics_16bit_tiff_patches
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -28,6 +36,9 @@ class YOLOWrapper:
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"YOLOWrapper initialized with device: {self.device}")
|
||||
|
||||
# Apply Ultralytics runtime patches early (before first import/instantiation of YOLO datasets/trainers).
|
||||
apply_ultralytics_16bit_tiff_patches()
|
||||
|
||||
def load_model(self) -> bool:
|
||||
"""
|
||||
Load YOLO model from path.
|
||||
@@ -37,6 +48,9 @@ class YOLOWrapper:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Loading YOLO model from {self.model_path}")
|
||||
# Import YOLO lazily to ensure runtime patches are applied first.
|
||||
from ultralytics import YOLO
|
||||
|
||||
self.model = YOLO(self.model_path)
|
||||
self.model.to(self.device)
|
||||
logger.info("Model loaded successfully")
|
||||
@@ -55,6 +69,7 @@ class YOLOWrapper:
|
||||
save_dir: str = "data/models",
|
||||
name: str = "custom_model",
|
||||
resume: bool = False,
|
||||
callbacks: Optional[Dict[str, Callable]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -69,13 +84,15 @@ class YOLOWrapper:
|
||||
save_dir: Directory to save trained model
|
||||
name: Name for the training run
|
||||
resume: Resume training from last checkpoint
|
||||
callbacks: Optional Ultralytics callback dictionary
|
||||
**kwargs: Additional training arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with training results
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting training: {name}")
|
||||
@@ -83,6 +100,16 @@ class YOLOWrapper:
|
||||
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
|
||||
)
|
||||
|
||||
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
|
||||
# Users can override by passing explicit kwargs.
|
||||
kwargs.setdefault("mosaic", 0.0)
|
||||
kwargs.setdefault("mixup", 0.0)
|
||||
kwargs.setdefault("cutmix", 0.0)
|
||||
kwargs.setdefault("copy_paste", 0.0)
|
||||
kwargs.setdefault("hsv_h", 0.0)
|
||||
kwargs.setdefault("hsv_s", 0.0)
|
||||
kwargs.setdefault("hsv_v", 0.0)
|
||||
|
||||
# Train the model
|
||||
results = self.model.train(
|
||||
data=data_yaml,
|
||||
@@ -117,7 +144,8 @@ class YOLOWrapper:
|
||||
Dictionary with validation metrics
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Starting validation on {split} split")
|
||||
@@ -158,10 +186,15 @@ class YOLOWrapper:
|
||||
List of detection dictionaries
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
prepared_source, cleanup_path = self._prepare_source(source)
|
||||
|
||||
try:
|
||||
logger.info(f"Running inference on {source}")
|
||||
logger.info(
|
||||
f"Running inference on {source} -> prepared_source {prepared_source}"
|
||||
)
|
||||
results = self.model.predict(
|
||||
source=source,
|
||||
conf=conf,
|
||||
@@ -180,6 +213,14 @@ class YOLOWrapper:
|
||||
except Exception as e:
|
||||
logger.error(f"Error during inference: {e}")
|
||||
raise
|
||||
finally:
|
||||
if 0: # cleanup_path:
|
||||
try:
|
||||
os.remove(cleanup_path)
|
||||
except OSError as cleanup_error:
|
||||
logger.warning(
|
||||
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
|
||||
)
|
||||
|
||||
def export(
|
||||
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
|
||||
@@ -196,7 +237,8 @@ class YOLOWrapper:
|
||||
Path to exported model
|
||||
"""
|
||||
if self.model is None:
|
||||
self.load_model()
|
||||
if not self.load_model():
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
try:
|
||||
logger.info(f"Exporting model to {format} format")
|
||||
@@ -208,6 +250,32 @@ class YOLOWrapper:
|
||||
logger.error(f"Error exporting model: {e}")
|
||||
raise
|
||||
|
||||
def _prepare_source(self, source):
|
||||
"""Convert single-channel images to RGB temporarily for inference."""
|
||||
cleanup_path = None
|
||||
|
||||
if isinstance(source, (str, Path)):
|
||||
source_path = Path(source)
|
||||
if source_path.is_file():
|
||||
try:
|
||||
img_obj = Image(source_path)
|
||||
suffix = source_path.suffix or ".png"
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp_path = tmp.name
|
||||
tmp.close()
|
||||
img_obj.save(tmp_path)
|
||||
cleanup_path = tmp_path
|
||||
logger.info(
|
||||
f"Converted image {source_path} to RGB for inference at {tmp_path}"
|
||||
)
|
||||
return tmp_path, cleanup_path
|
||||
except Exception as convert_error:
|
||||
logger.warning(
|
||||
f"Failed to preprocess {source_path} as RGB, continuing with original file: {convert_error}"
|
||||
)
|
||||
|
||||
return source, cleanup_path
|
||||
|
||||
def _format_training_results(self, results) -> Dict[str, Any]:
|
||||
"""Format training results into dictionary."""
|
||||
try:
|
||||
|
||||
@@ -7,6 +7,7 @@ import yaml
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -46,18 +47,15 @@ class ConfigManager:
|
||||
"database": {"path": "data/detections.db"},
|
||||
"image_repository": {
|
||||
"base_path": "",
|
||||
"allowed_extensions": [
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".png",
|
||||
".tif",
|
||||
".tiff",
|
||||
".bmp",
|
||||
],
|
||||
"allowed_extensions": Image.SUPPORTED_EXTENSIONS,
|
||||
},
|
||||
"models": {
|
||||
"default_base_model": "yolov8s-seg.pt",
|
||||
"models_directory": "data/models",
|
||||
"base_model_choices": [
|
||||
"yolov8s-seg.pt",
|
||||
"yolo11s-seg.pt",
|
||||
],
|
||||
},
|
||||
"training": {
|
||||
"default_epochs": 100,
|
||||
@@ -65,6 +63,20 @@ class ConfigManager:
|
||||
"default_imgsz": 640,
|
||||
"default_patience": 50,
|
||||
"default_lr0": 0.01,
|
||||
"two_stage": {
|
||||
"enabled": False,
|
||||
"stage1": {
|
||||
"epochs": 20,
|
||||
"lr0": 0.0005,
|
||||
"patience": 10,
|
||||
"freeze": 10,
|
||||
},
|
||||
"stage2": {
|
||||
"epochs": 150,
|
||||
"lr0": 0.0003,
|
||||
"patience": 30,
|
||||
},
|
||||
},
|
||||
},
|
||||
"detection": {
|
||||
"default_confidence": 0.25,
|
||||
@@ -213,6 +225,4 @@ class ConfigManager:
|
||||
|
||||
def get_allowed_extensions(self) -> list:
|
||||
"""Get list of allowed image file extensions."""
|
||||
return self.get(
|
||||
"image_repository.allowed_extensions", [".jpg", ".jpeg", ".png"]
|
||||
)
|
||||
return self.get("image_repository.allowed_extensions", Image.SUPPORTED_EXTENSIONS)
|
||||
|
||||
@@ -28,7 +28,9 @@ def get_image_files(
|
||||
List of absolute paths to image files
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
from src.utils.image import Image
|
||||
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
# Normalize extensions to lowercase
|
||||
allowed_extensions = [ext.lower() for ext in allowed_extensions]
|
||||
@@ -204,7 +206,9 @@ def is_image_file(
|
||||
True if file is an image
|
||||
"""
|
||||
if allowed_extensions is None:
|
||||
allowed_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
from src.utils.image import Image
|
||||
|
||||
allowed_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
|
||||
extension = Path(file_path).suffix.lower()
|
||||
return extension in [ext.lower() for ext in allowed_extensions]
|
||||
|
||||
@@ -6,16 +6,52 @@ import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import validate_file_path, is_image_file
|
||||
|
||||
from PySide6.QtGui import QImage
|
||||
|
||||
from tifffile import imread, imwrite
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
||||
"""
|
||||
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
|
||||
|
||||
Args:
|
||||
arr: Input grayscale image as numpy array
|
||||
|
||||
Returns:
|
||||
Pseudo-RGB image as numpy array
|
||||
"""
|
||||
if arr.ndim != 2:
|
||||
raise ValueError("Input array must be a grayscale image with shape (H, W)")
|
||||
|
||||
a1 = arr.copy().astype(np.float32)
|
||||
a1 -= np.percentile(a1, 2)
|
||||
a1[a1 < 0] = 0
|
||||
p999 = np.percentile(a1, 99.9)
|
||||
a1[a1 > p999] = p999
|
||||
a1 /= a1.max()
|
||||
|
||||
if 0:
|
||||
a2 = a1.copy()
|
||||
a2 = a2**gamma
|
||||
a2 /= a2.max()
|
||||
|
||||
a3 = a1.copy()
|
||||
p9999 = np.percentile(a3, 99.99)
|
||||
a3[a3 > p9999] = p9999
|
||||
a3 /= a3.max()
|
||||
|
||||
return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
# return np.stack([a1, a2, a3], axis=0)
|
||||
|
||||
|
||||
class ImageLoadError(Exception):
|
||||
"""Exception raised when an image cannot be loaded."""
|
||||
|
||||
@@ -54,7 +90,6 @@ class Image:
|
||||
"""
|
||||
self.path = Path(image_path)
|
||||
self._data: Optional[np.ndarray] = None
|
||||
self._pil_image: Optional[PILImage.Image] = None
|
||||
self._width: int = 0
|
||||
self._height: int = 0
|
||||
self._channels: int = 0
|
||||
@@ -80,40 +115,39 @@ class Image:
|
||||
if not is_image_file(str(self.path), self.SUPPORTED_EXTENSIONS):
|
||||
ext = self.path.suffix.lower()
|
||||
raise ImageLoadError(
|
||||
f"Unsupported image format: {ext}. "
|
||||
f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
|
||||
f"Unsupported image format: {ext}. " f"Supported formats: {', '.join(self.SUPPORTED_EXTENSIONS)}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load with OpenCV (returns BGR format)
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
if self.path.suffix.lower() in [".tif", ".tiff"]:
|
||||
self._data = imread(str(self.path))
|
||||
else:
|
||||
raise NotImplementedError("RGB is not implemented")
|
||||
# Load with OpenCV (returns BGR format)
|
||||
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if self._data is None:
|
||||
raise ImageLoadError(f"Failed to load image with OpenCV: {self.path}")
|
||||
|
||||
# Extract metadata
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
# print(self._data.shape)
|
||||
if len(self._data.shape) == 2:
|
||||
self._height, self._width = self._data.shape[:2]
|
||||
self._channels = 1
|
||||
else:
|
||||
self._height, self._width = self._data.shape[1:]
|
||||
self._channels = self._data.shape[0]
|
||||
# self._channels = self._data.shape[2] if len(self._data.shape) == 3 else 1
|
||||
self._format = self.path.suffix.lower().lstrip(".")
|
||||
self._size_bytes = self.path.stat().st_size
|
||||
self._dtype = self._data.dtype
|
||||
|
||||
# Load PIL version for compatibility (convert BGR to RGB)
|
||||
if self._channels == 3:
|
||||
rgb_data = cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
self._pil_image = PILImage.fromarray(rgb_data)
|
||||
elif self._channels == 4:
|
||||
rgba_data = cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
|
||||
self._pil_image = PILImage.fromarray(rgba_data)
|
||||
else:
|
||||
# Grayscale
|
||||
self._pil_image = PILImage.fromarray(self._data)
|
||||
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
if 0:
|
||||
logger.info(
|
||||
f"Successfully loaded image: {self.path.name} "
|
||||
f"({self._width}x{self._height}, {self._channels} channels, "
|
||||
f"{self._format.upper()})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading image {self.path}: {e}")
|
||||
@@ -131,18 +165,6 @@ class Image:
|
||||
raise ImageLoadError("Image data not available")
|
||||
return self._data
|
||||
|
||||
@property
|
||||
def pil_image(self) -> PILImage.Image:
|
||||
"""
|
||||
Get image data as PIL Image (RGB or grayscale).
|
||||
|
||||
Returns:
|
||||
PIL Image object
|
||||
"""
|
||||
if self._pil_image is None:
|
||||
raise ImageLoadError("PIL image not available")
|
||||
return self._pil_image
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
"""Get image width in pixels."""
|
||||
@@ -187,6 +209,7 @@ class Image:
|
||||
@property
|
||||
def dtype(self) -> np.dtype:
|
||||
"""Get the data type of the image array."""
|
||||
|
||||
if self._dtype is None:
|
||||
raise ImageLoadError("Image dtype not available")
|
||||
return self._dtype
|
||||
@@ -206,8 +229,10 @@ class Image:
|
||||
elif self._channels == 1:
|
||||
if self._dtype == np.uint16:
|
||||
return QImage.Format_Grayscale16
|
||||
else:
|
||||
elif self._dtype == np.uint8:
|
||||
return QImage.Format_Grayscale8
|
||||
elif self._dtype == np.float32:
|
||||
return QImage.Format_BGR30
|
||||
else:
|
||||
raise ImageLoadError(f"Unsupported number of channels: {self._channels}")
|
||||
|
||||
@@ -218,6 +243,12 @@ class Image:
|
||||
Returns:
|
||||
Image data in RGB format as numpy array
|
||||
"""
|
||||
if self.channels == 1:
|
||||
img = get_pseudo_rgb(self.data)
|
||||
self._dtype = img.dtype
|
||||
return img
|
||||
raise NotImplementedError
|
||||
|
||||
if self._channels == 3:
|
||||
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
|
||||
elif self._channels == 4:
|
||||
@@ -225,6 +256,18 @@ class Image:
|
||||
else:
|
||||
return self._data
|
||||
|
||||
def get_qt_rgb(self) -> np.ascontiguousarray:
|
||||
# we keep data as (C, H, W)
|
||||
_img = self.get_rgb()
|
||||
|
||||
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
|
||||
img[..., 0] = _img[0] # R gradient
|
||||
img[..., 1] = _img[1] # G gradient
|
||||
img[..., 2] = _img[2] # B constant
|
||||
img[..., 3] = 1.0 # A = 1.0 (opaque)
|
||||
|
||||
return np.ascontiguousarray(img)
|
||||
|
||||
def get_grayscale(self) -> np.ndarray:
|
||||
"""
|
||||
Get image as grayscale numpy array.
|
||||
@@ -277,11 +320,26 @@ class Image:
|
||||
"""
|
||||
return self._channels >= 3
|
||||
|
||||
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
||||
|
||||
if self.channels == 1:
|
||||
if pseudo_rgb:
|
||||
img = get_pseudo_rgb(self.data)
|
||||
print("Image.save", img.shape)
|
||||
else:
|
||||
img = np.repeat(self.data, 3, axis=2)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Only grayscale images are supported for now.")
|
||||
|
||||
imwrite(path, data=img)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return (
|
||||
f"Image(path='{self.path.name}', "
|
||||
f"shape=({self._width}x{self._height}x{self._channels}), "
|
||||
# Display as HxWxC to match the conventional NumPy shape semantics.
|
||||
f"shape=({self._height}x{self._width}x{self._channels}), "
|
||||
f"format={self._format}, "
|
||||
f"size={self.size_mb:.2f}MB)"
|
||||
)
|
||||
@@ -289,3 +347,15 @@ class Image:
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the Image object."""
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
img = Image(args.path)
|
||||
img.save(args.path + "test.tif")
|
||||
print(img)
|
||||
|
||||
168
src/utils/image_converters.py
Normal file
168
src/utils/image_converters.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import numpy as np
|
||||
|
||||
from roifile import ImagejRoi
|
||||
from tifffile import TiffFile, TiffWriter
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class UT:
|
||||
"""
|
||||
Docstring for UT
|
||||
|
||||
Operetta files along with rois drawn in ImageJ
|
||||
"""
|
||||
|
||||
def __init__(self, roifile_fn: Path, no_labels: bool):
|
||||
self.roifile_fn = roifile_fn
|
||||
print("is file", self.roifile_fn.is_file())
|
||||
self.rois = None
|
||||
if no_labels:
|
||||
self.rois = ImagejRoi.fromfile(self.roifile_fn)
|
||||
print(self.roifile_fn.stem)
|
||||
print(self.roifile_fn.parent.parts[-1])
|
||||
if "Roi-" in self.roifile_fn.stem:
|
||||
self.stem = self.roifile_fn.stem.split("Roi-")[1]
|
||||
else:
|
||||
self.stem = self.roifile_fn.parent.parts[-1]
|
||||
|
||||
else:
|
||||
self.roifile_fn = roifile_fn / roifile_fn.parts[-1]
|
||||
self.stem = self.roifile_fn.stem
|
||||
|
||||
print(self.roifile_fn)
|
||||
|
||||
print(self.stem)
|
||||
self.image, self.image_props = self._load_images()
|
||||
|
||||
def _load_images(self):
|
||||
"""Loading sequence of tif files
|
||||
array sequence is CZYX
|
||||
"""
|
||||
print("Loading images:", self.roifile_fn.parent, self.stem)
|
||||
fns = list(self.roifile_fn.parent.glob(f"{self.stem.lower()}*.tif*"))
|
||||
stems = [fn.stem.split(self.stem)[-1] for fn in fns]
|
||||
n_ch = len(set([stem.split("-ch")[-1].split("t")[0] for stem in stems]))
|
||||
n_p = len(set([stem.split("-")[0] for stem in stems]))
|
||||
n_t = len(set([stem.split("t")[1] for stem in stems]))
|
||||
|
||||
with TiffFile(fns[0]) as tif:
|
||||
img = tif.asarray()
|
||||
w, h = img.shape
|
||||
dtype = img.dtype
|
||||
self.image_props = {
|
||||
"channels": n_ch,
|
||||
"planes": n_p,
|
||||
"tiles": n_t,
|
||||
"width": w,
|
||||
"height": h,
|
||||
"dtype": dtype,
|
||||
}
|
||||
print("Image props", self.image_props)
|
||||
|
||||
image_stack = np.zeros((n_ch, n_p, w, h), dtype=dtype)
|
||||
for fn in fns:
|
||||
with TiffFile(fn) as tif:
|
||||
img = tif.asarray()
|
||||
stem = fn.stem.split(self.stem)[-1]
|
||||
ch = int(stem.split("-ch")[-1].split("t")[0])
|
||||
p = int(stem.split("-")[0].split("p")[1])
|
||||
t = int(stem.split("t")[1])
|
||||
print(fn.stem, "ch", ch, "p", p, "t", t)
|
||||
image_stack[ch - 1, p - 1] = img
|
||||
|
||||
print(image_stack.shape)
|
||||
|
||||
return image_stack, self.image_props
|
||||
|
||||
@property
|
||||
def width(self):
|
||||
return self.image_props["width"]
|
||||
|
||||
@property
|
||||
def height(self):
|
||||
return self.image_props["height"]
|
||||
|
||||
@property
|
||||
def nchannels(self):
|
||||
return self.image_props["channels"]
|
||||
|
||||
@property
|
||||
def nplanes(self):
|
||||
return self.image_props["planes"]
|
||||
|
||||
def export_rois(
|
||||
self,
|
||||
path: Path,
|
||||
subfolder: str = "labels",
|
||||
class_index: int = 0,
|
||||
):
|
||||
"""Export rois to a file"""
|
||||
with open(path / subfolder / f"{self.stem}.txt", "w") as f:
|
||||
for i, roi in enumerate(self.rois):
|
||||
rc = roi.subpixel_coordinates
|
||||
if rc is None:
|
||||
print(f"No coordinates: {self.roifile_fn}, element {i}, out of {len(self.rois)}")
|
||||
continue
|
||||
xmn, ymn = rc.min(axis=0)
|
||||
xmx, ymx = rc.max(axis=0)
|
||||
xc = (xmn + xmx) / 2
|
||||
yc = (ymn + ymx) / 2
|
||||
bw = xmx - xmn
|
||||
bh = ymx - ymn
|
||||
coords = f"{xc/self.width} {yc/self.height} {bw/self.width} {bh/self.height} "
|
||||
for x, y in rc:
|
||||
coords += f"{x/self.width} {y/self.height} "
|
||||
f.write(f"{class_index} {coords}\n")
|
||||
|
||||
return
|
||||
|
||||
def export_image(
|
||||
self,
|
||||
path: Path,
|
||||
subfolder: str = "images",
|
||||
plane_mode: str = "max projection",
|
||||
channel: int = 0,
|
||||
):
|
||||
"""Export image to a file"""
|
||||
|
||||
if plane_mode == "max projection":
|
||||
self.image = np.max(self.image[channel], axis=0)
|
||||
print(self.image.shape)
|
||||
|
||||
print(path / subfolder / f"{self.stem}.tif")
|
||||
with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif:
|
||||
tif.write(self.image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input", nargs="*", type=Path)
|
||||
parser.add_argument("-o", "--output", type=Path)
|
||||
parser.add_argument(
|
||||
"--no-labels",
|
||||
action="store_false",
|
||||
help="Source does not have labels, export only images",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# print(args)
|
||||
# aa
|
||||
|
||||
for path in args.input:
|
||||
print("Path:", path)
|
||||
if not args.no_labels:
|
||||
print("No labels")
|
||||
ut = UT(path, args.no_labels)
|
||||
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||
|
||||
else:
|
||||
for rfn in Path(path).glob("*.zip"):
|
||||
# if Path(path).suffix == ".zip":
|
||||
print("Roi FN:", rfn)
|
||||
ut = UT(rfn, args.no_labels)
|
||||
ut.export_rois(args.output, class_index=0)
|
||||
ut.export_image(args.output, plane_mode="max projection", channel=0)
|
||||
|
||||
print()
|
||||
353
src/utils/image_splitter.py
Normal file
353
src/utils/image_splitter.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import numpy as np
|
||||
|
||||
from pathlib import Path
|
||||
from tifffile import imread, imwrite
|
||||
from shapely.geometry import LineString
|
||||
from copy import deepcopy
|
||||
from scipy.ndimage import zoom
|
||||
|
||||
|
||||
# debug
|
||||
from src.utils.image import Image
|
||||
from show_yolo_seg import draw_annotations
|
||||
|
||||
import pylab as plt
|
||||
import cv2
|
||||
|
||||
|
||||
class Label:
|
||||
def __init__(self, yolo_annotation: str):
|
||||
class_id, bbox, polygon = self.parse_yolo_annotation(yolo_annotation)
|
||||
self.class_id = class_id
|
||||
self.bbox = bbox
|
||||
self.polygon = polygon
|
||||
|
||||
def parse_yolo_annotation(self, yolo_annotation: str):
|
||||
class_id, *coords = yolo_annotation.split()
|
||||
class_id = int(class_id)
|
||||
bbox = np.array(coords[:4], dtype=np.float32)
|
||||
polygon = np.array(coords[4:], dtype=np.float32).reshape(-1, 2) if len(coords) > 4 else None
|
||||
if not any(np.isclose(polygon[0], polygon[-1])):
|
||||
polygon = np.vstack([polygon, polygon[0]])
|
||||
return class_id, bbox, polygon
|
||||
|
||||
def offset_label(
|
||||
self,
|
||||
img_w,
|
||||
img_h,
|
||||
distance: float = 1.0,
|
||||
cap_style: int = 2,
|
||||
join_style: int = 2,
|
||||
):
|
||||
if self.polygon is None:
|
||||
self.bbox = np.array(
|
||||
[
|
||||
self.bbox[0] - distance if self.bbox[0] - distance > 0 else 0,
|
||||
self.bbox[1] - distance if self.bbox[1] - distance > 0 else 0,
|
||||
self.bbox[2] + distance if self.bbox[2] + distance < 1 else 1,
|
||||
self.bbox[3] + distance if self.bbox[3] + distance < 1 else 1,
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
return self.bbox
|
||||
|
||||
def coords_are_normalized(coords):
|
||||
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||
print(coords)
|
||||
# if not coords:
|
||||
# return False
|
||||
return all(max(coords.flatten)) <= 1.001
|
||||
|
||||
def poly_to_pts(coords, img_w, img_h):
|
||||
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||
# if coords_are_normalized(coords):
|
||||
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
|
||||
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||
return pts
|
||||
|
||||
pts = poly_to_pts(self.polygon, img_w, img_h)
|
||||
line = LineString(pts)
|
||||
# Buffer distance in pixels
|
||||
buffered = line.buffer(distance=distance, cap_style=cap_style, join_style=join_style)
|
||||
self.polygon = np.array(buffered.exterior.coords, dtype=np.float32) / (img_w, img_h)
|
||||
xmn, ymn = self.polygon.min(axis=0)
|
||||
xmx, ymx = self.polygon.max(axis=0)
|
||||
xc = (xmn + xmx) / 2
|
||||
yc = (ymn + ymx) / 2
|
||||
bw = xmx - xmn
|
||||
bh = ymx - ymn
|
||||
self.bbox = np.array([xc, yc, bw, bh], dtype=np.float32)
|
||||
|
||||
return self.bbox, self.polygon
|
||||
|
||||
def translate(self, x, y, scale_x, scale_y):
|
||||
self.bbox[0] -= x
|
||||
self.bbox[0] *= scale_x
|
||||
self.bbox[1] -= y
|
||||
self.bbox[1] *= scale_y
|
||||
self.bbox[2] *= scale_x
|
||||
self.bbox[3] *= scale_y
|
||||
if self.polygon is not None:
|
||||
self.polygon[:, 0] -= x
|
||||
self.polygon[:, 0] *= scale_x
|
||||
self.polygon[:, 1] -= y
|
||||
self.polygon[:, 1] *= scale_y
|
||||
|
||||
def in_range(self, hrange, wrange):
|
||||
xc, yc, h, w = self.bbox
|
||||
x1 = xc - w / 2
|
||||
y1 = yc - h / 2
|
||||
x2 = xc + w / 2
|
||||
y2 = yc + h / 2
|
||||
truth_val = (
|
||||
xc >= wrange[0]
|
||||
and x1 <= wrange[1]
|
||||
and x2 >= wrange[0]
|
||||
and x2 <= wrange[1]
|
||||
and y1 >= hrange[0]
|
||||
and y1 <= hrange[1]
|
||||
and y2 >= hrange[0]
|
||||
and y2 <= hrange[1]
|
||||
)
|
||||
|
||||
print(x1, x2, wrange, y1, y2, hrange, truth_val)
|
||||
return truth_val
|
||||
|
||||
def to_string(self, bbox: list = None, polygon: list = None):
|
||||
if bbox is None:
|
||||
bbox = self.bbox
|
||||
if polygon is None:
|
||||
polygon = self.polygon
|
||||
coords = " ".join([f"{x:.6f}" for x in self.bbox])
|
||||
if self.polygon is not None:
|
||||
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
|
||||
return f"{self.class_id} {coords}"
|
||||
|
||||
def __str__(self):
|
||||
return f"Class: {self.class_id}, BBox: {self.bbox}, Polygon: {self.polygon}"
|
||||
|
||||
|
||||
class YoloLabelReader:
|
||||
def __init__(self, label_path: Path):
|
||||
self.label_path = label_path
|
||||
self.labels = self._read_labels()
|
||||
|
||||
def _read_labels(self):
|
||||
with open(self.label_path, "r") as f:
|
||||
labels = [Label(line) for line in f.readlines()]
|
||||
|
||||
return labels
|
||||
|
||||
def get_labels(self, hrange, wrange):
|
||||
"""hrange and wrange are tuples of (start, end) normalized to [0, 1]"""
|
||||
labels = []
|
||||
# print(hrange, wrange)
|
||||
for lbl in self.labels:
|
||||
# print(lbl)
|
||||
if lbl.in_range(hrange, wrange):
|
||||
labels.append(lbl)
|
||||
return labels if len(labels) > 0 else None
|
||||
|
||||
def __get_item__(self, index):
|
||||
return self.labels[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.labels)
|
||||
|
||||
|
||||
class ImageSplitter:
|
||||
def __init__(self, image_path: Path, label_path: Path):
|
||||
self.image = imread(image_path)
|
||||
self.image_path = image_path
|
||||
self.label_path = label_path
|
||||
if not label_path.exists():
|
||||
print(f"Label file {label_path} not found")
|
||||
self.labels = None
|
||||
else:
|
||||
self.labels = YoloLabelReader(label_path)
|
||||
|
||||
def split_into_tiles(self, patch_size: tuple = (2, 2)):
|
||||
"""Split image into patches of size patch_size"""
|
||||
hstep, wstep = (
|
||||
self.image.shape[0] // patch_size[0],
|
||||
self.image.shape[1] // patch_size[1],
|
||||
)
|
||||
h, w = self.image.shape[:2]
|
||||
|
||||
for i in range(patch_size[0]):
|
||||
for j in range(patch_size[1]):
|
||||
tile_reference = f"i{i}j{j}"
|
||||
hrange = (i * hstep / h, (i + 1) * hstep / h)
|
||||
wrange = (j * wstep / w, (j + 1) * wstep / w)
|
||||
tile = self.image[i * hstep : (i + 1) * hstep, j * wstep : (j + 1) * wstep]
|
||||
|
||||
labels = None
|
||||
if self.labels is not None:
|
||||
labels = deepcopy(self.labels.get_labels(hrange, wrange))
|
||||
print(id(labels))
|
||||
|
||||
if labels is not None:
|
||||
print(hrange[0], wrange[0])
|
||||
for l in labels:
|
||||
print(l.bbox)
|
||||
[l.translate(wrange[0], hrange[0], 2, 2) for l in labels]
|
||||
print("translated")
|
||||
for l in labels:
|
||||
print(l.bbox)
|
||||
|
||||
# print(labels)
|
||||
yield tile_reference, tile, labels
|
||||
|
||||
def split_respective_to_label(self, padding: int = 67):
|
||||
if self.labels is None:
|
||||
raise ValueError("No labels found. Only images having labels can be split.")
|
||||
|
||||
for i, label in enumerate(self.labels):
|
||||
tile_reference = f"_lbl-{i+1:02d}"
|
||||
# print(label.bbox)
|
||||
|
||||
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
|
||||
xc, yc, h, w = [
|
||||
int(np.round(f))
|
||||
for f in [
|
||||
xc_norm * self.image.shape[1],
|
||||
yc_norm * self.image.shape[0],
|
||||
h_norm * self.image.shape[0],
|
||||
w_norm * self.image.shape[1],
|
||||
]
|
||||
] # image coords
|
||||
|
||||
# print("img coords:", xc, yc, h, w)
|
||||
pad_xneg = padding + 1 # int(w / 2) + padding
|
||||
pad_xpos = padding # int(w / 2) + padding
|
||||
pad_yneg = padding + 1 # int(h / 2) + padding
|
||||
pad_ypos = padding # int(h / 2) + padding
|
||||
if xc - pad_xneg < 0:
|
||||
pad_xneg = xc
|
||||
if pad_xpos + xc > self.image.shape[1]:
|
||||
pad_xpos = self.image.shape[1] - xc
|
||||
if yc - pad_yneg < 0:
|
||||
pad_yneg = yc
|
||||
if pad_ypos + yc > self.image.shape[0]:
|
||||
pad_ypos = self.image.shape[0] - yc
|
||||
|
||||
# print("pads:", pad_xneg, pad_xpos, pad_yneg, pad_ypos)
|
||||
|
||||
tile = self.image[
|
||||
yc - pad_yneg : yc + pad_ypos,
|
||||
xc - pad_xneg : xc + pad_xpos,
|
||||
]
|
||||
ny, nx = tile.shape
|
||||
x_offset = pad_xneg
|
||||
y_offset = pad_yneg
|
||||
|
||||
# print("tile shape:", tile.shape)
|
||||
|
||||
yolo_annotation = f"{label.class_id} {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
|
||||
print(yolo_annotation)
|
||||
yolo_annotation += " ".join(
|
||||
[
|
||||
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
|
||||
for x, y in label.polygon
|
||||
]
|
||||
)
|
||||
new_label = Label(yolo_annotation=yolo_annotation)
|
||||
|
||||
yield tile_reference, tile, [new_label]
|
||||
|
||||
|
||||
def main(args):
|
||||
|
||||
if args.output:
|
||||
args.output.mkdir(exist_ok=True, parents=True)
|
||||
(args.output / "images").mkdir(exist_ok=True)
|
||||
(args.output / "images-zoomed").mkdir(exist_ok=True)
|
||||
(args.output / "labels").mkdir(exist_ok=True)
|
||||
|
||||
for image_path in (args.input / "images").glob("*.tif"):
|
||||
data = ImageSplitter(
|
||||
image_path=image_path,
|
||||
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
|
||||
)
|
||||
|
||||
if args.split_around_label:
|
||||
data = data.split_respective_to_label(padding=args.padding)
|
||||
else:
|
||||
data = data.split_into_tiles(patch_size=args.patch_size)
|
||||
|
||||
for tile_reference, tile, labels in data:
|
||||
print()
|
||||
print(tile_reference, tile.shape, labels) # len(labels) if labels else None)
|
||||
|
||||
# { debug
|
||||
debug = False
|
||||
if debug:
|
||||
plt.figure(figsize=(10, 10 * tile.shape[0] / tile.shape[1]))
|
||||
if labels is None:
|
||||
plt.imshow(tile, cmap="gray")
|
||||
plt.axis("off")
|
||||
plt.title(f"{image_path.name} ({tile_reference})")
|
||||
plt.show()
|
||||
continue
|
||||
|
||||
print(labels[0].bbox)
|
||||
# Draw annotations
|
||||
out = draw_annotations(
|
||||
cv2.cvtColor((tile / tile.max() * 255).astype(np.uint8), cv2.COLOR_GRAY2BGR),
|
||||
[l.to_string() for l in labels],
|
||||
alpha=0.1,
|
||||
)
|
||||
|
||||
# Convert BGR -> RGB for matplotlib display
|
||||
out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||
plt.imshow(out_rgb)
|
||||
plt.axis("off")
|
||||
plt.title(f"{image_path.name} ({tile_reference})")
|
||||
plt.show()
|
||||
# } debug
|
||||
|
||||
if args.output:
|
||||
imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile)
|
||||
scale = 5
|
||||
tile_zoomed = zoom(tile, zoom=scale)
|
||||
imwrite(args.output / "images-zoomed" / f"{image_path.stem}_{tile_reference}.tif", tile_zoomed)
|
||||
|
||||
if labels is not None:
|
||||
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
|
||||
for label in labels:
|
||||
label.offset_label(tile.shape[1], tile.shape[0])
|
||||
f.write(label.to_string() + "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-i", "--input", type=Path)
|
||||
parser.add_argument("-o", "--output", type=Path)
|
||||
parser.add_argument(
|
||||
"-p",
|
||||
"--patch-size",
|
||||
nargs=2,
|
||||
type=int,
|
||||
default=[2, 2],
|
||||
help="Number of patches along height and width, rows and columns, respectively",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-sal",
|
||||
"--split-around-label",
|
||||
action="store_true",
|
||||
help="If enabled, the image will be split around the label and for each label, a separate image will be created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--padding",
|
||||
type=int,
|
||||
default=67,
|
||||
help="Padding around the label when splitting around the label.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
1
src/utils/show_yolo_seg.py
Symbolic link
1
src/utils/show_yolo_seg.py
Symbolic link
@@ -0,0 +1 @@
|
||||
../../tests/show_yolo_seg.py
|
||||
156
src/utils/ultralytics_16bit_patch.py
Normal file
156
src/utils/ultralytics_16bit_patch.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Ultralytics runtime patches for 16-bit TIFF training.
|
||||
|
||||
Goals:
|
||||
- Use `tifffile` to decode `.tif/.tiff` reliably (OpenCV can silently drop bit-depth depending on codec).
|
||||
- Preserve 16-bit data through the dataloader as `uint16` tensors.
|
||||
- Fix Ultralytics trainer normalization (default divides by 255) to scale `uint16` correctly.
|
||||
- Avoid uint8-forcing augmentations by recommending/setting hyp values (handled by caller).
|
||||
|
||||
This module is intended to be imported/called **before** instantiating/using YOLO.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
|
||||
"""Apply runtime monkey-patches to Ultralytics to better support 16-bit TIFFs.
|
||||
|
||||
This function is safe to call multiple times.
|
||||
|
||||
Args:
|
||||
force: If True, re-apply patches even if already applied.
|
||||
"""
|
||||
|
||||
# Import inside function to ensure patching occurs before YOLO model/dataset is created.
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# import tifffile
|
||||
import torch
|
||||
from src.utils.image import Image
|
||||
|
||||
from ultralytics.utils import patches as ul_patches
|
||||
|
||||
already_patched = getattr(ul_patches.imread, "__name__", "") == "tifffile_imread"
|
||||
if already_patched and not force:
|
||||
return
|
||||
|
||||
_original_imread = ul_patches.imread
|
||||
|
||||
def tifffile_imread(filename: str, flags: int = cv2.IMREAD_COLOR, pseudo_rgb: bool = True) -> Optional[np.ndarray]:
|
||||
"""Replacement for [`ultralytics.utils.patches.imread()`](venv/lib/python3.12/site-packages/ultralytics/utils/patches.py:20).
|
||||
|
||||
- For `.tif/.tiff`, uses `tifffile.imread()` and preserves dtype (e.g. uint16).
|
||||
- For other formats, falls back to Ultralytics' original implementation.
|
||||
- Always returns HWC (3 dims). For grayscale, returns (H, W, 1) or (H, W, 3) depending on requested flags.
|
||||
"""
|
||||
# print("here")
|
||||
# return _original_imread(filename, flags)
|
||||
ext = os.path.splitext(filename)[1].lower()
|
||||
if ext in (".tif", ".tiff"):
|
||||
arr = Image(filename).get_qt_rgb()[:, :, :3]
|
||||
|
||||
# Normalize common shapes:
|
||||
# - (H, W) -> (H, W, 1)
|
||||
# - (C, H, W) -> (H, W, C) (heuristic)
|
||||
if arr is None:
|
||||
return None
|
||||
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[0] < arr.shape[1]:
|
||||
arr = np.transpose(arr, (1, 2, 0))
|
||||
if arr.ndim == 2:
|
||||
arr = arr[..., None]
|
||||
|
||||
# Ensure contiguous array for downstream OpenCV ops.
|
||||
# logger.info(f"Loading with monkey-patched imread: {filename}")
|
||||
arr = arr.astype(np.float32)
|
||||
arr /= arr.max()
|
||||
arr *= 2**16 - 1
|
||||
arr = arr.astype(np.uint16)
|
||||
return np.ascontiguousarray(arr)
|
||||
|
||||
# logger.info(f"Loading with original imread: {filename}")
|
||||
return _original_imread(filename, flags)
|
||||
|
||||
# Patch the canonical reference.
|
||||
ul_patches.imread = tifffile_imread
|
||||
|
||||
# Patch common module-level imports (some Ultralytics modules do `from ... import imread`).
|
||||
# Importing these modules is safe and helps ensure the patched function is used.
|
||||
try:
|
||||
import ultralytics.data.base as _ul_base
|
||||
|
||||
_ul_base.imread = tifffile_imread
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
import ultralytics.data.loaders as _ul_loaders
|
||||
|
||||
_ul_loaders.imread = tifffile_imread
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Patch trainer normalization: default divides by 255 regardless of input dtype.
|
||||
from ultralytics.models.yolo.detect import train as detect_train
|
||||
|
||||
_orig_preprocess_batch = detect_train.DetectionTrainer.preprocess_batch
|
||||
|
||||
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
|
||||
# Start from upstream behavior to keep device placement + multiscale identical,
|
||||
# but replace the 255 division with dtype-aware scaling.
|
||||
logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
|
||||
for k, v in batch.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")
|
||||
|
||||
img = batch.get("img")
|
||||
if isinstance(img, torch.Tensor):
|
||||
# Decide scaling denom based on dtype (avoid expensive reductions if possible).
|
||||
if img.dtype == torch.uint8:
|
||||
denom = 255.0
|
||||
elif img.dtype == torch.uint16:
|
||||
denom = 65535.0
|
||||
elif img.dtype.is_floating_point:
|
||||
# Assume already in 0-1 range if float.
|
||||
denom = 1.0
|
||||
else:
|
||||
# Generic integer fallback.
|
||||
try:
|
||||
denom = float(torch.iinfo(img.dtype).max)
|
||||
except Exception:
|
||||
denom = 255.0
|
||||
|
||||
batch["img"] = img.float() / denom
|
||||
|
||||
# Multi-scale branch copied from upstream to avoid re-introducing `/255` scaling.
|
||||
if getattr(self.args, "multi_scale", False):
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
imgs = batch["img"]
|
||||
sz = (
|
||||
random.randrange(int(self.args.imgsz * 0.5), int(self.args.imgsz * 1.5 + self.stride))
|
||||
// self.stride
|
||||
* self.stride
|
||||
)
|
||||
sf = sz / max(imgs.shape[2:])
|
||||
if sf != 1:
|
||||
ns = [math.ceil(x * sf / self.stride) * self.stride for x in imgs.shape[2:]]
|
||||
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)
|
||||
batch["img"] = imgs
|
||||
|
||||
return batch
|
||||
|
||||
detect_train.DetectionTrainer.preprocess_batch = preprocess_batch_16bit
|
||||
|
||||
# Tag function to make it easier to detect patch state.
|
||||
setattr(detect_train.DetectionTrainer.preprocess_batch, "_ultralytics_16bit_patch", True)
|
||||
223
tests/show_yolo_seg.py
Normal file
223
tests/show_yolo_seg.py
Normal file
@@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
show_yolo_seg.py
|
||||
|
||||
Usage:
|
||||
python show_yolo_seg.py /path/to/image.jpg /path/to/labels.txt
|
||||
|
||||
Supports:
|
||||
- Segmentation polygons: "class x1 y1 x2 y2 ... xn yn"
|
||||
- YOLO bbox lines as fallback: "class x_center y_center width height"
|
||||
Coordinates can be normalized [0..1] or absolute pixels (auto-detected).
|
||||
"""
|
||||
import sys
|
||||
import cv2
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import random
|
||||
from shapely.geometry import LineString
|
||||
|
||||
from src.utils.image import Image
|
||||
|
||||
|
||||
def parse_label_line(line):
|
||||
parts = line.strip().split()
|
||||
if not parts:
|
||||
return None
|
||||
cls = int(float(parts[0]))
|
||||
coords = [float(x) for x in parts[1:]]
|
||||
return cls, coords
|
||||
|
||||
|
||||
def coords_are_normalized(coords):
|
||||
# If every coordinate is between 0 and 1 (inclusive-ish), assume normalized
|
||||
if not coords:
|
||||
return False
|
||||
return max(coords) <= 1.001
|
||||
|
||||
|
||||
def yolo_bbox_to_xyxy(coords, img_w, img_h):
|
||||
# coords: [xc, yc, w, h] normalized or absolute
|
||||
xc, yc, w, h = coords[:4]
|
||||
if max(coords) <= 1.001:
|
||||
xc *= img_w
|
||||
yc *= img_h
|
||||
w *= img_w
|
||||
h *= img_h
|
||||
x1 = int(round(xc - w / 2))
|
||||
y1 = int(round(yc - h / 2))
|
||||
x2 = int(round(xc + w / 2))
|
||||
y2 = int(round(yc + h / 2))
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def poly_to_pts(coords, img_w, img_h):
|
||||
# coords: [x1 y1 x2 y2 ...] either normalized or absolute
|
||||
if coords_are_normalized(coords[4:]):
|
||||
coords = [coords[i] * (img_w if i % 2 == 0 else img_h) for i in range(len(coords))]
|
||||
pts = np.array(coords, dtype=np.int32).reshape(-1, 2)
|
||||
return pts
|
||||
|
||||
|
||||
def random_color_for_class(cls):
|
||||
random.seed(cls) # deterministic per class
|
||||
return (
|
||||
0,
|
||||
0,
|
||||
255,
|
||||
) # tuple(int(x) for x in np.array([random.randint(0, 255) for _ in range(3)]))
|
||||
|
||||
|
||||
def draw_annotations(img, labels, alpha=0.4, draw_bbox_for_poly=True):
|
||||
# img: BGR numpy array
|
||||
overlay = img.copy()
|
||||
h, w = img.shape[:2]
|
||||
for line in labels:
|
||||
if isinstance(line, str):
|
||||
cls, coords = parse_label_line(line)
|
||||
if isinstance(line, tuple):
|
||||
cls, coords = line
|
||||
|
||||
if not coords:
|
||||
continue
|
||||
# polygon case (>=6 coordinates)
|
||||
if len(coords) >= 6:
|
||||
color = random_color_for_class(cls)
|
||||
|
||||
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords[:4], w, h)
|
||||
print(x1, y1, x2, y2)
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 1)
|
||||
|
||||
pts = poly_to_pts(coords[4:], w, h)
|
||||
# line = LineString(pts)
|
||||
# # Buffer distance in pixels
|
||||
# buffered = line.buffer(3, cap_style=2, join_style=2)
|
||||
# coords = np.array(buffered.exterior.coords, dtype=np.int32)
|
||||
# cv2.fillPoly(overlay, [coords], color=(255, 255, 255))
|
||||
|
||||
# fill on overlay
|
||||
cv2.fillPoly(overlay, [pts], color)
|
||||
# outline on base image
|
||||
cv2.polylines(img, [pts], isClosed=True, color=color, thickness=1)
|
||||
# put class text at first point
|
||||
x, y = int(pts[0, 0]), int(pts[0, 1]) - 6
|
||||
if 0:
|
||||
cv2.putText(
|
||||
img,
|
||||
str(cls),
|
||||
(x, max(6, y)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
|
||||
# YOLO bbox case (4 coords)
|
||||
elif len(coords) == 4:
|
||||
x1, y1, x2, y2 = yolo_bbox_to_xyxy(coords, w, h)
|
||||
color = random_color_for_class(cls)
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
|
||||
cv2.putText(
|
||||
img,
|
||||
str(cls),
|
||||
(x1, max(6, y1 - 4)),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.6,
|
||||
(255, 255, 255),
|
||||
2,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
else:
|
||||
# Unknown / invalid format, skip
|
||||
continue
|
||||
|
||||
# blend overlay for filled polygons
|
||||
cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img)
|
||||
return img
|
||||
|
||||
|
||||
def load_labels_file(label_path):
|
||||
labels = []
|
||||
with open(label_path, "r") as f:
|
||||
for raw in f:
|
||||
line = raw.strip()
|
||||
if not line:
|
||||
continue
|
||||
parsed = parse_label_line(line)
|
||||
if parsed:
|
||||
labels.append(parsed)
|
||||
return labels
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Show YOLO segmentation / polygon annotations")
|
||||
parser.add_argument("image", type=str, help="Path to image file")
|
||||
parser.add_argument("--labels", type=str, help="Path to YOLO label file (polygons)")
|
||||
parser.add_argument("--alpha", type=float, default=0.4, help="Polygon fill alpha (0..1)")
|
||||
parser.add_argument("--no-bbox", action="store_true", help="Don't draw bounding boxes for polygons")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args)
|
||||
|
||||
img_path = Path(args.image)
|
||||
if args.labels:
|
||||
lbl_path = Path(args.labels)
|
||||
else:
|
||||
lbl_path = img_path.with_suffix(".txt")
|
||||
lbl_path = Path(str(lbl_path).replace("images", "labels"))
|
||||
|
||||
if not img_path.exists():
|
||||
print("Image not found:", img_path)
|
||||
sys.exit(1)
|
||||
if not lbl_path.exists():
|
||||
print("Label file not found:", lbl_path)
|
||||
sys.exit(1)
|
||||
|
||||
# img = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
|
||||
img = (Image(img_path).get_qt_rgb() * 255).astype(np.uint8)
|
||||
|
||||
if img is None:
|
||||
print("Could not load image:", img_path)
|
||||
sys.exit(1)
|
||||
|
||||
labels = load_labels_file(str(lbl_path))
|
||||
if not labels:
|
||||
print("No labels parsed from", lbl_path)
|
||||
# continue and just show image
|
||||
out = draw_annotations(img.copy(), labels, alpha=args.alpha, draw_bbox_for_poly=(not args.no_bbox))
|
||||
|
||||
lclass, coords = labels[0]
|
||||
print(lclass, coords)
|
||||
bbox = coords[:4]
|
||||
print("bbox", bbox)
|
||||
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
|
||||
yc, xc, h, w = bbox
|
||||
print("bbox", bbox)
|
||||
polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
|
||||
print("pl", coords[4:])
|
||||
print("pl", polyline)
|
||||
|
||||
# Convert BGR -> RGB for matplotlib display
|
||||
# out_rgb = cv2.cvtColor(out, cv2.COLOR_BGR2RGB)
|
||||
out_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
# out_rgb = Image()
|
||||
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
|
||||
plt.imshow(out_rgb)
|
||||
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
|
||||
plt.plot(
|
||||
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
|
||||
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],
|
||||
"r",
|
||||
linewidth=2,
|
||||
)
|
||||
|
||||
# plt.axis("off")
|
||||
plt.title(f"{img_path.name} ({lbl_path.name})")
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -27,7 +27,7 @@ class TestImage:
|
||||
|
||||
def test_supported_extensions(self):
|
||||
"""Test that supported extensions are correctly defined."""
|
||||
expected_extensions = [".jpg", ".jpeg", ".png", ".tif", ".tiff", ".bmp"]
|
||||
expected_extensions = Image.SUPPORTED_EXTENSIONS
|
||||
assert Image.SUPPORTED_EXTENSIONS == expected_extensions
|
||||
|
||||
def test_image_properties(self, tmp_path):
|
||||
|
||||
Reference in New Issue
Block a user