ResAttUNet trained on MARIDA and Litter Windrows DataSet - elena-andreini/TriesteItalyChapter_PlasticDebrisDetection GitHub Wiki
This document describes the Python code which implements a deep learning segmentation model for detecting marine debris using multispectral satellite imagery from the MARIDA and Litter Windrows datasets. The model employs a Residual Attention U-Net to perform binary segmentation (debris vs. non-debris) on four Sentinel-2 bands (B4, B6, B8A, B11). The following sections detail the environment setup, data preparation, preprocessing, model training, evaluation, and optional features.
The code processes multispectral satellite imagery from two datasets:
- MARIDA Dataset: Contains 11 Sentinel-2 bands with 15 semantic classes, including marine debris, sargassum, and various water types.
- Litter Windrows (LR) Dataset: Focuses on binary classification of marine debris in windrows.
The model predicts binary masks (debris vs. non-debris) using a Residual Attention U-Net, incorporating data augmentation, custom preprocessing, and evaluation metrics to handle class imbalance and invalid pixels.
-
Dependencies:
-
rasterio
: For reading GeoTIFF files. -
torch
,torchvision
: For deep learning operations. -
pandas
,numpy
: For data manipulation. -
matplotlib
,skimage
: For visualization and image processing. -
sklearn
: For evaluation metrics. -
pytorch_lightning
(commented): For potential scalability.
-
-
Environment: Designed for Kaggle or Colab environments with GPU support, using
kagglehub
to download datasets. -
Seeding: Sets a random seed (
seed=42
) for reproducibility across NumPy, PyTorch, and Python's random module using theset_seed
function. - Device Configuration: Automatically selects GPU (CUDA) if available, otherwise defaults to CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
set_seed(seed=42)
-
MARIDA Dataset:
- Function: create_marida_df
- Purpose: Constructs a DataFrame from the MARIDA dataset, reading image and mask file paths from split files (train_X.txt, val_X.txt, test_X.txt).
- Details: Organizes paths by date and tile using extract_date_tile with regex. Filters out invalid files (e.g., missing paths).
- Function: create_marida_df
-
Litter Windrows (LR) Dataset:
- Function: create_LR_dataframe
- Purpose: Creates a DataFrame from binary split files (train_X.txt, val_X.txt, test_X.txt) containing image and mask paths.
- Function: create_LR_dataframe
- Function: compute_invalid_pixels
- Purpose: Identifies images with invalid pixels (NaNs, negative values, or values > 1) and creates filtered DataFrames (marida_df_F, lr_df_F, etc.).
- After first usage the output is saved to CSV files. Then the code loads precomputed invalid pixel statistics from such CSV files (e.g., marida_df_invalid_info.csv) to exclude problematic images.
marida_df = create_marida_df(MARIDA_path)
marida_df_invalid = pd.read_csv('/kaggle/working/marida_df_invalid_info.csv')
marida_df_F = marida_df.drop(marida_df_invalid[marida_df_invalid['nan pixels'] > 0].index)
- Function: compute_stats
- Purpose: Calculates per-band mean and standard deviation for the joint dataset (MARIDA + Litter Windrows), stored in global_stats.npz.
- MARIDA: Maps classes [1, 2, 3, 4, 9] (marine debris, dense sargassum, sparse sargassum, organic material, foam) to debris, yielding a debris fraction of approximately 0.01186.
- LR: Sub-samples non-debris pixels with a debris-to-background ratio (LR_ratio).
- Purpose: control the debris/non-debris unbalance, heavily subsampling non-labeled pixels to consider as background
- An effective ratio is computed as the weighted mean of the MARIDA and LR ratio, yielding the following class distribution
class_distribution = np.array([1 - effective_ratio, effective_ratio])
- Class: MergedSegmentationDataset_B
- Purpose: Combines MARIDA and LR datasets into a single PyTorch Dataset.
- Inputs:
- df_dataset1: MARIDA DataFrame.
- df_dataset2: LR DataFrame.
- bands_mean, bands_std: Normalization statistics.
- selected_bands: Indices of bands (B4, B6, B8A, B11).
- transform, standardization: Data augmentation and normalization transforms.
- Features:
- Loads GeoTIFF images and masks using rasterio.
- Filters invalid pixels (NaNs, no-data, < -1.5, > 1.5) with get_invalid_mask.
- Replaces invalid pixels with band mean values to ensure robust training.
- Applies transformations to image-mask pairs.
- Returns image, mask, and dataset ID (0 for MARIDA, 1 for LR).
merged_ds = MergedSegmentationDataset_B(marida_df_F, lr_df_F, global_bands_mean, global_bands_std, selected_bands, transform=transformTrain, standardization=standardization)
- Training Transforms (transformTrain):
- Converts data to PyTorch tensors.
- Applies random rotations using RandomRotationTransform (angles: [-90, 0, 90, 180]).
- Applies random horizontal flips.
- Testing Transforms (transformTest):
- Converts data to tensors without augmentation.
- Normalization:
- Uses transforms.Normalize with global_bands_mean and global_bands_std for selected bands.
transformTrain = transforms.Compose([
transforms.ToTensor(),
RandomRotationTransform([-90, 0, 90, 180]),
transforms.RandomHorizontalFlip()
])
standardization = transforms.Normalize(global_bands_mean[selected_bands].tolist(), global_bands_std[selected_bands].tolist())
- Function: custom_collate_fn - Purpose: Processes batches to ensure consistent handling of MARIDA and LR data. - Steps: - Stacks images, masks, and dataset IDs into tensors. - Moves data to the appropriate device (GPU/CPU). - Processes MARIDA masks (batch_process_marida_masks): - Maps classes [1, 2, 3, 4, 9] to debris (2), others to non-debris (1), then subtracts 1 to yield [0, 1]. - Processes LR masks (batch_select_bg_pixels): - Sets debris pixels (value 1) to 2. - Creates annular background masks by dilating debris masks with radii r1=5 and r2=20, sampling background pixels at target_ratio=10. - Sets background pixels to 1, subtracts 1 to yield [0, 1]. - Combines processed masks for both datasets.
def custom_collate_fn(batch):
images, masks, dataset_ids = zip(*batch)
images = torch.stack(images)
masks = torch.stack(masks)
dataset_ids = torch.tensor(dataset_ids, dtype=torch.long)
images, masks, dataset_ids = images.to(device), masks.to(device), dataset_ids.to(device)
lr_masks = batch_select_bg_pixels(images, masks, dataset_ids, r1=5, r2=20, target_ratio=LR_ratio, device=device)
marida_masks = batch_process_marida_masks(masks, dataset_ids, device=device)
masks = lr_masks + marida_masks
return images, masks, dataset_ids
2.4. Model Architecture
- Model: ResidualAttentionUNet
- Purpose: A U-Net variant with residual connections and attention mechanisms for enhanced feature extraction.
- Components:
- DownSampleWithAttention: Downsampling blocks with convolutions, batch normalization, LeakyReLU, average pooling, and channel/spatial attention.
- ResidualBlock: Residual connections with convolutions, batch normalization, ReLU, and attention mechanisms.
- UpSampleWithAttention: Upsampling blocks with bilinear interpolation, convolutions, batch normalization, LeakyReLU, and attention.
- Classification Layer: 1x1 convolution to output 2 classes (debris, non-debris).
- Input: 4-channel input (B4, B6, B8A, B11).
- Output: 2-channel logits for binary segmentation.
model = ResidualAttentionUNet(len(selected_bands), 2).to(device)
- DataLoader:
- Training: Uses trainLoader with MergedSegmentationDataset_B, batch_size=16, shuffling, and custom_collate_fn.
- Validation: Uses testLoader without shuffling.
- Loss Function:
- Cross-entropy loss (nn.CrossEntropyLoss) with class weights computed via gen_weights (c=1.03) to address class imbalance.
- Ignores pixels labeled -1 (invalid/unlabeled).
- Optimizer: AdamW with learning rate 1e-3 and weight decay 1e-4.
- Scheduler: ReduceLROnPlateau reduces learning rate by 0.5 if validation debris IoU does not improve for 6 epochs, with a minimum learning rate of 1e-6.
- Training Loop:
- Runs for 70 epochs with early stopping (patience=10).
- Trains the model, computes loss, and updates weights.
- Evaluates training metrics (precision, recall, F1, IoU) every 10 epochs using the Evaluation function.
- Validates on the validation set, saving the model to best_model.pth if debris IoU improves.
criterion = torch.nn.CrossEntropyLoss(ignore_index=-1, reduction='mean', weight=gen_weights(torch.from_numpy(class_distribution), c=1.03).to(torch.float32))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=6, min_lr=1e-6, verbose=True)
- Test Dataset:
- Uses marida_test_ds (MARIDA test set only) with transformTest and custom_collate_fn.
- Metrics:
- Computed via the Evaluation function:
- Macro, micro, and weighted precision, recall, and F1 scores.
- Subset accuracy and IoU.
- Debris-specific precision, recall, F1, and IoU (class 1).
- Applies a 0.7 threshold to debris class probabilities (probs[:, 1] < 0.7 set to 0).
- Computed via the Evaluation function:
- Process:
- Loads the best model from best_model.pth.
- Evaluates on the MARIDA test set, ignoring unlabeled pixels (mask = target != -1).
- Stores metrics in test_metrics_history and prints results.
model.load_state_dict(torch.load("/kaggle/working/best_model.pth", map_location=device))
model.eval()
probs[probs[:, 1] < 0.8] = 0. # Thresholding
acc = Evaluation(yPredicted, yTrue)
- Output:
- Copies the best model to a named file (e.g., model_60_epochs_ratio_1_15_bs16_test_iou_debris_083_thr0.7.pth).
- FDI and NDWI:
- Functions: compute_fdi_from_tiff, compute_ndwi
- Purpose: Compute Floating Debris Index (FDI) and Normalized Difference Water Index (NDWI) for visualization.
- Visualization: plot_fdi displays RGB patches, binary masks, FDI, and NDWI side-by-side.
- RGB Conversion:
- Functions: cvt_RGB, cvt_RGB_from_4_bands
- Purpose: Convert multispectral images to RGB for visualization.
- Display:
- Function: display
- Purpose: Visualizes image-mask pairs using Matplotlib.
- Function: display
- Pytorch Lightning Implementation:
- Purpose: A commented BinaryClassificationModel class suggests a Pytorch Lightning implementation for scalable training, logging, and checkpointing. VSCP Augmentation:
- Function: vscp
- Purpose: Implements copy-paste augmentation by copying valid pixels from the second half of the batch to the first half. Disabled to avoid index mixing issues.
- Multi-Dataset Integration: Combines MARIDA and LR datasets with consistent preprocessing and class mapping.
- Attention-Based Architecture: Uses channel and spatial attention to focus on relevant features.
- Class Imbalance Handling: Employs weighted loss and background pixel sampling to address low debris fraction.
- Invalid Pixel Management: Filters and imputes invalid pixels for robust training.
- Reproducibility: despite seeding, currently GPU computations are not deterministic, hence complete reproducibility is not guaranteed
- Evaluation Metrics: Includes debris-specific IoU for performance assessment on imbalanced data.
- Install dependencies (rasterio, torch, etc.).
- Download MARIDA and LR datasets via kagglehub.
- Place precomputed statistics and CSV files in /kaggle/working/.
MARIDA_path = '/kaggle/input/marida-marine-debrish-dataset'
LR_splits_path = '/kaggle/input/litter-windrows-patches/binary_splits'
- Execute in a Kaggle or Colab environment with GPU support.
- Outputs include the trained model (best_model.pth) and evaluation metrics.
- Adjust selected_bands to use different Sentinel-2 bands.
- Modify LR_ratio to change the background-to-debris sampling ratio.
- Tune hyperparameters (e.g., batch_size, learning rate, patience) for better performance.
- Commented-Out Code: Enable VSCP augmentation or Pytorch Lightning for enhanced functionality.
- Band Selection: Experiment with additional bands or derived indices (FDI, NDWI).
- Scalability: Integrate Pytorch Lightning for efficient training and logging.
The code provides a pipeline for marine debris segmentation using a Residual Attention U-Net on multispectral satellite imagery. It handles data preprocessing, augmentation, training, and evaluation, addressing challenges like class imbalance and invalid pixels. The model achieves reliable debris detection and can be extended with additional features or datasets for improved performance.