Model Training - golololologol/LLM-Distillery GitHub Wiki

Model Training and Fine-Tuning

The core of this project is the collect_and_finetune.py script, which orchestrates the entire distillation process—from data preparation and model compatibility checks to teacher-guided training and fine-tuning of the student model. This page outlines the key steps involved in this pipeline.


1. Initialization

Configuration and Paths

  • Configuration Loading: The script begins by loading configuration settings using the get_config function. These settings govern everything from device selection and cache sizes to model paths and training parameters.
  • Setting Up Paths: The Paths class initializes directories (such as the cache folder) used to store HDF5 datasets for the main and validation splits.

Model Initialization

  • Teacher Models:
    The get_teachers function loads teacher models from the specified folder. Whether the folder contains a single model file (e.g., a .bin, .json, or .safetensors file) or multiple model directories, each teacher is instantiated using the TeacherModel class.
  • Student Model:
    A student model is created via the StudentModel class. This model will be fine-tuned using the collected knowledge from one or multiple teacher models.

2. Ensuring Compatibility

For successful distillation, teacher and student models must be compatible. The script performs the following checks:

  • Vocabulary Family Verification:
    Using the ensure_compatibility function, the vocabulary family of the teachers (and the first teacher in particular) is compared against that of the student. Any mismatch results in an error.

3. Dataset Preparation and Tokenization

Tokenization of Datasets

  • Training and Validation Datasets:
    Both the training and validation datasets are tokenized via the tokenize_dataset utility.
    • The student’s dataset is transformed into a sequence of tokenized conversations.
    • Similarly, each teacher’s dataset is tokenized to provide corresponding data for collection.

Filtering and Reordering

  • Common IDs Filtering:
    After tokenization, conversation IDs that are common across the teacher and student datasets are identified. This filtration ensures that only valid samples are colleced and passed to the training loop.
  • Dataset Reordering:
    The student’s dataset is reordered to align with the teacher datasets, preparing it for efficient training.

4. HDF5 Data Synchronization

SHA-Based Dataset Sync

The system maintains an HDF5 cache of tokenized data. To synchronize this cache with the current text dataset, the script:

  • Calculates SHA Checksums:
    Each sample in the dataset is assigned a SHA checksum, which is compared against the stored checksums in the H5 datasets.
  • Determines Collection Actions:
    The calculate_sync function computes which samples need to be collected, renamed, or removed based on these comparisons.
  • Dataset Remapping:
    Depending on whether the configuration indicates a rebase (i.e., a complete refresh of the dataset on disk) or an incremental update, the HDF5 data managers for both the main and validation datasets are synchronized appropriately.

Top-K Parameter Check

  • Verification:
    The check_topk function ensures that the top-K sampling parameter in the H5 dataset matches the configuration. If a mismatch is detected, the user is prompted to confirm whether to proceed with the dataset’s stored value.

5. Training and Fine-Tuning

The training process can operate in one of two modes, depending on the size of the HDF5 dataset relative to the specified cache size.

Full Data Collection

  • When Within Cache Limits:
    If the full dataset can be collected within the cache size, all samples (or "chunk" IDs) are processed in one step.
  • Teacher Processing:
    Each teacher processes its designated chunk using the process_chunk function. This function runs inference with configurable options such as reserved VRAM size and number of inference workers.
  • Student Training:
    The student model’s train_chunk method is then invoked to fine-tune on the collected data.

Chunked Data Collection

  • Handling Large Datasets:
    If the dataset size exceeds the maximum cache size, data collection and training occur in smaller, manageable chunks.
  • Epoch-Based Processing:
    The script iterates over a set number of epochs. Within each epoch, the dataset is purged, and each teacher processes individual chunks sequentially.
  • Incremental Training:
    After processing each chunk, the student model is trained on that batch, allowing the system to accommodate very large datasets without overloading available resources.

Validation Data Collection

  • Parallel Processing for Validation:
    Validation samples are also processed using a similar approach, minus chunking, ensuring that the performance of the student model is monitored during training.

6. Finalization and Cleanup

After the training loop concludes:

  • Resource Management:
    All data managers (for both the main and validation datasets) and model instances (teachers and student) are closed to free up system resources.
  • Graceful Termination:
    The handle_termination function is responsible for safely closing connections and exiting the script, minimizing the risk of data corruption or system instability.

Summary

In simpler words, collect_and_finetune.py does all of this:

  • Loads and verifies configurations and models.
  • Prepares and synchronizes tokenized datasets.
  • Performs teacher-led data collection, including handling oversize datasets via chunking.
  • Finetunes the student model through both validation and training loops.
  • Cleans up resources upon completion.