Learning - MalondaClement/pipeline GitHub Wiki

Learning Module

Learner 🤖

train_epoch(dataloader, model, criterion, optimizer, lr_scheduler, epoch, validClasses, void=-1, args=None)
Parameters
  • dataloader (torch.utils.data.DataLoader) - Dataloader with training data (dataloaders["train"]).
  • model (torch.nn.Module) - Model return by get_model function.
  • criterion - Loss function.
  • optimizer - Optimizer used by function to update models parameters.
  • lr_scheduler -
  • epoch (int) - Current epoch number.
  • validClasses (list) - List of all valid classes in the dataset.
  • void - Void class.
  • args (helpers.ARGS.ARGS) - Object with all arguments used during training.
Returns
  • (average loss, average accuracy)
Return type
  • tuple
validate_epoch(dataloader, model, criterion, epoch, classLabels, validClasses, void=-1, maskColors=None, folder='baseline_run', args=None)
Parameters
  • dataloader (torch.utils.data.DataLoader) - Dataloader with training data (dataloaders["test"] or dataloaders["val"]).
  • model (torch.nn.Module) - Model return by get_model function.
  • criterion - Loss function.
  • epoch (int) - Current epoch number.
  • classLabels (list) - List of all classes names.
  • validClasses (list) - List of all valid classes in the dataset.
  • void - Void class.
  • maskColors (numpy.array) - Array of colors mask for each class.
  • folder (str) - Save path (unused).
  • args (helpers.ARGS.ARGS) - Object with all arguments used during training.
Returns
  • (average loss, average accuracy, miou)
Return type
  • tuple

Utils 🧰

get_dataloader(dataset, args)
Parameters
  • dataset (class)- Dataset class (not an instance).
  • args (helpers.ARGS.ARGS) - Object with all arguments used during training.
Returns
  • Dictionary with dataloaders for train, test and val set.
Return type
  • dict
Example
from helpers.ARGS import ARGS
from datasets.tunnel import Tunnel
from learning.utils import get_dataloader

Dataset = Tunnel

args = ARGS("DeepLabV3_Resnet50", "batch_17", len(Dataset.validClasses), labels_type="csv", batch_size=4, epochs=300)

dataloaders = get_dataloader(Dataset, args)