炒冷饭,一道16年前的北美OUG SQL挑战赛回顾‐3 - l1t1/note GitHub Wiki

前面说过当N>300时,有的面值的概率已经超过了双精度浮点数的表示范围,只能显示0,怎么才能较精确地显示这种数据?首先想到用gmpy2库执行高精度计算,结果比较下来,mpfr虽然精度较高,但速度下降很大,不大满意。mpq分数在运算过程中保留精确的分子和分母,速度比较适中。

先让Deepseek改写我的递归CTE SQL,提示了好几遍,他明白了“SQL CTE每次迭代只能访问上次迭代的数据,考虑用二维列表存储,第一个维度是lvl,两个表交叉连接”,测试了mpq分数,告诉他“计算完全精确,用这个思路来重写递归程序,完成:第一步,计算不大于投掷次数的2的幂次投掷结果。第二步,筛选出bitand投掷次数不为0的幂次。最后一步,把有用的幂次逐个相乘。”同时让他把超过精度的数字正确显示成科学计数法。

import gmpy2
from collections import defaultdict

def exact_dice_probabilities(die_faces, num_throws):
    # 将输入概率转换为精确分数
    die_dist = {}
    for face, prob in die_faces.items():
        if isinstance(prob, (int, float)):
            prob = gmpy2.mpq(gmpy2.mpfr(prob))  # 浮点数转最接近的分数
        die_dist[gmpy2.mpz(face)] = prob
    
    # 计算所有不超过num_throws的2的幂次的结果
    power_levels = {}
    power_levels[1] = die_dist
    
    current_power = 1
    while current_power * 2 <= num_throws:
        next_power = current_power * 2
        new_level = defaultdict(gmpy2.mpq)
        for v1, p1 in power_levels[current_power].items():
            for v2, p2 in power_levels[current_power].items():
                new_level[v1 + v2] += p1 * p2
        power_levels[next_power] = dict(new_level)
        current_power = next_power
    
    # 筛选需要的幂次
    needed_powers = [p for p in power_levels if p & num_throws]
    needed_powers.sort()
    
    if not needed_powers:
        raise ValueError("无效的投掷次数")
    
    # 逐步合并结果
    result = power_levels[needed_powers[0]].copy()
    for power in needed_powers[1:]:
        new_result = defaultdict(gmpy2.mpq)
        for v1, p1 in result.items():
            for v2, p2 in power_levels[power].items():
                new_result[v1 + v2] += p1 * p2
        result = new_result
    
    # 转换为科学计数法输出
    output = []
    for v in sorted(result):
        frac = result[v]
        # 对数法科学计数
        try:
            log10_val = float(gmpy2.log10(frac))
            exponent = int(gmpy2.floor(log10_val))
            mantissa = 10 ** (log10_val - exponent)
            sci_notation = f"{mantissa:.4f}E{exponent}"
        except:
            sci_notation = "0.0000E+00"
        
        output.append({
            'face_value': int(v),
            'probability': sci_notation,
            'exact_fraction': f"{frac.numerator}/{frac.denominator}"
        })
    
    return output

# 使用示例
if __name__ == "__main__":
    die_faces = {
        1: gmpy2.mpq(1,4),
        3: gmpy2.mpq(1,4),
        4: gmpy2.mpq(1,4),
        5: gmpy2.mpq(1,12),
        6: gmpy2.mpq(1,12),
        8: gmpy2.mpq(1,12)
    }
    
    num_throws = 300
    print(f"=== 精确计算 {num_throws}次非均匀骰子 ===")
    results = exact_dice_probabilities(die_faces, num_throws)
    
    # 打印部分结果
    print("\n前5个结果:")
    for r in results[:5]:
        print(f"面值 {r['face_value']}: {r['probability']} (精确分数: {r['exact_fraction']})")
    
    print("\n... 中间省略 ...")
    
    print("\n后5个结果:")
    for r in results[-5:]:
        print(f"面值 {r['face_value']}: {r['probability']} (精确分数: {r['exact_fraction']})")

感觉中间保留了一些用不到的2的幂次,有些多余,再让他参考numpy快速幂写法改写,结果并没有提升效率。再把两重循环计算乘积改成pandas的groupby sum, 用itertools的product等,也都提升不大。

最后突然想到,既然pandas允许计算mpq数据, numpy作为它的底层依赖库也应当支持,一试果然支持,科学计数法沿用前面的即可。所以有以下解法。另外,经过比较,numpy计算双精度浮点数的结果与mpq在精度范围内基本上结果一致,因此当N<300时,直接用浮点计算就能保证精度。

import numpy as np
from numpy.polynomial import polynomial as P
from gmpy2 import mpq, log10, floor
import gmpy2

def exact_poly_pow_numpy(die_faces, power):
    # 创建多项式系数数组
    max_face = max(die_faces.keys())
    poly = np.zeros(max_face + 1, dtype=object)  # object类型存储mpq
    
    for face, prob in die_faces.items():
        poly[face] = mpq(prob) if not isinstance(prob, mpq) else prob
    
    # 计算幂次
    result_poly = P.polypow(poly, power)
    
    # 准备输出
    results = []
    for exp in range(len(result_poly)):
        coeff = result_poly[exp]
        if coeff != 0:
            # 安全转换科学计数法
            sci_notation = safe_scientific_notation(coeff)
            
            results.append({
                'face_value': exp,
                'probability': sci_notation,
                'exact_fraction': f"{coeff.numerator}/{coeff.denominator}"
            })
    
    return results

def safe_scientific_notation(coeff):
    """三层保护的科学计数法转换"""
    try:
        # 第一层:尝试直接浮点转换
        f_val = float(coeff)
        if f_val != 0:
            return f"{f_val:.4e}"
    except:
        pass
    
    try:
        # 第二层:对数法计算
        num = coeff.numerator
        den = coeff.denominator
        log_num = float(log10(abs(num)))
        log_den = float(log10(den))
        log10_val = log_num - log_den
        exponent = int(floor(log10_val))
        mantissa = 10 ** (log10_val - exponent)
        return f"{mantissa:.4f}E{exponent}"
    except:
        # 第三层:显示分数数量级
        den = coeff.denominator
        return f"< 1E{int(floor(float(log10(den))))}"

# 使用示例
if __name__ == "__main__":
    die_faces = {
        1: mpq(1,4),
        3: mpq(1,4),
        4: mpq(1,4),
        5: mpq(1,12),
        6: mpq(1,12),
        8: mpq(1,12)
    }
    
    results = exact_poly_pow_numpy(die_faces, 400)
    
    # 打印极端小值的正确显示
    for r in results[-3:]:
        print(f"面值 {r['face_value']}: {r['probability']} (精确: {r['exact_fraction']})")