mpmul.c (3182B)
1 #include "os.h" 2 #include <mp.h> 3 #include "dat.h" 4 5 /* */ 6 /* from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260 */ 7 /* */ 8 /* mpvecmul is an assembly language routine that performs the inner */ 9 /* loop. */ 10 /* */ 11 /* the karatsuba trade off is set empiricly by measuring the algs on */ 12 /* a 400 MHz Pentium II. */ 13 /* */ 14 15 /* karatsuba like (see knuth pg 258) */ 16 /* prereq: p is already zeroed */ 17 static void 18 mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) 19 { 20 mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod; 21 int u0len, u1len, v0len, v1len, reslen; 22 int sign, n; 23 24 /* divide each piece in half */ 25 n = alen/2; 26 if(alen&1) 27 n++; 28 u0len = n; 29 u1len = alen-n; 30 if(blen > n){ 31 v0len = n; 32 v1len = blen-n; 33 } else { 34 v0len = blen; 35 v1len = 0; 36 } 37 u0 = a; 38 u1 = a + u0len; 39 v0 = b; 40 v1 = b + v0len; 41 42 /* room for the partial products */ 43 t = mallocz(Dbytes*5*(2*n+1), 1); 44 if(t == nil) 45 sysfatal("mpkaratsuba: %r"); 46 u0v0 = t; 47 u1v1 = t + (2*n+1); 48 diffprod = t + 2*(2*n+1); 49 res = t + 3*(2*n+1); 50 reslen = 4*n+1; 51 52 /* t[0] = (u1-u0) */ 53 sign = 1; 54 if(mpveccmp(u1, u1len, u0, u0len) < 0){ 55 sign = -1; 56 mpvecsub(u0, u0len, u1, u1len, u0v0); 57 } else 58 mpvecsub(u1, u1len, u0, u1len, u0v0); 59 60 /* t[1] = (v0-v1) */ 61 if(mpveccmp(v0, v0len, v1, v1len) < 0){ 62 sign *= -1; 63 mpvecsub(v1, v1len, v0, v1len, u1v1); 64 } else 65 mpvecsub(v0, v0len, v1, v1len, u1v1); 66 67 /* t[4:5] = (u1-u0)*(v0-v1) */ 68 mpvecmul(u0v0, u0len, u1v1, v0len, diffprod); 69 70 /* t[0:1] = u1*v1 */ 71 memset(t, 0, 2*(2*n+1)*Dbytes); 72 if(v1len > 0) 73 mpvecmul(u1, u1len, v1, v1len, u1v1); 74 75 /* t[2:3] = u0v0 */ 76 mpvecmul(u0, u0len, v0, v0len, u0v0); 77 78 /* res = u0*v0<<n + u0*v0 */ 79 mpvecadd(res, reslen, u0v0, u0len+v0len, res); 80 mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n); 81 82 /* res += u1*v1<<n + u1*v1<<2*n */ 83 if(v1len > 0){ 84 mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n); 85 mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n); 86 } 87 88 /* res += (u1-u0)*(v0-v1)<<n */ 89 if(sign < 0) 90 mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n); 91 else 92 mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n); 93 memmove(p, res, (alen+blen)*Dbytes); 94 95 free(t); 96 } 97 98 #define KARATSUBAMIN 32 99 100 void 101 mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) 102 { 103 int i; 104 mpdigit d; 105 mpdigit *t; 106 107 /* both mpvecdigmuladd and karatsuba are fastest when a is the longer vector */ 108 if(alen < blen){ 109 i = alen; 110 alen = blen; 111 blen = i; 112 t = a; 113 a = b; 114 b = t; 115 } 116 if(blen == 0){ 117 memset(p, 0, Dbytes*(alen+blen)); 118 return; 119 } 120 121 if(alen >= KARATSUBAMIN && blen > 1){ 122 /* O(n^1.585) */ 123 mpkaratsuba(a, alen, b, blen, p); 124 } else { 125 /* O(n^2) */ 126 for(i = 0; i < blen; i++){ 127 d = b[i]; 128 if(d != 0) 129 mpvecdigmuladd(a, alen, d, &p[i]); 130 } 131 } 132 } 133 134 void 135 mpmul(mpint *b1, mpint *b2, mpint *prod) 136 { 137 mpint *oprod; 138 139 oprod = nil; 140 if(prod == b1 || prod == b2){ 141 oprod = prod; 142 prod = mpnew(0); 143 } 144 145 prod->top = 0; 146 mpbits(prod, (b1->top+b2->top+1)*Dbits); 147 mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); 148 prod->top = b1->top+b2->top+1; 149 prod->sign = b1->sign*b2->sign; 150 mpnorm(prod); 151 152 if(oprod != nil){ 153 mpassign(prod, oprod); 154 mpfree(prod); 155 } 156 }