Major changes in annotation and results showing/storing
This commit is contained in:
@@ -105,6 +105,103 @@ class DatabaseManager:
|
||||
conn.execute("PRAGMA foreign_keys = ON") # Enable foreign keys
|
||||
return conn
|
||||
|
||||
# ==================== Detection Run Operations ====================
|
||||
|
||||
def upsert_detection_run(
|
||||
self,
|
||||
image_id: int,
|
||||
model_id: int,
|
||||
count: int,
|
||||
metadata: Optional[Dict] = None,
|
||||
) -> bool:
|
||||
"""Insert/update a per-image per-model detection run summary.
|
||||
|
||||
This enables the UI to show runs even when zero detections were produced.
|
||||
"""
|
||||
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO detection_runs (image_id, model_id, detected_at, count, metadata)
|
||||
VALUES (?, ?, CURRENT_TIMESTAMP, ?, ?)
|
||||
ON CONFLICT(image_id, model_id) DO UPDATE SET
|
||||
detected_at = CURRENT_TIMESTAMP,
|
||||
count = excluded.count,
|
||||
metadata = excluded.metadata
|
||||
""",
|
||||
(
|
||||
int(image_id),
|
||||
int(model_id),
|
||||
int(count),
|
||||
json.dumps(metadata) if metadata else None,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return True
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_detection_run_summaries(self, limit: int = 500, offset: int = 0) -> List[Dict]:
|
||||
"""Return latest detection run summaries grouped by image+model.
|
||||
|
||||
Includes runs with 0 detections.
|
||||
"""
|
||||
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
dr.image_id,
|
||||
dr.model_id,
|
||||
dr.detected_at,
|
||||
dr.count,
|
||||
dr.metadata,
|
||||
i.relative_path AS image_path,
|
||||
i.filename AS image_filename,
|
||||
m.model_name,
|
||||
m.model_version,
|
||||
GROUP_CONCAT(DISTINCT d.class_name) AS classes
|
||||
FROM detection_runs dr
|
||||
JOIN images i ON dr.image_id = i.id
|
||||
JOIN models m ON dr.model_id = m.id
|
||||
LEFT JOIN detections d
|
||||
ON d.image_id = dr.image_id AND d.model_id = dr.model_id
|
||||
GROUP BY dr.image_id, dr.model_id
|
||||
ORDER BY dr.detected_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(int(limit), int(offset)),
|
||||
)
|
||||
|
||||
rows: List[Dict] = []
|
||||
for row in cursor.fetchall():
|
||||
item = dict(row)
|
||||
if item.get("metadata"):
|
||||
try:
|
||||
item["metadata"] = json.loads(item["metadata"])
|
||||
except Exception:
|
||||
item["metadata"] = None
|
||||
rows.append(item)
|
||||
return rows
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_detection_run_total(self) -> int:
|
||||
"""Return total number of detection_runs rows."""
|
||||
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT COUNT(*) AS cnt FROM detection_runs")
|
||||
row = cursor.fetchone()
|
||||
return int(row["cnt"] if row and row["cnt"] is not None else 0)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
# ==================== Model Operations ====================
|
||||
|
||||
def add_model(
|
||||
@@ -527,6 +624,14 @@ class DatabaseManager:
|
||||
conn = self.get_connection()
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# Also clear detection run summaries so the Results tab does not continue
|
||||
# to show historical runs after detections have been wiped.
|
||||
try:
|
||||
cursor.execute("DELETE FROM detection_runs")
|
||||
except sqlite3.OperationalError:
|
||||
# Backwards-compatible: table may not exist on older DB files.
|
||||
pass
|
||||
|
||||
cursor.execute("DELETE FROM detections")
|
||||
conn.commit()
|
||||
return cursor.rowcount
|
||||
|
||||
@@ -45,6 +45,19 @@ CREATE TABLE IF NOT EXISTS detections (
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Detection runs table: stores per-image per-model run summaries (including 0 detections)
|
||||
CREATE TABLE IF NOT EXISTS detection_runs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
image_id INTEGER NOT NULL,
|
||||
model_id INTEGER NOT NULL,
|
||||
detected_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
count INTEGER NOT NULL DEFAULT 0,
|
||||
metadata TEXT,
|
||||
UNIQUE(image_id, model_id),
|
||||
FOREIGN KEY (image_id) REFERENCES images (id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (model_id) REFERENCES models (id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Object classes table: stores annotation class definitions with colors
|
||||
CREATE TABLE IF NOT EXISTS object_classes (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
@@ -81,6 +94,9 @@ CREATE INDEX IF NOT EXISTS idx_detections_model_id ON detections(model_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_class_name ON detections(class_name);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence);
|
||||
CREATE INDEX IF NOT EXISTS idx_detection_runs_image_id ON detection_runs(image_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_detection_runs_model_id ON detection_runs(model_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_detection_runs_detected_at ON detection_runs(detected_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_images_source ON images(source);
|
||||
|
||||
@@ -40,6 +40,12 @@ class ResultsTab(QWidget):
|
||||
self.db_manager = db_manager
|
||||
self.config_manager = config_manager
|
||||
|
||||
# Pagination
|
||||
self.page_size = 200
|
||||
self.current_page = 0 # 0-based
|
||||
self.total_runs = 0
|
||||
self.total_pages = 0
|
||||
|
||||
self.detection_summary: List[Dict] = []
|
||||
self.current_selection: Optional[Dict] = None
|
||||
self.current_image: Optional[Image] = None
|
||||
@@ -66,6 +72,20 @@ class ResultsTab(QWidget):
|
||||
self.refresh_btn.clicked.connect(self.refresh)
|
||||
controls_layout.addWidget(self.refresh_btn)
|
||||
|
||||
self.prev_page_btn = QPushButton("◀ Prev")
|
||||
self.prev_page_btn.setToolTip("Previous page")
|
||||
self.prev_page_btn.clicked.connect(self._prev_page)
|
||||
controls_layout.addWidget(self.prev_page_btn)
|
||||
|
||||
self.next_page_btn = QPushButton("Next ▶")
|
||||
self.next_page_btn.setToolTip("Next page")
|
||||
self.next_page_btn.clicked.connect(self._next_page)
|
||||
controls_layout.addWidget(self.next_page_btn)
|
||||
|
||||
self.page_label = QLabel("Page 1/1")
|
||||
self.page_label.setMinimumWidth(100)
|
||||
controls_layout.addWidget(self.page_label)
|
||||
|
||||
self.delete_all_btn = QPushButton("Delete All Detections")
|
||||
self.delete_all_btn.setToolTip(
|
||||
"Permanently delete ALL detections from the database.\n" "This cannot be undone."
|
||||
@@ -183,6 +203,8 @@ class ResultsTab(QWidget):
|
||||
|
||||
def refresh(self):
|
||||
"""Refresh the detection list and preview."""
|
||||
# Reset to first page on refresh.
|
||||
self.current_page = 0
|
||||
self._load_detection_summary()
|
||||
self._populate_results_table()
|
||||
self.current_selection = None
|
||||
@@ -193,57 +215,135 @@ class ResultsTab(QWidget):
|
||||
if hasattr(self, "export_labels_btn"):
|
||||
self.export_labels_btn.setEnabled(False)
|
||||
|
||||
self._update_pagination_controls()
|
||||
|
||||
def _prev_page(self):
|
||||
"""Go to previous results page."""
|
||||
if self.current_page <= 0:
|
||||
return
|
||||
self.current_page -= 1
|
||||
self._load_detection_summary()
|
||||
self._populate_results_table()
|
||||
self._update_pagination_controls()
|
||||
|
||||
def _next_page(self):
|
||||
"""Go to next results page."""
|
||||
if self.total_pages and self.current_page >= (self.total_pages - 1):
|
||||
return
|
||||
self.current_page += 1
|
||||
self._load_detection_summary()
|
||||
self._populate_results_table()
|
||||
self._update_pagination_controls()
|
||||
|
||||
def _update_pagination_controls(self):
|
||||
"""Update pagination label/button enabled state."""
|
||||
|
||||
# Default state (safe)
|
||||
total_pages = max(int(self.total_pages or 0), 1)
|
||||
current_page = min(max(int(self.current_page or 0), 0), total_pages - 1)
|
||||
self.current_page = current_page
|
||||
|
||||
self.page_label.setText(f"Page {current_page + 1}/{total_pages}")
|
||||
self.prev_page_btn.setEnabled(current_page > 0)
|
||||
self.next_page_btn.setEnabled(current_page < (total_pages - 1))
|
||||
|
||||
def _load_detection_summary(self):
|
||||
"""Load latest detection summaries grouped by image + model."""
|
||||
try:
|
||||
detections = self.db_manager.get_detections(limit=500)
|
||||
summary_map: Dict[tuple, Dict] = {}
|
||||
# Prefer run summaries (supports zero-detection runs). Fall back to legacy
|
||||
# detection aggregation if the DB/table isn't available.
|
||||
try:
|
||||
self.total_runs = int(self.db_manager.get_detection_run_total())
|
||||
except Exception:
|
||||
self.total_runs = 0
|
||||
|
||||
for det in detections:
|
||||
key = (det["image_id"], det["model_id"])
|
||||
metadata = det.get("metadata") or {}
|
||||
entry = summary_map.setdefault(
|
||||
key,
|
||||
self.total_pages = (self.total_runs + self.page_size - 1) // self.page_size if self.total_runs > 0 else 1
|
||||
|
||||
offset = int(self.current_page) * int(self.page_size)
|
||||
runs = self.db_manager.get_detection_run_summaries(
|
||||
limit=int(self.page_size),
|
||||
offset=offset,
|
||||
)
|
||||
summary: List[Dict] = []
|
||||
for run in runs:
|
||||
meta = run.get("metadata") or {}
|
||||
classes_raw = run.get("classes")
|
||||
classes = set()
|
||||
if isinstance(classes_raw, str) and classes_raw.strip():
|
||||
classes = {c.strip() for c in classes_raw.split(",") if c.strip()}
|
||||
|
||||
summary.append(
|
||||
{
|
||||
"image_id": det["image_id"],
|
||||
"model_id": det["model_id"],
|
||||
"image_path": det.get("image_path"),
|
||||
"image_filename": det.get("image_filename") or det.get("image_path"),
|
||||
"model_name": det.get("model_name", ""),
|
||||
"model_version": det.get("model_version", ""),
|
||||
"last_detected": det.get("detected_at"),
|
||||
"count": 0,
|
||||
"classes": set(),
|
||||
"source_path": metadata.get("source_path"),
|
||||
"repository_root": metadata.get("repository_root"),
|
||||
},
|
||||
"image_id": run.get("image_id"),
|
||||
"model_id": run.get("model_id"),
|
||||
"image_path": run.get("image_path"),
|
||||
"image_filename": run.get("image_filename") or run.get("image_path"),
|
||||
"model_name": run.get("model_name", ""),
|
||||
"model_version": run.get("model_version", ""),
|
||||
"last_detected": run.get("detected_at"),
|
||||
"count": int(run.get("count") or 0),
|
||||
"classes": classes,
|
||||
"source_path": meta.get("source_path"),
|
||||
"repository_root": meta.get("repository_root"),
|
||||
}
|
||||
)
|
||||
|
||||
entry["count"] += 1
|
||||
if det.get("detected_at") and (
|
||||
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||
):
|
||||
entry["last_detected"] = det.get("detected_at")
|
||||
if det.get("class_name"):
|
||||
entry["classes"].add(det["class_name"])
|
||||
if metadata.get("source_path") and not entry.get("source_path"):
|
||||
entry["source_path"] = metadata.get("source_path")
|
||||
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||
entry["repository_root"] = metadata.get("repository_root")
|
||||
|
||||
self.detection_summary = sorted(
|
||||
summary_map.values(),
|
||||
key=lambda x: str(x.get("last_detected") or ""),
|
||||
reverse=True,
|
||||
)
|
||||
self.detection_summary = summary
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load detection summary: {e}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detection results:\n{str(e)}",
|
||||
)
|
||||
self.detection_summary = []
|
||||
logger.error(f"Failed to load detection run summaries, falling back: {e}")
|
||||
# Disable pagination if we can't page via detection_runs.
|
||||
self.total_runs = 0
|
||||
self.total_pages = 1
|
||||
self.current_page = 0
|
||||
try:
|
||||
detections = self.db_manager.get_detections(limit=int(self.page_size))
|
||||
summary_map: Dict[tuple, Dict] = {}
|
||||
|
||||
for det in detections:
|
||||
key = (det["image_id"], det["model_id"])
|
||||
metadata = det.get("metadata") or {}
|
||||
entry = summary_map.setdefault(
|
||||
key,
|
||||
{
|
||||
"image_id": det["image_id"],
|
||||
"model_id": det["model_id"],
|
||||
"image_path": det.get("image_path"),
|
||||
"image_filename": det.get("image_filename") or det.get("image_path"),
|
||||
"model_name": det.get("model_name", ""),
|
||||
"model_version": det.get("model_version", ""),
|
||||
"last_detected": det.get("detected_at"),
|
||||
"count": 0,
|
||||
"classes": set(),
|
||||
"source_path": metadata.get("source_path"),
|
||||
"repository_root": metadata.get("repository_root"),
|
||||
},
|
||||
)
|
||||
|
||||
entry["count"] += 1
|
||||
if det.get("detected_at") and (
|
||||
not entry.get("last_detected") or str(det.get("detected_at")) > str(entry.get("last_detected"))
|
||||
):
|
||||
entry["last_detected"] = det.get("detected_at")
|
||||
if det.get("class_name"):
|
||||
entry["classes"].add(det["class_name"])
|
||||
if metadata.get("source_path") and not entry.get("source_path"):
|
||||
entry["source_path"] = metadata.get("source_path")
|
||||
if metadata.get("repository_root") and not entry.get("repository_root"):
|
||||
entry["repository_root"] = metadata.get("repository_root")
|
||||
|
||||
self.detection_summary = sorted(
|
||||
summary_map.values(),
|
||||
key=lambda x: str(x.get("last_detected") or ""),
|
||||
reverse=True,
|
||||
)
|
||||
except Exception as inner:
|
||||
logger.error(f"Failed to load detection summary: {inner}")
|
||||
QMessageBox.critical(
|
||||
self,
|
||||
"Error",
|
||||
f"Failed to load detection results:\n{str(inner)}",
|
||||
)
|
||||
self.detection_summary = []
|
||||
|
||||
def _populate_results_table(self):
|
||||
"""Populate the table widget with detection summaries."""
|
||||
|
||||
@@ -84,25 +84,19 @@ class InferenceEngine:
|
||||
|
||||
# Save detections to database, replacing any previous results for this image/model
|
||||
if save_to_db:
|
||||
deleted_count = self.db_manager.delete_detections_for_image(
|
||||
image_id, self.model_id
|
||||
)
|
||||
deleted_count = self.db_manager.delete_detections_for_image(image_id, self.model_id)
|
||||
if detections:
|
||||
detection_records = []
|
||||
for det in detections:
|
||||
# Use normalized bbox from detection
|
||||
bbox_normalized = det[
|
||||
"bbox_normalized"
|
||||
] # [x_min, y_min, x_max, y_max]
|
||||
bbox_normalized = det["bbox_normalized"] # [x_min, y_min, x_max, y_max]
|
||||
|
||||
metadata = {
|
||||
"class_id": det["class_id"],
|
||||
"source_path": str(Path(image_path).resolve()),
|
||||
}
|
||||
if repository_root:
|
||||
metadata["repository_root"] = str(
|
||||
Path(repository_root).resolve()
|
||||
)
|
||||
metadata["repository_root"] = str(Path(repository_root).resolve())
|
||||
|
||||
record = {
|
||||
"image_id": image_id,
|
||||
@@ -115,16 +109,27 @@ class InferenceEngine:
|
||||
}
|
||||
detection_records.append(record)
|
||||
|
||||
inserted_count = self.db_manager.add_detections_batch(
|
||||
detection_records
|
||||
)
|
||||
logger.info(
|
||||
f"Saved {inserted_count} detections to database (replaced {deleted_count})"
|
||||
)
|
||||
inserted_count = self.db_manager.add_detections_batch(detection_records)
|
||||
logger.info(f"Saved {inserted_count} detections to database (replaced {deleted_count})")
|
||||
else:
|
||||
logger.info(
|
||||
f"Detection run removed {deleted_count} stale entries but produced no new detections"
|
||||
logger.info(f"Detection run removed {deleted_count} stale entries but produced no new detections")
|
||||
|
||||
# Always store a run summary so the Results tab can show zero-detection runs.
|
||||
try:
|
||||
run_metadata = {
|
||||
"source_path": str(Path(image_path).resolve()),
|
||||
}
|
||||
if repository_root:
|
||||
run_metadata["repository_root"] = str(Path(repository_root).resolve())
|
||||
self.db_manager.upsert_detection_run(
|
||||
image_id=image_id,
|
||||
model_id=self.model_id,
|
||||
count=len(detections),
|
||||
metadata=run_metadata,
|
||||
)
|
||||
except Exception as exc:
|
||||
# Non-fatal: detection records may still be present.
|
||||
logger.warning(f"Failed to store detection run summary: {exc}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
@@ -232,9 +237,7 @@ class InferenceEngine:
|
||||
for det in detections:
|
||||
# Get color for this class
|
||||
class_name = det["class_name"]
|
||||
color_hex = bbox_colors.get(
|
||||
class_name, bbox_colors.get("default", "#00FF00")
|
||||
)
|
||||
color_hex = bbox_colors.get(class_name, bbox_colors.get("default", "#00FF00"))
|
||||
color = self._hex_to_bgr(color_hex)
|
||||
|
||||
# Draw segmentation mask if available and requested
|
||||
@@ -243,10 +246,7 @@ class InferenceEngine:
|
||||
if mask_normalized and len(mask_normalized) > 0:
|
||||
# Convert normalized coordinates to absolute pixels
|
||||
mask_points = np.array(
|
||||
[
|
||||
[int(pt[0] * width), int(pt[1] * height)]
|
||||
for pt in mask_normalized
|
||||
],
|
||||
[[int(pt[0] * width), int(pt[1] * height)] for pt in mask_normalized],
|
||||
dtype=np.int32,
|
||||
)
|
||||
|
||||
@@ -270,9 +270,7 @@ class InferenceEngine:
|
||||
label = f"{class_name} {det['confidence']:.2f}"
|
||||
|
||||
# Draw label background
|
||||
(label_w, label_h), baseline = cv2.getTextSize(
|
||||
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1
|
||||
)
|
||||
(label_w, label_h), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
|
||||
cv2.rectangle(
|
||||
img,
|
||||
(x1, y1 - label_h - baseline - 5),
|
||||
|
||||
@@ -186,7 +186,7 @@ class YOLOWrapper:
|
||||
raise RuntimeError(f"Failed to load model from {self.model_path}")
|
||||
|
||||
prepared_source, cleanup_path = self._prepare_source(source)
|
||||
imgsz = 1088
|
||||
imgsz = 640 # 1088
|
||||
try:
|
||||
logger.info(f"Running inference on {source} -> prepared_source {prepared_source}")
|
||||
results = self.model.predict(
|
||||
|
||||
@@ -6,6 +6,9 @@ import cv2
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from skimage.restoration import rolling_ball
|
||||
from skimage.filters import threshold_otsu
|
||||
from scipy.ndimage import median_filter, gaussian_filter
|
||||
|
||||
from src.utils.logger import get_logger
|
||||
from src.utils.file_utils import validate_file_path, is_image_file
|
||||
@@ -37,6 +40,8 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
||||
a1[a1 > p999] = p999
|
||||
a1 /= a1.max()
|
||||
|
||||
# print("Using get_pseudo_rgb")
|
||||
|
||||
if 1:
|
||||
a2 = a1.copy()
|
||||
a2 = a2**gamma
|
||||
@@ -46,14 +51,75 @@ def get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.5) -> np.ndarray:
|
||||
p9999 = np.percentile(a3, 99.99)
|
||||
a3[a3 > p9999] = p9999
|
||||
a3 /= a3.max()
|
||||
out = np.stack([a1, a2, a3], axis=0)
|
||||
|
||||
else:
|
||||
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
|
||||
return out
|
||||
|
||||
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
out = np.stack([a1, a2, a3], axis=0)
|
||||
# print(any(np.isnan(out).flatten()))
|
||||
|
||||
|
||||
def _get_pseudo_rgb(arr: np.ndarray, gamma: float = 0.3) -> np.ndarray:
|
||||
"""
|
||||
Convert a grayscale image to a pseudo-RGB image using a gamma correction.
|
||||
|
||||
Args:
|
||||
arr: Input grayscale image as numpy array
|
||||
|
||||
Returns:
|
||||
Pseudo-RGB image as numpy array
|
||||
"""
|
||||
if arr.ndim != 2:
|
||||
raise ValueError("Input array must be a grayscale image with shape (H, W)")
|
||||
|
||||
radius = 80
|
||||
# bg = rolling_ball(arr, radius=radius)
|
||||
a1 = arr.copy().astype(np.float32)
|
||||
a1 = a1.astype(np.float32)
|
||||
# a1 -= bg
|
||||
# a1[a1 < 0] = 0
|
||||
# a1 -= np.percentile(a1, 2)
|
||||
# a1[a1 < 0] = 0
|
||||
p999 = np.percentile(a1, 99.99)
|
||||
a1[a1 > p999] = p999
|
||||
a1 /= a1.max()
|
||||
|
||||
print("Using get_pseudo_rgb")
|
||||
|
||||
if 1:
|
||||
a2 = a1.copy()
|
||||
_a2 = a2**gamma
|
||||
thr = threshold_otsu(_a2)
|
||||
mask = gaussian_filter((_a2 > thr).astype(np.float32), sigma=5)
|
||||
mask[mask > 0.0001] = 1
|
||||
mask[mask <= 0.0001] = 0
|
||||
a2 *= mask
|
||||
|
||||
# bg2 = rolling_ball(a2, radius=radius)
|
||||
# a2 -= bg2
|
||||
a2 -= np.percentile(a2, 2)
|
||||
a2[a2 < 0] = 0
|
||||
a2 /= a2.max()
|
||||
|
||||
a3 = a1.copy()
|
||||
p9999 = np.percentile(a3, 99.99)
|
||||
a3[a3 > p9999] = p9999
|
||||
a3 /= a3.max()
|
||||
out = np.stack([a1, a2, _a2], axis=0)
|
||||
|
||||
else:
|
||||
out = np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
|
||||
return out
|
||||
|
||||
# return np.stack([a1, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
# return np.stack([a2, np.zeros(a1.shape), np.zeros(a1.shape)], axis=0)
|
||||
# print(any(np.isnan(out).flatten()))
|
||||
|
||||
|
||||
class ImageLoadError(Exception):
|
||||
"""Exception raised when an image cannot be loaded."""
|
||||
@@ -330,17 +396,20 @@ class Image:
|
||||
return self._channels >= 3
|
||||
|
||||
def save(self, path: Union[str, Path], pseudo_rgb: bool = True) -> None:
|
||||
|
||||
# print("Image.save", self.data.shape)
|
||||
if self.channels == 1:
|
||||
# print("Image.save grayscale")
|
||||
if pseudo_rgb:
|
||||
img = get_pseudo_rgb(self.data)
|
||||
print("Image.save", img.shape)
|
||||
# print("Image.save", img.shape)
|
||||
|
||||
else:
|
||||
img = np.repeat(self.data, 3, axis=2)
|
||||
|
||||
# print("Image.save no pseudo", img.shape)
|
||||
else:
|
||||
raise NotImplementedError("Only grayscale images are supported for now.")
|
||||
|
||||
# print("Image.save imwrite", img.shape)
|
||||
imwrite(path, data=img)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -5,7 +5,7 @@ from tifffile import imread, imwrite
|
||||
from shapely.geometry import LineString
|
||||
from copy import deepcopy
|
||||
from scipy.ndimage import zoom
|
||||
|
||||
from skimage.restoration import rolling_ball
|
||||
|
||||
# debug
|
||||
from src.utils.image import Image
|
||||
@@ -160,8 +160,11 @@ class YoloLabelReader:
|
||||
|
||||
|
||||
class ImageSplitter:
|
||||
def __init__(self, image_path: Path, label_path: Path):
|
||||
def __init__(self, image_path: Path, label_path: Path, subtract_bg: bool = False):
|
||||
self.image = imread(image_path)
|
||||
if subtract_bg:
|
||||
self.image = self.image - rolling_ball(self.image, radius=12)
|
||||
self.image[self.image < 0] = 0
|
||||
self.image_path = image_path
|
||||
self.label_path = label_path
|
||||
if not label_path.exists():
|
||||
@@ -280,6 +283,7 @@ def main(args):
|
||||
data = ImageSplitter(
|
||||
image_path=image_path,
|
||||
label_path=(args.input / "labels" / image_path.stem).with_suffix(".txt"),
|
||||
subtract_bg=args.subtract_bg,
|
||||
)
|
||||
|
||||
if args.split_around_label:
|
||||
@@ -373,6 +377,7 @@ if __name__ == "__main__":
|
||||
default=67,
|
||||
help="Padding around the label when splitting around the label.",
|
||||
)
|
||||
parser.add_argument("-bg", "--subtract-bg", action="store_true", help="Subtract background from the image.")
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user