LLM MLX RAG - eiichiromomma/CVMLAB GitHub Wiki

(LLM) MLX-LM でlocalなRAG

Open WebUIとかAnything LLMとか良さげなツールはあるのだが,元データのウンコっぷりをどうにかしないと,Garbage in, garbage outになる.結局サーバー動作させるよりもスクリプトで起動してChatにした方が良さそう.

  • 全角スペースでの均等割り付け調整
  • 全角括弧などの記号
  • 改行連発での空間調整
  • ページ跨ぎ対策

への対策をしたうえでチャンクを上手く食わせるかが問題なので,クリーンなドキュメントにしてからデカめのチャンクで重複も多めにして上手く立ち回るようになった.

モデルについては日本語のドキュメント中心ということもありQwenが優秀.

  • mlx-community/Qwen2.5-14B-Instruct-4bit: 24GBのMacBook Airだとしんどい.時間もかかる
  • mlx-community/Qwen2.5-7B-Instruct-4bit: 良いバランス
  • mlx-community/Qwen2.5-3B-Instruct-4bit: 速いけどアホ.簡体字も混ざり始める
import os
import re
import mlx.core as mx
from mlx_lm import load
from mlx_lm.generate import generate_step
from mlx_lm.sample_utils import make_sampler
from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings, StorageContext, load_index_from_storage, Document
from llama_index.core.retrievers import QueryFusionRetriever
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.retrievers.bm25 import BM25Retriever
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# モデルの準備
# model_path = "mlx-community/Qwen2.5-14B-Instruct-4bit"
model_path = "mlx-community/Qwen2.5-7B-Instruct-4bit"
# model_path = "mlx-community/Qwen2.5-3B-Instruct-4bit" # アホ

print(f"Loading model: {model_path} ...")
model, tokenizer = load(model_path)
N_RANK = 10

# LlamaIndex & Reranker 設定
Settings.embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-m3")
Settings.llm = None 
Settings.chunk_size = 1024
Settings.chunk_overlap = 512
node_parser = SentenceWindowNodeParser.from_defaults(
    window_size=5,  # 前後5文を含めてページ跨ぎも繋げる
    window_metadata_key="window",
    original_text_metadata_key="original_text",
)
Settings.node_parser = node_parser
RERANK_MODEL_NAME = "BAAI/bge-reranker-base"

print(f"Loading Reranker: {RERANK_MODEL_NAME} ...")
rerank_tokenizer = AutoTokenizer.from_pretrained(RERANK_MODEL_NAME)
rerank_model = AutoModelForSequenceClassification.from_pretrained(RERANK_MODEL_NAME)
rerank_model.eval() # 推論モード
# GPU (Metal) が使えるなら使う
if torch.backends.mps.is_available():
    rerank_model.to("mps")
    print("Reranker is using MPS (Metal).")

def clean_japanese_text(text):
    # 全角スペース(\u3000) を削除
    text = text.replace('\u3000', '')
    # 連続する空白・タブを1つのスペースにまとめる
    text = re.sub(r'[ \t]+', ' ', text)
    # 連続する改行(空行)を削除して詰める
    text = re.sub(r'\n\s*\n', '\n', text)
    # 行頭・行末の空白削除
    text = "\n".join([line.strip() for line in text.split('\n')])
    return text

def rerank_with_bge(query, nodes, top_n=5):
    if not nodes:
        return []
    # モデルに入力するペアを作成: [[クエリ, 文章A], [クエリ, 文章B], ...]
    pairs = [[query, node.get_content()] for node in nodes]
    
    with torch.no_grad():
        inputs = rerank_tokenizer(
            pairs, 
            padding=True, 
            truncation=True, 
            return_tensors='pt', 
            max_length=512
        )
        # デバイス転送
        if torch.backends.mps.is_available():
            inputs = {k: v.to("mps") for k, v in inputs.items()}
        # スコア算出
        scores = rerank_model(**inputs, return_dict=True).logits.view(-1,).float()
        
    # スコアが高い順にソートして、上位 top_n 個の (node, score) を返す
    # ※ item() でPythonのfloatに戻す
    scores_list = scores.cpu().numpy().tolist()
    ranked_nodes = sorted(zip(nodes, scores_list), key=lambda x: x[1], reverse=True)
    
    return ranked_nodes[:top_n]

def run_rag_generation(model, tokenizer, prompt_tokens):
    sampler = make_sampler(temp=0.4, top_p=0.9, min_p=0.05)
    generator = generate_step(
        prompt=prompt_tokens,
        model=model,
        sampler=sampler,
        # kv_bits=4,
        max_tokens=2000
    )

    all_tokens = []
    displayed_text = ""
    
    print("\n回答生成中...", flush=True)

    for i, (token, logits) in enumerate(generator):
        token_id = token.item() if hasattr(token, "item") else token
        if token_id == tokenizer.eos_token_id:
            break

        all_tokens.append(token_id)
        current_text = tokenizer.decode(all_tokens)

        # --- 停止制御 ---
        if "</answer>" in current_text:
            # 最後まで表示してから終了
            if len(current_text) > len(displayed_text):
                print(current_text[len(displayed_text):], end="", flush=True)
            print("\n\n[システム: 生成完了]")
            break

        # --- 文字化け防止待機ロジック ---
        clean_text = current_text
        if clean_text.endswith('\ufffd'):
            continue # 文字が完成するまで待つ

        # --- 差分表示 ---
        if len(clean_text) > len(displayed_text):
            new_text = clean_text[len(displayed_text):]
            print(new_text, end="", flush=True)
            displayed_text = clean_text

        # --- 無限ループ保険 ---
        if current_text.count("<answer>") >= 2:
            print("\n[システム: ループ検知により強制終了]")
            break

    return current_text

def run_rag():
    # データディレクトリのチェック
    if not os.path.exists("./data"):
        print("Error: './data' ディレクトリが見つかりません。PDFを入れてください。")
        return
    PERSIST_DIR = "./storage"
    print("Creating new index from './data'...")
    # ファイル名をメタデータとして埋め込む関数
    def get_meta(file_path):
        return {"file_name": os.path.basename(file_path)}
    raw_documents = SimpleDirectoryReader(
        "./data", 
        file_metadata=get_meta
    ).load_data()
    # テキストクリーニング & オブジェクト再作成
    # doc.text を直接書き換えるとエラーになるため新しい箱(Document)に入れ替え
    print(f"Cleaning {len(raw_documents)} document pages...")
    cleaned_documents = [] # ここに綺麗なドキュメントを詰める

    for doc in raw_documents:
        # ファイル名を先頭に埋め込むヘッダー
        file_name = doc.metadata.get('file_name', '不明')
        file_header = f"【文書名: {file_name}\n"
        # 本文のクリーニング
        cleaned_body = clean_japanese_text(doc.text)
        # 結合
        final_text = file_header + cleaned_body
        # 新しい Document オブジェクトとして作成
        new_doc = Document(
            text=final_text,
            metadata=doc.metadata # メタデータ(ページ番号など)は引き継ぐ
        )
        cleaned_documents.append(new_doc)
    # インデックスのロードまたは作成
    if not os.path.exists(PERSIST_DIR):
        # インデックス作成(綺麗なドキュメントを使う)
        print("Building Vector Index with BGE-M3...")
        index = VectorStoreIndex.from_documents(cleaned_documents)
        index.storage_context.persist(persist_dir=PERSIST_DIR)        
    else:
        print("Loading existing index from storage...")
        storage_context = StorageContext.from_defaults(persist_dir=PERSIST_DIR)
        index = load_index_from_storage(storage_context)    

    # ベクトル検索器 (意味検索)
    # ここで N_RANK件 拾う(取りこぼしを防ぐため)
    vector_retriever = index.as_retriever(similarity_top_k=N_RANK)
    # キーワード検索器 (BM25)
    # 固有名称や数値に強い。ここでも N_RANK件 拾う
    bm25_retriever = BM25Retriever.from_defaults(
        nodes=cleaned_documents, 
        similarity_top_k=N_RANK
    )
    all_nodes = list(index.docstore.docs.values())
    bm25_retriever = BM25Retriever.from_defaults(
        nodes=all_nodes, 
        similarity_top_k=N_RANK
    )
    # 両方の結果を混ぜて、いいとこ取りをする
    retriever = QueryFusionRetriever(
        [vector_retriever, bm25_retriever],
        similarity_top_k=N_RANK,   # 合計で上位N_RANK件を次の Rerank に回す
        num_queries=1,         # クエリ生成数
        mode="reciprocal_rerank", # 順位を公平に混ぜるアルゴリズム
        use_async=False
    )
    
    print("\n--- RAG Ready (Hybrid Search ) ---")
    print("Type 'exit' to quit.")
    while True:
        try:
            question = input("\n質問: ")
            if question.lower() in ["exit", "quit"]: break
            if not question.strip(): continue
                
            # ベクトル検索
            initial_nodes = retriever.retrieve(question)
            # リランク (BGE-Reranker)
            top_nodes_with_score = rerank_with_bge(question, initial_nodes, top_n=N_RANK)
            # コンテキスト構築
            context_list = []
            for node, score in top_nodes_with_score:
                file_name = node.metadata.get('file_name', '不明')
                page = node.metadata.get('page_label', '')
                # デバッグ表示
                print(f"  [Score: {score:.4f}] {file_name} : {node.get_content()[:30]}...")
                source_info = f"【ソース: {file_name} {f'(p.{page})' if page else ''}】"
                context_list.append(f"{source_info}\n{node.get_content()}")
            context_text = "\n\n".join(context_list)
            # メッセージ作成
            messages = [
                {"role": "user", "content": f"""以下の提供された【参考資料】に基づいて、ユーザーの質問に回答してください。

【制約事項】
1. 回答の中に情報の根拠となるソース名(例: 『資料その1.pdf』など)を明記してください。
2. まず <thinking> タグ内で思考プロセスを整理し、その後に <answer> タグ内で最終回答を記述してください。
3. </answer> を出力したら即座に終了してください。

【参考資料】
{context_text}

【質問】
{question}
"""}
            ]

            # テンプレート適用 & トークナイズ
            try:
                input_ids = tokenizer.apply_chat_template(
                    messages, 
                    tokenize=True, 
                    add_generation_prompt=True,
                    return_tensors="np" 
                )
            except Exception as e:
                # テンプレートがないモデルへのフォールバック
                print(f"Template Error: {e}, falling back to raw text.")
                text_prompt = f"Context:\n{context_text}\n\nQuestion:\n{question}\n\nAnswer:"
                input_ids = tokenizer.encode(text_prompt)

            # MLX配列への変換
            prompt_tokens = mx.array(input_ids)

            # 形状チェック: (1, N) のような2次元配列で来たら (N,) に平坦化する
            if len(prompt_tokens.shape) > 1:
                prompt_tokens = prompt_tokens.flatten()

            print(f"  -> Generating with {prompt_tokens.size} tokens...")

            # 生成実行
            run_rag_generation(model, tokenizer, prompt_tokens)

        except KeyboardInterrupt:
            print("\nInterrupted.")
            break
        except Exception as e:
            import traceback
            traceback.print_exc()
            print(f"\nError: {e}")

if __name__ == "__main__":
    run_rag()
⚠️ **GitHub.com Fallback** ⚠️