Core Scripts - Rvosuke/TLLM GitHub Wiki

Time-LLM核心脚本详解

本文档详细介绍Time-LLM项目中的核心Python脚本,包括run_main.pyrobustness_test.pymissing_data_test.py,解析其设计思路、功能实现和运行流程。

目录

  1. run_main.py - 主运行脚本
  2. robustness_test.py - 噪声鲁棒性测试
  3. missing_data_test.py - 缺失数据测试
  4. 共同设计特点与最佳实践
  5. 脚本使用指南
  6. 常见问题解决

run_main.py - 主运行脚本

run_main.py是Time-LLM项目的核心运行脚本,负责模型的训练、验证和测试功能。

功能概述

  1. 训练模式:训练TimeLLM、Autoformer或DLinear模型
  2. 测试模式:加载预训练模型并进行预测与评估
  3. 可视化:支持预测结果可视化

代码流程

  1. 参数解析:使用argparse模块解析命令行参数,包括:

    • 任务类型(长期/短期预测)
    • 数据集配置
    • 模型参数
    • 训练设置
  2. 分布式环境初始化

    if torch.cuda.is_available():
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        deepspeed_plugin = DeepSpeedPlugin(hf_ds_config='./ds_config_zero2.json')
        accelerator = Accelerator(kwargs_handlers=[ddp_kwargs], deepspeed_plugin=deepspeed_plugin)
    else:
        accelerator = Accelerator()
    

    脚本使用Accelerate库实现分布式训练,根据是否有GPU使用不同的配置策略。

  3. 数据加载

    train_data, train_loader = data_provider(args, 'train')
    test_data, test_loader = data_provider(args, 'test')
    

    通过data_provider函数加载训练和测试数据。

  4. 模型初始化

    if args.model == 'Autoformer':
        model = Autoformer.Model(args).float()
    elif args.model == 'DLinear':
        model = DLinear.Model(args).float()
    else:
        model = TimeLLM.Model(args).float()
    

    根据参数选择初始化相应的模型。

  5. 训练流程is_training=1时):

    • 初始化优化器和学习率调度器
    • 设置损失函数和评估指标
    • 使用Accelerate准备数据加载器、模型和优化器
    • 循环训练多个epoch
    • 每个epoch执行以下步骤:
      • 按批次迭代训练数据
      • 构建解码器输入(使用teacher forcing技术)
      • 前向计算得到预测输出
      • 计算损失并反向传播
      • 更新参数和学习率
      • 验证模型性能
      • 应用早停策略
  6. 测试流程is_training=0时):

    • 加载预训练模型权重
    • 评估模型在测试集上的性能
    • 可视化预测结果
    • 保存测试结果

设计思路

  1. 模块化设计

    • 数据处理、模型定义和训练逻辑分离
    • 支持多种模型架构(TimeLLM、Autoformer、DLinear)
  2. 灵活配置

    • 通过命令行参数实现高度可配置化
    • 支持不同的任务类型和数据集
  3. 分布式训练支持

    • 通过Accelerate和DeepSpeed实现高效分布式训练
    • 自动适应不同硬件环境(CPU/GPU)
  4. 混合精度训练

    • 支持自动混合精度训练,提高训练速度和内存效率
    • 适配不同的计算精度要求

robustness_test.py - 噪声鲁棒性测试

robustness_test.py用于评估模型对输入数据中噪声的鲁棒性,通过向测试数据添加不同水平的高斯噪声,分析模型性能变化。

功能概述

  1. 噪声注入:向时间序列数据中添加高斯噪声
  2. 性能评估:在不同噪声水平下评估模型性能
  3. 结果可视化:绘制噪声水平与模型性能的关系图

核心函数解析

  1. 噪声添加函数

    def add_noise_to_data(data_loader, noise_level=0.0, accelerator=None):
        """向数据集添加高斯噪声"""
        noisy_data_loader = []
        device = accelerator.device if accelerator else torch.device("cpu")
        
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(data_loader):
            if not batch_x.is_cuda and device.type == 'cuda':
                batch_x = batch_x.to(device)
                
            if noise_level > 0:
                # 创建与输入相同数据类型的噪声
                dtype = batch_x.dtype
                noise = torch.randn_like(batch_x, dtype=torch.float32) * noise_level
                noise = noise.to(dtype=dtype, device=device)
                
                # 添加噪声
                batch_x = batch_x + noise
            
            noisy_data_loader.append((batch_x, batch_y, batch_x_mark, batch_y_mark))
        
        return noisy_data_loader
    

    该函数仅向输入序列添加噪声,保持标签序列不变,并确保噪声的数据类型与输入一致。

  2. 模型评估函数

    def evaluate_model_with_noise(args, model, test_data, test_loader, noise_levels, accelerator):
        """在不同噪声水平下评估模型性能"""
        criterion = nn.MSELoss()
        mae_metric = nn.L1Loss()
        
        mse_losses = []
        mae_losses = []
        
        for noise_level in noise_levels:
            # 添加噪声并评估
            if noise_level == 0:
                test_loss, test_mae_loss = vali(args, accelerator, model, test_data, test_loader, criterion, mae_metric)
            else:
                noisy_test_loader = add_noise_to_data(test_loader, noise_level, accelerator)
                noisy_loader = NoiseDataLoader(noisy_test_loader)
                test_loss, test_mae_loss = vali(args, accelerator, model, test_data, noisy_loader, criterion, mae_metric)
            
            mse_losses.append(test_loss)
            mae_losses.append(test_mae_loss)
        
        return mse_losses, mae_losses
    

    该函数对每个噪声水平执行评估,返回MSE和MAE损失列表。

  3. 结果可视化函数

    def plot_results(noise_levels, mse_losses, mae_losses, output_dir, model_id, task_name, pred_len):
        """绘制噪声水平与损失的关系图"""
        plt.figure(figsize=(12, 5))
        
        # 绘制MSE损失
        plt.subplot(1, 2, 1)
        plt.plot(noise_levels, mse_losses, 'o-', linewidth=2, markersize=8)
        # ...绘图代码
        
        # 保存结果
        plt.savefig(os.path.join(output_dir, f'robustness_test_{model_id}_{task_name}_pred{pred_len}.png'), dpi=300)
        
        # 保存CSV数据
        results_df = pd.DataFrame({
            'noise_level': noise_levels,
            'mse_loss': mse_losses,
            'mae_loss': mae_losses
        })
        results_df.to_csv(os.path.join(output_dir, f'robustness_test_{model_id}_{task_name}_pred{pred_len}.csv'), index=False)
    

    该函数创建直观的可视化图表,并将原始数据保存为CSV格式。

设计思路

  1. 渐进式噪声测试

    • 从无噪声逐步增加噪声水平
    • 分析模型性能随噪声增加的衰减曲线
  2. 数据类型兼容性

    • 特别注意保持噪声与原始数据的数据类型一致
    • 避免混合精度训练环境中的类型转换问题
  3. 非侵入式设计

    • 不修改原始模型架构,仅改变输入数据
    • 便于快速评估任何预训练模型的鲁棒性

missing_data_test.py - 缺失数据测试

missing_data_test.py用于评估模型处理输入序列中缺失数据的能力,模拟真实场景中的数据缺失问题。

功能概述

  1. 缺失数据模拟:在输入序列中随机创建NaN值
  2. 数据插值:使用插值方法填充缺失值
  3. 性能评估:在不同缺失率下评估模型性能
  4. 可视化:展示原始数据、缺失数据和插值效果

核心函数解析

  1. 缺失数据创建函数

    def create_missing_data(batch_x, missing_rate=0.1, mask=None):
        """在输入序列中创建缺失值"""
        if mask is None:
            # 创建随机掩码,1表示保留,0表示缺失
            mask = torch.FloatTensor(batch_x.shape).uniform_() > missing_rate
            mask = mask.to(batch_x.device)
        
        # 将缺失位置的值设为NaN
        missing_x = batch_x.clone()
        missing_x[~mask] = float('nan')
        
        return missing_x, mask
    

    该函数根据指定的缺失率随机创建缺失值,并使用NaN标记。

  2. 插值填充函数

    def interpolate_missing_values(missing_x):
        """使用插值法填充缺失值"""
        batch_size, seq_len, feature_dim = missing_x.shape
        device = missing_x.device
        
        # 将数据移到CPU进行处理
        x_cpu = missing_x.detach().cpu().numpy()
        interpolated_x = np.zeros_like(x_cpu)
        
        # 为每个batch和每个特征进行插值
        for b in range(batch_size):
            for f in range(feature_dim):
                # 获取当前序列
                series = x_cpu[b, :, f]
                
                # 找出非NaN值的索引
                valid_indices = np.where(~np.isnan(series))[0]
                valid_values = series[valid_indices]
                
                if len(valid_indices) <= 1:  # 如果只有0或1个有效值,无法插值
                    interpolated_x[b, :, f] = 0.0  # 填充为0
                else:
                    if len(valid_indices) < seq_len:  # 如果有缺失值
                        # 使用线性插值填充缺失值
                        interp_func = interpolate.interp1d(
                            valid_indices, valid_values,
                            kind='linear', bounds_error=False, 
                            fill_value=(valid_values[0], valid_values[-1])
                        )
                        all_indices = np.arange(seq_len)
                        interpolated_series = interp_func(all_indices)
                        interpolated_x[b, :, f] = interpolated_series
                    else:  # 如果没有缺失值
                        interpolated_x[b, :, f] = series
        
        # 将插值后的数据转回tensor并移回原设备
        return torch.FloatTensor(interpolated_x).to(device)
    

    该函数使用Scipy的插值功能为缺失数据进行线性插值,处理边界情况并保持数据类型一致。

  3. 可视化函数

    def visualize_missing_data_examples(test_loader, missing_rates, output_dir, num_examples=3):
        """可视化展示原始数据、缺失数据和插值后的数据"""
        os.makedirs(output_dir, exist_ok=True)
        
        # 获取样本并可视化
        samples = []
        for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_loader):
            if i < num_examples:
                samples.append((batch_x[0], batch_y[0], batch_x_mark[0], batch_y_mark[0]))
            else:
                break
        
        # 绘制不同缺失率下的数据和插值效果
        for sample_idx, (x, y, x_mark, y_mark) in enumerate(samples):
            # ...绘图代码
    

    该函数创建多个图表,直观展示缺失值的位置和插值的效果,帮助理解插值方法的影响。

设计思路

  1. 现实场景模拟

    • 随机缺失模式反映真实世界中的传感器故障或数据丢失
    • 不同缺失率反映现实中不同的数据质量情况
  2. 数据修复策略

    • 使用线性插值作为默认的缺失值修复方法
    • 处理极端情况(如几乎所有数据都缺失)
  3. 可视化分析

    • 直观展示原始数据、缺失位置和插值效果
    • 帮助理解缺失值对预测性能的影响

共同设计特点与最佳实践

以上三个脚本展现了几个共同的设计特点和最佳实践:

1. 统一的命令行接口

所有脚本使用相同风格的argparse参数解析,保持参数名称一致性,便于使用者快速上手。

2. 异常处理机制

所有脚本都包含了全面的异常处理,特别是在模型加载环节:

try:
    # 尝试直接加载到当前设备
    state_dict = torch.load(args.checkpoint_path, map_location=accelerator.device)
    model.load_state_dict(state_dict)
except Exception as e:
    # 备选加载方法
    try:
        state_dict = torch.load(args.checkpoint_path, map_location='cpu')
        model.load_state_dict(state_dict)
    except Exception as e2:
        # 继续使用未初始化模型

这种多层次的异常处理确保了脚本在不同环境下的鲁棒性。

3. 分布式计算支持

所有脚本都使用了Accelerate库来支持分布式计算:

accelerator = Accelerator()
test_loader, model = accelerator.prepare(test_loader, model)

这种设计使得脚本能够无缝适应单GPU、多GPU和CPU环境。

4. 结果可视化与存储

所有测试脚本都包含了完善的结果可视化和存储机制:

  • 使用Matplotlib创建可视化图表
  • 将数值结果保存为CSV格式
  • 使用统一的命名规则便于结果分析

5. 精度管理

脚本特别关注数据类型和精度问题:

  • 确保噪声和插值后的数据与原始数据类型一致
  • 支持混合精度训练

脚本使用指南

run_main.py

训练模式

python run_main.py --task_name long_term_forecast --is_training 1 --model TimeLLM \
  --data beam --data_path processed_number_value_cell_2_beam_19.csv \
  --model_id CELL_2_BEAM_19 --root_path ./dataset/processed/ \
  --seq_len 12 --label_len 12 --pred_len 24 --batch_size 32 --train_rate 0.8

测试模式

python run_main.py --task_name long_term_forecast --is_training 0 --model TimeLLM \
  --data beam --data_path processed_number_value_cell_2_beam_19.csv \
  --model_id CELL_2_BEAM_19 --root_path ./dataset/processed/ \
  --checkpoint_path ./results/number_value_cell_2_beam_19/long_term_train8.ckpt \
  --seq_len 12 --label_len 12 --pred_len 24 --batch_size 32 --visualize_predictions

robustness_test.py

python robustness_test.py --model TimeLLM --task_name long_term_forecast \
  --checkpoint_path ./results/number_value_cell_2_beam_19/long_term_train8.ckpt \
  --data beam --model_id CELL_2_BEAM_19 \
  --output_dir ./results/number_value_cell_2_beam_19/long_term_train8/ \
  --data_path processed_number_value_cell_2_beam_19.csv \
  --root_path ./dataset/processed/ \
  --noise_min 0.01 --noise_max 0.05 --noise_steps 5

missing_data_test.py

python missing_data_test.py --model TimeLLM --task_name long_term_forecast \
  --checkpoint_path ./results/number_value_cell_2_beam_19/long_term_train8.ckpt \
  --data beam --model_id CELL_2_BEAM_19 \
  --output_dir ./results/number_value_cell_2_beam_19/long_term_train8/ \
  --data_path processed_number_value_cell_2_beam_19.csv \
  --root_path ./dataset/processed/ \
  --missing_min 0.0 --missing_max 0.25 --missing_steps 6 \
  --interpolation linear --visualize_examples

常见问题解决

1. 模型加载问题

问题:加载预训练模型时出现数据类型不匹配错误。

解决方案

  • 确保加载时使用正确的map_location参数
  • 检查模型保存时的数据类型和加载时的期望类型
  • 使用脚本中的备选加载方法,先加载到CPU然后再转移到目标设备

2. 内存溢出问题

问题:处理大型数据集或大型模型时出现内存溢出。

解决方案

  • 减小批量大小(batch_size)
  • 启用混合精度训练(--use_amp)
  • 减少语言模型层数(--llm_layers)
  • 使用较小的语言模型(如从LLAMA切换到BERT)

3. 数据类型不一致问题

问题:在噪声测试或缺失数据测试中出现数据类型不一致警告或错误。

解决方案

  • 确保生成的噪声和原始数据类型一致
  • 在插值过程中注意保持数据类型
  • 使用to(dtype=x.dtype)显式指定数据类型

4. 分布式训练问题

问题:多GPU环境下模型训练不同步或崩溃。

解决方案

  • 确保DeepSpeed配置正确(ds_config_zero2.json)
  • 使用find_unused_parameters=True参数处理未使用参数
  • 在不同GPU上使用相同的随机种子确保一致性

5. 插值异常问题

问题:极端缺失率下插值失败或结果异常。

解决方案

  • 使用更鲁棒的插值方法(如'nearest'而非'linear')
  • 对极端缺失情况(如只有0或1个有效点)使用特殊处理逻辑
  • 考虑在预处理阶段完全删除缺失严重的样本