Core Scripts - Rvosuke/TLLM GitHub Wiki
Time-LLM核心脚本详解
本文档详细介绍Time-LLM项目中的核心Python脚本,包括run_main.py
、robustness_test.py
和missing_data_test.py
,解析其设计思路、功能实现和运行流程。
目录
- run_main.py - 主运行脚本
- robustness_test.py - 噪声鲁棒性测试
- missing_data_test.py - 缺失数据测试
- 共同设计特点与最佳实践
- 脚本使用指南
- 常见问题解决
run_main.py - 主运行脚本
run_main.py
是Time-LLM项目的核心运行脚本,负责模型的训练、验证和测试功能。
功能概述
- 训练模式:训练TimeLLM、Autoformer或DLinear模型
- 测试模式:加载预训练模型并进行预测与评估
- 可视化:支持预测结果可视化
代码流程
-
参数解析:使用
argparse
模块解析命令行参数,包括:- 任务类型(长期/短期预测)
- 数据集配置
- 模型参数
- 训练设置
-
分布式环境初始化:
if torch.cuda.is_available(): ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) deepspeed_plugin = DeepSpeedPlugin(hf_ds_config='./ds_config_zero2.json') accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], deepspeed_plugin=deepspeed_plugin) else: accelerator = Accelerator()
脚本使用Accelerate库实现分布式训练,根据是否有GPU使用不同的配置策略。
-
数据加载:
train_data, train_loader = data_provider(args, 'train') test_data, test_loader = data_provider(args, 'test')
通过
data_provider
函数加载训练和测试数据。 -
模型初始化:
if args.model == 'Autoformer': model = Autoformer.Model(args).float() elif args.model == 'DLinear': model = DLinear.Model(args).float() else: model = TimeLLM.Model(args).float()
根据参数选择初始化相应的模型。
-
训练流程(
is_training=1
时):- 初始化优化器和学习率调度器
- 设置损失函数和评估指标
- 使用Accelerate准备数据加载器、模型和优化器
- 循环训练多个epoch
- 每个epoch执行以下步骤:
- 按批次迭代训练数据
- 构建解码器输入(使用teacher forcing技术)
- 前向计算得到预测输出
- 计算损失并反向传播
- 更新参数和学习率
- 验证模型性能
- 应用早停策略
-
测试流程(
is_training=0
时):- 加载预训练模型权重
- 评估模型在测试集上的性能
- 可视化预测结果
- 保存测试结果
设计思路
-
模块化设计:
- 数据处理、模型定义和训练逻辑分离
- 支持多种模型架构(TimeLLM、Autoformer、DLinear)
-
灵活配置:
- 通过命令行参数实现高度可配置化
- 支持不同的任务类型和数据集
-
分布式训练支持:
- 通过Accelerate和DeepSpeed实现高效分布式训练
- 自动适应不同硬件环境(CPU/GPU)
-
混合精度训练:
- 支持自动混合精度训练,提高训练速度和内存效率
- 适配不同的计算精度要求
robustness_test.py - 噪声鲁棒性测试
robustness_test.py
用于评估模型对输入数据中噪声的鲁棒性,通过向测试数据添加不同水平的高斯噪声,分析模型性能变化。
功能概述
- 噪声注入:向时间序列数据中添加高斯噪声
- 性能评估:在不同噪声水平下评估模型性能
- 结果可视化:绘制噪声水平与模型性能的关系图
核心函数解析
-
噪声添加函数:
def add_noise_to_data(data_loader, noise_level=0.0, accelerator=None): """向数据集添加高斯噪声""" noisy_data_loader = [] device = accelerator.device if accelerator else torch.device("cpu") for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(data_loader): if not batch_x.is_cuda and device.type == 'cuda': batch_x = batch_x.to(device) if noise_level > 0: # 创建与输入相同数据类型的噪声 dtype = batch_x.dtype noise = torch.randn_like(batch_x, dtype=torch.float32) * noise_level noise = noise.to(dtype=dtype, device=device) # 添加噪声 batch_x = batch_x + noise noisy_data_loader.append((batch_x, batch_y, batch_x_mark, batch_y_mark)) return noisy_data_loader
该函数仅向输入序列添加噪声,保持标签序列不变,并确保噪声的数据类型与输入一致。
-
模型评估函数:
def evaluate_model_with_noise(args, model, test_data, test_loader, noise_levels, accelerator): """在不同噪声水平下评估模型性能""" criterion = nn.MSELoss() mae_metric = nn.L1Loss() mse_losses = [] mae_losses = [] for noise_level in noise_levels: # 添加噪声并评估 if noise_level == 0: test_loss, test_mae_loss = vali(args, accelerator, model, test_data, test_loader, criterion, mae_metric) else: noisy_test_loader = add_noise_to_data(test_loader, noise_level, accelerator) noisy_loader = NoiseDataLoader(noisy_test_loader) test_loss, test_mae_loss = vali(args, accelerator, model, test_data, noisy_loader, criterion, mae_metric) mse_losses.append(test_loss) mae_losses.append(test_mae_loss) return mse_losses, mae_losses
该函数对每个噪声水平执行评估,返回MSE和MAE损失列表。
-
结果可视化函数:
def plot_results(noise_levels, mse_losses, mae_losses, output_dir, model_id, task_name, pred_len): """绘制噪声水平与损失的关系图""" plt.figure(figsize=(12, 5)) # 绘制MSE损失 plt.subplot(1, 2, 1) plt.plot(noise_levels, mse_losses, 'o-', linewidth=2, markersize=8) # ...绘图代码 # 保存结果 plt.savefig(os.path.join(output_dir, f'robustness_test_{model_id}_{task_name}_pred{pred_len}.png'), dpi=300) # 保存CSV数据 results_df = pd.DataFrame({ 'noise_level': noise_levels, 'mse_loss': mse_losses, 'mae_loss': mae_losses }) results_df.to_csv(os.path.join(output_dir, f'robustness_test_{model_id}_{task_name}_pred{pred_len}.csv'), index=False)
该函数创建直观的可视化图表,并将原始数据保存为CSV格式。
设计思路
-
渐进式噪声测试:
- 从无噪声逐步增加噪声水平
- 分析模型性能随噪声增加的衰减曲线
-
数据类型兼容性:
- 特别注意保持噪声与原始数据的数据类型一致
- 避免混合精度训练环境中的类型转换问题
-
非侵入式设计:
- 不修改原始模型架构,仅改变输入数据
- 便于快速评估任何预训练模型的鲁棒性
missing_data_test.py - 缺失数据测试
missing_data_test.py
用于评估模型处理输入序列中缺失数据的能力,模拟真实场景中的数据缺失问题。
功能概述
- 缺失数据模拟:在输入序列中随机创建NaN值
- 数据插值:使用插值方法填充缺失值
- 性能评估:在不同缺失率下评估模型性能
- 可视化:展示原始数据、缺失数据和插值效果
核心函数解析
-
缺失数据创建函数:
def create_missing_data(batch_x, missing_rate=0.1, mask=None): """在输入序列中创建缺失值""" if mask is None: # 创建随机掩码,1表示保留,0表示缺失 mask = torch.FloatTensor(batch_x.shape).uniform_() > missing_rate mask = mask.to(batch_x.device) # 将缺失位置的值设为NaN missing_x = batch_x.clone() missing_x[~mask] = float('nan') return missing_x, mask
该函数根据指定的缺失率随机创建缺失值,并使用NaN标记。
-
插值填充函数:
def interpolate_missing_values(missing_x): """使用插值法填充缺失值""" batch_size, seq_len, feature_dim = missing_x.shape device = missing_x.device # 将数据移到CPU进行处理 x_cpu = missing_x.detach().cpu().numpy() interpolated_x = np.zeros_like(x_cpu) # 为每个batch和每个特征进行插值 for b in range(batch_size): for f in range(feature_dim): # 获取当前序列 series = x_cpu[b, :, f] # 找出非NaN值的索引 valid_indices = np.where(~np.isnan(series))[0] valid_values = series[valid_indices] if len(valid_indices) <= 1: # 如果只有0或1个有效值,无法插值 interpolated_x[b, :, f] = 0.0 # 填充为0 else: if len(valid_indices) < seq_len: # 如果有缺失值 # 使用线性插值填充缺失值 interp_func = interpolate.interp1d( valid_indices, valid_values, kind='linear', bounds_error=False, fill_value=(valid_values[0], valid_values[-1]) ) all_indices = np.arange(seq_len) interpolated_series = interp_func(all_indices) interpolated_x[b, :, f] = interpolated_series else: # 如果没有缺失值 interpolated_x[b, :, f] = series # 将插值后的数据转回tensor并移回原设备 return torch.FloatTensor(interpolated_x).to(device)
该函数使用Scipy的插值功能为缺失数据进行线性插值,处理边界情况并保持数据类型一致。
-
可视化函数:
def visualize_missing_data_examples(test_loader, missing_rates, output_dir, num_examples=3): """可视化展示原始数据、缺失数据和插值后的数据""" os.makedirs(output_dir, exist_ok=True) # 获取样本并可视化 samples = [] for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader): if i < num_examples: samples.append((batch_x[0], batch_y[0], batch_x_mark[0], batch_y_mark[0])) else: break # 绘制不同缺失率下的数据和插值效果 for sample_idx, (x, y, x_mark, y_mark) in enumerate(samples): # ...绘图代码
该函数创建多个图表,直观展示缺失值的位置和插值的效果,帮助理解插值方法的影响。
设计思路
-
现实场景模拟:
- 随机缺失模式反映真实世界中的传感器故障或数据丢失
- 不同缺失率反映现实中不同的数据质量情况
-
数据修复策略:
- 使用线性插值作为默认的缺失值修复方法
- 处理极端情况(如几乎所有数据都缺失)
-
可视化分析:
- 直观展示原始数据、缺失位置和插值效果
- 帮助理解缺失值对预测性能的影响
共同设计特点与最佳实践
以上三个脚本展现了几个共同的设计特点和最佳实践:
1. 统一的命令行接口
所有脚本使用相同风格的argparse
参数解析,保持参数名称一致性,便于使用者快速上手。
2. 异常处理机制
所有脚本都包含了全面的异常处理,特别是在模型加载环节:
try:
# 尝试直接加载到当前设备
state_dict = torch.load(args.checkpoint_path, map_location=accelerator.device)
model.load_state_dict(state_dict)
except Exception as e:
# 备选加载方法
try:
state_dict = torch.load(args.checkpoint_path, map_location='cpu')
model.load_state_dict(state_dict)
except Exception as e2:
# 继续使用未初始化模型
这种多层次的异常处理确保了脚本在不同环境下的鲁棒性。
3. 分布式计算支持
所有脚本都使用了Accelerate库来支持分布式计算:
accelerator = Accelerator()
test_loader, model = accelerator.prepare(test_loader, model)
这种设计使得脚本能够无缝适应单GPU、多GPU和CPU环境。
4. 结果可视化与存储
所有测试脚本都包含了完善的结果可视化和存储机制:
- 使用Matplotlib创建可视化图表
- 将数值结果保存为CSV格式
- 使用统一的命名规则便于结果分析
5. 精度管理
脚本特别关注数据类型和精度问题:
- 确保噪声和插值后的数据与原始数据类型一致
- 支持混合精度训练
脚本使用指南
run_main.py
训练模式:
python run_main.py --task_name long_term_forecast --is_training 1 --model TimeLLM \
--data beam --data_path processed_number_value_cell_2_beam_19.csv \
--model_id CELL_2_BEAM_19 --root_path ./dataset/processed/ \
--seq_len 12 --label_len 12 --pred_len 24 --batch_size 32 --train_rate 0.8
测试模式:
python run_main.py --task_name long_term_forecast --is_training 0 --model TimeLLM \
--data beam --data_path processed_number_value_cell_2_beam_19.csv \
--model_id CELL_2_BEAM_19 --root_path ./dataset/processed/ \
--checkpoint_path ./results/number_value_cell_2_beam_19/long_term_train8.ckpt \
--seq_len 12 --label_len 12 --pred_len 24 --batch_size 32 --visualize_predictions
robustness_test.py
python robustness_test.py --model TimeLLM --task_name long_term_forecast \
--checkpoint_path ./results/number_value_cell_2_beam_19/long_term_train8.ckpt \
--data beam --model_id CELL_2_BEAM_19 \
--output_dir ./results/number_value_cell_2_beam_19/long_term_train8/ \
--data_path processed_number_value_cell_2_beam_19.csv \
--root_path ./dataset/processed/ \
--noise_min 0.01 --noise_max 0.05 --noise_steps 5
missing_data_test.py
python missing_data_test.py --model TimeLLM --task_name long_term_forecast \
--checkpoint_path ./results/number_value_cell_2_beam_19/long_term_train8.ckpt \
--data beam --model_id CELL_2_BEAM_19 \
--output_dir ./results/number_value_cell_2_beam_19/long_term_train8/ \
--data_path processed_number_value_cell_2_beam_19.csv \
--root_path ./dataset/processed/ \
--missing_min 0.0 --missing_max 0.25 --missing_steps 6 \
--interpolation linear --visualize_examples
常见问题解决
1. 模型加载问题
问题:加载预训练模型时出现数据类型不匹配错误。
解决方案:
- 确保加载时使用正确的
map_location
参数 - 检查模型保存时的数据类型和加载时的期望类型
- 使用脚本中的备选加载方法,先加载到CPU然后再转移到目标设备
2. 内存溢出问题
问题:处理大型数据集或大型模型时出现内存溢出。
解决方案:
- 减小批量大小(
batch_size
) - 启用混合精度训练(
--use_amp
) - 减少语言模型层数(
--llm_layers
) - 使用较小的语言模型(如从LLAMA切换到BERT)
3. 数据类型不一致问题
问题:在噪声测试或缺失数据测试中出现数据类型不一致警告或错误。
解决方案:
- 确保生成的噪声和原始数据类型一致
- 在插值过程中注意保持数据类型
- 使用
to(dtype=x.dtype)
显式指定数据类型
4. 分布式训练问题
问题:多GPU环境下模型训练不同步或崩溃。
解决方案:
- 确保DeepSpeed配置正确(
ds_config_zero2.json
) - 使用
find_unused_parameters=True
参数处理未使用参数 - 在不同GPU上使用相同的随机种子确保一致性
5. 插值异常问题
问题:极端缺失率下插值失败或结果异常。
解决方案:
- 使用更鲁棒的插值方法(如'nearest'而非'linear')
- 对极端缺失情况(如只有0或1个有效点)使用特殊处理逻辑
- 考虑在预处理阶段完全删除缺失严重的样本