diff --git a/scp.c b/scp.c index 69ebac82..0cb2dee3 100644 --- a/scp.c +++ b/scp.c @@ -77,16 +77,22 @@ static void bump(char *fmt, ...) exit(1); } -static void get_password(const char *prompt, char *str, int maxlen) +static int get_password(const char *prompt, char *str, int maxlen) { HANDLE hin, hout; DWORD savemode, i; if (password) { - strncpy(str, password, maxlen); - str[maxlen-1] = '\0'; - password = NULL; - return; + static int tried_once = 0; + + if (tried_once) { + return 0; + } else { + strncpy(str, password, maxlen); + str[maxlen-1] = '\0'; + tried_once = 1; + return 1; + } } hin = GetStdHandle(STD_INPUT_HANDLE); @@ -107,6 +113,8 @@ static void get_password(const char *prompt, char *str, int maxlen) str[i] = '\0'; WriteFile(hout, "\r\n", 2, &i, NULL); + + return 1; } /* diff --git a/scp.h b/scp.h index 754f65c1..401e38db 100644 --- a/scp.h +++ b/scp.h @@ -9,7 +9,7 @@ /* Exported from ssh.c */ extern int scp_flags; -extern void (*ssh_get_password)(const char *prompt, char *str, int maxlen); +extern int (*ssh_get_password)(const char *prompt, char *str, int maxlen); char * ssh_scp_init(char *host, int port, char *cmd, char **realhost); int ssh_scp_recv(unsigned char *buf, int len); void ssh_scp_send(unsigned char *buf, int len); diff --git a/ssh.c b/ssh.c index a355c297..acc89cfc 100644 --- a/ssh.c +++ b/ssh.c @@ -81,7 +81,7 @@ static SOCKET s = INVALID_SOCKET; static unsigned char session_key[32]; static struct ssh_cipher *cipher = NULL; int scp_flags = 0; -void (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL; +int (*ssh_get_password)(const char *prompt, char *str, int maxlen) = NULL; static char *savedhost; @@ -267,8 +267,12 @@ next_packet: static void ssh_gotdata(unsigned char *data, int datalen) { while (datalen > 0) { - if ( s_rdpkt(&data, &datalen) == 0 ) + if ( s_rdpkt(&data, &datalen) == 0 ) { ssh_protocol(NULL, 0, 1); + if (ssh_state == SSH_STATE_CLOSED) { + return; + } + } } } @@ -781,7 +785,17 @@ static int do_ssh_login(unsigned char *in, int inlen, int ispkt) if (IS_SCP) { char prompt[200]; sprintf(prompt, "%s@%s's password: ", cfg.username, savedhost); - ssh_get_password(prompt, password, sizeof(password)); + if (!ssh_get_password(prompt, password, sizeof(password))) { + /* + * get_password failed to get a password (for + * example because one was supplied on the command + * line which has already failed to work). + * Terminate. + */ + logevent("No more passwords to try"); + ssh_state = SSH_STATE_CLOSED; + crReturn(1); + } } else { if (pktin.type == SSH_SMSG_FAILURE && @@ -845,7 +859,8 @@ static int do_ssh_login(unsigned char *in, int inlen, int ispkt) logevent("Authentication refused"); } else if (pktin.type == SSH_MSG_DISCONNECT) { logevent("Received disconnect request"); - crReturn(0); + ssh_state = SSH_STATE_CLOSED; + crReturn(1); } else if (pktin.type != SSH_SMSG_SUCCESS) { fatalbox("Strange packet received, type %d", pktin.type); } @@ -861,8 +876,11 @@ static void ssh_protocol(unsigned char *in, int inlen, int ispkt) { random_init(); - while (!do_ssh_login(in, inlen, ispkt)) + while (!do_ssh_login(in, inlen, ispkt)) { crReturnV; + } + if (ssh_state == SSH_STATE_CLOSED) + crReturnV; if (!cfg.nopty) { send_packet(SSH_CMSG_REQUEST_PTY, @@ -982,6 +1000,11 @@ static int ssh_msg (WPARAM wParam, LPARAM lParam) { return 0; } ssh_gotdata (buf, ret); + if (ssh_state == SSH_STATE_CLOSED) { + closesocket(s); + s = INVALID_SOCKET; + return 0; + } return 1; } return 1; /* shouldn't happen, but WTF */ @@ -1180,7 +1203,13 @@ char *ssh_scp_init(char *host, int port, char *cmd, char **realhost) get_packet(); if (s == INVALID_SOCKET) return "Connection closed by remote host"; - } while (!do_ssh_login(NULL, 0, 1)); + } while (!do_ssh_login(NULL, 0, 1)); + + if (ssh_state == SSH_STATE_CLOSED) { + closesocket(s); + s = INVALID_SOCKET; + return "Session initialisation error"; + } /* Execute command */ sprintf(buf, "Sending command: %.100s", cmd);