1 // Fast integer multiplication using FFT in a modular ring.
2 // Bruno Haible 14.5.,16.5.1996
4 // FFT in the complex domain has the drawback that it needs careful round-off
5 // error analysis. So here we choose another field of characteristic 0: Q_p.
6 // Since Q_p contains exactly the (p-1)th roots of unity, we choose
7 // p == 1 mod N and have the Nth roots of unity (N = 2^n) in Q_p and
8 // even in Z_p. Actually, we compute in Z/(p^m Z).
10 // All operations the FFT algorithm needs is addition, subtraction,
11 // multiplication, multiplication by the Nth root and unity and division
12 // by N. Hence we can use the domain Z/(p^m Z) even if p is not a prime!
14 // We use the Schönhage-Strassen choice of the modulus: p = 2^R+1. This
15 // has two big advantages: Multiplication and division by 2 (which is a
16 // (2R)th root of unity) or a power of 2 is just a shift and an subtraction.
17 // And multiplication mod p is just a normal multiplication, followed by
19 // In order to exploit the (2R)th root of unity for FFT, we choose R = 2^r,
20 // and do an FFT of size M with M = 2^m and M | 2R.
22 // Say we want to compute the product of two integers with N1 and N2 bits,
23 // respectively. We choose N >= N1+N2 and K, R, M with
24 // ceiling(N1/K)+ceiling(N2/K)-1 <= M (i.e. roughly N <= K*M),
25 // 2*K+ceiling(log2(M)) <= R,
26 // R = 2^r, M = 2^m, M | 2R.
27 // We then split each of the factors in M K-bit chunks each, and do
28 // an FFT mod p = 2^R+1. We then recover the convolution of the chunks
29 // from the FFT product (the first inequality ensures that this is possible).
30 // The second inequality ensures that we have no overflow, i.e. the
31 // convolution result is valid in Z, not only in Z/pZ.
33 // The computation time (bit complexity) will be proportional to
34 // Mul(N) = O(M log(M) * O(2R)) + M * Mul(R+1).
35 // Hence we try to choose R as small as possible.
36 // Roughly, R >= 2*K, R >= M/2, hence R^2 >= K*M >= N.
38 // For example, when N1 = N2 = 1000000:
39 // Choosing R = 1024, M = 2048, K = 506, ceiling(N1/K) = ceiling(N2/K) = 1977,
40 // M >= 3953, doesn't work.
41 // Choosing R = 2048, M = 4096, K = 1018, ceiling(N1/K) = ceiling(N2/K) = 983,
43 // Actually, we will also want intDsize | K, so that splitting into chunks
44 // and putting together the result can be done without shifts. So
45 // choose R = 2048, M = 4096, K = 992, ceiling(N1/K) = ceiling(N2/K) = 1009.
46 // We see that M = 2048 suffices.
48 // In contrast to Nussbaumer multiplication, here we can use the standard
49 // Karatsuba algorithm for multiplication mod p = 2^R+1. We don't have to
53 // Define this for (cheap) consistency checks.
57 // Operations modulo p = 2^R+1, each chunk represented as chlen words
58 // (chlen = floor(R/intDsize)+1).
60 static inline void assign (const uintC R, const uintC chlen,
61 const uintD* a, uintD* r)
64 copy_loop_lsp(a,r,chlen);
68 static void addm (const uintC R, const uintC chlen,
69 const uintD* a, const uintD* b, uintD* r)
73 add_loop_lsp(a,b, r, chlen);
75 if (lspref(r,chlen-1) < ((uintD)1 << (R % intDsize)))
77 if (lspref(r,chlen-1) == ((uintD)1 << (R % intDsize)))
78 if (!DS_test_loop(r lspop (chlen-1),chlen-1,r))
80 // r >= p, so subtract r := r-p.
81 lspref(r,chlen-1) -= ((uintD)1 << (R % intDsize));
82 dec_loop_lsp(r,chlen);
84 if (lspref(r,chlen-1) < 1)
86 if (lspref(r,chlen-1) == 1)
87 if (!DS_test_loop(r lspop (chlen-1),chlen-1,r))
89 // r >= p, so subtract r := r-p.
90 lspref(r,chlen-1) -= 1;
91 dec_loop_lsp(r,chlen);
96 static void subm (const uintC R, const uintC chlen,
97 const uintD* a, const uintD* b, uintD* r)
101 sub_loop_lsp(a,b, r, chlen);
103 if ((sintD)lspref(r,chlen-1) >= 0)
105 // r < 0, so add r := r+p.
106 lspref(r,chlen-1) += ((uintD)1 << (R % intDsize));
107 inc_loop_lsp(r,chlen);
109 if ((sintD)lspref(r,chlen-1) >= 0)
111 // r < 0, so add r := r+p.
112 lspref(r,chlen-1) += 1;
113 inc_loop_lsp(r,chlen);
117 // r := (a << s) mod p (0 <= s < R).
118 // Assume that a and r don't overlap.
119 static void shiftleftm (const uintC R, const uintC chlen,
120 const uintD* a, uintL s, uintD* r)
122 // Write a = 2^(R-s)*b + c, then
123 // a << s = 2^R*b + (c << s) = (c << s) - b.
127 var uintD b = lspref(a,0) >> (R-s);
128 var uintD c = lspref(a,0) & (((uintD)1 << (R-s)) - 1);
132 c += ((uintD)1 << R) + 1;
137 // Here R >= intDsize, hence intDsize | R.
138 if ((s % intDsize) == 0) {
139 var uintP lenb = s/intDsize;
140 var uintP lenc = (R-s)/intDsize;
141 // chlen = 1 + lenb + lenc.
142 lspref(r,lenb+lenc) = 0;
143 copy_loop_lsp(a,r lspop lenb,lenc);
144 copy_loop_lsp(a lspop lenc,r,lenb);
145 if ((lspref(a,lenb+lenc) > 0) || neg_loop_lsp(r,lenb)) // -b gives carry?
146 if (dec_loop_lsp(r lspop lenb,lenc))
147 // add p = 2^R+1 to compensate with carry
148 inc_loop_lsp(r,chlen);
150 var uintP lenb = floor(s,intDsize);
151 var uintP lenc = floor(R-s,intDsize)+1;
152 // chlen = 1 + lenb + lenc.
154 lspref(r,lenb+lenc) = 0;
155 var uintD b0 = shiftleftcopy_loop_lsp(a,r lspop lenb,lenc,s);
160 bov = shiftleftcopy_loop_lsp(a lspop lenc,r,lenb,s);
163 bov |= lspref(a,lenb+lenc) << s;
164 if (neg_loop_lsp(r,lenb))
166 if (lspref(r,lenb) >= bov)
167 lspref(r,lenb) -= bov;
169 lspref(r,lenb) -= bov;
170 if (dec_loop_lsp(r lspop (lenb+1),lenc-1))
171 // add p = 2^R+1 to compensate with carry
172 inc_loop_lsp(r,chlen);
177 // r := (a * b) mod p
178 static void mulm (const uintC R, const uintC chlen,
179 const uintD* a, const uintD* b, uintD* r)
182 // The leading digits are very likely to be 0.
183 var uintP a_len = chlen;
184 if (lspref(a,a_len-1) == 0)
187 } while ((a_len > 0) && (lspref(a,a_len-1) == 0));
189 clear_loop_lsp(r,chlen);
192 var uintP b_len = chlen;
193 if (lspref(b,b_len-1) == 0)
196 } while ((b_len > 0) && (lspref(b,b_len-1) == 0));
198 clear_loop_lsp(r,chlen);
201 CL_SMALL_ALLOCA_STACK;
202 var uintD* tmp = cl_small_alloc_array(uintD,2*chlen);
203 cl_UDS_mul(a,a_len, b,b_len, arrayLSDptr(tmp,2*chlen));
204 DS_clear_loop(arrayMSDptr(tmp,2*chlen),2*chlen-(a_len+b_len),arrayLSDptr(tmp,2*chlen) lspop (a_len+b_len));
205 // To divide c (0 <= c < p^2) by p = 2^R+1,
206 // we set q := floor(c/2^R) and r := c - q*p = (c mod 2^R) - q.
207 // If this becomes negative, set r := r + p (at most twice).
208 // (This works because floor(c/p) <= q <= floor(c/p)+2.)
209 // (Actually, here, 0 <= c <= (p-1)^2, hence
210 // floor(c/p) <= q <= floor(c/p)+1, so we have
211 // to set r := r + p at most once!)
215 var uintD r0 = (arrayLSref(tmp,2,0) & (((uintD)1 << R) - 1))
216 - ((arrayLSref(tmp,2,1) << (intDsize-R)) | (arrayLSref(tmp,2,0) >> R));
218 r0 += ((uintD)1 << R) + 1;
223 // Here R >= intDsize, hence intDsize | R.
224 // R/intDsize = chlen-1.
225 // arrayLSref(tmp,2*chlen,2*chlen-1) = 0, arrayLSref(tmp,2*chlen,2*chlen-2) <= 1.
226 lspref(r,chlen-1) = 0;
227 if (sub_loop_lsp(arrayLSDptr(tmp,2*chlen),arrayLSDptr(tmp,2*chlen) lspop (chlen-1),r,chlen-1) || arrayLSref(tmp,2*chlen,2*chlen-2))
228 // add p = 2^R+1 to compensate with carry
229 inc_loop_lsp(r,chlen);
232 // b := (a / 2) mod p
233 static void shiftm (const uintC R, const uintC chlen,
234 const uintD* a, uintD* b)
237 shiftrightcopy_loop_msp(a lspop chlen,b lspop chlen,chlen,1,0);
238 if (lspref(a,0) & 1) {
239 // ((a + p) >> 1) = (a >> 1) + (p>>1) + 1.
243 lspref(b,0) |= ((uintD)1 << (R-1));
247 lspref(b,chlen-2) |= ((uintD)1 << (intDsize-1));
248 inc_loop_lsp(b,chlen);
255 // Reverse an n-bit number x. n>0.
256 static uintC bit_reverse (uintL n, uintC x)
263 } while (!(--n == 0));
268 static void mulu_fftm (const uintL r, const uintC R, // R = 2^r
269 const uintL m, const uintC M, // M = 2^m
270 const uintC k, // K = intDsize*k
271 const uintD* sourceptr1, uintC len1,
272 const uintD* sourceptr2, uintC len2,
275 // ceiling(len1/k)+ceiling(len2/k)-1 <= M,
277 // R = 2^r, M = 2^m, M | 2R.
280 var const uintC chlen = floor(R,intDsize)+1; // chunk length (in words)
282 var uintD* const arrX = cl_alloc_array(uintD,chlen<<m);
283 var uintD* const arrY = cl_alloc_array(uintD,chlen<<m);
285 var uintD* const arrZ = cl_alloc_array(uintD,chlen<<m);
287 var uintD* const arrZ = arrX; // put Z in place of X - saves memory
289 #define X(i) arrayLSDptr(&arrX[chlen*(i)],chlen)
290 #define Y(i) arrayLSDptr(&arrY[chlen*(i)],chlen)
291 #define Z(i) (arrayLSDptr(&arrZ[chlen*(i)],chlen))
295 num_stack_alloc(chlen,,tmp=);
296 num_stack_alloc(chlen,,sum=);
297 num_stack_alloc(chlen,,diff=);
298 var bool squaring = ((sourceptr1 == sourceptr2) && (len1 == len2));
300 // Initialize factors X(i) and Y(i).
302 var const uintD* sptr = sourceptr1;
303 var uintC slen = len1;
304 for (i = 0; i < M; i++) {
305 var uintD* ptr = X(i);
307 copy_loop_lsp(sptr,ptr,k);
308 clear_loop_lsp(ptr lspop k,chlen-k);
312 copy_loop_lsp(sptr,ptr,slen);
313 clear_loop_lsp(ptr lspop slen,chlen-slen);
318 // X(i) := ... := X(M-1) := 0
319 clear_loop_up(&arrX[chlen*i],chlen*(M-i));
322 var const uintD* sptr = sourceptr2;
323 var uintC slen = len2;
324 for (i = 0; i < M; i++) {
325 var uintD* ptr = Y(i);
327 copy_loop_lsp(sptr,ptr,k);
328 clear_loop_lsp(ptr lspop k,chlen-k);
332 copy_loop_lsp(sptr,ptr,slen);
333 clear_loop_lsp(ptr lspop slen,chlen-slen);
338 // Y(i) := ... := Y(M-1) := 0
339 clear_loop_up(&arrY[chlen*i],chlen*(M-i));
341 // Do an FFT of length M on X. w = 2^(2R/M) = 2^(2^(r+1-m)).
345 var const uintC tmax = M>>1; // tmax = 2^(m-1)
346 for (var uintC t = 0; t < tmax; t++) {
348 var uintC i2 = i1 + tmax;
349 // Butterfly: replace (X(i1),X(i2)) by
350 // (X(i1) + X(i2), X(i1) - X(i2)).
351 assign(R,chlen, X(i2), tmp);
352 subm(R,chlen, X(i1),tmp, X(i2));
353 addm(R,chlen, X(i1),tmp, X(i1));
356 for (l = m-2; l>=0; l--) {
357 var const uintC smax = (uintC)1 << (m-1-l);
358 var const uintC tmax = (uintC)1 << l;
359 for (var uintC s = 0; s < smax; s++) {
360 // w^exp = 2^(exp << (r+1-m)).
361 var uintC exp = bit_reverse(m-1-l,s) << (r-(m-1-l));
362 for (var uintC t = 0; t < tmax; t++) {
363 var uintC i1 = (s << (l+1)) + t;
364 var uintC i2 = i1 + tmax;
365 // Butterfly: replace (X(i1),X(i2)) by
366 // (X(i1) + w^exp*X(i2), X(i1) - w^exp*X(i2)).
367 shiftleftm(R,chlen, X(i2),exp, tmp);
368 subm(R,chlen, X(i1),tmp, X(i2));
369 addm(R,chlen, X(i1),tmp, X(i1));
374 // Do an FFT of length M on Y. w = 2^(2R/M) = 2^(2^(r+1-m)).
378 var const uintC tmax = M>>1; // tmax = 2^(m-1)
379 for (var uintC t = 0; t < tmax; t++) {
381 var uintC i2 = i1 + tmax;
382 // Butterfly: replace (Y(i1),Y(i2)) by
383 // (Y(i1) + Y(i2), Y(i1) - Y(i2)).
384 assign(R,chlen, Y(i2), tmp);
385 subm(R,chlen, Y(i1),tmp, Y(i2));
386 addm(R,chlen, Y(i1),tmp, Y(i1));
389 for (l = m-2; l>=0; l--) {
390 var const uintC smax = (uintC)1 << (m-1-l);
391 var const uintC tmax = (uintC)1 << l;
392 for (var uintC s = 0; s < smax; s++) {
393 // w^exp = 2^(exp << (r+1-m)).
394 var uintC exp = bit_reverse(m-1-l,s) << (r-(m-1-l));
395 for (var uintC t = 0; t < tmax; t++) {
396 var uintC i1 = (s << (l+1)) + t;
397 var uintC i2 = i1 + tmax;
398 // Butterfly: replace (Y(i1),Y(i2)) by
399 // (Y(i1) + w^exp*Y(i2), Y(i1) - w^exp*Y(i2)).
400 shiftleftm(R,chlen, Y(i2),exp, tmp);
401 subm(R,chlen, Y(i1),tmp, Y(i2));
402 addm(R,chlen, Y(i1),tmp, Y(i1));
407 // Multiply the transformed vectors into Z.
409 for (i = 0; i < M; i++)
410 mulm(R,chlen, X(i),Y(i),Z(i));
412 for (i = 0; i < M; i++)
413 mulm(R,chlen, X(i),X(i),Z(i));
415 // Undo an FFT of length M on Z. w = 2^(2R/M) = 2^(2^(r+1-m)).
418 for (l = 0; l < m-1; l++) {
419 var const uintC smax = (uintC)1 << (m-1-l);
420 var const uintC tmax = (uintC)1 << l;
421 /* s = 0, exp = 0 */ {
422 for (var uintC t = 0; t < tmax; t++) {
424 var uintC i2 = i1 + tmax;
425 // Inverse Butterfly: replace (Z(i1),Z(i2)) by
426 // ((Z(i1)+Z(i2))/2, (Z(i1)-Z(i2))/(2*w^exp)),
428 addm(R,chlen, Z(i1),Z(i2), sum);
429 subm(R,chlen, Z(i1),Z(i2), diff);
430 shiftm(R,chlen, sum, Z(i1));
431 shiftm(R,chlen, diff, Z(i2));
434 for (var uintC s = 1; s < smax; s++) {
435 // w^exp = 2^(exp << (r+1-m)).
436 var uintC exp = bit_reverse(m-1-l,s) << (r-(m-1-l));
437 exp = R - exp; // negate exp (use w^-1 instead of w)
438 for (var uintC t = 0; t < tmax; t++) {
439 var uintC i1 = (s << (l+1)) + t;
440 var uintC i2 = i1 + tmax;
441 // Inverse Butterfly: replace (Z(i1),Z(i2)) by
442 // ((Z(i1)+Z(i2))/2, (Z(i1)-Z(i2))/(2*w^exp)),
443 // with exp <-- (M/2 - exp).
444 addm(R,chlen, Z(i1),Z(i2), sum);
445 subm(R,chlen, Z(i2),Z(i1), diff); // note that w^(M/2) = 2^R = -1
446 shiftm(R,chlen, sum, Z(i1));
447 shiftleftm(R,chlen, diff,exp-1, Z(i2));
452 var const uintC tmax = M>>1; // tmax = 2^(m-1)
453 for (var uintC t = 0; t < tmax; t++) {
455 var uintC i2 = i1 + tmax;
456 // Inverse Butterfly: replace (Z(i1),Z(i2)) by
457 // ((Z(i1)+Z(i2))/2, (Z(i1)-Z(i2))/2).
458 addm(R,chlen, Z(i1),Z(i2), sum);
459 subm(R,chlen, Z(i1),Z(i2), diff);
460 shiftm(R,chlen, sum, Z(i1));
461 shiftm(R,chlen, diff, Z(i2));
465 var uintC zchlen = 2*k + ceiling(m,intDsize);
467 // Check that every Z(i) has at most 2*K+m bits.
469 var uintC zerodigits = chlen - zchlen;
470 for (i = 0; i < M; i++)
471 if (DS_test_loop(Z(i) lspop chlen,zerodigits,Z(i) lspop zchlen))
475 // Put together result.
476 var uintC destlen = len1+len2;
477 clear_loop_lsp(destptr,destlen);
478 for (i = 0; i < M; i++, destptr = destptr lspop k, destlen -= k) {
479 if (zchlen <= destlen) {
480 if (addto_loop_lsp(Z(i),destptr,zchlen))
481 if (inc_loop_lsp(destptr lspop zchlen,destlen-zchlen))
485 if (DS_test_loop(Z(i) lspop zchlen,zchlen-destlen,Z(i) lspop destlen))
488 if (addto_loop_lsp(Z(i),destptr,destlen))
497 // Check that Z(i)..Z(M-1) are all zero.
498 if (test_loop_up(&arrZ[chlen*i],chlen*(M-i)))
509 // The running time of mulu_fftm() is roughly
510 // O(M log(M) * O(2R)) + M * R^(1+c), where c = log3/log2 - 1 = 0.585...
511 // Try to minimize this given the constraints
512 // ceiling(len1/k)+ceiling(len2/k)-1 <= M,
513 // K = intDsize*k, 2*K+m <= R,
514 // R = 2^r, M = 2^m, M | 2R.
516 // Necessary conditions:
517 // len1+len2 <= k*(M+1), intDsize*(len1+len2) <= K*(M+1) <= (R-1)/2 * (2*R+1) < R^2.
518 // 2*intDsize+1 <= R, log2_intDsize+1 < r.
519 // So we start with len1 <= len2,
520 // r := max(log2_intDsize+2,ceiling(ceiling(log2(intDsize*2*len1))/2)), R := 2^r.
522 // kmax := floor((R-(r+1))/(2*intDsize)), Kmax := intDsize*kmax,
523 // m := max(1,ceiling(log2(2*ceiling(len1/kmax)-1))), M := 2^m,
524 // if m > r+1 retry with r <- r+1.
525 // [Now we are sure that we can at least multiply len1 and len1 digits using these
526 // values of r and m, symbolically (r,m) OKFOR (len1,len1).]
527 // [Normally, we will have m=r+1 or m=r.]
528 // For (len1,len2), we might want to split the second integer into pieces.
529 // If (r,m) OKFOR (len1,len2)
530 // If (r-1,m) OKFOR (len1,ceiling(len2/2))
531 // then use (r-1,m) and two pieces
532 // else use (r,m) and one piece
534 // q1 := number of pieces len2 needs to be splitted into to be OKFOR (r,m),
536 // q2 := number of pieces len2 needs to be splitted into to be OKFOR (r,m+1),
538 // then use (r,m+1) and q2 pieces
539 // else use (r,m) and q1 pieces
541 // q2 := number of pieces len2 needs to be splitted into to be OKFOR (r+1,m),
543 // then use (r+1,m) and q2 pieces
544 // else use (r,m) and q1 pieces
546 // Because we always choose r >= log2_intDsize+2, R >= 4*intDsize, so chlen >= 5.
547 // To avoid infinite recursion, mulu_fft_modm() must only be called with len1 > 5.
549 static bool okfor (uintL r, uintL m, uintC len1, uintC len2)
551 var uintC R = (uintC)1 << r;
552 var uintC M = (uintC)1 << m;
553 var uintC k = floor(R-m,2*intDsize);
554 return (ceiling(len1,k)+ceiling(len2,k) <= M+1);
557 static uintC numpieces (uintL r, uintL m, uintC len1, uintC len2)
559 var uintC R = (uintC)1 << r;
560 var uintC M = (uintC)1 << m;
561 var uintC k = floor(R-m,2*intDsize);
562 var uintC piecelen2 = (M+1-ceiling(len1,k))*k;
564 if ((sintC)piecelen2 <= 0)
567 return ceiling(len2,piecelen2);
570 static void mulu_fft_modm (const uintD* sourceptr1, uintC len1,
571 const uintD* sourceptr2, uintC len2,
573 // Called only with 6 <= len1 <= len2.
576 integerlengthC(len1-1, n=); // 2^(n-1) < len1 <= 2^n
579 r = ceiling(log2_intDsize+1+n,2);
580 if (r < log2_intDsize+2)
583 var uintC k = floor(((uintC)1 << r) - (r+1), 2*intDsize);
584 var uintC M = 2*ceiling(len1,k)-1;
585 integerlengthC(M, m=);
594 if (!(m > 0 && m <= r+1 && okfor(r,m,len1,len1)))
597 if (okfor(r,m,len1,len2)) {
598 if ((m <= r) && (r > log2_intDsize+2) && okfor(r-1,m,len1,ceiling(len2,2)))
599 if (!(sourceptr1 == sourceptr2 && len1 == len2)) // when squaring, keep one piece
602 var uintC q1 = numpieces(r,m,len1,len2);
604 var uintC q2 = numpieces(r,m+1,len1,len2);
608 var uintC q2 = numpieces(r+1,m,len1,len2);
613 var uintC R = (uintC)1 << r;
614 var uintC M = (uintC)1 << m;
615 var uintC k = floor(R-m,2*intDsize);
616 var uintC piecelen2 = (M+1-ceiling(len1,k))*k;
617 if (piecelen2 >= len2) {
619 mulu_fftm(r,R, m,M, k, sourceptr1,len1, sourceptr2,len2, destptr);
624 num_stack_alloc(len1+piecelen2,,tmpptr=);
625 var uintC destlen = len1+len2;
626 clear_loop_lsp(destptr,destlen);
628 var uintC len2p; // length of a piece of source2
632 // len2p = min(piecelen2,len2).
633 var uintC destlenp = len1 + len2p;
634 // destlenp = min(len1+piecelen2,destlen).
635 // Use tmpptr[-destlenp..-1].
638 mulu_loop_lsp(lspref(sourceptr2,0),sourceptr1,tmpptr,len1);
639 } else if (2*len2p < piecelen2) {
641 cl_UDS_mul(sourceptr1,len1, sourceptr2,len2p, tmpptr);
643 mulu_fftm(r,R, m,M, k, sourceptr1,len1, sourceptr2,len2p, tmpptr);
645 if (addto_loop_lsp(tmpptr,destptr,destlenp))
646 if (inc_loop_lsp(destptr lspop destlenp,destlen-destlenp))
649 destptr = destptr lspop len2p;
651 sourceptr2 = sourceptr2 lspop len2p;