利用DeepSeek4个步骤1秒内将根号2计算到小数点后百万位 - l1t1/note GitHub Wiki

首先是来自常数计算简易程序网站的一个c程序sqrt2series.c ,抄录如下:

/*
** Pascal Sebah : July 1999
** 
** Subject:
**
**    A very easy program to compute sqrt(2) with many digits.
**    No optimisations, no tricks, just a basic program to learn how
**    to compute in multiprecision.  
**
** Formula:
**
**    sqrt(2) = (239/169)*1/sqrt(1-1/57122)
**
** Data:
**
**    A big real (or multiprecision real) is defined in base B as:
**      X = x(0) + x(1)/B^1 + ... + x(n-1)/B^(n-1)
**      where 0<=x(i)<B
**
** Results: (PentiumII, 450Mhz)
**    
**    1000   decimals :   0.02seconds
**    10000  decimals :   1.7s
**    100000 decimals : 176.0s
**
** With a little work it's possible to reduce those computation
** times by a factor of 3 and more.
*/
#include <stdio.h>
#include <malloc.h>
long B=10000; /* Working base */
long LB=4;    /* Log10(base)  */
/*
** Set the big real x to the small integer Integer 
*/
void SetToInteger (long n, long *x, long Integer) {
  long i;
  for (i=1; i<n; i++) x[i] = 0;
  x[0] = Integer;
}
/*
** Is the big real x equal to zero ?
*/
long IsZero (long n, long *x) {
  long i;
  for (i=0; i<n; i++)  
    if (x[i])	return 0;
	return 1;
}
/*
** Addition of big reals : x += y
**  Like school addition with carry management
*/
void Add (long n, long *x, long *y) {
  long carry=0, i;
  for (i=n-1; i>=0; i--) {
    x[i] += y[i]+carry;
    if (x[i]<B) carry = 0;
    else {
      carry = 1;
      x[i] -= B;
    }
  }  
}
/*
** Multiplication of the big real x by the integer q 
*/
void Mul (long n, long *x, long q) {
  long carry = 0, xi, i;
  for (i=n-1; i>=0; i--) {
    xi  = x[i]*q;		
    xi += carry;		
    if (xi>=B) {
      carry = xi/B;
      xi -= (carry*B);
    }
    else 
      carry = 0;
    x[i] = xi;
	}  
}
/*
** Division of the big real x by the integer d 
**  Like school division with carry management
*/
void Div (long n, long *x, long d) {
  long carry=0, xi, q, i;
  for (i=0; i<n; i++) {
    xi    = x[i]+carry*B;
    q     = xi/d;
    carry = xi-q*d;   
    x[i]  = q;        
  }  
}
/*
** Print the big real x
*/
void Print (long n, long *x) {
  long i;
  printf ("%d.", x[0]);
  for (i=1; i<n; i++) {
    printf ("%.4d", x[i]);
    if (i%25==0) printf ("%8d\n", i*4);
  }
  printf ("\n");
}
/*
** Computation of the constant sqrt(2)
*/
void main () {
  long NbDigits=10000, size=1+NbDigits/LB;
  long *r2 = (long *)malloc(size*sizeof(long));
  long *uk = (long *)malloc(size*sizeof(long));
  long k=1;
  /*
  ** Formula used :
  **    sqrt(2) = (239/169)*1/sqrt(1-1/57122)
  ** and
  **   1/sqrt(1-x) = 1+(1/2)x+(1.3)/(2.4)x^2+(1.3.5)/(2.4.6)x^3+...
  */
  SetToInteger (size, r2, 1); /* r2 = 1 */
  SetToInteger (size, uk, 1);	/* uk = 1 */
  while (!IsZero(size, uk)) {
    Div (size, uk, 57122); /* uk = u(k-1)/57122 * (2k-1)/(2k) */
    Div (size, uk, 2*k);  
    Mul (size, uk, 2*k-1);  
    Add (size, r2, uk);    /* r2 = r2+uk */
    k++;
  }
  Mul (size, r2, 239);  
  Div (size, r2, 169);  /* r2 = (239/169)*r2 */
  Print (size, r2);     /* Print out of sqrt(2) */
  free (r2);
 	free (uk);
}

这里的sqrt(2) = (239/169)*1/sqrt(1-1/57122)公式来自http://numbers.computation.free.fr/Constants/Sqrt2/sqrt2.html 泰勒展开式的例8。

设m=239,n=169,因为mm=2nn-1,所以sqrt(2) = m/nsqrt(1-1/(2nn)),化简即可得。为什么要把简单的根号2化成复杂的sqrt(1-1/57122)?因为还有一条公式:

1/sqrt(1-x) = 1+(1/2)x+(1.3)/(2.4)x^2+(1.3.5)/(2.4.6)x^3+...,

设x=1/57122,就可以用级数来计算1/sqrt(1-1/57122),然后把结果乘上239/169就可以得到sqrt(2)。

作者选用57122应该是因为当时的c编译器long是4字节。太大的数难以在long型数组的一个元素中完成运算。

第一步:让Deepseek改写成利用Python原生的高精度整数

前期丢给他main函数直接让他改写,它用了gmpy2库的浮点类型mpfr,效率较低,反复改了几轮没有改善。然后丢给他完整的sqrt2.c(我将sqrt2series.c改名了,便于操作)提示他:

请不用gmpy2库,原封不动地翻译附件sqrt2.c为python程序,包括各个函数和全局变量,用传统字符串,不用f"",只输出代码,不用给出任何文字说明,

这次Deepseek给出的程序完全能够正确执行,效率还可以。

我反馈:翻译的程序和c程序输出完全一致。谢谢,我把BASE改为10^8,LB = 4也输出了结果,前10000位也对。时间缩短1半。考虑到python原生的整数就是高精度的,直接把BASE改为10^10012,就不需要数组了,一个变量就行了。请按这个思路改写。

于是有了如下的第一版,算1万位0.04秒,比python BASE10000数组版本快几十倍。算10万位是0.9秒,而c程序要4.6秒:

def compute_sqrt2_single_var(NbDigits=10000):
    # 直接使用Python原生大整数,BASE=10^NbDigits
    BASE = 10 ** NbDigits
    r2 = BASE  # 初始化为1 * BASE
    uk = BASE  # 初始化为1 * BASE
    k = 1
    
    while uk != 0:
        uk = uk // 57122
        uk = uk // (2*k)
        uk = uk * (2*k - 1)
        r2 += uk
        k += 1
    
    r2 = r2 * 239
    r2 = r2 // 169
    
    # 格式化为字符串
    result_str = str(r2)
    return result_str[0] + '.' + result_str[1:NbDigits+1]

if __name__ == "__main__":
    digits = 10000
    sqrt2 = compute_sqrt2_single_var(digits)
    
    # 输出前100位验证
    print("√2的前100位:")
    print(sqrt2[:102])

第二步:利用佩尔方程计算更大的m,n对

提示Deepseek:请编写找出满足m^2=2n^2-1的整数对的python程序,尽量高效(我还不知道这是佩尔方程),结果他给出了:

def find_pell_solutions(max_k=10):
    """
    找出满足 m^2 = 2n^2 - 1 的整数对 (m, n)
    使用佩尔方程的递推关系式生成解序列
    """
    solutions = []
    # 初始解 (m1, n1) = (1, 1)
    m, n = 1, 1
    solutions.append((m, n))
    
    # 递推关系式: 
    # m_{k+1} = 3*m_k + 4*n_k
    # n_{k+1} = 2*m_k + 3*n_k
    for _ in range(1, max_k):
        m_new = 3 * m + 4 * n
        n_new = 2 * m + 3 * n
        solutions.append((m_new, n_new))
        m, n = m_new, n_new
    
    return solutions

# 示例:生成前10个解
solutions = find_pell_solutions(10)
for idx, (m, n) in enumerate(solutions, 1):
    print(f"解 {idx}: m = {m}, n = {n}")
    print(f"验证: {m}^2 = {m**2}, 2*{n}^2 - 1 = {2*n**2 - 1}\n")

经过验证,更大的佩尔解确实能提高计算效率。一开始收益挺多,后面就小了,是因为python大数运算效率随着数字增大而下降。

原方法(m,n=239,169): 21024次迭代, 0.923s

=== 使用第2组佩尔解 (m=9369319, n=6625109) ===
新方法: 7173次迭代, 0.516s
=== 使用第3组佩尔解 (m=423859315570607, n=299713796309065) ===
新方法: 3419次迭代, 0.263s
=== 使用第4组佩尔解 (m=19175002942688032928599, n=13558774610046711780701) ===
新方法: 2245次迭代, 0.216s
=== 使用第5组佩尔解 (m=867459377074481256712011306719, n=613386407933224037990008001809) ===
新方法: 1671次迭代, 0.180s

第三步,改用gmpy2库进行大数计算

只要很少量的修改,我自己完成了,对代码做了微小的改动,比如计算uk改为一步完成,uk = uk * mpz(2*k - 1)//(mpz(d)*mpz(2*k)))

import sys
sys.set_int_max_str_digits(0) 
from gmpy2 import mpz
def compute_sqrt2_single_var(NbDigits=1000000):
    # 直接使用Python原生大整数,BASE=10^NbDigits
    #m, n = mpz(239), mpz(169 ) #36.920 s
    #m, n = mpz(423859315570607), mpz(299713796309065 ) #6.976 s
    m, n = mpz(867459377074481256712011306719), mpz(613386407933224037990008001809 ) #3.580 s
    d=mpz(2) * n * n
    BASE = mpz(mpz(10) ** NbDigits)
    r2 = BASE  # 初始化为1 * BASE
    uk = BASE  # 初始化为1 * BASE
    k = 1
    
    while uk != mpz(0):
        uk = uk * mpz(2*k - 1)//(mpz(d)*mpz(2*k))
        r2 += uk
        k += 1
    
    r2 = r2 * m // n
    
    # 格式化为字符串
    result_str = str(r2)
    return result_str[0] + '.' + result_str[1:NbDigits+1]

if __name__ == "__main__":
    digits = 1000012
    sqrt2 = compute_sqrt2_single_var(digits)
    
    # 输出前100位验证
    print("√2的前1000000位最后10位:")
    print(sqrt2[1000002-10:1000002])

第四步,让Deepseek把程序改为多进程。

起初,他用参数来给各进程传递uk,结果光预计算这些就耗费了完整计算sqrt(2)的时间。

提示他:每部分只计算本部分的 mpz(2k - 1) // (d * mpz(2k))之和,比如: 1/sqrt(1-x) = 1+(1/2)x+(13)/(24)x^2+(135)/(246)x^3+(1357)/(2468)x^4,第一部分算1+(1/2)x+(13)/(24)x^2不变,第二部分算(5)/(6)x^1和(57)/(68)x^2,然后乘以第一部分的最后一项就得到了真正的第二部分,每部分都返回总和以及最后一项,最终计算的时候从第2部分开始,要乘的项逐步增加。

但他好像没理解,还坚持用预计算的方法。

于是继续提示:

我对比了单进程与多进程结果的每一位,结果正确,但用时多进程比单进程多了2倍。
调试发现:
1.并行度算上了超线程,我的8核CPU,cpu_count()算出了16,num_workers改成8,速度基本不变。
2.单线程预计算确定总迭代次数用时和单进程计算全部结果用时一样,等于这步就抵消了优化,能否用数学方法推算。
3.预计算后续块的初始uk值,包括所有任务下一块的uk值加在一起也差不多和单进程计算全部结果用时一样
4.真正的并行计算确实比单进程节约了时间,
下面逐步解决这些问题,我们先解决第2点用数学方法推算总迭代次数。我给出一些特殊值,1万位:168,10万位:1671,50万位:8351,100万位:16702。大约是0.0168*位数,用它替换实际计算,快了1/3,结果不变,所以第2点问题解决。
请提供第3点的优化思路,我上次说的每部分返回最后一次迭代值和总和的思路能否实现?

又提示:

并行程序还存在两个问题,第一,经验值0.0168只适用于特定的n值,导致当n更大时迭代过多,比如n=mpz(5257949235243715217981999992354509006832052451463219203025345254761919215209 ),100万位精度,迭代6592次,uk就=0了,剩下的迭代都是空转。 第二,不均衡,以分为2部分为例,前3296 次迭代耗费70%的时间,因为uk的长度长,从100万位到50万位,后3296 次迭代只要30%时间,因为uk的长度短,从50万位到1位,请提供更合理的分配策略,只要文字说明,不写代码

还是没改好,只好基于他的版本自己动手改,其中动态计算迭代次数的公式int((log(10)*10**6+log(d))/(log(2*d)-1))是deepseek给出的。

把这里的10**6替换成变量p, 代表最终精度,移项可得p=int(k*(log10(2*d)-1)-log10(d), 于是有如下计算在p,d条件下,在第k次迭代的安全最小uk估计值的函数。def find_uk(p,k,d):return p-int(k*(log10(2*d)-1)-log10(d))

动态sub_uk数组是基于观察到,uk的长度随迭代次数增加逐步变短,最后uk=0,在单进程处理时,同样的迭代次数,如1000次,前面的用时大于后面。

在分块处理时,为了简化操作,给各块初始uk统一赋予了BASE,结果就是每块都和第一块一样慢。预计算能得到任何区间的sub_uk值,但需要双倍时间,更好的办法是通过数学公式计算,这就是上述函数的用武之地。

经过多次测试,得出100万位精度在本例d的每块迭代次数上限经验数据如下:[500, 1000, 1500, 2000,3000,4000,5000, 7000], 此时各块的运行时间大致相同。

完整程序如下:

import sys
from gmpy2 import mpz, log10
from multiprocessing import Pool, cpu_count, freeze_support
from math import log
import time
def compute_sqrt2_single_var(NbDigits=1000000):
    t=time.time()
    # 直接使用Python原生大整数,BASE=10^NbDigits
    #m, n = mpz(239), mpz(169 ) #36.920 s
    #m, n = mpz(423859315570607), mpz(299713796309065 ) #6.976 s
    #m, n = mpz(867459377074481256712011306719), mpz(613386407933224037990008001809 ) # 3.580 s
    m, n = mpz(7435863118750905263055986496384859433567398577182017960020667655210005041519), mpz(5257949235243715217981999992354509006832052451463219203025345254761919215209 ) #2.237
    d=mpz(2) * n * n
    BASE = mpz(mpz(10) ** NbDigits)
    r2 = BASE  # 初始化为1 * BASE
    uk = BASE  # 初始化为1 * BASE
    k = 1
    
    while uk != mpz(0):
        uk = uk * mpz(2*k - 1)//(mpz(d)*mpz(2*k))
        r2 += uk
        k += 1
    
    r2 = r2 * m // n
    
    # 格式化为字符串
    result_str = str(r2)
    print("单进程", time.time()-t)
    return result_str[0] + '.' + result_str[1:NbDigits+1]


def compute_segment(args):
    t=time.time()
    """计算一个区间的部分和"""
    start_k, end_k, initial_uk, d = args
    uk = initial_uk
    segment_sum = mpz(0)
    
    for k in range(start_k, end_k + 1):
        uk = uk * (2*k - 1) // (d * 2*k)
        segment_sum += uk
    print("区间",start_k, time.time()-t)
    return (start_k, segment_sum, uk)  # 返回部分和及最后的uk值

def parallel_sqrt2(NbDigits=1000000):
    def find_uk(p,k,d):return p-int(k*(log10(2*d)-1)-log10(d))
    
    t=time.time()
    #m = mpz(867459377074481256712011306719)
    #n = mpz(613386407933224037990008001809)
    m, n = mpz(7435863118750905263055986496384859433567398577182017960020667655210005041519), mpz(5257949235243715217981999992354509006832052451463219203025345254761919215209 ) #2.237

    d = mpz(2) * n * n
    BASE = mpz(10)**NbDigits
    
    # 公式计算确定总迭代次数
    uk = BASE

    total_iters =int((log(10)*10**6+log(d))/(log(2*d)-1)) #total_iters =int(0.0168*NbDigits)
    #print(total_iters, time.time()-t)

    # 分块计算
    num_workers = min(cpu_count(), 8)
    chunk_size = (total_iters + num_workers - 1) // num_workers
    tasks = []
    current_uk = BASE
    
    # 第一块特殊处理
    start_k = 1
    end_k = chunk_size
    end_ks=[500, 1000, 1500, 2000,3000,4000,5000, 7000] #人工对8块分配上限


    print(end_ks)

    # 创建剩余任务

    p=10**6
    sub_uk=[0]
    for i in range(len((end_ks))):
        #print(i, find_uk(p,end_ks[i],d))
        sub_uk.append(mpz(10**find_uk(p,end_ks[i],d))) #添加公式计算的uk

    end_k =end_ks[0] #第一块的上限
    tasks.append((start_k, end_k, current_uk, d))    
    for i in range(1, num_workers):
        start_k = end_k + 1
        #end_k = min(start_k + chunk_size - 1, total_iters)
        end_k = min(end_ks[i], total_iters) #每块的上限
        tasks.append((start_k, end_k,sub_uk[i] , d))


    # 并行计算
    with Pool(num_workers) as pool:
        results = pool.map(compute_segment, tasks)
    

    # 合并结果
    #r2 = BASE + sum(seg_sum for seg_sum, _ in results)
    r2=BASE
    cum_end=mpz(1)
    results.sort(key=lambda x:x[0])
    k=0
    for start_k, seg_sum, end_uk in results:
        if cum_end==mpz(1):
            r2=r2+seg_sum*cum_end
            cum_end=cum_end*end_uk
        else:
            r2=r2+seg_sum*cum_end//sub_uk[k]
            cum_end=cum_end*end_uk//sub_uk[k]
        k=k+1
    # 最终处理

    final_result = str(r2 * m // n)
    print("多进程并行", time.time()-t)
    return final_result[0] + '.' + final_result[1:NbDigits+1]

if __name__ == "__main__":
    #sys.set_int_max_str_digits(0)
    freeze_support()
    digits = 1000012
    sqrt2 = parallel_sqrt2(digits)
    print("最后10位:", sqrt2[:52], sqrt2[1000002-10:1000002])
    new=sqrt2
    old= compute_sqrt2_single_var(NbDigits=digits)
            # 比较结果
    diff_pos = None
    for i in range(min(len(old), len(new))):
        if old[i] != new[i]:
            diff_pos = i;print(i, old[i] , new[i] )
            break
    

上述程序在鲲鹏920 8核CPU上能在1.8秒内算出一百万位, 在AMD 8845HS上的测试数据如下, digits = 1000012余量是为了保证前100万位正确。这里输出的最后十位9048412043,经与 http://www.numberworld.org/digits/Sqrt(2)/ 给出的正确值核对无误。

C:\d>python rt2.py
[500, 1000, 1500, 2000, 3000, 4000, 5000, 7000]
区间 5001 0.11553835868835449
区间 4001 0.2580108642578125
区间 1501 0.3613595962524414
区间 1001 0.41032886505126953
区间 3001 0.41078948974609375
区间 501 0.4555642604827881
区间 1 0.47553300857543945
区间 2001 0.5027995109558105
多进程并行 1.3588900566101074
最后10位: 1.41421356237309504880168872420969807856967187537694 9048412043
单进程 1.9617667198181152
1000009 0 1
⚠️ **GitHub.com Fallback** ⚠️