diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 099043eae0..a910c38d31 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -195,6 +195,10 @@ def _connection_reduce_fn(val,import_fn): _NOT_SET = object() +# TLS session cache defaults +_DEFAULT_TLS_SESSION_CACHE_SIZE = 100 +_DEFAULT_TLS_SESSION_CACHE_TTL = 3600 # 1 hour in seconds + class NoHostAvailable(Exception): """ @@ -875,6 +879,72 @@ def default_retry_policy(self, policy): .. versionadded:: 3.17.0 """ + tls_session_cache_enabled = True + """ + Enable or disable TLS session caching for faster reconnections. + When enabled (default), TLS sessions are cached and reused for subsequent + connections to the same endpoint, reducing handshake latency. + + Set to False to disable session caching entirely. + + .. versionadded:: 3.30.0 + """ + + tls_session_cache_size = _DEFAULT_TLS_SESSION_CACHE_SIZE + """ + Maximum number of TLS sessions to cache. Default is 100. + When the cache is full, the least recently used session is evicted. + + .. versionadded:: 3.30.0 + """ + + tls_session_cache_ttl = _DEFAULT_TLS_SESSION_CACHE_TTL + """ + Time-to-live for cached TLS sessions in seconds. Default is 3600 (1 hour). + Sessions older than this value will not be reused. + + .. versionadded:: 3.30.0 + """ + + tls_session_cache_options = None + """ + Advanced TLS session cache configuration. Can be set to: + + - An instance of :class:`~cassandra.tls.TLSSessionCacheOptions` for + fine-grained control over session caching behavior (e.g., cache_by_host_only option). + - An instance of :class:`~cassandra.tls.TLSSessionCache` (or a custom subclass) + for complete control over session caching implementation. + + If None (default), a cache is created using :attr:`~.tls_session_cache_size` + and :attr:`~.tls_session_cache_ttl` when SSL/TLS is enabled. + + This option takes precedence over the individual tls_session_cache_* parameters. + + Example with options:: + + from cassandra.tls import TLSSessionCacheOptions + + # Cache by host only (ignoring port) + options = TLSSessionCacheOptions( + max_size=200, + ttl=7200, + cache_by_host_only=True + ) + cluster = Cluster(ssl_context=ssl_context, tls_session_cache_options=options) + + Example with custom cache:: + + from cassandra.tls import TLSSessionCache + + class MyCustomCache(TLSSessionCache): + # Custom implementation + pass + + cluster = Cluster(ssl_context=ssl_context, tls_session_cache_options=MyCustomCache()) + + .. versionadded:: 3.30.0 + """ + sockopts = None """ An optional list of tuples which will be used as arguments to @@ -1204,6 +1274,10 @@ def __init__(self, idle_heartbeat_timeout=30, no_compact=False, ssl_context=None, + tls_session_cache_enabled=True, + tls_session_cache_size=_DEFAULT_TLS_SESSION_CACHE_SIZE, + tls_session_cache_ttl=_DEFAULT_TLS_SESSION_CACHE_TTL, + tls_session_cache_options=None, endpoint_factory=None, application_name=None, application_version=None, @@ -1420,6 +1494,33 @@ def __init__(self, self.ssl_options = ssl_options self.ssl_context = ssl_context + self.tls_session_cache_enabled = tls_session_cache_enabled + self.tls_session_cache_size = tls_session_cache_size + self.tls_session_cache_ttl = tls_session_cache_ttl + self.tls_session_cache_options = tls_session_cache_options + + # Initialize TLS session cache if SSL is enabled and caching is enabled + self._tls_session_cache = None + if (ssl_context or ssl_options) and tls_session_cache_enabled: + from cassandra.tls import TLSSessionCache, TLSSessionCacheOptions + + if tls_session_cache_options is not None: + # Check if it's a TLSSessionCache instance (use directly) + # or TLSSessionCacheOptions (use create_cache()) + if isinstance(tls_session_cache_options, TLSSessionCache): + self._tls_session_cache = tls_session_cache_options + else: + # Assume it's TLSSessionCacheOptions + self._tls_session_cache = tls_session_cache_options.create_cache() + else: + # Create default cache from individual parameters + cache_options = TLSSessionCacheOptions( + max_size=tls_session_cache_size, + ttl=tls_session_cache_ttl, + cache_by_host_only=False + ) + self._tls_session_cache = cache_options.create_cache() + self.sockopts = sockopts self.cql_version = cql_version self.max_schema_agreement_wait = max_schema_agreement_wait @@ -1661,6 +1762,7 @@ def _make_connection_kwargs(self, endpoint, kwargs_dict): kwargs_dict.setdefault('sockopts', self.sockopts) kwargs_dict.setdefault('ssl_options', self.ssl_options) kwargs_dict.setdefault('ssl_context', self.ssl_context) + kwargs_dict.setdefault('tls_session_cache', self._tls_session_cache) kwargs_dict.setdefault('cql_version', self.cql_version) kwargs_dict.setdefault('protocol_version', self.protocol_version) kwargs_dict.setdefault('user_type_map', self._user_types) diff --git a/cassandra/connection.py b/cassandra/connection.py index 9ac02c9776..cc51f54618 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -161,6 +161,15 @@ def socket_family(self): """ return socket.AF_UNSPEC + @property + def tls_session_cache_key(self): + """ + Returns the cache key components for TLS session caching. + This is a tuple that uniquely identifies this endpoint for TLS session purposes. + Subclasses may override this to include additional components (e.g., SNI server name). + """ + return (self.address, self.port) + def resolve(self): """ Resolve the endpoint to an address/port. This is called @@ -275,6 +284,14 @@ def port(self): def ssl_options(self): return self._ssl_options + @property + def tls_session_cache_key(self): + """ + Returns the cache key including server_name for SNI endpoints. + This prevents cache collisions when multiple SNI endpoints use the same proxy. + """ + return (self.address, self.port, self._server_name) + def resolve(self): try: resolved_addresses = socket.getaddrinfo(self._proxy_address, self._port, @@ -349,6 +366,14 @@ def port(self): def socket_family(self): return socket.AF_UNIX + @property + def tls_session_cache_key(self): + """ + Returns the cache key for Unix socket endpoints. + Since Unix sockets don't have a port, only the path is used. + """ + return (self._unix_socket_path,) + def resolve(self): return self.address, None @@ -687,6 +712,7 @@ class Connection(object): endpoint = None ssl_options = None ssl_context = None + tls_session_cache = None last_error = None # The current number of operations that are in flight. More precisely, @@ -763,7 +789,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None, ssl_options=None, sockopts=None, compression: Union[bool, str] = True, cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False, user_type_map=None, connect_timeout=None, allow_beta_protocol_version=False, no_compact=False, - ssl_context=None, owning_pool=None, shard_id=None, total_shards=None, + ssl_context=None, tls_session_cache=None, owning_pool=None, shard_id=None, total_shards=None, on_orphaned_stream_released=None, application_info: Optional[ApplicationInfoBase] = None): # TODO next major rename host to endpoint and remove port kwarg. self.endpoint = host if isinstance(host, EndPoint) else DefaultEndPoint(host, port) @@ -771,6 +797,7 @@ def __init__(self, host='127.0.0.1', port=9042, authenticator=None, self.authenticator = authenticator self.ssl_options = ssl_options.copy() if ssl_options else {} self.ssl_context = ssl_context + self.tls_session_cache = tls_session_cache self.sockopts = sockopts self.compression = compression self.cql_version = cql_version @@ -913,7 +940,21 @@ def _wrap_socket_from_context(self): server_hostname = self.endpoint.address opts['server_hostname'] = server_hostname - return self.ssl_context.wrap_socket(self._socket, **opts) + # Try to get a cached TLS session for resumption + # Note: Session resumption works with both TLS 1.2 and TLS 1.3 + # Python's ssl module handles both transparently via SSLSession objects + if self.tls_session_cache: + cached_session = self.tls_session_cache.get_session(self.endpoint) + if cached_session: + opts['session'] = cached_session + log.debug("Using cached TLS session for %s", self.endpoint) + + ssl_socket = self.ssl_context.wrap_socket(self._socket, **opts) + + # Note: Session is NOT stored here - it will be stored after successful connection + # in _connect_socket() to ensure we only cache sessions for successful connections + + return ssl_socket def _initiate_connection(self, sockaddr): if self.features.shard_id is not None: @@ -968,6 +1009,15 @@ def _connect_socket(self): # run that here. if self._check_hostname: self._validate_hostname() + + # Store the TLS session after successful connection + # This ensures we only cache sessions for connections that actually succeeded + if self.tls_session_cache and self.ssl_context and hasattr(self._socket, 'session'): + if self._socket.session: + self.tls_session_cache.set_session(self.endpoint, self._socket.session) + if hasattr(self._socket, 'session_reused') and self._socket.session_reused: + log.debug("TLS session was reused for %s", self.endpoint) + sockerr = None break except socket.error as err: diff --git a/cassandra/io/eventletreactor.py b/cassandra/io/eventletreactor.py index 42874036d5..44193cb4ea 100644 --- a/cassandra/io/eventletreactor.py +++ b/cassandra/io/eventletreactor.py @@ -109,6 +109,13 @@ def _wrap_socket_from_context(self): # This is necessary for SNI self._socket.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + # Apply cached TLS session for resumption (PyOpenSSL) + if self.tls_session_cache: + cached_session = self.tls_session_cache.get_session(self.endpoint) + if cached_session: + self._socket.set_session(cached_session) + log.debug("Using cached TLS session for %s", self.endpoint) + def _initiate_connection(self, sockaddr): if self.uses_legacy_ssl_options: super(EventletConnection, self)._initiate_connection(sockaddr) @@ -116,6 +123,13 @@ def _initiate_connection(self, sockaddr): self._socket.connect(sockaddr) if self.ssl_context or self.ssl_options: self._socket.do_handshake() + # Store TLS session after successful handshake (PyOpenSSL) + if self.tls_session_cache: + session = self._socket.get_session() + if session: + self.tls_session_cache.set_session(self.endpoint, session) + if self._socket.session_reused(): + log.debug("TLS session was reused for %s", self.endpoint) def _match_hostname(self): if self.uses_legacy_ssl_options: diff --git a/cassandra/io/twistedreactor.py b/cassandra/io/twistedreactor.py index e4605a7446..73c18bc4c5 100644 --- a/cassandra/io/twistedreactor.py +++ b/cassandra/io/twistedreactor.py @@ -139,11 +139,12 @@ def _on_loop_timer(self): @implementer(IOpenSSLClientConnectionCreator) class _SSLCreator(object): - def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout): + def __init__(self, endpoint, ssl_context, ssl_options, check_hostname, timeout, tls_session_cache=None): self.endpoint = endpoint self.ssl_options = ssl_options self.check_hostname = check_hostname self.timeout = timeout + self.tls_session_cache = tls_session_cache if ssl_context: self.context = ssl_context @@ -171,11 +172,27 @@ def info_callback(self, connection, where, ret): transport = connection.get_app_data() transport.failVerification(Failure(ConnectionException("Hostname verification failed", self.endpoint))) + # Store TLS session after successful handshake (PyOpenSSL) + if self.tls_session_cache: + session = connection.get_session() + if session: + self.tls_session_cache.set_session(self.endpoint, session) + if connection.session_reused(): + log.debug("TLS session was reused for %s", self.endpoint) + def clientConnectionForTLS(self, tlsProtocol): connection = SSL.Connection(self.context, None) connection.set_app_data(tlsProtocol) if self.ssl_options and "server_hostname" in self.ssl_options: connection.set_tlsext_host_name(self.ssl_options['server_hostname'].encode('ascii')) + + # Apply cached TLS session for resumption (PyOpenSSL) + if self.tls_session_cache: + cached_session = self.tls_session_cache.get_session(self.endpoint) + if cached_session: + connection.set_session(cached_session) + log.debug("Using cached TLS session for %s", self.endpoint) + return connection @@ -241,6 +258,7 @@ def add_connection(self): self.ssl_options, self._check_hostname, self.connect_timeout, + tls_session_cache=self.tls_session_cache, ) endpoint = SSL4ClientEndpoint( diff --git a/cassandra/tls.py b/cassandra/tls.py new file mode 100644 index 0000000000..2e8c94a559 --- /dev/null +++ b/cassandra/tls.py @@ -0,0 +1,246 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +TLS session caching implementation for faster reconnections. +""" + +from abc import ABC, abstractmethod +from collections import OrderedDict, namedtuple +from threading import RLock +import time + + +# Named tuple for TLS session cache entries +_SessionCacheEntry = namedtuple('_SessionCacheEntry', ['session', 'timestamp']) + + +class TLSSessionCache(ABC): + """ + Abstract base class for TLS session caching. + + Implementations should provide thread-safe caching of TLS sessions + to enable session resumption for faster reconnections. + """ + + @abstractmethod + def get_session(self, endpoint): + """ + Get a cached TLS session for the given endpoint. + + Args: + endpoint: The EndPoint object representing the connection target + + Returns: + ssl.SSLSession object if a valid cached session exists, None otherwise + """ + pass + + @abstractmethod + def set_session(self, endpoint, session): + """ + Store a TLS session for the given endpoint. + + Args: + endpoint: The EndPoint object representing the connection target + session: The ssl.SSLSession object to cache + """ + pass + + @abstractmethod + def clear_expired(self): + """Remove all expired sessions from the cache.""" + pass + + @abstractmethod + def clear(self): + """Clear all sessions from the cache.""" + pass + + @abstractmethod + def size(self): + """Return the current number of cached sessions.""" + pass + + +class DefaultTLSSessionCache(TLSSessionCache): + """ + Default implementation of TLS session caching. + + This cache stores TLS sessions per endpoint to allow quick TLS + renegotiation when reconnecting to the same server. Sessions are + automatically expired after a TTL and the cache has a maximum + size with LRU eviction using OrderedDict. + + TLS session resumption works with both TLS 1.2 and TLS 1.3: + - TLS 1.2: Session IDs (RFC 5246) and optionally Session Tickets (RFC 5077) + - TLS 1.3: Session Tickets (RFC 8446) + + Python's ssl.SSLSession API handles both versions transparently, so no + version-specific checks are needed. + """ + + # Cleanup expired sessions every N set_session calls + _EXPIRY_CLEANUP_INTERVAL = 100 + + def __init__(self, max_size=100, ttl=3600, cache_by_host_only=False): + """ + Initialize the TLS session cache. + + Args: + max_size: Maximum number of sessions to cache (default: 100) + ttl: Time-to-live for cached sessions in seconds (default: 3600) + cache_by_host_only: If True, cache sessions by host only (ignoring port). + If False, cache by host and port (default: False) + """ + self._sessions = OrderedDict() # OrderedDict for O(1) LRU eviction + self._lock = RLock() + self._max_size = max_size + self._ttl = ttl + self._cache_by_host_only = cache_by_host_only + self._operation_count = 0 # Counter for opportunistic cleanup + + def _make_key(self, endpoint): + """ + Create a cache key from endpoint. + + Uses the endpoint's tls_session_cache_key property which returns + appropriate components for each endpoint type (e.g., includes + server_name for SNI endpoints to prevent cache collisions). + """ + key = endpoint.tls_session_cache_key + if self._cache_by_host_only: + # When caching by host only, use just the first component (address/path) + return (key[0],) + else: + return key + + def get_session(self, endpoint): + """ + Get a cached TLS session for the given endpoint. + + Args: + endpoint: The EndPoint object representing the connection target + + Returns: + ssl.SSLSession object if a valid cached session exists, None otherwise + """ + key = self._make_key(endpoint) + with self._lock: + if key not in self._sessions: + return None + + entry = self._sessions[key] + + # Check if session has expired + if time.time() - entry.timestamp > self._ttl: + del self._sessions[key] + return None + + # Move to end to mark as recently used (LRU) + self._sessions.move_to_end(key) + return entry.session + + def set_session(self, endpoint, session): + """ + Store a TLS session for the given endpoint. + + Args: + endpoint: The EndPoint object representing the connection target + session: The ssl.SSLSession object to cache + """ + if session is None: + return + + key = self._make_key(endpoint) + current_time = time.time() + + with self._lock: + # Opportunistically clean up expired sessions periodically + self._operation_count += 1 + if self._operation_count >= self._EXPIRY_CLEANUP_INTERVAL: + self._operation_count = 0 + self._clear_expired_unlocked(current_time) + + # If key already exists, just update it + if key in self._sessions: + self._sessions[key] = _SessionCacheEntry(session, current_time) + self._sessions.move_to_end(key) + return + + # If cache is at max size, remove least recently used entry (first item) + if len(self._sessions) >= self._max_size: + self._sessions.popitem(last=False) + + # Store session with creation time + self._sessions[key] = _SessionCacheEntry(session, current_time) + + def _clear_expired_unlocked(self, current_time=None): + """Remove all expired sessions (must be called with lock held).""" + if current_time is None: + current_time = time.time() + expired_keys = [ + key for key, entry in self._sessions.items() + if current_time - entry.timestamp > self._ttl + ] + for key in expired_keys: + del self._sessions[key] + + def clear_expired(self): + """Remove all expired sessions from the cache.""" + with self._lock: + self._clear_expired_unlocked() + + def clear(self): + """Clear all sessions from the cache.""" + with self._lock: + self._sessions.clear() + + def size(self): + """Return the current number of cached sessions.""" + with self._lock: + return len(self._sessions) + + +class TLSSessionCacheOptions: + """ + Default implementation of TLS session cache configuration options. + """ + + def __init__(self, max_size=100, ttl=3600, cache_by_host_only=False): + """ + Initialize TLS session cache options. + + Args: + max_size: Maximum number of sessions to cache (default: 100) + ttl: Time-to-live for cached sessions in seconds (default: 3600) + cache_by_host_only: If True, cache sessions by host only (ignoring port). + If False, cache by host and port (default: False) + """ + self.max_size = max_size + self.ttl = ttl + self.cache_by_host_only = cache_by_host_only + + def create_cache(self): + """ + Build and return a DefaultTLSSessionCache implementation. + + Returns: + DefaultTLSSessionCache: A configured session cache instance + """ + return DefaultTLSSessionCache( + max_size=self.max_size, + ttl=self.ttl, + cache_by_host_only=self.cache_by_host_only + ) diff --git a/docs/security.rst b/docs/security.rst index 57e2be71da..cc86d4cef1 100644 --- a/docs/security.rst +++ b/docs/security.rst @@ -402,3 +402,115 @@ then you can do a proxy execute... s.execute('select * from k.t;', execute_as='user1') # the request will be executed as 'user1' Please see the `official documentation `_ for more details on the feature and configuration process. + +TLS Session Resumption +---------------------- + +.. versionadded:: 3.30.0 + +The driver automatically caches TLS sessions to enable session resumption for faster reconnections. +When a TLS connection is established, the session is cached and can be reused for subsequent +connections to the same endpoint, reducing handshake latency and CPU usage. + +**TLS Version Support**: Session resumption works with both TLS 1.2 and TLS 1.3. TLS 1.2 uses +Session IDs and optionally Session Tickets (RFC 5077), while TLS 1.3 uses Session Tickets (RFC 8446) +as the primary mechanism. Python's ``ssl.SSLSession`` API handles both versions transparently. + +Session caching is **enabled by default** when SSL/TLS is configured and applies to the following +connection classes: + +* :class:`~cassandra.io.asyncorereactor.AsyncoreConnection` (default) +* :class:`~cassandra.io.libevreactor.LibevConnection` +* :class:`~cassandra.io.asyncioreactor.AsyncioConnection` +* :class:`~cassandra.io.geventreactor.GeventConnection` (when not using SSL) + +.. note:: + Session caching is not currently supported for PyOpenSSL-based reactors + (:class:`~cassandra.io.twistedreactor.TwistedConnection`, + :class:`~cassandra.io.eventletreactor.EventletConnection`) but may be added in a future release. + +Configuration +^^^^^^^^^^^^^ + +TLS session caching is controlled by three cluster-level parameters: + +* :attr:`~.Cluster.tls_session_cache_enabled` - Enable or disable session caching (default: ``True``) +* :attr:`~.Cluster.tls_session_cache_size` - Maximum number of sessions to cache (default: ``100``) +* :attr:`~.Cluster.tls_session_cache_ttl` - Time-to-live for cached sessions in seconds (default: ``3600``) + +Example with default settings (session caching enabled): + +.. code-block:: python + + from cassandra.cluster import Cluster + import ssl + + ssl_context = ssl.create_default_context(cafile='/path/to/ca.crt') + cluster = Cluster( + contact_points=['127.0.0.1'], + ssl_context=ssl_context + ) + session = cluster.connect() + +Example with custom cache settings: + +.. code-block:: python + + from cassandra.cluster import Cluster + import ssl + + ssl_context = ssl.create_default_context(cafile='/path/to/ca.crt') + cluster = Cluster( + contact_points=['127.0.0.1'], + ssl_context=ssl_context, + tls_session_cache_size=200, # Cache up to 200 sessions + tls_session_cache_ttl=7200 # Sessions expire after 2 hours + ) + session = cluster.connect() + +Example with session caching disabled: + +.. code-block:: python + + from cassandra.cluster import Cluster + import ssl + + ssl_context = ssl.create_default_context(cafile='/path/to/ca.crt') + cluster = Cluster( + contact_points=['127.0.0.1'], + ssl_context=ssl_context, + tls_session_cache_enabled=False + ) + session = cluster.connect() + +How It Works +^^^^^^^^^^^^ + +When session caching is enabled: + +1. The first connection to an endpoint establishes a new TLS session and caches it +2. Subsequent connections to the same endpoint reuse the cached session +3. Sessions are cached per endpoint (host:port combination) +4. Sessions expire after the configured TTL +5. When the cache reaches max size, the least recently used session is evicted + +Performance Benefits +^^^^^^^^^^^^^^^^^^^^ + +TLS session resumption is a standard TLS feature that provides performance benefits: + +* **Faster reconnection times** - Reduced handshake latency by reusing cached sessions +* **Lower CPU usage** - Fewer cryptographic operations during reconnection +* **Better overall throughput** - Especially beneficial for workloads with frequent reconnections + +The actual performance improvement depends on various factors including network latency, +server configuration, and workload characteristics. + +Security Considerations +^^^^^^^^^^^^^^^^^^^^^^^ + +* Sessions are stored in memory only and never persisted to disk +* Sessions are cached per cluster and not shared across different cluster instances +* Sessions for one endpoint are never used for a different endpoint +* Hostname verification still occurs on each connection, even when reusing sessions +* Sessions automatically expire after the configured TTL diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index 56dc6a5c2d..6342afe24b 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -500,3 +500,107 @@ def test_can_connect_with_sslcontext_default_context(self): """ ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS) validate_ssl_options(ssl_context=ssl_context) + + @unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context") + def test_tls_session_cache_enabled_by_default(self): + """ + Test that TLS session caching is enabled by default when SSL is configured. + + @since 3.30.0 + @expected_result TLS session cache is created and configured + @test_category connection:ssl + """ + ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS) + cluster = TestCluster( + contact_points=[DefaultEndPoint('127.0.0.1')], + ssl_context=ssl_context + ) + + # Verify session cache was created + self.assertIsNotNone(cluster._tls_session_cache) + self.assertEqual(cluster.tls_session_cache_enabled, True) + self.assertEqual(cluster.tls_session_cache_size, 100) + self.assertEqual(cluster.tls_session_cache_ttl, 3600) + + cluster.shutdown() + + @unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context") + def test_tls_session_cache_can_be_disabled(self): + """ + Test that TLS session caching can be disabled. + + @since 3.30.0 + @expected_result TLS session cache is not created when disabled + @test_category connection:ssl + """ + ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS) + cluster = TestCluster( + contact_points=[DefaultEndPoint('127.0.0.1')], + ssl_context=ssl_context, + tls_session_cache_enabled=False + ) + + # Verify session cache was not created + self.assertIsNone(cluster._tls_session_cache) + self.assertEqual(cluster.tls_session_cache_enabled, False) + + cluster.shutdown() + + @unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context") + def test_tls_session_reuse(self): + """ + Test that TLS sessions are reused across multiple connections to the same endpoint. + + @since 3.30.0 + @expected_result Sessions are cached and reused, reducing handshake overhead + @test_category connection:ssl + """ + ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS) + cluster = TestCluster( + contact_points=[DefaultEndPoint('127.0.0.1')], + ssl_context=ssl_context + ) + + try: + session = cluster.connect(wait_for_all_pools=True) + + # Verify session cache was populated + self.assertIsNotNone(cluster._tls_session_cache) + initial_cache_size = cluster._tls_session_cache.size() + self.assertGreater(initial_cache_size, 0, "Session cache should contain sessions after connection") + + # Execute a simple query + result = session.execute("SELECT * FROM system.local WHERE key='local'") + self.assertIsNotNone(result) + + # Get a connection from the pool to check session_reused flag + # Note: We can't easily check the exact connection that was reused, + # but we can verify the cache has sessions + cache_size = cluster._tls_session_cache.size() + self.assertGreater(cache_size, 0, "Session cache should contain sessions") + + finally: + cluster.shutdown() + + @unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context") + def test_tls_session_cache_configuration(self): + """ + Test that TLS session cache can be configured with custom parameters. + + @since 3.30.0 + @expected_result Custom cache configuration is applied + @test_category connection:ssl + """ + ssl_context = ssl.create_default_context(cafile=CLIENT_CA_CERTS) + cluster = TestCluster( + contact_points=[DefaultEndPoint('127.0.0.1')], + ssl_context=ssl_context, + tls_session_cache_size=50, + tls_session_cache_ttl=1800 + ) + + self.assertIsNotNone(cluster._tls_session_cache) + self.assertEqual(cluster.tls_session_cache_size, 50) + self.assertEqual(cluster.tls_session_cache_ttl, 1800) + + cluster.shutdown() diff --git a/tests/unit/io/test_eventletreactor.py b/tests/unit/io/test_eventletreactor.py index d3962196a4..75257358b0 100644 --- a/tests/unit/io/test_eventletreactor.py +++ b/tests/unit/io/test_eventletreactor.py @@ -75,3 +75,175 @@ def _timers(self): # There is no unpatching because there is not a clear way # of doing it reliably + + +try: + from eventlet.green.OpenSSL import SSL as _ + _HAS_EVENTLET_PYOPENSSL = True +except ImportError: + _HAS_EVENTLET_PYOPENSSL = False + + +@notpypy +@unittest.skipIf(skip_condition, "Skipping the eventlet tests because it's not installed") +@unittest.skipIf(not _HAS_EVENTLET_PYOPENSSL, "PyOpenSSL not available for eventlet") +class EventletTLSSessionCacheTest(unittest.TestCase): + """Test TLS session caching for EventletConnection with PyOpenSSL.""" + + @classmethod + def setUpClass(cls): + if skip_condition: + return + import eventlet + eventlet.sleep() + monkey_patch() + EventletConnection.initialize_reactor() + + def test_wrap_socket_applies_cached_session(self): + """Test that _wrap_socket_from_context applies cached TLS session.""" + from unittest.mock import Mock, MagicMock + from cassandra.connection import DefaultEndPoint + + # Create mock objects + mock_cache = Mock() + mock_session = Mock() + mock_cache.get_session.return_value = mock_session + + mock_ssl_context = MagicMock() + mock_ssl_connection = MagicMock() + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('eventlet.green.socket.socket'): + with patch.object(EventletConnection, '_connect_socket'): + with patch.object(EventletConnection, '_send_options_message'): + conn = EventletConnection( + endpoint, + cql_version='3.0.1', + connect_timeout=5 + ) + conn.ssl_context = mock_ssl_context + conn.ssl_options = {} + conn.tls_session_cache = mock_cache + conn._socket = Mock() + + # Patch SSL.Connection to return our mock + with patch('cassandra.io.eventletreactor.SSL.Connection', return_value=mock_ssl_connection): + conn._wrap_socket_from_context() + + # Verify get_session was called with endpoint + mock_cache.get_session.assert_called_once_with(endpoint) + + # Verify set_session was called on the SSL connection + mock_ssl_connection.set_session.assert_called_once_with(mock_session) + + def test_wrap_socket_no_session_when_cache_empty(self): + """Test that _wrap_socket_from_context handles empty cache.""" + from unittest.mock import Mock, MagicMock + from cassandra.connection import DefaultEndPoint + + mock_cache = Mock() + mock_cache.get_session.return_value = None # No cached session + + mock_ssl_context = MagicMock() + mock_ssl_connection = MagicMock() + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('eventlet.green.socket.socket'): + with patch.object(EventletConnection, '_connect_socket'): + with patch.object(EventletConnection, '_send_options_message'): + conn = EventletConnection( + endpoint, + cql_version='3.0.1', + connect_timeout=5 + ) + conn.ssl_context = mock_ssl_context + conn.ssl_options = {} + conn.tls_session_cache = mock_cache + conn._socket = Mock() + + with patch('cassandra.io.eventletreactor.SSL.Connection', return_value=mock_ssl_connection): + conn._wrap_socket_from_context() + + # Verify get_session was called + mock_cache.get_session.assert_called_once_with(endpoint) + + # Verify set_session was NOT called on SSL connection (no cached session) + mock_ssl_connection.set_session.assert_not_called() + + def test_initiate_connection_stores_session_after_handshake(self): + """Test that _initiate_connection stores session after successful handshake.""" + from unittest.mock import Mock, MagicMock + from cassandra.connection import DefaultEndPoint + + mock_cache = Mock() + mock_session = Mock() + + mock_ssl_socket = MagicMock() + mock_ssl_socket.get_session.return_value = mock_session + mock_ssl_socket.session_reused.return_value = False + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('eventlet.green.socket.socket'): + with patch.object(EventletConnection, '_connect_socket'): + with patch.object(EventletConnection, '_send_options_message'): + conn = EventletConnection( + endpoint, + cql_version='3.0.1', + connect_timeout=5 + ) + conn.ssl_context = Mock() + conn.ssl_options = {} + conn.tls_session_cache = mock_cache + conn._socket = mock_ssl_socket + conn.uses_legacy_ssl_options = False + + sockaddr = ('127.0.0.1', 9042) + conn._initiate_connection(sockaddr) + + # Verify handshake was called + mock_ssl_socket.do_handshake.assert_called_once() + + # Verify session was retrieved and stored + mock_ssl_socket.get_session.assert_called_once() + mock_cache.set_session.assert_called_once_with(endpoint, mock_session) + + def test_initiate_connection_logs_session_reuse(self): + """Test that _initiate_connection logs when session is reused.""" + from unittest.mock import Mock, MagicMock + from cassandra.connection import DefaultEndPoint + + mock_cache = Mock() + mock_session = Mock() + + mock_ssl_socket = MagicMock() + mock_ssl_socket.get_session.return_value = mock_session + mock_ssl_socket.session_reused.return_value = True # Session was reused + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('eventlet.green.socket.socket'): + with patch.object(EventletConnection, '_connect_socket'): + with patch.object(EventletConnection, '_send_options_message'): + conn = EventletConnection( + endpoint, + cql_version='3.0.1', + connect_timeout=5 + ) + conn.ssl_context = Mock() + conn.ssl_options = {} + conn.tls_session_cache = mock_cache + conn._socket = mock_ssl_socket + conn.uses_legacy_ssl_options = False + + with patch('cassandra.io.eventletreactor.log') as mock_log: + sockaddr = ('127.0.0.1', 9042) + conn._initiate_connection(sockaddr) + + # Verify session_reused was checked + mock_ssl_socket.session_reused.assert_called_once() + + # Verify debug log was called for session reuse + mock_log.debug.assert_called() diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py index 54abe884ae..6151d3e981 100644 --- a/tests/unit/io/test_twistedreactor.py +++ b/tests/unit/io/test_twistedreactor.py @@ -188,3 +188,239 @@ def test_push(self, mock_connectTCP): self.obj_ut.push('123 pickup') self.mock_reactor_cft.assert_called_with( transport_mock.write, '123 pickup') + + +try: + from OpenSSL import SSL as PyOpenSSL + _HAS_PYOPENSSL = True +except ImportError: + _HAS_PYOPENSSL = False + + +@unittest.skipIf(twistedreactor is None, "Twisted libraries not available") +@unittest.skipIf(not _HAS_PYOPENSSL, "PyOpenSSL not available") +class TestSSLCreatorTLSSessionCache(unittest.TestCase): + """Test TLS session caching for _SSLCreator with PyOpenSSL.""" + + def setUp(self): + twistedreactor.TwistedConnection.initialize_reactor() + + def tearDown(self): + loop = twistedreactor.TwistedConnection._loop + if loop and not loop._reactor_stopped(): + loop._cleanup() + + def test_client_connection_applies_cached_session(self): + """Test that clientConnectionForTLS applies cached TLS session.""" + mock_cache = Mock() + mock_session = Mock() + mock_cache.get_session.return_value = mock_session + + mock_ssl_context = Mock() + mock_ssl_connection = Mock() + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('cassandra.io.twistedreactor.SSL.Connection', return_value=mock_ssl_connection): + creator = twistedreactor._SSLCreator( + endpoint=endpoint, + ssl_context=mock_ssl_context, + ssl_options={}, + check_hostname=False, + timeout=5, + tls_session_cache=mock_cache + ) + + mock_tls_protocol = Mock() + creator.clientConnectionForTLS(mock_tls_protocol) + + # Verify get_session was called with endpoint + mock_cache.get_session.assert_called_once_with(endpoint) + + # Verify set_session was called on the SSL connection + mock_ssl_connection.set_session.assert_called_once_with(mock_session) + + def test_client_connection_no_session_when_cache_empty(self): + """Test that clientConnectionForTLS handles empty cache.""" + mock_cache = Mock() + mock_cache.get_session.return_value = None # No cached session + + mock_ssl_context = Mock() + mock_ssl_connection = Mock() + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('cassandra.io.twistedreactor.SSL.Connection', return_value=mock_ssl_connection): + creator = twistedreactor._SSLCreator( + endpoint=endpoint, + ssl_context=mock_ssl_context, + ssl_options={}, + check_hostname=False, + timeout=5, + tls_session_cache=mock_cache + ) + + mock_tls_protocol = Mock() + creator.clientConnectionForTLS(mock_tls_protocol) + + # Verify get_session was called + mock_cache.get_session.assert_called_once_with(endpoint) + + # Verify set_session was NOT called on SSL connection + mock_ssl_connection.set_session.assert_not_called() + + def test_client_connection_no_cache_configured(self): + """Test that clientConnectionForTLS works without a cache.""" + mock_ssl_context = Mock() + mock_ssl_connection = Mock() + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('cassandra.io.twistedreactor.SSL.Connection', return_value=mock_ssl_connection): + creator = twistedreactor._SSLCreator( + endpoint=endpoint, + ssl_context=mock_ssl_context, + ssl_options={}, + check_hostname=False, + timeout=5, + tls_session_cache=None # No cache + ) + + mock_tls_protocol = Mock() + result = creator.clientConnectionForTLS(mock_tls_protocol) + + # Should return the connection without errors + self.assertEqual(result, mock_ssl_connection) + + # Verify set_session was NOT called + mock_ssl_connection.set_session.assert_not_called() + + def test_info_callback_stores_session_after_handshake(self): + """Test that info_callback stores session after handshake.""" + mock_cache = Mock() + mock_session = Mock() + + mock_ssl_context = Mock() + mock_ssl_connection = Mock() + mock_ssl_connection.get_session.return_value = mock_session + mock_ssl_connection.session_reused.return_value = False + mock_ssl_connection.get_peer_certificate.return_value.get_subject.return_value.commonName = '127.0.0.1' + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('cassandra.io.twistedreactor.SSL.Connection', return_value=mock_ssl_connection): + creator = twistedreactor._SSLCreator( + endpoint=endpoint, + ssl_context=mock_ssl_context, + ssl_options={}, + check_hostname=False, + timeout=5, + tls_session_cache=mock_cache + ) + + # Simulate handshake completion + creator.info_callback(mock_ssl_connection, PyOpenSSL.SSL_CB_HANDSHAKE_DONE, 0) + + # Verify session was retrieved and stored + mock_ssl_connection.get_session.assert_called_once() + mock_cache.set_session.assert_called_once_with(endpoint, mock_session) + + def test_info_callback_logs_session_reuse(self): + """Test that info_callback logs when session is reused.""" + mock_cache = Mock() + mock_session = Mock() + + mock_ssl_context = Mock() + mock_ssl_connection = Mock() + mock_ssl_connection.get_session.return_value = mock_session + mock_ssl_connection.session_reused.return_value = True # Session was reused + mock_ssl_connection.get_peer_certificate.return_value.get_subject.return_value.commonName = '127.0.0.1' + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('cassandra.io.twistedreactor.SSL.Connection', return_value=mock_ssl_connection): + creator = twistedreactor._SSLCreator( + endpoint=endpoint, + ssl_context=mock_ssl_context, + ssl_options={}, + check_hostname=False, + timeout=5, + tls_session_cache=mock_cache + ) + + with patch('cassandra.io.twistedreactor.log') as mock_log: + creator.info_callback(mock_ssl_connection, PyOpenSSL.SSL_CB_HANDSHAKE_DONE, 0) + + # Verify session_reused was checked + mock_ssl_connection.session_reused.assert_called_once() + + # Verify debug log was called for session reuse + mock_log.debug.assert_called() + + def test_info_callback_no_session_store_when_no_cache(self): + """Test that info_callback doesn't store session when no cache configured.""" + mock_ssl_context = Mock() + mock_ssl_connection = Mock() + mock_ssl_connection.get_peer_certificate.return_value.get_subject.return_value.commonName = '127.0.0.1' + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + with patch('cassandra.io.twistedreactor.SSL.Connection', return_value=mock_ssl_connection): + creator = twistedreactor._SSLCreator( + endpoint=endpoint, + ssl_context=mock_ssl_context, + ssl_options={}, + check_hostname=False, + timeout=5, + tls_session_cache=None # No cache + ) + + creator.info_callback(mock_ssl_connection, PyOpenSSL.SSL_CB_HANDSHAKE_DONE, 0) + + # Verify get_session was NOT called + mock_ssl_connection.get_session.assert_not_called() + + +@unittest.skipIf(twistedreactor is None, "Twisted libraries not available") +@unittest.skipIf(not _HAS_PYOPENSSL, "PyOpenSSL not available") +class TestTwistedConnectionTLSSessionCache(unittest.TestCase): + """Test TLS session caching integration in TwistedConnection.""" + + def setUp(self): + if twistedreactor.TwistedConnection._loop: + twistedreactor.TwistedConnection._loop._cleanup() + twistedreactor.TwistedConnection.initialize_reactor() + self.reactor_cft_patcher = patch('twisted.internet.reactor.callFromThread') + self.reactor_run_patcher = patch('twisted.internet.reactor.run') + self.mock_reactor_cft = self.reactor_cft_patcher.start() + self.mock_reactor_run = self.reactor_run_patcher.start() + + def tearDown(self): + self.reactor_cft_patcher.stop() + self.reactor_run_patcher.stop() + + def test_add_connection_passes_tls_session_cache(self): + """Test that add_connection passes tls_session_cache to _SSLCreator.""" + mock_cache = Mock() + mock_ssl_context = Mock() + + endpoint = DefaultEndPoint('127.0.0.1', 9042) + + conn = twistedreactor.TwistedConnection( + endpoint, + cql_version='3.0.1', + connect_timeout=5 + ) + conn.ssl_context = mock_ssl_context + conn.ssl_options = {} + conn.tls_session_cache = mock_cache + + with patch('cassandra.io.twistedreactor._SSLCreator') as mock_creator_class: + with patch('cassandra.io.twistedreactor.SSL4ClientEndpoint'): + with patch('cassandra.io.twistedreactor.connectProtocol'): + conn.add_connection() + + # Verify _SSLCreator was called with tls_session_cache + mock_creator_class.assert_called_once() + call_kwargs = mock_creator_class.call_args + self.assertEqual(call_kwargs.kwargs.get('tls_session_cache'), mock_cache) diff --git a/tests/unit/test_endpoints.py b/tests/unit/test_endpoints.py index 14fb8b5806..acddc1b711 100644 --- a/tests/unit/test_endpoints.py +++ b/tests/unit/test_endpoints.py @@ -10,7 +10,7 @@ import itertools -from cassandra.connection import DefaultEndPoint, SniEndPoint, SniEndPointFactory +from cassandra.connection import DefaultEndPoint, SniEndPointFactory, UnixSocketEndPoint from unittest.mock import patch @@ -53,3 +53,41 @@ def test_endpoint_resolve(self): for i in range(10): (address, _) = endpoint.resolve() assert address == next(it) + + def test_sni_endpoint_tls_session_cache_key(self): + """Test that SNI endpoints include server_name in cache key.""" + endpoint1 = self.endpoint_factory.create_from_sni('server1.example.com') + endpoint2 = self.endpoint_factory.create_from_sni('server2.example.com') + + # Both have same proxy address and port + assert endpoint1.address == endpoint2.address + assert endpoint1.port == endpoint2.port + + # But different cache keys due to server_name + assert endpoint1.tls_session_cache_key != endpoint2.tls_session_cache_key + assert endpoint1.tls_session_cache_key == ('proxy.datastax.com', 30002, 'server1.example.com') + assert endpoint2.tls_session_cache_key == ('proxy.datastax.com', 30002, 'server2.example.com') + + +class DefaultEndPointTest(unittest.TestCase): + + def test_tls_session_cache_key(self): + """Test that DefaultEndPoint cache key is (address, port).""" + endpoint = DefaultEndPoint('10.0.0.1', 9042) + assert endpoint.tls_session_cache_key == ('10.0.0.1', 9042) + + endpoint2 = DefaultEndPoint('10.0.0.1', 9043) + assert endpoint2.tls_session_cache_key == ('10.0.0.1', 9043) + assert endpoint.tls_session_cache_key != endpoint2.tls_session_cache_key + + +class UnixSocketEndPointTest(unittest.TestCase): + + def test_tls_session_cache_key(self): + """Test that UnixSocketEndPoint cache key is just the path.""" + endpoint = UnixSocketEndPoint('/var/run/scylla.sock') + assert endpoint.tls_session_cache_key == ('/var/run/scylla.sock',) + + # Different paths should have different keys + endpoint2 = UnixSocketEndPoint('/tmp/scylla.sock') + assert endpoint.tls_session_cache_key != endpoint2.tls_session_cache_key diff --git a/tests/unit/test_tls_session_cache.py b/tests/unit/test_tls_session_cache.py new file mode 100644 index 0000000000..e3b14ab832 --- /dev/null +++ b/tests/unit/test_tls_session_cache.py @@ -0,0 +1,288 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import unittest +from unittest.mock import Mock +from threading import Thread + +from cassandra.tls import DefaultTLSSessionCache + + +class MockEndPoint: + """Mock EndPoint for testing.""" + def __init__(self, address, port): + self.address = address + self.port = port + + @property + def tls_session_cache_key(self): + return (self.address, self.port) + + +class TLSSessionCacheTest(unittest.TestCase): + """Test the TLSSessionCache implementation.""" + + def test_cache_basic_operations(self): + """Test basic get and set operations.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=60) + + # Create a mock session and endpoint + mock_session = Mock() + endpoint = MockEndPoint('host1', 9042) + + # Initially empty + self.assertIsNone(cache.get_session(endpoint)) + self.assertEqual(cache.size(), 0) + + # Set a session + cache.set_session(endpoint, mock_session) + self.assertEqual(cache.size(), 1) + + # Retrieve the session + retrieved = cache.get_session(endpoint) + self.assertEqual(retrieved, mock_session) + + def test_cache_different_endpoints(self): + """Test that different endpoints have separate cache entries.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=60) + + session1 = Mock(name='session1') + session2 = Mock(name='session2') + session3 = Mock(name='session3') + + endpoint1 = MockEndPoint('host1', 9042) + endpoint2 = MockEndPoint('host2', 9042) + endpoint3 = MockEndPoint('host1', 9043) + + cache.set_session(endpoint1, session1) + cache.set_session(endpoint2, session2) + cache.set_session(endpoint3, session3) + + self.assertEqual(cache.size(), 3) + self.assertEqual(cache.get_session(endpoint1), session1) + self.assertEqual(cache.get_session(endpoint2), session2) + self.assertEqual(cache.get_session(endpoint3), session3) + + def test_cache_ttl_expiration(self): + """Test that sessions expire after TTL.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=1) # 1 second TTL + + mock_session = Mock() + endpoint = MockEndPoint('host1', 9042) + cache.set_session(endpoint, mock_session) + + # Should be retrievable immediately + self.assertIsNotNone(cache.get_session(endpoint)) + + # Wait for expiration + time.sleep(1.1) + + # Should be expired + self.assertIsNone(cache.get_session(endpoint)) + self.assertEqual(cache.size(), 0) + + def test_cache_max_size_eviction(self): + """Test that LRU eviction works when cache is full.""" + cache = DefaultTLSSessionCache(max_size=3, ttl=60) + + session1 = Mock(name='session1') + session2 = Mock(name='session2') + session3 = Mock(name='session3') + session4 = Mock(name='session4') + + endpoint1 = MockEndPoint('host1', 9042) + endpoint2 = MockEndPoint('host2', 9042) + endpoint3 = MockEndPoint('host3', 9042) + endpoint4 = MockEndPoint('host4', 9042) + + # Fill cache to capacity + cache.set_session(endpoint1, session1) + cache.set_session(endpoint2, session2) + cache.set_session(endpoint3, session3) + + self.assertEqual(cache.size(), 3) + + # Access session2 to mark it as recently used + cache.get_session(endpoint2) + + # Add a fourth session - should evict session1 (least recently used) + cache.set_session(endpoint4, session4) + + self.assertEqual(cache.size(), 3) + self.assertIsNone(cache.get_session(endpoint1)) + self.assertIsNotNone(cache.get_session(endpoint2)) + self.assertIsNotNone(cache.get_session(endpoint3)) + self.assertIsNotNone(cache.get_session(endpoint4)) + + def test_cache_clear_expired(self): + """Test manual clearing of expired sessions.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=1) + + session1 = Mock(name='session1') + session2 = Mock(name='session2') + + endpoint1 = MockEndPoint('host1', 9042) + endpoint2 = MockEndPoint('host2', 9042) + + cache.set_session(endpoint1, session1) + time.sleep(1.1) # Let session1 expire + cache.set_session(endpoint2, session2) + + # Before clearing, both are in cache + self.assertEqual(cache.size(), 2) + + # Clear expired sessions + cache.clear_expired() + + # Only session2 should remain + self.assertEqual(cache.size(), 1) + self.assertIsNone(cache.get_session(endpoint1)) + self.assertIsNotNone(cache.get_session(endpoint2)) + + def test_cache_clear_all(self): + """Test clearing all sessions from cache.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=60) + + endpoint1 = MockEndPoint('host1', 9042) + endpoint2 = MockEndPoint('host2', 9042) + endpoint3 = MockEndPoint('host3', 9042) + + cache.set_session(endpoint1, Mock()) + cache.set_session(endpoint2, Mock()) + cache.set_session(endpoint3, Mock()) + + self.assertEqual(cache.size(), 3) + + cache.clear() + + self.assertEqual(cache.size(), 0) + + def test_cache_none_session(self): + """Test that None sessions are not cached.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=60) + + endpoint = MockEndPoint('host1', 9042) + cache.set_session(endpoint, None) + + self.assertEqual(cache.size(), 0) + self.assertIsNone(cache.get_session(endpoint)) + + def test_cache_update_existing_session(self): + """Test that updating an existing session works correctly.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=60) + + session1 = Mock(name='session1') + session2 = Mock(name='session2') + + endpoint = MockEndPoint('host1', 9042) + + cache.set_session(endpoint, session1) + self.assertEqual(cache.get_session(endpoint), session1) + + # Update with new session + cache.set_session(endpoint, session2) + self.assertEqual(cache.get_session(endpoint), session2) + + # Size should still be 1 + self.assertEqual(cache.size(), 1) + + def test_cache_thread_safety(self): + """Test that cache operations are thread-safe.""" + cache = DefaultTLSSessionCache(max_size=100, ttl=60) + errors = [] + + def set_sessions(thread_id): + try: + for i in range(50): + session = Mock(name=f'session_{thread_id}_{i}') + endpoint = MockEndPoint(f'host{thread_id}', 9042 + i) + cache.set_session(endpoint, session) + except Exception as e: + errors.append(e) + + def get_sessions(thread_id): + try: + for i in range(50): + endpoint = MockEndPoint(f'host{thread_id}', 9042 + i) + cache.get_session(endpoint) + except Exception as e: + errors.append(e) + + # Create multiple threads doing concurrent operations + threads = [] + for i in range(5): + t1 = Thread(target=set_sessions, args=(i,)) + t2 = Thread(target=get_sessions, args=(i,)) + threads.extend([t1, t2]) + + for t in threads: + t.start() + + for t in threads: + t.join() + + # Check that no errors occurred + self.assertEqual(len(errors), 0, f"Thread safety test failed with errors: {errors}") + + # Check that cache is not empty and within max size + self.assertGreater(cache.size(), 0) + self.assertLessEqual(cache.size(), 100) + + def test_cache_by_host_only(self): + """Test caching by host only (ignoring port).""" + cache = DefaultTLSSessionCache(max_size=10, ttl=60, cache_by_host_only=True) + + session = Mock(name='session') + + endpoint1 = MockEndPoint('host1', 9042) + endpoint2 = MockEndPoint('host1', 9043) # Same host, different port + + # Set session for first endpoint + cache.set_session(endpoint1, session) + self.assertEqual(cache.size(), 1) + + # Get session using second endpoint (same host, different port) + # Should return the same session because we're caching by host only + retrieved = cache.get_session(endpoint2) + self.assertEqual(retrieved, session) + + # Cache should still have size 1 + self.assertEqual(cache.size(), 1) + + def test_automatic_expired_cleanup(self): + """Test that expired sessions are cleaned up automatically during set_session.""" + cache = DefaultTLSSessionCache(max_size=10, ttl=1) + # Override cleanup interval for testing + cache._EXPIRY_CLEANUP_INTERVAL = 5 + + # Add some sessions that will expire + for i in range(3): + endpoint = MockEndPoint(f'host{i}', 9042) + cache.set_session(endpoint, Mock(name=f'session{i}')) + + self.assertEqual(cache.size(), 3) + + # Wait for sessions to expire + time.sleep(1.1) + + # Add sessions until cleanup is triggered (5 operations) + for i in range(5): + endpoint = MockEndPoint(f'newhost{i}', 9042) + cache.set_session(endpoint, Mock(name=f'newsession{i}')) + + # Expired sessions should have been cleaned up + # The 3 expired sessions should be removed + # Only the 5 new sessions should remain + self.assertEqual(cache.size(), 5)