From b6ff99fb1785219ab37fe65dc2981278f8805a08 Mon Sep 17 00:00:00 2001 From: Subrata Paitandi Date: Tue, 27 Jan 2026 10:47:00 +0000 Subject: [PATCH 1/2] segfault fix --- mssql_python/pybind/connection/connection.cpp | 28 ++++++++++++++++++- mssql_python/pybind/connection/connection.h | 4 +++ mssql_python/pybind/ddbc_bindings.cpp | 27 ++++++++++++------ mssql_python/pybind/ddbc_bindings.h | 6 ++++ 4 files changed, 55 insertions(+), 10 deletions(-) diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index 1fe4d213..d61971b3 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -94,6 +94,19 @@ void Connection::connect(const py::dict& attrs_before) { void Connection::disconnect() { if (_dbcHandle) { LOG("Disconnecting from database"); + + // CRITICAL FIX: Mark all child statement handles as implicitly freed + // When we free the DBC handle below, the ODBC driver will automatically free + // all child STMT handles. We need to tell the SqlHandle objects about this + // so they don't try to free the handles again during their destruction. + LOG("Marking %zu child statement handles as implicitly freed", _childStatementHandles.size()); + for (auto& weakHandle : _childStatementHandles) { + if (auto handle = weakHandle.lock()) { + handle->markImplicitlyFreed(); + } + } + _childStatementHandles.clear(); + SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); checkError(ret); // triggers SQLFreeHandle via destructor, if last owner @@ -173,7 +186,20 @@ SqlHandlePtr Connection::allocStatementHandle() { SQLHANDLE stmt = nullptr; SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt); checkError(ret); - return std::make_shared(static_cast(SQL_HANDLE_STMT), stmt); + auto stmtHandle = std::make_shared(static_cast(SQL_HANDLE_STMT), stmt); + + // Track this child handle so we can mark it as implicitly freed when connection closes + // Use weak_ptr to avoid circular references and allow normal cleanup + _childStatementHandles.push_back(stmtHandle); + + // Clean up expired weak_ptrs periodically to avoid unbounded growth + // Remove entries where the weak_ptr is expired (object was already destroyed) + _childStatementHandles.erase( + std::remove_if(_childStatementHandles.begin(), _childStatementHandles.end(), + [](const std::weak_ptr& wp) { return wp.expired(); }), + _childStatementHandles.end()); + + return stmtHandle; } SQLRETURN Connection::setAttribute(SQLINTEGER attribute, py::object value) { diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index d007106a..6bdb596b 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -61,6 +61,10 @@ class Connection { std::chrono::steady_clock::time_point _lastUsed; std::wstring wstrStringBuffer; // wstr buffer for string attribute setting std::string strBytesBuffer; // string buffer for byte attributes setting + + // Track child statement handles to mark them as implicitly freed when connection closes + // Uses weak_ptr to avoid circular references and allow normal cleanup + std::vector> _childStatementHandles; }; class ConnectionHandle { diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f49d860a..32baaf0f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -1144,6 +1144,10 @@ SQLSMALLINT SqlHandle::type() const { return _type; } +void SqlHandle::markImplicitlyFreed() { + _implicitly_freed = true; +} + /* * IMPORTANT: Never log in destructors - it causes segfaults. * During program exit, C++ destructors may run AFTER Python shuts down. @@ -1169,16 +1173,19 @@ void SqlHandle::free() { return; } - // Always clean up ODBC resources, regardless of Python state + // CRITICAL FIX: Check if handle was already implicitly freed by parent handle + // When Connection::disconnect() frees the DBC handle, the ODBC driver automatically + // frees all child STMT handles. We track this state to avoid double-free attempts. + // This approach avoids calling ODBC functions on potentially-freed handles, which + // would cause use-after-free errors. + if (_implicitly_freed) { + _handle = nullptr; // Just clear the pointer, don't call ODBC functions + return; + } + + // Handle is valid and not implicitly freed, proceed with normal freeing SQLFreeHandle_ptr(_type, _handle); _handle = nullptr; - - // Only log if Python is not shutting down (to avoid segfault) - if (!pythonShuttingDown) { - // Don't log during destruction - even in normal cases it can be - // problematic If logging is needed, use explicit close() methods - // instead - } } } @@ -4360,7 +4367,9 @@ PYBIND11_MODULE(ddbc_bindings, m) { .def_readwrite("ddbcErrorMsg", &ErrorInfo::ddbcErrorMsg); py::class_(m, "SqlHandle") - .def("free", &SqlHandle::free, "Free the handle"); + .def("free", &SqlHandle::free, "Free the handle") + .def("markImplicitlyFreed", &SqlHandle::markImplicitlyFreed, + "Mark handle as implicitly freed by parent handle"); py::class_(m, "Connection") .def(py::init(), py::arg("conn_str"), diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 391903ef..87391b55 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -378,10 +378,16 @@ class SqlHandle { SQLHANDLE get() const; SQLSMALLINT type() const; void free(); + + // Mark this handle as implicitly freed (freed by parent handle) + // This prevents double-free attempts when the ODBC driver automatically + // frees child handles (e.g., STMT handles when DBC handle is freed) + void markImplicitlyFreed(); private: SQLSMALLINT _type; SQLHANDLE _handle; + bool _implicitly_freed = false; // Tracks if handle was freed by parent }; using SqlHandlePtr = std::shared_ptr; From b7f105f1297a7b5a18818b5b67f36b9d20f99074 Mon Sep 17 00:00:00 2001 From: Subrata Paitandi Date: Tue, 27 Jan 2026 11:06:46 +0000 Subject: [PATCH 2/2] test addition and linting fix --- mssql_python/pybind/connection/connection.cpp | 17 +- mssql_python/pybind/connection/connection.h | 2 +- mssql_python/pybind/ddbc_bindings.cpp | 11 +- mssql_python/pybind/ddbc_bindings.h | 2 +- ...st_016_connection_invalidation_segfault.py | 351 ++++++++++++++++++ 5 files changed, 366 insertions(+), 17 deletions(-) create mode 100644 tests/test_016_connection_invalidation_segfault.py diff --git a/mssql_python/pybind/connection/connection.cpp b/mssql_python/pybind/connection/connection.cpp index d61971b3..7c5e756a 100644 --- a/mssql_python/pybind/connection/connection.cpp +++ b/mssql_python/pybind/connection/connection.cpp @@ -94,19 +94,20 @@ void Connection::connect(const py::dict& attrs_before) { void Connection::disconnect() { if (_dbcHandle) { LOG("Disconnecting from database"); - + // CRITICAL FIX: Mark all child statement handles as implicitly freed // When we free the DBC handle below, the ODBC driver will automatically free // all child STMT handles. We need to tell the SqlHandle objects about this // so they don't try to free the handles again during their destruction. - LOG("Marking %zu child statement handles as implicitly freed", _childStatementHandles.size()); + LOG("Marking %zu child statement handles as implicitly freed", + _childStatementHandles.size()); for (auto& weakHandle : _childStatementHandles) { if (auto handle = weakHandle.lock()) { handle->markImplicitlyFreed(); } } _childStatementHandles.clear(); - + SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get()); checkError(ret); // triggers SQLFreeHandle via destructor, if last owner @@ -187,18 +188,18 @@ SqlHandlePtr Connection::allocStatementHandle() { SQLRETURN ret = SQLAllocHandle_ptr(SQL_HANDLE_STMT, _dbcHandle->get(), &stmt); checkError(ret); auto stmtHandle = std::make_shared(static_cast(SQL_HANDLE_STMT), stmt); - + // Track this child handle so we can mark it as implicitly freed when connection closes // Use weak_ptr to avoid circular references and allow normal cleanup _childStatementHandles.push_back(stmtHandle); - + // Clean up expired weak_ptrs periodically to avoid unbounded growth // Remove entries where the weak_ptr is expired (object was already destroyed) _childStatementHandles.erase( std::remove_if(_childStatementHandles.begin(), _childStatementHandles.end(), [](const std::weak_ptr& wp) { return wp.expired(); }), _childStatementHandles.end()); - + return stmtHandle; } @@ -334,7 +335,7 @@ bool Connection::reset() { disconnect(); return false; } - + // SQL_ATTR_RESET_CONNECTION does NOT reset the transaction isolation level. // Explicitly reset it to the default (SQL_TXN_READ_COMMITTED) to prevent // isolation level settings from leaking between pooled connection usages. @@ -346,7 +347,7 @@ bool Connection::reset() { disconnect(); return false; } - + updateLastUsed(); return true; } diff --git a/mssql_python/pybind/connection/connection.h b/mssql_python/pybind/connection/connection.h index 6bdb596b..4bda21a0 100644 --- a/mssql_python/pybind/connection/connection.h +++ b/mssql_python/pybind/connection/connection.h @@ -61,7 +61,7 @@ class Connection { std::chrono::steady_clock::time_point _lastUsed; std::wstring wstrStringBuffer; // wstr buffer for string attribute setting std::string strBytesBuffer; // string buffer for byte attributes setting - + // Track child statement handles to mark them as implicitly freed when connection closes // Uses weak_ptr to avoid circular references and allow normal cleanup std::vector> _childStatementHandles; diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 32baaf0f..52b9d21e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -2900,7 +2900,6 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p // Cache decimal separator to avoid repeated system calls - for (SQLSMALLINT i = 1; i <= colCount; ++i) { SQLWCHAR columnName[256]; SQLSMALLINT columnNameLen; @@ -3622,8 +3621,6 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum columnInfos[col].processedColumnSize + 1; // +1 for null terminator } - - // Performance: Build function pointer dispatch table (once per batch) // This eliminates the switch statement from the hot loop - 10,000 rows × 10 // cols reduces from 100,000 switch evaluations to just 10 switch @@ -4040,8 +4037,8 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch lobColumns.push_back(i + 1); // 1-based } } - - // Initialized to 0 for LOB path counter; overwritten by ODBC in non-LOB path; + + // Initialized to 0 for LOB path counter; overwritten by ODBC in non-LOB path; SQLULEN numRowsFetched = 0; // If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap if (!lobColumns.empty()) { @@ -4073,7 +4070,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch LOG("FetchMany_wrap: Error when binding columns - SQLRETURN=%d", ret); return ret; } - + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); @@ -4368,7 +4365,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { py::class_(m, "SqlHandle") .def("free", &SqlHandle::free, "Free the handle") - .def("markImplicitlyFreed", &SqlHandle::markImplicitlyFreed, + .def("markImplicitlyFreed", &SqlHandle::markImplicitlyFreed, "Mark handle as implicitly freed by parent handle"); py::class_(m, "Connection") diff --git a/mssql_python/pybind/ddbc_bindings.h b/mssql_python/pybind/ddbc_bindings.h index 87391b55..190c9bd1 100644 --- a/mssql_python/pybind/ddbc_bindings.h +++ b/mssql_python/pybind/ddbc_bindings.h @@ -378,7 +378,7 @@ class SqlHandle { SQLHANDLE get() const; SQLSMALLINT type() const; void free(); - + // Mark this handle as implicitly freed (freed by parent handle) // This prevents double-free attempts when the ODBC driver automatically // frees child handles (e.g., STMT handles when DBC handle is freed) diff --git a/tests/test_016_connection_invalidation_segfault.py b/tests/test_016_connection_invalidation_segfault.py new file mode 100644 index 00000000..b92826aa --- /dev/null +++ b/tests/test_016_connection_invalidation_segfault.py @@ -0,0 +1,351 @@ +""" +Test for connection invalidation segfault scenario (Issue: Use-after-free on statement handles) + +This test reproduces the segfault that occurred in SQLAlchemy's RealReconnectTest +when connection invalidation triggered automatic freeing of child statement handles +by the ODBC driver, followed by Python GC attempting to free the same handles again. + +The fix uses state tracking where Connection marks child handles as "implicitly freed" +before disconnecting, preventing SqlHandle::free() from calling ODBC functions on +already-freed handles. + +Background: +- When Connection::disconnect() frees a DBC handle, ODBC automatically frees all child STMT handles +- Python SqlHandle objects weren't aware of this implicit freeing +- GC later tried to free those handles again via SqlHandle::free(), causing use-after-free +- Fix: Connection tracks children in _childStatementHandles vector and marks them as + implicitly freed before DBC is freed +""" + +import gc +import pytest +from mssql_python import connect, DatabaseError, OperationalError + + +def test_connection_invalidation_with_multiple_cursors(conn_str): + """ + Test connection invalidation scenario that previously caused segfaults. + + This test: + 1. Creates a connection with multiple active cursors + 2. Executes queries on those cursors to create statement handles + 3. Simulates connection invalidation by closing the connection + 4. Forces garbage collection to trigger handle cleanup + 5. Verifies no segfault occurs during the cleanup process + + Previously, this would crash because: + - Closing connection freed the DBC handle + - ODBC driver automatically freed all child STMT handles + - Python GC later tried to free those same STMT handles + - Result: use-after-free crash (segfault) + + With the fix: + - Connection marks all child handles as "implicitly freed" before closing + - SqlHandle::free() checks the flag and skips ODBC calls if true + - Result: No crash, clean shutdown + """ + # Create connection + conn = connect(conn_str) + + # Create multiple cursors with statement handles + cursors = [] + for i in range(5): + cursor = conn.cursor() + cursor.execute("SELECT 1 AS id, 'test' AS name") + cursor.fetchall() # Fetch results to complete the query + cursors.append(cursor) + + # Close connection without explicitly closing cursors first + # This simulates the invalidation scenario where connection is lost + conn.close() + + # Force garbage collection to trigger cursor cleanup + # This is where the segfault would occur without the fix + cursors = None + gc.collect() + + # If we reach here without crashing, the fix is working + assert True + + +def test_connection_invalidation_without_cursor_close(conn_str): + """ + Test that cursors are properly cleaned up when connection is closed + without explicitly closing the cursors. + + This mimics the SQLAlchemy scenario where connection pools may + invalidate connections without first closing all cursors. + """ + conn = connect(conn_str) + + # Create cursors and execute queries + cursor1 = conn.cursor() + cursor1.execute("SELECT 1") + cursor1.fetchone() + + cursor2 = conn.cursor() + cursor2.execute("SELECT 2") + cursor2.fetchone() + + cursor3 = conn.cursor() + cursor3.execute("SELECT 3") + cursor3.fetchone() + + # Close connection with active cursors + conn.close() + + # Trigger GC - should not crash + del cursor1, cursor2, cursor3 + gc.collect() + + assert True + + +def test_repeated_connection_invalidation_cycles(conn_str): + """ + Test repeated connection invalidation cycles to ensure no memory + corruption or handle leaks occur across multiple iterations. + + This stress test simulates the scenario from SQLAlchemy's + RealReconnectTest which ran multiple invalidation tests in sequence. + """ + for iteration in range(10): + # Create connection + conn = connect(conn_str) + + # Create and use cursors + for cursor_num in range(3): + cursor = conn.cursor() + cursor.execute(f"SELECT {iteration} AS iteration, {cursor_num} AS cursor_num") + result = cursor.fetchone() + assert result[0] == iteration + assert result[1] == cursor_num + + # Close connection (invalidate) + conn.close() + + # Force GC after each iteration + gc.collect() + + # Final GC to clean up any remaining references + gc.collect() + assert True + + +def test_connection_close_with_uncommitted_transaction(conn_str): + """ + Test that closing a connection with an uncommitted transaction + properly cleans up statement handles without crashing. + """ + conn = connect(conn_str) + cursor = conn.cursor() + + try: + # Start a transaction + cursor.execute("CREATE TABLE #temp_test (id INT, name VARCHAR(50))") + cursor.execute("INSERT INTO #temp_test VALUES (1, 'test')") + # Don't commit - leave transaction open + + # Close connection without commit or rollback + conn.close() + + # Trigger GC + del cursor + gc.collect() + + assert True + except Exception as e: + pytest.fail(f"Unexpected exception during connection close: {e}") + + +def test_cursor_after_connection_invalidation_raises_error(conn_str): + """ + Test that attempting to use a cursor after connection is closed + raises an appropriate error rather than crashing. + """ + conn = connect(conn_str) + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.fetchone() + + # Close connection + conn.close() + + # Attempting to execute on cursor should raise an error, not crash + with pytest.raises((DatabaseError, OperationalError)): + cursor.execute("SELECT 2") + + # GC should not crash + del cursor + gc.collect() + + +def test_multiple_connections_concurrent_invalidation(conn_str): + """ + Test that multiple connections can be invalidated concurrently + without interfering with each other's handle cleanup. + """ + connections = [] + all_cursors = [] + + # Create multiple connections with cursors + for conn_num in range(5): + conn = connect(conn_str) + connections.append(conn) + + for cursor_num in range(3): + cursor = conn.cursor() + cursor.execute(f"SELECT {conn_num} AS conn, {cursor_num} AS cursor") + cursor.fetchone() + all_cursors.append(cursor) + + # Close all connections + for conn in connections: + conn.close() + + # Clear references and force GC + connections = None + all_cursors = None + gc.collect() + + assert True + + +def test_connection_invalidation_with_prepared_statements(conn_str): + """ + Test connection invalidation when cursors have prepared statements. + This ensures statement handles are properly marked as implicitly freed. + """ + conn = connect(conn_str) + + # Create cursor with parameterized query (prepared statement) + cursor = conn.cursor() + cursor.execute("SELECT ? AS value", (42,)) + result = cursor.fetchone() + assert result[0] == 42 + + # Execute another parameterized query + cursor.execute("SELECT ? AS name, ? AS age", ("John", 30)) + result = cursor.fetchone() + assert result[0] == "John" + assert result[1] == 30 + + # Close connection with prepared statements + conn.close() + + # GC should handle cleanup without crash + del cursor + gc.collect() + + assert True + + +def test_verify_markImplicitlyFreed_method_exists(): + """ + Verify that the markImplicitlyFreed method exists on SqlHandle. + This is the core of the segfault fix. + """ + from mssql_python import ddbc_bindings + + # Verify the method exists + assert hasattr( + ddbc_bindings.SqlHandle, "markImplicitlyFreed" + ), "SqlHandle should have markImplicitlyFreed method" + + # Verify free method also exists + assert hasattr(ddbc_bindings.SqlHandle, "free"), "SqlHandle should have free method" + + +def test_connection_invalidation_with_fetchall(conn_str): + """ + Test connection invalidation when cursors have fetched all results. + This ensures all statement handle states are properly cleaned up. + """ + conn = connect(conn_str) + + cursor = conn.cursor() + cursor.execute("SELECT number FROM (VALUES (1), (2), (3), (4), (5)) AS numbers(number)") + results = cursor.fetchall() + assert len(results) == 5 + + # Close connection after fetchall + conn.close() + + # GC cleanup should work without issues + del cursor + gc.collect() + + assert True + + +def test_nested_connection_cursor_cleanup(conn_str): + """ + Test nested connection/cursor creation and cleanup pattern. + This mimics complex application patterns where connections + and cursors are created in nested scopes. + """ + + def inner_function(connection): + cursor = connection.cursor() + cursor.execute("SELECT 'inner' AS scope") + return cursor.fetchone() + + def outer_function(conn_str): + conn = connect(conn_str) + result = inner_function(conn) + conn.close() + return result + + # Run multiple times to ensure no accumulated state issues + for _ in range(5): + result = outer_function(conn_str) + assert result[0] == "inner" + gc.collect() + + # Final cleanup + gc.collect() + assert True + + +if __name__ == "__main__": + # Allow running this test file directly for quick verification + import sys + + if len(sys.argv) > 1: + conn_str = sys.argv[1] + print("Running connection invalidation segfault tests...") + + test_connection_invalidation_with_multiple_cursors(conn_str) + print("✓ test_connection_invalidation_with_multiple_cursors passed") + + test_connection_invalidation_without_cursor_close(conn_str) + print("✓ test_connection_invalidation_without_cursor_close passed") + + test_repeated_connection_invalidation_cycles(conn_str) + print("✓ test_repeated_connection_invalidation_cycles passed") + + test_connection_close_with_uncommitted_transaction(conn_str) + print("✓ test_connection_close_with_uncommitted_transaction passed") + + test_cursor_after_connection_invalidation_raises_error(conn_str) + print("✓ test_cursor_after_connection_invalidation_raises_error passed") + + test_multiple_connections_concurrent_invalidation(conn_str) + print("✓ test_multiple_connections_concurrent_invalidation passed") + + test_connection_invalidation_with_prepared_statements(conn_str) + print("✓ test_connection_invalidation_with_prepared_statements passed") + + test_verify_markImplicitlyFreed_method_exists() + print("✓ test_verify_markImplicitlyFreed_method_exists passed") + + test_connection_invalidation_with_fetchall(conn_str) + print("✓ test_connection_invalidation_with_fetchall passed") + + test_nested_connection_cursor_cleanup(conn_str) + print("✓ test_nested_connection_cursor_cleanup passed") + + print("\n✓✓✓ All connection invalidation segfault tests passed! ✓✓✓") + else: + print("Usage: python test_016_connection_invalidation_segfault.py ") + print("Or run with pytest: pytest test_016_connection_invalidation_segfault.py")