引导deepseek实现一个高效算自然对数的底e的高精度算法‐2 - l1t1/note GitHub Wiki
上次在deepseek的帮助下实现的计算e程序1秒计算能十万位,但计算百万位要90秒,大约是O(n^2)的时间复杂度。在经历计算圆周率π的优化过程后,考虑能否采用二进制分裂法以及gmpy2库的能力进一步提高。
直接告诉deepseek采用二进制分裂法是不行的,它给出的程序完全错误,我也没有能力去修改它。
所以我采取迂回的办法,自己设计了一个分块计算算法,即把计算小数点后digits位所需的迭代次数n一分为二,让它硬编码前后两块的计算程序,把计算结果再去与单次计算的比较。
这里也有一些波折,计算第二块的逻辑与第一块有所不同,deepseek的推理有误,
current_coeff = 1
numerator2 = current_coeff
for k in range(1, m+1-n): #deepseek误作 range(n+1, m+1)
current_coeff *= (m - k + 1)
numerator2 += current_coeff
numerator2-= current_coeff #deepseek未减去多加的一项
合并计算的算法也有误,
#deepseek写作
total = (numerator1 * denominator2 + numerator2 * denominator1) * power // (denominator1 * denominator2)
应为:
total = (numerator1 * denominator2 + numerator2 ) * power // (denominator1 * denominator2)
好在两步的步骤足够简单,它搭好了框架基础,我做修补工作就得到一个正确的程序,如下所示:
import math
import time
import sys
sys.set_int_max_str_digits(0)
def find_required_terms(digits):
"""精确计算所需最小n使得n! > 10^digits"""
ln_10 = math.log(10)
n_low = 0
n_high = 2 * (digits + 1) # 上界估计
# 二分查找求解
while n_low < n_high:
mid = (n_low + n_high) // 2
if mid <= 1:
stirling = 0
else:
stirling = mid * math.log(mid) - mid + 0.5 * math.log(2 * math.pi * mid)
if stirling > digits * ln_10:
n_high = mid
else:
n_low = mid + 1
return n_low
def calculate_e_optimized(digits):
n = find_required_terms(digits)
# 单变量累积计算
current_coeff = 1
numerator = current_coeff
for k in range(1, n+1):
current_coeff *= (n - k + 1)
numerator += current_coeff
denominator = current_coeff
power = 10**digits
print(str(numerator)[:10], str(denominator)[:10])
total = (numerator * power) // denominator
return f"2.{str(total % power).zfill(digits)}" #计算100000位耗时: 0.927s
#return f"2.{str(total)[1:digits+1]}" #计算100000位耗时: 0.925s
def calculate_e_two_parts(digits):
m = find_required_terms(digits)
n=m//2
# 单变量累积计算
current_coeff = 1
numerator1 = current_coeff
for k in range(1, n+1):
current_coeff *= (n - k + 1)
numerator1 += current_coeff
denominator1 = current_coeff
current_coeff = 1
numerator2 = current_coeff
for k in range(1, m+1-n):
current_coeff *= (m - k + 1)# ;print(current_coeff)
numerator2 += current_coeff
numerator2-= current_coeff
denominator2 = current_coeff
power = 10**digits
print(str(numerator1*denominator2+numerator2)[:10], str(denominator1*denominator2)[:10])
total = (numerator1*denominator2+numerator2) * power // (denominator1*denominator2) #
return f"2.{str(total % power).zfill(digits)}" #计算100000位耗时: 0.927s
#return f"2.{str(total)[1:digits+1]}" #计算100000位耗时: 0.925s
# 10,000位测试
digits = 1000000
start = time.time()
e_value = calculate_e_optimized(digits)
elapsed = time.time() - start
print(f"计算{digits}位耗时: {elapsed:.3f}s")
print("所需项数:", find_required_terms(digits))
print("前50位:", e_value[:52])
start = time.time()
e_two = calculate_e_two_parts(digits)
elapsed = time.time() - start
print(f"计算{digits}位耗时: {elapsed:.3f}s")
print("所需项数:", find_required_terms(digits))
print("前50位:", e_two[:52])
#print(e_value[40000:49999] , e_two[40000:49999])
for i in range(len(e_value)):
if e_value[i]!=e_two[i]:
print(i, "处不同")
break
print(e_value)
print()
print(e_two)
print(f"结果一致: {e_value == e_two}")
上述程序计算百万位的用时大约20秒。估计让他改成调用gmpy2库应当比较顺利,毕竟是机械翻译。可是想不到,它又做了自作聪明的修改,把计算迭代次数n的函数改错了。我用肉眼逐行比对,很快改正了,把改正后的丢给他,让他再照样出一个分三块计算的,这次非常顺利,上面的修改逻辑都理解对了。合并计算写得完美:
# 合并三部分结果
total = (numerator1 * denominator2 * denominator3 +
numerator2 * denominator3 +
numerator3) * power // (denominator1 * denominator2 * denominator3)
return f"2.{str(total % power).zfill(digits)}"
再让他举一反三,编写一个分n块的,他的炫技毛病又犯了,再次将计算迭代次数n的函数改错,本来照搬即可,结果写成这样,逻辑全错了
正确写法:
if stirling > digits * ln_10:
n_high = mid
else:
n_low = mid + 1
return n_low
错误写法:
n_high = mid if stirling > digits * ln_10 else (n_low := mid + 1)
return n_high
合并计算的循环也写错了,似乎对num和 den相乘有执念。
写成了:
# 合并所有部分
combined_num = parts[0][0]
combined_den = parts[0][1]
for num, den in parts[1:]:
combined_num = combined_num * den + num * combined_den
combined_den *= den
total = (combined_num * power) // combined_den
应该是
combined_num = 0
combined_den = 1
for num, den in reversed(parts):
combined_num += num * combined_den
combined_den *= den
total = (combined_num * power) // combined_den
无论如何,这些都是容易改的错误,最后的程序和运行结果如下,分块计算的效果还是相当明显的,不过边际效益递减:
import math
import time
import sys
import gmpy2
from gmpy2 import mpz, log, const_pi
#sys.set_int_max_str_digits(0)
# 设置gmpy2的精度
#gmpy2.get_context().precision = 1000000 # 设置足够的精度
def find_required_terms(digits):
"""精确计算所需最小n使得n! > 10^digits"""
ln_10 = log(mpz(10))
n_low, n_high = 0, 2 * (digits + 1)
while n_low < n_high:
mid = (n_low + n_high) // 2
if mid <= 1:
stirling = mpz(0)
else:
mid_mpz = mpz(mid)
stirling = mid_mpz * log(mid_mpz) - mid_mpz + log(2 * const_pi() * mid_mpz) / 2
if stirling > digits * ln_10:
n_high = mid
else:
n_low = mid + 1
#n_high = mid if stirling > digits * ln_10 else (n_low := mid + 1)
return n_low #high
def calculate_e_single(digits):
"""单块计算"""
n = find_required_terms(digits)
power = mpz(10)**digits
coeff = numerator = mpz(1)
for k in range(1, n+1):
coeff *= (n - k + 1)
numerator += coeff
total = (numerator * power) // coeff
return "2."+str(total % power) #f"2.{str((numerator * power) // coeff % power).zfill(digits)}"
def calculate_e_n_parts(digits, num_parts=2):
"""分n块计算"""
m = find_required_terms(digits)
power = mpz(10)**digits
part_size = m // num_parts
# 计算各部分的分子和分母
parts = []
for i in range(num_parts):
start = i * part_size
end = (i+1)*part_size if i < num_parts-1 else m
coeff = numerator = mpz(1)
for k in range(1, end-start+1):
coeff *= (end - k + 1)
numerator += coeff
# 最后一项多加了,需要减去
if i > 0:
numerator -= coeff
parts.append((numerator, coeff))
# 合并所有部分
combined_num = 0
combined_den = 1
for num, den in reversed(parts):
combined_num += num * combined_den
combined_den *= den
total = (combined_num * power) // combined_den
return "2."+str(total % power) #f"2.{str(total % power).zfill(digits)}"
# 测试配置
digits = 1000000
test_cases = [
("单部分计算", 1),
("两部分计算", 2),
("三部分计算", 3),
("四部分计算", 4),
("五部分计算", 5)
]
print("计算e的",digits,"位近似值 (使用gmpy2优化)\n")
# 基准测试
reference = None
for name, parts in test_cases:
start = time.time()
result = calculate_e_n_parts(digits, parts) if parts > 1 else calculate_e_single(digits)
elapsed = time.time() - start
print(name,"耗时: ",elapsed)
#print(f"所需项数: {find_required_terms(digits)}")
print("前50位: ",result[:52])
# 验证结果一致性
if reference is None:
reference = result
else:
mismatch = next((i for i in range(len(reference)) if reference[i] != result[i]), None)
print("结果验证: "+'一致' if mismatch is None else '第'+str(mismatch)+'位不一致')
print()
计算e的 1000000 位近似值 (使用gmpy2优化)
单部分计算 耗时: 16.216398000717163
前50位: 2.71828182845904523536028747135266249775724709369995
两部分计算 耗时: 6.306462526321411
前50位: 2.71828182845904523536028747135266249775724709369995
结果验证: 一致
三部分计算 耗时: 4.106442451477051
前50位: 2.71828182845904523536028747135266249775724709369995
结果验证: 一致
四部分计算 耗时: 3.1275036334991455
前50位: 2.71828182845904523536028747135266249775724709369995
结果验证: 一致
五部分计算 耗时: 2.5681347846984863
前50位: 2.71828182845904523536028747135266249775724709369995
结果验证: 一致
这个程序在4部分的时候已经超越了Xavier Gourdon大师在1999年编写的C语言程序e.c,他用到了二进制分裂法,还自行实现了大数运算和FFT。
gcc e.c FFT.c BigInt.c -o e -O3 -lm
./e
*** E computation ***
Enter the number of decimal digits : 100
Total Allocated memory = 9 K
Starting series computation
Starting final division
Total time : 0.00 seconds
Worst error in FFT (should be less than 0.25): 0.0000001192
E = 2
7182818284 5904523536 0287471352 6624977572 4709369995 : 50
9574966967 6277240766 3035354759 4571382178 5251664274 : 100
time ./e <1m.txt >e1m.txt
real 0m3.770s
user 0m3.752s
sys 0m0.000s
我还尝试把最后合并计算由先加后除改为先除后加,减少numerator1大数运算的位数,不过结果完全没有提升,也是出乎意料。
# 合并两部分结果
#total = (numerator1 * denominator2 + numerator2 ) * power // (denominator1 * denominator2)#先加后除
total = numerator1 * power // denominator1 + numerator2 * power // (denominator1 * denominator2) #先除后加
在等待github可访问期间,通过一篇博客文章《π和e的无穷级数计算》找到如下地址的c++程序,十分精炼同时速度也很快,大约半秒完成一百万位。也是二进制分裂法的。让deepseek首次转换后,速度反而慢很多,要13秒,经过分析,除法计算用了return f"{mpfr(2 + mpq(p, q)):.{actual_digits}f}"
,而python中mpfr库很慢,经过测试,把它改成整数算法return "2."+str(p*mpz(10)**target_digits // q)
就快了10倍。果然,找现成程序硬翻译,同时加一点计算机浮点运算慢的知识,优化最容易达成。最终python程序如下。
import math
import time
from gmpy2 import mpz, mpq, log2, log10, lgamma, get_context, fac
def calculate_e_terms(l, r):
if l == r:
return mpz(1), mpz(l)
mid = (l + r) // 2
p1, q1 = calculate_e_terms(l, mid)
p2, q2 = calculate_e_terms(mid + 1, r)
return p1 * q2 + p2, q1 * q2
def calculate_e_fast(n, target_digits):
"""优化版:纯整数运算生成小数部分"""
# 计算所需二进制精度
ln_n_fact = float(n * math.log(n) - n + 0.5*math.log(2*math.pi*n))
prec_bits = int((ln_n_fact + math.log(n)) / math.log(2)) + 128
get_context().precision = prec_bits
p, q = calculate_e_terms(2, n)
power = mpz(10)**target_digits
decimal_part = (p * power) // q # 关键优化点
return "2."+str(decimal_part)
def find_required_terms(digits):
"""计算所需最小n(使用斯特林公式)"""
ln10 = math.log(10)
n = 1
while True:
ln_n_fact = n * math.log(n) - n + 0.5*math.log(2*math.pi*n)
if (ln_n_fact - n * ln10) > digits * ln10:
return n
n += 1
if __name__ == "__main__":
target_digits = 1000000
n = find_required_terms(target_digits)
start = time.time()
e_str = calculate_e_fast(n, target_digits)
elapsed = time.time() - start
print("计算项数: ",n)
print("耗时: ",elapsed)
print(e_str)
执行时间如下:
time python3 gmp_e4.py > gmp_e4_1m.txt
real 0m1.131s
user 0m1.040s
sys 0m0.008s
最终实现了与c++程序类似的速度。