Skip to content
Open
174 changes: 174 additions & 0 deletions Lib/test/test_sqlite3/test_userfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,34 @@ def value(self): return 1 << 65
self.assertRaisesRegex(sqlite.DataError, "string or blob too big",
self.cur.execute, self.query % "err_val_ret")

def test_close_conn_in_window_func_value(self):
# gh-145040: closing connection in window function value() callback.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
con.executemany("INSERT INTO t VALUES(?)",
[(i,) for i in range(20)])

class CloseConnWindow:
def step(self, value):
pass
def finalize(self):
return 0
def value(self):
con.close()
return 0
def inverse(self, value):
pass

con.create_window_function("evil_win", 1, CloseConnWindow)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
cursor = con.execute(
"SELECT evil_win(x) OVER "
"(ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM t"
)
list(cursor)
con.close()


class AggregateTests(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -723,6 +751,152 @@ def test_agg_keyword_args(self):
'takes exactly 3 positional arguments'):
self.con.create_aggregate("test", 1, aggregate_class=AggrText)

def test_aggr_close_conn_in_step(self):
# Connection.close() in an aggregate step callback must not crash.
con = sqlite.connect(":memory:", autocommit=True)
cur = con.cursor()
cur.execute("CREATE TABLE t(x INTEGER)")
for i in range(50):
cur.execute("INSERT INTO t VALUES (?)", (i,))

class CloseConnAgg:
def __init__(self):
self.total = 0

def step(self, value):
self.total += value
con.close()

def finalize(self):
return self.total

con.create_aggregate("agg_close", 1, CloseConnAgg)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute("SELECT agg_close(x) FROM t")
con.close()

def test_close_conn_in_nested_callback(self):
# gh-145040: close() must be prevented even in nested callbacks.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
for i in range(5):
con.execute("INSERT INTO t VALUES(?)", (i,))

def outer_func(x):
con.close()
return x

def inner_func(x):
return x * 10

con.create_function("outer_func", 1, outer_func)
con.create_function("inner_func", 1, inner_func)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute("SELECT outer_func(inner_func(x)) FROM t")
# Connection must still be usable after the failed close attempt.
self.assertEqual(con.execute("SELECT 1").fetchone(), (1,))
con.close()

def test_close_conn_in_nested_callback_caught(self):
# gh-145040: close attempt must propagate even if the exception
# is caught inside the callback and a nested execute consumes
# the flag.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x INTEGER)")
con.execute("INSERT INTO t VALUES(1)")

def swallow_close(x):
try:
con.close()
except sqlite.ProgrammingError:
pass
try:
con.execute("SELECT 1")
except sqlite.ProgrammingError:
pass
return x

con.create_function("swallow_close", 1, swallow_close)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute("SELECT swallow_close(x) FROM t")
# Connection must still be usable.
self.assertEqual(con.execute("SELECT 1").fetchone(), (1,))
con.close()

def test_close_conn_in_udf_during_executemany(self):
# gh-145040: closing connection in UDF during executemany.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x)")

def close_conn(x):
con.close()
return x

con.create_function("close_conn", 1, close_conn)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.executemany("INSERT INTO t VALUES(close_conn(?))",
[(i,) for i in range(10)])
con.close()

def test_close_conn_in_progress_handler_during_iternext(self):
# gh-145040: closing connection in progress handler during iteration.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x)")
con.executemany("INSERT INTO t VALUES(?)",
[(i,) for i in range(100)])

count = 0
def close_progress():
nonlocal count
count += 1
if count >= 5:
con.close()
return 1
return 0

cursor = con.execute("SELECT * FROM t")
con.set_progress_handler(close_progress, 1)
msg = "from within a callback"
import test.support
with test.support.catch_unraisable_exception():
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
for row in cursor:
pass
del cursor
gc_collect()
con.close()

def test_close_conn_in_collation_callback(self):
# gh-145040: closing connection in collation callback.
con = sqlite.connect(":memory:", autocommit=True)
con.execute("CREATE TABLE t(x TEXT)")
con.executemany("INSERT INTO t VALUES(?)",
[(f"item_{i}",) for i in range(50)])

count = 0
def evil_collation(a, b):
nonlocal count
count += 1
if count == 10:
con.close()
if a < b:
return -1
elif a > b:
return 1
return 0

con.create_collation("evil_coll", evil_collation)
msg = "from within a callback"
with self.assertRaisesRegex(sqlite.ProgrammingError, msg):
con.execute(
"SELECT * FROM t ORDER BY x COLLATE evil_coll"
)
con.close()


class AuthorizerTests(unittest.TestCase):
@staticmethod
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fixed a crash in the :mod:`sqlite3` module caused by closing the database
connection from within a callback function invoked during
``sqlite3_step()`` (e.g., an aggregate ``step``, a user-defined function
via :meth:`~sqlite3.Connection.create_function`, a progress handler, or a
collation callback). Raise :exc:`~sqlite3.ProgrammingError` instead of
crashing.
40 changes: 39 additions & 1 deletion Modules/_sqlite/connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "pycore_pyerrors.h" // _PyErr_ChainExceptions1()
#include "pycore_pylifecycle.h" // _Py_IsInterpreterFinalizing()
#include "pycore_unicodeobject.h" // _PyUnicode_AsUTF8NoNUL
#include "pycore_weakref.h"

#include <stdbool.h>

Expand Down Expand Up @@ -283,10 +284,17 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, PyObject *database,
goto error;
}

/* Create lists of weak references to blobs */
/* Create lists of weak references to cursors and blobs */
PyObject *cursors = PyList_New(0);
if (cursors == NULL) {
Py_DECREF(statement_cache);
goto error;
}

PyObject *blobs = PyList_New(0);
if (blobs == NULL) {
Py_DECREF(statement_cache);
Py_DECREF(cursors);
goto error;
}

Expand All @@ -299,7 +307,9 @@ pysqlite_connection_init_impl(pysqlite_Connection *self, PyObject *database,
self->check_same_thread = check_same_thread;
self->thread_ident = PyThread_get_thread_ident();
self->statement_cache = statement_cache;
self->cursors = cursors;
self->blobs = blobs;
self->close_attempted_in_callback = 0;
self->row_factory = Py_NewRef(Py_None);
self->text_factory = Py_NewRef(&PyUnicode_Type);
self->trace_ctx = NULL;
Expand Down Expand Up @@ -381,6 +391,7 @@ connection_traverse(PyObject *op, visitproc visit, void *arg)
pysqlite_Connection *self = _pysqlite_Connection_CAST(op);
Py_VISIT(Py_TYPE(self));
Py_VISIT(self->statement_cache);
Py_VISIT(self->cursors);
Py_VISIT(self->blobs);
Py_VISIT(self->row_factory);
Py_VISIT(self->text_factory);
Expand All @@ -405,6 +416,7 @@ connection_clear(PyObject *op)
{
pysqlite_Connection *self = _pysqlite_Connection_CAST(op);
Py_CLEAR(self->statement_cache);
Py_CLEAR(self->cursors);
Py_CLEAR(self->blobs);
Py_CLEAR(self->row_factory);
Py_CLEAR(self->text_factory);
Expand Down Expand Up @@ -655,6 +667,32 @@ pysqlite_connection_close_impl(pysqlite_Connection *self)
return NULL;
}

/* Check if any cursor is locked (actively executing a query);
* closing during a callback is illegal per the SQLite C API docs. */
assert(PyList_CheckExact(self->cursors));
Py_ssize_t n = PyList_GET_SIZE(self->cursors);
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *weakref = PyList_GET_ITEM(self->cursors, i);
if (_PyWeakref_IsDead(weakref)) {
continue;
}
PyObject *obj;
if (!PyWeakref_GetRef(weakref, &obj)) {
continue;
}
int locked = ((pysqlite_Cursor *)obj)->locked;
Py_DECREF(obj);
if (locked) {
self->close_attempted_in_callback = 1;
PyTypeObject *tp = Py_TYPE(self);
pysqlite_state *state = pysqlite_get_state_by_type(tp);
PyErr_SetString(state->ProgrammingError,
"Cannot close the database connection "
"from within a callback function.");
return NULL;
}
}

pysqlite_close_all_blobs(self);
Py_CLEAR(self->statement_cache);
if (connection_close(self) < 0) {
Expand Down
8 changes: 7 additions & 1 deletion Modules/_sqlite/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,18 @@ typedef struct

int initialized;

/* set to 1 when close() is attempted while a cursor is locked (actively
* executing); checked after stmt_step() returns to raise the appropriate
* ProgrammingError */
int close_attempted_in_callback;

/* thread identification of the thread the connection was created in */
unsigned long thread_ident;

PyObject *statement_cache;

/* Lists of weak references to blobs used within this connection */
/* Lists of weak references to cursors and blobs used within this connection */
PyObject *cursors;
PyObject *blobs;

PyObject* row_factory;
Expand Down
Loading
Loading