Transfer - Gorilla-Lab-SCUT/gorilla-core GitHub Wiki

代码迁移手册

针对 Solver 不完善和灵活度较差的问题,我们现阶段选择一种折中的方案,提供一个 demo 性质的网络训练脚本,供同学们进行参考,整个训练流程不需要定义诸如 Solver 等封装类别,均由函数搭建实现。

我们提供了利用 PointNetModelNet40 进行分类的训练脚本,代码详见: 该脚本重点在于实现了分布式的训练(虽然对于 ModelNet40 这样的小数据集分布式不一定能实现更快的速度,这里仅提供一个简单的任务作为 demo 脚本)。 代码中已经包含了尽可能详尽的注释,该手册的目的在于更详尽的解读脚本以及辅助同学们更好地进行代码的迁移。 理论上,同学们只需要将脚本中的 FIXME 部分根据自己的需求进行修改即可。

脚本导入依赖项后,开始于 __main__

if __name__ == "__main__":
    # 获取命令行参数
    args = get_parser()

    # 自动查找空闲 gpu 并设定 (Optional, 可手动设置,详情可观察源码)
    gorilla.set_cuda_visible_devices(num_gpu=args.num_gpus)

    # 启动器 (分布式训练的关键)
    # args.num_gpus = 1 的情况下
    # 等同于 main(args)
    gorilla.launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,) # use tuple to wrap
    )

首先,我们观察 get_parser()

def get_parser():
    # FIXME: `default_argument_parser` 中包含了对于分布式来说
    #        必要的参数,如 `num-gpus`,以及一下辅助参数,如 `--resume`
    #        同学们应该在此之上进行添加自己需要的参数
    #        我在这里添加了 `config` 参数在于确定配置文件的路径
    parser = gorilla.default_argument_parser()
    parser.add_argument("--config",
                        type=str,
                        default="config/default.yaml",
                        help="path to config file")

    args_cfg = parser.parse_args()

    return args_cfg

以上得到参数后,在暂时忽略 gorilla.launch(本质上是一个包装器,仅在多 gpu 时起作用),进入 main(args)

def main(args):
    # 读取配置文件,Config 详见 API 文档
    cfg = gorilla.Config.fromfile(args.config)
    # 将 `args` 中的参数写入读取配置文件得到的 `cfg` 中
    cfg = gorilla.config.merge_cfg_and_args(cfg, args)

    # 初始化 logger 以及获取 log 文件所在的目录如下:
    # root
    #  └── log
    #       └── ${config_prefix} -> log_dir
    log_dir, logger = gorilla.collect_logger(
        prefix=os.path.splitext(os.path.basename(args.config))[0])
    #### NOTE: 同学们可以自己给定路径初始化 logger,我们也提供了相应的 API 如下:
    # logger = gorilla.get_logger(log_file)
    cfg.log_dir = log_dir # 获取日志目录方便后续对 epoch 进行保存

    # 对必要文件和目录进行备份
    # FIXME: 同学们可以修改 `backup_list` 来修改需要备份的文件和目录
    #        目录则会保持现有的结构进行拷贝(详见 API 文件)
    #        同学们可以修改 `backup_dir` 修改备份地址
    backup_list = ["plain_train.py", "test.py", "network", args.config]
    backup_dir = os.path.join(log_dir, "backup")
    gorilla.backup(backup_dir, backup_list)
    
    # 设置随机种子
    seed = cfg.get("seed", 0)
    gorilla.set_random_seed(seed)

    # 打印信息
    logger.info("=> creating model ...")

    # 初始化模型
    # NOTE: 该例子采用了 gorilla-core 实现的构建函数,非常推荐同学们使用
    #       当然不习惯的同学也可以不使用构建函数手动初始化模型
    model = gorilla.build_model(cfg.model)
    model = model.cuda()

    # 分布式操作
    if args.num_gpus > 1:
        # 将模型中的 BatchNorm 转换为 SyncBatchNorm (NOTE: 低版本 pytorch 该函数存在问题,已知 1.7 及以上ok,1.3.1 及以下不行)
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        # 分布式包装函数
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gorilla.get_local_rank()])

    # (可选)打印模型参数
    # logger.info("Model:\n{}".format(model))

    # (可选)统计模型参数量
    count_parameters = sum(gorilla.parameter_count(model).values())
    logger.info(f"#classifier parameters new: {count_parameters}")

    # 可是训练模型
    do_train(model, cfg, logger)

以上就是对模型进行初始化以及进行运行相关辅助函数,接下来进入到 do_train 函数,实现网络的训练:

def do_train(model, cfg, logger):
    model.train()
    # 根据配置文件中的 `optimizer/lr_scheduler` 参数项初始化 `optimizer/lr_scheduler`
    # 同学们也可以手动进行初始化
    # 如果不使用 lr_scheduler 也可以手动编写函数管理学习率
    optimizer = gorilla.build_optimizer(model, cfg.optimizer)
    lr_scheduler = gorilla.build_lr_scheduler(optimizer, cfg.lr_scheduler)

    # loss 函数初始化,可以手动初始化 loss 函数
    # 如果 loss 计算在 model 中进行,可以省略
    criterion = gorilla.build_loss(cfg.loss)

    ### 网络/优化器/学习策略 的相关加载
    # get_checkpoint 可选,只要获取到 checkpoint 路径即可
    checkpoint = get_checkpoint(cfg.log_dir)
    # 判断 checkpoint 是否为有效文件
    if gorilla.is_filepath(checkpoint):
        # meta is the dict save some necessary information (last epoch/iteration, acc, loss)
        # meta 是一个保存与网络参数无关但是必要的参数的字典
        # 用来存储例如 epoch/iter/acc 等参数
        meta = gorilla.resume(model=model,
                              filename=checkpoint,
                              optimizer=optimizer,    # optimizer/scheduler 可以缺省
                              scheduler=lr_scheduler, # 不给定参数即可(为 None)
                              resume_optimizer=True,  # 是否对 optimizer/scheduler
                              resume_scheduler=True,  # 进行加载
                              strict=False,           # 严格判断加载参数一致性
                              )
        # 获取加载权重中保存的 epoch/iter
        # 不然初始化为 1
        epoch = meta.get("epoch", 0) + 1
        iter = meta.get("iter", 0) + 1
    
    # 初始化数据集和迭代器
    # NOTE: build_dataloader 包装了 build_dataset
    #       同学们可通过阅读源码了解原理,如果不使用该 API
    #       注意如果进行分布式则需要对 torch.utils.data.Dataloader
    #       的 sampler 针对分布式进行特别定义,为此,我们提供了
    #       gorilla.data.DistributedSampler 供同学们使用
    #       同学们只需根据以下提示进行定义即可实现 dataloader 的分布式
    #       >>> sampler = gorilla.data.DistributedSampler(dataset)
    #       >>> dataloader = torch.utils.data.Dataloader(dataset, batch_size, **kwargs)
    cfg.dataset.split = "train"
    dataset = gorilla.build_dataset(cfg.dataset)
    train_dataloader = gorilla.build_dataloader(dataset,
                                                cfg.dataloader,
                                                shuffle=True,
                                                drop_last=True)

    # 初始化 tensorboard
    # NOTE: TensorBoardWriter 可以视作 tensorboardX.SummaryWriter 的轻微封装
    #       它支持和 tensorboardX.SummaryWriter 一样的 add_scalar 和 add_scalars API
    #       另外它内置了 buffer 支持记录功能,详见 API文档
    writer = gorilla.TensorBoardWriter(cfg.log_dir)

    # initialize time buffer and timers (Optional)
    # 初始化计时器
    iter_timer = gorilla.Timer()
    epoch_timer = gorilla.Timer()

    # 初始化关于 loss 和时间的 buffer,用于统计
    # NOTE: HistoryBuffer 可以视作实现了 "clear/avg/sum` 功能的 List
    #       详情见 API文档
    loss_buffer = gorilla.HistoryBuffer()
    iter_time = gorilla.HistoryBuffer()
    data_time = gorilla.HistoryBuffer()

    # 开始训练
    while epoch <= cfg.epochs:
        torch.cuda.empty_cache() # (empty cuda cache, Optional)
        for i, batch in enumerate(train_dataloader):
            # 从计时器中获取数据加载的时间并记录
            data_time.update(iter_timer.since_last())
            ### FIXME: 以下就是同学们需要替换的内容,换成自己所需要的前传步骤
            point_sets = batch["point_set"].cuda() # [B, N, C]
            labels = batch["label"].long().cuda() # [B, N]

            # 模型前传和计算loss
            logits = model(point_sets)
            loss_inp = {"logits": logits,
                        "labels": labels}
            loss, loss_out = criterion(loss_inp) # [B, N, num_class]
            # FIXME: 以上直到 FIXME 的内容为同学们需要替换的部分

            # 记录 loss
            loss_buffer.update(loss)

            # 采样获取学习率
            lr = optimizer.param_groups[0]["lr"]
            #### 写入 tensorboard (NOTE: 以下为 3 种等效的实现方式)
            # # solution1: (record the data dict into buffer)
            # writer.update({"train/loss": loss, "lr": lr})
            # writer.write(iter)
            # solution2:
            writer.update({"train/loss": loss, "lr": lr}, iter) # given the `global_step` means write immediately
            # # # solution3:
            # writer.add_scalar(f"train/loss", loss, iter)
            # writer.add_scalar(f"lr", lr, iter)

            # # (NOTE: `loss_out` is work for multi losses, which saves each loss item)
            # writer.update(loss_out, iter) # it will not add the `train/` prefix
            # for k, v in loss_out.items():
            #     writer.add_scalar(f"train/{k}", v, iter)

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 更新 iter
            iter += 1

            # 从计时器获取前传花费的时间
            iter_time.update(iter_timer.since_start())
            iter_timer.reset() # 重置计时器以便进行下一次计时

            # 计算训练所需的剩余时间
            remain_iter = (cfg.epochs - epoch + 1) * len(train_dataloader) + i + 1
            remain_time = gorilla.convert_seconds(remain_iter * iter_time.avg) # 将描述转换为 "时:分:秒" 的字符串

            # 打印信息
            print(f"epoch: {epoch}/{cfg.epochs} iter: {i + 1}/{len(train_dataloader)} "
                  f"lr: {lr:4f} loss: {loss_buffer.latest:.4f}({loss_buffer.avg:.4f}) "
                  f"data_time: {data_time.latest:.2f}({data_time.avg:.2f}) "
                  f"iter_time: {iter_time.latest:.2f}({iter_time.avg:.2f}) eta: {remain_time}")
        
        # 同步(仅分布式起作用)
        gorilla.synchronize()
        # 更新学习率
        # NOTE: 如果是 iteration 来控制学习率则放到循环内部
        #       同学们也可以手动控制学习率
        lr_scheduler.step()

        # 记录每个 epoch 的训练信息
        logger.info(f"epoch: {epoch}/{cfg.epochs}, train loss: {loss_buffer.avg}, time: {epoch_timer.since_start()}s")
        # 清空各个 buffer
        iter_time.clear()
        epoch_timer.clear()
        loss_buffer.clear()

        # 将必要信息保存到 meta 中
        meta = {"epoch": epoch,
                "iter": iter,
                "loss": loss_buffer.avg}
    
        # 保存权重
        # NOTE: 我们提供了 save_checkpoint 的 API 可供同学们保存权重,
        #       路径同学们可以自己决定,其中 `model` 和 `filename` 为
        #       必要输入参数,其余均可缺省(为None,不保存)
        checkpoint = os.path.join(cfg.log_dir, "epoch_{0:05d}.pth".format(epoch))
        gorilla.save_checkpoint(model=model,
                                filename=checkpoint,
                                optimizer=optimizer,
                                scheduler=lr_scheduler,
                                meta=meta)
        logger.info("Saving " + checkpoint)
        # 最后一个 epoch 始终保存为 "epoch_latest.pth" (Optional)
        latest_checkpoint = os.path.join(cfg.log_dir, "epoch_latest.pth")
        gorilla.save_checkpoint(model=model,
                                filename=latest_checkpoint,
                                optimizer=optimizer,
                                scheduler=lr_scheduler,
                                meta=meta)

        # 更新 epoch
        epoch += 1

以上就实现了一个训练脚本,训练的启动也非常简单,运行以下命令即可:

python plain_train.py --config ${config-file}

其中分布式相关的参数为 --num-gpus,默认为 1,要实现分布式训练,运行以下命令即可:

python plain_train.py --config ${config-file} --num-gpus ${gpus}

日志目录

关于日志目录,当运行 gorilla.collect_logger 可得到日志目录结构如下:(假设配置文件为 ./config/default.yaml)

root (project root)
  |
  └── log (default log root, can be modified in `gorilla.collect_logger`)
      |
      ├── default -> `log_dir` 
      |     ├── backup            (directory to backup)
      |     |    └── ...
      |     |
      |     ├── ${timestamp}.log  (log file)
      |     ├── events.out.xxx    (tensorboard file)
      |     ├── epoch_xxx.pth     (save checkpoint)
      |     ├── ...
      |     └── epoch_latest.pth
      |
      └── (other config directory)

同学们可以结合源码根据需要自行更改。

注册与构建机制

从源码中可以看到,dataloader/model/criterion 均用 build_dataloader/build_model/build_loss 实现。 这里就要说明一下 gorilla 的注册以及构建机制。 注册机制详情见 API文档,这里粗略说明一下。

import gorilla

@gorilla.MODELS.register_module()
class PointNetCls(nn.Module):
    def __init__(self,
                 in_channels: int = 3,
                 feat_size: int = 1024,
                 num_classes: int = 2,
                 dropout: float = 0.,
                 classifier_layer_dims: Iterable[int] = [512, 256],
                 feat_layer_dims: Iterable[int] = [64, 128],
                 activation=F.relu,
                 batchnorm: bool = True,
                 transposed_input: bool = False):
    ...

以上为 PointNet 的初始化,可以看到在类别声明前多了装饰器 @gorilla.MODELS.register_module() 这帮助我们将 PointNetCls 加入 gorilla.MODELS 的注册列表中,我们可以通过 gorilla.MODELS 进行查看:

>>> gorilla.MODELS
Registry(type=models)
PointNetCls:
    <class 'gorilla3d.nn.models.pointnet.pointnet.PointNetCls'>:

这个注册列表同学们可以看作一个包含相关类被调用的 字典。只要同学们在相应目录中写好相应的 __init__.py 脚本以及加上修饰符 @gorilla.REGISTRY.register_module() 即可实现相关的注册。 注册后,我们即可用构建函数进行构建,对应于 gorilla.MODELS 我们有 build_model 函数,利用构建函数构建的时候,参数是需要一一对应的(带有默认参数可不填),PointNetCls 的配置如下(.yaml 格式):

# model
model:
  type: PointNetCls
  in_channels: 3
  feat_size: 1024
  num_classes: 40
  dropout: 0.4
  classifier_layer_dims:
    - 512
    - 256
  feat_layer_dims:
    - 64
    - 128

可以看到与 PointNetCls 的初始化参数是一一对应的关系,由此我们可以利用 model=build_model(cfg.model) 实现 PointNetCls 的初始化。 build_module/build_loss/build_dataset 同理。

为了更好地抽象化我们提供了以下注册表和相应的构建函数:

LOSSES = Registry("losses")         # -> build_loss
MODELS = Registry("models")         # -> build_model
MODULES = Registry("modules")       # -> build_modules
DATASETS = Registry("datasets")     # -> build_dataset
OPTIMIZERS = Registry("optimizers") # -> build_optimizer
SCHEDULERS = Registry("schedulers") # -> build_scheduler

Tips:

关于 build_dataloader 与上面的构建函数有些许不同,其是对 build_dataset 的进一步封装,内容很简单,推荐同学们阅读一下源码