plan9port

fork of plan9port with libvec, libstr and libsdb
Log | Files | Refs | README | LICENSE

bayes.c (3822B)


      1 #include <u.h>
      2 #include <libc.h>
      3 #include <bio.h>
      4 #include <regexp.h>
      5 #include "hash.h"
      6 
      7 enum
      8 {
      9 	MAXTAB = 256,
     10 	MAXBEST = 32
     11 };
     12 
     13 typedef struct Table Table;
     14 struct Table
     15 {
     16 	char *file;
     17 	Hash *hash;
     18 	int nmsg;
     19 };
     20 
     21 typedef struct Word Word;
     22 struct Word
     23 {
     24 	Stringtab *s;	/* from hmsg */
     25 	int count[MAXTAB];	/* counts from each table */
     26 	double p[MAXTAB];	/* probabilities from each table */
     27 	double mp;	/* max probability */
     28 	int mi;		/* w.p[w.mi] = w.mp */
     29 };
     30 
     31 Table tab[MAXTAB];
     32 int ntab;
     33 
     34 Word best[MAXBEST];
     35 int mbest;
     36 int nbest;
     37 
     38 int debug;
     39 
     40 void
     41 usage(void)
     42 {
     43 	fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
     44 	exits("usage");
     45 }
     46 
     47 void*
     48 emalloc(int n)
     49 {
     50 	void *v;
     51 
     52 	v = mallocz(n, 1);
     53 	if(v == nil)
     54 		sysfatal("out of memory");
     55 	return v;
     56 }
     57 
     58 void
     59 noteword(Word *w)
     60 {
     61 	int i;
     62 
     63 	for(i=nbest-1; i>=0; i--)
     64 		if(w->mp < best[i].mp)
     65 			break;
     66 	i++;
     67 
     68 	if(i >= mbest)
     69 		return;
     70 	if(nbest == mbest)
     71 		nbest--;
     72 	if(i < nbest)
     73 		memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
     74 	best[i] = *w;
     75 	nbest++;
     76 }
     77 
     78 Hash*
     79 hread(char *s)
     80 {
     81 	Hash *h;
     82 	Biobuf *b;
     83 
     84 	if((b = Bopenlock(s, OREAD)) == nil)
     85 		sysfatal("open %s: %r", s);
     86 
     87 	h = emalloc(sizeof(Hash));
     88 	Breadhash(b, h, 1);
     89 	Bterm(b);
     90 	return h;
     91 }
     92 
     93 void
     94 main(int argc, char **argv)
     95 {
     96 	int i, j, a, mi, oi, tot, keywords;
     97 	double totp, p, xp[MAXTAB];
     98 	Hash *hmsg;
     99 	Word w;
    100 	Stringtab *s, *t;
    101 	Biobuf bout;
    102 
    103 	mbest = 15;
    104 	keywords = 0;
    105 	ARGBEGIN{
    106 	case 'D':
    107 		debug = 1;
    108 		break;
    109 	case 'k':
    110 		keywords = 1;
    111 		break;
    112 	case 'm':
    113 		mbest = atoi(EARGF(usage()));
    114 		if(mbest > MAXBEST)
    115 			sysfatal("cannot keep more than %d words", MAXBEST);
    116 		break;
    117 	default:
    118 		usage();
    119 	}ARGEND
    120 
    121 	for(i=0; i<argc; i++)
    122 		if(strcmp(argv[i], "~") == 0)
    123 			break;
    124 
    125 	if(i > MAXTAB)
    126 		sysfatal("cannot handle more than %d tables", MAXTAB);
    127 
    128 	if(i+1 >= argc)
    129 		usage();
    130 
    131 	for(i=0; i<argc; i++){
    132 		if(strcmp(argv[i], "~") == 0)
    133 			break;
    134 		tab[ntab].file = argv[i];
    135 		tab[ntab].hash = hread(argv[i]);
    136 		s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
    137 		if(s == nil || s->count == 0)
    138 			tab[ntab].nmsg = 1;
    139 		else
    140 			tab[ntab].nmsg = s->count;
    141 		ntab++;
    142 	}
    143 
    144 	Binit(&bout, 1, OWRITE);
    145 
    146 	oi = ++i;
    147 	for(a=i; a<argc; a++){
    148 		hmsg = hread(argv[a]);
    149 		nbest = 0;
    150 		for(s=hmsg->all; s; s=s->link){
    151 			w.s = s;
    152 			tot = 0;
    153 			totp = 0.0;
    154 			for(i=0; i<ntab; i++){
    155 				t = findstab(tab[i].hash, s->str, s->n, 0);
    156 				if(t == nil)
    157 					w.count[i] = 0;
    158 				else
    159 					w.count[i] = t->count;
    160 				tot += w.count[i];
    161 				p = w.count[i]/(double)tab[i].nmsg;
    162 				if(p >= 1.0)
    163 					p = 1.0;
    164 				w.p[i] = p;
    165 				totp += p;
    166 			}
    167 
    168 			if(tot < 5){		/* word does not appear enough; give to box 0 */
    169 				w.p[0] = 0.5;
    170 				for(i=1; i<ntab; i++)
    171 					w.p[i] = 0.1;
    172 				w.mp = 0.5;
    173 				w.mi = 0;
    174 				noteword(&w);
    175 				continue;
    176 			}
    177 
    178 			w.mp = 0.0;
    179 			for(i=0; i<ntab; i++){
    180 				p = w.p[i];
    181 				p /= totp;
    182 				if(p < 0.01)
    183 					p = 0.01;
    184 				else if(p > 0.99)
    185 					p = 0.99;
    186 				if(p > w.mp){
    187 					w.mp = p;
    188 					w.mi = i;
    189 				}
    190 				w.p[i] = p;
    191 			}
    192 			noteword(&w);
    193 		}
    194 
    195 		totp = 0.0;
    196 		for(i=0; i<ntab; i++){
    197 			p = 1.0;
    198 			for(j=0; j<nbest; j++)
    199 				p *= best[j].p[i];
    200 			xp[i] = p;
    201 			totp += p;
    202 		}
    203 		for(i=0; i<ntab; i++)
    204 			xp[i] /= totp;
    205 		mi = 0;
    206 		for(i=1; i<ntab; i++)
    207 			if(xp[i] > xp[mi])
    208 				mi = i;
    209 		if(oi != argc-1)
    210 			Bprint(&bout, "%s: ", argv[a]);
    211 		Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
    212 		if(keywords){
    213 			for(i=0; i<nbest; i++){
    214 				Bprint(&bout, " ");
    215 				Bwrite(&bout, best[i].s->str, best[i].s->n);
    216 				Bprint(&bout, " %f", best[i].p[mi]);
    217 			}
    218 		}
    219 		freehash(hmsg);
    220 		Bprint(&bout, "\n");
    221 		if(debug){
    222 			for(i=0; i<nbest; i++){
    223 				Bwrite(&bout, best[i].s->str, best[i].s->n);
    224 				Bprint(&bout, " %f", best[i].p[mi]);
    225 				if(best[i].p[mi] < best[i].mp)
    226 					Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
    227 				Bprint(&bout, "\n");
    228 			}
    229 		}
    230 	}
    231 	Bterm(&bout);
    232 }