Compare commits

9 Commits

Author SHA1 Message Date
506c74e53a Small update 2026-01-16 10:39:46 +02:00
eefda5b878 Adding metdata to tiled images 2026-01-16 10:39:14 +02:00
31cb6a6c8e Using 8bit images 2026-01-16 10:38:34 +02:00
0c19ea2557 Updating 2026-01-16 10:30:13 +02:00
89e47591db Formatting 2026-01-16 10:27:15 +02:00
69cde09e53 Changing alpha value 2026-01-16 10:26:25 +02:00
fcbd5fb16d correcting label writing and formatting code 2026-01-16 10:24:19 +02:00
ca52312925 Adding LIKE option for filtering queries 2026-01-16 10:18:48 +02:00
0a93bf797a Adding auto zoom when result is loaded 2026-01-12 14:15:02 +02:00
9 changed files with 238 additions and 357 deletions

View File

@@ -60,9 +60,7 @@ class DatabaseManager:
cursor = conn.cursor()
# Check if annotations table exists
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'"
)
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='annotations'")
if not cursor.fetchone():
# Table doesn't exist yet, no migration needed
return
@@ -242,9 +240,7 @@ class DatabaseManager:
return cursor.lastrowid
except sqlite3.IntegrityError:
# Image already exists, return its ID
cursor.execute(
"SELECT id FROM images WHERE relative_path = ?", (relative_path,)
)
cursor.execute("SELECT id FROM images WHERE relative_path = ?", (relative_path,))
row = cursor.fetchone()
return row["id"] if row else None
finally:
@@ -255,17 +251,13 @@ class DatabaseManager:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM images WHERE relative_path = ?", (relative_path,)
)
cursor.execute("SELECT * FROM images WHERE relative_path = ?", (relative_path,))
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def get_or_create_image(
self, relative_path: str, filename: str, width: int, height: int
) -> int:
def get_or_create_image(self, relative_path: str, filename: str, width: int, height: int) -> int:
"""Get existing image or create new one."""
existing = self.get_image_by_path(relative_path)
if existing:
@@ -355,16 +347,8 @@ class DatabaseManager:
bbox[2],
bbox[3],
det["confidence"],
(
json.dumps(det.get("segmentation_mask"))
if det.get("segmentation_mask")
else None
),
(
json.dumps(det.get("metadata"))
if det.get("metadata")
else None
),
(json.dumps(det.get("segmentation_mask")) if det.get("segmentation_mask") else None),
(json.dumps(det.get("metadata")) if det.get("metadata") else None),
),
)
conn.commit()
@@ -409,12 +393,13 @@ class DatabaseManager:
if filters:
conditions = []
for key, value in filters.items():
if (
key.startswith("d.")
or key.startswith("i.")
or key.startswith("m.")
):
if key.startswith("d.") or key.startswith("i.") or key.startswith("m."):
if "like" in value.lower():
conditions.append(f"{key} LIKE ?")
params.append(value.split(" ")[1])
else:
conditions.append(f"{key} = ?")
params.append(value)
else:
conditions.append(f"d.{key} = ?")
params.append(value)
@@ -442,18 +427,14 @@ class DatabaseManager:
finally:
conn.close()
def get_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> List[Dict]:
def get_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> List[Dict]:
"""Get all detections for a specific image."""
filters = {"image_id": image_id}
if model_id:
filters["model_id"] = model_id
return self.get_detections(filters)
def delete_detections_for_image(
self, image_id: int, model_id: Optional[int] = None
) -> int:
def delete_detections_for_image(self, image_id: int, model_id: Optional[int] = None) -> int:
"""Delete detections tied to a specific image and optional model."""
conn = self.get_connection()
try:
@@ -524,9 +505,7 @@ class DatabaseManager:
""",
params,
)
class_counts = {
row["class_name"]: row["count"] for row in cursor.fetchall()
}
class_counts = {row["class_name"]: row["count"] for row in cursor.fetchall()}
# Average confidence
cursor.execute(
@@ -583,9 +562,7 @@ class DatabaseManager:
# ==================== Export Operations ====================
def export_detections_to_csv(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
def export_detections_to_csv(self, output_path: str, filters: Optional[Dict] = None) -> bool:
"""Export detections to CSV file."""
try:
detections = self.get_detections(filters)
@@ -614,9 +591,7 @@ class DatabaseManager:
for det in detections:
row = {k: det[k] for k in fieldnames if k in det}
# Convert segmentation mask list to JSON string for CSV
if row.get("segmentation_mask") and isinstance(
row["segmentation_mask"], list
):
if row.get("segmentation_mask") and isinstance(row["segmentation_mask"], list):
row["segmentation_mask"] = json.dumps(row["segmentation_mask"])
writer.writerow(row)
@@ -625,9 +600,7 @@ class DatabaseManager:
print(f"Error exporting to CSV: {e}")
return False
def export_detections_to_json(
self, output_path: str, filters: Optional[Dict] = None
) -> bool:
def export_detections_to_json(self, output_path: str, filters: Optional[Dict] = None) -> bool:
"""Export detections to JSON file."""
try:
detections = self.get_detections(filters)
@@ -785,17 +758,13 @@ class DatabaseManager:
conn = self.get_connection()
try:
cursor = conn.cursor()
cursor.execute(
"SELECT * FROM object_classes WHERE class_name = ?", (class_name,)
)
cursor.execute("SELECT * FROM object_classes WHERE class_name = ?", (class_name,))
row = cursor.fetchone()
return dict(row) if row else None
finally:
conn.close()
def add_object_class(
self, class_name: str, color: str, description: Optional[str] = None
) -> int:
def add_object_class(self, class_name: str, color: str, description: Optional[str] = None) -> int:
"""
Add a new object class.
@@ -928,8 +897,7 @@ class DatabaseManager:
if not split_map[required]:
raise ValueError(
"Unable to determine %s image directory under %s. Provide it "
"explicitly via the 'splits' argument."
% (required, dataset_root_path)
"explicitly via the 'splits' argument." % (required, dataset_root_path)
)
yaml_splits: Dict[str, str] = {}
@@ -955,11 +923,7 @@ class DatabaseManager:
if yaml_splits.get("test"):
payload["test"] = yaml_splits["test"]
output_path_obj = (
Path(output_path).expanduser()
if output_path
else dataset_root_path / "data.yaml"
)
output_path_obj = Path(output_path).expanduser() if output_path else dataset_root_path / "data.yaml"
output_path_obj.parent.mkdir(parents=True, exist_ok=True)
with open(output_path_obj, "w", encoding="utf-8") as handle:
@@ -1019,15 +983,9 @@ class DatabaseManager:
for split_name, options in patterns.items():
for relative in options:
candidate = (dataset_root / relative).resolve()
if (
candidate.exists()
and candidate.is_dir()
and self._directory_has_images(candidate)
):
if candidate.exists() and candidate.is_dir() and self._directory_has_images(candidate):
try:
inferred[split_name] = candidate.relative_to(
dataset_root
).as_posix()
inferred[split_name] = candidate.relative_to(dataset_root).as_posix()
except ValueError:
inferred[split_name] = candidate.as_posix()
break

View File

@@ -35,9 +35,7 @@ logger = get_logger(__name__)
class ResultsTab(QWidget):
"""Results tab showing detection history and preview overlays."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
@@ -71,24 +69,12 @@ class ResultsTab(QWidget):
left_layout.addLayout(controls_layout)
self.results_table = QTableWidget(0, 5)
self.results_table.setHorizontalHeaderLabels(
["Image", "Model", "Detections", "Classes", "Last Updated"]
)
self.results_table.horizontalHeader().setSectionResizeMode(
0, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
1, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
2, QHeaderView.ResizeToContents
)
self.results_table.horizontalHeader().setSectionResizeMode(
3, QHeaderView.Stretch
)
self.results_table.horizontalHeader().setSectionResizeMode(
4, QHeaderView.ResizeToContents
)
self.results_table.setHorizontalHeaderLabels(["Image", "Model", "Detections", "Classes", "Last Updated"])
self.results_table.horizontalHeader().setSectionResizeMode(0, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(1, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(2, QHeaderView.ResizeToContents)
self.results_table.horizontalHeader().setSectionResizeMode(3, QHeaderView.Stretch)
self.results_table.horizontalHeader().setSectionResizeMode(4, QHeaderView.ResizeToContents)
self.results_table.setSelectionBehavior(QAbstractItemView.SelectRows)
self.results_table.setSelectionMode(QAbstractItemView.SingleSelection)
self.results_table.setEditTriggers(QAbstractItemView.NoEditTriggers)
@@ -106,6 +92,8 @@ class ResultsTab(QWidget):
preview_layout = QVBoxLayout()
self.preview_canvas = AnnotationCanvasWidget()
# Auto-zoom so newly loaded images fill the available preview viewport.
self.preview_canvas.set_auto_fit_to_view(True)
self.preview_canvas.set_polyline_enabled(False)
self.preview_canvas.set_show_bboxes(True)
preview_layout.addWidget(self.preview_canvas)
@@ -119,9 +107,7 @@ class ResultsTab(QWidget):
self.show_bboxes_checkbox.stateChanged.connect(self._toggle_bboxes)
self.show_confidence_checkbox = QCheckBox("Show Confidence")
self.show_confidence_checkbox.setChecked(False)
self.show_confidence_checkbox.stateChanged.connect(
self._apply_detection_overlays
)
self.show_confidence_checkbox.stateChanged.connect(self._apply_detection_overlays)
toggles_layout.addWidget(self.show_masks_checkbox)
toggles_layout.addWidget(self.show_bboxes_checkbox)
toggles_layout.addWidget(self.show_confidence_checkbox)
@@ -169,8 +155,7 @@ class ResultsTab(QWidget):
"image_id": det["image_id"],
"model_id": det["model_id"],
"image_path": det.get("image_path"),
"image_filename": det.get("image_filename")
or det.get("image_path"),
"image_filename": det.get("image_filename") or det.get("image_path"),
"model_name": det.get("model_name", ""),
"model_version": det.get("model_version", ""),
"last_detected": det.get("detected_at"),
@@ -183,8 +168,7 @@ class ResultsTab(QWidget):
entry["count"] += 1
if det.get("detected_at") and (
not entry.get("last_detected")
or str(det.get("detected_at")) > str(entry.get("last_detected"))
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
):
entry["last_detected"] = det.get("detected_at")
if det.get("class_name"):
@@ -214,9 +198,7 @@ class ResultsTab(QWidget):
for row, entry in enumerate(self.detection_summary):
model_label = f"{entry['model_name']} {entry['model_version']}".strip()
class_list = (
", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
)
class_list = ", ".join(sorted(entry["classes"])) if entry["classes"] else "-"
items = [
QTableWidgetItem(entry.get("image_filename", "")),

View File

@@ -10,6 +10,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import yaml
import numpy as np
from PySide6.QtCore import Qt, QThread, Signal
from PySide6.QtWidgets import (
QWidget,
@@ -91,10 +92,7 @@ class TrainingWorker(QThread):
},
}
]
computed_total = sum(
max(0, int((stage.get("params") or {}).get("epochs", 0)))
for stage in self.stage_plan
)
computed_total = sum(max(0, int((stage.get("params") or {}).get("epochs", 0))) for stage in self.stage_plan)
self.total_epochs = total_epochs if total_epochs else computed_total or epochs
self._stop_requested = False
@@ -201,9 +199,7 @@ class TrainingWorker(QThread):
class TrainingTab(QWidget):
"""Training tab for model training."""
def __init__(
self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None
):
def __init__(self, db_manager: DatabaseManager, config_manager: ConfigManager, parent=None):
super().__init__(parent)
self.db_manager = db_manager
self.config_manager = config_manager
@@ -337,18 +333,14 @@ class TrainingTab(QWidget):
self.model_version_edit = QLineEdit("v1")
form_layout.addRow("Version:", self.model_version_edit)
default_base_model = self.config_manager.get(
"models.default_base_model", "yolov8s-seg.pt"
)
default_base_model = self.config_manager.get("models.default_base_model", "yolov8s-seg.pt")
base_model_choices = self.config_manager.get("models.base_model_choices", [])
self.base_model_combo = QComboBox()
self.base_model_combo.addItem("Custom path…", "")
for choice in base_model_choices:
self.base_model_combo.addItem(choice, choice)
self.base_model_combo.currentIndexChanged.connect(
self._on_base_model_preset_changed
)
self.base_model_combo.currentIndexChanged.connect(self._on_base_model_preset_changed)
form_layout.addRow("Base Model Preset:", self.base_model_combo)
base_model_layout = QHBoxLayout()
@@ -434,12 +426,8 @@ class TrainingTab(QWidget):
group_layout = QVBoxLayout()
self.two_stage_checkbox = QCheckBox("Enable staged head-only + full fine-tune")
two_stage_defaults = (
training_defaults.get("two_stage", {}) if training_defaults else {}
)
self.two_stage_checkbox.setChecked(
bool(two_stage_defaults.get("enabled", False))
)
two_stage_defaults = training_defaults.get("two_stage", {}) if training_defaults else {}
self.two_stage_checkbox.setChecked(bool(two_stage_defaults.get("enabled", False)))
self.two_stage_checkbox.toggled.connect(self._on_two_stage_toggled)
group_layout.addWidget(self.two_stage_checkbox)
@@ -501,9 +489,7 @@ class TrainingTab(QWidget):
stage2_group.setLayout(stage2_form)
controls_layout.addWidget(stage2_group)
helper_label = QLabel(
"When enabled, staged hyperparameters override the global epochs/patience/lr."
)
helper_label = QLabel("When enabled, staged hyperparameters override the global epochs/patience/lr.")
helper_label.setWordWrap(True)
controls_layout.addWidget(helper_label)
@@ -548,9 +534,7 @@ class TrainingTab(QWidget):
if normalized == preset_value:
target_index = idx
break
if normalized.endswith(f"/{preset_value}") or normalized.endswith(
f"\\{preset_value}"
):
if normalized.endswith(f"/{preset_value}") or normalized.endswith(f"\\{preset_value}"):
target_index = idx
break
self.base_model_combo.blockSignals(True)
@@ -638,9 +622,7 @@ class TrainingTab(QWidget):
def _browse_dataset(self):
"""Open a file dialog to manually select data.yaml."""
start_dir = self.config_manager.get(
"training.last_dataset_dir", "data/datasets"
)
start_dir = self.config_manager.get("training.last_dataset_dir", "data/datasets")
start_path = Path(start_dir).expanduser()
if not start_path.exists():
start_path = Path.cwd()
@@ -676,9 +658,7 @@ class TrainingTab(QWidget):
return
except Exception as exc:
logger.exception("Unexpected error while generating data.yaml")
self._display_dataset_error(
"Unexpected error while generating data.yaml. Check logs for details."
)
self._display_dataset_error("Unexpected error while generating data.yaml. Check logs for details.")
QMessageBox.critical(
self,
"data.yaml Generation Failed",
@@ -755,13 +735,9 @@ class TrainingTab(QWidget):
self.selected_dataset = info
self.dataset_root_label.setText(info["root"]) # type: ignore[arg-type]
self.train_count_label.setText(
self._format_split_info(info["splits"].get("train"))
)
self.train_count_label.setText(self._format_split_info(info["splits"].get("train")))
self.val_count_label.setText(self._format_split_info(info["splits"].get("val")))
self.test_count_label.setText(
self._format_split_info(info["splits"].get("test"))
)
self.test_count_label.setText(self._format_split_info(info["splits"].get("test")))
self.num_classes_label.setText(str(info["num_classes"]))
class_names = ", ".join(info["class_names"]) or ""
self.class_names_label.setText(class_names)
@@ -815,18 +791,12 @@ class TrainingTab(QWidget):
if split_path.exists():
split_info["count"] = self._count_images(split_path)
if split_info["count"] == 0:
warnings.append(
f"No images found for {split_name} split at {split_path}"
)
warnings.append(f"No images found for {split_name} split at {split_path}")
else:
warnings.append(
f"{split_name.capitalize()} path does not exist: {split_path}"
)
warnings.append(f"{split_name.capitalize()} path does not exist: {split_path}")
else:
if split_name in ("train", "val"):
warnings.append(
f"{split_name.capitalize()} split missing in data.yaml"
)
warnings.append(f"{split_name.capitalize()} split missing in data.yaml")
splits[split_name] = split_info
names_list = self._normalize_class_names(data.get("names"))
@@ -844,9 +814,7 @@ class TrainingTab(QWidget):
if not names_list and nc_value:
names_list = [f"class_{idx}" for idx in range(int(nc_value))]
elif nc_value and len(names_list) not in (0, int(nc_value)):
warnings.append(
f"Number of class names ({len(names_list)}) does not match nc={nc_value}"
)
warnings.append(f"Number of class names ({len(names_list)}) does not match nc={nc_value}")
dataset_name = data.get("name") or base_path.name
@@ -898,16 +866,12 @@ class TrainingTab(QWidget):
class_index_map = self._build_class_index_map(dataset_info)
if not class_index_map:
self._append_training_log(
"Skipping label export: dataset classes do not match database entries."
)
self._append_training_log("Skipping label export: dataset classes do not match database entries.")
return
dataset_root_str = dataset_info.get("root")
dataset_yaml_path = dataset_info.get("yaml_path")
dataset_yaml = (
Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
)
dataset_yaml = Path(dataset_yaml_path).expanduser() if dataset_yaml_path else None
dataset_root: Optional[Path]
if dataset_root_str:
dataset_root = Path(dataset_root_str).resolve()
@@ -941,7 +905,9 @@ class TrainingTab(QWidget):
if stats["registered_images"]:
message += f" {stats['registered_images']} image(s) had database-backed annotations."
if stats["missing_records"]:
message += f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
message += (
f" {stats['missing_records']} image(s) had no database entry; empty label files were written."
)
split_messages.append(message)
for msg in split_messages:
@@ -973,9 +939,7 @@ class TrainingTab(QWidget):
continue
processed_images += 1
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(
".txt"
)
label_path = (labels_dir / image_file.relative_to(images_dir)).with_suffix(".txt")
label_path.parent.mkdir(parents=True, exist_ok=True)
found, annotation_entries = self._fetch_annotations_for_image(
@@ -991,25 +955,23 @@ class TrainingTab(QWidget):
for entry in annotation_entries:
polygon = entry.get("polygon") or []
if polygon:
print(image_file, polygon[:4], polygon[-2:], entry.get("bbox"))
# coords = " ".join(f"{value:.6f}" for value in entry.get("bbox"))
# coords += " "
coords = " ".join(f"{value:.6f}" for value in polygon)
handle.write(f"{entry['class_idx']} {coords}\n")
annotations_written += 1
elif entry.get("bbox"):
x_center, y_center, width, height = entry["bbox"]
handle.write(
f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n"
)
handle.write(f"{entry['class_idx']} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}\n")
annotations_written += 1
total_annotations += annotations_written
cache_reset_root = labels_dir.parent
self._invalidate_split_cache(cache_reset_root)
if processed_images == 0:
self._append_training_log(
f"[{split_name}] No images found to export labels for."
)
self._append_training_log(f"[{split_name}] No images found to export labels for.")
return None
return {
@@ -1135,6 +1097,10 @@ class TrainingTab(QWidget):
xs.append(x_val)
ys.append(y_val)
if any(np.abs(np.array(coords[:2]) - np.array(coords[-2:])) < 1e-5):
print("Closing polygon")
coords.extend(coords[:2])
if len(coords) < 6:
continue
@@ -1147,6 +1113,11 @@ class TrainingTab(QWidget):
+ abs((min(ys) if ys else 0.0) - y_min)
+ abs((max(ys) if ys else 0.0) - y_max)
)
width = max(0.0, x_max - x_min)
height = max(0.0, y_max - y_min)
x_center = x_min + width / 2.0
y_center = y_min + height / 2.0
score = (x_center, y_center, width, height)
candidates.append((score, coords))
@@ -1164,13 +1135,10 @@ class TrainingTab(QWidget):
return 1.0
return value
def _prepare_dataset_for_training(
self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None
) -> Path:
def _prepare_dataset_for_training(self, dataset_yaml: Path, dataset_info: Optional[Dict[str, Any]] = None) -> Path:
dataset_info = dataset_info or (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_yaml)
else self._parse_dataset_yaml(dataset_yaml)
)
@@ -1189,14 +1157,10 @@ class TrainingTab(QWidget):
cache_root = self._get_rgb_cache_root(dataset_yaml)
rgb_yaml = cache_root / "data.yaml"
if rgb_yaml.exists():
self._append_training_log(
f"Detected grayscale dataset; reusing RGB cache at {cache_root}"
)
self._append_training_log(f"Detected grayscale dataset; reusing RGB cache at {cache_root}")
return rgb_yaml
self._append_training_log(
f"Detected grayscale dataset; creating RGB cache at {cache_root}"
)
self._append_training_log(f"Detected grayscale dataset; creating RGB cache at {cache_root}")
self._build_rgb_dataset(cache_root, dataset_info)
return rgb_yaml
@@ -1463,15 +1427,12 @@ class TrainingTab(QWidget):
dataset_path = Path(dataset_yaml).expanduser()
if not dataset_path.exists():
QMessageBox.warning(
self, "Invalid Dataset", "Selected data.yaml file does not exist."
)
QMessageBox.warning(self, "Invalid Dataset", "Selected data.yaml file does not exist.")
return
dataset_info = (
self.selected_dataset
if self.selected_dataset
and self.selected_dataset.get("yaml_path") == str(dataset_path)
if self.selected_dataset and self.selected_dataset.get("yaml_path") == str(dataset_path)
else self._parse_dataset_yaml(dataset_path)
)
@@ -1480,16 +1441,12 @@ class TrainingTab(QWidget):
dataset_to_use = self._prepare_dataset_for_training(dataset_path, dataset_info)
if dataset_to_use != dataset_path:
self._append_training_log(
f"Using RGB-converted dataset at {dataset_to_use.parent}"
)
self._append_training_log(f"Using RGB-converted dataset at {dataset_to_use.parent}")
params = self._collect_training_params()
stage_plan = self._compose_stage_plan(params)
params["stage_plan"] = stage_plan
total_planned_epochs = (
self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
)
total_planned_epochs = self._calculate_total_stage_epochs(stage_plan) or params["epochs"]
params["total_planned_epochs"] = total_planned_epochs
self._active_training_params = params
self._training_cancelled = False
@@ -1498,9 +1455,7 @@ class TrainingTab(QWidget):
self._append_training_log("Two-stage fine-tuning schedule:")
self._log_stage_plan(stage_plan)
self._append_training_log(
f"Starting training run '{params['run_name']}' using {params['base_model']}"
)
self._append_training_log(f"Starting training run '{params['run_name']}' using {params['base_model']}")
self.training_progress_bar.setVisible(True)
self.training_progress_bar.setMaximum(max(1, total_planned_epochs))
@@ -1528,9 +1483,7 @@ class TrainingTab(QWidget):
def _stop_training(self):
if self.training_worker and self.training_worker.isRunning():
self._training_cancelled = True
self._append_training_log(
"Stop requested. Waiting for the current epoch to finish..."
)
self._append_training_log("Stop requested. Waiting for the current epoch to finish...")
self.training_worker.stop()
self.stop_training_button.setEnabled(False)
@@ -1566,9 +1519,7 @@ class TrainingTab(QWidget):
if worker.isRunning():
if not worker.wait(wait_timeout_ms):
logger.warning(
"Training worker did not finish within %sms", wait_timeout_ms
)
logger.warning("Training worker did not finish within %sms", wait_timeout_ms)
worker.deleteLater()
@@ -1585,16 +1536,12 @@ class TrainingTab(QWidget):
self._set_training_state(False)
self.training_progress_bar.setVisible(False)
def _on_training_progress(
self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]
):
def _on_training_progress(self, current_epoch: int, total_epochs: int, metrics: Dict[str, Any]):
self.training_progress_bar.setMaximum(total_epochs)
self.training_progress_bar.setValue(current_epoch)
parts = [f"Epoch {current_epoch}/{total_epochs}"]
if metrics:
metric_text = ", ".join(
f"{key}: {value:.4f}" for key, value in metrics.items()
)
metric_text = ", ".join(f"{key}: {value:.4f}" for key, value in metrics.items())
parts.append(metric_text)
self._append_training_log(" | ".join(parts))
@@ -1621,9 +1568,7 @@ class TrainingTab(QWidget):
f"Model trained but not registered: {exc}",
)
else:
QMessageBox.information(
self, "Training Complete", "Training finished successfully."
)
QMessageBox.information(self, "Training Complete", "Training finished successfully.")
def _on_training_error(self, message: str):
self._cleanup_training_worker()
@@ -1669,9 +1614,7 @@ class TrainingTab(QWidget):
metrics=results.get("metrics"),
)
self._append_training_log(
f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}"
)
self._append_training_log(f"Registered model '{params['model_name']}' (ID {model_id}) at {model_path}")
self._active_training_params = None
def _set_training_state(self, is_training: bool):
@@ -1714,9 +1657,7 @@ class TrainingTab(QWidget):
def _browse_save_dir(self):
start_path = self.save_dir_edit.text().strip() or "data/models"
directory = QFileDialog.getExistingDirectory(
self, "Select Save Directory", start_path
)
directory = QFileDialog.getExistingDirectory(self, "Select Save Directory", start_path)
if directory:
self.save_dir_edit.setText(directory)

View File

@@ -18,7 +18,7 @@ from PySide6.QtGui import (
QPaintEvent,
QPolygonF,
)
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect
from PySide6.QtCore import Qt, QEvent, Signal, QPoint, QPointF, QRect, QTimer
from typing import Any, Dict, List, Optional, Tuple
from src.utils.image import Image, ImageLoadError
@@ -79,9 +79,7 @@ def rdp(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float,
return [start, end]
def simplify_polyline(
points: List[Tuple[float, float]], epsilon: float
) -> List[Tuple[float, float]]:
def simplify_polyline(points: List[Tuple[float, float]], epsilon: float) -> List[Tuple[float, float]]:
"""
Simplify a polyline with RDP while preserving closure semantics.
@@ -145,6 +143,10 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_step = 0.1
self.zoom_wheel_step = 0.15
# Auto-fit behavior (opt-in): when enabled, newly loaded images (and resizes)
# will scale to fill the available viewport while preserving aspect ratio.
self._auto_fit_to_view: bool = False
# Drawing / interaction state
self.is_drawing = False
self.polyline_enabled = False
@@ -175,6 +177,35 @@ class AnnotationCanvasWidget(QWidget):
self._setup_ui()
def set_auto_fit_to_view(self, enabled: bool):
"""Enable/disable automatic zoom-to-fit behavior."""
self._auto_fit_to_view = bool(enabled)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def fit_to_view(self, padding_px: int = 6):
"""Zoom the image so it fits the scroll area's viewport (aspect preserved)."""
if self.original_pixmap is None:
return
viewport = self.scroll_area.viewport().size()
available_w = max(1, int(viewport.width()) - int(padding_px))
available_h = max(1, int(viewport.height()) - int(padding_px))
img_w = max(1, int(self.original_pixmap.width()))
img_h = max(1, int(self.original_pixmap.height()))
scale_w = available_w / img_w
scale_h = available_h / img_h
new_scale = min(scale_w, scale_h)
new_scale = max(self.zoom_min, min(self.zoom_max, float(new_scale)))
if abs(new_scale - self.zoom_scale) < 1e-4:
return
self.zoom_scale = new_scale
self._apply_zoom()
def _setup_ui(self):
"""Setup user interface."""
layout = QVBoxLayout()
@@ -187,9 +218,7 @@ class AnnotationCanvasWidget(QWidget):
self.canvas_label = QLabel("No image loaded")
self.canvas_label.setAlignment(Qt.AlignCenter)
self.canvas_label.setStyleSheet(
"QLabel { background-color: #2b2b2b; color: #888; }"
)
self.canvas_label.setStyleSheet("QLabel { background-color: #2b2b2b; color: #888; }")
self.canvas_label.setScaledContents(False)
self.canvas_label.setMouseTracking(True)
@@ -212,9 +241,18 @@ class AnnotationCanvasWidget(QWidget):
self.zoom_scale = 1.0
self.clear_annotations()
self._display_image()
logger.debug(
f"Loaded image into annotation canvas: {image.width}x{image.height}"
)
# Defer fit-to-view until the widget has a valid viewport size.
if self._auto_fit_to_view:
QTimer.singleShot(0, self.fit_to_view)
logger.debug(f"Loaded image into annotation canvas: {image.width}x{image.height}")
def resizeEvent(self, event):
"""Optionally keep the image fitted when the widget is resized."""
super().resizeEvent(event)
if self._auto_fit_to_view and self.original_pixmap is not None:
QTimer.singleShot(0, self.fit_to_view)
def clear(self):
"""Clear the displayed image and all annotations."""
@@ -289,22 +327,14 @@ class AnnotationCanvasWidget(QWidget):
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
)
scaled_annotations = self.annotation_pixmap.scaled(
scaled_width,
scaled_height,
Qt.KeepAspectRatio,
(
Qt.SmoothTransformation
if self.zoom_scale >= 1.0
else Qt.FastTransformation
),
(Qt.SmoothTransformation if self.zoom_scale >= 1.0 else Qt.FastTransformation),
)
# Composite image and annotations
@@ -390,16 +420,11 @@ class AnnotationCanvasWidget(QWidget):
y = (pos.y() - offset_y) / self.zoom_scale
# Check bounds
if (
0 <= x < self.original_pixmap.width()
and 0 <= y < self.original_pixmap.height()
):
if 0 <= x < self.original_pixmap.width() and 0 <= y < self.original_pixmap.height():
return (int(x), int(y))
return None
def _find_polyline_at(
self, img_x: float, img_y: float, threshold_px: float = 5.0
) -> Optional[int]:
def _find_polyline_at(self, img_x: float, img_y: float, threshold_px: float = 5.0) -> Optional[int]:
"""
Find index of polyline whose geometry is within threshold_px of (img_x, img_y).
Returns the index in self.polylines, or None if none is close enough.
@@ -421,9 +446,7 @@ class AnnotationCanvasWidget(QWidget):
# Precise distance to all segments
for (x1, y1), (x2, y2) in zip(polyline[:-1], polyline[1:]):
d = perpendicular_distance(
(img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2))
)
d = perpendicular_distance((img_x, img_y), (float(x1), float(y1)), (float(x2), float(y2)))
if d < best_dist:
best_dist = d
best_index = idx
@@ -624,11 +647,7 @@ class AnnotationCanvasWidget(QWidget):
def mouseMoveEvent(self, event: QMouseEvent):
"""Handle mouse move events for drawing."""
if (
not self.is_drawing
or not self.polyline_enabled
or self.annotation_pixmap is None
):
if not self.is_drawing or not self.polyline_enabled or self.annotation_pixmap is None:
super().mouseMoveEvent(event)
return
@@ -688,15 +707,10 @@ class AnnotationCanvasWidget(QWidget):
if len(simplified) >= 2:
# Store polyline and redraw all annotations
self._add_polyline(
simplified, self.polyline_pen_color, self.polyline_pen_width
)
self._add_polyline(simplified, self.polyline_pen_color, self.polyline_pen_width)
# Convert to normalized coordinates for metadata + signal
normalized_stroke = [
self._image_to_normalized_coords(int(x), int(y))
for (x, y) in simplified
]
normalized_stroke = [self._image_to_normalized_coords(int(x), int(y)) for (x, y) in simplified]
self.all_strokes.append(
{
"points": normalized_stroke,
@@ -709,8 +723,7 @@ class AnnotationCanvasWidget(QWidget):
# Emit signal with normalized coordinates
self.annotation_drawn.emit(normalized_stroke)
logger.debug(
f"Completed stroke with {len(simplified)} points "
f"(normalized len={len(normalized_stroke)})"
f"Completed stroke with {len(simplified)} points " f"(normalized len={len(normalized_stroke)})"
)
self.current_stroke = []
@@ -750,9 +763,7 @@ class AnnotationCanvasWidget(QWidget):
# Store polyline as [y_norm, x_norm] to match DB convention and
# the expectations of draw_saved_polyline().
normalized_polyline = [
[y / img_height, x / img_width] for (x, y) in polyline
]
normalized_polyline = [[y / img_height, x / img_width] for (x, y) in polyline]
logger.debug(
f"Polyline {idx}: {len(polyline)} points, "
@@ -772,7 +783,7 @@ class AnnotationCanvasWidget(QWidget):
self,
polyline: List[List[float]],
color: str,
width: int = 3,
width: int = 1,
annotation_id: Optional[int] = None,
):
"""
@@ -810,17 +821,13 @@ class AnnotationCanvasWidget(QWidget):
# Store and redraw using common pipeline
pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency
pen_color.setAlpha(255) # Add semi-transparency
self._add_polyline(img_coords, pen_color, width, annotation_id=annotation_id)
# Store in all_strokes for consistency (uses normalized coordinates)
self.all_strokes.append(
{"points": polyline, "color": color, "alpha": 128, "width": width}
)
self.all_strokes.append({"points": polyline, "color": color, "alpha": 255, "width": width})
logger.debug(
f"Drew saved polyline with {len(polyline)} points in color {color}"
)
logger.debug(f"Drew saved polyline with {len(polyline)} points in color {color}")
def draw_saved_bbox(
self,
@@ -844,9 +851,7 @@ class AnnotationCanvasWidget(QWidget):
return
if len(bbox) != 4:
logger.warning(
f"Invalid bounding box format: expected 4 values, got {len(bbox)}"
)
logger.warning(f"Invalid bounding box format: expected 4 values, got {len(bbox)}")
return
# Convert normalized coordinates to image coordinates (for logging/debug)
@@ -867,15 +872,11 @@ class AnnotationCanvasWidget(QWidget):
# in _redraw_annotations() together with all polylines.
pen_color = QColor(color)
pen_color.setAlpha(128) # Add semi-transparency
self.bboxes.append(
[float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)]
)
self.bboxes.append([float(x_min_norm), float(y_min_norm), float(x_max_norm), float(y_max_norm)])
self.bbox_meta.append({"color": pen_color, "width": int(width), "label": label})
# Store in all_strokes for consistency
self.all_strokes.append(
{"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label}
)
self.all_strokes.append({"bbox": bbox, "color": color, "alpha": 128, "width": width, "label": label})
# Redraw overlay (polylines + all bounding boxes)
self._redraw_annotations()

View File

@@ -96,9 +96,7 @@ class YOLOWrapper:
try:
logger.info(f"Starting training: {name}")
logger.info(
f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}"
)
logger.info(f"Data: {data_yaml}, Epochs: {epochs}, Batch: {batch}, ImgSz: {imgsz}")
# Defaults for 16-bit safety: disable augmentations that force uint8 and HSV ops that assume 0..255.
# Users can override by passing explicit kwargs.
@@ -149,9 +147,7 @@ class YOLOWrapper:
try:
logger.info(f"Starting validation on {split} split")
results = self.model.val(
data=data_yaml, split=split, device=self.device, **kwargs
)
results = self.model.val(data=data_yaml, split=split, device=self.device, **kwargs)
logger.info("Validation completed successfully")
return self._format_validation_results(results)
@@ -190,11 +186,9 @@ class YOLOWrapper:
raise RuntimeError(f"Failed to load model from {self.model_path}")
prepared_source, cleanup_path = self._prepare_source(source)
imgsz = 1088
try:
logger.info(
f"Running inference on {source} -> prepared_source {prepared_source}"
)
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
results = self.model.predict(
source=source,
conf=conf,
@@ -203,6 +197,7 @@ class YOLOWrapper:
save_txt=save_txt,
save_conf=save_conf,
device=self.device,
imgsz=imgsz,
**kwargs,
)
@@ -218,13 +213,9 @@ class YOLOWrapper:
try:
os.remove(cleanup_path)
except OSError as cleanup_error:
logger.warning(
f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}"
)
logger.warning(f"Failed to delete temporary RGB image {cleanup_path}: {cleanup_error}")
def export(
self, format: str = "onnx", output_path: Optional[str] = None, **kwargs
) -> str:
def export(self, format: str = "onnx", output_path: Optional[str] = None, **kwargs) -> str:
"""
Export model to different format.
@@ -265,9 +256,7 @@ class YOLOWrapper:
tmp.close()
img_obj.save(tmp_path)
cleanup_path = tmp_path
logger.info(
f"Converted image {source_path} to RGB for inference at {tmp_path}"
)
logger.info(f"Converted image {source_path} to RGB for inference at {tmp_path}")
return tmp_path, cleanup_path
except Exception as convert_error:
logger.warning(
@@ -280,9 +269,7 @@ class YOLOWrapper:
"""Format training results into dictionary."""
try:
# Get the results dict
results_dict = (
results.results_dict if hasattr(results, "results_dict") else {}
)
results_dict = results.results_dict if hasattr(results, "results_dict") else {}
formatted = {
"success": True,
@@ -315,9 +302,7 @@ class YOLOWrapper:
"mAP50-95": float(box_metrics.map),
"precision": float(box_metrics.mp),
"recall": float(box_metrics.mr),
"fitness": (
float(results.fitness) if hasattr(results, "fitness") else 0.0
),
"fitness": (float(results.fitness) if hasattr(results, "fitness") else 0.0),
}
# Add per-class metrics if available
@@ -327,11 +312,7 @@ class YOLOWrapper:
if idx < len(box_metrics.ap):
class_metrics[name] = {
"ap": float(box_metrics.ap[idx]),
"ap50": (
float(box_metrics.ap50[idx])
if hasattr(box_metrics, "ap50")
else 0.0
),
"ap50": (float(box_metrics.ap50[idx]) if hasattr(box_metrics, "ap50") else 0.0),
}
formatted["class_metrics"] = class_metrics
@@ -364,21 +345,15 @@ class YOLOWrapper:
"class_id": int(boxes.cls[i]),
"class_name": result.names[int(boxes.cls[i])],
"confidence": float(boxes.conf[i]),
"bbox_normalized": [
float(v) for v in xyxyn
], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [
float(v) for v in boxes.xyxy[i].cpu().numpy()
], # Absolute pixels
"bbox_normalized": [float(v) for v in xyxyn], # [x_min, y_min, x_max, y_max]
"bbox_absolute": [float(v) for v in boxes.xyxy[i].cpu().numpy()], # Absolute pixels
}
# Extract segmentation mask if available
if has_masks:
try:
# Get the mask for this detection
mask_data = result.masks.xy[
i
] # Polygon coordinates in absolute pixels
mask_data = result.masks.xy[i] # Polygon coordinates in absolute pixels
# Convert to normalized coordinates
if len(mask_data) > 0:
@@ -391,9 +366,7 @@ class YOLOWrapper:
else:
detection["segmentation_mask"] = None
except Exception as mask_error:
logger.warning(
f"Error extracting mask for detection {i}: {mask_error}"
)
logger.warning(f"Error extracting mask for detection {i}: {mask_error}")
detection["segmentation_mask"] = None
else:
detection["segmentation_mask"] = None
@@ -407,9 +380,7 @@ class YOLOWrapper:
return []
@staticmethod
def convert_bbox_format(
bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy"
) -> List[float]:
def convert_bbox_format(bbox: List[float], format_from: str = "xywh", format_to: str = "xyxy") -> List[float]:
"""
Convert bounding box between formats.

View File

@@ -37,7 +37,7 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a1[a1 > p999] = p999
a1 /= a1.max()
if 0:
if 1:
a2 = a1.copy()
a2 = a2**gamma
a2 /= a2.max()
@@ -47,9 +47,12 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
a3[a3 > p9999] = p9999
a3 /= a3.max()
return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
# return np.stack([a1, a2, a3], axis=0)
out = np.stack([a1, a2, a3], axis=0)
# print(any(np.isnan(out).flatten()))
return out
class ImageLoadError(Exception):
@@ -122,7 +125,7 @@ class Image:
if self.path.suffix.lower() in [".tif", ".tiff"]:
self._data = imread(str(self.path))
else:
raise NotImplementedError("RGB is not implemented")
# raise NotImplementedError("RGB is not implemented")
# Load with OpenCV (returns BGR format)
self._data = cv2.imread(str(self.path), cv2.IMREAD_UNCHANGED)
@@ -246,20 +249,24 @@ class Image:
if self.channels == 1:
img = get_pseudo_rgb(self.data)
self._dtype = img.dtype
return img
return img, True
elif self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB), False
elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA), False
else:
raise NotImplementedError
if self._channels == 3:
return cv2.cvtColor(self._data, cv2.COLOR_BGR2RGB)
elif self._channels == 4:
return cv2.cvtColor(self._data, cv2.COLOR_BGRA2RGBA)
else:
return self._data
# else:
# return self._data
def get_qt_rgb(self) -> np.ascontiguousarray:
# we keep data as (C, H, W)
_img = self.get_rgb()
_img, pseudo = self.get_rgb()
if pseudo:
img = np.zeros((self.height, self.width, 4), dtype=np.float32)
img[..., 0] = _img[0] # R gradient
img[..., 1] = _img[1] # G gradient
@@ -267,6 +274,8 @@ class Image:
img[..., 3] = 1.0 # A = 1.0 (opaque)
return np.ascontiguousarray(img)
else:
return np.ascontiguousarray(_img)
def get_grayscale(self) -> np.ndarray:
"""

View File

@@ -114,11 +114,12 @@ class Label:
return truth_val
def to_string(self, bbox: list = None, polygon: list = None):
coords = ""
if bbox is None:
bbox = self.bbox
# coords += " ".join([f"{x:.6f}" for x in self.bbox])
if polygon is None:
polygon = self.polygon
coords = " ".join([f"{x:.6f}" for x in self.bbox])
if self.polygon is not None:
coords += " " + " ".join([f"{x:.6f} {y:.6f}" for x, y in self.polygon])
return f"{self.class_id} {coords}"
@@ -179,6 +180,13 @@ class ImageSplitter:
for i in range(patch_size[0]):
for j in range(patch_size[1]):
metadata = {
"image_path": str(self.image_path),
"label_path": str(self.label_path),
"tile_section": f"{i}, {j}",
"tile_size": f"{hstep}, {wstep}",
"patch_size": f"{patch_size[0]}, {patch_size[1]}",
}
tile_reference = f"i{i}j{j}"
hrange = (i * hstep / h, (i + 1) * hstep / h)
wrange = (j * wstep / w, (j + 1) * wstep / w)
@@ -199,7 +207,7 @@ class ImageSplitter:
print(l.bbox)
# print(labels)
yield tile_reference, tile, labels
yield tile_reference, tile, labels, metadata
def split_respective_to_label(self, padding: int = 67):
if self.labels is None:
@@ -208,6 +216,7 @@ class ImageSplitter:
for i, label in enumerate(self.labels):
tile_reference = f"_lbl-{i+1:02d}"
# print(label.bbox)
metadata = {"image_path": str(self.image_path), "label_path": str(self.label_path), "label_index": str(i)}
xc_norm, yc_norm, h_norm, w_norm = label.bbox # normalized coords
xc, yc, h, w = [
@@ -246,17 +255,17 @@ class ImageSplitter:
# print("tile shape:", tile.shape)
yolo_annotation = f"{label.class_id} {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
print(yolo_annotation)
yolo_annotation = f"{label.class_id} " # {x_offset/nx} {y_offset/ny} {h_norm} {w_norm} "
yolo_annotation += " ".join(
[
f"{(x*self.image.shape[1]-(xc - x_offset))/nx:.6f} {(y*self.image.shape[0]-(yc-y_offset))/ny:.6f}"
for x, y in label.polygon
]
)
print(yolo_annotation)
new_label = Label(yolo_annotation=yolo_annotation)
yield tile_reference, tile, [new_label]
yield tile_reference, tile, [new_label], metadata
def main(args):
@@ -278,9 +287,9 @@ def main(args):
else:
data = data.split_into_tiles(patch_size=args.patch_size)
for tile_reference, tile, labels in data:
for tile_reference, tile, labels, metadata in data:
print()
print(tile_reference, tile.shape, labels) # len(labels) if labels else None)
print(tile_reference, tile.shape, labels, metadata) # len(labels) if labels else None)
# { debug
debug = False
@@ -310,15 +319,21 @@ def main(args):
# } debug
if args.output:
imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile)
# imwrite(args.output / "images" / f"{image_path.stem}_{tile_reference}.tif", tile, metadata=metadata)
scale = 5
tile_zoomed = zoom(tile, zoom=scale)
imwrite(args.output / "images-zoomed" / f"{image_path.stem}_{tile_reference}.tif", tile_zoomed)
metadata["scale"] = scale
imwrite(
args.output / "images" / f"{image_path.stem}_{tile_reference}.tif",
tile_zoomed,
metadata=metadata,
imagej=True,
)
if labels is not None:
with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f:
for label in labels:
label.offset_label(tile.shape[1], tile.shape[0])
# label.offset_label(tile.shape[1], tile.shape[0])
f.write(label.to_string() + "\n")

View File

@@ -72,8 +72,9 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
# logger.info(f"Loading with monkey-patched imread: {filename}")
arr = arr.astype(np.float32)
arr /= arr.max()
arr *= 2**16 - 1
arr = arr.astype(np.uint16)
arr *= 2**8 - 1
arr = arr.astype(np.uint8)
# print(arr.shape, arr.dtype, any(np.isnan(arr).flatten()), np.where(np.isnan(arr)), arr.min(), arr.max())
return np.ascontiguousarray(arr)
# logger.info(f"Loading with original imread: {filename}")
@@ -105,7 +106,7 @@ def apply_ultralytics_16bit_tiff_patches(*, force: bool = False) -> None:
def preprocess_batch_16bit(self, batch: dict) -> dict: # type: ignore[override]
# Start from upstream behavior to keep device placement + multiscale identical,
# but replace the 255 division with dtype-aware scaling.
logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
# logger.info(f"Preprocessing batch with monkey-patched preprocess_batch")
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device, non_blocking=self.device.type == "cuda")

View File

@@ -196,7 +196,9 @@ def main():
bbox = np.array(bbox) * np.array([img.shape[1], img.shape[0], img.shape[1], img.shape[0]])
yc, xc, h, w = bbox
print("bbox", bbox)
polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
# polyline = np.array(coords[4:]).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
polyline = np.array(coords).reshape(-1, 2) * np.array([img.shape[1], img.shape[0]])
print("pl", coords[4:])
print("pl", polyline)
@@ -207,6 +209,7 @@ def main():
plt.figure(figsize=(10, 10 * out.shape[0] / out.shape[1]))
plt.imshow(out_rgb)
plt.plot(polyline[:, 0], polyline[:, 1], "y", linewidth=2)
if 0:
plt.plot(
[yc - h / 2, yc - h / 2, yc + h / 2, yc + h / 2, yc - h / 2],
[xc - w / 2, xc + w / 2, xc + w / 2, xc - w / 2, xc - w / 2],