]> www.ginac.de Git - cln.git/blob - src/base/digitseq/cl_DS_recipsqrt.cc
* src/base/cl_macros.h: alloca(3) has size_t argument type.
[cln.git] / src / base / digitseq / cl_DS_recipsqrt.cc
1 // cl_UDS_recipsqrt().
2
3 // General includes.
4 #include "cl_sysdep.h"
5
6 // Specification.
7 #include "cl_DS.h"
8
9
10 // Implementation.
11
12 #include "cl_low.h"
13 #include "cln/abort.h"
14
15 namespace cln {
16
17 // Compute the reciprocal square root of a digit sequence.
18 // Input: UDS a_MSDptr/a_len/.. of length a_len,
19 //        with 1/4 <= a < 1.
20 //        [i.e. 1/4*beta^a_len <= a < beta^a_len]
21 // Output: UDS b_MSDptr/b_len+2/.. of length b_len+1 (b_len>1), plus 1 more bit
22 //         in the last limb) such that
23 //         1 <= b <= 2  [i.e. beta^b_len <= b <= 2*beta^b_len]
24 //         and  | 1/sqrt(a) - b | < 1/2*beta^(-b_len).
25 // If a_len > b_len, only the most significant b_len+1 limbs of a are used.
26   extern void cl_UDS_recipsqrt (const uintD* a_MSDptr, uintC a_len,
27                                 uintD* b_MSDptr, uintC b_len);
28 // Method:
29 // Using Newton iteration for computation of x^-1/2.
30 // The Newton iteration for f(y) = x-1/y^2 reads:
31 //   y --> y - (x-1/y^2)/(2/y^3) = y + y*(1-x*y^2)/2 =: g(y).
32 // We have  T^3-3*T+2 = (T-1)^2*(T+2), hence
33 //   1/sqrt(x) - g(y) = 1/(2*sqrt(x)) * (sqrt(x)*y-1)^2 * (sqrt(x)*y+2).
34 // Hence g(y) <= 1/sqrt(x).
35 // If we choose 0 < y_0 <= 1/sqrt(x), then set y_(n+1) := g(y_n), we will
36 // always have 0 < y_n <= 1/sqrt(x).
37 // Since
38 //   1/sqrt(x) - g(y) = sqrt(x)*(sqrt(x)*y+2)/2 * (1/sqrt(x) - y)^2,
39 // which is >= 0 and < 3/2 * (1/sqrt(x) - y)^2, we have a quadratically
40 // convergent iteration.
41 // For n = 1,2,...,b_len we compute approximations y with 1 <= yn <= 2
42 // and  | 1/sqrt(x) - yn | < 1/2*beta^(-n).
43 // Step n=1:
44 //   Compute the isqrt of the leading two digits of x, yields one digit.
45 //   Compute its reciprocal, then do one iteration as below (n=0 -> m=1).
46 // Step n -> m with n < m <= 2*n:
47 //   Write x = xm + xr with 0 <= xr < beta^-(m+1).
48 //   Set ym' = yn + (yn*(1-xm*yn*yn))/2, round down to a multiple ym
49 //   of beta^-(m+1).
50 //   (Actually, compute yn*yn, round up to a multiple of beta^-(m+1),   [1]
51 //    multiply with xm,        round up to a multiple of beta^-(m+1),   [2]
52 //    subtract from 1,         no rounding needed,                      [2]
53 //    multiply with yn,        round down to a multiple of beta^-(m+1), [5]
54 //    divide by 2,             round down to a multiple of beta^-(m+1), [3]
55 //    add to yn,               no rounding needed.  [Max rounding error: ^])
56 //   The exact value ym' (no rounding) would satisfy
57 //     0 <= 1/sqrt(xm) - ym' < 3/2 * (1/sqrt(xm) - yn)^2
58 //                           < 3/8 * beta^(-2*n)          by hypothesis,
59 //                           <= 3/8 * beta^-m.
60 //   The rounding errors all go into the same direction, so
61 //     0 <= ym' - ym < 3 * beta^-(m+1) < 1/4 * beta^-m.
62 //   Combine both inequalities:
63 //     0 <= 1/sqrt(xm) - ym < 1/2 * beta^-m.
64 //   Neglecting xr can introduce a small error in the opposite direction:
65 //     0 <= 1/sqrt(xm) - 1/sqrt(x) = (sqrt(x) - sqrt(xm))/(sqrt(x)*sqrt(xm))
66 //        = xr / (sqrt(x)*sqrt(xm)*(sqrt(x)+sqrt(xm)))
67 //        <= 4*xr < 4*beta^-(m+1) < 1/2*beta^-m.
68 //   Combine both inequalities:
69 //     | 1/sqrt(x) - ym | < 1/2 * beta^-m.
70 //   (Actually, choosing the opposite rounding direction wouldn't hurt either.)
71 // Choice of n:
72 //   So that the computation is minimal, e.g. in the case b_len=10:
73 //   1 -> 2 -> 3 -> 5 -> 10 and not 1 -> 2 -> 4 -> 8 -> 10.
74   void cl_UDS_recipsqrt (const uintD* a_MSDptr, uintC a_len,
75                          uintD* b_MSDptr, uintC b_len)
76     {
77         var uintC y_len = b_len+2;
78         var uintC x_len = (a_len <= b_len ? a_len : b_len+1);
79         var const uintD* const x_MSDptr = a_MSDptr;
80         var uintD* y_MSDptr;
81         var uintD* y2_MSDptr;
82         var uintD* y3_MSDptr;
83         var uintD* y4_MSDptr;
84         CL_ALLOCA_STACK;
85         num_stack_alloc(y_len,y_MSDptr=,);
86         num_stack_alloc(2*y_len,y2_MSDptr=,);
87         num_stack_alloc(2*y_len,y3_MSDptr=,);
88         num_stack_alloc(2*y_len,y4_MSDptr=,);
89         // Step n = 1.
90         { var uintD x1 = mspref(x_MSDptr,0);
91           var uintD x2 = (a_len > 1 ? mspref(x_MSDptr,1) : 0);
92           var uintD y0;
93           var uintD y1;
94           var bool sqrtp;
95           isqrtD(x1,x2, y1=,sqrtp=);
96           // 2^31 <= y1 < 2^32.
97           y0 = 1;
98           if (!sqrtp) // want to compute 1/sqrt(x) rounded down
99                 if (++y1 == 0)
100                         goto step1_done; // 1/1.0000 = 1.0000
101           // Set y0|y1 := 2^(2*intDsize)/y1
102           //            = 2^intDsize + (2^(2*intDsize)-2^intDsize*y1)/y1.
103           if ((uintD)(-y1) >= y1) {
104                 y0 = 2; y1 = 0;
105           } else {
106                 #if HAVE_DD
107                 divuD(highlowDD_0((uintD)(-y1)),y1, y1=,);
108                 #else
109                 divuD((uintD)(-y1),0,y1, y1=,);
110                 #endif
111           }
112         step1_done:
113           mspref(y_MSDptr,0) = y0;
114           mspref(y_MSDptr,1) = y1;
115         }
116         // Other steps.
117         var int k;
118         integerlength32((uint32)b_len-1,k=);
119         // 2^(k-1) < b_len <= 2^k, so we need k steps, plus one
120         // one more step at the beginning (because step 1 was not complete).
121         var uintC n = 0;
122         for (; k>=0; k--)
123           { var uintC m = ((b_len-1)>>k)+1; // = ceiling(b_len/2^k)
124             // Compute ym := yn + (yn*(1-xm*yn*yn))/2, rounded.
125             // Storage: at y_MSDptr: (1 + n+1) limbs, yn.
126             //          at y2_MSDptr: (2 + 2*n+2) limbs, yn^2.
127             //          at y3_MSDptr: (1 + m+1) limbs, xm*yn*yn, 1-xm*yn*yn.
128             //          at y4_MSDptr: (2-n + m+n+2) limbs, yn*(1-xm*yn*yn).
129             clear_loop_msp(y_MSDptr mspop (n+2),m-n);
130             cl_UDS_mul_square(y_MSDptr mspop (n+2),n+2,
131                               y2_MSDptr mspop 2*(n+2));
132             var uintC xm_len = (m < x_len ? m+1 : x_len);
133             var uintC y2_len = m+2; // = (m+1 <= 2*n+2 ? m+2 : 2*n+3);
134             cl_UDS_mul(x_MSDptr mspop xm_len,xm_len,
135                        y2_MSDptr mspop (y2_len+1),y2_len,
136                        y3_MSDptr mspop (xm_len+y2_len));
137             if (mspref(y3_MSDptr,0)==0)
138               // xm*yn*yn < 1
139               { neg_loop_lsp(y3_MSDptr mspop (m+2),m+2);
140                 mspref(y3_MSDptr,0) += 1;
141                 if (test_loop_msp(y3_MSDptr,n)) cl_abort(); // check 0 <= y3 < beta^-(n-1)
142                 cl_UDS_mul(y_MSDptr mspop (n+2),n+2,
143                            y3_MSDptr mspop (m+2),m+2-n,
144                            y4_MSDptr mspop (m+4));
145                 shift1right_loop_msp(y4_MSDptr,m+3-n,0);
146                 if (addto_loop_lsp(y4_MSDptr mspop (m+3-n),y_MSDptr mspop (m+2),m+3-n))
147                   if ((n<1) || inc_loop_lsp(y_MSDptr mspop (n-1),n-1)) cl_abort();
148               }
149               else
150               // xm*yn*yn >= 1 (this can happen since xm >= xn)
151               { mspref(y3_MSDptr,0) -= 1;
152                 if (test_loop_msp(y3_MSDptr,n)) cl_abort(); // check 0 >= y3 > -beta^-(n-1)
153                 cl_UDS_mul(y_MSDptr mspop (n+2),n+2,
154                            y3_MSDptr mspop (m+2),m+2-n,
155                            y4_MSDptr mspop (m+4));
156                 shift1right_loop_msp(y4_MSDptr,m+3-n,0);
157                 if (subfrom_loop_lsp(y4_MSDptr mspop (m+3-n),y_MSDptr mspop (m+2),m+3-n))
158                   if ((n<1) || dec_loop_lsp(y_MSDptr mspop (n-1),n-1)) cl_abort();
159               }
160             n = m;
161             // n = ceiling(b_len/2^k) limbs of y have now been computed.
162           }
163         copy_loop_msp(y_MSDptr,b_MSDptr,b_len+2);
164 }
165 // Bit complexity (N := b_len): O(M(N)).
166
167 }  // namespace cln