Model Training - golololologol/LLM-Distillery GitHub Wiki
Model Training and Fine-Tuning
The core of this project is the collect_and_finetune.py
script, which orchestrates the entire distillation process—from data preparation and model compatibility checks to teacher-guided training and fine-tuning of the student model. This page outlines the key steps involved in this pipeline.
1. Initialization
Configuration and Paths
- Configuration Loading: The script begins by loading configuration settings using the
get_config
function. These settings govern everything from device selection and cache sizes to model paths and training parameters. - Setting Up Paths: The
Paths
class initializes directories (such as the cache folder) used to store HDF5 datasets for the main and validation splits.
Model Initialization
- Teacher Models:
Theget_teachers
function loads teacher models from the specified folder. Whether the folder contains a single model file (e.g., a.bin
,.json
, or.safetensors
file) or multiple model directories, each teacher is instantiated using theTeacherModel
class. - Student Model:
A student model is created via theStudentModel
class. This model will be fine-tuned using the collected knowledge from one or multiple teacher models.
2. Ensuring Compatibility
For successful distillation, teacher and student models must be compatible. The script performs the following checks:
- Vocabulary Family Verification:
Using theensure_compatibility
function, the vocabulary family of the teachers (and the first teacher in particular) is compared against that of the student. Any mismatch results in an error.
3. Dataset Preparation and Tokenization
Tokenization of Datasets
- Training and Validation Datasets:
Both the training and validation datasets are tokenized via thetokenize_dataset
utility.- The student’s dataset is transformed into a sequence of tokenized conversations.
- Similarly, each teacher’s dataset is tokenized to provide corresponding data for collection.
Filtering and Reordering
- Common IDs Filtering:
After tokenization, conversation IDs that are common across the teacher and student datasets are identified. This filtration ensures that only valid samples are colleced and passed to the training loop. - Dataset Reordering:
The student’s dataset is reordered to align with the teacher datasets, preparing it for efficient training.
4. HDF5 Data Synchronization
SHA-Based Dataset Sync
The system maintains an HDF5 cache of tokenized data. To synchronize this cache with the current text dataset, the script:
- Calculates SHA Checksums:
Each sample in the dataset is assigned a SHA checksum, which is compared against the stored checksums in the H5 datasets. - Determines Collection Actions:
Thecalculate_sync
function computes which samples need to be collected, renamed, or removed based on these comparisons. - Dataset Remapping:
Depending on whether the configuration indicates a rebase (i.e., a complete refresh of the dataset on disk) or an incremental update, the HDF5 data managers for both the main and validation datasets are synchronized appropriately.
Top-K Parameter Check
- Verification:
Thecheck_topk
function ensures that the top-K sampling parameter in the H5 dataset matches the configuration. If a mismatch is detected, the user is prompted to confirm whether to proceed with the dataset’s stored value.
5. Training and Fine-Tuning
The training process can operate in one of two modes, depending on the size of the HDF5 dataset relative to the specified cache size.
Full Data Collection
- When Within Cache Limits:
If the full dataset can be collected within the cache size, all samples (or "chunk" IDs) are processed in one step. - Teacher Processing:
Each teacher processes its designated chunk using theprocess_chunk
function. This function runs inference with configurable options such as reserved VRAM size and number of inference workers. - Student Training:
The student model’strain_chunk
method is then invoked to fine-tune on the collected data.
Chunked Data Collection
- Handling Large Datasets:
If the dataset size exceeds the maximum cache size, data collection and training occur in smaller, manageable chunks. - Epoch-Based Processing:
The script iterates over a set number of epochs. Within each epoch, the dataset is purged, and each teacher processes individual chunks sequentially. - Incremental Training:
After processing each chunk, the student model is trained on that batch, allowing the system to accommodate very large datasets without overloading available resources.
Validation Data Collection
- Parallel Processing for Validation:
Validation samples are also processed using a similar approach, minus chunking, ensuring that the performance of the student model is monitored during training.
6. Finalization and Cleanup
After the training loop concludes:
- Resource Management:
All data managers (for both the main and validation datasets) and model instances (teachers and student) are closed to free up system resources. - Graceful Termination:
Thehandle_termination
function is responsible for safely closing connections and exiting the script, minimizing the risk of data corruption or system instability.
Summary
In simpler words, collect_and_finetune.py
does all of this:
- Loads and verifies configurations and models.
- Prepares and synchronizes tokenized datasets.
- Performs teacher-led data collection, including handling oversize datasets via chunking.
- Finetunes the student model through both validation and training loops.
- Cleans up resources upon completion.