plan9port

fork of plan9port with libvec, libstr and libsdb
Log | Files | Refs | README | LICENSE

SConn.c (4441B)


      1 #include <u.h>
      2 #include <libc.h>
      3 #include <mp.h>
      4 #include <libsec.h>
      5 #include "SConn.h"
      6 
      7 extern int verbose;
      8 
      9 typedef struct ConnState {
     10 	uchar secret[SHA1dlen];
     11 	ulong seqno;
     12 	RC4state rc4;
     13 } ConnState;
     14 
     15 #undef SS
     16 typedef struct SS {
     17 	int fd;		/* file descriptor for read/write of encrypted data */
     18 	int alg;	/* if nonzero, "alg sha rc4_128" */
     19 	ConnState in, out;
     20 } SS;
     21 
     22 static int
     23 SC_secret(SConn *conn, uchar *sigma, int direction)
     24 {
     25 	SS *ss = (SS*)(conn->chan);
     26 	int nsigma = conn->secretlen;
     27 
     28 	if(direction != 0){
     29 		hmac_sha1(sigma, nsigma, (uchar*)"one", 3, ss->out.secret, nil);
     30 		hmac_sha1(sigma, nsigma, (uchar*)"two", 3, ss->in.secret, nil);
     31 	}else{
     32 		hmac_sha1(sigma, nsigma, (uchar*)"two", 3, ss->out.secret, nil);
     33 		hmac_sha1(sigma, nsigma, (uchar*)"one", 3, ss->in.secret, nil);
     34 	}
     35 	setupRC4state(&ss->in.rc4, ss->in.secret, 16); /* restrict to 128 bits */
     36 	setupRC4state(&ss->out.rc4, ss->out.secret, 16);
     37 	ss->alg = 1;
     38 	return 0;
     39 }
     40 
     41 static void
     42 hash(uchar secret[SHA1dlen], uchar *data, int len, int seqno, uchar d[SHA1dlen])
     43 {
     44 	DigestState sha;
     45 	uchar seq[4];
     46 
     47 	seq[0] = seqno>>24;
     48 	seq[1] = seqno>>16;
     49 	seq[2] = seqno>>8;
     50 	seq[3] = seqno;
     51 	memset(&sha, 0, sizeof sha);
     52 	sha1(secret, SHA1dlen, nil, &sha);
     53 	sha1(data, len, nil, &sha);
     54 	sha1(seq, 4, d, &sha);
     55 }
     56 
     57 static int
     58 verify(uchar secret[SHA1dlen], uchar *data, int len, int seqno, uchar d[SHA1dlen])
     59 {
     60 	DigestState sha;
     61 	uchar seq[4];
     62 	uchar digest[SHA1dlen];
     63 
     64 	seq[0] = seqno>>24;
     65 	seq[1] = seqno>>16;
     66 	seq[2] = seqno>>8;
     67 	seq[3] = seqno;
     68 	memset(&sha, 0, sizeof sha);
     69 	sha1(secret, SHA1dlen, nil, &sha);
     70 	sha1(data, len, nil, &sha);
     71 	sha1(seq, 4, digest, &sha);
     72 	return memcmp(d, digest, SHA1dlen);
     73 }
     74 
     75 static int
     76 SC_read(SConn *conn, uchar *buf, int n)
     77 {
     78 	SS *ss = (SS*)(conn->chan);
     79 	uchar count[2], digest[SHA1dlen];
     80 	int len, nr;
     81 
     82 	if(read(ss->fd, count, 2) != 2 || (count[0]&0x80) == 0){
     83 		snprint((char*)buf,n,"!SC_read invalid count");
     84 		return -1;
     85 	}
     86 	len = (count[0]&0x7f)<<8 | count[1];	/* SSL-style count; no pad */
     87 	if(ss->alg){
     88 		len -= SHA1dlen;
     89 		if(len <= 0 || readn(ss->fd, digest, SHA1dlen) != SHA1dlen){
     90 			snprint((char*)buf,n,"!SC_read missing sha1");
     91 			return -1;
     92 		}
     93 		if(len > n || readn(ss->fd, buf, len) != len){
     94 			snprint((char*)buf,n,"!SC_read missing data");
     95 			return -1;
     96 		}
     97 		rc4(&ss->in.rc4, digest, SHA1dlen);
     98 		rc4(&ss->in.rc4, buf, len);
     99 		if(verify(ss->in.secret, buf, len, ss->in.seqno, digest) != 0){
    100 			snprint((char*)buf,n,"!SC_read integrity check failed");
    101 			return -1;
    102 		}
    103 	}else{
    104 		if(len <= 0 || len > n){
    105 			snprint((char*)buf,n,"!SC_read implausible record length");
    106 			return -1;
    107 		}
    108 		if( (nr = readn(ss->fd, buf, len)) != len){
    109 			snprint((char*)buf,n,"!SC_read expected %d bytes, but got %d", len, nr);
    110 			return -1;
    111 		}
    112 	}
    113 	ss->in.seqno++;
    114 	return len;
    115 }
    116 
    117 static int
    118 SC_write(SConn *conn, uchar *buf, int n)
    119 {
    120 	SS *ss = (SS*)(conn->chan);
    121 	uchar count[2], digest[SHA1dlen], enc[Maxmsg+1];
    122 	int len;
    123 
    124 	if(n <= 0 || n > Maxmsg+1){
    125 		werrstr("!SC_write invalid n %d", n);
    126 		return -1;
    127 	}
    128 	len = n;
    129 	if(ss->alg)
    130 		len += SHA1dlen;
    131 	count[0] = 0x80 | len>>8;
    132 	count[1] = len;
    133 	if(write(ss->fd, count, 2) != 2){
    134 		werrstr("!SC_write invalid count");
    135 		return -1;
    136 	}
    137 	if(ss->alg){
    138 		hash(ss->out.secret, buf, n, ss->out.seqno, digest);
    139 		rc4(&ss->out.rc4, digest, SHA1dlen);
    140 		memcpy(enc, buf, n);
    141 		rc4(&ss->out.rc4, enc, n);
    142 		if(write(ss->fd, digest, SHA1dlen) != SHA1dlen ||
    143 				write(ss->fd, enc, n) != n){
    144 			werrstr("!SC_write error on send");
    145 			return -1;
    146 		}
    147 	}else{
    148 		if(write(ss->fd, buf, n) != n){
    149 			werrstr("!SC_write error on send");
    150 			return -1;
    151 		}
    152 	}
    153 	ss->out.seqno++;
    154 	return n;
    155 }
    156 
    157 static void
    158 SC_free(SConn *conn)
    159 {
    160 	SS *ss = (SS*)(conn->chan);
    161 
    162 	close(ss->fd);
    163 	free(ss);
    164 	free(conn);
    165 }
    166 
    167 SConn*
    168 newSConn(int fd)
    169 {
    170 	SS *ss;
    171 	SConn *conn;
    172 
    173 	if(fd < 0)
    174 		return nil;
    175 	ss = (SS*)emalloc(sizeof(*ss));
    176 	conn = (SConn*)emalloc(sizeof(*conn));
    177 	ss->fd  = fd;
    178 	ss->alg = 0;
    179 	conn->chan = (void*)ss;
    180 	conn->secretlen = SHA1dlen;
    181 	conn->free = SC_free;
    182 	conn->secret = SC_secret;
    183 	conn->read = SC_read;
    184 	conn->write = SC_write;
    185 	return conn;
    186 }
    187 
    188 void
    189 writerr(SConn *conn, char *s)
    190 {
    191 	char buf[Maxmsg];
    192 
    193 	snprint(buf, Maxmsg, "!%s", s);
    194 	conn->write(conn, (uchar*)buf, strlen(buf));
    195 }
    196 
    197 int
    198 readstr(SConn *conn, char *s)
    199 {
    200 	int n;
    201 
    202 	n = conn->read(conn, (uchar*)s, Maxmsg);
    203 	if(n >= 0){
    204 		s[n] = 0;
    205 		if(s[0] == '!'){
    206 			memmove(s, s+1, n);
    207 			n = -1;
    208 		}
    209 	}else{
    210 		strcpy(s, "read error");
    211 	}
    212 	return n;
    213 }