训练一个上下文分类器 - nanpuhaha/SerpentAI GitHub Wiki

使用 Serpent.AI 框架附带的一些工具你就可以训练上下文场景分类器了,它们是特殊类型的机器学习模型。如果你想要流畅地管理游戏代理中的游戏流,那么正确训练的上下文分类器对你来说至关重要。在本指南中,我们将以 Super Hexagon 游戏为例子从头开始构建一个场景分类器。

解构游戏

想要创建一个超凡的上下文分类器的话,你至少需要花点时间深入研究一下这个游戏。将游戏解构成不同的阶段。以 Super Hexagon 为例,我们希望所创建的上下文分类器能识别四种不同的状态: main_menu 主菜单, level_select 关卡选择, game 游戏中和 game_over 游戏结束。

启动游戏

serpent launch SuperHexagon

捕捉上下文游戏帧图像

假设我们需要从 main_menu 主菜单捕捉图像的话:

  1. 手动进入 main_menu 主菜单界面。
  2. 启动一个帧捕捉器捕捉一些图像帧: serpent capture context SuperHexagon 0.5 main_menu
  3. 在捕获器运行的时候,尽量在不离开该状态的前提下多进行一些操作,以产生尽可能多的视觉变化。
  4. 等到至少有 200 帧被捕获的时候(越多越好!),单击终端,让框架将内容保存到磁盘,然后按 Ctrl-C 结束该捕捉器。
  5. 如果你到 datasets/collect_frames_for_context/main_menu 目录下查看一番,你应该可以看到这些帧了。
  6. 重复该操作,直到你捕获完所有的其他场景为止。(你至少需要两个场景样本来开始训练)

训练上下文分类器

接下来,你将执行的是你这辈子能执行的最简单的机器学习训练操作:

serpent train context 15

该命令会把你所收集的每个场景的图像帧拆分成训练集及验证集,并进行一次 15 个周期的模型训练。如果你使用的是 GPU 加速版本的 Tensorflow,训练操作将只需要几分钟的时间。不然的话就去喝杯咖啡吧。☕️

训练结束后,你要做的第一件事就是系统的分析一下终端输出的内容。 首先,我们来看一下每一个迭代的准确率,寻找一下准确率高于 0.9 的几个周期。 如果你的所有准确率值都在 0.9 以下,那你可能需要为每个场景收集更多的帧图像。如果这种情况没有出现的话,你已经成功地训练了一个模型!但不要过早的庆祝... 事实上你的模型有可能过拟合了。不过这都没关系。让我们再来看看终端输出。我们要寻找的是一些迭代周期的稳定点: 高精确度,下一个损失值几乎没有减少(或上涨!)。你要记住这个周期次数,因为我们要再次用这个迭代次数进行训练。比如说这个数字是 8 的话,我们就运行:

serpent train context 8

再次验证这个结果是否让你满意。你的模型文件将存储在 datasets/context_classifier.model 目录下。

将分类器打包到游戏代理中

将模型复制到 plugins/SerpentSuperHexagonGameAgentPlugin/files/ml_models/ 目录下。

在游戏代理代码中装载分类器

将以下内容添加到游戏代理的构造函数或帧图像处理程序的初始设置函数中:

plugin_path = offshoot.config["file_paths"]["plugins"]

context_classifier_path = f"{plugin_path}/SerpentSuperHexagonGameAgentPlugin/files/ml_models/context_classifier.model"

from serpent.machine_learning.context_classification.context_classifiers.cnn_inception_v3_context_classifier import CNNInceptionV3ContextClassifier
context_classifier = CNNInceptionV3ContextClassifier(input_shape=(384, 512, 3))  # Replace with the shape (rows, cols, channels) of your captured context frames

context_classifier.prepare_generators()
context_classifier.load_classifier(context_classifier_path)

self.machine_learning_models["context_classifier"] = context_classifier

在运行中使用上下文分类器

context = self.machine_learning_models["context_classifier"].predict(game_frame.frame)