DoLa - juunho/SKCC_LCL GitHub Wiki

DoLa : Decoding by Contrasting Layers Improves Factuality in Large Language Models

πŸ’‘ LLM의 ν™˜κ°μ„ 쀄이기 μœ„ν•œ μƒˆλ‘œμš΄ λ””μ½”λ”© μ „λž΅

Introduction

  • LLM의 ν™˜κ°

    LLM은 사전 ν•™μŠ΅μ—μ„œ κ΄€μ°°λœ μ‹€μ œ μ‚¬μ‹€μ—μ„œ λ²—μ–΄λ‚œ λ‚΄μš©μ„ μƒμ„±ν•˜λŠ” β€œν™˜κ°β€μ„ μΌμœΌν‚€λŠ” κ²½ν–₯이 있음.

    ν™˜κ°μ˜ μ΄μœ λŠ” λͺ…ν™•ν•˜μ§€ μ•ŠμœΌλ‚˜, 데이터와 μ–Έμ–΄λͺ¨λΈμ˜ ν™•λ₯  뢄포 κ°„ 차이(KL-divergence)λ₯Ό μ€„μ΄κ³ μž ν•™μŠ΅ν•˜λŠ” language modeling objective κ°€ 원인이 될 수 있음. 즉, μ–Έμ–΄ λͺ¨λΈμ€ ν•™μŠ΅λœ μ½”νΌμŠ€μ—μ„œ μ‹€μ œ 사싀을 μΈμ‹ν•˜λŠ” λŒ€μ‹ μ— μ™ΈλΆ€ νŒ¨ν„΄μ„ μΈμ‹ν•˜λ©° ν•™μŠ΅λ¨.

  • λͺ¨λΈ 해석 관점

    image

    이전 μ—°κ΅¬μ—μ„œ LM의 ν•˜μœ„ λ ˆμ΄μ–΄λŠ” lower-level information(은,λŠ”,이,κ°€,ν•œ,…)을, μƒμœ„ λ ˆμ΄μ–΄λŠ” semantic information(6.25μ „μŸ, λ°œμƒ, 연도,…)을 μΈμ½”λ”©ν•˜λŠ” κ²ƒμœΌλ‘œ λ‚˜νƒ€λ‚¨.

    λ˜ν•œ LM λ‚΄μ—μ„œ νŠΉμ • feed-forward layerλ₯Ό λ³€ν˜•ν•˜μ—¬ 사싀적 지식을 νŽΈμ§‘ν•  수 μžˆμŒμ„ λ³΄μ—¬μ€Œ.

DoLa

image

μœ„ κ·Έλ¦Όμ—μ„œ LM은 λ§ˆμ§€λ§‰ layerμ—μ„œ 높은 ν™•λ₯ μ„ κ°€μ§€λŠ” Seattle을 output으둜 λ±‰κ²Œ 됨.

μ‹€μ œ μ •λ‹΅(Olympia)은 μƒμœ„ λ ˆμ΄μ–΄ 일수둝 ν™•λ₯ μ΄ μ¦κ°€ν•΄μ•Όλ˜λŠ”λ°, ν•˜μœ„ λ ˆμ΄μ–΄μ—μ„œλΆ€ν„° ν™•λ₯ μ΄ λ†’μ•˜λ˜ κ²ƒμœΌλ‘œ μΈν•˜μ—¬ ν™˜κ°μ΄ λ°œμƒλ¨.

β†’ μ„œλ‘œ λ‹€λ₯Έ λ ˆμ΄μ–΄ κ°„μ˜ ν™•λ₯  뢄포λ₯Ό λŒ€μ‘°ν•˜μ—¬ 이λ₯Ό κ΅μ •ν•˜λ©΄ ν™˜κ°μ„ 쀄일 수 μžˆμ§€ μ•Šμ„κΉŒ?

image

μœ„ 그림은 λ§ˆμ§€λ§‰ λ ˆμ΄μ–΄μ™€ 이전 λ ˆμ΄μ–΄λ“€μ˜ JSD(Jenson-Shannon divergence)λ₯Ό κ³„μ‚°ν•œ κ²ƒμž„. JSDκ°€ μž‘μ„ 수둝 λ§ˆμ§€λ§‰ λ ˆμ΄μ–΄μ˜ 뢄포와 차이가 μ λ‹€λŠ” 의미λ₯Ό μ§€λ‹˜.

  1. JSDκ°€ μ€‘μš”ν•œ μ—”ν‹°ν‹°μ˜ μƒμœ„ λ ˆμ΄μ–΄μ—μ„œ μ—¬μ „νžˆ λ†’μŒ. β†’ μ΄λŠ” λ§ˆμ§€λ§‰ λ ˆμ΄μ–΄μ™€ 차이가 큰 것을 μ˜λ―Έν•˜κ³  μ˜ˆμΈ‘μ„ λ³€κ²½ν•  수 μžˆλ‹€λŠ” 사싀을 μ˜λ―Έν•¨.

  2. μ‰¬μš΄ 토큰을 μ˜ˆμΈ‘ν•  λ•Œ 쀑간 λ ˆμ΄μ–΄μ—μ„œ 맀우 μž‘μ•„μ§. β†’ λͺ¨λΈμ΄ 쀑간 λ ˆμ΄μ–΄μ—μ„œ 생성할 토큰을 이미 κ²°μ •ν–ˆμœΌλ©° 좜λ ₯ 뢄포λ₯Ό 거의 λ³€κ²½ν•˜μ§€ μ•Šκ³  μœ μ§€ν•¨.

β†’ λ ˆμ΄μ–΄μ˜ JSDκ°€ κ°‘μžκΈ° λ³€ν•  λ•Œλ₯Ό λŒ€μ‘°ν•˜λ©΄(contrastive decoding), LM의 μ‹€μ œ 사싀을 μ¦ν­μ‹œμΌœ 사싀적 지식을 효과적으둜 μ‚¬μš©ν•˜κ³  잘λͺ»λœ 사싀을 μƒμ„±ν•˜λŠ” ν™˜κ° ν˜„μƒμ„ ν•΄μ†Œν•  수 μžˆμ„ κ²ƒμž„!

μ΄λŸ¬ν•œ μ ‘κ·Ό 방식은 좔가적인 1) 외뢀지식이 ν•„μš”ν•˜μ§€ μ•Šκ³  2) μ–΄λ– ν•œ 파인 νŠœλ‹λ„ ν•˜μ§€ μ•Šμ•„λ„ λœλ‹€λŠ” μž₯점을 가짐.

Method

(1) transformer 의 λ§ˆμ§€λ§‰ layer와 λ‹€λ₯Έ layer κ°„μ˜ JSDλ₯Ό κ³„μ‚°ν•˜μ—¬ 뢄포가 κ°€μž₯ 크게 λ³€ν™”ν•˜λŠ” λ ˆμ΄μ–΄λ₯Ό 찾음.

(2) (1)μ—μ„œ 찾은 layer의 logitκ³Ό μ›λž˜ transformer λ§ˆμ§€λ§‰ layer의 logit을 λΉ„κ΅ν•˜μ—¬, μ΅œμ’… logit을 ꡐ정함.

  • μ½”λ“œ

    # 1. Stacking all premature_layers into a new dimension
    stacked_premature_layers = torch.stack(
        [candidate_premature_logits[i].to(final_logits) for i in candidate_premature_layers], dim=0
    )
    # 2. Calculate the softmax values for mature_layer and all premature_layers
    softmax_mature_layer = F.softmax(final_logits, dim=-1).detach() # shape: (batch_size, vocab_size)
    softmax_premature_layers = F.softmax(
        stacked_premature_layers, dim=-1
    )  # shape: (num_premature_layers, batch_size, vocab_size)
    
    # 3. Calculate M, the average distribution
    M = 0.5 * (
        softmax_mature_layer[None, :, :] + softmax_premature_layers
    )  # shape: (num_premature_layers, batch_size, vocab_size)
    
    # 4. Calculate log-softmax for the KL divergence
    log_softmax_mature_layer = F.log_softmax(final_logits, dim=-1)  # shape: (batch_size, vocab_size)
    log_softmax_premature_layers = F.log_softmax(
        stacked_premature_layers, dim=-1
    )  # shape: (num_premature_layers, batch_size, vocab_size)
    
    # 5. Calculate the KL divergences and then the JS divergences
    kl1 = F.kl_div(log_softmax_mature_layer[None, :, :], M, reduction="none").mean(
        -1
    )  # shape: (num_premature_layers, batch_size)
    kl2 = F.kl_div(log_softmax_premature_layers, M, reduction="none").mean(
        -1
    )  # shape: (num_premature_layers, batch_size)
    js_divs = 0.5 * (kl1 + kl2)  # shape: (num_premature_layers, batch_size)
    # 6. Reduce the batchmean
    js_divs = js_divs.mean(-1)  # shape: (num_premature_layers,)
    
    premature_layer = candidate_premature_layers[int(js_divs.argmax().cpu().item())]
    base_logits = candidate_premature_logits[premature_layer]
    final_logits, base_logits = _relative_top_filter(final_logits.detach(), base_logits.detach())
    logits = final_logits - base_logits.to(final_logits.device)
    def _relative_top_filter(
        scores: torch.FloatTensor,
        baseline_scores: torch.FloatTensor,
        relative_top: float = 0.1,
        filter_value: float = -float("Inf"),
        base_filter_value=-1e-3,
        min_tokens_to_keep: int = 1,
    ) -> torch.FloatTensor:
    
        scores_normalized = scores.log_softmax(dim=-1)
        baseline_scores_normalized = baseline_scores.log_softmax(dim=-1)
        sorted_logits, sorted_indices = torch.sort(scores_normalized, descending=True)
        min_thresh = sorted_logits[..., min_tokens_to_keep - 1]
        probs_max = torch.max(scores_normalized, dim=-1).values # μ‹€μ œ μ •λ‹΅ logit
        probs_thresh = probs_max + np.log(relative_top) # μ‹€μ œ μ •λ‹΅ logit에 μ–΄λ–€ κ°’ λ”ν•œ 것. 
        probs_thresh = torch.min(min_thresh, probs_thresh) 
        probs_thresh = probs_thresh.unsqueeze(-1)
        baseline_scores_normalized[scores_normalized < probs_thresh] = base_filter_value
        scores_normalized[scores_normalized < probs_thresh] = filter_value
        return scores_normalized, baseline_scores_normalized
⚠️ **GitHub.com Fallback** ⚠️