Prior Prediction - Nerogar/OneTrainer GitHub Wiki

Prior Prediction Background:

A long term project by dxqbYD to add a form of model preservation to Lora Training. Originally developed for use with Flux, as that model tends to more easily see loss of concepts or a class when training a Lora.

Original Pull Request

Usage

In order to use prior prediction, you need to do the following:

  1. Create a new concept, that will be used as the prior prediction concept set.
  2. In the new concept, choose prior prediction as the concept type.

That is it. Now you have a concept that will be used as prior prediction for your training run.

Limitations

  • Only works for Lora training.
  • Was primarily developed for use with Flux. Tested with SDXL, Hunyuan Video and HiDream as well.
  • Prior prediction is more closely tied to train dtype. If you change the training dtype, you may have to recache.
  • Adaptive optimizers will likely struggle with prior prediction, as the loss for the prior prediction train steps will be very close to zero.

Guidance

Practical guidelines for prior prediction are a work in progress. This section will be updated as more information is found.

  • What should you NOT use for your prior prediction concept?

    • Your training dataset with the same captions. The model will be confused by this setup, so do not do it.
  • What should you use for your prior prediction set?

    • Related concept with simple captions (from joycaption or similar)
    • Related concept with no captions (not recommended, but has been shown to work)
  • What should the ratio of images be?

    • 1:1 is a good starting point
  • What batch size should I use?

    • Batch sizes of 1 and 2 have been both tested.
  • What LR do I need to use?

    • You may need to increase your LR slightly.

Technical Information

What does prior prediction do at a technical level?

If a concept is marked as a prior prediction concept, the training target is changed. Usually the model is trained to denoise any image into the training image. For prior prediction images, the model is trained to denoise the image into the same thing the original model would. It's essentially trying to prevent the model from changing for those training images

Example

From AI Toolkit image

Masked Prior Preservation

Masked prior preservation is related but not the same as the above prior prediction. It uses the concept of prior prediction, but it does it in a different way. The idea here would be to have only one dataset, instead of needing to add additional images. This can save time, but results will always vary.

Normal masked training has an inherent drawback that the model is free to put whatever it wants in the masked area. Masked Prior Preservation is a way to try and combat this weakness. With masked prior preservation, the unmasked area is treated as a standard learning target. The masked area is treated as a prior prediction target. This is a really good way to prevent unwanted items in a training image from being learned and overwhelming a Lora. An example is that Flux and most modern models automatically create a bokeh effect on some images. If you train only with images that do not have this, your resulting lora will by default make clear images as the default. Masked prior preservation will ensure the model retains its original settings more easily without having to try to prompt it back.

To enable masked prior preservation, you need a data set with masks. The concept should be set to "normal", when you create it. You should caption the entire image, as both the normal training and prior prediction need the full caption for the image.

On the train tab, you need to enable masked prior preservation by changing some settings in the masked training area.

  • Masked training needs to be enabled.

  • Unmasked probability and unmasked weight should be set to low values, 0 is what has been tested.

  • Masked Prior Preservation Weight should be set high, 1 is what has been tested. image

  • A known limitation with masked prior preservation is that validation sets do not work.