Guides MOE Training - kennetholsenatm-gif/q_mini_wasm_v2 GitHub Wiki
This guide covers end-to-end training of 243-expert Mixture of Experts (MoE) models using Forward-Forward learning.
Data Preparation → Router Initialization → Expert Training → Load Balancing → Evaluation
#include "core/moe/moe_trainer.hpp"
#include "core/moe/unified_router.hpp"
// 1. Create 243-expert router
auto router = Create243ExpertRouter();
// 2. Setup trainer
MoETrainingConfig config;
config.learning_rate = 1;
config.training_batch_size = 64;
// ... (truncated)
// See source for complete codeTraining data consists of positive samples (real data). Negative samples are generated automatically by corruption.
std::vector<std::vector<ternary::Trit>> training_data;
// Each sample is a vector of ternary values
std::vector<ternary::Trit> sample;
sample.push_back(ternary::Trit::POSITIVE);
sample.push_back(ternary::Trit::ZERO);
sample.push_back(ternary::Trit::NEGATIVE);
// ... more values
training_data.push_back(sample);// Convert text to ternary encoding
std::vector<ternary::Trit> TextToTernary(const std::string& text) {
std::vector<ternary::Trit> result;
for (char c : text) {
// Map characters to ternary values
// Example: A-M -> POSITIVE, N-Z -> NEGATIVE, space/punct -> ZERO
if (c >= 'A' && c <= 'M') {
result.push_back(ternary::Trit::POSITIVE);
} else if (c >= 'N' && c <= 'Z') {
result.push_back(ternary::Trit::NEGATIVE);
// ... (truncated)
// See source for complete codestd::vector<std::vector<ternary::Trit>> LoadData(const std::string& path) {
std::vector<std::vector<ternary::Trit>> data;
std::ifstream file(path, std::ios::binary);
// Read number of samples
size_t num_samples;
file.read(reinterpret_cast<char*>(&num_samples), sizeof(num_samples));
for (size_t i = 0; i < num_samples; ++i) {
// Read sample size
// ... (truncated)
// See source for complete codeFor better training, augment your data:
// Add noise to samples
std::vector<ternary::Trit> AddNoise(
const std::vector<ternary::Trit>& sample,
float noise_prob = 0.1
) {
static std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(0.0, 1.0);
std::uniform_int_distribution<int> val_dist(-1, 1);
auto result = sample;
for (auto& val : result) {
// ... (truncated)
// See source for complete codeMoETrainingConfig config;
// Forward-Forward settings
config.learning_rate = 1; // GF(3): typically ±1
config.negative_samples_per_positive = 1;
config.training_batch_size = 64;
// Load balancing
config.load_balance_alpha = 0.01f; // 1% load balancing penalty
config.rebalance_interval = 100; // Rebalance every 100 batches
// ... (truncated)
// See source for complete codeconfig.training_batch_size = 16; // Small batches
config.max_epochs = 20; // Fewer epochs
config.log_interval = 1; // Frequent logging
config.verbose = true;config.training_batch_size = 256; // Large batches
config.max_epochs = 500; // Many epochs
config.log_interval = 100; // Infrequent logging
config.early_stopping_patience = 50; // More patienceconfig.training_batch_size = 32; // Smaller batches
config.negative_samples_per_positive = 1; // Minimal negatives
config.load_balance_alpha = 0.001f; // Minimal rebalancingfor (size_t epoch = 0; epoch < num_epochs; ++epoch) {
// 1. Shuffle data
Shuffle(data);
// 2. Process batches
for (size_t batch = 0; batch < num_batches; ++batch) {
auto batch_data = GetBatch(data, batch);
// 3. Train batch
auto metrics = trainer.TrainBatch(batch_data);
// ... (truncated)
// See source for complete codeauto metrics = trainer.GetMetrics();
std::cout << "Epoch: " << metrics.epoch << std::endl;
std::cout << "Goodness Delta: " << metrics.avg_goodness_delta << std::endl;
std::cout << "Load Balance: " << metrics.load_balance_score << std::endl;
std::cout << "Routing Latency: " << metrics.avg_routing_latency_ms << " ms" << std::endl;| Metric | Healthy Range | Concerning | Action |
|---|---|---|---|
| Goodness Delta | > 0.5 | < 0.1 | Increase learning rate, check data |
| Load Balance Score | 0.2-0.4 | > 0.8 | Reduce load_balance_alpha |
| Routing Latency | < 100μs | > 500μs | Enable hierarchical selection |
| Expert Utilization | Uneven | Uniform | Training working! |
-
Positive Sample: Real data from dataset
auto positive = training_data[i]; -
Generate Negative: Corrupt ~10% of values
auto negative = trainer.GenerateNegativeSample(positive); -
Route to Experts: Select top-K experts
auto routing = router.Route(positive); -
Compute Goodness: For both positive and negative
int32_t pos_goodness = expert.ComputeGoodness(expert.Forward(positive)); int32_t neg_goodness = expert.ComputeGoodness(expert.Forward(negative));
-
Update Weights: If positive > negative
int32_t delta = pos_goodness - neg_goodness; if (delta > 0) { // Reinforce weights expert.UpdateWeightsHebbian(positive, delta); }
Training converges when:
- Goodness delta stabilizes > 0.5
- Load balance score remains < 0.5
- No improvement for
early_stopping_patienceepochs
bool HasConverged(const MoETrainingMetrics& metrics) {
return metrics.avg_goodness_delta < min_threshold ||
epochs_without_improvement >= early_stopping_patience;
}Without load balancing:
- Router sends all inputs to same few experts
- Other experts never train
- Model collapses to single-expert behavior
Loss = Forward-Forward Loss + α × Load Balance Loss
where:
Load Balance Loss = variance(expert_utilization)
α = load_balance_alpha (typically 0.01)
Too aggressive (α too high):
- Experts forced to be used equally
- No specialization occurs
- Model quality degrades
Solution: Reduce α
config.load_balance_alpha = 0.001f; // Less aggressiveToo weak (α too low):
- All inputs go to same 2-3 experts
- Other 240+ experts unused
- Wasted capacity
Solution: Increase α
config.load_balance_alpha = 0.1f; // More aggressiveauto stats = router.GetLoadStats();
// Print top 10 most used experts
std::vector<std::pair<float, size_t>> util;
for (size_t i = 0; i < stats.utilization_rates.size(); ++i) {
util.push_back({stats.utilization_rates[i], i});
}
std::sort(util.begin(), util.end(),
[](auto& a, auto& b) { return a.first > b.first; });
// ... (truncated)
// See source for complete codeStart with easy examples, gradually increase difficulty:
// Sort by complexity (e.g., sequence length)
std::sort(data.begin(), data.end(),
[](const auto& a, const auto& b) {
return a.size() < b.size();
});
// Train in phases
for (size_t phase = 0; phase < 3; ++phase) {
size_t end_idx = data.size() * (phase + 1) / 3;
auto phase_data = std::vector(data.begin(), data.begin() + end_idx);
trainer.Train(phase_data, 20);
}Randomly disable experts during training to improve robustness:
// With 10% probability, skip training this expert
std::uniform_real_distribution<float> dist(0.0, 1.0);
for (size_t expert_id : routing.selected_experts) {
if (dist(rng) > 0.1) { // 90% chance to train
auto expert = trainer.GetExpert(expert_id);
expert->TrainForwardForward(positive, negative);
}
}Start with few experts, gradually add more:
// Start with 8 active experts
config.active_experts = 8;
trainer.Train(data, 20);
// Increase to 16
config.active_experts = 16;
trainer.Train(data, 20);
// Final: 32
config.active_experts = 32;
trainer.Train(data, 60);Train on multiple tasks simultaneously:
// Task A: Code generation
auto code_data = LoadCodeData();
// Task B: Natural language
auto text_data = LoadTextData();
// Task C: Mathematics
auto math_data = LoadMathData();
// Combine
std::vector<std::vector<ternary::Trit>> all_data;
// ... (truncated)
// See source for complete code// Save every 10 epochs
if (epoch % 10 == 0) {
std::string path = "moe_checkpoint_epoch_" +
std::to_string(epoch) + ".chk";
trainer.SaveCheckpoint(path);
}// Resume from checkpoint
trainer.LoadCheckpoint("moe_checkpoint_epoch_50.chk");
// Continue training
trainer.Train(data, remaining_epochs);MoETrainingMetrics best_metrics;
float best_score = 0.0f;
for (size_t epoch = 0; epoch < num_epochs; ++epoch) {
auto metrics = trainer.TrainEpoch(data);
// Score based on goodness delta and load balance
float score = metrics.avg_goodness_delta * (1.0f - metrics.load_balance_score);
if (score > best_score) {
best_score = score;
// ... (truncated)
// See source for complete code// Load test data
auto test_data = LoadData("test_data.bin");
// Evaluate
size_t correct = 0;
for (const auto& sample : test_data) {
// Route to experts
auto routing = router.Route(sample);
// Get expert outputs
for (size_t expert_id : routing.selected_experts) {
// ... (truncated)
// See source for complete code// 1. Average goodness on test set
float total_goodness = 0.0f;
for (const auto& sample : test_data) {
auto routing = router.Route(sample);
for (size_t expert_id : routing.selected_experts) {
auto expert = trainer.GetExpert(expert_id);
total_goodness += expert->ComputeGoodness(expert->Forward(sample));
}
}
float avg_goodness = total_goodness / test_data.size();
// ... (truncated)
// See source for complete codeSymptoms: No output, no progress
Check:
- Data loaded correctly?
- Experts initialized?
- Configuration valid?
assert(!data.empty());
assert(trainer.GetExpert(0) != nullptr);
assert(router.Validate243Config());Symptoms: No convergence, delta = 0
Causes:
- All inputs identical
- Experts not learning (weights not updating)
- Negative samples too similar to positive
Solutions:
// Check data diversity
assert(data.size() > 1000);
// Verify weight updates
auto expert = trainer.GetExpert(0);
auto old_weights = expert->GetWeights();
trainer.TrainBatch({sample});
auto new_weights = expert->GetWeights();
assert(old_weights != new_weights); // Should changeSymptoms: OOM crashes during training
Solutions:
-
Reduce batch size:
config.training_batch_size = 32; -
Reduce active experts:
router.GetConfig().active_experts = 8; -
Use gradient checkpointing (for deep experts):
expert_config.num_layers = 2; // Instead of 4
Symptoms: Low accuracy, bad generations
Check:
- Enough training data? (Need >10K samples)
- Enough epochs? (Try 100+)
- Load balance appropriate? (Check utilization)
- Data quality? (Validate no corruption)
-
Always validate configuration before training
assert(router.Validate243Config()); -
Monitor training from epoch 0
config.verbose = true; config.log_interval = 1;
-
Save checkpoints frequently
if (epoch % 5 == 0) SaveCheckpoint();
-
Use validation set to detect overfitting
auto train_metrics = trainer.TrainEpoch(train_data); auto val_metrics = Evaluate(val_data); if (val_metrics.goodness < best_val_goodness * 0.9) { // Overfitting! Stop training. break; }
-
Start with small scale, then expand
- Test with 16 experts first
- Scale to 64, then 128, then 243
- Validate at each scale
- [Expert Network Architecture Guide](Architecture-Expert Networks.md)
- [243-Expert Configuration Guide](243 Expert Config)
- [API Reference: MoETrainer](API-MOE Trainer.md)
- [Forward-Forward Learning](Learning-Forward Forward.md)
Version: 1.0
Last Updated: April 2026