DoLa - juunho/SKCC_LCL GitHub Wiki
-
LLMμ νκ°
LLMμ μ¬μ νμ΅μμ κ΄μ°°λ μ€μ μ¬μ€μμ λ²μ΄λ λ΄μ©μ μμ±νλ βνκ°βμ μΌμΌν€λ κ²½ν₯μ΄ μμ.
νκ°μ μ΄μ λ λͺ ννμ§ μμΌλ, λ°μ΄ν°μ μΈμ΄λͺ¨λΈμ νλ₯ λΆν¬ κ° μ°¨μ΄(KL-divergence)λ₯Ό μ€μ΄κ³ μ νμ΅νλ language modeling objective κ° μμΈμ΄ λ μ μμ. μ¦, μΈμ΄ λͺ¨λΈμ νμ΅λ μ½νΌμ€μμ μ€μ μ¬μ€μ μΈμνλ λμ μ μΈλΆ ν¨ν΄μ μΈμνλ©° νμ΅λ¨.
-
λͺ¨λΈ ν΄μ κ΄μ
μ΄μ μ°κ΅¬μμ LMμ νμ λ μ΄μ΄λ lower-level information(μ,λ,μ΄,κ°,ν,β¦)μ, μμ λ μ΄μ΄λ semantic information(6.25μ μ, λ°μ, μ°λ,β¦)μ μΈμ½λ©νλ κ²μΌλ‘ λνλ¨.
λν LM λ΄μμ νΉμ feed-forward layerλ₯Ό λ³ννμ¬ μ¬μ€μ μ§μμ νΈμ§ν μ μμμ 보μ¬μ€.
μ κ·Έλ¦Όμμ LMμ λ§μ§λ§ layerμμ λμ νλ₯ μ κ°μ§λ Seattleμ outputμΌλ‘ λ±κ² λ¨.
μ€μ μ λ΅(Olympia)μ μμ λ μ΄μ΄ μΌμλ‘ νλ₯ μ΄ μ¦κ°ν΄μΌλλλ°, νμ λ μ΄μ΄μμλΆν° νλ₯ μ΄ λμλ κ²μΌλ‘ μΈνμ¬ νκ°μ΄ λ°μλ¨.
β μλ‘ λ€λ₯Έ λ μ΄μ΄ κ°μ νλ₯ λΆν¬λ₯Ό λμ‘°νμ¬ μ΄λ₯Ό κ΅μ νλ©΄ νκ°μ μ€μΌ μ μμ§ μμκΉ?
μ κ·Έλ¦Όμ λ§μ§λ§ λ μ΄μ΄μ μ΄μ λ μ΄μ΄λ€μ JSD(Jenson-Shannon divergence)λ₯Ό κ³μ°ν κ²μ. JSDκ° μμ μλ‘ λ§μ§λ§ λ μ΄μ΄μ λΆν¬μ μ°¨μ΄κ° μ λ€λ μλ―Έλ₯Ό μ§λ.
-
JSDκ° μ€μν μν°ν°μ μμ λ μ΄μ΄μμ μ¬μ ν λμ. β μ΄λ λ§μ§λ§ λ μ΄μ΄μ μ°¨μ΄κ° ν° κ²μ μλ―Ένκ³ μμΈ‘μ λ³κ²½ν μ μλ€λ μ¬μ€μ μλ―Έν¨.
-
μ¬μ΄ ν ν°μ μμΈ‘ν λ μ€κ° λ μ΄μ΄μμ λ§€μ° μμμ§. β λͺ¨λΈμ΄ μ€κ° λ μ΄μ΄μμ μμ±ν ν ν°μ μ΄λ―Έ κ²°μ νμΌλ©° μΆλ ₯ λΆν¬λ₯Ό κ±°μ λ³κ²½νμ§ μκ³ μ μ§ν¨.
β λ μ΄μ΄μ JSDκ° κ°μκΈ° λ³ν λλ₯Ό λμ‘°νλ©΄(contrastive decoding), LMμ μ€μ μ¬μ€μ μ¦νμμΌ μ¬μ€μ μ§μμ ν¨κ³Όμ μΌλ‘ μ¬μ©νκ³ μλͺ»λ μ¬μ€μ μμ±νλ νκ° νμμ ν΄μν μ μμ κ²μ!
μ΄λ¬ν μ κ·Ό λ°©μμ μΆκ°μ μΈ 1) μΈλΆμ§μμ΄ νμνμ§ μκ³ 2) μ΄λ ν νμΈ νλλ νμ§ μμλ λλ€λ μ₯μ μ κ°μ§.
(1) transformer μ λ§μ§λ§ layerμ λ€λ₯Έ layer κ°μ JSDλ₯Ό κ³μ°νμ¬ λΆν¬κ° κ°μ₯ ν¬κ² λ³ννλ λ μ΄μ΄λ₯Ό μ°Ύμ.
(2) (1)μμ μ°Ύμ layerμ logitκ³Ό μλ transformer λ§μ§λ§ layerμ logitμ λΉκ΅νμ¬, μ΅μ’ logitμ κ΅μ ν¨.
-
μ½λ
# 1. Stacking all premature_layers into a new dimension stacked_premature_layers = torch.stack( [candidate_premature_logits[i].to(final_logits) for i in candidate_premature_layers], dim=0 ) # 2. Calculate the softmax values for mature_layer and all premature_layers softmax_mature_layer = F.softmax(final_logits, dim=-1).detach() # shape: (batch_size, vocab_size) softmax_premature_layers = F.softmax( stacked_premature_layers, dim=-1 ) # shape: (num_premature_layers, batch_size, vocab_size) # 3. Calculate M, the average distribution M = 0.5 * ( softmax_mature_layer[None, :, :] + softmax_premature_layers ) # shape: (num_premature_layers, batch_size, vocab_size) # 4. Calculate log-softmax for the KL divergence log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1) # shape: (batch_size, vocab_size) log_softmax_premature_layers = F.log_softmax( stacked_premature_layers, dim=-1 ) # shape: (num_premature_layers, batch_size, vocab_size) # 5. Calculate the KL divergences and then the JS divergences kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean( -1 ) # shape: (num_premature_layers, batch_size) kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean( -1 ) # shape: (num_premature_layers, batch_size) js_divs = 0.5 * (kl1 + kl2) # shape: (num_premature_layers, batch_size) # 6. Reduce the batchmean js_divs = js_divs.mean(-1) # shape: (num_premature_layers,) premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())] base_logits = candidate_premature_logits[premature_layer] final_logits, base_logits = _relative_top_filter(final_logits.detach(), base_logits.detach()) logits = final_logits - base_logits.to(final_logits.device)
def _relative_top_filter( scores: torch.FloatTensor, baseline_scores: torch.FloatTensor, relative_top: float = 0.1, filter_value: float = -float("Inf"), base_filter_value=-1e-3, min_tokens_to_keep: int = 1, ) -> torch.FloatTensor: scores_normalized = scores.log_softmax(dim=-1) baseline_scores_normalized = baseline_scores.log_softmax(dim=-1) sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True) min_thresh = sorted_logits[..., min_tokens_to_keep - 1] probs_max = torch.max(scores_normalized, dim=-1).values # μ€μ μ λ΅ logit probs_thresh = probs_max + np.log(relative_top) # μ€μ μ λ΅ logitμ μ΄λ€ κ° λν κ². probs_thresh = torch.min(min_thresh, probs_thresh) probs_thresh = probs_thresh.unsqueeze(-1) baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value scores_normalized[scores_normalized < probs_thresh] = filter_value return scores_normalized, baseline_scores_normalized