plan9port

fork of plan9port with libvec, libstr and libsdb
Log | Files | Refs | README | LICENSE

inflate.c (13119B)


      1 #include <u.h>
      2 #include <libc.h>
      3 #include <flate.h>
      4 
      5 enum {
      6 	HistorySize=	32*1024,
      7 	BufSize=	4*1024,
      8 	MaxHuffBits=	17,	/* maximum bits in a encoded code */
      9 	Nlitlen=	288,	/* number of litlen codes */
     10 	Noff=		32,	/* number of offset codes */
     11 	Nclen=		19,	/* number of codelen codes */
     12 	LenShift=	10,	/* code = len<<LenShift|code */
     13 	LitlenBits=	7,	/* number of bits in litlen decode table */
     14 	OffBits=	6,	/* number of bits in offset decode table */
     15 	ClenBits=	6,	/* number of bits in code len decode table */
     16 	MaxFlatBits=	LitlenBits,
     17 	MaxLeaf=	Nlitlen
     18 };
     19 
     20 typedef struct Input	Input;
     21 typedef struct History	History;
     22 typedef struct Huff	Huff;
     23 
     24 struct Input
     25 {
     26 	int	error;		/* first error encountered, or FlateOk */
     27 	void	*wr;
     28 	int	(*w)(void*, void*, int);
     29 	void	*getr;
     30 	int	(*get)(void*);
     31 	ulong	sreg;
     32 	int	nbits;
     33 };
     34 
     35 struct History
     36 {
     37 	uchar	his[HistorySize];
     38 	uchar	*cp;		/* current pointer in history */
     39 	int	full;		/* his has been filled up at least once */
     40 };
     41 
     42 struct Huff
     43 {
     44 	int	maxbits;	/* max bits for any code */
     45 	int	minbits;	/* min bits to get before looking in flat */
     46 	int	flatmask;	/* bits used in "flat" fast decoding table */
     47 	ulong	flat[1<<MaxFlatBits];
     48 	ulong	maxcode[MaxHuffBits];
     49 	ulong	last[MaxHuffBits];
     50 	ulong	decode[MaxLeaf];
     51 };
     52 
     53 /* litlen code words 257-285 extra bits */
     54 static int litlenextra[Nlitlen-257] =
     55 {
     56 /* 257 */	0, 0, 0,
     57 /* 260 */	0, 0, 0, 0, 0, 1, 1, 1, 1, 2,
     58 /* 270 */	2, 2, 2, 3, 3, 3, 3, 4, 4, 4,
     59 /* 280 */	4, 5, 5, 5, 5, 0, 0, 0
     60 };
     61 
     62 static int litlenbase[Nlitlen-257];
     63 
     64 /* offset code word extra bits */
     65 static int offextra[Noff] =
     66 {
     67 	0,  0,  0,  0,  1,  1,  2,  2,  3,  3,
     68 	4,  4,  5,  5,  6,  6,  7,  7,  8,  8,
     69 	9,  9,  10, 10, 11, 11, 12, 12, 13, 13,
     70 	0,  0,
     71 };
     72 static int offbase[Noff];
     73 
     74 /* order code lengths */
     75 static int clenorder[Nclen] =
     76 {
     77         16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15
     78 };
     79 
     80 /* for static huffman tables */
     81 static	Huff	litlentab;
     82 static	Huff	offtab;
     83 static	uchar	revtab[256];
     84 
     85 static int	uncblock(Input *in, History*);
     86 static int	fixedblock(Input *in, History*);
     87 static int	dynamicblock(Input *in, History*);
     88 static int	sregfill(Input *in, int n);
     89 static int	sregunget(Input *in);
     90 static int	decode(Input*, History*, Huff*, Huff*);
     91 static int	hufftab(Huff*, char*, int, int);
     92 static int	hdecsym(Input *in, Huff *h, int b);
     93 
     94 int
     95 inflateinit(void)
     96 {
     97 	char *len;
     98 	int i, j, base;
     99 
    100 	/* byte reverse table */
    101 	for(i=0; i<256; i++)
    102 		for(j=0; j<8; j++)
    103 			if(i & (1<<j))
    104 				revtab[i] |= 0x80 >> j;
    105 
    106 	for(i=257,base=3; i<Nlitlen; i++) {
    107 		litlenbase[i-257] = base;
    108 		base += 1<<litlenextra[i-257];
    109 	}
    110 	/* strange table entry in spec... */
    111 	litlenbase[285-257]--;
    112 
    113 	for(i=0,base=1; i<Noff; i++) {
    114 		offbase[i] = base;
    115 		base += 1<<offextra[i];
    116 	}
    117 
    118 	len = malloc(MaxLeaf);
    119 	if(len == nil)
    120 		return FlateNoMem;
    121 
    122 	/* static Litlen bit lengths */
    123 	for(i=0; i<144; i++)
    124 		len[i] = 8;
    125 	for(i=144; i<256; i++)
    126 		len[i] = 9;
    127 	for(i=256; i<280; i++)
    128 		len[i] = 7;
    129 	for(i=280; i<Nlitlen; i++)
    130 		len[i] = 8;
    131 
    132 	if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits))
    133 		return FlateInternal;
    134 
    135 	/* static Offset bit lengths */
    136 	for(i=0; i<Noff; i++)
    137 		len[i] = 5;
    138 
    139 	if(!hufftab(&offtab, len, Noff, MaxFlatBits))
    140 		return FlateInternal;
    141 	free(len);
    142 
    143 	return FlateOk;
    144 }
    145 
    146 int
    147 inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*))
    148 {
    149 	History *his;
    150 	Input in;
    151 	int final, type;
    152 
    153 	his = malloc(sizeof(History));
    154 	if(his == nil)
    155 		return FlateNoMem;
    156 	his->cp = his->his;
    157 	his->full = 0;
    158 	in.getr = getr;
    159 	in.get = get;
    160 	in.wr = wr;
    161 	in.w = w;
    162 	in.nbits = 0;
    163 	in.sreg = 0;
    164 	in.error = FlateOk;
    165 
    166 	do {
    167 		if(!sregfill(&in, 3))
    168 			goto bad;
    169 		final = in.sreg & 0x1;
    170 		type = (in.sreg>>1) & 0x3;
    171 		in.sreg >>= 3;
    172 		in.nbits -= 3;
    173 		switch(type) {
    174 		default:
    175 			in.error = FlateCorrupted;
    176 			goto bad;
    177 		case 0:
    178 			/* uncompressed */
    179 			if(!uncblock(&in, his))
    180 				goto bad;
    181 			break;
    182 		case 1:
    183 			/* fixed huffman */
    184 			if(!fixedblock(&in, his))
    185 				goto bad;
    186 			break;
    187 		case 2:
    188 			/* dynamic huffman */
    189 			if(!dynamicblock(&in, his))
    190 				goto bad;
    191 			break;
    192 		}
    193 	} while(!final);
    194 
    195 	if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) {
    196 		in.error = FlateOutputFail;
    197 		goto bad;
    198 	}
    199 
    200 	if(!sregunget(&in))
    201 		goto bad;
    202 
    203 	free(his);
    204 	if(in.error != FlateOk)
    205 		return FlateInternal;
    206 	return FlateOk;
    207 
    208 bad:
    209 	free(his);
    210 	if(in.error == FlateOk)
    211 		return FlateInternal;
    212 	return in.error;
    213 }
    214 
    215 static int
    216 uncblock(Input *in, History *his)
    217 {
    218 	int len, nlen, c;
    219 	uchar *hs, *hp, *he;
    220 
    221 	if(!sregunget(in))
    222 		return 0;
    223 	len = (*in->get)(in->getr);
    224 	len |= (*in->get)(in->getr)<<8;
    225 	nlen = (*in->get)(in->getr);
    226 	nlen |= (*in->get)(in->getr)<<8;
    227 	if(len != (~nlen&0xffff)) {
    228 		in->error = FlateCorrupted;
    229 		return 0;
    230 	}
    231 
    232 	hp = his->cp;
    233 	hs = his->his;
    234 	he = hs + HistorySize;
    235 
    236 	while(len > 0) {
    237 		c = (*in->get)(in->getr);
    238 		if(c < 0)
    239 			return 0;
    240 		*hp++ = c;
    241 		if(hp == he) {
    242 			his->full = 1;
    243 			if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
    244 				in->error = FlateOutputFail;
    245 				return 0;
    246 			}
    247 			hp = hs;
    248 		}
    249 		len--;
    250 	}
    251 
    252 	his->cp = hp;
    253 
    254 	return 1;
    255 }
    256 
    257 static int
    258 fixedblock(Input *in, History *his)
    259 {
    260 	return decode(in, his, &litlentab, &offtab);
    261 }
    262 
    263 static int
    264 dynamicblock(Input *in, History *his)
    265 {
    266 	Huff *lentab, *offtab;
    267 	char *len;
    268 	int i, j, n, c, nlit, ndist, nclen, res, nb;
    269 
    270 	if(!sregfill(in, 14))
    271 		return 0;
    272 	nlit = (in->sreg&0x1f) + 257;
    273 	ndist = ((in->sreg>>5) & 0x1f) + 1;
    274 	nclen = ((in->sreg>>10) & 0xf) + 4;
    275 	in->sreg >>= 14;
    276 	in->nbits -= 14;
    277 
    278 	if(nlit > Nlitlen || ndist > Noff || nlit < 257) {
    279 		in->error = FlateCorrupted;
    280 		return 0;
    281 	}
    282 
    283 	/* huff table header */
    284 	len = malloc(Nlitlen+Noff);
    285 	lentab = malloc(sizeof(Huff));
    286 	offtab = malloc(sizeof(Huff));
    287 	if(len == nil || lentab == nil || offtab == nil){
    288 		in->error = FlateNoMem;
    289 		goto bad;
    290 	}
    291 	for(i=0; i < Nclen; i++)
    292 		len[i] = 0;
    293 	for(i=0; i<nclen; i++) {
    294 		if(!sregfill(in, 3))
    295 			goto bad;
    296 		len[clenorder[i]] = in->sreg & 0x7;
    297 		in->sreg >>= 3;
    298 		in->nbits -= 3;
    299 	}
    300 
    301 	if(!hufftab(lentab, len, Nclen, ClenBits)){
    302 		in->error = FlateCorrupted;
    303 		goto bad;
    304 	}
    305 
    306 	n = nlit+ndist;
    307 	for(i=0; i<n;) {
    308 		nb = lentab->minbits;
    309 		for(;;){
    310 			if(in->nbits<nb && !sregfill(in, nb))
    311 				goto bad;
    312 			c = lentab->flat[in->sreg & lentab->flatmask];
    313 			nb = c & 0xff;
    314 			if(nb > in->nbits){
    315 				if(nb != 0xff)
    316 					continue;
    317 				c = hdecsym(in, lentab, c);
    318 				if(c < 0)
    319 					goto bad;
    320 			}else{
    321 				c >>= 8;
    322 				in->sreg >>= nb;
    323 				in->nbits -= nb;
    324 			}
    325 			break;
    326 		}
    327 
    328 		if(c < 16) {
    329 			j = 1;
    330 		} else if(c == 16) {
    331 			if(in->nbits<2 && !sregfill(in, 2))
    332 				goto bad;
    333 			j = (in->sreg&0x3)+3;
    334 			in->sreg >>= 2;
    335 			in->nbits -= 2;
    336 			if(i == 0) {
    337 				in->error = FlateCorrupted;
    338 				goto bad;
    339 			}
    340 			c = len[i-1];
    341 		} else if(c == 17) {
    342 			if(in->nbits<3 && !sregfill(in, 3))
    343 				goto bad;
    344 			j = (in->sreg&0x7)+3;
    345 			in->sreg >>= 3;
    346 			in->nbits -= 3;
    347 			c = 0;
    348 		} else if(c == 18) {
    349 			if(in->nbits<7 && !sregfill(in, 7))
    350 				goto bad;
    351 			j = (in->sreg&0x7f)+11;
    352 			in->sreg >>= 7;
    353 			in->nbits -= 7;
    354 			c = 0;
    355 		} else {
    356 			in->error = FlateCorrupted;
    357 			goto bad;
    358 		}
    359 
    360 		if(i+j > n) {
    361 			in->error = FlateCorrupted;
    362 			goto bad;
    363 		}
    364 
    365 		while(j) {
    366 			len[i] = c;
    367 			i++;
    368 			j--;
    369 		}
    370 	}
    371 
    372 	if(!hufftab(lentab, len, nlit, LitlenBits)
    373 	|| !hufftab(offtab, &len[nlit], ndist, OffBits)){
    374 		in->error = FlateCorrupted;
    375 		goto bad;
    376 	}
    377 
    378 	res = decode(in, his, lentab, offtab);
    379 
    380 	free(len);
    381 	free(lentab);
    382 	free(offtab);
    383 
    384 	return res;
    385 
    386 bad:
    387 	free(len);
    388 	free(lentab);
    389 	free(offtab);
    390 	return 0;
    391 }
    392 
    393 static int
    394 decode(Input *in, History *his, Huff *litlentab, Huff *offtab)
    395 {
    396 	int len, off;
    397 	uchar *hs, *hp, *hq, *he;
    398 	int c;
    399 	int nb;
    400 
    401 	hs = his->his;
    402 	he = hs + HistorySize;
    403 	hp = his->cp;
    404 
    405 	for(;;) {
    406 		nb = litlentab->minbits;
    407 		for(;;){
    408 			if(in->nbits<nb && !sregfill(in, nb))
    409 				return 0;
    410 			c = litlentab->flat[in->sreg & litlentab->flatmask];
    411 			nb = c & 0xff;
    412 			if(nb > in->nbits){
    413 				if(nb != 0xff)
    414 					continue;
    415 				c = hdecsym(in, litlentab, c);
    416 				if(c < 0)
    417 					return 0;
    418 			}else{
    419 				c >>= 8;
    420 				in->sreg >>= nb;
    421 				in->nbits -= nb;
    422 			}
    423 			break;
    424 		}
    425 
    426 		if(c < 256) {
    427 			/* literal */
    428 			*hp++ = c;
    429 			if(hp == he) {
    430 				his->full = 1;
    431 				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
    432 					in->error = FlateOutputFail;
    433 					return 0;
    434 				}
    435 				hp = hs;
    436 			}
    437 			continue;
    438 		}
    439 
    440 		if(c == 256)
    441 			break;
    442 
    443 		if(c > 285) {
    444 			in->error = FlateCorrupted;
    445 			return 0;
    446 		}
    447 
    448 		c -= 257;
    449 		nb = litlenextra[c];
    450 		if(in->nbits < nb && !sregfill(in, nb))
    451 			return 0;
    452 		len = litlenbase[c] + (in->sreg & ((1<<nb)-1));
    453 		in->sreg >>= nb;
    454 		in->nbits -= nb;
    455 
    456 		/* get offset */
    457 		nb = offtab->minbits;
    458 		for(;;){
    459 			if(in->nbits<nb && !sregfill(in, nb))
    460 				return 0;
    461 			c = offtab->flat[in->sreg & offtab->flatmask];
    462 			nb = c & 0xff;
    463 			if(nb > in->nbits){
    464 				if(nb != 0xff)
    465 					continue;
    466 				c = hdecsym(in, offtab, c);
    467 				if(c < 0)
    468 					return 0;
    469 			}else{
    470 				c >>= 8;
    471 				in->sreg >>= nb;
    472 				in->nbits -= nb;
    473 			}
    474 			break;
    475 		}
    476 
    477 		if(c > 29) {
    478 			in->error = FlateCorrupted;
    479 			return 0;
    480 		}
    481 
    482 		nb = offextra[c];
    483 		if(in->nbits < nb && !sregfill(in, nb))
    484 			return 0;
    485 
    486 		off = offbase[c] + (in->sreg & ((1<<nb)-1));
    487 		in->sreg >>= nb;
    488 		in->nbits -= nb;
    489 
    490 		hq = hp - off;
    491 		if(hq < hs) {
    492 			if(!his->full) {
    493 				in->error = FlateCorrupted;
    494 				return 0;
    495 			}
    496 			hq += HistorySize;
    497 		}
    498 
    499 		/* slow but correct */
    500 		while(len) {
    501 			*hp = *hq;
    502 			hq++;
    503 			hp++;
    504 			if(hq >= he)
    505 				hq = hs;
    506 			if(hp == he) {
    507 				his->full = 1;
    508 				if((*in->w)(in->wr, hs, HistorySize) != HistorySize) {
    509 					in->error = FlateOutputFail;
    510 					return 0;
    511 				}
    512 				hp = hs;
    513 			}
    514 			len--;
    515 		}
    516 
    517 	}
    518 
    519 	his->cp = hp;
    520 
    521 	return 1;
    522 }
    523 
    524 static int
    525 revcode(int c, int b)
    526 {
    527 	/* shift encode up so it starts on bit 15 then reverse */
    528 	c <<= (16-b);
    529 	c = revtab[c>>8] | (revtab[c&0xff]<<8);
    530 	return c;
    531 }
    532 
    533 /*
    534  * construct the huffman decoding arrays and a fast lookup table.
    535  * the fast lookup is a table indexed by the next flatbits bits,
    536  * which returns the symbol matched and the number of bits consumed,
    537  * or the minimum number of bits needed and 0xff if more than flatbits
    538  * bits are needed.
    539  *
    540  * flatbits can be longer than the smallest huffman code,
    541  * because shorter codes are assigned smaller lexical prefixes.
    542  * this means assuming zeros for the next few bits will give a
    543  * conservative answer, in the sense that it will either give the
    544  * correct answer, or return the minimum number of bits which
    545  * are needed for an answer.
    546  */
    547 static int
    548 hufftab(Huff *h, char *hb, int maxleaf, int flatbits)
    549 {
    550 	ulong bitcount[MaxHuffBits];
    551 	ulong c, fc, ec, mincode, code, nc[MaxHuffBits];
    552 	int i, b, minbits, maxbits;
    553 
    554 	for(i = 0; i < MaxHuffBits; i++)
    555 		bitcount[i] = 0;
    556 	maxbits = -1;
    557 	minbits = MaxHuffBits + 1;
    558 	for(i=0; i < maxleaf; i++){
    559 		b = hb[i];
    560 		if(b){
    561 			bitcount[b]++;
    562 			if(b < minbits)
    563 				minbits = b;
    564 			if(b > maxbits)
    565 				maxbits = b;
    566 		}
    567 	}
    568 
    569 	h->maxbits = maxbits;
    570 	if(maxbits <= 0){
    571 		h->maxbits = 0;
    572 		h->minbits = 0;
    573 		h->flatmask = 0;
    574 		return 1;
    575 	}
    576 	code = 0;
    577 	c = 0;
    578 	for(b = 0; b <= maxbits; b++){
    579 		h->last[b] = c;
    580 		c += bitcount[b];
    581 		mincode = code << 1;
    582 		nc[b] = mincode;
    583 		code = mincode + bitcount[b];
    584 		if(code > (1 << b))
    585 			return 0;
    586 		h->maxcode[b] = code - 1;
    587 		h->last[b] += code - 1;
    588 	}
    589 
    590 	if(flatbits > maxbits)
    591 		flatbits = maxbits;
    592 	h->flatmask = (1 << flatbits) - 1;
    593 	if(minbits > flatbits)
    594 		minbits = flatbits;
    595 	h->minbits = minbits;
    596 
    597 	b = 1 << flatbits;
    598 	for(i = 0; i < b; i++)
    599 		h->flat[i] = ~0;
    600 
    601 	/*
    602 	 * initialize the flat table to include the minimum possible
    603 	 * bit length for each code prefix
    604 	 */
    605 	for(b = maxbits; b > flatbits; b--){
    606 		code = h->maxcode[b];
    607 		if(code == -1)
    608 			break;
    609 		mincode = code + 1 - bitcount[b];
    610 		mincode >>= b - flatbits;
    611 		code >>= b - flatbits;
    612 		for(; mincode <= code; mincode++)
    613 			h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff;
    614 	}
    615 
    616 	for(i = 0; i < maxleaf; i++){
    617 		b = hb[i];
    618 		if(b <= 0)
    619 			continue;
    620 		c = nc[b]++;
    621 		if(b <= flatbits){
    622 			code = (i << 8) | b;
    623 			ec = (c + 1) << (flatbits - b);
    624 			if(ec > (1<<flatbits))
    625 				return 0;	/* this is actually an internal error */
    626 			for(fc = c << (flatbits - b); fc < ec; fc++)
    627 				h->flat[revcode(fc, flatbits)] = code;
    628 		}
    629 		if(b > minbits){
    630 			c = h->last[b] - c;
    631 			if(c >= maxleaf)
    632 				return 0;
    633 			h->decode[c] = i;
    634 		}
    635 	}
    636 	return 1;
    637 }
    638 
    639 static int
    640 hdecsym(Input *in, Huff *h, int nb)
    641 {
    642 	long c;
    643 
    644 	if((nb & 0xff) == 0xff)
    645 		nb = nb >> 8;
    646 	else
    647 		nb = nb & 0xff;
    648 	for(; nb <= h->maxbits; nb++){
    649 		if(in->nbits<nb && !sregfill(in, nb))
    650 			return -1;
    651 		c = revtab[in->sreg&0xff]<<8;
    652 		c |= revtab[(in->sreg>>8)&0xff];
    653 		c >>= (16-nb);
    654 		if(c <= h->maxcode[nb]){
    655 			in->sreg >>= nb;
    656 			in->nbits -= nb;
    657 			return h->decode[h->last[nb] - c];
    658 		}
    659 	}
    660 	in->error = FlateCorrupted;
    661 	return -1;
    662 }
    663 
    664 static int
    665 sregfill(Input *in, int n)
    666 {
    667 	int c;
    668 
    669 	while(n > in->nbits) {
    670 		c = (*in->get)(in->getr);
    671 		if(c < 0){
    672 			in->error = FlateInputFail;
    673 			return 0;
    674 		}
    675 		in->sreg |= c<<in->nbits;
    676 		in->nbits += 8;
    677 	}
    678 	return 1;
    679 }
    680 
    681 static int
    682 sregunget(Input *in)
    683 {
    684 	if(in->nbits >= 8) {
    685 		in->error = FlateInternal;
    686 		return 0;
    687 	}
    688 
    689 	/* throw other bits on the floor */
    690 	in->nbits = 0;
    691 	in->sreg = 0;
    692 	return 1;
    693 }