@@ -206,6 +206,7 @@ enum rs_state {
rs_connect_error = 0x0800,
rs_disconnected = 0x1000,
rs_error = 0x2000,
+ rs_shutdown = 0x4000,
};
#define RS_OPT_SWAP_SGL (1 << 0)
@@ -1786,9 +1787,15 @@ static int rs_poll_cq(struct rsocket *rs)
case RS_OP_CTRL:
if (rs_msg_data(msg) == RS_CTRL_DISCONNECT) {
rs->state = rs_disconnected;
+ rshutdown(rs->index, SHUT_RDWR);
return 0;
} else if (rs_msg_data(msg) == RS_CTRL_SHUTDOWN) {
- rs->state &= ~rs_readable;
+ if (rs->state & rs_writable) {
+ rs->state &= ~rs_readable;
+ } else {
+ rs->state = rs_disconnected;
+ return 0;
+ }
}
break;
case RS_OP_WRITE:
@@ -2914,10 +2921,12 @@ static int rs_poll_events(struct pollfd *rfds, struct pollfd *fds, nfds_t nfds)
rs = idm_lookup(&idm, fds[i].fd);
if (rs) {
+ fastlock_acquire(&rs->cq_wait_lock);
if (rs->type == SOCK_STREAM)
rs_get_cq_event(rs);
else
ds_get_cq_event(rs);
+ fastlock_release(&rs->cq_wait_lock);
fds[i].revents = rs_poll_rs(rs, fds[i].events, 1, rs_poll_all);
} else {
fds[i].revents = rfds[i].revents;
@@ -3064,7 +3073,8 @@ int rselect(int nfds, fd_set *readfds, fd_set *writefds,
/*
* For graceful disconnect, notify the remote side that we're
- * disconnecting and wait until all outstanding sends complete.
+ * disconnecting and wait until all outstanding sends complete, provided
+ * that the remote side has not sent a disconnect message.
*/
int rshutdown(int socket, int how)
{
@@ -3072,11 +3082,6 @@ int rshutdown(int socket, int how)
int ctrl, ret = 0;
rs = idm_at(&idm, socket);
- if (how == SHUT_RD) {
- rs->state &= ~rs_readable;
- return 0;
- }
-
if (rs->fd_flags & O_NONBLOCK)
rs_set_nonblocking(rs, 0);
@@ -3084,15 +3089,20 @@ int rshutdown(int socket, int how)
if (how == SHUT_RDWR) {
ctrl = RS_CTRL_DISCONNECT;
rs->state &= ~(rs_readable | rs_writable);
- } else {
+ } else if (how == SHUT_WR) {
rs->state &= ~rs_writable;
ctrl = (rs->state & rs_readable) ?
RS_CTRL_SHUTDOWN : RS_CTRL_DISCONNECT;
+ } else {
+ rs->state &= ~rs_readable;
+ if (rs->state & rs_writable)
+ goto out;
+ ctrl = RS_CTRL_DISCONNECT;
}
if (!rs->ctrl_avail) {
ret = rs_process_cq(rs, 0, rs_conn_can_send_ctrl);
if (ret)
- return ret;
+ goto out;
}
if ((rs->state & rs_connected) && rs->ctrl_avail) {
@@ -3104,10 +3114,19 @@ int rshutdown(int socket, int how)
if (rs->state & rs_connected)
rs_process_cq(rs, 0, rs_conn_all_sends_done);
+out:
if ((rs->fd_flags & O_NONBLOCK) && (rs->state & rs_connected))
rs_set_nonblocking(rs, rs->fd_flags);
- return 0;
+ if (rs->state & rs_disconnected) {
+ /* Generate event by flushing receives to unblock rpoll */
+ ibv_req_notify_cq(rs->cm_id->recv_cq, 0);
+ rdma_disconnect(rs->cm_id);
+ }
+
+ rs->state = rs_shutdown;
+
+ return ret;
}
static void ds_shutdown(struct rsocket *rs)