plan9port

fork of plan9port with libvec, libstr and libsdb
Log | Files | Refs | README | LICENSE

ssh-agent.c (18364B)


      1 /*
      2  * Present factotum in ssh agent clothing.
      3  */
      4 #include <u.h>
      5 #include <libc.h>
      6 #include <mp.h>
      7 #include <libsec.h>
      8 #include <auth.h>
      9 #include <thread.h>
     10 #include <9pclient.h>
     11 
     12 enum
     13 {
     14 	STACK = 65536
     15 };
     16 enum		/* agent protocol packet types */
     17 {
     18 	SSH_AGENTC_NONE = 0,
     19 	SSH_AGENTC_REQUEST_RSA_IDENTITIES,
     20 	SSH_AGENT_RSA_IDENTITIES_ANSWER,
     21 	SSH_AGENTC_RSA_CHALLENGE,
     22 	SSH_AGENT_RSA_RESPONSE,
     23 	SSH_AGENT_FAILURE,
     24 	SSH_AGENT_SUCCESS,
     25 	SSH_AGENTC_ADD_RSA_IDENTITY,
     26 	SSH_AGENTC_REMOVE_RSA_IDENTITY,
     27 	SSH_AGENTC_REMOVE_ALL_RSA_IDENTITIES,
     28 
     29 	SSH2_AGENTC_REQUEST_IDENTITIES = 11,
     30 	SSH2_AGENT_IDENTITIES_ANSWER,
     31 	SSH2_AGENTC_SIGN_REQUEST,
     32 	SSH2_AGENT_SIGN_RESPONSE,
     33 
     34 	SSH2_AGENTC_ADD_IDENTITY = 17,
     35 	SSH2_AGENTC_REMOVE_IDENTITY,
     36 	SSH2_AGENTC_REMOVE_ALL_IDENTITIES,
     37 	SSH2_AGENTC_ADD_SMARTCARD_KEY,
     38 	SSH2_AGENTC_REMOVE_SMARTCARD_KEY,
     39 
     40 	SSH_AGENTC_LOCK,
     41 	SSH_AGENTC_UNLOCK,
     42 	SSH_AGENTC_ADD_RSA_ID_CONSTRAINED,
     43 	SSH2_AGENTC_ADD_ID_CONSTRAINED,
     44 	SSH_AGENTC_ADD_SMARTCARD_KEY_CONSTRAINED,
     45 
     46 	SSH_AGENT_CONSTRAIN_LIFETIME = 1,
     47 	SSH_AGENT_CONSTRAIN_CONFIRM = 2,
     48 
     49 	SSH2_AGENT_FAILURE = 30,
     50 
     51 	SSH_COM_AGENT2_FAILURE = 102,
     52 	SSH_AGENT_OLD_SIGNATURE = 0x01
     53 };
     54 
     55 typedef struct Aconn Aconn;
     56 struct Aconn
     57 {
     58 	uchar *data;
     59 	uint ndata;
     60 	int ctl;
     61 	int fd;
     62 	char dir[40];
     63 };
     64 
     65 typedef struct Msg Msg;
     66 struct Msg
     67 {
     68 	uchar *bp;
     69 	uchar *p;
     70 	uchar *ep;
     71 	int bpalloc;
     72 };
     73 
     74 char adir[40];
     75 int afd;
     76 int chatty;
     77 char *factotum = "factotum";
     78 
     79 void		agentproc(void *v);
     80 void*	emalloc(int n);
     81 void*	erealloc(void *v, int n);
     82 void		listenproc(void *v);
     83 int		runmsg(Aconn *a);
     84 void		listkeystext(void);
     85 
     86 void
     87 usage(void)
     88 {
     89 	fprint(2, "usage: 9 ssh-agent [-D] [factotum]\n");
     90 	threadexitsall("usage");
     91 }
     92 
     93 int
     94 threadmaybackground(void)
     95 {
     96 	return 1;
     97 }
     98 
     99 void
    100 threadmain(int argc, char **argv)
    101 {
    102 	int fd, pid, export, dotextlist;
    103 	char dir[100], *ns;
    104 	char sock[200], addr[200];
    105 	uvlong x;
    106 
    107 	export = 0;
    108 	dotextlist = 0;
    109 	pid = getpid();
    110 	fmtinstall('B', mpfmt);
    111 	fmtinstall('H', encodefmt);
    112 	fmtinstall('[', encodefmt);
    113 
    114 	ARGBEGIN{
    115 	case '9':
    116 		chatty9pclient++;
    117 		break;
    118 	case 'D':
    119 		chatty++;
    120 		break;
    121 	case 'e':
    122 		export = 1;
    123 		break;
    124 	case 'l':
    125 		dotextlist = 1;
    126 		break;
    127 	default:
    128 		usage();
    129 	}ARGEND
    130 
    131 	if(argc > 1)
    132 		usage();
    133 	if(argc == 1)
    134 		factotum = argv[0];
    135 
    136 	if(dotextlist)
    137 		listkeystext();
    138 
    139 	ns = getns();
    140 	snprint(sock, sizeof sock, "%s/ssh-agent.socket", ns);
    141 	if(0){
    142 		x = ((uvlong)fastrand()<<32) | fastrand();
    143 		x ^= ((uvlong)fastrand()<<32) | fastrand();
    144 		snprint(dir, sizeof dir, "/tmp/ssh-%llux", x);
    145 		if((fd = create(dir, OREAD, DMDIR|0700)) < 0)
    146 			sysfatal("mkdir %s: %r", dir);
    147 		close(fd);
    148 		snprint(sock, sizeof sock, "%s/agent.%d", dir, pid);
    149 	}
    150 	snprint(addr, sizeof addr, "unix!%s", sock);
    151 
    152 	if((afd = announce(addr, adir)) < 0)
    153 		sysfatal("announce %s: %r", addr);
    154 
    155 	print("SSH_AUTH_SOCK=%s;\n", sock);
    156 	if(export)
    157 		print("export SSH_AUTH_SOCK;\n");
    158 	print("SSH_AGENT_PID=%d;\n", pid);
    159 	if(export)
    160 		print("export SSH_AGENT_PID;\n");
    161 	close(1);
    162 	rfork(RFNOTEG);
    163 	proccreate(listenproc, nil, STACK);
    164 	threadexits(0);
    165 }
    166 
    167 void
    168 listenproc(void *v)
    169 {
    170 	Aconn *a;
    171 
    172 	USED(v);
    173 	for(;;){
    174 		a = emalloc(sizeof *a);
    175 		a->ctl = listen(adir, a->dir);
    176 		if(a->ctl < 0)
    177 			sysfatal("listen: %r");
    178 		proccreate(agentproc, a, STACK);
    179 	}
    180 }
    181 
    182 void
    183 agentproc(void *v)
    184 {
    185 	Aconn *a;
    186 	int n;
    187 
    188 	a = v;
    189 	a->fd = accept(a->ctl, a->dir);
    190 	close(a->ctl);
    191 	a->ctl = -1;
    192 	for(;;){
    193 		a->data = erealloc(a->data, a->ndata+1024);
    194 		n = read(a->fd, a->data+a->ndata, 1024);
    195 		if(n <= 0)
    196 			break;
    197 		a->ndata += n;
    198 		while(runmsg(a))
    199 			;
    200 	}
    201 	close(a->fd);
    202 	free(a);
    203 	threadexits(nil);
    204 }
    205 
    206 int
    207 get1(Msg *m)
    208 {
    209 	if(m->p >= m->ep)
    210 		return 0;
    211 	return *m->p++;
    212 }
    213 
    214 int
    215 get2(Msg *m)
    216 {
    217 	uint x;
    218 
    219 	if(m->p+2 > m->ep)
    220 		return 0;
    221 	x = (m->p[0]<<8)|m->p[1];
    222 	m->p += 2;
    223 	return x;
    224 }
    225 
    226 int
    227 get4(Msg *m)
    228 {
    229 	uint x;
    230 	if(m->p+4 > m->ep)
    231 		return 0;
    232 	x = (m->p[0]<<24)|(m->p[1]<<16)|(m->p[2]<<8)|m->p[3];
    233 	m->p += 4;
    234 	return x;
    235 }
    236 
    237 uchar*
    238 getn(Msg *m, uint n)
    239 {
    240 	uchar *p;
    241 
    242 	if(m->p+n > m->ep)
    243 		return nil;
    244 	p = m->p;
    245 	m->p += n;
    246 	return p;
    247 }
    248 
    249 char*
    250 getstr(Msg *m)
    251 {
    252 	uint n;
    253 	uchar *p;
    254 
    255 	n = get4(m);
    256 	p = getn(m, n);
    257 	if(p == nil)
    258 		return nil;
    259 	p--;
    260 	memmove(p, p+1, n);
    261 	p[n] = 0;
    262 	return (char*)p;
    263 }
    264 
    265 mpint*
    266 getmp(Msg *m)
    267 {
    268 	int n;
    269 	uchar *p;
    270 
    271 	n = (get2(m)+7)/8;
    272 	if((p=getn(m, n)) == nil)
    273 		return nil;
    274 	return betomp(p, n, nil);
    275 }
    276 
    277 mpint*
    278 getmp2(Msg *m)
    279 {
    280 	int n;
    281 	uchar *p;
    282 
    283 	n = get4(m);
    284 	if((p = getn(m, n)) == nil)
    285 		return nil;
    286 	return betomp(p, n, nil);
    287 }
    288 
    289 void
    290 newmsg(Msg *m)
    291 {
    292 	memset(m, 0, sizeof *m);
    293 }
    294 
    295 void
    296 mreset(Msg *m)
    297 {
    298 	if(m->bpalloc){
    299 		memset(m->bp, 0, m->ep-m->bp);
    300 		free(m->bp);
    301 	}
    302 	memset(m, 0, sizeof *m);
    303 }
    304 
    305 Msg*
    306 getm(Msg *m, Msg *mm)
    307 {
    308 	uint n;
    309 	uchar *p;
    310 
    311 	n = get4(m);
    312 	if((p = getn(m, n)) == nil)
    313 		return nil;
    314 	mm->bp = p;
    315 	mm->p = p;
    316 	mm->ep = p+n;
    317 	mm->bpalloc = 0;
    318 	return mm;
    319 }
    320 
    321 uchar*
    322 ensure(Msg *m, int n)
    323 {
    324 	int len;
    325 	uchar *p;
    326 	uchar *obp;
    327 
    328 	if(m->bp == nil)
    329 		m->bpalloc = 1;
    330 	if(!m->bpalloc){
    331 		p = emalloc(m->ep - m->bp);
    332 		memmove(p, m->bp, m->ep - m->bp);
    333 		obp = m->bp;
    334 		m->bp = p;
    335 		m->ep += m->bp - obp;
    336 		m->p += m->bp - obp;
    337 		m->bpalloc = 1;
    338 	}
    339 	len = m->ep - m->bp;
    340 	if(m->p+n > m->ep){
    341 		obp = m->bp;
    342 		m->bp = erealloc(m->bp, len+n+1024);
    343 		m->p += m->bp - obp;
    344 		m->ep += m->bp - obp;
    345 		m->ep += n+1024;
    346 	}
    347 	p = m->p;
    348 	m->p += n;
    349 	return p;
    350 }
    351 
    352 void
    353 put4(Msg *m, uint n)
    354 {
    355 	uchar *p;
    356 
    357 	p = ensure(m, 4);
    358 	p[0] = (n>>24)&0xFF;
    359 	p[1] = (n>>16)&0xFF;
    360 	p[2] = (n>>8)&0xFF;
    361 	p[3] = n&0xFF;
    362 }
    363 
    364 void
    365 put2(Msg *m, uint n)
    366 {
    367 	uchar *p;
    368 
    369 	p = ensure(m, 2);
    370 	p[0] = (n>>8)&0xFF;
    371 	p[1] = n&0xFF;
    372 }
    373 
    374 void
    375 put1(Msg *m, uint n)
    376 {
    377 	uchar *p;
    378 
    379 	p = ensure(m, 1);
    380 	p[0] = n&0xFF;
    381 }
    382 
    383 void
    384 putn(Msg *m, void *a, uint n)
    385 {
    386 	uchar *p;
    387 
    388 	p = ensure(m, n);
    389 	memmove(p, a, n);
    390 }
    391 
    392 void
    393 putmp(Msg *m, mpint *b)
    394 {
    395 	int bits, n;
    396 	uchar *p;
    397 
    398 	bits = mpsignif(b);
    399 	put2(m, bits);
    400 	n = (bits+7)/8;
    401 	p = ensure(m, n);
    402 	mptobe(b, p, n, nil);
    403 }
    404 
    405 void
    406 putmp2(Msg *m, mpint *b)
    407 {
    408 	int bits, n;
    409 	uchar *p;
    410 
    411 	if(mpcmp(b, mpzero) == 0){
    412 		put4(m, 0);
    413 		return;
    414 	}
    415 	bits = mpsignif(b);
    416 	n = (bits+7)/8;
    417 	if(bits%8 == 0){
    418 		put4(m, n+1);
    419 		put1(m, 0);
    420 	}else
    421 		put4(m, n);
    422 	p = ensure(m, n);
    423 	mptobe(b, p, n, nil);
    424 }
    425 
    426 void
    427 putstr(Msg *m, char *s)
    428 {
    429 	int n;
    430 
    431 	n = strlen(s);
    432 	put4(m, n);
    433 	putn(m, s, n);
    434 }
    435 
    436 void
    437 putm(Msg *m, Msg *mm)
    438 {
    439 	uint n;
    440 
    441 	n = mm->p - mm->bp;
    442 	put4(m, n);
    443 	putn(m, mm->bp, n);
    444 }
    445 
    446 void
    447 newreply(Msg *m, int type)
    448 {
    449 	memset(m, 0, sizeof *m);
    450 	put4(m, 0);
    451 	put1(m, type);
    452 }
    453 
    454 void
    455 reply(Aconn *a, Msg *m)
    456 {
    457 	uint n;
    458 	uchar *p;
    459 
    460 	n = (m->p - m->bp) - 4;
    461 	p = m->bp;
    462 	p[0] = (n>>24)&0xFF;
    463 	p[1] = (n>>16)&0xFF;
    464 	p[2] = (n>>8)&0xFF;
    465 	p[3] = n&0xFF;
    466 	if(chatty)
    467 		fprint(2, "respond %d t=%d: %.*H\n", n, p[4], n, m->bp+4);
    468 	write(a->fd, p, n+4);
    469 	mreset(m);
    470 }
    471 
    472 typedef struct Key Key;
    473 struct Key
    474 {
    475 	mpint *mod;
    476 	mpint *ek;
    477 	char *comment;
    478 };
    479 
    480 static char*
    481 find(char **f, int nf, char *k)
    482 {
    483 	int i, len;
    484 
    485 	len = strlen(k);
    486 	for(i=1; i<nf; i++)	/* i=1: f[0] is "key" */
    487 		if(strncmp(f[i], k, len) == 0 && f[i][len] == '=')
    488 			return f[i]+len+1;
    489 	return nil;
    490 }
    491 
    492 static int
    493 putrsa1(Msg *m, char **f, int nf)
    494 {
    495 	char *p;
    496 	mpint *mod, *ek;
    497 
    498 	p = find(f, nf, "n");
    499 	if(p == nil || (mod = strtomp(p, nil, 16, nil)) == nil)
    500 		return -1;
    501 	p = find(f, nf, "ek");
    502 	if(p == nil || (ek = strtomp(p, nil, 16, nil)) == nil){
    503 		mpfree(mod);
    504 		return -1;
    505 	}
    506 	p = find(f, nf, "comment");
    507 	if(p == nil)
    508 		p = "";
    509 	put4(m, mpsignif(mod));
    510 	putmp(m, ek);
    511 	putmp(m, mod);
    512 	putstr(m, p);
    513 	mpfree(mod);
    514 	mpfree(ek);
    515 	return 0;
    516 }
    517 
    518 void
    519 printattr(char **f, int nf)
    520 {
    521 	int i;
    522 
    523 	print("#");
    524 	for(i=0; i<nf; i++)
    525 		print(" %s", f[i]);
    526 	print("\n");
    527 }
    528 
    529 void
    530 printrsa1(char **f, int nf)
    531 {
    532 	char *p;
    533 	mpint *mod, *ek;
    534 
    535 	p = find(f, nf, "n");
    536 	if(p == nil || (mod = strtomp(p, nil, 16, nil)) == nil)
    537 		return;
    538 	p = find(f, nf, "ek");
    539 	if(p == nil || (ek = strtomp(p, nil, 16, nil)) == nil){
    540 		mpfree(mod);
    541 		return;
    542 	}
    543 	p = find(f, nf, "comment");
    544 	if(p == nil)
    545 		p = "";
    546 
    547 	if(chatty)
    548 		printattr(f, nf);
    549 	print("%d %.10B %.10B %s\n", mpsignif(mod), ek, mod, p);
    550 	mpfree(ek);
    551 	mpfree(mod);
    552 }
    553 
    554 static int
    555 putrsa(Msg *m, char **f, int nf)
    556 {
    557 	char *p;
    558 	mpint *mod, *ek;
    559 
    560 	p = find(f, nf, "n");
    561 	if(p == nil || (mod = strtomp(p, nil, 16, nil)) == nil)
    562 		return -1;
    563 	p = find(f, nf, "ek");
    564 	if(p == nil || (ek = strtomp(p, nil, 16, nil)) == nil){
    565 		mpfree(mod);
    566 		return -1;
    567 	}
    568 	putstr(m, "ssh-rsa");
    569 	putmp2(m, ek);
    570 	putmp2(m, mod);
    571 	mpfree(ek);
    572 	mpfree(mod);
    573 	return 0;
    574 }
    575 
    576 RSApub*
    577 getrsapub(Msg *m)
    578 {
    579 	RSApub *k;
    580 
    581 	k = rsapuballoc();
    582 	if(k == nil)
    583 		return nil;
    584 	k->ek = getmp2(m);
    585 	k->n = getmp2(m);
    586 	if(k->ek == nil || k->n == nil){
    587 		rsapubfree(k);
    588 		return nil;
    589 	}
    590 	return k;
    591 }
    592 
    593 static int
    594 putdsa(Msg *m, char **f, int nf)
    595 {
    596 	char *p;
    597 	int ret;
    598 	mpint *dp, *dq, *dalpha, *dkey;
    599 
    600 	ret = -1;
    601 	dp = dq = dalpha = dkey = nil;
    602 	p = find(f, nf, "p");
    603 	if(p == nil || (dp = strtomp(p, nil, 16, nil)) == nil)
    604 		goto out;
    605 	p = find(f, nf, "q");
    606 	if(p == nil || (dq = strtomp(p, nil, 16, nil)) == nil)
    607 		goto out;
    608 	p = find(f, nf, "alpha");
    609 	if(p == nil || (dalpha = strtomp(p, nil, 16, nil)) == nil)
    610 		goto out;
    611 	p = find(f, nf, "key");
    612 	if(p == nil || (dkey = strtomp(p, nil, 16, nil)) == nil)
    613 		goto out;
    614 	putstr(m, "ssh-dss");
    615 	putmp2(m, dp);
    616 	putmp2(m, dq);
    617 	putmp2(m, dalpha);
    618 	putmp2(m, dkey);
    619 	ret = 0;
    620 out:
    621 	mpfree(dp);
    622 	mpfree(dq);
    623 	mpfree(dalpha);
    624 	mpfree(dkey);
    625 	return ret;
    626 }
    627 
    628 static int
    629 putkey2(Msg *m, int (*put)(Msg*,char**,int), char **f, int nf)
    630 {
    631 	char *p;
    632 	Msg mm;
    633 
    634 	newmsg(&mm);
    635 	if(put(&mm, f, nf) < 0)
    636 		return -1;
    637 	putm(m, &mm);
    638 	mreset(&mm);
    639 	p = find(f, nf, "comment");
    640 	if(p == nil)
    641 		p = "";
    642 	putstr(m, p);
    643 	return 0;
    644 }
    645 
    646 static int
    647 printkey(char *type, int (*put)(Msg*,char**,int), char **f, int nf)
    648 {
    649 	Msg m;
    650 	char *p;
    651 
    652 	newmsg(&m);
    653 	if(put(&m, f, nf) < 0)
    654 		return -1;
    655 	p = find(f, nf, "comment");
    656 	if(p == nil)
    657 		p = "";
    658 	if(chatty)
    659 		printattr(f, nf);
    660 	print("%s %.*[ %s\n", type, m.p-m.bp, m.bp, p);
    661 	mreset(&m);
    662 	return 0;
    663 }
    664 
    665 DSApub*
    666 getdsapub(Msg *m)
    667 {
    668 	DSApub *k;
    669 
    670 	k = dsapuballoc();
    671 	if(k == nil)
    672 		return nil;
    673 	k->p = getmp2(m);
    674 	k->q = getmp2(m);
    675 	k->alpha = getmp2(m);
    676 	k->key = getmp2(m);
    677 	if(!k->p || !k->q || !k->alpha || !k->key){
    678 		dsapubfree(k);
    679 		return nil;
    680 	}
    681 	return k;
    682 }
    683 
    684 static int
    685 listkeys(Msg *m, int version)
    686 {
    687 	char buf[8192+1], *line[100], *f[20], *p, *s;
    688 	int pnk;
    689 	int i, n, nl, nf, nk;
    690 	CFid *fid;
    691 
    692 	nk = 0;
    693 	pnk = m->p - m->bp;
    694 	put4(m, 0);
    695 	if((fid = nsopen(factotum, nil, "ctl", OREAD)) == nil){
    696 		fprint(2, "ssh-agent: open factotum: %r\n");
    697 		return -1;
    698 	}
    699 	for(;;){
    700 		if((n = fsread(fid, buf, sizeof buf-1)) <= 0)
    701 			break;
    702 		buf[n] = 0;
    703 		nl = getfields(buf, line, nelem(line), 1, "\n");
    704 		for(i=0; i<nl; i++){
    705 			nf = tokenize(line[i], f, nelem(f));
    706 			if(nf == 0 || strcmp(f[0], "key") != 0)
    707 				continue;
    708 			p = find(f, nf, "proto");
    709 			if(p == nil)
    710 				continue;
    711 			s = find(f, nf, "service");
    712 			if(s == nil)
    713 				continue;
    714 
    715 			if(version == 1 && strcmp(p, "rsa") == 0 && strcmp(s, "ssh") == 0)
    716 				if(putrsa1(m, f, nf) >= 0)
    717 					nk++;
    718 			if(version == 2 && strcmp(p, "rsa") == 0 && strcmp(s, "ssh-rsa") == 0)
    719 				if(putkey2(m, putrsa, f, nf) >= 0)
    720 					nk++;
    721 			if(version == 2 && strcmp(p, "dsa") == 0 && strcmp(s, "ssh-dss") == 0)
    722 				if(putkey2(m, putdsa, f, nf) >= 0)
    723 					nk++;
    724 		}
    725 	}
    726 	if(chatty)
    727 		fprint(2, "sending %d keys\n", nk);
    728 	fsclose(fid);
    729 	m->bp[pnk+0] = (nk>>24)&0xFF;
    730 	m->bp[pnk+1] = (nk>>16)&0xFF;
    731 	m->bp[pnk+2] = (nk>>8)&0xFF;
    732 	m->bp[pnk+3] = nk&0xFF;
    733 	return nk;
    734 }
    735 
    736 void
    737 listkeystext(void)
    738 {
    739 	char buf[8192+1], *line[100], *f[20], *p, *s;
    740 	int i, n, nl, nf;
    741 	CFid *fid;
    742 
    743 	if((fid = nsopen(factotum, nil, "ctl", OREAD)) == nil){
    744 		fprint(2, "ssh-agent: open factotum: %r\n");
    745 		return;
    746 	}
    747 	for(;;){
    748 		if((n = fsread(fid, buf, sizeof buf-1)) <= 0)
    749 			break;
    750 		buf[n] = 0;
    751 		nl = getfields(buf, line, nelem(line), 1, "\n");
    752 		for(i=0; i<nl; i++){
    753 			nf = tokenize(line[i], f, nelem(f));
    754 			if(nf == 0 || strcmp(f[0], "key") != 0)
    755 				continue;
    756 			p = find(f, nf, "proto");
    757 			if(p == nil)
    758 				continue;
    759 			s = find(f, nf, "service");
    760 			if(s == nil)
    761 				continue;
    762 
    763 			if(strcmp(p, "rsa") == 0 && strcmp(s, "ssh") == 0)
    764 				printrsa1(f, nf);
    765 			if(strcmp(p, "rsa") == 0 && strcmp(s, "ssh-rsa") == 0)
    766 				printkey("ssh-rsa", putrsa, f, nf);
    767 			if(strcmp(p, "dsa") == 0 && strcmp(s, "ssh-dss") == 0)
    768 				printkey("ssh-dss", putdsa, f, nf);
    769 		}
    770 	}
    771 	fsclose(fid);
    772 	threadexitsall(nil);
    773 }
    774 
    775 mpint*
    776 rsaunpad(mpint *b)
    777 {
    778 	int i, n;
    779 	uchar buf[2560];
    780 
    781 	n = (mpsignif(b)+7)/8;
    782 	if(n > sizeof buf){
    783 		werrstr("rsaunpad: too big");
    784 		return nil;
    785 	}
    786 	mptobe(b, buf, n, nil);
    787 
    788 	/* the initial zero has been eaten by the betomp -> mptobe sequence */
    789 	if(buf[0] != 2){
    790 		werrstr("rsaunpad: expected leading 2");
    791 		return nil;
    792 	}
    793 	for(i=1; i<n; i++)
    794 		if(buf[i]==0)
    795 			break;
    796 	return betomp(buf+i, n-i, nil);
    797 }
    798 
    799 void
    800 mptoberjust(mpint *b, uchar *buf, int len)
    801 {
    802 	int n;
    803 
    804 	n = mptobe(b, buf, len, nil);
    805 	assert(n >= 0);
    806 	if(n < len){
    807 		len -= n;
    808 		memmove(buf+len, buf, n);
    809 		memset(buf, 0, len);
    810 	}
    811 }
    812 
    813 static int
    814 dorsa(Aconn *a, mpint *mod, mpint *exp, mpint *chal, uchar chalbuf[32])
    815 {
    816 	AuthRpc *rpc;
    817 	char buf[4096], *p;
    818 	mpint *decr, *unpad;
    819 
    820 	USED(exp);
    821 	if((rpc = auth_allocrpc()) == nil){
    822 		fprint(2, "ssh-agent: auth_allocrpc: %r\n");
    823 		return -1;
    824 	}
    825 	snprint(buf, sizeof buf, "proto=rsa service=ssh role=decrypt n=%lB ek=%lB", mod, exp);
    826 	if(chatty)
    827 		fprint(2, "ssh-agent: start %s\n", buf);
    828 	if(auth_rpc(rpc, "start", buf, strlen(buf)) != ARok){
    829 		fprint(2, "ssh-agent: auth 'start' failed: %r\n");
    830 	Die:
    831 		auth_freerpc(rpc);
    832 		return -1;
    833 	}
    834 
    835 	p = mptoa(chal, 16, nil, 0);
    836 	if(p == nil){
    837 		fprint(2, "ssh-agent: dorsa: mptoa: %r\n");
    838 		goto Die;
    839 	}
    840 	if(chatty)
    841 		fprint(2, "ssh-agent: challenge %B => %s\n", chal, p);
    842 	if(auth_rpc(rpc, "writehex", p, strlen(p)) != ARok){
    843 		fprint(2, "ssh-agent: dorsa: auth 'write': %r\n");
    844 		free(p);
    845 		goto Die;
    846 	}
    847 	free(p);
    848 	if(auth_rpc(rpc, "readhex", nil, 0) != ARok){
    849 		fprint(2, "ssh-agent: dorsa: auth 'read': %r\n");
    850 		goto Die;
    851 	}
    852 	decr = strtomp(rpc->arg, nil, 16, nil);
    853 	if(chatty)
    854 		fprint(2, "ssh-agent: response %s => %B\n", rpc->arg, decr);
    855 	if(decr == nil){
    856 		fprint(2, "ssh-agent: dorsa: strtomp: %r\n");
    857 		goto Die;
    858 	}
    859 	unpad = rsaunpad(decr);
    860 	if(chatty)
    861 		fprint(2, "ssh-agent: unpad %B => %B\n", decr, unpad);
    862 	if(unpad == nil){
    863 		fprint(2, "ssh-agent: dorsa: rsaunpad: %r\n");
    864 		mpfree(decr);
    865 		goto Die;
    866 	}
    867 	mpfree(decr);
    868 	mptoberjust(unpad, chalbuf, 32);
    869 	mpfree(unpad);
    870 	auth_freerpc(rpc);
    871 	return 0;
    872 }
    873 
    874 int
    875 keysign(Msg *mkey, Msg *mdata, Msg *msig)
    876 {
    877 	char *s;
    878 	AuthRpc *rpc;
    879 	RSApub *rsa;
    880 	DSApub *dsa;
    881 	char buf[4096];
    882 	uchar digest[SHA1dlen];
    883 
    884 	s = getstr(mkey);
    885 	if(strcmp(s, "ssh-rsa") == 0){
    886 		rsa = getrsapub(mkey);
    887 		if(rsa == nil)
    888 			return -1;
    889 		snprint(buf, sizeof buf, "proto=rsa service=ssh-rsa role=sign n=%lB ek=%lB",
    890 			rsa->n, rsa->ek);
    891 		rsapubfree(rsa);
    892 	}else if(strcmp(s, "ssh-dss") == 0){
    893 		dsa = getdsapub(mkey);
    894 		if(dsa == nil)
    895 			return -1;
    896 		snprint(buf, sizeof buf, "proto=dsa service=ssh-dss role=sign p=%lB q=%lB alpha=%lB key=%lB",
    897 			dsa->p, dsa->q, dsa->alpha, dsa->key);
    898 		dsapubfree(dsa);
    899 	}else{
    900 		fprint(2, "ssh-agent: cannot sign key type %s\n", s);
    901 		werrstr("unknown key type %s", s);
    902 		return -1;
    903 	}
    904 
    905 	if((rpc = auth_allocrpc()) == nil){
    906 		fprint(2, "ssh-agent: auth_allocrpc: %r\n");
    907 		return -1;
    908 	}
    909 	if(chatty)
    910 		fprint(2, "ssh-agent: start %s\n", buf);
    911 	if(auth_rpc(rpc, "start", buf, strlen(buf)) != ARok){
    912 		fprint(2, "ssh-agent: auth 'start' failed: %r\n");
    913 	Die:
    914 		auth_freerpc(rpc);
    915 		return -1;
    916 	}
    917 	sha1(mdata->bp, mdata->ep-mdata->bp, digest, nil);
    918 	if(auth_rpc(rpc, "write", digest, SHA1dlen) != ARok){
    919 		fprint(2, "ssh-agent: auth 'write in sign failed: %r\n");
    920 		goto Die;
    921 	}
    922 	if(auth_rpc(rpc, "read", nil, 0) != ARok){
    923 		fprint(2, "ssh-agent: auth 'read' failed: %r\n");
    924 		goto Die;
    925 	}
    926 	newmsg(msig);
    927 	putstr(msig, s);
    928 	put4(msig, rpc->narg);
    929 	putn(msig, rpc->arg, rpc->narg);
    930 	auth_freerpc(rpc);
    931 	return 0;
    932 }
    933 
    934 int
    935 runmsg(Aconn *a)
    936 {
    937 	char *p;
    938 	int n, nk, type, rt, vers;
    939 	mpint *ek, *mod, *chal;
    940 	uchar sessid[16], chalbuf[32], digest[MD5dlen];
    941 	uint len, flags;
    942 	DigestState *s;
    943 	Msg m, mkey, mdata, msig;
    944 
    945 	if(a->ndata < 4)
    946 		return 0;
    947 	len = (a->data[0]<<24)|(a->data[1]<<16)|(a->data[2]<<8)|a->data[3];
    948 	if(a->ndata < 4+len)
    949 		return 0;
    950 	m.p = a->data+4;
    951 	m.ep = m.p+len;
    952 	type = get1(&m);
    953 	if(chatty)
    954 		fprint(2, "msg %d: %.*H\n", type, len, m.p);
    955 	switch(type){
    956 	default:
    957 	Failure:
    958 		newreply(&m, SSH_AGENT_FAILURE);
    959 		reply(a, &m);
    960 		break;
    961 
    962 	case SSH_AGENTC_REQUEST_RSA_IDENTITIES:
    963 		vers = 1;
    964 		newreply(&m, SSH_AGENT_RSA_IDENTITIES_ANSWER);
    965 		goto Identities;
    966 	case SSH2_AGENTC_REQUEST_IDENTITIES:
    967 		vers = 2;
    968 		newreply(&m, SSH2_AGENT_IDENTITIES_ANSWER);
    969 	Identities:
    970 		nk = listkeys(&m, vers);
    971 		if(nk < 0){
    972 			mreset(&m);
    973 			goto Failure;
    974 		}
    975 		if(chatty)
    976 			fprint(2, "request identities\n", nk);
    977 		reply(a, &m);
    978 		break;
    979 
    980 	case SSH_AGENTC_RSA_CHALLENGE:
    981 		n = get4(&m);
    982 		USED(n);
    983 		ek = getmp(&m);
    984 		mod = getmp(&m);
    985 		chal = getmp(&m);
    986 		if((p = (char*)getn(&m, 16)) == nil){
    987 		Failchal:
    988 			mpfree(ek);
    989 			mpfree(mod);
    990 			mpfree(chal);
    991 			goto Failure;
    992 		}
    993 		memmove(sessid, p, 16);
    994 		rt = get4(&m);
    995 		if(rt != 1 || dorsa(a, mod, ek, chal, chalbuf) < 0)
    996 			goto Failchal;
    997 		s = md5(chalbuf, 32, nil, nil);
    998 		if(s == nil)
    999 			goto Failchal;
   1000 		md5(sessid, 16, digest, s);
   1001 		print("md5 %.*H %.*H => %.*H\n", 32, chalbuf, 16, sessid, MD5dlen, digest);
   1002 
   1003 		newreply(&m, SSH_AGENT_RSA_RESPONSE);
   1004 		putn(&m, digest, 16);
   1005 		reply(a, &m);
   1006 
   1007 		mpfree(ek);
   1008 		mpfree(mod);
   1009 		mpfree(chal);
   1010 		break;
   1011 
   1012 	case SSH2_AGENTC_SIGN_REQUEST:
   1013 		if(getm(&m, &mkey) == nil
   1014 		|| getm(&m, &mdata) == nil)
   1015 			goto Failure;
   1016 		flags = get4(&m);
   1017 		if(flags & SSH_AGENT_OLD_SIGNATURE)
   1018 			goto Failure;
   1019 		if(keysign(&mkey, &mdata, &msig) < 0)
   1020 			goto Failure;
   1021 		if(chatty)
   1022 			fprint(2, "signature: %.*H\n",
   1023 				msig.p-msig.bp, msig.bp);
   1024 		newreply(&m, SSH2_AGENT_SIGN_RESPONSE);
   1025 		putm(&m, &msig);
   1026 		mreset(&msig);
   1027 		reply(a, &m);
   1028 		break;
   1029 
   1030 	case SSH_AGENTC_ADD_RSA_IDENTITY:
   1031 		/*
   1032 			msg: n[4] mod[mp] pubexp[exp] privexp[mp]
   1033 				p^-1 mod q[mp] p[mp] q[mp] comment[str]
   1034 		 */
   1035 		goto Failure;
   1036 
   1037 	case SSH_AGENTC_REMOVE_RSA_IDENTITY:
   1038 		/*
   1039 			msg: n[4] mod[mp] pubexp[mp]
   1040 		 */
   1041 		goto Failure;
   1042 
   1043 	}
   1044 
   1045 	a->ndata -= 4+len;
   1046 	memmove(a->data, a->data+4+len, a->ndata);
   1047 	return 1;
   1048 }
   1049 
   1050 void*
   1051 emalloc(int n)
   1052 {
   1053 	void *v;
   1054 
   1055 	v = mallocz(n, 1);
   1056 	if(v == nil){
   1057 		abort();
   1058 		sysfatal("out of memory allocating %d", n);
   1059 	}
   1060 	return v;
   1061 }
   1062 
   1063 void*
   1064 erealloc(void *v, int n)
   1065 {
   1066 	v = realloc(v, n);
   1067 	if(v == nil){
   1068 		abort();
   1069 		sysfatal("out of memory reallocating %d", n);
   1070 	}
   1071 	return v;
   1072 }