度量学习: 获取类中心 - bobo0810/Classification GitHub Wiki
场景
以人脸比对为例
- 已知:预训练模型、新的大规模人脸比对数据集(类别数>10w)。
- 目的:基于新数据集,微调训练。
当前问题 :训练时若随机初始化分类器,类别数过大将导致收敛过慢,甚至无法收敛。
-
解决方案:基于新数据集,预训练模型先提取各类别的类中心特征,来初始化分类器权重。
类中心为类内所有图片的特征均值
-
原理:分类器维护着类中心矩阵,初始化权重后已基本可分,如下图。再进行微调训练可快速收敛,性能大幅提升。
实现
- 生成类中心
python ./ExtraTools/build_class_center.py --txt_path 新数据集路径 --weight_path 特征模型路径 --save_npy 类中心的保存路径
- 启用
metric_train.py
打开注释,加载指定类中心路径
# classcenter = np.load("home/xxx.npy")
# criterion.W.data = torch.from_numpy(classcenter.T)