SchonhageStrassen乘法实现

2024-06-01 02:49:10 发布

您现在位置:Python中文网/ 问答频道 /正文

我试图用NTT实现Schonhage-Strassen乘法算法,但遇到了一个问题,最终得到的向量实际上并不等于它应该是什么。你知道吗

对于两个输入向量ab,每个输入向量由K位的N“数字”(每个集合的最终N/2项为0)组成,每个输入向量给定一个模M = 2^(2*K)+1、一个单位根w = N^(4*K-1) | w^N = 1 mod M、该值的模逆wi | wi*w = 1 mod Mu | u*N = 1 mod M,以下python代码用于(尝试)使用Schonhage-Strassen算法将这些向量相乘:

#a and b are lists of length N, representing large integers
A = [ sum([ (a[i]*pow(w,i*j,M))%M for i in range(N)]) for j in range(N)] #NTT of a
B = [ sum([ (b[i]*pow(w,i*j,M))%M for i in range(N)]) for j in range(N)] #NTT of b
C = [ (A[i]*B[i])%M for i in range(N)] #A * B multiplied pointwise
c = [ sum([ (C[i]*pow(wi,i*j,M))%M for i in range(N)]) for j in range(N)] #intermediate step in INTT of C
ci = [ (i*u)%M for i in c] #INTT of C, should be product of a and b

理论上,取ab的NTT,逐点相乘,然后取结果的INTT,如果我没有弄错的话,应该得到乘积,我已经测试了NTT和INTT的这些方法,以确认它们是彼此相反的。然而,最终得到的向量ci,不是等于ab的乘积,而是每个元素取模M的乘积,给出了不正确的乘积结果。你知道吗

例如,使用N=K=8a, b的随机向量运行测试,得到以下结果:

M = 2^(2*8)+1 = 65537
w = 16, wi = 61441
u = 57345
a = [212, 251, 84, 186, 0, 0, 0, 0] (3126131668 as an integer)
b = [180, 27, 234, 225, 0, 0, 0, 0] (3790216116)
NTT(a) = [733, 66681, 147842, 92262, 130933, 107825, 114562, 127302]
NTT(b) = [666, 64598, 80332, 54468, 131236, 186644, 181708, 88232]
Pointwise product of above two lines mod M = [29419, 39913, 25015, 14993, 42695, 49488, 52438, 51319]
INTT of above line (i.e. result) = [38160, 50904, 5968, 11108, 15616, 62424, 41850, 0] (11848430946168040720)
Actual product of a x b = [38160, 50904, 71505, 142182, 81153, 62424, 41850, 0] (11848714628791561488)

在这个例子中,几乎每次我尝试它时,实际乘积的元素和我的算法的结果在向量的开始和结束附近的几个元素是相同的,但是在向量的中间它们会偏离。如上所述,ci的元素都等于a*bM的元素。我一定是误解了这个算法,虽然我不完全确定是什么。我用错模数了吗?你知道吗


Tags: ofin算法cimod元素forrange
1条回答
网友
1楼 · 发布于 2024-06-01 02:49:10

谨防数量理论,<强> NTT <强>不是我的专业领域,所以用偏见来阅读,但我自己成功地在<强> C++ +<强>中实现<强> NTT <强>,并将其用于Bigimm乘法(^ {CD1>},^ {< CD2>},^ {CD3>}),因此这里有一些我的研究。我强烈建议你先阅读我的两篇QA:

所以你可以把你的结果/代码/常量和我的比较。然而,我发展了我的NTT来使用单一的硬编码素数(适合32位值的最大单位根)。你知道吗

现在你的代码有什么问题。我不使用python编写代码,但在您的问题中没有看到NTT代码。不管怎样,从我看到的情况来看:

  1. 检查你的根或团结

    在你的问题中你提到了条件:

    w^N = 1 mod M
    

    但这远远不够。请参阅上面的第一个链接,它描述了必须满足的所有条件(使用查找和检查它的代码)。我不确定您的参数是否符合所有需要的条件,您只是忘记或遗漏了这些条件,或者没有所以请检查它。IIRC我也与这些条件作斗争,因为在我编码的时候,我掌握的NTT信息非常少,大多数信息不完整或错误。。。

  2. 您没有使用模块化算法!!!

    我假设你的素数是M(在我的术语中是p),所以所有的子结果都必须小于M,这在你的例子中显然不是真的:

    M = 65537
    NTT(a) = [733, 66681, 147842, 92262, 130933, 107825, 114562, 127302]
    NTT(b) = [666, 64598, 80332, 54468, 131236, 186644, 181708, 88232]
    

    正如您所看到的,只有两个ntt的第一个元素是有效的,所有其他元素都大于M这是错误的!!!

  3. 小心溢出

    你的M非常小~16bit相比之下,你的输入值看起来~8bit会很快溢出,从而使你的NTT结果失效。你知道吗

    这里引用了我的第二个链接,我找到了一个艰难而经验主义的方法:

    To avoid overflows for big datasets, limit input numbers to p/4 bits. Where p is number of bits per NTT element so for this 32 bit version use max (32 bit/4 -> 8 bit) input values.

    因此,在您的情况下,您应该处理16/4 = 4bit块而不是8位,或者使用更大的M,例如我的0xC0000001,它是~32bit。你知道吗

    这就解释了你的观察:产品的第一个要素是好的,然后不是。。。要知道,如果你把2个8位的数字相乘,你会得到16位。。。现在意识到您正在对乘法子结果进行更多的递归加法,因此它很快就会在第二个值中超过16位M。。。

总之,您没有使用模块化算术,素数太小和/或处理的数据块太大,也可能选择了错误的素数。你知道吗

相关问题 更多 >