diff --git a/sdcm/cluster.py b/sdcm/cluster.py index 6fa12c43d8..4efefc0cd0 100644 --- a/sdcm/cluster.py +++ b/sdcm/cluster.py @@ -17,6 +17,7 @@ import logging import os import shutil +import ssl import sys import random import re @@ -64,7 +65,8 @@ from sdcm.provision.scylla_yaml.certificate_builder import ScyllaYamlCertificateAttrBuilder from sdcm.provision.scylla_yaml.cluster_builder import ScyllaYamlClusterAttrBuilder from sdcm.provision.scylla_yaml.scylla_yaml import ScyllaYaml -from sdcm.provision.helpers.certificate import install_client_certificate, install_encryption_at_rest_files +from sdcm.provision.helpers.certificate import install_client_certificate, install_encryption_at_rest_files, CLIENT_KEYFILE, \ + CLIENT_CERTFILE, CLIENT_TRUSTSTORE from sdcm.remote import RemoteCmdRunnerBase, LOCALRUNNER, NETWORK_EXCEPTIONS, shell_script_cmd, RetryableNetworkException from sdcm.remote.libssh2_client import UnexpectedExit as Libssh2_UnexpectedExit from sdcm.remote.remote_file import remote_file, yaml_file_to_dict, dict_to_yaml_file @@ -3389,11 +3391,17 @@ def get_node_by_ip(self, node_ip, datacenter=None): return node return None - def _create_session(self, node, keyspace, user, password, compression, - # pylint: disable=too-many-arguments, too-many-locals - protocol_version, load_balancing_policy=None, - port=None, ssl_opts=None, node_ips=None, connect_timeout=None, - verbose=True, connection_bundle_file=None): + @staticmethod + def create_ssl_context(keyfile: str, certfile: str, truststore: str): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + ssl_context.load_verify_locations(cafile=truststore) + return ssl_context + + def _create_session(self, node, keyspace, user, password, compression, protocol_version, load_balancing_policy=None, port=None, + ssl_context=None, node_ips=None, connect_timeout=None, verbose=True, connection_bundle_file=None): if not port: port = node.CQL_PORT @@ -3409,10 +3417,12 @@ def _create_session(self, node, keyspace, user, password, compression, else: auth_provider = None - if ssl_opts is None and self.params.get('client_encrypt'): - ssl_opts = {'ca_certs': './data_dir/ssl_conf/client/catest.pem'} - self.log.debug(str(ssl_opts)) - kwargs = dict(contact_points=node_ips, port=port, ssl_options=ssl_opts) + if ssl_context is None and self.params.get('client_encrypt'): + ssl_context = self.create_ssl_context( + keyfile=CLIENT_KEYFILE, certfile=CLIENT_CERTFILE, truststore=CLIENT_TRUSTSTORE) + self.log.debug("ssl_context: %s", str(ssl_context)) + + kwargs = dict(contact_points=node_ips, port=port, ssl_context=ssl_context) if connection_bundle_file: kwargs = dict(scylla_cloud=connection_bundle_file) cluster_driver = ClusterDriver(auth_provider=auth_provider, @@ -3438,23 +3448,22 @@ def _create_session(self, node, keyspace, user, password, compression, def cql_connection(self, node, keyspace=None, user=None, # pylint: disable=too-many-arguments password=None, compression=True, protocol_version=None, - port=None, ssl_opts=None, connect_timeout=100, verbose=True): + port=None, ssl_context=None, connect_timeout=100, verbose=True): if connection_bundle_file := node.parent_cluster.connection_bundle_file: wlrr = None node_ips = [] else: node_ips = self.get_node_cql_ips() wlrr = WhiteListRoundRobinPolicy(node_ips) - return self._create_session(node=node, keyspace=keyspace, user=user, password=password, - compression=compression, protocol_version=protocol_version, - load_balancing_policy=wlrr, port=port, ssl_opts=ssl_opts, node_ips=node_ips, - connect_timeout=connect_timeout, verbose=verbose, + return self._create_session(node=node, keyspace=keyspace, user=user, password=password, compression=compression, + protocol_version=protocol_version, load_balancing_policy=wlrr, port=port, ssl_context=ssl_context, + node_ips=node_ips, connect_timeout=connect_timeout, verbose=verbose, connection_bundle_file=connection_bundle_file) def cql_connection_exclusive(self, node, keyspace=None, user=None, # pylint: disable=too-many-arguments,too-many-locals password=None, compression=True, protocol_version=None, port=None, - ssl_opts=None, connect_timeout=100, verbose=True): + ssl_context=None, connect_timeout=100, verbose=True): if connection_bundle_file := node.parent_cluster.connection_bundle_file: # TODO: handle the case of multiple datacenters bundle_yaml = yaml.safe_load(connection_bundle_file.open('r', encoding='utf-8')) @@ -3470,10 +3479,9 @@ def host_filter(host): else: node_ips = [node.cql_address] wlrr = WhiteListRoundRobinPolicy(node_ips) - return self._create_session(node=node, keyspace=keyspace, user=user, password=password, - compression=compression, protocol_version=protocol_version, - load_balancing_policy=wlrr, port=port, ssl_opts=ssl_opts, node_ips=node_ips, - connect_timeout=connect_timeout, verbose=verbose, + return self._create_session(node=node, keyspace=keyspace, user=user, password=password, compression=compression, + protocol_version=protocol_version, load_balancing_policy=wlrr, port=port, ssl_context=ssl_context, + node_ips=node_ips, connect_timeout=connect_timeout, verbose=verbose, connection_bundle_file=connection_bundle_file) @retrying(n=8, sleep_time=15, allowed_exceptions=(NoHostAvailable,)) @@ -3481,7 +3489,7 @@ def cql_connection_patient(self, node, keyspace=None, # pylint: disable=too-many-arguments,unused-argument user=None, password=None, compression=True, protocol_version=None, - port=None, ssl_opts=None, connect_timeout=100, verbose=True): + port=None, ssl_context=None, connect_timeout=100, verbose=True): """ Returns a connection after it stops throwing NoHostAvailables. @@ -3497,7 +3505,7 @@ def cql_connection_patient_exclusive(self, node, keyspace=None, user=None, password=None, compression=True, protocol_version=None, - port=None, ssl_opts=None, connect_timeout=100, verbose=True): + port=None, ssl_context=None, connect_timeout=100, verbose=True): """ Returns a connection after it stops throwing NoHostAvailables. diff --git a/sdcm/provision/helpers/certificate.py b/sdcm/provision/helpers/certificate.py index b02ddb9201..61fe549571 100644 --- a/sdcm/provision/helpers/certificate.py +++ b/sdcm/provision/helpers/certificate.py @@ -16,6 +16,10 @@ from sdcm.remote import shell_script_cmd from sdcm.utils.common import get_data_dir_path +CLIENT_KEYFILE = get_data_dir_path('ssl_conf', "client/test.key") +CLIENT_CERTFILE = get_data_dir_path('ssl_conf', "client/test.crt") +CLIENT_TRUSTSTORE = get_data_dir_path('ssl_conf', "client/catest.pem") + def install_client_certificate(remoter): if remoter.run('ls /etc/scylla/ssl_conf', ignore_status=True).ok: diff --git a/sdcm/provision/scylla_yaml/certificate_builder.py b/sdcm/provision/scylla_yaml/certificate_builder.py index 8285d509f8..c644f47c22 100644 --- a/sdcm/provision/scylla_yaml/certificate_builder.py +++ b/sdcm/provision/scylla_yaml/certificate_builder.py @@ -10,13 +10,13 @@ # See LICENSE for more details. # # Copyright (c) 2021 ScyllaDB - +import os from functools import cached_property from typing import Optional, Any from pydantic import Field -from sdcm.provision.helpers.certificate import install_client_certificate +from sdcm.provision.helpers.certificate import install_client_certificate, CLIENT_CERTFILE, CLIENT_KEYFILE, CLIENT_TRUSTSTORE from sdcm.provision.scylla_yaml.auxiliaries import ScyllaYamlAttrBuilderBase, ClientEncryptionOptions, \ ServerEncryptionOptions @@ -40,9 +40,9 @@ def client_encryption_options(self) -> Optional[ClientEncryptionOptions]: return None return ClientEncryptionOptions( enabled=True, - certificate=self._ssl_files_path + '/client/test.crt', - keyfile=self._ssl_files_path + '/client/test.key', - truststore=self._ssl_files_path + '/client/catest.pem', + certificate=os.path.join(self._ssl_files_path, 'client', os.path.basename(CLIENT_CERTFILE)), + keyfile=os.path.join(self._ssl_files_path, 'client', os.path.basename(CLIENT_KEYFILE)), + truststore=os.path.join(self._ssl_files_path, 'client', os.path.basename(CLIENT_TRUSTSTORE)), ) @property diff --git a/unit_tests/test_python_driver.py b/unit_tests/test_python_driver.py index 48f10474d0..721d73ad47 100644 --- a/unit_tests/test_python_driver.py +++ b/unit_tests/test_python_driver.py @@ -42,3 +42,26 @@ def test_01_test_python_driver_serverless_connectivity(params): output = res.all() log.debug(output) assert len(output) == 1 + + +@pytest.mark.parametrize('encrypted', [ + pytest.param(True, marks=pytest.mark.docker_scylla_args(ssl=True), id='encrypted'), + pytest.param(False, marks=pytest.mark.docker_scylla_args(ssl=False), id='clear') +]) +def test_02_test_python_driver(docker_scylla, params, encrypted): + + params['client_encrypt'] = encrypted + node = docker_scylla + db_cluster = DummyDbCluster(nodes=[node], params=params) + node.parent_cluster = db_cluster + + for func in [db_cluster.cql_connection_patient, + db_cluster.cql_connection_patient_exclusive]: + + with func(node) as session: + for host in session.cluster.metadata.all_hosts(): + log.debug(host) + res = session.execute("SELECT * FROM system.local") + output = res.all() + log.debug(output) + assert len(output) == 1