Skip to content

Commit

Permalink
UCT/TCP/GTEST: Protect against connection from non-UCX sock-based app
Browse files Browse the repository at this point in the history
  • Loading branch information
dmitrygx committed Jan 8, 2020
1 parent 230daff commit eaec8bd
Show file tree
Hide file tree
Showing 8 changed files with 426 additions and 73 deletions.
2 changes: 1 addition & 1 deletion src/ucs/sys/sock.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ ucs_status_t ucs_socket_setopt(int fd, int level, int optname,
return UCS_OK;
}

static const char *ucs_socket_getname_str(int fd, char *str, size_t max_size)
const char *ucs_socket_getname_str(int fd, char *str, size_t max_size)
{
struct sockaddr_storage sock_addr = {0}; /* Suppress Clang false-positive */
socklen_t addr_size;
Expand Down
13 changes: 13 additions & 0 deletions src/ucs/sys/sock.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,19 @@ const char* ucs_sockaddr_str(const struct sockaddr *sock_addr,
char *str, size_t max_size);


/**
* Extract the IP address from a given socket fd and return it as a string.
*
* @param [in] fd Socket fd.
* @param [out] str A string filled with the IP address.
* @param [in] max_size Size of a string (considering '\0'-terminated symbol)
*
* @return ip_str if the sock_addr has a valid IP address or 'Invalid address'
* otherwise.
*/
const char *ucs_socket_getname_str(int fd, char *str, size_t max_size);


/**
* Return a value indicating the relationships between passed sockaddr structures.
*
Expand Down
28 changes: 9 additions & 19 deletions src/uct/tcp/tcp.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

#define UCT_TCP_CONFIG_PREFIX "TCP_"

/* Magic number that is used by TCP to identify its peers */
#define UCT_TCP_MAGIC_NUMBER 0xCAFEBABE12345678lu

/* Maximum number of events to wait on event set */
#define UCT_TCP_MAX_EVENTS 16

Expand Down Expand Up @@ -86,9 +89,13 @@ typedef enum uct_tcp_ep_conn_state {
* After it is done, it sends `UCT_TCP_CM_CONN_REQ` to the peer.
* All AM operations return `UCS_ERR_NO_RESOURCE` error to a caller. */
UCT_TCP_EP_CONN_STATE_CONNECTING,
/* EP is receiving the magic number in order to verify a peer. EP is moved
* to this state after accept() completed. */
UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER,
/* EP is accepting connection from a peer, i.e. accept() returns socket fd
* on which a connection was accepted, this EP was created using this socket
* and now it is waiting for `UCT_TCP_CM_CONN_REQ` from a peer. */
* fd and the magic number was received and verified by EP and now it is
* waiting for `UCT_TCP_CM_CONN_REQ` from a peer. */
UCT_TCP_EP_CONN_STATE_ACCEPTING,
/* EP is waiting for `UCT_TCP_CM_CONN_ACK` message from a peer after sending
* `UCT_TCP_CM_CONN_REQ`.
Expand Down Expand Up @@ -144,6 +151,7 @@ KHASH_INIT(uct_tcp_cm_eps, struct sockaddr_in, ucs_list_link_t*,
typedef struct uct_tcp_cm_state {
const char *name; /* CM state name */
uct_tcp_ep_progress_t tx_progress; /* TX progress function */
uct_tcp_ep_progress_t rx_progress; /* RX progress function */
} uct_tcp_cm_state_t;


Expand Down Expand Up @@ -386,10 +394,6 @@ void uct_tcp_iface_remove_ep(uct_tcp_ep_t *ep);

ucs_status_t uct_tcp_ep_handle_dropped_connect(uct_tcp_ep_t *ep, int io_errno);

unsigned uct_tcp_ep_progress_am_rx(uct_tcp_ep_t *ep);

unsigned uct_tcp_ep_progress_put_rx(uct_tcp_ep_t *ep);

ucs_status_t uct_tcp_ep_init(uct_tcp_iface_t *iface, int fd,
const struct sockaddr_in *dest_addr,
uct_tcp_ep_t **ep_p);
Expand Down Expand Up @@ -480,20 +484,6 @@ ucs_status_t uct_tcp_cm_handle_incoming_conn(uct_tcp_iface_t *iface,

ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep);

static inline unsigned uct_tcp_ep_progress_tx(uct_tcp_ep_t *ep)
{
return uct_tcp_ep_cm_state[ep->conn_state].tx_progress(ep);
}

static inline unsigned uct_tcp_ep_progress_rx(uct_tcp_ep_t *ep)
{
if (!(ep->ctx_caps & UCS_BIT(UCT_TCP_EP_CTX_TYPE_PUT_RX))) {
return uct_tcp_ep_progress_am_rx(ep);
} else {
return uct_tcp_ep_progress_put_rx(ep);
}
}

static inline void uct_tcp_iface_outstanding_inc(uct_tcp_iface_t *iface)
{
iface->outstanding++;
Expand Down
82 changes: 49 additions & 33 deletions src/uct/tcp/tcp_cm.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ void uct_tcp_cm_change_conn_state(uct_tcp_ep_t *ep,

switch(ep->conn_state) {
case UCT_TCP_EP_CONN_STATE_CONNECTING:
ucs_assertv(iface->config.conn_nb, "ep=%p", ep);
/* Fall through */
case UCT_TCP_EP_CONN_STATE_WAITING_ACK:
if (old_conn_state == UCT_TCP_EP_CONN_STATE_CLOSED) {
uct_tcp_iface_outstanding_inc(iface);
Expand Down Expand Up @@ -61,13 +59,15 @@ void uct_tcp_cm_change_conn_state(uct_tcp_ep_t *ep,
(old_conn_state == UCT_TCP_EP_CONN_STATE_WAITING_ACK) ||
(old_conn_state == UCT_TCP_EP_CONN_STATE_WAITING_REQ)) {
uct_tcp_iface_outstanding_dec(iface);
} else if (old_conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) {
} else if ((old_conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) ||
(old_conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER)) {
/* Since ep::peer_addr is 0'ed, we have to print w/o peer's address */
full_log = 0;
}
break;
default:
ucs_assert(ep->conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING);
ucs_assert((ep->conn_state == UCT_TCP_EP_CONN_STATE_ACCEPTING) ||
(ep->conn_state == UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER));
/* Since ep::peer_addr is 0'ed and client's <address:port>
* has already been logged, print w/o peer's address */
full_log = 0;
Expand Down Expand Up @@ -143,8 +143,9 @@ static void uct_tcp_cm_trace_conn_pkt(const uct_tcp_ep_t *ep,

ucs_status_t uct_tcp_cm_send_event(uct_tcp_ep_t *ep, uct_tcp_cm_conn_event_t event)
{
uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface,
uct_tcp_iface_t);
uct_tcp_iface_t *iface = ucs_derived_of(ep->super.super.iface,
uct_tcp_iface_t);
size_t magic_number_length = 0;
void *pkt_buf;
size_t pkt_length, cm_pkt_length;
uct_tcp_cm_conn_req_pkt_t *conn_pkt;
Expand All @@ -160,20 +161,30 @@ ucs_status_t uct_tcp_cm_send_event(uct_tcp_ep_t *ep, uct_tcp_cm_conn_event_t eve
(ep->conn_state != UCT_TCP_EP_CONN_STATE_CONNECTED),
"ep=%p", ep);

pkt_length = sizeof(*pkt_hdr);
pkt_length = sizeof(*pkt_hdr);
if (event == UCT_TCP_CM_CONN_REQ) {
cm_pkt_length = sizeof(*conn_pkt);
cm_pkt_length = sizeof(*conn_pkt);

if (ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTING) {
magic_number_length = sizeof(uint64_t);
}
} else {
cm_pkt_length = sizeof(event);
cm_pkt_length = sizeof(event);
}
pkt_length += cm_pkt_length;
pkt_buf = ucs_alloca(pkt_length);

pkt_hdr = (uct_tcp_am_hdr_t*)pkt_buf;
pkt_length += cm_pkt_length + magic_number_length;
pkt_buf = ucs_alloca(pkt_length);
pkt_hdr = (uct_tcp_am_hdr_t*)(UCS_PTR_BYTE_OFFSET(pkt_buf,
magic_number_length));
pkt_hdr->am_id = UCT_AM_ID_MAX;
pkt_hdr->length = cm_pkt_length;

if (event == UCT_TCP_CM_CONN_REQ) {
if (ep->conn_state == UCT_TCP_EP_CONN_STATE_CONNECTING) {
ucs_assert(magic_number_length == sizeof(uint64_t));
*(uint64_t*)pkt_buf = UCT_TCP_MAGIC_NUMBER;
}

conn_pkt = (uct_tcp_cm_conn_req_pkt_t*)(pkt_hdr + 1);
conn_pkt->event = UCT_TCP_CM_CONN_REQ;
conn_pkt->iface_addr = iface->config.ifaddr;
Expand Down Expand Up @@ -500,27 +511,40 @@ unsigned uct_tcp_cm_handle_conn_pkt(uct_tcp_ep_t **ep_p, void *pkt, uint32_t len
return 0;
}

unsigned uct_tcp_cm_conn_progress(uct_tcp_ep_t *ep)
static ucs_status_t uct_tcp_cm_conn_complete(uct_tcp_ep_t *ep,
unsigned *progress_count_p)
{
ucs_status_t status;

if (!ucs_socket_is_connected(ep->fd)) {
ucs_error("tcp_ep %p: connection establishment for "
"socket fd %d was unsuccessful", ep, ep->fd);
goto err;
}

status = uct_tcp_cm_send_event(ep, UCT_TCP_CM_CONN_REQ);
if (status != UCS_OK) {
return 0;
goto out;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_WAITING_ACK);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0);

ucs_assertv((ep->tx.length == 0) && (ep->tx.offset == 0) &&
(ep->tx.buf == NULL), "ep=%p", ep);
return 1;
out:
if (progress_count_p != NULL) {
*progress_count_p = (status == UCS_OK);
}
return status;
}

unsigned uct_tcp_cm_conn_progress(uct_tcp_ep_t *ep)
{
unsigned progress_count;

if (!ucs_socket_is_connected(ep->fd)) {
ucs_error("tcp_ep %p: connection establishment for "
"socket fd %d was unsuccessful", ep, ep->fd);
goto err;
}

uct_tcp_cm_conn_complete(ep, &progress_count);
return progress_count;

err:
uct_tcp_ep_set_failed(ep);
Expand All @@ -539,13 +563,13 @@ ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep)
return UCS_ERR_TIMED_OUT;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_CONNECTING);

status = ucs_socket_connect(ep->fd, (const struct sockaddr*)&ep->peer_addr);
if (UCS_STATUS_IS_ERR(status)) {
return status;
} else if (status == UCS_INPROGRESS) {
uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_CONNECTING);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVWRITE, 0);

return UCS_OK;
}

Expand All @@ -558,15 +582,7 @@ ucs_status_t uct_tcp_cm_conn_start(uct_tcp_ep_t *ep)
}
}

status = uct_tcp_cm_send_event(ep, UCT_TCP_CM_CONN_REQ);
if (status != UCS_OK) {
return status;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_WAITING_ACK);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0);

return UCS_OK;
return uct_tcp_cm_conn_complete(ep, NULL);
}

/* This function is called from async thread */
Expand Down Expand Up @@ -594,7 +610,7 @@ ucs_status_t uct_tcp_cm_handle_incoming_conn(uct_tcp_iface_t *iface,
return status;
}

uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_ACCEPTING);
uct_tcp_cm_change_conn_state(ep, UCT_TCP_EP_CONN_STATE_RECV_MAGIC_NUMBER);
uct_tcp_ep_mod_events(ep, UCS_EVENT_SET_EVREAD, 0);

ucs_debug("tcp_iface %p: accepted connection from "
Expand Down
Loading

0 comments on commit eaec8bd

Please sign in to comment.