Learning - MalondaClement/pipeline GitHub Wiki
Learning Module
Learner 🤖
train_epoch(dataloader, model, criterion, optimizer, lr_scheduler, epoch, validClasses, void=-1, args=None)
Parameters
- dataloader (
torch.utils.data.DataLoader
) - Dataloader with training data (dataloaders["train"]
).
- model (
torch.nn.Module
) - Model return by get_model
function.
- criterion - Loss function.
- optimizer - Optimizer used by function to update models parameters.
- lr_scheduler -
- epoch (
int
) - Current epoch number.
- validClasses (
list
) - List of all valid classes in the dataset.
- void - Void class.
- args (
helpers.ARGS.ARGS
) - Object with all arguments used during training.
Returns
- (average loss, average accuracy)
Return type
validate_epoch(dataloader, model, criterion, epoch, classLabels, validClasses, void=-1, maskColors=None, folder='baseline_run', args=None)
Parameters
- dataloader (
torch.utils.data.DataLoader
) - Dataloader with training data (dataloaders["test"]
or dataloaders["val"]
).
- model (
torch.nn.Module
) - Model return by get_model
function.
- criterion - Loss function.
- epoch (
int
) - Current epoch number.
- classLabels (
list
) - List of all classes names.
- validClasses (
list
) - List of all valid classes in the dataset.
- void - Void class.
- maskColors (
numpy.array
) - Array of colors mask for each class.
- folder (
str
) - Save path (unused).
- args (
helpers.ARGS.ARGS
) - Object with all arguments used during training.
Returns
- (average loss, average accuracy, miou)
Return type
Utils 🧰
get_dataloader(dataset, args)
Parameters
- dataset (
class
)- Dataset class (not an instance).
- args (
helpers.ARGS.ARGS
) - Object with all arguments used during training.
Returns
- Dictionary with dataloaders for train, test and val set.
Return type
Example
from helpers.ARGS import ARGS
from datasets.tunnel import Tunnel
from learning.utils import get_dataloader
Dataset = Tunnel
args = ARGS("DeepLabV3_Resnet50", "batch_17", len(Dataset.validClasses), labels_type="csv", batch_size=4, epochs=300)
dataloaders = get_dataloader(Dataset, args)