diff --git a/sys/netinet/sctp_constants.h b/sys/netinet/sctp_constants.h index 66f2cca5ab6d..3df6ad6db2aa 100644 --- a/sys/netinet/sctp_constants.h +++ b/sys/netinet/sctp_constants.h @@ -968,7 +968,7 @@ __FBSDID("$FreeBSD$"); #define sctp_sowwakeup(inp, so) \ do { \ if (inp->sctp_flags & SCTP_PCB_FLAGS_DONT_WAKE) { \ - inp->sctp_flags |= SCTP_PCB_FLAGS_WAKEOUTPUT; \ + sctp_pcb_add_flags(inp, SCTP_PCB_FLAGS_WAKEOUTPUT); \ } else { \ sowwakeup(so); \ } \ @@ -977,8 +977,8 @@ do { \ #define sctp_sowwakeup_locked(inp, so) \ do { \ if (inp->sctp_flags & SCTP_PCB_FLAGS_DONT_WAKE) { \ + sctp_pcb_add_flags(inp, SCTP_PCB_FLAGS_WAKEOUTPUT); \ SOCKBUF_UNLOCK(&((so)->so_snd)); \ - inp->sctp_flags |= SCTP_PCB_FLAGS_WAKEOUTPUT; \ } else { \ sowwakeup_locked(so); \ } \ @@ -987,7 +987,7 @@ do { \ #define sctp_sorwakeup(inp, so) \ do { \ if (inp->sctp_flags & SCTP_PCB_FLAGS_DONT_WAKE) { \ - inp->sctp_flags |= SCTP_PCB_FLAGS_WAKEINPUT; \ + sctp_pcb_add_flags(inp, SCTP_PCB_FLAGS_WAKEINPUT); \ } else { \ sorwakeup(so); \ } \ @@ -996,7 +996,7 @@ do { \ #define sctp_sorwakeup_locked(inp, so) \ do { \ if (inp->sctp_flags & SCTP_PCB_FLAGS_DONT_WAKE) { \ - inp->sctp_flags |= SCTP_PCB_FLAGS_WAKEINPUT; \ + sctp_pcb_add_flags(inp, SCTP_PCB_FLAGS_WAKEINPUT); \ SOCKBUF_UNLOCK(&((so)->so_rcv)); \ } else { \ sorwakeup_locked(so); \ diff --git a/sys/netinet/sctp_input.c b/sys/netinet/sctp_input.c index ff16654968d5..46b818c9983e 100644 --- a/sys/netinet/sctp_input.c +++ b/sys/netinet/sctp_input.c @@ -1491,8 +1491,7 @@ sctp_process_cookie_existing(struct mbuf *m, int iphlen, int offset, * init/init-ack/cookie done before the * init-ack came back.. */ - stcb->sctp_ep->sctp_flags |= - SCTP_PCB_FLAGS_CONNECTED; + sctp_pcb_add_flags(stcb->sctp_ep, SCTP_PCB_FLAGS_CONNECTED); soisconnected(stcb->sctp_socket); } /* notify upper layer */ @@ -1689,7 +1688,7 @@ sctp_process_cookie_existing(struct mbuf *m, int iphlen, int offset, if (((stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) || (stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_IN_TCPPOOL)) && (!SCTP_IS_LISTENING(inp))) { - stcb->sctp_ep->sctp_flags |= SCTP_PCB_FLAGS_CONNECTED; + sctp_pcb_add_flags(stcb->sctp_ep, SCTP_PCB_FLAGS_CONNECTED); soisconnected(stcb->sctp_socket); } if (SCTP_GET_STATE(stcb) == SCTP_STATE_COOKIE_ECHOED) @@ -2182,7 +2181,7 @@ sctp_process_cookie_new(struct mbuf *m, int iphlen, int offset, * * XXXMJ unlocked */ - stcb->sctp_ep->sctp_flags |= SCTP_PCB_FLAGS_CONNECTED; + sctp_pcb_add_flags(stcb->sctp_ep, SCTP_PCB_FLAGS_CONNECTED); soisconnected(stcb->sctp_socket); } else if ((stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) && (SCTP_IS_LISTENING(inp))) { @@ -2793,7 +2792,7 @@ sctp_handle_cookie_ack(struct sctp_cookie_ack_chunk *cp SCTP_UNUSED, sctp_ulp_notify(SCTP_NOTIFY_ASSOC_UP, stcb, 0, NULL, SCTP_SO_NOT_LOCKED); if ((stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) || (stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_IN_TCPPOOL)) { - stcb->sctp_ep->sctp_flags |= SCTP_PCB_FLAGS_CONNECTED; + sctp_pcb_add_flags(stcb->sctp_ep, SCTP_PCB_FLAGS_CONNECTED); if ((stcb->asoc.state & SCTP_STATE_CLOSED_SOCKET) == 0) { soisconnected(stcb->sctp_socket); } diff --git a/sys/netinet/sctp_pcb.c b/sys/netinet/sctp_pcb.c index 38c88d8ae8e4..bbbec5385c3c 100644 --- a/sys/netinet/sctp_pcb.c +++ b/sys/netinet/sctp_pcb.c @@ -7067,3 +7067,18 @@ sctp_initiate_iterator(inp_func inpf, /* sa_ignore MEMLEAK {memory is put on the tailq for the iterator} */ return (0); } + +/* + * Atomically add flags to the sctp_flags of an inp. + * To be used when the write lock of the inp is not held. + */ +void +sctp_pcb_add_flags(struct sctp_inpcb *inp, uint32_t flags) +{ + uint32_t old_flags, new_flags; + + do { + old_flags = inp->sctp_flags; + new_flags = old_flags | flags; + } while (atomic_cmpset_int(&inp->sctp_flags, old_flags, new_flags) == 0); +} diff --git a/sys/netinet/sctp_pcb.h b/sys/netinet/sctp_pcb.h index 736b0f9d54e9..687ccf6a1c50 100644 --- a/sys/netinet/sctp_pcb.h +++ b/sys/netinet/sctp_pcb.h @@ -619,6 +619,9 @@ int sctp_swap_inpcb_for_listen(struct sctp_inpcb *inp); void sctp_clean_up_stream(struct sctp_tcb *stcb, struct sctp_readhead *rh); +void + sctp_pcb_add_flags(struct sctp_inpcb *, uint32_t); + /*- * Null in last arg inpcb indicate run on ALL ep's. Specific inp in last arg * indicates run on ONLY assoc's of the specified endpoint. diff --git a/sys/netinet/sctputil.c b/sys/netinet/sctputil.c index 23f95353323f..bdb35b988ae6 100644 --- a/sys/netinet/sctputil.c +++ b/sys/netinet/sctputil.c @@ -4340,7 +4340,7 @@ sctp_abort_notification(struct sctp_tcb *stcb, bool from_peer, bool timeout, if ((stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_IN_TCPPOOL) || ((stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) && (stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_CONNECTED))) { - stcb->sctp_ep->sctp_flags |= SCTP_PCB_FLAGS_WAS_ABORTED; + sctp_pcb_add_flags(stcb->sctp_ep, SCTP_PCB_FLAGS_WAS_ABORTED); } if ((stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_SOCKET_GONE) || (stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_SOCKET_ALLGONE) ||