1
0
mirror of https://git.tartarus.org/simon/putty.git synced 2025-01-09 09:27:59 +00:00

Rewrite packet parsing in sshshare.c using BinarySource.

Another set of localised decoding routines get thrown away here. Also,
I've changed the APIs of a couple of helper functions in x11fwd.c to
take ptrlens in place of zero-terminated C strings, because that's the
format in which they come back from the decode, and it saves mallocing
a zero-terminated version of each one just to pass to those helpers.
This commit is contained in:
Simon Tatham 2018-05-29 19:11:22 +01:00
parent 28c086ca9a
commit 5be57af173
3 changed files with 100 additions and 175 deletions

4
ssh.h
View File

@ -657,8 +657,8 @@ char *platform_get_x_display(void);
*/
void x11_get_auth_from_authfile(struct X11Display *display,
const char *authfilename);
int x11_identify_auth_proto(const char *proto);
void *x11_dehexify(const char *hex, int *outlen);
int x11_identify_auth_proto(ptrlen protoname);
void *x11_dehexify(ptrlen hex, int *outlen);
Bignum copybn(Bignum b);
Bignum bn_power_2(int n);

View File

@ -728,30 +728,30 @@ static void send_packet_to_downstream(struct ssh_sharing_connstate *cs,
* If that happens, we just chop up the packet into pieces and
* send them as separate CHANNEL_DATA packets.
*/
const char *upkt = (const char *)pkt;
BinarySource src[1];
unsigned channel;
ptrlen data;
int len = toint(GET_32BIT(upkt + 4));
upkt += 8; /* skip channel id + length field */
if (len < 0 || len > pktlen - 8)
len = pktlen - 8;
BinarySource_BARE_INIT(src, pkt, pktlen);
channel = get_uint32(src);
data = get_string(src);
do {
int this_len = (len > chan->downstream_maxpkt ?
chan->downstream_maxpkt : len);
int this_len = (data.len > chan->downstream_maxpkt ?
chan->downstream_maxpkt : data.len);
packet = strbuf_new();
put_uint32(packet, 0); /* placeholder for length field */
put_byte(packet, type);
put_uint32(packet, chan->downstream_id);
put_uint32(packet, channel);
put_uint32(packet, this_len);
put_data(packet, upkt, this_len);
len -= this_len;
upkt += this_len;
put_data(packet, data.ptr, this_len);
data.ptr = (const char *)data.ptr + this_len;
data.len -= this_len;
PUT_32BIT(packet->s, packet->len-4);
sk_write(cs->sock, packet->s, packet->len);
strbuf_free(packet);
} while (len > 0);
} while (data.len > 0);
} else {
/*
* Just do the obvious thing.
@ -924,44 +924,6 @@ static void share_closing(Plug plug, const char *error_msg, int error_code,
share_begin_cleanup(cs);
}
static int getstring_inner(const void *vdata, int datalen,
char **out, int *outlen)
{
const unsigned char *data = (const unsigned char *)vdata;
int len;
if (datalen < 4)
return FALSE;
len = toint(GET_32BIT(data));
if (len < 0 || len > datalen - 4)
return FALSE;
if (outlen)
*outlen = len + 4; /* total size including length field */
if (out)
*out = dupprintf("%.*s", len, (char *)data + 4);
return TRUE;
}
static char *getstring(const void *data, int datalen)
{
char *ret;
if (getstring_inner(data, datalen, &ret, NULL))
return ret;
else
return NULL;
}
static int getstring_size(const void *data, int datalen)
{
int ret;
if (getstring_inner(data, datalen, NULL, &ret))
return ret;
else
return -1;
}
/*
* Append a message to the end of an xchannel's queue.
*/
@ -1011,9 +973,11 @@ void share_dead_xchannel_respond(struct ssh_sharing_connstate *cs,
* A CHANNEL_REQUEST is responded to by sending
* CHANNEL_FAILURE, if it has want_reply set.
*/
int wantreplypos = getstring_size(msg->data, msg->datalen);
if (wantreplypos > 0 && wantreplypos < msg->datalen &&
msg->data[wantreplypos] != 0) {
BinarySource src[1];
BinarySource_BARE_INIT(src, msg->data, msg->datalen);
get_uint32(src); /* skip channel id */
get_string(src); /* skip request type */
if (get_bool(src)) {
strbuf *packet = strbuf_new();
put_uint32(packet, xc->server_id);
ssh_send_packet_from_downstream
@ -1170,10 +1134,13 @@ void share_got_pkt_from_server(void *csv, int type,
{
struct ssh_sharing_connstate *cs = (struct ssh_sharing_connstate *)csv;
struct share_globreq *globreq;
int id_pos;
size_t id_pos;
unsigned upstream_id, server_id;
struct share_channel *chan;
struct share_xchannel *xc;
BinarySource src[1];
BinarySource_BARE_INIT(src, pkt, pktlen);
switch (type) {
case SSH2_MSG_REQUEST_SUCCESS:
@ -1207,9 +1174,9 @@ void share_got_pkt_from_server(void *csv, int type,
break;
case SSH2_MSG_CHANNEL_OPEN:
id_pos = getstring_size(pkt, pktlen);
assert(id_pos >= 0);
server_id = GET_32BIT(pkt + id_pos);
get_string(src);
server_id = get_uint32(src);
assert(!get_err(src));
share_add_halfchannel(cs, server_id);
send_packet_to_downstream(cs, type, pkt, pktlen, NULL);
@ -1230,13 +1197,13 @@ void share_got_pkt_from_server(void *csv, int type,
* first uint32 field in the packet. Substitute the downstream
* channel id for our one and pass the packet downstream.
*/
assert(pktlen >= 4);
upstream_id = GET_32BIT(pkt);
id_pos = src->pos;
upstream_id = get_uint32(src);
if ((chan = share_find_channel_by_upstream(cs, upstream_id)) != NULL) {
/*
* The normal case: this id refers to an open channel.
*/
PUT_32BIT(pkt, chan->downstream_id);
PUT_32BIT(pkt + id_pos, chan->downstream_id);
send_packet_to_downstream(cs, type, pkt, pktlen, chan);
/*
@ -1296,9 +1263,10 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
int type,
unsigned char *pkt, int pktlen)
{
char *request_name;
ptrlen request_name;
struct share_forwarding *fwd;
int id_pos;
size_t id_pos;
unsigned maxpkt;
unsigned old_id, new_id, server_id;
struct share_globreq *globreq;
struct share_channel *chan;
@ -1306,6 +1274,11 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
struct share_xchannel *xc;
strbuf *packet;
char *err = NULL;
BinarySource src[1];
size_t wantreplypos;
int orig_wantreply;
BinarySource_BARE_INIT(src, pkt, pktlen);
switch (type) {
case SSH2_MSG_DISCONNECT:
@ -1326,39 +1299,26 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
* will probably require that too, and so we don't forward on
* any request we don't understand.
*/
request_name = getstring(pkt, pktlen);
if (request_name == NULL) {
err = dupprintf("Truncated GLOBAL_REQUEST packet");
goto confused;
}
request_name = get_string(src);
wantreplypos = src->pos;
orig_wantreply = get_bool(src);
if (!strcmp(request_name, "tcpip-forward")) {
int wantreplypos, orig_wantreply, port, ret;
if (ptrlen_eq_string(request_name, "tcpip-forward")) {
ptrlen hostpl;
char *host;
sfree(request_name);
int port, ret;
/*
* Pick the packet apart to find the want_reply field and
* the host/port we're going to ask to listen on.
*/
wantreplypos = getstring_size(pkt, pktlen);
if (wantreplypos < 0 || wantreplypos >= pktlen) {
hostpl = get_string(src);
port = toint(get_uint32(src));
if (get_err(src)) {
err = dupprintf("Truncated GLOBAL_REQUEST packet");
goto confused;
}
orig_wantreply = pkt[wantreplypos];
port = getstring_size(pkt + (wantreplypos + 1),
pktlen - (wantreplypos + 1));
port += (wantreplypos + 1);
if (port < 0 || port > pktlen - 4) {
err = dupprintf("Truncated GLOBAL_REQUEST packet");
goto confused;
}
host = getstring(pkt + (wantreplypos + 1),
pktlen - (wantreplypos + 1));
assert(host != NULL);
port = GET_32BIT(pkt + port);
host = mkstr(hostpl);
/*
* See if we can allocate space in ssh.c's tree of remote
@ -1382,11 +1342,10 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
* that we know whether this forwarding needs to be
* cleaned up if downstream goes away.
*/
int old_wantreply = pkt[wantreplypos];
pkt[wantreplypos] = 1;
ssh_send_packet_from_downstream
(cs->parent->ssh, cs->id, type, pkt, pktlen,
old_wantreply ? NULL : "upstream added want_reply flag");
orig_wantreply ? NULL : "upstream added want_reply flag");
fwd = share_add_forwarding(cs, host, port);
ssh_sharing_queue_global_request(cs->parent->ssh, cs);
@ -1404,34 +1363,23 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
}
sfree(host);
} else if (!strcmp(request_name, "cancel-tcpip-forward")) {
int wantreplypos, orig_wantreply, port;
} else if (ptrlen_eq_string(request_name, "cancel-tcpip-forward")) {
ptrlen hostpl;
char *host;
int port;
struct share_forwarding *fwd;
sfree(request_name);
/*
* Pick the packet apart to find the want_reply field and
* the host/port we're going to ask to listen on.
*/
wantreplypos = getstring_size(pkt, pktlen);
if (wantreplypos < 0 || wantreplypos >= pktlen) {
hostpl = get_string(src);
port = toint(get_uint32(src));
if (get_err(src)) {
err = dupprintf("Truncated GLOBAL_REQUEST packet");
goto confused;
}
orig_wantreply = pkt[wantreplypos];
port = getstring_size(pkt + (wantreplypos + 1),
pktlen - (wantreplypos + 1));
port += (wantreplypos + 1);
if (port < 0 || port > pktlen - 4) {
err = dupprintf("Truncated GLOBAL_REQUEST packet");
goto confused;
}
host = getstring(pkt + (wantreplypos + 1),
pktlen - (wantreplypos + 1));
assert(host != NULL);
port = GET_32BIT(pkt + port);
host = mkstr(hostpl);
/*
* Look up the existing forwarding with these details.
@ -1449,11 +1397,10 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
* that _we_ know whether the forwarding has been
* deleted even if downstream doesn't want to know.
*/
int old_wantreply = pkt[wantreplypos];
pkt[wantreplypos] = 1;
ssh_send_packet_from_downstream
(cs->parent->ssh, cs->id, type, pkt, pktlen,
old_wantreply ? NULL : "upstream added want_reply flag");
orig_wantreply ? NULL : "upstream added want_reply flag");
ssh_sharing_queue_global_request(cs->parent->ssh, cs);
}
@ -1463,16 +1410,7 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
* Request we don't understand. Manufacture a failure
* message if an answer was required.
*/
int wantreplypos;
sfree(request_name);
wantreplypos = getstring_size(pkt, pktlen);
if (wantreplypos < 0 || wantreplypos >= pktlen) {
err = dupprintf("Truncated GLOBAL_REQUEST packet");
goto confused;
}
if (pkt[wantreplypos])
if (orig_wantreply)
send_packet_to_downstream(cs, SSH2_MSG_REQUEST_FAILURE,
"", 0, NULL);
}
@ -1480,16 +1418,17 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
case SSH2_MSG_CHANNEL_OPEN:
/* Sender channel id comes after the channel type string */
id_pos = getstring_size(pkt, pktlen);
if (id_pos < 0 || id_pos > pktlen - 12) {
get_string(src);
id_pos = src->pos;
old_id = get_uint32(src);
new_id = ssh_alloc_sharing_channel(cs->parent->ssh, cs);
get_uint32(src); /* skip initial window size */
maxpkt = get_uint32(src);
if (get_err(src)) {
err = dupprintf("Truncated CHANNEL_OPEN packet");
goto confused;
}
old_id = GET_32BIT(pkt + id_pos);
new_id = ssh_alloc_sharing_channel(cs->parent->ssh, cs);
share_add_channel(cs, old_id, new_id, 0, UNACKNOWLEDGED,
GET_32BIT(pkt + id_pos + 8));
share_add_channel(cs, old_id, new_id, 0, UNACKNOWLEDGED, maxpkt);
PUT_32BIT(pkt + id_pos, new_id);
ssh_send_packet_from_downstream(cs->parent->ssh, cs->id,
type, pkt, pktlen, NULL);
@ -1501,10 +1440,16 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
goto confused;
}
id_pos = 4; /* sender channel id is 2nd uint32 field in packet */
old_id = GET_32BIT(pkt + id_pos);
server_id = get_uint32(src);
id_pos = src->pos;
old_id = get_uint32(src);
get_uint32(src); /* skip initial window size */
maxpkt = get_uint32(src);
if (get_err(src)) {
err = dupprintf("Truncated CHANNEL_OPEN_CONFIRMATION packet");
goto confused;
}
server_id = GET_32BIT(pkt);
/* This server id may refer to either a halfchannel or an xchannel. */
hc = NULL, xc = NULL; /* placate optimiser */
if ((hc = share_find_halfchannel(cs, server_id)) != NULL) {
@ -1519,8 +1464,7 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
PUT_32BIT(pkt + id_pos, new_id);
chan = share_add_channel(cs, old_id, new_id, server_id, OPEN,
GET_32BIT(pkt + 12));
chan = share_add_channel(cs, old_id, new_id, server_id, OPEN, maxpkt);
if (hc) {
ssh_send_packet_from_downstream(cs->parent->ssh, cs->id,
@ -1539,12 +1483,12 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
break;
case SSH2_MSG_CHANNEL_OPEN_FAILURE:
if (pktlen < 4) {
server_id = get_uint32(src);
if (get_err(src)) {
err = dupprintf("Truncated CHANNEL_OPEN_FAILURE packet");
goto confused;
}
server_id = GET_32BIT(pkt);
/* This server id may refer to either a halfchannel or an xchannel. */
if ((hc = share_find_halfchannel(cs, server_id)) != NULL) {
ssh_send_packet_from_downstream(cs->parent->ssh, cs->id,
@ -1570,8 +1514,11 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
case SSH2_MSG_CHANNEL_FAILURE:
case SSH2_MSG_IGNORE:
case SSH2_MSG_DEBUG:
if (type == SSH2_MSG_CHANNEL_REQUEST &&
(request_name = getstring(pkt + 4, pktlen - 4)) != NULL) {
server_id = get_uint32(src);
if (type == SSH2_MSG_CHANNEL_REQUEST) {
request_name = get_string(src);
/*
* Agent forwarding requests from downstream are treated
* specially. Because OpenSSHD doesn't let us enable agent
@ -1597,11 +1544,8 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
* subsequent CHANNEL_OPENs still can't be associated with
* a parent session channel.)
*/
if (!strcmp(request_name, "auth-agent-req@openssh.com") &&
if (ptrlen_eq_string(request_name, "auth-agent-req@openssh.com") &&
!ssh_agent_forwarding_permitted(cs->parent->ssh)) {
unsigned server_id = GET_32BIT(pkt);
sfree(request_name);
chan = share_find_channel_by_server(cs, server_id);
if (chan) {
@ -1630,14 +1574,10 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
* whether it's one to handle locally or one to pass on to
* a downstream, and if the latter, which one.
*/
if (!strcmp(request_name, "x11-req")) {
unsigned server_id = GET_32BIT(pkt);
if (ptrlen_eq_string(request_name, "x11-req")) {
int want_reply, single_connection, screen;
char *auth_proto_str, *auth_data;
ptrlen auth_data;
int auth_proto;
int pos;
sfree(request_name);
chan = share_find_channel_by_server(cs, server_id);
if (!chan) {
@ -1652,26 +1592,16 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
* Pick apart the whole message to find the downstream
* auth details.
*/
/* we have already seen: 4 bytes channel id, 4+7 request name */
if (pktlen < 17) {
err = dupprintf("Truncated CHANNEL_REQUEST(\"x11\") packet");
want_reply = get_bool(src);
single_connection = get_bool(src);
auth_proto = x11_identify_auth_proto(get_string(src));
auth_data = get_string(src);
screen = toint(get_uint32(src));
if (get_err(src)) {
err = dupprintf("Truncated CHANNEL_REQUEST(\"x11-req\")"
" packet");
goto confused;
}
want_reply = pkt[15] != 0;
single_connection = pkt[16] != 0;
auth_proto_str = getstring(pkt+17, pktlen-17);
auth_proto = x11_identify_auth_proto(auth_proto_str);
sfree(auth_proto_str);
pos = 17 + getstring_size(pkt+17, pktlen-17);
auth_data = getstring(pkt+pos, pktlen-pos);
pos += getstring_size(pkt+pos, pktlen-pos);
if (pktlen < pos+4) {
err = dupprintf("Truncated CHANNEL_REQUEST(\"x11\") packet");
sfree(auth_data);
goto confused;
}
screen = GET_32BIT(pkt+pos);
if (auth_proto < 0) {
/* Reject due to not understanding downstream's
@ -1682,14 +1612,12 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
cs, SSH2_MSG_CHANNEL_FAILURE,
packet->s, packet->len, NULL);
strbuf_free(packet);
sfree(auth_data);
break;
}
chan->x11_auth_proto = auth_proto;
chan->x11_auth_data = x11_dehexify(auth_data,
&chan->x11_auth_datalen);
sfree(auth_data);
chan->x11_auth_upstream =
ssh_sharing_add_x11_display(cs->parent->ssh, auth_proto,
cs, chan);
@ -1715,14 +1643,11 @@ static void share_got_pkt_from_downstream(struct ssh_sharing_connstate *cs,
break;
}
sfree(request_name);
}
ssh_send_packet_from_downstream(cs->parent->ssh, cs->id,
type, pkt, pktlen, NULL);
if (type == SSH2_MSG_CHANNEL_CLOSE && pktlen >= 4) {
server_id = GET_32BIT(pkt);
chan = share_find_channel_by_server(cs, server_id);
if (chan) {
if (chan->state == RCVD_CLOSE) {

View File

@ -974,29 +974,29 @@ void x11_send_eof(struct X11Connection *xconn)
* representations of an X11 auth protocol name + hex cookie into our
* usual integer protocol id and binary auth data.
*/
int x11_identify_auth_proto(const char *protoname)
int x11_identify_auth_proto(ptrlen protoname)
{
int protocol;
for (protocol = 1; protocol < lenof(x11_authnames); protocol++)
if (!strcmp(protoname, x11_authnames[protocol]))
if (ptrlen_eq_string(protoname, x11_authnames[protocol]))
return protocol;
return -1;
}
void *x11_dehexify(const char *hex, int *outlen)
void *x11_dehexify(ptrlen hexpl, int *outlen)
{
int len, i;
unsigned char *ret;
len = strlen(hex) / 2;
len = hexpl.len / 2;
ret = snewn(len, unsigned char);
for (i = 0; i < len; i++) {
char bytestr[3];
unsigned val = 0;
bytestr[0] = hex[2*i];
bytestr[1] = hex[2*i+1];
bytestr[0] = ((const char *)hexpl.ptr)[2*i];
bytestr[1] = ((const char *)hexpl.ptr)[2*i+1];
bytestr[2] = '\0';
sscanf(bytestr, "%x", &val);
ret[i] = val;