inflate.c (13119B)
1 #include <u.h> 2 #include <libc.h> 3 #include <flate.h> 4 5 enum { 6 HistorySize= 32*1024, 7 BufSize= 4*1024, 8 MaxHuffBits= 17, /* maximum bits in a encoded code */ 9 Nlitlen= 288, /* number of litlen codes */ 10 Noff= 32, /* number of offset codes */ 11 Nclen= 19, /* number of codelen codes */ 12 LenShift= 10, /* code = len<<LenShift|code */ 13 LitlenBits= 7, /* number of bits in litlen decode table */ 14 OffBits= 6, /* number of bits in offset decode table */ 15 ClenBits= 6, /* number of bits in code len decode table */ 16 MaxFlatBits= LitlenBits, 17 MaxLeaf= Nlitlen 18 }; 19 20 typedef struct Input Input; 21 typedef struct History History; 22 typedef struct Huff Huff; 23 24 struct Input 25 { 26 int error; /* first error encountered, or FlateOk */ 27 void *wr; 28 int (*w)(void*, void*, int); 29 void *getr; 30 int (*get)(void*); 31 ulong sreg; 32 int nbits; 33 }; 34 35 struct History 36 { 37 uchar his[HistorySize]; 38 uchar *cp; /* current pointer in history */ 39 int full; /* his has been filled up at least once */ 40 }; 41 42 struct Huff 43 { 44 int maxbits; /* max bits for any code */ 45 int minbits; /* min bits to get before looking in flat */ 46 int flatmask; /* bits used in "flat" fast decoding table */ 47 ulong flat[1<<MaxFlatBits]; 48 ulong maxcode[MaxHuffBits]; 49 ulong last[MaxHuffBits]; 50 ulong decode[MaxLeaf]; 51 }; 52 53 /* litlen code words 257-285 extra bits */ 54 static int litlenextra[Nlitlen-257] = 55 { 56 /* 257 */ 0, 0, 0, 57 /* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 58 /* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 59 /* 280 */ 4, 5, 5, 5, 5, 0, 0, 0 60 }; 61 62 static int litlenbase[Nlitlen-257]; 63 64 /* offset code word extra bits */ 65 static int offextra[Noff] = 66 { 67 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 68 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 69 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, 70 0, 0, 71 }; 72 static int offbase[Noff]; 73 74 /* order code lengths */ 75 static int clenorder[Nclen] = 76 { 77 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 78 }; 79 80 /* for static huffman tables */ 81 static Huff litlentab; 82 static Huff offtab; 83 static uchar revtab[256]; 84 85 static int uncblock(Input *in, History*); 86 static int fixedblock(Input *in, History*); 87 static int dynamicblock(Input *in, History*); 88 static int sregfill(Input *in, int n); 89 static int sregunget(Input *in); 90 static int decode(Input*, History*, Huff*, Huff*); 91 static int hufftab(Huff*, char*, int, int); 92 static int hdecsym(Input *in, Huff *h, int b); 93 94 int 95 inflateinit(void) 96 { 97 char *len; 98 int i, j, base; 99 100 /* byte reverse table */ 101 for(i=0; i<256; i++) 102 for(j=0; j<8; j++) 103 if(i & (1<<j)) 104 revtab[i] |= 0x80 >> j; 105 106 for(i=257,base=3; i<Nlitlen; i++) { 107 litlenbase[i-257] = base; 108 base += 1<<litlenextra[i-257]; 109 } 110 /* strange table entry in spec... */ 111 litlenbase[285-257]--; 112 113 for(i=0,base=1; i<Noff; i++) { 114 offbase[i] = base; 115 base += 1<<offextra[i]; 116 } 117 118 len = malloc(MaxLeaf); 119 if(len == nil) 120 return FlateNoMem; 121 122 /* static Litlen bit lengths */ 123 for(i=0; i<144; i++) 124 len[i] = 8; 125 for(i=144; i<256; i++) 126 len[i] = 9; 127 for(i=256; i<280; i++) 128 len[i] = 7; 129 for(i=280; i<Nlitlen; i++) 130 len[i] = 8; 131 132 if(!hufftab(&litlentab, len, Nlitlen, MaxFlatBits)) 133 return FlateInternal; 134 135 /* static Offset bit lengths */ 136 for(i=0; i<Noff; i++) 137 len[i] = 5; 138 139 if(!hufftab(&offtab, len, Noff, MaxFlatBits)) 140 return FlateInternal; 141 free(len); 142 143 return FlateOk; 144 } 145 146 int 147 inflate(void *wr, int (*w)(void*, void*, int), void *getr, int (*get)(void*)) 148 { 149 History *his; 150 Input in; 151 int final, type; 152 153 his = malloc(sizeof(History)); 154 if(his == nil) 155 return FlateNoMem; 156 his->cp = his->his; 157 his->full = 0; 158 in.getr = getr; 159 in.get = get; 160 in.wr = wr; 161 in.w = w; 162 in.nbits = 0; 163 in.sreg = 0; 164 in.error = FlateOk; 165 166 do { 167 if(!sregfill(&in, 3)) 168 goto bad; 169 final = in.sreg & 0x1; 170 type = (in.sreg>>1) & 0x3; 171 in.sreg >>= 3; 172 in.nbits -= 3; 173 switch(type) { 174 default: 175 in.error = FlateCorrupted; 176 goto bad; 177 case 0: 178 /* uncompressed */ 179 if(!uncblock(&in, his)) 180 goto bad; 181 break; 182 case 1: 183 /* fixed huffman */ 184 if(!fixedblock(&in, his)) 185 goto bad; 186 break; 187 case 2: 188 /* dynamic huffman */ 189 if(!dynamicblock(&in, his)) 190 goto bad; 191 break; 192 } 193 } while(!final); 194 195 if(his->cp != his->his && (*w)(wr, his->his, his->cp - his->his) != his->cp - his->his) { 196 in.error = FlateOutputFail; 197 goto bad; 198 } 199 200 if(!sregunget(&in)) 201 goto bad; 202 203 free(his); 204 if(in.error != FlateOk) 205 return FlateInternal; 206 return FlateOk; 207 208 bad: 209 free(his); 210 if(in.error == FlateOk) 211 return FlateInternal; 212 return in.error; 213 } 214 215 static int 216 uncblock(Input *in, History *his) 217 { 218 int len, nlen, c; 219 uchar *hs, *hp, *he; 220 221 if(!sregunget(in)) 222 return 0; 223 len = (*in->get)(in->getr); 224 len |= (*in->get)(in->getr)<<8; 225 nlen = (*in->get)(in->getr); 226 nlen |= (*in->get)(in->getr)<<8; 227 if(len != (~nlen&0xffff)) { 228 in->error = FlateCorrupted; 229 return 0; 230 } 231 232 hp = his->cp; 233 hs = his->his; 234 he = hs + HistorySize; 235 236 while(len > 0) { 237 c = (*in->get)(in->getr); 238 if(c < 0) 239 return 0; 240 *hp++ = c; 241 if(hp == he) { 242 his->full = 1; 243 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) { 244 in->error = FlateOutputFail; 245 return 0; 246 } 247 hp = hs; 248 } 249 len--; 250 } 251 252 his->cp = hp; 253 254 return 1; 255 } 256 257 static int 258 fixedblock(Input *in, History *his) 259 { 260 return decode(in, his, &litlentab, &offtab); 261 } 262 263 static int 264 dynamicblock(Input *in, History *his) 265 { 266 Huff *lentab, *offtab; 267 char *len; 268 int i, j, n, c, nlit, ndist, nclen, res, nb; 269 270 if(!sregfill(in, 14)) 271 return 0; 272 nlit = (in->sreg&0x1f) + 257; 273 ndist = ((in->sreg>>5) & 0x1f) + 1; 274 nclen = ((in->sreg>>10) & 0xf) + 4; 275 in->sreg >>= 14; 276 in->nbits -= 14; 277 278 if(nlit > Nlitlen || ndist > Noff || nlit < 257) { 279 in->error = FlateCorrupted; 280 return 0; 281 } 282 283 /* huff table header */ 284 len = malloc(Nlitlen+Noff); 285 lentab = malloc(sizeof(Huff)); 286 offtab = malloc(sizeof(Huff)); 287 if(len == nil || lentab == nil || offtab == nil){ 288 in->error = FlateNoMem; 289 goto bad; 290 } 291 for(i=0; i < Nclen; i++) 292 len[i] = 0; 293 for(i=0; i<nclen; i++) { 294 if(!sregfill(in, 3)) 295 goto bad; 296 len[clenorder[i]] = in->sreg & 0x7; 297 in->sreg >>= 3; 298 in->nbits -= 3; 299 } 300 301 if(!hufftab(lentab, len, Nclen, ClenBits)){ 302 in->error = FlateCorrupted; 303 goto bad; 304 } 305 306 n = nlit+ndist; 307 for(i=0; i<n;) { 308 nb = lentab->minbits; 309 for(;;){ 310 if(in->nbits<nb && !sregfill(in, nb)) 311 goto bad; 312 c = lentab->flat[in->sreg & lentab->flatmask]; 313 nb = c & 0xff; 314 if(nb > in->nbits){ 315 if(nb != 0xff) 316 continue; 317 c = hdecsym(in, lentab, c); 318 if(c < 0) 319 goto bad; 320 }else{ 321 c >>= 8; 322 in->sreg >>= nb; 323 in->nbits -= nb; 324 } 325 break; 326 } 327 328 if(c < 16) { 329 j = 1; 330 } else if(c == 16) { 331 if(in->nbits<2 && !sregfill(in, 2)) 332 goto bad; 333 j = (in->sreg&0x3)+3; 334 in->sreg >>= 2; 335 in->nbits -= 2; 336 if(i == 0) { 337 in->error = FlateCorrupted; 338 goto bad; 339 } 340 c = len[i-1]; 341 } else if(c == 17) { 342 if(in->nbits<3 && !sregfill(in, 3)) 343 goto bad; 344 j = (in->sreg&0x7)+3; 345 in->sreg >>= 3; 346 in->nbits -= 3; 347 c = 0; 348 } else if(c == 18) { 349 if(in->nbits<7 && !sregfill(in, 7)) 350 goto bad; 351 j = (in->sreg&0x7f)+11; 352 in->sreg >>= 7; 353 in->nbits -= 7; 354 c = 0; 355 } else { 356 in->error = FlateCorrupted; 357 goto bad; 358 } 359 360 if(i+j > n) { 361 in->error = FlateCorrupted; 362 goto bad; 363 } 364 365 while(j) { 366 len[i] = c; 367 i++; 368 j--; 369 } 370 } 371 372 if(!hufftab(lentab, len, nlit, LitlenBits) 373 || !hufftab(offtab, &len[nlit], ndist, OffBits)){ 374 in->error = FlateCorrupted; 375 goto bad; 376 } 377 378 res = decode(in, his, lentab, offtab); 379 380 free(len); 381 free(lentab); 382 free(offtab); 383 384 return res; 385 386 bad: 387 free(len); 388 free(lentab); 389 free(offtab); 390 return 0; 391 } 392 393 static int 394 decode(Input *in, History *his, Huff *litlentab, Huff *offtab) 395 { 396 int len, off; 397 uchar *hs, *hp, *hq, *he; 398 int c; 399 int nb; 400 401 hs = his->his; 402 he = hs + HistorySize; 403 hp = his->cp; 404 405 for(;;) { 406 nb = litlentab->minbits; 407 for(;;){ 408 if(in->nbits<nb && !sregfill(in, nb)) 409 return 0; 410 c = litlentab->flat[in->sreg & litlentab->flatmask]; 411 nb = c & 0xff; 412 if(nb > in->nbits){ 413 if(nb != 0xff) 414 continue; 415 c = hdecsym(in, litlentab, c); 416 if(c < 0) 417 return 0; 418 }else{ 419 c >>= 8; 420 in->sreg >>= nb; 421 in->nbits -= nb; 422 } 423 break; 424 } 425 426 if(c < 256) { 427 /* literal */ 428 *hp++ = c; 429 if(hp == he) { 430 his->full = 1; 431 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) { 432 in->error = FlateOutputFail; 433 return 0; 434 } 435 hp = hs; 436 } 437 continue; 438 } 439 440 if(c == 256) 441 break; 442 443 if(c > 285) { 444 in->error = FlateCorrupted; 445 return 0; 446 } 447 448 c -= 257; 449 nb = litlenextra[c]; 450 if(in->nbits < nb && !sregfill(in, nb)) 451 return 0; 452 len = litlenbase[c] + (in->sreg & ((1<<nb)-1)); 453 in->sreg >>= nb; 454 in->nbits -= nb; 455 456 /* get offset */ 457 nb = offtab->minbits; 458 for(;;){ 459 if(in->nbits<nb && !sregfill(in, nb)) 460 return 0; 461 c = offtab->flat[in->sreg & offtab->flatmask]; 462 nb = c & 0xff; 463 if(nb > in->nbits){ 464 if(nb != 0xff) 465 continue; 466 c = hdecsym(in, offtab, c); 467 if(c < 0) 468 return 0; 469 }else{ 470 c >>= 8; 471 in->sreg >>= nb; 472 in->nbits -= nb; 473 } 474 break; 475 } 476 477 if(c > 29) { 478 in->error = FlateCorrupted; 479 return 0; 480 } 481 482 nb = offextra[c]; 483 if(in->nbits < nb && !sregfill(in, nb)) 484 return 0; 485 486 off = offbase[c] + (in->sreg & ((1<<nb)-1)); 487 in->sreg >>= nb; 488 in->nbits -= nb; 489 490 hq = hp - off; 491 if(hq < hs) { 492 if(!his->full) { 493 in->error = FlateCorrupted; 494 return 0; 495 } 496 hq += HistorySize; 497 } 498 499 /* slow but correct */ 500 while(len) { 501 *hp = *hq; 502 hq++; 503 hp++; 504 if(hq >= he) 505 hq = hs; 506 if(hp == he) { 507 his->full = 1; 508 if((*in->w)(in->wr, hs, HistorySize) != HistorySize) { 509 in->error = FlateOutputFail; 510 return 0; 511 } 512 hp = hs; 513 } 514 len--; 515 } 516 517 } 518 519 his->cp = hp; 520 521 return 1; 522 } 523 524 static int 525 revcode(int c, int b) 526 { 527 /* shift encode up so it starts on bit 15 then reverse */ 528 c <<= (16-b); 529 c = revtab[c>>8] | (revtab[c&0xff]<<8); 530 return c; 531 } 532 533 /* 534 * construct the huffman decoding arrays and a fast lookup table. 535 * the fast lookup is a table indexed by the next flatbits bits, 536 * which returns the symbol matched and the number of bits consumed, 537 * or the minimum number of bits needed and 0xff if more than flatbits 538 * bits are needed. 539 * 540 * flatbits can be longer than the smallest huffman code, 541 * because shorter codes are assigned smaller lexical prefixes. 542 * this means assuming zeros for the next few bits will give a 543 * conservative answer, in the sense that it will either give the 544 * correct answer, or return the minimum number of bits which 545 * are needed for an answer. 546 */ 547 static int 548 hufftab(Huff *h, char *hb, int maxleaf, int flatbits) 549 { 550 ulong bitcount[MaxHuffBits]; 551 ulong c, fc, ec, mincode, code, nc[MaxHuffBits]; 552 int i, b, minbits, maxbits; 553 554 for(i = 0; i < MaxHuffBits; i++) 555 bitcount[i] = 0; 556 maxbits = -1; 557 minbits = MaxHuffBits + 1; 558 for(i=0; i < maxleaf; i++){ 559 b = hb[i]; 560 if(b){ 561 bitcount[b]++; 562 if(b < minbits) 563 minbits = b; 564 if(b > maxbits) 565 maxbits = b; 566 } 567 } 568 569 h->maxbits = maxbits; 570 if(maxbits <= 0){ 571 h->maxbits = 0; 572 h->minbits = 0; 573 h->flatmask = 0; 574 return 1; 575 } 576 code = 0; 577 c = 0; 578 for(b = 0; b <= maxbits; b++){ 579 h->last[b] = c; 580 c += bitcount[b]; 581 mincode = code << 1; 582 nc[b] = mincode; 583 code = mincode + bitcount[b]; 584 if(code > (1 << b)) 585 return 0; 586 h->maxcode[b] = code - 1; 587 h->last[b] += code - 1; 588 } 589 590 if(flatbits > maxbits) 591 flatbits = maxbits; 592 h->flatmask = (1 << flatbits) - 1; 593 if(minbits > flatbits) 594 minbits = flatbits; 595 h->minbits = minbits; 596 597 b = 1 << flatbits; 598 for(i = 0; i < b; i++) 599 h->flat[i] = ~0; 600 601 /* 602 * initialize the flat table to include the minimum possible 603 * bit length for each code prefix 604 */ 605 for(b = maxbits; b > flatbits; b--){ 606 code = h->maxcode[b]; 607 if(code == -1) 608 break; 609 mincode = code + 1 - bitcount[b]; 610 mincode >>= b - flatbits; 611 code >>= b - flatbits; 612 for(; mincode <= code; mincode++) 613 h->flat[revcode(mincode, flatbits)] = (b << 8) | 0xff; 614 } 615 616 for(i = 0; i < maxleaf; i++){ 617 b = hb[i]; 618 if(b <= 0) 619 continue; 620 c = nc[b]++; 621 if(b <= flatbits){ 622 code = (i << 8) | b; 623 ec = (c + 1) << (flatbits - b); 624 if(ec > (1<<flatbits)) 625 return 0; /* this is actually an internal error */ 626 for(fc = c << (flatbits - b); fc < ec; fc++) 627 h->flat[revcode(fc, flatbits)] = code; 628 } 629 if(b > minbits){ 630 c = h->last[b] - c; 631 if(c >= maxleaf) 632 return 0; 633 h->decode[c] = i; 634 } 635 } 636 return 1; 637 } 638 639 static int 640 hdecsym(Input *in, Huff *h, int nb) 641 { 642 long c; 643 644 if((nb & 0xff) == 0xff) 645 nb = nb >> 8; 646 else 647 nb = nb & 0xff; 648 for(; nb <= h->maxbits; nb++){ 649 if(in->nbits<nb && !sregfill(in, nb)) 650 return -1; 651 c = revtab[in->sreg&0xff]<<8; 652 c |= revtab[(in->sreg>>8)&0xff]; 653 c >>= (16-nb); 654 if(c <= h->maxcode[nb]){ 655 in->sreg >>= nb; 656 in->nbits -= nb; 657 return h->decode[h->last[nb] - c]; 658 } 659 } 660 in->error = FlateCorrupted; 661 return -1; 662 } 663 664 static int 665 sregfill(Input *in, int n) 666 { 667 int c; 668 669 while(n > in->nbits) { 670 c = (*in->get)(in->getr); 671 if(c < 0){ 672 in->error = FlateInputFail; 673 return 0; 674 } 675 in->sreg |= c<<in->nbits; 676 in->nbits += 8; 677 } 678 return 1; 679 } 680 681 static int 682 sregunget(Input *in) 683 { 684 if(in->nbits >= 8) { 685 in->error = FlateInternal; 686 return 0; 687 } 688 689 /* throw other bits on the floor */ 690 in->nbits = 0; 691 in->sreg = 0; 692 return 1; 693 }