SimpleDNN model({784, 256, 128, 64, 10}, {0.2, 0.4, 0.2});
AdamW optimizer(0.0002, 0.9, 0.999, 1e-8, 0.01);
int epochs = 100;
int batch_size = 64;
// 定义早停的默认参数
const int kDefaultPatience =
20; // 默认在验证准确率连续10个eval_interval没有改善后停止
const double kDefaultMinDelta = 1e-4; // 认为验证准确率有改善的最小阈值
图像增强代码:
// ====================================================================
// 更新的图像增强函数,包含随机旋转
// ====================================================================
torch::Tensor apply_image_augmentations(const torch::Tensor& input_batch,
int image_height, int image_width) {
torch::Tensor augmented_batch = input_batch.clone();
int64_t batch_size = augmented_batch.size(0);
int64_t flat_dim = augmented_batch.size(1);
if (flat_dim != image_height * image_width) {
DLOG(WARNING) << "Flat dimension (" << flat_dim
<< ") does not match image_height * image_width ("
<< image_height * image_width
<< "). Skipping augmentation for this batch.";
return input_batch;
}
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_real_distribution<> prob_dis(
0.0, 1.0); // 用于随机决策是否执行某种增强
for (int64_t i = 0; i < batch_size; ++i) {
torch::Tensor flat_image_tensor = augmented_batch[i];
torch::Tensor reshaped_image_tensor =
flat_image_tensor.reshape({image_height, image_width}).contiguous();
cv::Mat image_mat(image_height, image_width, CV_32FC1,
reshaped_image_tensor.data_ptr());
// 1. 随机水平翻转
if (prob_dis(gen) > 0.5) {
cv::flip(image_mat, image_mat, 1);
}
// 2. 随机亮度调整
if (prob_dis(gen) > 0.5) {
float delta_brightness = random_float(-0.2, 0.2, gen); // 更大的调整范围
image_mat += delta_brightness;
cv::max(image_mat, 0.0, image_mat);
cv::min(image_mat, 1.0, image_mat);
}
// 3. 随机对比度调整
if (prob_dis(gen) > 0.5) {
float alpha = random_float(0.8, 1.2, gen); // 对比度因子 [0.8, 1.2]
image_mat *= alpha;
cv::max(image_mat, 0.0, image_mat);
cv::min(image_mat, 1.0, image_mat);
}
// 4. 随机高斯模糊 (小核,避免过度模糊)
if (prob_dis(gen) > 0.7) { // 较低的概率
int kernel_size_val =
static_cast<int>(random_float(0, 1, gen) > 0.5 ? 3 : 5);
// 确保 kernel_size 是奇数
if (kernel_size_val % 2 == 0) {
kernel_size_val++;
}
cv::GaussianBlur(image_mat, image_mat,
cv::Size(kernel_size_val, kernel_size_val), 0);
}
// 5. 随机旋转 (使用 warpAffine)
if (prob_dis(gen) > 0.6) { // 例如,60% 的概率进行旋转
float angle =
random_float(-15.0, 15.0, gen); // 旋转角度范围 -15 到 15 度
float scale =
1.0; // 缩放因子,可以设置为 random_float(0.8, 1.2, gen) 进行随机缩放
// 计算旋转中心 (图像中心)
cv::Point2f center((image_width - 1) / 2.0F, (image_height - 1) / 2.0F);
// 获取旋转矩阵
cv::Mat rot_mat = cv::getRotationMatrix2D(center, angle, scale);
// 定义输出图像的大小。通常保持与原图相同大小
cv::Size output_size(image_width, image_height);
// 应用仿射变换。
// INTER_LINEAR 是线性插值。
// BORDER_CONSTANT 指定边界模式,即超出图像部分的填充方式。
// Scalar(0) 指定填充值,对于归一化到 [0,1] 的灰度图,0 代表黑色。
cv::warpAffine(image_mat, image_mat, rot_mat, output_size,
cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
}
torch::Tensor augmented_image_tensor = torch::from_blob(
image_mat.data, {image_height, image_width}, torch::kFloat32);
augmented_batch[i] = augmented_image_tensor.clone().reshape({flat_dim});
}
return augmented_batch;
}
最好的结果是发生在 epoch=91, test acc = 0.9835
epoch =100:
[E604 14:01:26.000000000 training.cc:318] train, 训练耗时: 175.56 秒
[E604 14:01:26.000000000 training.cc:320] train, 平均训练损失: 0.0756407
[E604 14:01:26.000000000 training.cc:321] train, 训练准确率: 0.977038
[E604 14:01:26.000000000 training.cc:328] train, 正在评估测试集:
[E604 14:01:28.000000000 training.cc:331] train, 测试准确率: 0.9812
[E604 14:01:28.000000000 training.cc:345] train, 测试准确率未改善,耐心计数: 9/20
[E604 14:01:28.000000000 training.cc:360] train, 训练完成,达到最大 epoch 数。
显著增加 epochs 数量。 这是一个低风险、高回报的策略
如果可能,实现更多类型的图像增强
SimpleDNN model({784, 256, 128, 64, 10}, {0.2, 0.35, 0.2}); // 尝试调整
int epochs = 200; // 尝试 200 或 300
SimpleDNN model({784, 256, 128, 64, 10});
AdamW optimizer(0.0002, 0.9, 0.999, 1e-8, 0.01);
int epochs = 300;
int batch_size = 64;
int eval_interval = 1;
// 学习率调度参数
int lr_decay_epochs = 50; // 每 50 个 epoch 衰减学习率
double lr_decay_rate = 0.5; // 学习率衰减到原来的一半
除了最后一个全连接层,其他全连接层增加了 BatchNorm1D,同时去除了dropout层。
现在的网络结构:
layers_ = {fc1_.get(), bn1_.get(), relue1_.get(), fc2_.get(),
bn2_.get(), relue2_.get(), fc3_.get(), bn3_.get(),
relue3_.get(), fc4_.get()};
图像增强:
// ====================================================================
// 更新的图像增强函数,包含随机旋转、平移、缩放、弹性形变
// ====================================================================
torch::Tensor apply_image_augmentations(const torch::Tensor& input_batch,
int image_height, int image_width) {
torch::Tensor augmented_batch = input_batch.clone();
int64_t batch_size = augmented_batch.size(0);
int64_t flat_dim = augmented_batch.size(1);
if (flat_dim != image_height * image_width) {
DLOG(WARNING) << "Flat dimension (" << flat_dim
<< ") does not match image_height * image_width ("
<< image_height * image_width
<< "). Skipping augmentation for this batch.";
return input_batch;
}
std::mt19937 gen(std::chrono::system_clock::now().time_since_epoch().count());
std::uniform_real_distribution<> prob_dis(
0.0, 1.0); // 用于随机决策是否执行某种增强
for (int64_t i = 0; i < batch_size; ++i) {
torch::Tensor flat_image_tensor = augmented_batch[i];
torch::Tensor reshaped_image_tensor =
flat_image_tensor.reshape({image_height, image_width}).contiguous();
cv::Mat image_mat(image_height, image_width, CV_32FC1,
reshaped_image_tensor.data_ptr());
// 1. 随机水平翻转 (较少用于 MNIST)
// if (prob_dis(gen) > 0.9) { // 降低概率,因为数字翻转通常是另一个数字
// cv::flip(image_mat, image_mat, 1);
// }
// 2. 随机亮度调整
if (prob_dis(gen) > 0.5) {
float delta_brightness = random_float(-0.2, 0.2, gen);
image_mat += delta_brightness;
cv::max(image_mat, 0.0, image_mat);
cv::min(image_mat, 1.0, image_mat);
}
// 3. 随机对比度调整
if (prob_dis(gen) > 0.5) {
float alpha_contrast = random_float(0.8, 1.2, gen);
image_mat *= alpha_contrast;
cv::max(image_mat, 0.0, image_mat);
cv::min(image_mat, 1.0, image_mat);
}
// 4. 随机高斯模糊 (小核,避免过度模糊)
if (prob_dis(gen) > 0.7) {
int kernel_size_val =
static_cast<int>(random_float(0, 1, gen) > 0.5 ? 3 : 5);
if (kernel_size_val % 2 == 0) {
kernel_size_val++;
}
cv::GaussianBlur(image_mat, image_mat,
cv::Size(kernel_size_val, kernel_size_val), 0);
}
// 5. 随机旋转 (使用 warpAffine)
if (prob_dis(gen) > 0.6) {
float angle = random_float(-15.0, 15.0, gen);
float scale = 1.0; // 初始缩放为 1.0
cv::Point2f center((image_width - 1) / 2.0F, (image_height - 1) / 2.0F);
cv::Mat rot_mat = cv::getRotationMatrix2D(center, angle, scale);
cv::Size output_size(image_width, image_height);
cv::warpAffine(image_mat, image_mat, rot_mat, output_size,
cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
}
// ====== 新增增强 ======
// 6. 随机平移 (Translation)
if (prob_dis(gen) > 0.6) { // 60% 概率进行平移
float max_translation_pixels = 2; // 最大平移像素数,可以调整
float tx =
random_float(-max_translation_pixels, max_translation_pixels, gen);
float ty =
random_float(-max_translation_pixels, max_translation_pixels, gen);
cv::Mat trans_mat = (cv::Mat_<float>(2, 3) << 1, 0, tx, 0, 1, ty);
cv::warpAffine(image_mat, image_mat, trans_mat, image_mat.size(),
cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
}
// 7. 随机缩放 (Scaling)
// 缩放应与旋转结合使用,或者单独进行
if (prob_dis(gen) > 0.6) { // 60% 概率进行缩放
float scale_factor = random_float(0.8, 1.2, gen); // 缩放范围 [0.8, 1.2]
cv::Point2f center((image_width - 1) / 2.0F, (image_height - 1) / 2.0F);
cv::Mat scale_mat = cv::getRotationMatrix2D(
center, 0, scale_factor); // 0 度旋转,只做缩放
cv::warpAffine(image_mat, image_mat, scale_mat, image_mat.size(),
cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
}
// 8. 弹性形变 (Elastic Distortion)
// 这是一个计算量较大的增强,可以降低其应用概率或参数范围
if (prob_dis(gen) > 0.8) { // 例如,20% 的概率应用弹性形变
float alpha = random_float(20.0, 40.0, gen); // 形变强度
float sigma = random_float(4.0, 6.0, gen); // 高斯核标准差
image_mat = elastic_transform(image_mat, alpha, sigma, gen);
}
// ============================
torch::Tensor augmented_image_tensor = torch::from_blob(
image_mat.data, {image_height, image_width}, torch::kFloat32);
augmented_batch[i] = augmented_image_tensor.clone().reshape({flat_dim});
}
return augmented_batch;
}
在epoch=109,在测试集上的准确率:0.9924