From 3c8247b3bc881329907ea07dc2cb40e8e4d6bbf2 Mon Sep 17 00:00:00 2001 From: Martin Laasmaa Date: Wed, 21 Jan 2026 08:51:39 +0200 Subject: [PATCH] Fixing annotations in database --- src/database/db_manager.py | 140 +++++++++++- src/database/schema.sql | 8 +- src/gui/tabs/annotation_tab.py | 402 ++++++++++++++++++++++++++++++++- src/gui/tabs/training_tab.py | 30 ++- src/utils/image_converters.py | 14 +- src/utils/image_splitter.py | 12 +- 6 files changed, 589 insertions(+), 17 deletions(-) diff --git a/src/database/db_manager.py b/src/database/db_manager.py index 9efa5c2..9bf2f04 100644 --- a/src/database/db_manager.py +++ b/src/database/db_manager.py @@ -40,8 +40,15 @@ class DatabaseManager: conn = self.get_connection() try: - # Check if annotations table needs migration + # Pre-schema migrations. + # These must run BEFORE executing schema.sql because schema.sql may + # contain CREATE INDEX statements referencing newly added columns. + # + # 1) Check if annotations table needs migration (may drop an old table) self._migrate_annotations_table(conn) + # 2) Ensure images table has the required columns (e.g. 'source') + self._migrate_images_table(conn) + conn.commit() # Read schema file and execute schema_path = Path(__file__).parent / "schema.sql" @@ -53,6 +60,19 @@ class DatabaseManager: finally: conn.close() + def _migrate_images_table(self, conn: sqlite3.Connection) -> None: + """Migrate images table to include the 'source' column if missing.""" + + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='images'") + if not cursor.fetchone(): + return + + cursor.execute("PRAGMA table_info(images)") + columns = {row[1] for row in cursor.fetchall()} + if "source" not in columns: + cursor.execute("ALTER TABLE images ADD COLUMN source TEXT") + def _migrate_annotations_table(self, conn: sqlite3.Connection) -> None: """ Migrate annotations table from old schema (class_name) to new schema (class_id). @@ -233,6 +253,7 @@ class DatabaseManager: height: int, captured_at: Optional[datetime] = None, checksum: Optional[str] = None, + source: Optional[str] = None, ) -> int: """ Add a new image to the database. @@ -253,10 +274,10 @@ class DatabaseManager: cursor = conn.cursor() cursor.execute( """ - INSERT INTO images (relative_path, filename, width, height, captured_at, checksum) - VALUES (?, ?, ?, ?, ?, ?) + INSERT INTO images (relative_path, filename, width, height, captured_at, checksum, source) + VALUES (?, ?, ?, ?, ?, ?, ?) """, - (relative_path, filename, width, height, captured_at, checksum), + (relative_path, filename, width, height, captured_at, checksum, source), ) conn.commit() return cursor.lastrowid @@ -286,6 +307,18 @@ class DatabaseManager: return existing["id"] return self.add_image(relative_path, filename, width, height) + def set_image_source(self, image_id: int, source: Optional[str]) -> bool: + """Set/update the source marker for an image row.""" + + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("UPDATE images SET source = ? WHERE id = ?", (source, int(image_id))) + conn.commit() + return cursor.rowcount > 0 + finally: + conn.close() + # ==================== Detection Operations ==================== def add_detection( @@ -658,6 +691,84 @@ class DatabaseManager: # ==================== Annotation Operations ==================== + def get_images_summary( + self, + name_filter: Optional[str] = None, + source_filter: Optional[str] = None, + order_by: str = "filename", + order_dir: str = "ASC", + limit: Optional[int] = None, + offset: int = 0, + ) -> List[Dict]: + """Return all images with annotation counts (including zero). + + This is used by the Annotation tab to populate the image list even when + no annotations exist yet. + + Args: + name_filter: Optional substring filter applied to filename/relative_path. + order_by: One of: 'filename', 'relative_path', 'annotation_count', 'added_at'. + order_dir: 'ASC' or 'DESC'. + limit: Optional max number of rows. + offset: Pagination offset. + + Returns: + List of dicts: {id, relative_path, filename, added_at, annotation_count} + """ + + allowed_order_by = { + "filename": "i.filename", + "relative_path": "i.relative_path", + "annotation_count": "annotation_count", + "added_at": "i.added_at", + } + order_expr = allowed_order_by.get(order_by, "i.filename") + dir_norm = str(order_dir).upper().strip() + if dir_norm not in {"ASC", "DESC"}: + dir_norm = "ASC" + + conn = self.get_connection() + try: + params: List[Any] = [] + where_clauses: List[str] = [] + if name_filter: + token = f"%{name_filter}%" + where_clauses.append("(i.filename LIKE ? OR i.relative_path LIKE ?)") + params.extend([token, token]) + if source_filter: + where_clauses.append("i.source = ?") + params.append(source_filter) + + where_sql = "" + if where_clauses: + where_sql = "WHERE " + " AND ".join(where_clauses) + + limit_sql = "" + if limit is not None: + limit_sql = " LIMIT ? OFFSET ?" + params.extend([int(limit), int(offset)]) + + query = f""" + SELECT + i.id, + i.relative_path, + i.filename, + i.added_at, + COUNT(a.id) AS annotation_count + FROM images i + LEFT JOIN annotations a ON a.image_id = i.id + {where_sql} + GROUP BY i.id + ORDER BY {order_expr} {dir_norm} + {limit_sql} + """ + + cursor = conn.cursor() + cursor.execute(query, params) + return [dict(row) for row in cursor.fetchall()] + finally: + conn.close() + def get_annotated_images_summary( self, name_filter: Optional[str] = None, @@ -832,6 +943,27 @@ class DatabaseManager: finally: conn.close() + def delete_annotations_for_image(self, image_id: int) -> int: + """Delete ALL annotations for a specific image. + + This is primarily used for import/overwrite workflows. + + Args: + image_id: ID of the image whose annotations should be deleted. + + Returns: + Number of rows deleted. + """ + + conn = self.get_connection() + try: + cursor = conn.cursor() + cursor.execute("DELETE FROM annotations WHERE image_id = ?", (int(image_id),)) + conn.commit() + return int(cursor.rowcount or 0) + finally: + conn.close() + # ==================== Object Class Operations ==================== def get_object_classes(self) -> List[Dict]: diff --git a/src/database/schema.sql b/src/database/schema.sql index 9345465..abef287 100644 --- a/src/database/schema.sql +++ b/src/database/schema.sql @@ -7,7 +7,7 @@ CREATE TABLE IF NOT EXISTS models ( model_name TEXT NOT NULL, model_version TEXT NOT NULL, model_path TEXT NOT NULL, - base_model TEXT NOT NULL DEFAULT 'yolov8s.pt', + base_model TEXT NOT NULL DEFAULT 'yolo11s.pt', created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, training_params TEXT, -- JSON string of training parameters metrics TEXT, -- JSON string of validation metrics @@ -23,7 +23,8 @@ CREATE TABLE IF NOT EXISTS images ( height INTEGER, captured_at TIMESTAMP, added_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - checksum TEXT + checksum TEXT, + source TEXT ); -- Detections table: stores detection results @@ -82,7 +83,8 @@ CREATE INDEX IF NOT EXISTS idx_detections_detected_at ON detections(detected_at) CREATE INDEX IF NOT EXISTS idx_detections_confidence ON detections(confidence); CREATE INDEX IF NOT EXISTS idx_images_relative_path ON images(relative_path); CREATE INDEX IF NOT EXISTS idx_images_added_at ON images(added_at); +CREATE INDEX IF NOT EXISTS idx_images_source ON images(source); CREATE INDEX IF NOT EXISTS idx_annotations_image_id ON annotations(image_id); CREATE INDEX IF NOT EXISTS idx_annotations_class_id ON annotations(class_id); CREATE INDEX IF NOT EXISTS idx_models_created_at ON models(created_at); -CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name); \ No newline at end of file +CREATE INDEX IF NOT EXISTS idx_object_classes_class_name ON object_classes(class_name); diff --git a/src/gui/tabs/annotation_tab.py b/src/gui/tabs/annotation_tab.py index 9520a5e..3e16a92 100644 --- a/src/gui/tabs/annotation_tab.py +++ b/src/gui/tabs/annotation_tab.py @@ -135,11 +135,27 @@ class AnnotationTab(QWidget): load_group = QGroupBox("Image Loading") load_layout = QVBoxLayout() - # Load image button + # Buttons row button_layout = QHBoxLayout() self.load_image_btn = QPushButton("Load Image") self.load_image_btn.clicked.connect(self._load_image) button_layout.addWidget(self.load_image_btn) + + self.import_images_btn = QPushButton("Import Images") + self.import_images_btn.setToolTip( + "Import one or more images into the database.\n" "Images already present in the DB are skipped." + ) + self.import_images_btn.clicked.connect(self._import_images) + button_layout.addWidget(self.import_images_btn) + + self.import_annotations_btn = QPushButton("Import Annotations") + self.import_annotations_btn.setToolTip( + "Import YOLO .txt annotation files and register them with their corresponding images.\n" + "Existing annotations for those images will be overwritten." + ) + self.import_annotations_btn.clicked.connect(self._import_annotations) + button_layout.addWidget(self.import_annotations_btn) + button_layout.addStretch() load_layout.addLayout(button_layout) @@ -168,6 +184,372 @@ class AnnotationTab(QWidget): # Populate list on startup. self._refresh_annotated_images_list() + def _import_images(self) -> None: + """Import one or more images into the database and refresh the list.""" + + settings = QSettings("microscopy_app", "object_detection") + last_dir = settings.value("annotation_tab/last_image_import_directory", None) + + repo_root = (self.config_manager.get_image_repository_path() or "").strip() + if last_dir and Path(str(last_dir)).exists(): + start_dir = str(last_dir) + elif repo_root and Path(repo_root).exists(): + start_dir = repo_root + else: + start_dir = str(Path.home()) + + # Build filter string for supported extensions + patterns = " ".join(f"*{ext}" for ext in Image.SUPPORTED_EXTENSIONS) + file_paths, _ = QFileDialog.getOpenFileNames( + self, + "Select Image File(s)", + start_dir, + f"Images ({patterns})", + ) + + if not file_paths: + return + + try: + settings.setValue("annotation_tab/last_image_import_directory", str(Path(file_paths[0]).parent)) + # Keep compatibility with the existing image resolver fallback (it checks last_directory). + settings.setValue("annotation_tab/last_directory", str(Path(file_paths[0]).parent)) + except Exception: + pass + + imported = 0 + tagged_existing = 0 + skipped = 0 + errors: list[str] = [] + + for fp in file_paths: + try: + img_path = Path(fp) + img = Image(str(img_path)) + relative_path = self._compute_relative_path_for_repo(img_path) + + # Skip if already present + existing = self.db_manager.get_image_by_path(relative_path) + if existing: + # If the image already exists (e.g. created earlier by other workflows), + # tag it as being managed by the Annotation tab so it becomes visible + # in the left list. + try: + self.db_manager.set_image_source(int(existing["id"]), "annotation_tab") + tagged_existing += 1 + except Exception: + # If tagging fails, fall back to treating as skipped. + skipped += 1 + continue + + image_id = self.db_manager.add_image( + relative_path, + img_path.name, + img.width, + img.height, + source="annotation_tab", + ) + try: + # In case the DB row was created by an older schema/migration path. + self.db_manager.set_image_source(image_id, "annotation_tab") + except Exception: + pass + imported += 1 + except ImageLoadError as exc: + skipped += 1 + errors.append(f"Failed to load image {fp}: {exc}") + except Exception as exc: + skipped += 1 + errors.append(f"Failed to import image {fp}: {exc}") + + self._refresh_annotated_images_list(select_image_id=self.current_image_id) + + msg = ( + f"Imported: {imported}\n" + f"Already in DB (tagged for Annotation tab): {tagged_existing}\n" + f"Skipped (errors): {skipped}" + ) + if errors: + details = "\n".join(errors[:25]) + if len(errors) > 25: + details += f"\n... and {len(errors) - 25} more" + msg += "\n\nDetails:\n" + details + QMessageBox.information(self, "Import Images", msg) + + # ==================== Import annotations (YOLO .txt) ==================== + + def _import_annotations(self) -> None: + """Import YOLO segmentation/bbox annotations from one or more .txt files.""" + + settings = QSettings("microscopy_app", "object_detection") + last_dir = settings.value("annotation_tab/last_annotation_directory", None) + + # Default start dir: repo root if set, otherwise last used, otherwise home. + repo_root = (self.config_manager.get_image_repository_path() or "").strip() + if last_dir and Path(str(last_dir)).exists(): + start_dir = str(last_dir) + elif repo_root and Path(repo_root).exists(): + start_dir = repo_root + else: + start_dir = str(Path.home()) + + file_paths, _ = QFileDialog.getOpenFileNames( + self, + "Select YOLO Annotation File(s)", + start_dir, + "YOLO annotations (*.txt)", + ) + + if not file_paths: + return + + # Persist last annotation directory for the next import. + try: + settings.setValue("annotation_tab/last_annotation_directory", str(Path(file_paths[0]).parent)) + except Exception: + pass + + imported_images = 0 + imported_annotations = 0 + overwritten_images = 0 + skipped = 0 + errors: list[str] = [] + + for label_file in file_paths: + label_path = Path(label_file) + try: + image_path = self._infer_corresponding_image_path(label_path) + if not image_path: + skipped += 1 + errors.append(f"Image not found for label file: {label_path}") + continue + + # Load image to obtain width/height for DB entry. + img = Image(str(image_path)) + + # Store in DB using a repo-relative path if possible. + relative_path = self._compute_relative_path_for_repo(image_path) + image_id = self.db_manager.get_or_create_image(relative_path, image_path.name, img.width, img.height) + try: + self.db_manager.set_image_source(image_id, "annotation_tab") + except Exception: + pass + + # Overwrite existing annotations for this image. + try: + deleted = self.db_manager.delete_annotations_for_image(image_id) + except AttributeError: + # Safety fallback if older DBManager is used. + deleted = 0 + if deleted > 0: + overwritten_images += 1 + + # Parse YOLO lines and insert as annotations. + parsed = self._parse_yolo_annotation_file(label_path) + if not parsed: + # Empty/invalid label file: treat as "clear" operation (already deleted above) + imported_images += 1 + continue + + db_classes = self.db_manager.get_object_classes() or [] + classes_by_index = {idx: row for idx, row in enumerate(db_classes)} + + for class_index, bbox, poly in parsed: + class_row = classes_by_index.get(int(class_index)) + if not class_row: + skipped += 1 + errors.append( + f"Unknown class index {class_index} in {label_path.name}. " + "Create object classes in the UI first (class index is based on DB ordering)." + ) + continue + + ann_id = self.db_manager.add_annotation( + image_id=image_id, + class_id=int(class_row["id"]), + bbox=bbox, + annotator="import", + segmentation_mask=poly, + verified=False, + ) + if ann_id: + imported_annotations += 1 + + imported_images += 1 + + # If we imported for the currently open image, reload. + if self.current_image_id and int(self.current_image_id) == int(image_id): + self._load_annotations_for_current_image() + + except ImageLoadError as exc: + skipped += 1 + errors.append(f"Failed to load image for {label_path.name}: {exc}") + except Exception as exc: + skipped += 1 + errors.append(f"Import failed for {label_path.name}: {exc}") + + # Refresh annotated images list. + self._refresh_annotated_images_list(select_image_id=self.current_image_id) + + summary = ( + f"Imported files: {len(file_paths)}\n" + f"Images processed: {imported_images}\n" + f"Annotations inserted: {imported_annotations}\n" + f"Images overwritten (had existing annotations): {overwritten_images}\n" + f"Skipped: {skipped}" + ) + if errors: + # Cap error details to avoid huge dialogs. + details = "\n".join(errors[:25]) + if len(errors) > 25: + details += f"\n... and {len(errors) - 25} more" + summary += "\n\nDetails:\n" + details + + QMessageBox.information(self, "Import Annotations", summary) + + def _infer_corresponding_image_path(self, label_path: Path) -> Path | None: + """Infer image path from YOLO label file path. + + Requirement: image(s) live in an `images/` folder located in the label file's parent directory. + Example: + /dataset/train/labels/img123.txt -> /dataset/train/images/img123.(any supported ext) + """ + + parent = label_path.parent + images_dir = parent.parent / "images" + stem = label_path.stem + + # 1) Direct stem match in images dir (any supported extension) + for ext in Image.SUPPORTED_EXTENSIONS: + candidate = images_dir / f"{stem}{ext}" + if candidate.exists() and candidate.is_file(): + return candidate + + # 2) Fallback: repository-root search by filename + repo_root = (self.config_manager.get_image_repository_path() or "").strip() + if repo_root: + root = Path(repo_root).expanduser() + try: + if root.exists(): + for ext in Image.SUPPORTED_EXTENSIONS: + filename = f"{stem}{ext}" + for match in root.rglob(filename): + if match.is_file(): + return match.resolve() + except Exception: + pass + + return None + + def _compute_relative_path_for_repo(self, image_path: Path) -> str: + """Compute a stable `relative_path` suitable for DB storage. + + Policy: + - If an image repository root is configured and the image is under it, store a repo-relative path. + - Otherwise, store an absolute resolved path so the image can be reopened later. + """ + + repo_root = (self.config_manager.get_image_repository_path() or "").strip() + try: + if repo_root: + repo_root_path = Path(repo_root).expanduser().resolve() + img_resolved = image_path.expanduser().resolve() + if img_resolved.is_relative_to(repo_root_path): + return img_resolved.relative_to(repo_root_path).as_posix() + except Exception: + pass + try: + return str(image_path.expanduser().resolve()) + except Exception: + return str(image_path) + + def _parse_yolo_annotation_file( + self, label_path: Path + ) -> list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]]: + """Parse a YOLO .txt label file. + + Supports: + - YOLO segmentation polygon format: "class x1 y1 x2 y2 ..." (normalized) + - YOLO bbox format: "class x_center y_center width height" (normalized) + + Returns: + List of (class_index, bbox_xyxy_norm, segmentation_mask_db) + Where segmentation_mask_db is [[y_norm, x_norm], ...] or None. + """ + + out: list[tuple[int, tuple[float, float, float, float], list[list[float]] | None]] = [] + try: + raw = label_path.read_text(encoding="utf-8").splitlines() + except OSError as exc: + logger.error(f"Failed to read label file {label_path}: {exc}") + return out + + for line in raw: + stripped = line.strip() + if not stripped: + continue + parts = stripped.split() + if len(parts) < 5: + # not enough for bbox + continue + + try: + class_idx = int(float(parts[0])) + coords = [float(x) for x in parts[1:]] + except Exception: + continue + + # Segmentation polygon format (>= 6 values) + if len(coords) >= 6: + # bbox is not explicitly present in this format in our importer; compute from polygon. + xs = coords[0::2] + ys = coords[1::2] + if not xs or not ys: + continue + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + bbox = ( + self._clamp01(x_min), + self._clamp01(y_min), + self._clamp01(x_max), + self._clamp01(y_max), + ) + + # Convert to DB polyline convention: [[y_norm, x_norm], ...] + poly: list[list[float]] = [] + for x, y in zip(xs, ys): + poly.append([self._clamp01(float(y)), self._clamp01(float(x))]) + # Ensure closure for consistency (optional) + if poly and poly[0] != poly[-1]: + poly.append(list(poly[0])) + out.append((class_idx, bbox, poly)) + continue + + # bbox format: xc yc w h + if len(coords) >= 4: + xc, yc, w, h = coords[:4] + x_min = xc - w / 2.0 + y_min = yc - h / 2.0 + x_max = xc + w / 2.0 + y_max = yc + h / 2.0 + bbox = ( + self._clamp01(float(x_min)), + self._clamp01(float(y_min)), + self._clamp01(float(x_max)), + self._clamp01(float(y_max)), + ) + out.append((class_idx, bbox, None)) + + return out + + @staticmethod + def _clamp01(value: float) -> float: + if value < 0.0: + return 0.0 + if value > 1.0: + return 1.0 + return float(value) + def _load_image(self): """Load and display an image file.""" # Get last opened directory from QSettings @@ -222,6 +604,11 @@ class AnnotationTab(QWidget): self.current_image.width, self.current_image.height, ) + # Mark as managed by Annotation tab so it appears in the left list. + try: + self.db_manager.set_image_source(int(self.current_image_id), "annotation_tab") + except Exception: + pass # Display image using the AnnotationCanvasWidget self.annotation_canvas.load_image(self.current_image) @@ -596,9 +983,9 @@ class AnnotationTab(QWidget): name_filter = self.annotated_filter_edit.text().strip() try: - rows = self.db_manager.get_annotated_images_summary(name_filter=name_filter) + rows = self.db_manager.get_images_summary(name_filter=name_filter, source_filter="annotation_tab") except Exception as exc: - logger.error(f"Failed to load annotated images summary: {exc}") + logger.error(f"Failed to load images summary: {exc}") rows = [] sorting_enabled = self.annotated_images_table.isSortingEnabled() @@ -719,6 +1106,15 @@ class AnnotationTab(QWidget): except Exception: pass + # 4b) Try the last directory used by the image import picker + try: + settings = QSettings("microscopy_app", "object_detection") + last_import_dir = settings.value("annotation_tab/last_image_import_directory", None) + if last_import_dir: + candidates.append(Path(str(last_import_dir)) / Path(rel).name) + except Exception: + pass + for p in candidates: try: expanded = p.expanduser() diff --git a/src/gui/tabs/training_tab.py b/src/gui/tabs/training_tab.py index d6bdbd0..7041fed 100644 --- a/src/gui/tabs/training_tab.py +++ b/src/gui/tabs/training_tab.py @@ -905,9 +905,14 @@ class TrainingTab(QWidget): if stats["registered_images"]: message += f" {stats['registered_images']} image(s) had database-backed annotations." if stats["missing_records"]: - message += ( - f" {stats['missing_records']} image(s) had no database entry; empty label files were written." - ) + preserved = stats.get("preserved_existing_labels", 0) + if preserved: + message += ( + f" {stats['missing_records']} image(s) had no database annotations; " + f"preserved {preserved} existing label file(s) (no overwrite)." + ) + else: + message += f" {stats['missing_records']} image(s) had no database annotations; empty label files were written." split_messages.append(message) for msg in split_messages: @@ -929,6 +934,7 @@ class TrainingTab(QWidget): processed_images = 0 registered_images = 0 missing_records = 0 + preserved_existing_labels = 0 total_annotations = 0 for image_file in images_dir.rglob("*"): @@ -950,6 +956,12 @@ class TrainingTab(QWidget): else: missing_records += 1 + # If the database has no entry for this image, do not overwrite an existing label file + # with an empty one (preserve any manually created labels on disk). + if not found and label_path.exists(): + preserved_existing_labels += 1 + continue + annotations_written = 0 with open(label_path, "w", encoding="utf-8") as handle: for entry in annotation_entries: @@ -979,6 +991,7 @@ class TrainingTab(QWidget): "processed_images": processed_images, "registered_images": registered_images, "missing_records": missing_records, + "preserved_existing_labels": preserved_existing_labels, "total_annotations": total_annotations, } @@ -1008,6 +1021,10 @@ class TrainingTab(QWidget): resolved_image = image_path.resolve() candidates: List[str] = [] + # Some DBs store absolute paths in `images.relative_path`. + # Include the absolute resolved path as a lookup candidate. + candidates.append(resolved_image.as_posix()) + for base in (dataset_root, images_dir): try: relative = resolved_image.relative_to(base.resolve()).as_posix() @@ -1032,6 +1049,13 @@ class TrainingTab(QWidget): return False, [] annotations = self.db_manager.get_annotations_for_image(image_row["id"]) or [] + + # Treat "found" as "has database-backed annotations". + # If the image exists in DB but has no annotations yet, we don't want to overwrite + # an existing label file on disk with an empty one. + if not annotations: + return False, [] + yolo_entries: List[Dict[str, Any]] = [] for ann in annotations: diff --git a/src/utils/image_converters.py b/src/utils/image_converters.py index 93e18bd..094c192 100644 --- a/src/utils/image_converters.py +++ b/src/utils/image_converters.py @@ -97,7 +97,7 @@ class UT: class_index: int = 0, ): """Export rois to a file""" - with open(path / subfolder / f"{self.stem}.txt", "w") as f: + with open(path / subfolder / f"{PREFIX}-{self.stem}.txt", "w") as f: for i, roi in enumerate(self.rois): rc = roi.subpixel_coordinates if rc is None: @@ -129,8 +129,8 @@ class UT: self.image = np.max(self.image[channel], axis=0) print(self.image.shape) - print(path / subfolder / f"{self.stem}.tif") - with TiffWriter(path / subfolder / f"{self.stem}.tif") as tif: + print(path / subfolder / f"{PREFIX}_{self.stem}.tif") + with TiffWriter(path / subfolder / f"{PREFIX}-{self.stem}.tif") as tif: tif.write(self.image) @@ -145,11 +145,19 @@ if __name__ == "__main__": action="store_false", help="Source does not have labels, export only images", ) + parser.add_argument("--prefix", help="Prefix for output files") args = parser.parse_args() + PREFIX = args.prefix # print(args) # aa + # for path in args.input: + # print(path) + # ut = UT(path, args.no_labels) + # ut.export_image(args.output, plane_mode="max projection", channel=0) + # ut.export_rois(args.output, class_index=0) + for path in args.input: print("Path:", path) if not args.no_labels: diff --git a/src/utils/image_splitter.py b/src/utils/image_splitter.py index 0cf22f9..bef3941 100644 --- a/src/utils/image_splitter.py +++ b/src/utils/image_splitter.py @@ -273,7 +273,7 @@ def main(args): if args.output: args.output.mkdir(exist_ok=True, parents=True) (args.output / "images").mkdir(exist_ok=True) - (args.output / "images-zoomed").mkdir(exist_ok=True) + # (args.output / "images-zoomed").mkdir(exist_ok=True) (args.output / "labels").mkdir(exist_ok=True) for image_path in (args.input / "images").glob("*.tif"): @@ -332,10 +332,20 @@ def main(args): if labels is not None: with open(args.output / "labels" / f"{image_path.stem}_{tile_reference}.txt", "w") as f: + print( + f"Writing {len(labels)} labels to {args.output / 'labels' / f'{image_path.stem}_{tile_reference}.txt'}" + ) for label in labels: # label.offset_label(tile.shape[1], tile.shape[0]) f.write(label.to_string() + "\n") + # { debug + if debug: + print(label.to_string()) + # } debug + # break + # break + if __name__ == "__main__": import argparse