Skip to content
53 changes: 53 additions & 0 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,59 @@ def dummycallback(sock, servername, ctx, cycle=ctx):
gc.collect()
self.assertIs(wr(), None)

@unittest.skipUnless(support.Py_GIL_DISABLED,
"test is only useful if the GIL is disabled")
@threading_helper.requires_working_threading()
def test_sni_callback_race(self):
# Replacing sni_callback while handshakes are in-flight must not
# crash (use-after-free on the callback in free-threaded builds).
client_ctx, server_ctx, hostname = testing_context()

server_ctx.sni_callback = lambda *a: None
done = threading.Event()

def do_handshakes():
while not done.is_set():
c_in = ssl.MemoryBIO()
c_out = ssl.MemoryBIO()
s_in = ssl.MemoryBIO()
s_out = ssl.MemoryBIO()
client = client_ctx.wrap_bio(
c_in, c_out, server_hostname=hostname)
server = server_ctx.wrap_bio(s_in, s_out, server_side=True)
for _ in range(50):
try:
client.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if c_out.pending:
s_in.write(c_out.read())
try:
server.do_handshake()
except ssl.SSLWantReadError:
pass
except ssl.SSLError:
break
if s_out.pending:
c_in.write(s_out.read())

def toggle_callback():
while not done.is_set():
server_ctx.sni_callback = lambda *a: None
server_ctx.sni_callback = None

workers = max(4, (os.cpu_count() or 4) * 2)
threads = [threading.Thread(target=do_handshakes)
for _ in range(workers)]
threads.append(threading.Thread(target=toggle_callback))

with threading_helper.catch_threading_exception() as cm:
with threading_helper.start_threads(threads):
done.set()
self.assertIsNone(cm.exc_value)

def test_cert_store_stats(self):
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
self.assertEqual(ctx.cert_store_stats(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix race condition in :attr:`ssl.SSLContext.sni_callback`
36 changes: 20 additions & 16 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define OPENSSL_NO_DEPRECATED 1

#include "Python.h"
#include "pycore_critical_section.h" // Py_BEGIN_CRITICAL_SECTION()
#include "pycore_fileutils.h" // _PyIsSelectable_fd()
#include "pycore_long.h" // _PyLong_UnsignedLongLong_Converter()
#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1()
Expand Down Expand Up @@ -5153,12 +5154,15 @@ _servername_callback(SSL *s, int *al, void *args)
PyObject *result;
/* The high-level ssl.SSLSocket object */
PyObject *ssl_socket;
PyObject *sni_cb;
const char *servername = SSL_get_servername(s, TLSEXT_NAMETYPE_host_name);
PyGILState_STATE gstate = PyGILState_Ensure();

if (sslctx->set_sni_cb == NULL) {
/* remove race condition in this the call back while if removing the
* callback is in progress */
Py_BEGIN_CRITICAL_SECTION(sslctx);
sni_cb = Py_XNewRef(sslctx->set_sni_cb);
Py_END_CRITICAL_SECTION();

if (sni_cb == NULL) {
PyGILState_Release(gstate);
return SSL_TLSEXT_ERR_OK;
}
Expand All @@ -5185,7 +5189,7 @@ _servername_callback(SSL *s, int *al, void *args)
goto error;

if (servername == NULL) {
result = PyObject_CallFunctionObjArgs(sslctx->set_sni_cb, ssl_socket,
result = PyObject_CallFunctionObjArgs(sni_cb, ssl_socket,
Py_None, sslctx, NULL);
}
else {
Expand All @@ -5212,7 +5216,7 @@ _servername_callback(SSL *s, int *al, void *args)
}
Py_DECREF(servername_bytes);
result = PyObject_CallFunctionObjArgs(
sslctx->set_sni_cb, ssl_socket, servername_str,
sni_cb, ssl_socket, servername_str,
sslctx, NULL);
Py_DECREF(servername_str);
}
Expand All @@ -5222,7 +5226,7 @@ _servername_callback(SSL *s, int *al, void *args)
PyErr_FormatUnraisable("Exception ignored "
"in ssl servername callback "
"while calling set SNI callback %R",
sslctx->set_sni_cb);
sni_cb);
*al = SSL_AD_HANDSHAKE_FAILURE;
ret = SSL_TLSEXT_ERR_ALERT_FATAL;
}
Expand All @@ -5247,11 +5251,13 @@ _servername_callback(SSL *s, int *al, void *args)
Py_DECREF(result);
}

Py_DECREF(sni_cb);
PyGILState_Release(gstate);
return ret;

error:
Py_XDECREF(ssl_socket);
Py_XDECREF(sni_cb);
*al = SSL_AD_INTERNAL_ERROR;
ret = SSL_TLSEXT_ERR_ALERT_FATAL;
PyGILState_Release(gstate);
Expand Down Expand Up @@ -5301,20 +5307,18 @@ _ssl__SSLContext_sni_callback_set_impl(PySSLContext *self, PyObject *value)
"sni_callback cannot be set on TLS_CLIENT context");
return -1;
}
Py_CLEAR(self->set_sni_cb);
if (value == Py_None) {
if (!PyCallable_Check(value)) {
SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
}
else {
if (!PyCallable_Check(value)) {
SSL_CTX_set_tlsext_servername_callback(self->ctx, NULL);
PyErr_SetString(PyExc_TypeError,
"not a callable object");
Py_CLEAR(self->set_sni_cb);
if (value != Py_None) {
PyErr_SetString(PyExc_TypeError, "not a callable object");
return -1;
}
self->set_sni_cb = Py_NewRef(value);
SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
}
else {
Py_XSETREF(self->set_sni_cb, Py_NewRef(value));
SSL_CTX_set_tlsext_servername_arg(self->ctx, self);
SSL_CTX_set_tlsext_servername_callback(self->ctx, _servername_callback);
Comment on lines +5310 to +5321
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look unnecessary if set_sni_cb is always accessed in a critical section.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately just the critical section didn't help (thread sanitizer kept complaining). Probably because it's getting accessed from within ssl library itself, so we can't guard against that.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I think I got it working

}
return 0;
}
Expand Down
Loading