diff --git a/errsock.c b/errsock.c index c951e27b..14e66456 100644 --- a/errsock.c +++ b/errsock.c @@ -39,7 +39,7 @@ static const char *sk_error_socket_error(Socket *s) return es->error; } -static SocketEndpointInfo *sk_error_peer_info(Socket *s) +static SocketEndpointInfo *sk_error_endpoint_info(Socket *s, bool peer) { return NULL; } @@ -48,7 +48,7 @@ static const SocketVtable ErrorSocket_sockvt = { .plug = sk_error_plug, .close = sk_error_close, .socket_error = sk_error_socket_error, - .peer_info = sk_error_peer_info, + .endpoint_info = sk_error_endpoint_info, /* other methods are NULL */ }; diff --git a/network.h b/network.h index cfe90267..28aa2eea 100644 --- a/network.h +++ b/network.h @@ -34,7 +34,7 @@ struct SocketVtable { void (*set_frozen) (Socket *s, bool is_frozen); /* ignored by tcp, but vital for ssl */ const char *(*socket_error) (Socket *s); - SocketEndpointInfo *(*peer_info) (Socket *s); + SocketEndpointInfo *(*endpoint_info) (Socket *s, bool peer); }; typedef union { void *p; int i; } accept_ctx_t; @@ -297,8 +297,10 @@ static inline void sk_set_frozen(Socket *s, bool is_frozen) * not NULL, then it is dynamically allocated, and should be freed by * a call to sk_free_endpoint_info(). See below for the definition. */ +static inline SocketEndpointInfo *sk_endpoint_info(Socket *s, bool peer) +{ return s->vt->endpoint_info(s, peer); } static inline SocketEndpointInfo *sk_peer_info(Socket *s) -{ return s->vt->peer_info(s); } +{ return sk_endpoint_info(s, true); } /* * The structure returned from sk_endpoint_info, and a function to free diff --git a/proxy/proxy.c b/proxy/proxy.c index d6811f21..ab5b9a71 100644 --- a/proxy/proxy.c +++ b/proxy/proxy.c @@ -415,6 +415,19 @@ SockAddr *name_lookup(const char *host, int port, char **canonicalname, } } +static SocketEndpointInfo *sk_proxy_endpoint_info(Socket *s, bool peer) +{ + ProxySocket *ps = container_of(s, ProxySocket, sock); + + /* We can't reliably find out where we ended up connecting _to_: + * that's at the far end of the proxy, and might be anything. */ + if (peer) + return NULL; + + /* But we can at least tell where we're coming _from_. */ + return sk_endpoint_info(ps->sub_socket, false); +} + static const SocketVtable ProxySocket_sockvt = { .plug = sk_proxy_plug, .close = sk_proxy_close, @@ -423,7 +436,7 @@ static const SocketVtable ProxySocket_sockvt = { .write_eof = sk_proxy_write_eof, .set_frozen = sk_proxy_set_frozen, .socket_error = sk_proxy_socket_error, - .peer_info = NULL, + .endpoint_info = sk_proxy_endpoint_info, }; static const PlugVtable ProxySocket_plugvt = { diff --git a/proxy/sshproxy.c b/proxy/sshproxy.c index 9d3b2318..c04599af 100644 --- a/proxy/sshproxy.c +++ b/proxy/sshproxy.c @@ -123,7 +123,7 @@ static const char *sshproxy_socket_error(Socket *s) return sp->errmsg; } -static SocketEndpointInfo *sshproxy_peer_info(Socket *s) +static SocketEndpointInfo *sshproxy_endpoint_info(Socket *s, bool peer) { return NULL; } @@ -136,7 +136,7 @@ static const SocketVtable SshProxy_sock_vt = { .write_eof = sshproxy_write_eof, .set_frozen = sshproxy_set_frozen, .socket_error = sshproxy_socket_error, - .peer_info = sshproxy_peer_info, + .endpoint_info = sshproxy_endpoint_info, }; static void sshproxy_eventlog(LogPolicy *lp, const char *event) diff --git a/unix/fd-socket.c b/unix/fd-socket.c index 9758a17b..f90d26d5 100644 --- a/unix/fd-socket.c +++ b/unix/fd-socket.c @@ -315,6 +315,11 @@ static void fdsocket_select_result_input_error(int fd, int event) } } +static SocketEndpointInfo *fdsocket_endpoint_info(Socket *s, bool peer) +{ + return NULL; +} + static const SocketVtable FdSocket_sockvt = { .plug = fdsocket_plug, .close = fdsocket_close, @@ -323,7 +328,7 @@ static const SocketVtable FdSocket_sockvt = { .write_eof = fdsocket_write_eof, .set_frozen = fdsocket_set_frozen, .socket_error = fdsocket_socket_error, - .peer_info = NULL, + .endpoint_info = fdsocket_endpoint_info, }; static void fdsocket_connect_success_callback(void *ctx) diff --git a/unix/network.c b/unix/network.c index 2ad1c5af..95ce4b93 100644 --- a/unix/network.c +++ b/unix/network.c @@ -488,7 +488,7 @@ static size_t sk_net_write(Socket *s, const void *data, size_t len); static size_t sk_net_write_oob(Socket *s, const void *data, size_t len); static void sk_net_write_eof(Socket *s); static void sk_net_set_frozen(Socket *s, bool is_frozen); -static SocketEndpointInfo *sk_net_peer_info(Socket *s); +static SocketEndpointInfo *sk_net_endpoint_info(Socket *s, bool peer); static const char *sk_net_socket_error(Socket *s); static const SocketVtable NetSocket_sockvt = { @@ -499,7 +499,7 @@ static const SocketVtable NetSocket_sockvt = { .write_eof = sk_net_write_eof, .set_frozen = sk_net_set_frozen, .socket_error = sk_net_socket_error, - .peer_info = sk_net_peer_info, + .endpoint_info = sk_net_endpoint_info, }; static Socket *sk_net_accept(accept_ctx_t ctx, Plug *plug) @@ -1494,7 +1494,7 @@ static void sk_net_set_frozen(Socket *sock, bool is_frozen) uxsel_tell(s); } -static SocketEndpointInfo *sk_net_peer_info(Socket *sock) +static SocketEndpointInfo *sk_net_endpoint_info(Socket *sock, bool peer) { NetSocket *s = container_of(sock, NetSocket, sock); union sockaddr_union addr; @@ -1504,8 +1504,12 @@ static SocketEndpointInfo *sk_net_peer_info(Socket *sock) #endif SocketEndpointInfo *pi; - if (getpeername(s->s, &addr.sa, &addrlen) < 0) - return NULL; + { + int retd = (peer ? getpeername(s->s, &addr.sa, &addrlen) : + getsockname(s->s, &addr.sa, &addrlen)); + if (retd < 0) + return NULL; + } pi = snew(SocketEndpointInfo); pi->addressfamily = ADDRTYPE_UNSPEC; diff --git a/windows/handle-socket.c b/windows/handle-socket.c index 5a3dd25d..2a5ba166 100644 --- a/windows/handle-socket.c +++ b/windows/handle-socket.c @@ -296,7 +296,7 @@ static const char *sk_handle_socket_error(Socket *s) return hs->error; } -static SocketEndpointInfo *sk_handle_peer_info(Socket *s) +static SocketEndpointInfo *sk_handle_endpoint_info(Socket *s, bool peer) { HandleSocket *hs = container_of(s, HandleSocket, sock); ULONG pid; @@ -304,6 +304,9 @@ static SocketEndpointInfo *sk_handle_peer_info(Socket *s) DECL_WINDOWS_FUNCTION(static, BOOL, GetNamedPipeClientProcessId, (HANDLE, PULONG)); + if (!peer) + return NULL; + if (!kernel32_module) { kernel32_module = load_system32_dll("kernel32.dll"); #if !HAVE_GETNAMEDPIPECLIENTPROCESSID @@ -345,7 +348,7 @@ static const SocketVtable HandleSocket_sockvt = { .write_eof = sk_handle_write_eof, .set_frozen = sk_handle_set_frozen, .socket_error = sk_handle_socket_error, - .peer_info = sk_handle_peer_info, + .endpoint_info = sk_handle_endpoint_info, }; static void sk_handle_connect_success_callback(void *ctx) @@ -431,7 +434,8 @@ static void sk_handle_deferred_set_frozen(Socket *s, bool is_frozen) hs->frozen = is_frozen; } -static SocketEndpointInfo *sk_handle_deferred_peer_info(Socket *s) +static SocketEndpointInfo *sk_handle_deferred_endpoint_info( + Socket *s, bool peer) { return NULL; } @@ -444,7 +448,7 @@ static const SocketVtable HandleSocket_deferred_sockvt = { .write_eof = sk_handle_deferred_write_eof, .set_frozen = sk_handle_deferred_set_frozen, .socket_error = sk_handle_socket_error, - .peer_info = sk_handle_deferred_peer_info, + .endpoint_info = sk_handle_deferred_endpoint_info, }; Socket *make_deferred_handle_socket(DeferredSocketOpener *opener, diff --git a/windows/named-pipe-server.c b/windows/named-pipe-server.c index c1074712..4b0be4db 100644 --- a/windows/named-pipe-server.c +++ b/windows/named-pipe-server.c @@ -63,7 +63,8 @@ static const char *sk_namedpipeserver_socket_error(Socket *s) return ps->error; } -static SocketEndpointInfo *sk_namedpipeserver_peer_info(Socket *s) +static SocketEndpointInfo *sk_namedpipeserver_endpoint_info( + Socket *s, bool peer) { return NULL; } @@ -196,7 +197,7 @@ static const SocketVtable NamedPipeServerSocket_sockvt = { .plug = sk_namedpipeserver_plug, .close = sk_namedpipeserver_close, .socket_error = sk_namedpipeserver_socket_error, - .peer_info = sk_namedpipeserver_peer_info, + .endpoint_info = sk_namedpipeserver_endpoint_info, }; Socket *new_named_pipe_listener(const char *pipename, Plug *plug) diff --git a/windows/network.c b/windows/network.c index f0bf5f9f..94792c85 100644 --- a/windows/network.c +++ b/windows/network.c @@ -209,6 +209,8 @@ DECL_WINDOWS_FUNCTION(static, SOCKET, accept, (SOCKET, struct sockaddr FAR *, int FAR *)); DECL_WINDOWS_FUNCTION(static, int, getpeername, (SOCKET, struct sockaddr FAR *, int FAR *)); +DECL_WINDOWS_FUNCTION(static, int, getsockname, + (SOCKET, struct sockaddr FAR *, int FAR *)); DECL_WINDOWS_FUNCTION(static, int, recv, (SOCKET, char FAR *, int, int)); DECL_WINDOWS_FUNCTION(static, int, WSAIoctl, (SOCKET, DWORD, LPVOID, DWORD, LPVOID, DWORD, @@ -332,6 +334,7 @@ void sk_init(void) GET_WINDOWS_FUNCTION(winsock_module, ioctlsocket); GET_WINDOWS_FUNCTION(winsock_module, accept); GET_WINDOWS_FUNCTION(winsock_module, getpeername); + GET_WINDOWS_FUNCTION(winsock_module, getsockname); GET_WINDOWS_FUNCTION(winsock_module, recv); GET_WINDOWS_FUNCTION(winsock_module, WSAIoctl); @@ -821,7 +824,7 @@ static size_t sk_net_write_oob(Socket *s, const void *data, size_t len); static void sk_net_write_eof(Socket *s); static void sk_net_set_frozen(Socket *s, bool is_frozen); static const char *sk_net_socket_error(Socket *s); -static SocketEndpointInfo *sk_net_peer_info(Socket *s); +static SocketEndpointInfo *sk_net_endpoint_info(Socket *s, bool peer); static const SocketVtable NetSocket_sockvt = { .plug = sk_net_plug, @@ -831,7 +834,7 @@ static const SocketVtable NetSocket_sockvt = { .write_eof = sk_net_write_eof, .set_frozen = sk_net_set_frozen, .socket_error = sk_net_socket_error, - .peer_info = sk_net_peer_info, + .endpoint_info = sk_net_endpoint_info, }; static Socket *sk_net_accept(accept_ctx_t ctx, Plug *plug) @@ -1747,7 +1750,7 @@ static const char *sk_net_socket_error(Socket *sock) return s->error; } -static SocketEndpointInfo *sk_net_peer_info(Socket *sock) +static SocketEndpointInfo *sk_net_endpoint_info(Socket *sock, bool peer) { NetSocket *s = container_of(sock, NetSocket, sock); #ifdef NO_IPV6 @@ -1759,8 +1762,13 @@ static SocketEndpointInfo *sk_net_peer_info(Socket *sock) int addrlen = sizeof(addr); SocketEndpointInfo *pi; - if (p_getpeername(s->s, (struct sockaddr *)&addr, &addrlen) < 0) - return NULL; + { + int retd = (peer ? + p_getpeername(s->s, (struct sockaddr *)&addr, &addrlen) : + p_getsockname(s->s, (struct sockaddr *)&addr, &addrlen)); + if (retd < 0) + return NULL; + } pi = snew(SocketEndpointInfo); pi->addressfamily = ADDRTYPE_UNSPEC;