verRL:load dataset - chunhualiao/public-docs GitHub Wiki

veRL

Where in the Code is the Training Data Read?

The user is asking where in the code the training data is read. Since the previous response mentioned that the script runs verl.trainer.main_ppo, I should examine that file. I'll start by listing the files in verl/trainer/ to confirm the file exists and see if there are any other relevant files.

Checking verl/trainer/main_ppo.py

Cline wants to view the top-level files in this directory:

verl/trainer/main_ppo.py exists. I should read this file to find where the training data is read.

After examining verl/trainer/main_ppo.py, I don't see any explicit data loading logic. The script seems to configure the training process and initialize workers, but the actual data loading likely happens within the RayPPOTrainer class or its associated worker classes. The config parameters in the script point to data files, but the code itself doesn't directly read them. The RayPPOTrainer class is imported from verl.trainer.ppo.ray_trainer. So, I should examine verl/trainer/ppo/ray_trainer.py.

Checking verl/trainer/ppo/ray_trainer.py

After reading verl/trainer/ppo/ray_trainer.py, I found the _create_dataloader method, which creates train_dataloader and val_dataloader. It uses verl.utils.dataset.rl_dataset.RLHFDataset and torch.utils.data.DataLoader. The RLHFDataset is initialized with parquet_files from the config, which are the training and validation files specified in the original script. So the data loading happens within the RLHFDataset class and the standard DataLoader. I should examine verl/utils/dataset/rl_dataset.py to see how the RLHFDataset class reads the data.

Checking verl/utils/dataset/rl_dataset.py

I have found the relevant code. In verl/utils/dataset/rl_dataset.py, the RLHFDataset class reads the data in the _read_files_and_tokenize method. It uses pd.read_parquet to read the parquet files specified in the constructor. The data is stored in self.dataframe. The __getitem__ method then processes individual rows and tokenizes the prompts.

The training data is read in the verl/utils/dataset/rl_dataset.py file, specifically within the RLHFDataset class.

How the Data Loading Works

Initialization (__init__ method):

  • The RLHFDataset is initialized with a list of parquet file paths (parquet_files), a tokenizer (tokenizer), and other configuration options.
  • It stores these file paths and downloads the files locally using copy_local_path_from_hdfs in the _download method.

Reading and Tokenizing (_read_files_and_tokenize method):

  • The _read_files_and_tokenize method reads the parquet files using pd.read_parquet from the pandas library.
  • It reads each file into a pandas DataFrame and then concatenates them into a single DataFrame (self.dataframe).
  • It then filters out prompts that are too long based on max_prompt_length.

Getting Items (__getitem__ method):

  • The __getitem__ method is called by the PyTorch DataLoader to retrieve individual data samples.
  • It retrieves a row from self.dataframe as a dictionary.
  • It extracts the prompt, applies a chat template (if provided), and tokenizes the prompt using verl_F.tokenize_and_postprocess_data, which uses the provided tokenizer.
  • It returns a dictionary containing input_ids, attention_mask, position_ids, and other information from the row.

Collation (collate_fn function):

  • The collate_fn function takes a list of dictionaries (returned by __getitem__) and combines them into a single batch.
  • It stacks the tensors and converts the non-tensor data into numpy arrays.

Summary

The RLHFDataset class handles reading the data from the specified parquet files, tokenizing the prompts, and preparing the data for the model. The DataLoader then uses this class to efficiently load and batch the data during training.


Inferring the Format of the Training Data

The user wants to know the format of the training data and see an example. I have already examined the relevant code in verl/utils/dataset/rl_dataset.py. The RLHFDataset class reads data from parquet files into a pandas DataFrame. The __getitem__ method shows how a row from this DataFrame is processed.

The key part is:

row_dict = self.dataframe.iloc[item].to_dict()

which converts a row to a dictionary. The prompt is extracted using:

chat = row_dict.pop(self.prompt_key)

Other relevant keys are added to the dictionary, such as input_ids, attention_mask, position_ids, raw_prompt (if return_raw_chat is True), and index. The prompt_key defaults to 'prompt' but can be configured.

Inferred Format:

The parquet file is expected to have at least a column containing the prompts. The name of this column is configurable via the prompt_key parameter (defaults to "prompt"). Other columns can be present and will be included in the processed data.

The __getitem__ method processes each row as follows:

  1. It converts the row to a dictionary (row_dict).
  2. It extracts the prompt from the dictionary using chat = row_dict.pop(self.prompt_key). The chat variable will contain the prompt data.
  3. It tokenizes the prompt, creating input_ids, attention_mask, and position_ids.
  4. It adds these tokenized representations to row_dict.
  5. If return_raw_chat is True, it adds the raw prompt to row_dict under the key 'raw_prompt'.
  6. It adds an index to row_dict.
  7. It includes any extra information from the original row data.

Example Data Entry (Before Processing by __getitem__):

Let's assume the parquet file has the following columns: prompt, response, data_source, and extra_info. A possible row (and thus a data entry) could look like this:

{
  "prompt": [{"role": "user", "content": "Explain the theory of relativity."}],
  "response": "...",  // This is not used by the RLHFDataset, but might be present in the data
  "data_source": "some_dataset",
  "extra_info": {"index": 123, "some_other_info": "value"}
}

Example Data Entry (After Processing by __getitem__):

After processing by the __getitem__ method, the data entry would look like this (assuming return_raw_chat is True):

{
  "input_ids": [1, 23, 45, ..., 67],  // Example token IDs
  "attention_mask": [1, 1, 1, ..., 1],  // Example attention mask
  "position_ids": [0, 1, 2, ..., 255], // Example position IDs
  "response": "...",
  "data_source": "some_dataset",
  "raw_prompt": [{"role": "user", "content": "Explain the theory of relativity."}],
  "index": 123,
  "extra_info": {"index": 123, "some_other_info": "value"}
}

Key Points:

  • The most important part is the prompt column, which contains the prompt itself. The format of the prompt is expected to be a list of dictionaries, where each dictionary has a 'role' and 'content' key, similar to the OpenAI chat format.
  • The response field is present in the original data but is not used during the prompt processing in RLHFDataset.
  • input_ids, attention_mask, and position_ids are generated by tokenizing the prompt.
  • The extra_info field can contain arbitrary additional information.
  • The index field is extracted from the extra_info field, if present, otherwise, it defaults to 0.

This example demonstrates the expected structure of the data within the parquet file and how it's transformed by the RLHFDataset.