度量学习: 获取类中心 - bobo0810/Classification GitHub Wiki

场景

以人脸比对为例

  • 已知:预训练模型、新的大规模人脸比对数据集(类别数>10w)。
  • 目的:基于新数据集,微调训练。

当前问题 :训练时若随机初始化分类器,类别数过大将导致收敛过慢,甚至无法收敛。

  • 解决方案:基于新数据集,预训练模型先提取各类别的类中心特征,来初始化分类器权重。

    类中心为类内所有图片的特征均值

  • 原理:分类器维护着类中心矩阵,初始化权重后已基本可分,如下图。再进行微调训练可快速收敛,性能大幅提升。

image

实现

  1. 生成类中心
python  ./ExtraTools/build_class_center.py  --txt_path 新数据集路径 --weight_path 特征模型路径  --save_npy 类中心的保存路径 

访问更多参数

  1. 启用

metric_train.py 打开注释,加载指定类中心路径

# classcenter = np.load("home/xxx.npy")
# criterion.W.data = torch.from_numpy(classcenter.T)