利用DeepSeek4个步骤1秒内将根号2计算到小数点后百万位 - l1t1/note GitHub Wiki
首先是来自常数计算简易程序网站的一个c程序sqrt2series.c ,抄录如下:
/*
** Pascal Sebah : July 1999
**
** Subject:
**
** A very easy program to compute sqrt(2) with many digits.
** No optimisations, no tricks, just a basic program to learn how
** to compute in multiprecision.
**
** Formula:
**
** sqrt(2) = (239/169)*1/sqrt(1-1/57122)
**
** Data:
**
** A big real (or multiprecision real) is defined in base B as:
** X = x(0) + x(1)/B^1 + ... + x(n-1)/B^(n-1)
** where 0<=x(i)<B
**
** Results: (PentiumII, 450Mhz)
**
** 1000 decimals : 0.02seconds
** 10000 decimals : 1.7s
** 100000 decimals : 176.0s
**
** With a little work it's possible to reduce those computation
** times by a factor of 3 and more.
*/
#include <stdio.h>
#include <malloc.h>
long B=10000; /* Working base */
long LB=4; /* Log10(base) */
/*
** Set the big real x to the small integer Integer
*/
void SetToInteger (long n, long *x, long Integer) {
long i;
for (i=1; i<n; i++) x[i] = 0;
x[0] = Integer;
}
/*
** Is the big real x equal to zero ?
*/
long IsZero (long n, long *x) {
long i;
for (i=0; i<n; i++)
if (x[i]) return 0;
return 1;
}
/*
** Addition of big reals : x += y
** Like school addition with carry management
*/
void Add (long n, long *x, long *y) {
long carry=0, i;
for (i=n-1; i>=0; i--) {
x[i] += y[i]+carry;
if (x[i]<B) carry = 0;
else {
carry = 1;
x[i] -= B;
}
}
}
/*
** Multiplication of the big real x by the integer q
*/
void Mul (long n, long *x, long q) {
long carry = 0, xi, i;
for (i=n-1; i>=0; i--) {
xi = x[i]*q;
xi += carry;
if (xi>=B) {
carry = xi/B;
xi -= (carry*B);
}
else
carry = 0;
x[i] = xi;
}
}
/*
** Division of the big real x by the integer d
** Like school division with carry management
*/
void Div (long n, long *x, long d) {
long carry=0, xi, q, i;
for (i=0; i<n; i++) {
xi = x[i]+carry*B;
q = xi/d;
carry = xi-q*d;
x[i] = q;
}
}
/*
** Print the big real x
*/
void Print (long n, long *x) {
long i;
printf ("%d.", x[0]);
for (i=1; i<n; i++) {
printf ("%.4d", x[i]);
if (i%25==0) printf ("%8d\n", i*4);
}
printf ("\n");
}
/*
** Computation of the constant sqrt(2)
*/
void main () {
long NbDigits=10000, size=1+NbDigits/LB;
long *r2 = (long *)malloc(size*sizeof(long));
long *uk = (long *)malloc(size*sizeof(long));
long k=1;
/*
** Formula used :
** sqrt(2) = (239/169)*1/sqrt(1-1/57122)
** and
** 1/sqrt(1-x) = 1+(1/2)x+(1.3)/(2.4)x^2+(1.3.5)/(2.4.6)x^3+...
*/
SetToInteger (size, r2, 1); /* r2 = 1 */
SetToInteger (size, uk, 1); /* uk = 1 */
while (!IsZero(size, uk)) {
Div (size, uk, 57122); /* uk = u(k-1)/57122 * (2k-1)/(2k) */
Div (size, uk, 2*k);
Mul (size, uk, 2*k-1);
Add (size, r2, uk); /* r2 = r2+uk */
k++;
}
Mul (size, r2, 239);
Div (size, r2, 169); /* r2 = (239/169)*r2 */
Print (size, r2); /* Print out of sqrt(2) */
free (r2);
free (uk);
}
这里的sqrt(2) = (239/169)*1/sqrt(1-1/57122)公式来自http://numbers.computation.free.fr/Constants/Sqrt2/sqrt2.html 泰勒展开式的例8。
设m=239,n=169,因为mm=2nn-1,所以sqrt(2) = m/nsqrt(1-1/(2nn)),化简即可得。为什么要把简单的根号2化成复杂的sqrt(1-1/57122)?因为还有一条公式:
1/sqrt(1-x) = 1+(1/2)x+(1.3)/(2.4)x^2+(1.3.5)/(2.4.6)x^3+...,
设x=1/57122,就可以用级数来计算1/sqrt(1-1/57122),然后把结果乘上239/169就可以得到sqrt(2)。
作者选用57122应该是因为当时的c编译器long是4字节。太大的数难以在long型数组的一个元素中完成运算。
前期丢给他main函数直接让他改写,它用了gmpy2库的浮点类型mpfr,效率较低,反复改了几轮没有改善。然后丢给他完整的sqrt2.c(我将sqrt2series.c改名了,便于操作)提示他:
请不用gmpy2库,原封不动地翻译附件sqrt2.c为python程序,包括各个函数和全局变量,用传统字符串,不用f"",只输出代码,不用给出任何文字说明,
这次Deepseek给出的程序完全能够正确执行,效率还可以。
我反馈:翻译的程序和c程序输出完全一致。谢谢,我把BASE改为10^8,LB = 4也输出了结果,前10000位也对。时间缩短1半。考虑到python原生的整数就是高精度的,直接把BASE改为10^10012,就不需要数组了,一个变量就行了。请按这个思路改写。
于是有了如下的第一版,算1万位0.04秒,比python BASE10000数组版本快几十倍。算10万位是0.9秒,而c程序要4.6秒:
def compute_sqrt2_single_var(NbDigits=10000):
# 直接使用Python原生大整数,BASE=10^NbDigits
BASE = 10 ** NbDigits
r2 = BASE # 初始化为1 * BASE
uk = BASE # 初始化为1 * BASE
k = 1
while uk != 0:
uk = uk // 57122
uk = uk // (2*k)
uk = uk * (2*k - 1)
r2 += uk
k += 1
r2 = r2 * 239
r2 = r2 // 169
# 格式化为字符串
result_str = str(r2)
return result_str[0] + '.' + result_str[1:NbDigits+1]
if __name__ == "__main__":
digits = 10000
sqrt2 = compute_sqrt2_single_var(digits)
# 输出前100位验证
print("√2的前100位:")
print(sqrt2[:102])
提示Deepseek:请编写找出满足m^2=2n^2-1的整数对的python程序,尽量高效(我还不知道这是佩尔方程),结果他给出了:
def find_pell_solutions(max_k=10):
"""
找出满足 m^2 = 2n^2 - 1 的整数对 (m, n)
使用佩尔方程的递推关系式生成解序列
"""
solutions = []
# 初始解 (m1, n1) = (1, 1)
m, n = 1, 1
solutions.append((m, n))
# 递推关系式:
# m_{k+1} = 3*m_k + 4*n_k
# n_{k+1} = 2*m_k + 3*n_k
for _ in range(1, max_k):
m_new = 3 * m + 4 * n
n_new = 2 * m + 3 * n
solutions.append((m_new, n_new))
m, n = m_new, n_new
return solutions
# 示例:生成前10个解
solutions = find_pell_solutions(10)
for idx, (m, n) in enumerate(solutions, 1):
print(f"解 {idx}: m = {m}, n = {n}")
print(f"验证: {m}^2 = {m**2}, 2*{n}^2 - 1 = {2*n**2 - 1}\n")
经过验证,更大的佩尔解确实能提高计算效率。一开始收益挺多,后面就小了,是因为python大数运算效率随着数字增大而下降。
原方法(m,n=239,169): 21024次迭代, 0.923s
=== 使用第2组佩尔解 (m=9369319, n=6625109) ===
新方法: 7173次迭代, 0.516s
=== 使用第3组佩尔解 (m=423859315570607, n=299713796309065) ===
新方法: 3419次迭代, 0.263s
=== 使用第4组佩尔解 (m=19175002942688032928599, n=13558774610046711780701) ===
新方法: 2245次迭代, 0.216s
=== 使用第5组佩尔解 (m=867459377074481256712011306719, n=613386407933224037990008001809) ===
新方法: 1671次迭代, 0.180s
只要很少量的修改,我自己完成了,对代码做了微小的改动,比如计算uk改为一步完成,uk = uk * mpz(2*k - 1)//(mpz(d)*mpz(2*k)))
。
import sys
sys.set_int_max_str_digits(0)
from gmpy2 import mpz
def compute_sqrt2_single_var(NbDigits=1000000):
# 直接使用Python原生大整数,BASE=10^NbDigits
#m, n = mpz(239), mpz(169 ) #36.920 s
#m, n = mpz(423859315570607), mpz(299713796309065 ) #6.976 s
m, n = mpz(867459377074481256712011306719), mpz(613386407933224037990008001809 ) #3.580 s
d=mpz(2) * n * n
BASE = mpz(mpz(10) ** NbDigits)
r2 = BASE # 初始化为1 * BASE
uk = BASE # 初始化为1 * BASE
k = 1
while uk != mpz(0):
uk = uk * mpz(2*k - 1)//(mpz(d)*mpz(2*k))
r2 += uk
k += 1
r2 = r2 * m // n
# 格式化为字符串
result_str = str(r2)
return result_str[0] + '.' + result_str[1:NbDigits+1]
if __name__ == "__main__":
digits = 1000012
sqrt2 = compute_sqrt2_single_var(digits)
# 输出前100位验证
print("√2的前1000000位最后10位:")
print(sqrt2[1000002-10:1000002])
起初,他用参数来给各进程传递uk,结果光预计算这些就耗费了完整计算sqrt(2)的时间。
提示他:每部分只计算本部分的 mpz(2k - 1) // (d * mpz(2k))之和,比如: 1/sqrt(1-x) = 1+(1/2)x+(13)/(24)x^2+(135)/(246)x^3+(1357)/(2468)x^4,第一部分算1+(1/2)x+(13)/(24)x^2不变,第二部分算(5)/(6)x^1和(57)/(68)x^2,然后乘以第一部分的最后一项就得到了真正的第二部分,每部分都返回总和以及最后一项,最终计算的时候从第2部分开始,要乘的项逐步增加。
但他好像没理解,还坚持用预计算的方法。
于是继续提示:
我对比了单进程与多进程结果的每一位,结果正确,但用时多进程比单进程多了2倍。
调试发现:
1.并行度算上了超线程,我的8核CPU,cpu_count()算出了16,num_workers改成8,速度基本不变。
2.单线程预计算确定总迭代次数用时和单进程计算全部结果用时一样,等于这步就抵消了优化,能否用数学方法推算。
3.预计算后续块的初始uk值,包括所有任务下一块的uk值加在一起也差不多和单进程计算全部结果用时一样
4.真正的并行计算确实比单进程节约了时间,
下面逐步解决这些问题,我们先解决第2点用数学方法推算总迭代次数。我给出一些特殊值,1万位:168,10万位:1671,50万位:8351,100万位:16702。大约是0.0168*位数,用它替换实际计算,快了1/3,结果不变,所以第2点问题解决。
请提供第3点的优化思路,我上次说的每部分返回最后一次迭代值和总和的思路能否实现?
又提示:
并行程序还存在两个问题,第一,经验值0.0168只适用于特定的n值,导致当n更大时迭代过多,比如n=mpz(5257949235243715217981999992354509006832052451463219203025345254761919215209 ),100万位精度,迭代6592次,uk就=0了,剩下的迭代都是空转。 第二,不均衡,以分为2部分为例,前3296 次迭代耗费70%的时间,因为uk的长度长,从100万位到50万位,后3296 次迭代只要30%时间,因为uk的长度短,从50万位到1位,请提供更合理的分配策略,只要文字说明,不写代码
还是没改好,只好基于他的版本自己动手改,其中动态计算迭代次数的公式int((log(10)*10**6+log(d))/(log(2*d)-1))
是deepseek给出的。
把这里的10**6替换成变量p, 代表最终精度,移项可得p=int(k*(log10(2*d)-1)-log10(d), 于是有如下计算在p,d条件下,在第k次迭代的安全最小uk估计值的函数。def find_uk(p,k,d):return p-int(k*(log10(2*d)-1)-log10(d))
。
动态sub_uk数组是基于观察到,uk的长度随迭代次数增加逐步变短,最后uk=0,在单进程处理时,同样的迭代次数,如1000次,前面的用时大于后面。
在分块处理时,为了简化操作,给各块初始uk统一赋予了BASE,结果就是每块都和第一块一样慢。预计算能得到任何区间的sub_uk值,但需要双倍时间,更好的办法是通过数学公式计算,这就是上述函数的用武之地。
经过多次测试,得出100万位精度在本例d的每块迭代次数上限经验数据如下:[500, 1000, 1500, 2000,3000,4000,5000, 7000]
, 此时各块的运行时间大致相同。
完整程序如下:
import sys
from gmpy2 import mpz, log10
from multiprocessing import Pool, cpu_count, freeze_support
from math import log
import time
def compute_sqrt2_single_var(NbDigits=1000000):
t=time.time()
# 直接使用Python原生大整数,BASE=10^NbDigits
#m, n = mpz(239), mpz(169 ) #36.920 s
#m, n = mpz(423859315570607), mpz(299713796309065 ) #6.976 s
#m, n = mpz(867459377074481256712011306719), mpz(613386407933224037990008001809 ) # 3.580 s
m, n = mpz(7435863118750905263055986496384859433567398577182017960020667655210005041519), mpz(5257949235243715217981999992354509006832052451463219203025345254761919215209 ) #2.237
d=mpz(2) * n * n
BASE = mpz(mpz(10) ** NbDigits)
r2 = BASE # 初始化为1 * BASE
uk = BASE # 初始化为1 * BASE
k = 1
while uk != mpz(0):
uk = uk * mpz(2*k - 1)//(mpz(d)*mpz(2*k))
r2 += uk
k += 1
r2 = r2 * m // n
# 格式化为字符串
result_str = str(r2)
print("单进程", time.time()-t)
return result_str[0] + '.' + result_str[1:NbDigits+1]
def compute_segment(args):
t=time.time()
"""计算一个区间的部分和"""
start_k, end_k, initial_uk, d = args
uk = initial_uk
segment_sum = mpz(0)
for k in range(start_k, end_k + 1):
uk = uk * (2*k - 1) // (d * 2*k)
segment_sum += uk
print("区间",start_k, time.time()-t)
return (start_k, segment_sum, uk) # 返回部分和及最后的uk值
def parallel_sqrt2(NbDigits=1000000):
def find_uk(p,k,d):return p-int(k*(log10(2*d)-1)-log10(d))
t=time.time()
#m = mpz(867459377074481256712011306719)
#n = mpz(613386407933224037990008001809)
m, n = mpz(7435863118750905263055986496384859433567398577182017960020667655210005041519), mpz(5257949235243715217981999992354509006832052451463219203025345254761919215209 ) #2.237
d = mpz(2) * n * n
BASE = mpz(10)**NbDigits
# 公式计算确定总迭代次数
uk = BASE
total_iters =int((log(10)*10**6+log(d))/(log(2*d)-1)) #total_iters =int(0.0168*NbDigits)
#print(total_iters, time.time()-t)
# 分块计算
num_workers = min(cpu_count(), 8)
chunk_size = (total_iters + num_workers - 1) // num_workers
tasks = []
current_uk = BASE
# 第一块特殊处理
start_k = 1
end_k = chunk_size
end_ks=[500, 1000, 1500, 2000,3000,4000,5000, 7000] #人工对8块分配上限
print(end_ks)
# 创建剩余任务
p=10**6
sub_uk=[0]
for i in range(len((end_ks))):
#print(i, find_uk(p,end_ks[i],d))
sub_uk.append(mpz(10**find_uk(p,end_ks[i],d))) #添加公式计算的uk
end_k =end_ks[0] #第一块的上限
tasks.append((start_k, end_k, current_uk, d))
for i in range(1, num_workers):
start_k = end_k + 1
#end_k = min(start_k + chunk_size - 1, total_iters)
end_k = min(end_ks[i], total_iters) #每块的上限
tasks.append((start_k, end_k,sub_uk[i] , d))
# 并行计算
with Pool(num_workers) as pool:
results = pool.map(compute_segment, tasks)
# 合并结果
#r2 = BASE + sum(seg_sum for seg_sum, _ in results)
r2=BASE
cum_end=mpz(1)
results.sort(key=lambda x:x[0])
k=0
for start_k, seg_sum, end_uk in results:
if cum_end==mpz(1):
r2=r2+seg_sum*cum_end
cum_end=cum_end*end_uk
else:
r2=r2+seg_sum*cum_end//sub_uk[k]
cum_end=cum_end*end_uk//sub_uk[k]
k=k+1
# 最终处理
final_result = str(r2 * m // n)
print("多进程并行", time.time()-t)
return final_result[0] + '.' + final_result[1:NbDigits+1]
if __name__ == "__main__":
#sys.set_int_max_str_digits(0)
freeze_support()
digits = 1000012
sqrt2 = parallel_sqrt2(digits)
print("最后10位:", sqrt2[:52], sqrt2[1000002-10:1000002])
new=sqrt2
old= compute_sqrt2_single_var(NbDigits=digits)
# 比较结果
diff_pos = None
for i in range(min(len(old), len(new))):
if old[i] != new[i]:
diff_pos = i;print(i, old[i] , new[i] )
break
上述程序在鲲鹏920 8核CPU上能在1.8秒内算出一百万位, 在AMD 8845HS上的测试数据如下, digits = 1000012余量是为了保证前100万位正确。这里输出的最后十位9048412043,经与 http://www.numberworld.org/digits/Sqrt(2)/ 给出的正确值核对无误。
C:\d>python rt2.py
[500, 1000, 1500, 2000, 3000, 4000, 5000, 7000]
区间 5001 0.11553835868835449
区间 4001 0.2580108642578125
区间 1501 0.3613595962524414
区间 1001 0.41032886505126953
区间 3001 0.41078948974609375
区间 501 0.4555642604827881
区间 1 0.47553300857543945
区间 2001 0.5027995109558105
多进程并行 1.3588900566101074
最后10位: 1.41421356237309504880168872420969807856967187537694 9048412043
单进程 1.9617667198181152
1000009 0 1