diff --git a/Modules/_ssl.c b/Modules/_ssl.c index 7dd57e7892af41..4c062994d987c3 100644 --- a/Modules/_ssl.c +++ b/Modules/_ssl.c @@ -420,26 +420,6 @@ typedef enum { #define ERRSTR1(x,y,z) (x ":" y ": " z) #define ERRSTR(x) ERRSTR1("_ssl.c", Py_STRINGIFY(__LINE__), x) -// Get the socket from a PySSLSocket, if it has one. -// Return a borrowed reference. -static inline PySocketSockObject* GET_SOCKET(PySSLSocket *obj) { - if (obj->Socket) { - PyObject *sock; - if (PyWeakref_GetRef(obj->Socket, &sock)) { - // GET_SOCKET() returns a borrowed reference - Py_DECREF(sock); - } - else { - // dead weak reference - sock = Py_None; - } - return (PySocketSockObject *)sock; // borrowed reference - } - else { - return NULL; - } -} - /* If sock is NULL, use a timeout of 0 second */ #define GET_SOCKET_TIMEOUT(sock) \ ((sock != NULL) ? (sock)->sock_timeout : 0) @@ -791,6 +771,35 @@ _ssl_deprecated(const char* msg, int stacklevel) { #define PY_SSL_DEPRECATED(name, stacklevel, ret) \ if (_ssl_deprecated((name), (stacklevel)) == -1) return (ret) +// Get the socket from a PySSLSocket, if it has one. +// Stores a strong reference in out_sock. +static int +get_socket(PySSLSocket *obj, PySocketSockObject **out_sock, + const char *filename, int lineno) +{ + if (!obj->Socket) { + *out_sock = NULL; + return 0; + } + PySocketSockObject *sock; + int res = PyWeakref_GetRef(obj->Socket, (PyObject **)&sock); + if (res == 0 || sock->sock_fd == INVALID_SOCKET) { + _setSSLError(get_state_sock(obj), + "Underlying socket connection gone", + PY_SSL_ERROR_NO_SOCKET, filename, lineno); + *out_sock = NULL; + return -1; + } + if (sock != NULL) { + /* just in case the blocking state of the socket has been changed */ + int nonblocking = (sock->sock_timeout >= 0); + BIO_set_nbio(SSL_get_rbio(obj->ssl), nonblocking); + BIO_set_nbio(SSL_get_wbio(obj->ssl), nonblocking); + } + *out_sock = sock; + return res; +} + /* * SSL objects */ @@ -1018,24 +1027,13 @@ _ssl__SSLSocket_do_handshake_impl(PySSLSocket *self) int ret; _PySSLError err; PyObject *exc = NULL; - int sockstate, nonblocking; - PySocketSockObject *sock = GET_SOCKET(self); + int sockstate; PyTime_t timeout, deadline = 0; int has_timeout; - if (sock) { - if (((PyObject*)sock) == Py_None) { - _setSSLError(get_state_sock(self), - "Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; - } - Py_INCREF(sock); - - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + PySocketSockObject *sock = NULL; + if (get_socket(self, &sock, __FILE__, __LINE__) < 0) { + return NULL; } timeout = GET_SOCKET_TIMEOUT(sock); @@ -2607,22 +2605,12 @@ _ssl__SSLSocket_sendfile_impl(PySSLSocket *self, int fd, Py_off_t offset, int sockstate; _PySSLError err; PyObject *exc = NULL; - PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; int has_timeout; - if (sock != NULL) { - if ((PyObject *)sock == Py_None) { - _setSSLError(get_state_sock(self), - "Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; - } - Py_INCREF(sock); - /* just in case the blocking state of the socket has been changed */ - int nonblocking = (sock->sock_timeout >= 0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + PySocketSockObject *sock = NULL; + if (get_socket(self, &sock, __FILE__, __LINE__) < 0) { + return NULL; } timeout = GET_SOCKET_TIMEOUT(sock); @@ -2744,26 +2732,12 @@ _ssl__SSLSocket_write_impl(PySSLSocket *self, Py_buffer *b) int sockstate; _PySSLError err; PyObject *exc = NULL; - int nonblocking; - PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; int has_timeout; - if (sock != NULL) { - if (((PyObject*)sock) == Py_None) { - _setSSLError(get_state_sock(self), - "Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; - } - Py_INCREF(sock); - } - - if (sock != NULL) { - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + PySocketSockObject *sock = NULL; + if (get_socket(self, &sock, __FILE__, __LINE__) < 0) { + return NULL; } timeout = GET_SOCKET_TIMEOUT(sock); @@ -2893,8 +2867,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, int sockstate; _PySSLError err; PyObject *exc = NULL; - int nonblocking; - PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; int has_timeout; @@ -2903,14 +2875,9 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, return NULL; } - if (sock != NULL) { - if (((PyObject*)sock) == Py_None) { - _setSSLError(get_state_sock(self), - "Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; - } - Py_INCREF(sock); + PySocketSockObject *sock = NULL; + if (get_socket(self, &sock, __FILE__, __LINE__) < 0) { + return NULL; } if (!group_right_1) { @@ -2941,13 +2908,6 @@ _ssl__SSLSocket_read_impl(PySSLSocket *self, Py_ssize_t len, } } - if (sock != NULL) { - /* just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); - } - timeout = GET_SOCKET_TIMEOUT(sock); has_timeout = (timeout > 0); if (has_timeout) @@ -3038,26 +2998,14 @@ _ssl__SSLSocket_shutdown_impl(PySSLSocket *self) { _PySSLError err; PyObject *exc = NULL; - int sockstate, nonblocking, ret; + int sockstate, ret; int zeros = 0; - PySocketSockObject *sock = GET_SOCKET(self); PyTime_t timeout, deadline = 0; int has_timeout; - if (sock != NULL) { - /* Guard against closed socket */ - if ((((PyObject*)sock) == Py_None) || (sock->sock_fd == INVALID_SOCKET)) { - _setSSLError(get_state_sock(self), - "Underlying socket connection gone", - PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__); - return NULL; - } - Py_INCREF(sock); - - /* Just in case the blocking state of the socket has been changed */ - nonblocking = (sock->sock_timeout >= 0); - BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking); - BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking); + PySocketSockObject *sock = NULL; + if (get_socket(self, &sock, __FILE__, __LINE__) < 0) { + return NULL; } timeout = GET_SOCKET_TIMEOUT(sock);