DNN训练mnist 日志 - housekeeper-software/tech GitHub Wiki

尝试1

参数

SimpleDNN model({784, 256, 128, 64, 10}, {0.2, 0.4, 0.2});
AdamW optimizer(0.0002, 0.9, 0.999, 1e-8, 0.01);
int epochs = 100;
int batch_size = 64;
// 定义早停的默认参数
const int kDefaultPatience =
    20;  // 默认在验证准确率连续10个eval_interval没有改善后停止
const double kDefaultMinDelta = 1e-4;  // 认为验证准确率有改善的最小阈值

图像增强代码:

// ====================================================================
// 更新的图像增强函数,包含随机旋转
// ====================================================================
torch::Tensor apply_image_augmentations(const torch::Tensor& input_batch,
                                        int image_height, int image_width) {
  torch::Tensor augmented_batch = input_batch.clone();

  int64_t batch_size = augmented_batch.size(0);
  int64_t flat_dim = augmented_batch.size(1);

  if (flat_dim != image_height * image_width) {
    DLOG(WARNING) << "Flat dimension (" << flat_dim
                  << ") does not match image_height * image_width ("
                  << image_height * image_width
                  << "). Skipping augmentation for this batch.";
    return input_batch;
  }

  std::random_device rd;
  std::mt19937 gen(rd());
  std::uniform_real_distribution<> prob_dis(
      0.0, 1.0);  // 用于随机决策是否执行某种增强

  for (int64_t i = 0; i < batch_size; ++i) {
    torch::Tensor flat_image_tensor = augmented_batch[i];
    torch::Tensor reshaped_image_tensor =
        flat_image_tensor.reshape({image_height, image_width}).contiguous();
    cv::Mat image_mat(image_height, image_width, CV_32FC1,
                      reshaped_image_tensor.data_ptr());

    // 1. 随机水平翻转
    if (prob_dis(gen) > 0.5) {
      cv::flip(image_mat, image_mat, 1);
    }

    // 2. 随机亮度调整
    if (prob_dis(gen) > 0.5) {
      float delta_brightness = random_float(-0.2, 0.2, gen);  // 更大的调整范围
      image_mat += delta_brightness;
      cv::max(image_mat, 0.0, image_mat);
      cv::min(image_mat, 1.0, image_mat);
    }

    // 3. 随机对比度调整
    if (prob_dis(gen) > 0.5) {
      float alpha = random_float(0.8, 1.2, gen);  // 对比度因子 [0.8, 1.2]
      image_mat *= alpha;
      cv::max(image_mat, 0.0, image_mat);
      cv::min(image_mat, 1.0, image_mat);
    }

    // 4. 随机高斯模糊 (小核,避免过度模糊)
    if (prob_dis(gen) > 0.7) {  // 较低的概率
      int kernel_size_val =
          static_cast<int>(random_float(0, 1, gen) > 0.5 ? 3 : 5);
      // 确保 kernel_size 是奇数
      if (kernel_size_val % 2 == 0) {
        kernel_size_val++;
      }
      cv::GaussianBlur(image_mat, image_mat,
                       cv::Size(kernel_size_val, kernel_size_val), 0);
    }

    // 5. 随机旋转 (使用 warpAffine)
    if (prob_dis(gen) > 0.6) {  // 例如,60% 的概率进行旋转
      float angle =
          random_float(-15.0, 15.0, gen);  // 旋转角度范围 -15 到 15 度
      float scale =
          1.0;  // 缩放因子,可以设置为 random_float(0.8, 1.2, gen) 进行随机缩放

      // 计算旋转中心 (图像中心)
      cv::Point2f center((image_width - 1) / 2.0F, (image_height - 1) / 2.0F);

      // 获取旋转矩阵
      cv::Mat rot_mat = cv::getRotationMatrix2D(center, angle, scale);

      // 定义输出图像的大小。通常保持与原图相同大小
      cv::Size output_size(image_width, image_height);

      // 应用仿射变换。
      // INTER_LINEAR 是线性插值。
      // BORDER_CONSTANT 指定边界模式,即超出图像部分的填充方式。
      // Scalar(0) 指定填充值,对于归一化到 [0,1] 的灰度图,0 代表黑色。
      cv::warpAffine(image_mat, image_mat, rot_mat, output_size,
                     cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
    }

    torch::Tensor augmented_image_tensor = torch::from_blob(
        image_mat.data, {image_height, image_width}, torch::kFloat32);
    augmented_batch[i] = augmented_image_tensor.clone().reshape({flat_dim});
  }

  return augmented_batch;
}

结果

最好的结果是发生在 epoch=91, test acc = 0.9835
epoch =100:
[E604 14:01:26.000000000 training.cc:318] train,   训练耗时: 175.56  秒
[E604 14:01:26.000000000 training.cc:320] train,   平均训练损失: 0.0756407
[E604 14:01:26.000000000 training.cc:321] train,   训练准确率: 0.977038
[E604 14:01:26.000000000 training.cc:328] train,   正在评估测试集:
[E604 14:01:28.000000000 training.cc:331] train,   测试准确率: 0.9812
[E604 14:01:28.000000000 training.cc:345] train,   测试准确率未改善,耐心计数: 9/20
[E604 14:01:28.000000000 training.cc:360] train,   训练完成,达到最大 epoch 数。

AI建议

显著增加 epochs 数量。 这是一个低风险、高回报的策略
如果可能,实现更多类型的图像增强

SimpleDNN model({784, 256, 128, 64, 10}, {0.2, 0.35, 0.2}); // 尝试调整
int epochs = 200; // 尝试 200 或 300

尝试2

参数

SimpleDNN model({784, 256, 128, 64, 10});
AdamW optimizer(0.0002, 0.9, 0.999, 1e-8, 0.01);
  int epochs = 300;
  int batch_size = 64;
  int eval_interval = 1;

  // 学习率调度参数
  int lr_decay_epochs = 50;    // 每 50 个 epoch 衰减学习率
  double lr_decay_rate = 0.5;  // 学习率衰减到原来的一半
除了最后一个全连接层,其他全连接层增加了 BatchNorm1D,同时去除了dropout层。
现在的网络结构:
layers_ = {fc1_.get(),    bn1_.get(),    relue1_.get(), fc2_.get(),
             bn2_.get(),    relue2_.get(), fc3_.get(),    bn3_.get(),
             relue3_.get(), fc4_.get()};

图像增强:

// ====================================================================
// 更新的图像增强函数,包含随机旋转、平移、缩放、弹性形变
// ====================================================================
torch::Tensor apply_image_augmentations(const torch::Tensor& input_batch,
                                        int image_height, int image_width) {
  torch::Tensor augmented_batch = input_batch.clone();

  int64_t batch_size = augmented_batch.size(0);
  int64_t flat_dim = augmented_batch.size(1);

  if (flat_dim != image_height * image_width) {
    DLOG(WARNING) << "Flat dimension (" << flat_dim
                  << ") does not match image_height * image_width ("
                  << image_height * image_width
                  << "). Skipping augmentation for this batch.";
    return input_batch;
  }

  std::mt19937 gen(std::chrono::system_clock::now().time_since_epoch().count());
  std::uniform_real_distribution<> prob_dis(
      0.0, 1.0);  // 用于随机决策是否执行某种增强

  for (int64_t i = 0; i < batch_size; ++i) {
    torch::Tensor flat_image_tensor = augmented_batch[i];
    torch::Tensor reshaped_image_tensor =
        flat_image_tensor.reshape({image_height, image_width}).contiguous();
    cv::Mat image_mat(image_height, image_width, CV_32FC1,
                      reshaped_image_tensor.data_ptr());

    // 1. 随机水平翻转 (较少用于 MNIST)
    // if (prob_dis(gen) > 0.9) { // 降低概率,因为数字翻转通常是另一个数字
    //   cv::flip(image_mat, image_mat, 1);
    // }

    // 2. 随机亮度调整
    if (prob_dis(gen) > 0.5) {
      float delta_brightness = random_float(-0.2, 0.2, gen);
      image_mat += delta_brightness;
      cv::max(image_mat, 0.0, image_mat);
      cv::min(image_mat, 1.0, image_mat);
    }

    // 3. 随机对比度调整
    if (prob_dis(gen) > 0.5) {
      float alpha_contrast = random_float(0.8, 1.2, gen);
      image_mat *= alpha_contrast;
      cv::max(image_mat, 0.0, image_mat);
      cv::min(image_mat, 1.0, image_mat);
    }

    // 4. 随机高斯模糊 (小核,避免过度模糊)
    if (prob_dis(gen) > 0.7) {
      int kernel_size_val =
          static_cast<int>(random_float(0, 1, gen) > 0.5 ? 3 : 5);
      if (kernel_size_val % 2 == 0) {
        kernel_size_val++;
      }
      cv::GaussianBlur(image_mat, image_mat,
                       cv::Size(kernel_size_val, kernel_size_val), 0);
    }

    // 5. 随机旋转 (使用 warpAffine)
    if (prob_dis(gen) > 0.6) {
      float angle = random_float(-15.0, 15.0, gen);
      float scale = 1.0;  // 初始缩放为 1.0

      cv::Point2f center((image_width - 1) / 2.0F, (image_height - 1) / 2.0F);
      cv::Mat rot_mat = cv::getRotationMatrix2D(center, angle, scale);
      cv::Size output_size(image_width, image_height);
      cv::warpAffine(image_mat, image_mat, rot_mat, output_size,
                     cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
    }

    // ====== 新增增强 ======

    // 6. 随机平移 (Translation)
    if (prob_dis(gen) > 0.6) {           // 60% 概率进行平移
      float max_translation_pixels = 2;  // 最大平移像素数,可以调整
      float tx =
          random_float(-max_translation_pixels, max_translation_pixels, gen);
      float ty =
          random_float(-max_translation_pixels, max_translation_pixels, gen);

      cv::Mat trans_mat = (cv::Mat_<float>(2, 3) << 1, 0, tx, 0, 1, ty);
      cv::warpAffine(image_mat, image_mat, trans_mat, image_mat.size(),
                     cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
    }

    // 7. 随机缩放 (Scaling)
    // 缩放应与旋转结合使用,或者单独进行
    if (prob_dis(gen) > 0.6) {                           // 60% 概率进行缩放
      float scale_factor = random_float(0.8, 1.2, gen);  // 缩放范围 [0.8, 1.2]

      cv::Point2f center((image_width - 1) / 2.0F, (image_height - 1) / 2.0F);
      cv::Mat scale_mat = cv::getRotationMatrix2D(
          center, 0, scale_factor);  // 0 度旋转,只做缩放
      cv::warpAffine(image_mat, image_mat, scale_mat, image_mat.size(),
                     cv::INTER_LINEAR, cv::BORDER_CONSTANT, cv::Scalar(0));
    }

    // 8. 弹性形变 (Elastic Distortion)
    // 这是一个计算量较大的增强,可以降低其应用概率或参数范围
    if (prob_dis(gen) > 0.8) {  // 例如,20% 的概率应用弹性形变
      float alpha = random_float(20.0, 40.0, gen);  // 形变强度
      float sigma = random_float(4.0, 6.0, gen);    // 高斯核标准差
      image_mat = elastic_transform(image_mat, alpha, sigma, gen);
    }
    // ============================

    torch::Tensor augmented_image_tensor = torch::from_blob(
        image_mat.data, {image_height, image_width}, torch::kFloat32);
    augmented_batch[i] = augmented_image_tensor.clone().reshape({flat_dim});
  }

  return augmented_batch;
}

结果

在epoch=109,在测试集上的准确率:0.9924

⚠️ **GitHub.com Fallback** ⚠️