plan9port

fork of plan9port with libvec, libstr and libsdb
Log | Files | Refs | README | LICENSE

mpmul.c (3182B)


      1 #include "os.h"
      2 #include <mp.h>
      3 #include "dat.h"
      4 
      5 /* */
      6 /*  from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260 */
      7 /* */
      8 /*  mpvecmul is an assembly language routine that performs the inner */
      9 /*  loop. */
     10 /* */
     11 /*  the karatsuba trade off is set empiricly by measuring the algs on */
     12 /*  a 400 MHz Pentium II. */
     13 /* */
     14 
     15 /* karatsuba like (see knuth pg 258) */
     16 /* prereq: p is already zeroed */
     17 static void
     18 mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
     19 {
     20 	mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
     21 	int u0len, u1len, v0len, v1len, reslen;
     22 	int sign, n;
     23 
     24 	/* divide each piece in half */
     25 	n = alen/2;
     26 	if(alen&1)
     27 		n++;
     28 	u0len = n;
     29 	u1len = alen-n;
     30 	if(blen > n){
     31 		v0len = n;
     32 		v1len = blen-n;
     33 	} else {
     34 		v0len = blen;
     35 		v1len = 0;
     36 	}
     37 	u0 = a;
     38 	u1 = a + u0len;
     39 	v0 = b;
     40 	v1 = b + v0len;
     41 
     42 	/* room for the partial products */
     43 	t = mallocz(Dbytes*5*(2*n+1), 1);
     44 	if(t == nil)
     45 		sysfatal("mpkaratsuba: %r");
     46 	u0v0 = t;
     47 	u1v1 = t + (2*n+1);
     48 	diffprod = t + 2*(2*n+1);
     49 	res = t + 3*(2*n+1);
     50 	reslen = 4*n+1;
     51 
     52 	/* t[0] = (u1-u0) */
     53 	sign = 1;
     54 	if(mpveccmp(u1, u1len, u0, u0len) < 0){
     55 		sign = -1;
     56 		mpvecsub(u0, u0len, u1, u1len, u0v0);
     57 	} else
     58 		mpvecsub(u1, u1len, u0, u1len, u0v0);
     59 
     60 	/* t[1] = (v0-v1) */
     61 	if(mpveccmp(v0, v0len, v1, v1len) < 0){
     62 		sign *= -1;
     63 		mpvecsub(v1, v1len, v0, v1len, u1v1);
     64 	} else
     65 		mpvecsub(v0, v0len, v1, v1len, u1v1);
     66 
     67 	/* t[4:5] = (u1-u0)*(v0-v1) */
     68 	mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
     69 
     70 	/* t[0:1] = u1*v1 */
     71 	memset(t, 0, 2*(2*n+1)*Dbytes);
     72 	if(v1len > 0)
     73 		mpvecmul(u1, u1len, v1, v1len, u1v1);
     74 
     75 	/* t[2:3] = u0v0 */
     76 	mpvecmul(u0, u0len, v0, v0len, u0v0);
     77 
     78 	/* res = u0*v0<<n + u0*v0 */
     79 	mpvecadd(res, reslen, u0v0, u0len+v0len, res);
     80 	mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
     81 
     82 	/* res += u1*v1<<n + u1*v1<<2*n */
     83 	if(v1len > 0){
     84 		mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
     85 		mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
     86 	}
     87 
     88 	/* res += (u1-u0)*(v0-v1)<<n */
     89 	if(sign < 0)
     90 		mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
     91 	else
     92 		mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
     93 	memmove(p, res, (alen+blen)*Dbytes);
     94 
     95 	free(t);
     96 }
     97 
     98 #define KARATSUBAMIN 32
     99 
    100 void
    101 mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
    102 {
    103 	int i;
    104 	mpdigit d;
    105 	mpdigit *t;
    106 
    107 	/* both mpvecdigmuladd and karatsuba are fastest when a is the longer vector */
    108 	if(alen < blen){
    109 		i = alen;
    110 		alen = blen;
    111 		blen = i;
    112 		t = a;
    113 		a = b;
    114 		b = t;
    115 	}
    116 	if(blen == 0){
    117 		memset(p, 0, Dbytes*(alen+blen));
    118 		return;
    119 	}
    120 
    121 	if(alen >= KARATSUBAMIN && blen > 1){
    122 		/* O(n^1.585) */
    123 		mpkaratsuba(a, alen, b, blen, p);
    124 	} else {
    125 		/* O(n^2) */
    126 		for(i = 0; i < blen; i++){
    127 			d = b[i];
    128 			if(d != 0)
    129 				mpvecdigmuladd(a, alen, d, &p[i]);
    130 		}
    131 	}
    132 }
    133 
    134 void
    135 mpmul(mpint *b1, mpint *b2, mpint *prod)
    136 {
    137 	mpint *oprod;
    138 
    139 	oprod = nil;
    140 	if(prod == b1 || prod == b2){
    141 		oprod = prod;
    142 		prod = mpnew(0);
    143 	}
    144 
    145 	prod->top = 0;
    146 	mpbits(prod, (b1->top+b2->top+1)*Dbits);
    147 	mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
    148 	prod->top = b1->top+b2->top+1;
    149 	prod->sign = b1->sign*b2->sign;
    150 	mpnorm(prod);
    151 
    152 	if(oprod != nil){
    153 		mpassign(prod, oprod);
    154 		mpfree(prod);
    155 	}
    156 }