Segmentation Plugin Example
This example demonstrates how to use the segmentation plugin system in ZarrNii. The plugin system is built on the pluggy framework, providing a flexible and extensible architecture for implementing custom segmentation algorithms.
Plugin Architecture Overview
ZarrNii's segmentation plugins use pluggy hook specifications and implementations:
- Hook specifications define the interface that all plugins must implement
- Hook implementations (using @hookimpl decorator) provide the actual functionality
- Plugins can be used directly or registered with the plugin manager for discovery
Basic Otsu Segmentation
import numpy as np
import dask.array as da
from zarrnii import ZarrNii, OtsuSegmentation
# Load or create your image data
# For this example, we'll create synthetic bimodal data
np.random.seed(42)
image_data = np.random.normal(0.2, 0.05, (1, 50, 100, 100)) # Background
image_data[0, 20:30, 40:60, 40:60] = np.random.normal(0.8, 0.05, (10, 20, 20)) # Foreground
# Create ZarrNii instance
darr = da.from_array(image_data, chunks=(1, 25, 50, 50))
znimg = ZarrNii.from_darr(darr, axes_order="ZYX", orientation="RAS")
# Method 1: Using the convenience method
segmented = znimg.segment_otsu(nbins=256)
# Method 2: Using the plugin directly
plugin = OtsuSegmentation(nbins=256)
segmented = znimg.segment(plugin)
# Method 3: Using plugin class with parameters
segmented = znimg.segment(OtsuSegmentation, nbins=128)
# The result is a new ZarrNii instance with binary segmentation
print(f"Original shape: {znimg.shape}")
print(f"Segmented shape: {segmented.shape}")
print(f"Segmented dtype: {segmented.data.dtype}")
print(f"Unique values: {np.unique(segmented.data.compute())}")
# Save segmented result as OME-Zarr
segmented.to_ome_zarr("segmented_image.ome.zarr")
Custom Chunk Processing
For large datasets, you can control the chunk size for blockwise processing:
# Segment with custom chunk size for memory efficiency
custom_chunks = (1, 10, 25, 25)
segmented = znimg.segment_otsu(chunk_size=custom_chunks)
# The segmentation will be applied block-wise using dask
result_data = segmented.data.compute()
Creating Custom Segmentation Plugins
You can create your own segmentation plugins by inheriting from SegmentationPlugin and implementing the required hook methods with the @hookimpl decorator:
from zarrnii.plugins import SegmentationPlugin
from zarrnii.plugins.segmentation.base import hookimpl
import numpy as np
class ThresholdSegmentation(SegmentationPlugin):
"""Simple threshold-based segmentation plugin."""
def __init__(self, threshold: float = 0.5, **kwargs):
"""Initialize the plugin.
Args:
threshold: Threshold value for segmentation
"""
super().__init__(threshold=threshold, **kwargs)
self.threshold = threshold
@hookimpl
def segment(self, image: np.ndarray, metadata=None) -> np.ndarray:
"""Apply threshold segmentation.
Args:
image: Input image as numpy array
metadata: Optional metadata dictionary
Returns:
Binary segmentation mask
"""
binary_mask = image > self.threshold
return binary_mask.astype(np.uint8)
@hookimpl
def segmentation_plugin_name(self) -> str:
"""Return the name of the plugin."""
return "Threshold Segmentation"
@hookimpl
def segmentation_plugin_description(self) -> str:
"""Return a description of the plugin."""
return f"Simple thresholding at value {self.threshold}"
# Method 1: Direct usage with ZarrNii (recommended for simple cases)
custom_plugin = ThresholdSegmentation(threshold=0.3)
segmented = znimg.segment(custom_plugin)
# Method 2: Using the plugin manager (recommended for external plugins)
from zarrnii.plugins import get_plugin_manager
pm = get_plugin_manager()
plugin = ThresholdSegmentation(threshold=0.3)
pm.register(plugin)
# Now the plugin is available through the plugin manager
# and can be discovered by other tools
registered_plugins = pm.get_plugins()
print(f"Registered {len(registered_plugins)} plugins")
External Segmentation Plugin Development
External plugins allow you to package and distribute your custom segmentation algorithms as separate Python packages. Here's a complete example:
Step 1: Create Your Plugin Package Structure
my_segmentation_plugin/
├── pyproject.toml
└── my_segmentation_plugin/
├── __init__.py
└── adaptive_threshold.py
Step 2: Implement Your Segmentation Plugin
In adaptive_threshold.py:
"""Adaptive threshold segmentation plugin."""
from zarrnii.plugins import SegmentationPlugin
from zarrnii.plugins.segmentation.base import hookimpl
import numpy as np
from skimage.filters import threshold_local
class AdaptiveThresholdSegmentation(SegmentationPlugin):
"""Adaptive thresholding segmentation for images with varying illumination."""
def __init__(self, block_size=35, offset=10, method='gaussian', **kwargs):
"""Initialize adaptive threshold segmentation.
Args:
block_size: Size of pixel neighborhood for threshold calculation
offset: Constant subtracted from weighted mean
method: Method for computing threshold ('gaussian' or 'mean')
"""
super().__init__(
block_size=block_size,
offset=offset,
method=method,
**kwargs
)
self.block_size = block_size
self.offset = offset
self.method = method
@hookimpl
def segment(self, image: np.ndarray, metadata=None) -> np.ndarray:
"""Segment image using adaptive thresholding.
Args:
image: Input image as numpy array
metadata: Optional metadata (unused)
Returns:
Binary segmentation mask
"""
if image.size == 0:
raise ValueError("Input image is empty")
if image.ndim < 2:
raise ValueError("Input image must be at least 2D")
# Work with 2D slice for threshold computation
work_image = image
if work_image.ndim > 2:
# Use first channel/slice if multi-dimensional
if work_image.ndim == 3 and work_image.shape[0] <= 4:
work_image = work_image[0]
elif work_image.ndim > 3:
work_image = work_image.reshape(-1, *work_image.shape[-2:])[0]
# Compute adaptive threshold
threshold = threshold_local(
work_image,
block_size=self.block_size,
offset=self.offset,
method=self.method
)
# Apply threshold to original image
binary_mask = image > threshold
return binary_mask.astype(np.uint8)
@hookimpl
def segmentation_plugin_name(self) -> str:
"""Return the name of the plugin."""
return "Adaptive Threshold Segmentation"
@hookimpl
def segmentation_plugin_description(self) -> str:
"""Return a description of the plugin."""
return (
f"Adaptive thresholding using {self.method} method "
f"(block_size={self.block_size}, offset={self.offset})"
)
Step 3: Configure Package Discovery
In __init__.py:
"""My Segmentation Plugin Package."""
from .adaptive_threshold import AdaptiveThresholdSegmentation
__all__ = ["AdaptiveThresholdSegmentation"]
In pyproject.toml:
[project]
name = "my-segmentation-plugin"
version = "0.1.0"
description = "Adaptive threshold segmentation plugin for ZarrNii"
dependencies = [
"zarrnii>=0.1.0",
"scikit-image>=0.21.0",
]
[project.entry-points."zarrnii.plugins"]
adaptive_threshold = "my_segmentation_plugin:AdaptiveThresholdSegmentation"
Step 4: Use Your External Plugin
After installing (pip install my-segmentation-plugin):
from zarrnii import ZarrNii
from my_segmentation_plugin import AdaptiveThresholdSegmentation
# Load your data
znimg = ZarrNii.from_ome_zarr("input.ome.zarr")
# Use the plugin directly
segmented = znimg.segment(
AdaptiveThresholdSegmentation(block_size=51, offset=5)
)
# Or register with the plugin manager for discovery
from zarrnii.plugins import get_plugin_manager
pm = get_plugin_manager()
plugin = AdaptiveThresholdSegmentation(block_size=51, offset=5)
pm.register(plugin)
# Call hooks through the plugin manager
test_image = znimg.data.compute()
results = pm.hook.segment(image=test_image)
# Get plugin information
names = pm.hook.segmentation_plugin_name()
descriptions = pm.hook.segmentation_plugin_description()
print(f"Using: {names[0]} - {descriptions[0]}")
Working with Multi-channel Images
The segmentation plugins automatically handle multi-channel images:
# Create multi-channel test data
multichannel_data = np.random.rand(3, 50, 100, 100) # 3 channels
multichannel_data[0, 20:30, 40:60, 40:60] += 0.5 # Add signal to first channel
darr = da.from_array(multichannel_data, chunks=(1, 25, 50, 50))
znimg = ZarrNii.from_darr(darr, axes_order="ZYX", orientation="RAS")
# Segment - will use first channel for threshold calculation
# but preserve all channel dimensions in output
segmented = znimg.segment_otsu()
print(f"Input shape: {znimg.shape}") # (3, 50, 100, 100)
print(f"Output shape: {segmented.shape}") # (3, 50, 100, 100)
Integration with Existing Workflows
The segmentation plugins integrate seamlessly with other ZarrNii operations:
# Complete workflow: load, downsample, segment, save
znimg = ZarrNii.from_ome_zarr("input_image.ome.zarr", level=1)
# Downsample for faster processing
downsampled = znimg.downsample(factors=2, spatial_dims=["z", "y", "x"])
# Apply segmentation
segmented = downsampled.segment_otsu(nbins=128)
# Crop to region of interest
bbox_min = (10, 20, 20)
bbox_max = (40, 80, 80)
cropped = segmented.crop_with_bounding_box(bbox_min, bbox_max)
# Save final result
cropped.to_ome_zarr("processed_segmentation.ome.zarr")