verRL:load dataset - chunhualiao/public-docs GitHub Wiki
Where in the Code is the Training Data Read?
The user is asking where in the code the training data is read. Since the previous response mentioned that the script runs verl.trainer.main_ppo
, I should examine that file. I'll start by listing the files in verl/trainer/
to confirm the file exists and see if there are any other relevant files.
verl/trainer/main_ppo.py
Checking Cline wants to view the top-level files in this directory:
verl/trainer/main_ppo.py
exists. I should read this file to find where the training data is read.
After examining verl/trainer/main_ppo.py
, I don't see any explicit data loading logic. The script seems to configure the training process and initialize workers, but the actual data loading likely happens within the RayPPOTrainer
class or its associated worker classes. The config parameters in the script point to data files, but the code itself doesn't directly read them. The RayPPOTrainer
class is imported from verl.trainer.ppo.ray_trainer
. So, I should examine verl/trainer/ppo/ray_trainer.py
.
verl/trainer/ppo/ray_trainer.py
Checking After reading verl/trainer/ppo/ray_trainer.py
, I found the _create_dataloader
method, which creates train_dataloader
and val_dataloader
. It uses verl.utils.dataset.rl_dataset.RLHFDataset
and torch.utils.data.DataLoader
. The RLHFDataset
is initialized with parquet_files
from the config, which are the training and validation files specified in the original script. So the data loading happens within the RLHFDataset
class and the standard DataLoader
. I should examine verl/utils/dataset/rl_dataset.py
to see how the RLHFDataset
class reads the data.
verl/utils/dataset/rl_dataset.py
Checking I have found the relevant code. In verl/utils/dataset/rl_dataset.py
, the RLHFDataset
class reads the data in the _read_files_and_tokenize
method. It uses pd.read_parquet
to read the parquet files specified in the constructor. The data is stored in self.dataframe
. The __getitem__
method then processes individual rows and tokenizes the prompts.
The training data is read in the verl/utils/dataset/rl_dataset.py
file, specifically within the RLHFDataset
class.
How the Data Loading Works
__init__
method):
Initialization (- The
RLHFDataset
is initialized with a list of parquet file paths (parquet_files
), a tokenizer (tokenizer
), and other configuration options. - It stores these file paths and downloads the files locally using
copy_local_path_from_hdfs
in the_download
method.
_read_files_and_tokenize
method):
Reading and Tokenizing (- The
_read_files_and_tokenize
method reads the parquet files usingpd.read_parquet
from the pandas library. - It reads each file into a pandas DataFrame and then concatenates them into a single DataFrame (
self.dataframe
). - It then filters out prompts that are too long based on
max_prompt_length
.
__getitem__
method):
Getting Items (- The
__getitem__
method is called by the PyTorch DataLoader to retrieve individual data samples. - It retrieves a row from
self.dataframe
as a dictionary. - It extracts the prompt, applies a chat template (if provided), and tokenizes the prompt using
verl_F.tokenize_and_postprocess_data
, which uses the provided tokenizer. - It returns a dictionary containing
input_ids
,attention_mask
,position_ids
, and other information from the row.
collate_fn
function):
Collation (- The
collate_fn
function takes a list of dictionaries (returned by__getitem__
) and combines them into a single batch. - It stacks the tensors and converts the non-tensor data into numpy arrays.
Summary
The RLHFDataset
class handles reading the data from the specified parquet files, tokenizing the prompts, and preparing the data for the model. The DataLoader
then uses this class to efficiently load and batch the data during training.
Inferring the Format of the Training Data
The user wants to know the format of the training data and see an example. I have already examined the relevant code in verl/utils/dataset/rl_dataset.py
. The RLHFDataset
class reads data from parquet files into a pandas DataFrame. The __getitem__
method shows how a row from this DataFrame is processed.
The key part is:
row_dict = self.dataframe.iloc[item].to_dict()
which converts a row to a dictionary. The prompt is extracted using:
chat = row_dict.pop(self.prompt_key)
Other relevant keys are added to the dictionary, such as input_ids
, attention_mask
, position_ids
, raw_prompt
(if return_raw_chat
is True), and index
. The prompt_key
defaults to 'prompt'
but can be configured.
Inferred Format:
The parquet file is expected to have at least a column containing the prompts. The name of this column is configurable via the prompt_key
parameter (defaults to "prompt"
). Other columns can be present and will be included in the processed data.
The __getitem__
method processes each row as follows:
- It converts the row to a dictionary (
row_dict
). - It extracts the
prompt
from the dictionary usingchat = row_dict.pop(self.prompt_key)
. Thechat
variable will contain the prompt data. - It tokenizes the prompt, creating
input_ids
,attention_mask
, andposition_ids
. - It adds these tokenized representations to
row_dict
. - If
return_raw_chat
is True, it adds the raw prompt torow_dict
under the key'raw_prompt'
. - It adds an
index
torow_dict
. - It includes any extra information from the original row data.
__getitem__
):
Example Data Entry (Before Processing by Let's assume the parquet file has the following columns: prompt
, response
, data_source
, and extra_info
. A possible row (and thus a data entry) could look like this:
{
"prompt": [{"role": "user", "content": "Explain the theory of relativity."}],
"response": "...", // This is not used by the RLHFDataset, but might be present in the data
"data_source": "some_dataset",
"extra_info": {"index": 123, "some_other_info": "value"}
}
__getitem__
):
Example Data Entry (After Processing by After processing by the __getitem__
method, the data entry would look like this (assuming return_raw_chat
is True):
{
"input_ids": [1, 23, 45, ..., 67], // Example token IDs
"attention_mask": [1, 1, 1, ..., 1], // Example attention mask
"position_ids": [0, 1, 2, ..., 255], // Example position IDs
"response": "...",
"data_source": "some_dataset",
"raw_prompt": [{"role": "user", "content": "Explain the theory of relativity."}],
"index": 123,
"extra_info": {"index": 123, "some_other_info": "value"}
}
Key Points:
- The most important part is the
prompt
column, which contains the prompt itself. The format of the prompt is expected to be a list of dictionaries, where each dictionary has a'role'
and'content'
key, similar to the OpenAI chat format. - The
response
field is present in the original data but is not used during the prompt processing inRLHFDataset
. input_ids
,attention_mask
, andposition_ids
are generated by tokenizing the prompt.- The
extra_info
field can contain arbitrary additional information. - The
index
field is extracted from theextra_info
field, if present, otherwise, it defaults to0
.
This example demonstrates the expected structure of the data within the parquet file and how it's transformed by the RLHFDataset
.