Ablations - shisa-ai/shisa-v2 GitHub Wiki
While not all runs are directly linked/documented, almost all the completed runs are here:
- Ablation final output/configs: https://huggingface.co/shisa-ai
- Training logs/configs: https://wandb.ai/augmxnt/shisa-v2
Current best model:
- ablation-53-rafathenev2.rp-shisa-v2-mistral-nemo-12b Current best llama3-8b
- ablation-66-a55.dpo.armorm-shisa-v2-llama-3.1-8b
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
tokyotech-llm/Llama-3.3-Swallow-70B-Instruct-v0.4 | 57.91 | 52.56 | 60.34 | 7.508 | 0.497 | 41.600 | 0.662 | 1.700 | 0.393 | 0.803 | 0.726 |
cyberagent/Mistral-Nemo-Japanese-Instruct-2408 | 50.91 | 37.47 | 61.11 | 7.265 | 0.340 | 28.500 | 0.575 | 8.086 | 0.260 | 0.625 | 0.671 |
ablation-66-a55.dpo.armorm-shisa-v2-llama-3.1-8b | 50.58 | 46.25 | 52.83 | 7.228 | 0.427 | 31.800 | 0.555 | 1.452 | 0.207 | 0.821 | 0.610 |
tokyotech-llm/Llama-3.1-Swallow-8B-Instruct-v0.3 | 45.65 | 35.44 | 55.36 | 7.180 | 0.301 | 26.400 | 0.559 | 3.213 | 0.253 | 0.641 | 0.482 |
elyza/Llama-3-ELYZA-JP-8B | 43.97 | 32.88 | 55.20 | 7.015 | 0.338 | 17.500 | 0.414 | 7.042 | 0.240 | 0.618 | 0.433 |
sbintuitions/sarashina2.2-3b-instruct-v0.1 | 42.55 | 29.76 | 52.38 | 7.281 | 0.284 | 21.200 | 0.422 | 3.126 | 0.220 | 0.495 | 0.573 |
shisa-ai/shisa-v1-llama3-8b | 38.31 | 35.15 | 38.75 | 6.303 | 0.362 | 20.200 | 0.225 | 0.993 | 0.087 | 0.630 | 0.518 |
augmxnt/shisa-gamma-7b-v1 | 28.71 | 18.65 | 40.85 | 5.487 | 0.258 | 2.200 | 0.520 | 0.483 | 0.127 | 0.372 | 0.183 |
augmxnt/shisa-7b-v1 | 20.92 | 20.19 | 21.66 | 3.508 | 0.203 | 16.500 | 0.021 | 0.642 | 0.153 | 0.273 | 0.000 |
This experiment tests whether pairwise training improves performance.
Start with our best shisa-v1.1 SFT ablation-11-rafathenev2
and try these four variations:
- seed=42 shuffle (should be almost the same as ablation 11, sanity check)
- seed=42 pairwise EN-JA
- seed=42 pairwise JA-EN
- seed=42 random pairwise
Sadly, SPIN simply does not replicate for us. It is worse that most of our DPO attempts, and worse than the base SFT model
Why does this fail to replicate? Our theory is SPIN was tested on zephyr-7b-sft-full
(Mistral 7B SFT'd on Ultrachat200k), a much weaker model. 50k was then taken from Ultrachat200k for generating new answers, with 50K/50K used for each iteration/n-1 iteration for stability.
In Abacus AI's Smaug DPOP paper, they point out one of DPO's failure modes is that if edit distances are low, then DPO can actually pick the worse answer and quality can go down so maybe this is related: https://arxiv.org/pdf/2402.13228
- https://github.com/abacusai/smaug/issues/2 - maybe also doesn't replicate well
Note, despite doing a test w/ an equal mix, it turns out that using a mix closer to DNO's (using 25% old data) seemed vs an equal amount seemed to be better.
Here's the topline summary: SPIN didn't seem to work at all on our test model. Each iteration is worse than the last, so something hinky is going on since the accuracy/margins go up normally for a DPO tune.
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3.1-8b | 50.18 | 46.16 | 51.91 | 7.151 | 0.423 | 32.600 | 0.553 | 0.961 | 0.207 | 0.809 | 0.616 |
spin-01-a55.test-shisa-v2-llama-3.1-8b-iter0 | 48.39 | 45.44 | 50.17 | 7.004 | 0.439 | 29.800 | 0.535 | 1.102 | 0.167 | 0.799 | 0.543 |
spin-00-a55.baseline-shisa-v2-llama-3.1-8b | 44.25 | 43.34 | 45.84 | 6.561 | 0.437 | 25.200 | 0.497 | 0.157 | 0.147 | 0.789 | 0.409 |
spin-01-a55.test-shisa-v2-llama-3.1-8b-iter1 | 41.26 | 43.92 | 37.57 | 4.829 | 0.418 | 27.200 | 0.522 | 0.049 | 0.147 | 0.816 | 0.463 |
- The baseline takes 72K (same total number of rows that 3 iteration shows) to compare what the DPO does
- We pre-shuffle our dataset to make sure there is a mix of categories before we slice
- This recipe mixes some of the previous round in for stability
- spin-01-iter0 (100%): ~20000K rows
- spin-01-iter1 (100%) + spin-01-iter0 (25%)
- spin-01-iter2 (100%) + spin-01-iter1 (25%) + spin-01-iter0 (5%)
The SPIN recipe uses 3 epochs of DPO. Is that the right number for each iteration? It seems, yes - overall, our output is better when we DPO more iterations:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3.1-8b | 50.18 | 46.16 | 51.91 | 7.151 | 0.423 | 32.600 | 0.553 | 0.961 | 0.207 | 0.809 | 0.616 |
spin-01-a55.test-shisa-v2-llama-3.1-8b-iter0 (3epoch) | 48.39 | 45.44 | 50.17 | 7.004 | 0.439 | 29.800 | 0.535 | 1.102 | 0.167 | 0.799 | 0.543 |
spin-01-a55.test-shisa-v2-llama-3.1-8b-iter0.2epoch | 47.84 | 45.49 | 49.39 | 6.960 | 0.436 | 29.900 | 0.541 | 0.719 | 0.153 | 0.804 | 0.518 |
spin-00-a55.baseline-shisa-v2-llama-3.1-8b.1epoch | 47.79 | 45.68 | 50.06 | 6.808 | 0.454 | 29.200 | 0.525 | 0.853 | 0.233 | 0.793 | 0.470 |
As an interesting counterpoint, 1 epoch of the 75K baseline set does better with each epoch making the model perform worse:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
spin-00-a55.baseline-shisa-v2-llama-3.1-8b.1epoch | 47.79 | 45.68 | 50.06 | 6.808 | 0.454 | 29.200 | 0.525 | 0.853 | 0.233 | 0.793 | 0.470 |
spin-00-a55.baseline-shisa-v2-llama-3.1-8b.2epoch | 46.64 | 44.57 | 47.46 | 6.662 | N/A | 26.400 | 0.511 | 0.582 | 0.167 | 0.809 | 0.488 |
spin-00-a55.baseline-shisa-v2-llama-3.1-8b (3 epoch) | 44.25 | 43.34 | 45.84 | 6.561 | 0.437 | 25.200 | 0.497 | 0.157 | 0.147 | 0.789 | 0.409 |
Both iter1 runs are worse than iter0, but the 100% iter1 + 100% iter0 mix is worse than 100% iter1 and 25% iter0 data.
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
spin-01-a55.test-shisa-v2-llama-3.1-8b-iter1 | 41.26 | 43.92 | 37.57 | 4.829 | 0.418 | 27.200 | 0.522 | 0.049 | 0.147 | 0.816 | 0.463 |
spin-01-a55.test-shisa-v2-llama-3.1-8b-iter1-2x | 39.98 | 43.65 | 35.05 | 4.168 | 0.414 | 26.900 | 0.516 | 0.021 | 0.193 | 0.816 | 0.463 |
Basically no difference between armorm vs olmo2 preference mixes, except the latter takes a lot longer to train.
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-66-a55.dpo.armorm-shisa-v2-llama-3.1-8b | 50.58 | 46.25 | 52.83 | 7.228 | 0.427 | 31.800 | 0.555 | 1.452 | 0.207 | 0.821 | 0.610 |
ablation-74-a55.dpo.olmo2-shisa-v2-llama-3.1-8b | 50.54 | 48.35 | 51.14 | 7.089 | 0.461 | 32.000 | 0.557 | 0.294 | 0.213 | 0.856 | 0.585 |
ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3.1-8b | 50.18 | 46.16 | 51.91 | 7.151 | 0.423 | 32.600 | 0.553 | 0.961 | 0.207 | 0.809 | 0.616 |
- Olmo preference mix takes 24h 42m https://wandb.ai/augmxnt/shisa-v2/runs/r9ycjoyt/overview
- The ArmoRM dpo set takes 1h 17m https://wandb.ai/augmxnt/shisa-v2/runs/vtuig65r/overview
The KTOTrainer does not appear to work with DeepSpeed-ZeRO 3
[ip-10-1-21-143:7]:[rank7]: Traceback (most recent call last):
[ip-10-1-21-143:7]:[rank7]: File "<frozen runpy>", line 198, in _run_module_as_main
[ip-10-1-21-143:7]:[rank7]: File "<frozen runpy>", line 88, in _run_code
[ip-10-1-21-143:7]:[rank7]: File "/fsx2/axolotl/axolotl/src/axolotl/cli/train.py", line 116, in <module>
[ip-10-1-21-143:7]:[rank7]: fire.Fire(do_cli)
[ip-10-1-21-143:7]:[rank7]: File "/fsx/ubuntu/miniforge3/envs/axolotl/lib/python3.12/site-packages/fire/core.py", line 135, in Fire
[ip-10-1-21-143:7]:[rank7]: component_trace = _Fire(component, args, parsed_flag_args, context, name)
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx/ubuntu/miniforge3/envs/axolotl/lib/python3.12/site-packages/fire/core.py", line 468, in _Fire
[ip-10-1-21-143:7]:[rank7]: component, remaining_args = _CallAndUpdateTrace(
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx/ubuntu/miniforge3/envs/axolotl/lib/python3.12/site-packages/fire/core.py", line 684, in _CallAndUpdateTrace
[ip-10-1-21-143:7]:[rank7]: component = fn(*varargs, **kwargs)
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx2/axolotl/axolotl/src/axolotl/cli/train.py", line 90, in do_cli
[ip-10-1-21-143:7]:[rank7]: return do_train(parsed_cfg, parsed_cli_args)
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx2/axolotl/axolotl/src/axolotl/cli/train.py", line 46, in do_train
[ip-10-1-21-143:7]:[rank7]: model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx2/axolotl/axolotl/src/axolotl/train.py", line 483, in train
[ip-10-1-21-143:7]:[rank7]: ) = setup_model_and_trainer(cfg, dataset_meta)
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx2/axolotl/axolotl/src/axolotl/train.py", line 444, in setup_model_and_trainer
[ip-10-1-21-143:7]:[rank7]: trainer = setup_trainer(
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx2/axolotl/axolotl/src/axolotl/utils/trainer.py", line 615, in setup_trainer
[ip-10-1-21-143:7]:[rank7]: return trainer_builder.build(total_num_steps)
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx2/axolotl/axolotl/src/axolotl/core/trainer_builder.py", line 1146, in build
[ip-10-1-21-143:7]:[rank7]: dpo_trainer = trainer_cls(
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx/ubuntu/miniforge3/envs/axolotl/lib/python3.12/site-packages/trl/trainer/kto_trainer.py", line 474, in __init__
[ip-10-1-21-143:7]:[rank7]: self.ref_model = create_reference_model(model)
[ip-10-1-21-143:7]:[rank7]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[ip-10-1-21-143:7]:[rank7]: File "/fsx/ubuntu/miniforge3/envs/axolotl/lib/python3.12/site-packages/trl/models/modeling_base.py", line 622, in create_reference_model
[ip-10-1-21-143:7]:[rank7]: raise ValueError(
[ip-10-1-21-143:7]:[rank7]: ValueError: DeepSpeed ZeRO-3 is enabled and is not compatible with `create_reference_model()`. Please instantiate your reference model directly with `AutoModelForCausalLM.from_pretrained()`.
Apparently TRL >0.13.0 is Broken w/ SimPO
- ablation-60-a55.dpo.enja-shisa-v2-llama-3.1-8b
- ablation-64-a55.simpo.enja-shisa-v2-llama-3.1-8b
- ablation-65-a55.simpo.armorm-shisa-v2-llama-3.1-8b
- See https://github.com/huggingface/trl/issues/2882 SimPO w/ https://github.com/princeton-nlp/SimPO https://huggingface.co/datasets/princeton-nlp/gemma2-ultrafeedback-armorm Currently broken in TRL for multiGPU?
OK, we can get this working on Axolotl 0.8.0 by downgrading trl from 0.15.1 (0.15.2 current) to 0.13.0
- related?
- ORPO also broken since 0.13.0
pip install trl==0.13.0
follow directions
make some changs to script imports
pip install wandb
[rank2]: Traceback (most recent call last):
[rank2]: File "/fsx/ubuntu/meti/train/simpo/SimPO/scripts/run_simpo.py", line 316, in <module>
[rank2]: main()
[rank2]: File "/fsx/ubuntu/meti/train/simpo/SimPO/scripts/run_simpo.py", line 269, in main
[rank2]: train_result = trainer.train(resume_from_checkpoint=checkpoint)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/fsx/ubuntu/miniforge3/envs/simpo/lib/python3.12/site-packages/transformers/trainer.py", line 2241, in train
[rank2]: return inner_training_loop(
[rank2]: ^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/fsx/ubuntu/miniforge3/envs/simpo/lib/python3.12/site-packages/transformers/trainer.py", line 2500, in _inner_training_loop
[rank2]: batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/fsx/ubuntu/meti/train/simpo/SimPO/scripts/simpo_trainer.py", line 764, in get_batch_samples
[rank2]: policy_output = model.generate(
[rank2]: ^^^^^^^^^^^^^^
[rank2]: AttributeError: 'generator' object has no attribute 'generate'
SimPO has was being a pain, but since they had a DPO set I decided to throw an ablation on to test. It gives a bit more of a boost than the EN/JA test set we were using.
- Preference set: https://huggingface.co/datasets/princeton-nlp/gemma2-ultrafeedback-armorm
- Notes:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-66-a55.dpo.armorm-shisa-v2-llama-3.1-8b | 47.95 | 46.25 | 47.06 | 7.228 | 0.427 | 31.800 | 0.555 | 1.452 | 0.207 | 0.821 | 0.610 |
ablation-59-a55.dpo.en-shisa-v2-llama-3.1-8b | 47.65 | 45.91 | 46.48 | 7.133 | 0.429 | 31.400 | 0.559 | 1.312 | 0.207 | 0.810 | 0.622 |
ablation-60-a55.dpo.enja-shisa-v2-llama-3.1-8b | 47.56 | 45.41 | 46.30 | 7.225 | 0.420 | 31.400 | 0.554 | 1.029 | 0.213 | 0.802 | 0.646 |
ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3.1-8b | 47.36 | 46.16 | 45.72 | 7.151 | 0.423 | 32.600 | 0.553 | 0.961 | 0.207 | 0.809 | 0.616 |
Adding the Yahoo! dataset has very little effect on Llama 3.1 8B
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-8b-lr8e6 | 50.32 | 46.41 | 51.73 | 7.141 | 0.417 | 33.900 | 0.554 | 1.489 | 0.173 | 0.808 | 0.628 |
ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3.1-8b | 50.18 | 46.16 | 51.91 | 7.151 | 0.423 | 32.600 | 0.553 | 0.961 | 0.207 | 0.809 | 0.616 |
ablation-63-rafathenev2.rp.tl.yahoo-shisa-v2-llama-3.1-8b | 49.62 | 45.66 | 50.70 | 7.161 | 0.425 | 31.200 | 0.559 | 0.910 | 0.140 | 0.808 | 0.640 |
So we give this a try on Mistral Nemo 12B and get a rather disastrous result:
- brings down Shaberi scores
- improves mixeval
- dramatically lowers LiveBench
- slightly lowers llm-jp-eval
- dramatically lowers jp rp-bench
- lowers ifeval We do 2 eval runs just to confirm:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-53-rafathenev2.rp-shisa-v2-mistral-nemo-12b | 51.36 | 43.29 | 57.01 | 7.212 | 0.390 | 34.100 | 0.575 | 5.596 | 0.193 | 0.702 | 0.634 |
ablation-53-rafathenev2.rp-shisa-v2-mistral-nemo-12b.run1 | 51.34 | 42.05 | 58.21 | 7.222 | 0.373 | 32.700 | 0.573 | 6.504 | 0.207 | 0.702 | 0.634 |
ablation-67-rafathenev2.rp.tl.yahoo-mistral-nemo-12b | 46.90 | 40.46 | 50.27 | 6.870 | 0.461 | 25.000 | 0.560 | 0.856 | 0.193 | 0.601 | 0.622 |
ablation-67-rafathenev2.rp.tl.yahoo-mistral-nemo-12b.run1 | 46.68 | 37.70 | 52.54 | 6.875 | 0.398 | 24.500 | 0.558 | 2.209 | 0.240 | 0.600 | 0.622 |
This initial test uses https://huggingface.co/datasets/shisa-ai/allenai_llama-3.1-tulu-3-405b-preference-mixture-filter-datecutoff_jpn - a 46.5K subset of AI2's 361K row Tulu 405B preference set: https://huggingface.co/datasets/allenai/llama-3.1-tulu-3-405b-preference-mixture that uses Tulu 405B to create Japanese translations.
We cannot use our dataset directly as there are mismatched rows that cause Axolotl to freak out. Instead, we have custom code that generates the clean prompt, chosen and rejected output:
- ablation-55-rafathenev2.rp.tl - current SFT mix
- ablation-58-a55.dpo.ultra - original EN chat results in a naive way (this is improper training)
- ablation-59-a55.dpo.en - EN chat set
- ablation-60-a55.dpo.enja - EN+JA chat sets shuffled
- ablation-61-a55.dpo.enjanot - EN+JA + and shuffles in variants using the EN/JA cross sets as rejected - the theory is this would help with replying in the correct language, but it trains double the data, so I guess we'll see what happens
- ablation-62-dpo.notseq - this takes ablation-60 (EN+JA) and then trains the EN/JA cross sets sequentially after
We train all of these for 1 epoch at a low 5e-7 LR w/ linear decay and we see steady improvement of margins and accuracy that show that the DPO is effective with our hyper-paramaters.
We don't get much benefit (slightly better, but really, within MOE):
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-59-a55.dpo.en-shisa-v2-llama-3.1-8b | 47.65 | 45.91 | 46.48 | 7.133 | 0.429 | 31.400 | 0.559 | 1.312 | 0.207 | 0.810 | 0.622 |
ablation-60-a55.dpo.enja-shisa-v2-llama-3.1-8b | 47.56 | 45.41 | 46.30 | 7.225 | 0.420 | 31.400 | 0.554 | 1.029 | 0.213 | 0.802 | 0.646 |
ablation-58-a55.dpo-shisa-v2-llama-3.1-8b | 47.38 | 45.42 | 46.26 | 7.050 | 0.427 | 30.800 | 0.560 | 1.163 | 0.227 | 0.801 | 0.628 |
ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3.1-8b | 47.36 | 46.16 | 45.72 | 7.151 | 0.423 | 32.600 | 0.553 | 0.961 | 0.207 | 0.809 | 0.616 |
ablation-61-a55.dpo.enjanot-shisa-v2-llama-3.1-8b | 47.13 | 45.73 | 45.77 | 7.166 | 0.435 | 30.700 | 0.555 | 1.001 | 0.200 | 0.803 | 0.610 |
ablation-62-dpo.notseq-ablation-60-a55.dpo.enja | 42.88 | 40.50 | 41.76 | 6.893 | 0.401 | 25.400 | 0.512 | 0.777 | 0.120 | 0.715 | 0.604 |
- While it's good to have JA preference data, one issue with having Tulu 405B do translations is that it will even out quality to certain level (good/bad) - if we are using preference data, we really need to do our own native generations
- An additional factor is that translations are likely not the way to go vs generations from a quality perspective
- JA perf improved even with EN only preference data, which is a useful data point
Adding the translation result:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-8b-lr8e6 | 47.75 | 46.41 | 46.08 | 7.141 | 0.423 | 33.900 | 0.553 | 1.489 | 0.173 | 0.800 | 0.628 |
ablation-55-rafathenev2.rp.toklen-shisa-v2-llama-3.1-8b-lr8e6 | 47.36 | 46.16 | 45.72 | 7.151 | 0.423 | 32.600 | 0.553 | 0.961 | 0.207 | 0.809 | 0.616 |
MODEL=ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-8b-lr8e6 OPENAI_URL=http://localhost:8000/v1 ./run_translation_bench.sh
vllm serve /fsx2/outputs/ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-8b-lr8e6 -tp 8 --gpu-memory-utilization 0.90 --num-scheduler-steps 20 --port 8000 --served-model-name ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-8b-lr8e6
EASY
---
9 ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-... 1.276373 171 299 2.652980e-01 7.818317e+00
9 ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3... 0.862285 153 300 1.813070e-01 7.031378e+00
HARD
---
10 ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-... 0.147662 208 420 0.145712 5.368485
10 ablation-55-rafathenev2.rp.tl-shisa-v2-llama-3... 0.153531 208 420 0.153409 5.383076
ablation-46-rafathenev2.rp-shisa-v2-llama-3.1-... 5.368485
shisa-v1.1-nemo-12b shisa-v1.1-llama3.1-8b shisa-v1.1-gamma-7b shisa-v1.1-7b
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
augmxnt/shisa-gamma-7b-v1 | 26.43 | 18.65 | 35.85 | 5.487 | 0.258 | 2.200 | 0.520 | 0.483 | 0.127 | 0.372 | 0.183 |
augmxnt/shisa-7b-v1 | 19.49 | 20.19 | 18.79 | 3.508 | 0.203 | 16.500 | 0.021 | 0.642 | 0.153 | 0.273 | 0.000 |
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-11-rafathenev2-shisa-v2-llama-3.1-8b-lr8e6 | 47.72 | 46.15 | 45.80 | 7.182 | 0.430 | 32.400 | 0.557 | 1.437 | 0.153 | 0.798 | 0.652 |
ablation-52-rafathenev2.0.8.0-shisa-v2-llama-3.1-8b-lr8e6 | 45.81 | 45.04 | 43.54 | 7.053 | 0.406 | 32.400 | 0.406 | 0.406 | 0.406 | 0.767 | 0.610 |
meta-llama/Llama-3.1-8B-Instruct | 42.52 | 44.79 | 36.07 | 6.459 | 0.434 | 29.200 | 0.241 | 0.673 | 0.153 | 0.803 | 0.634 |
This is to test if extra synthetic Japanese generation improves Japanese:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-11-rafathenev2-shisa-v2-llama-3.1-8b-lr8e6 | 47.72 | 46.15 | 45.80 | 7.182 | 0.430 | 32.400 | 0.557 | 1.437 | 0.153 | 0.798 | 0.652 |
ablation-49-rafathenev2.magpief-shisa-v2-llama-3.1-8b | 43.07 | 39.25 | 43.30 | 7.154 | 0.401 | 19.700 | 0.539 | 0.020 | 0.193 | 0.767 | 0.610 |
meta-llama/Llama-3.1-8B-Instruct | 42.52 | 44.79 | 36.07 | 6.476 | 0.401 | 27.700 | 0.247 | 1.010 | 0.160 | 0.803 | 0.634 |
- Results are about the same except that JP-IFEval is better
This is just to confirm that a mathcode only mix improves our coding ability:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
meta-llama/Llama-3.1-8B-Instruct | 42.52 | 44.79 | 36.07 | 6.476 | 0.442 | 27.700 | 0.247 | 1.010 | 0.160 | 0.803 | 0.634 |
ablation-50-mathcode-shisa-v2-llama-3.1-8b-lr8e6 | 38.72 | 36.53 | 35.24 | 6.150 | 0.395 | 15.200 | 0.385 | 0.003 | 0.147 | 0.733 | 0.671 |
- Yes we get some improvements in EvalPlus
- However, LiveBench code and data analysis give us 0.0 so something seems wrong. Data format wrong?
- Look at output.xlsx to see broken down scores
- look at results for raw output
51 mathcode reasoning...
We should see a coding improvement if we use only the "good" English code?
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-11-rafathenev2-shisa-v2-llama-3.1-8b-lr8e6 | 47.72 | 46.15 | 45.80 | 7.182 | 0.430 | 32.400 | 0.557 | 1.437 | 0.153 | 0.798 | 0.652 |
ablation-42-rafathenev2.cmrmixen.masked-shisa-v2-llama-3.1-8b-lr8e6 | 42.69 | 38.59 | 43.13 | 7.089 | 0.357 | 27.600 | 0.557 | 0.087 | 0.173 | 0.662 | 0.610 |
meta-llama/Meta-Llama-3.1-8B-Instruct | 42.30 | 45.19 | 35.07 | 6.459 | 0.434 | 29.200 | 0.241 | 0.673 | 0.153 | 0.807 | 0.640 |
ablation-41-cmrmixen.masked-shisa-v2-llama-3.1-8b-lr8e6 | 37.69 | 38.33 | 33.24 | 5.172 | 0.358 | 27.500 | 0.515 | 0.000 | 0.113 | 0.650 | 0.567 |
- is the Tulumix just bad?
- It does not improve either our EN or JA when shuffled in. Every metric, including evalplus is worse.
Here's a comparison of just the EN vs our questionable JA
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-40-rafathenev2.cmrmix.masked-shisa-v2-llama-3.1-8b-lr8e6 | 43.98 | 39.72 | 44.72 | 7.149 | 0.357 | 28.800 | 0.553 | 0.805 | 0.173 | 0.696 | 0.616 |
ablation-42-rafathenev2.cmrmixen.masked-shisa-v2-llama-3.1-8b-lr8e6 | 42.69 | 38.59 | 43.13 | 7.089 | 0.357 | 27.600 | 0.557 | 0.087 | 0.173 | 0.662 | 0.610 |
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-39-cmrmix.masked-shisa-v2-llama-3.1-8b-lr8e6 | 42.28 | 38.58 | 43.61 | 5.422 | 0.372 | 25.500 | 0.507 | N/A | 0.153 | 0.674 | 0.555 |
ablation-41-cmrmixen.masked-shisa-v2-llama-3.1-8b-lr8e6 | 37.69 | 38.33 | 33.24 | 5.172 | 0.358 | 27.500 | 0.515 | 0.000 | 0.113 | 0.650 | 0.567 |
As a sanity check, we test whether it's better to do input masking: yes it is (this is w/ the flawed CMR mix):
Here is the just the cmrmix vs masked cmrmix
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-39-cmrmix.masked-shisa-v2-llama-3.1-8b-lr8e6 | 42.28 | 38.58 | 43.61 | 5.422 | 0.372 | 25.500 | 0.507 | N/A | 0.153 | 0.674 | 0.555 |
ablation-36-cmrmix-shisa-v2-llama-3.1-8b-lr8e6 | 38.02 | 37.27 | 34.67 | 5.458 | 0.360 | 24.800 | 0.509 | 0.001 | 0.133 | 0.647 | 0.585 |
And here it is blended:
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-40-rafathenev2.cmrmix.masked-shisa-v2-llama-3.1-8b-lr8e6 | 43.98 | 39.72 | 44.72 | 7.149 | 0.357 | 28.800 | 0.553 | 0.805 | 0.173 | 0.696 | 0.616 |
ablation-37-rafathenev2.cmrmix-shisa-v2-llama-3.1-8b-lr8e6 | 43.09 | 39.29 | 43.69 | 7.105 | 0.358 | 28.900 | 0.549 | 0.273 | 0.187 | 0.671 | 0.591 |
This experiment tests whether the addition of shisa-ai/shisa-v2-code-math-reasoning-sft-mix, a 192.4K EN/JA code/math/reasoning dataset improves our performance.
-
ablation-11-rafathenev2
is shisa-v1.1-athenev2-filtered -
ablation-36-cmrmixshisa
is the shisa-v2-code-math-reasoning-sft-mix alone -
ablation-37-rafathenev2.cmrmix
shisa-v1.1-athenev2-filtered + shisa-v2-code-math-reasoning-sft-mix mixed -
ablation-38-rafathenv2.cmrmix.fixedlength
- mixed, but 50% of each to approx same training length asablation-11
Model | Overall | EN | JA | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | evalplus |
---|---|---|---|---|---|---|---|---|---|---|---|
ablation-11-rafathenev2-shisa-v2-llama-3.1-8b-lr8e6 | 47.72 | 46.15 | 45.80 | 7.182 | 0.430 | 32.400 | 0.557 | 1.437 | 0.153 | 0.798 | 0.652 |
ablation-38-rafathenev2.cmrmix.fixedlength-shisa-v2-llama-3.1-8b-lr8e6 | 45.67 | 37.75 | 52.04 | 7.008 | 0.348 | 26.100 | 0.533 | N/A | 0.147 | 0.670 | 0.598 |
ablation-37-rafathenev2.cmrmix-shisa-v2-llama-3.1-8b-lr8e6 | 43.09 | 39.29 | 43.69 | 7.105 | 0.358 | 28.900 | 0.549 | 0.273 | 0.187 | 0.671 | 0.591 |
meta-llama/Meta-Llama-3.1-8B-Instruct | 42.30 | 45.19 | 35.07 | 6.459 | 0.434 | 29.200 | 0.241 | 0.673 | 0.153 | 0.807 | 0.640 |
ablation-36-cmrmix-shisa-v2-llama-3.1-8b-lr8e6 | 38.02 | 37.27 | 34.67 | 5.458 | 0.360 | 24.800 | 0.509 | 0.001 | 0.133 | 0.647 | 0.585 |
- the JA crm-mix should improve coding but doesn't, so something very wrong! (sure enough, our translation model was eliding the conversion, no bueno)
This dataset is what we're using from allenai/tulu-3-sft-mixture with machine translation to Japanese.
https://huggingface.co/datasets/shisa-ai/shisa-v2-code-math-reasoning-sft-mix/viewer/default/train?p=2&sql=--+The+SQL+console+is+powered+by+DuckDB+WASM+and+runs+entirely+in+the+browser.%0A--+Get+started+by+typing+a+query+or+selecting+a+view+from+the+options+below.%0ASELECT+DISTINCT%28SOURCE%29+FROM+train%3B&views%5B%5D=train
-- The SQL console is powered by DuckDB WASM and runs entirely in the browser.
-- Get started by typing a query or selecting a view from the options below.
SELECT DISTINCT(SOURCE) FROM train;
ai2-adapt-dev/no_robots_converted
allenai/tulu-3-sft-personas-math-grade
ai2-adapt-dev/tulu_v3.9_open_math_2_gsm8k_50k
ai2-adapt-dev/evol_codealpaca_heval_decontaminated
ai2-adapt-dev/flan_v2_converted
ai2-adapt-dev/personahub_code_v2_34999
ai2-adapt-dev/numinamath_tir_math_decontaminated
In axolotl, the formatting is Tulu's, and here's how to train:
- path: shisa-ai/shisa-v2-code-math-reasoning-sft-mix
type: chat_template
field_messages: conversations
message_property_mappings:
role: role
content: content
roles:
system:
- system
assistant:
- gpt
- model
- assistant
user:
- human
- user
roles_to_train: ["system", "input", "assistant"]
For ablation-37
we add split: train[:50%]
to the config (see: https://github.com/axolotl-ai-cloud/axolotl/discussions/2280)
- ablation-01
- ablation-02
- ablation-03
- ablation-04
Confirmation that using the Liger Kernels improves training efficiency and doesn't negatively affect performance.
This is a sanity check using basically the exact same config (ShareGPT config has changed to updated syntax). The only thing that has changed is updated versions of the software (May 2024 vs Feb 2025).
- shisa-v1-llama3-8b
- ablation-00-baseline-shisa-v2-llama3-8b-lr-8e6
Trained from:
- Llama 3 8B Instruct
On almost every metric, our new ablation performs better using the same hyper-parameters and datasets. I guess there were a lot of under the hood bugs/improvements made between PyTorch, Transformers, TRL, and Axolotl...
Model | Overall | Shaberi | MixEval | LiveBench | llm-jp-eval | JP RP-Bench | JP IFEval | IFEval | HumanEval+ |
---|---|---|---|---|---|---|---|---|---|
ablation-00 | 40.43 | 6.182 | 0.343 | 27.300 | 0.550 | 0.933 | 0.127 | 0.693 | 0.537 |
shisa-v1-llama3 | 34.42 | 6.303 | 0.362 | 20.200 | 0.225 | 0.993 | 0.087 | 0.630 | 0.518 |
- Overall score is +17.5%
- see https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html and https://github.com/axolotl-ai-cloud/axolotl/issues/1375
https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/data/sft.py https://axolotl-ai-cloud.github.io/axolotl/docs/dataset_preprocessing.html https://github.com/axolotl-ai-cloud/axolotl/issues/1508 https://github.com/axolotl-ai-cloud/axolotl/issues/735 https://github.com/axolotl-ai-cloud/axolotl/discussions/2280 https://github.com/axolotl-ai-cloud/axolotl/issues/735