Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python driver): use 'ssl_context' instead 'ssl_options' #7142

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 30 additions & 22 deletions sdcm/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
import os
import shutil
import ssl
import sys
import random
import re
Expand Down Expand Up @@ -64,7 +65,8 @@
from sdcm.provision.scylla_yaml.certificate_builder import ScyllaYamlCertificateAttrBuilder
from sdcm.provision.scylla_yaml.cluster_builder import ScyllaYamlClusterAttrBuilder
from sdcm.provision.scylla_yaml.scylla_yaml import ScyllaYaml
from sdcm.provision.helpers.certificate import install_client_certificate, install_encryption_at_rest_files
from sdcm.provision.helpers.certificate import install_client_certificate, install_encryption_at_rest_files, CLIENT_KEYFILE, \
fruch marked this conversation as resolved.
Show resolved Hide resolved
CLIENT_CERTFILE, CLIENT_TRUSTSTORE
from sdcm.remote import RemoteCmdRunnerBase, LOCALRUNNER, NETWORK_EXCEPTIONS, shell_script_cmd, RetryableNetworkException
from sdcm.remote.libssh2_client import UnexpectedExit as Libssh2_UnexpectedExit
from sdcm.remote.remote_file import remote_file, yaml_file_to_dict, dict_to_yaml_file
Expand Down Expand Up @@ -3389,11 +3391,17 @@ def get_node_by_ip(self, node_ip, datacenter=None):
return node
return None

def _create_session(self, node, keyspace, user, password, compression,
# pylint: disable=too-many-arguments, too-many-locals
protocol_version, load_balancing_policy=None,
port=None, ssl_opts=None, node_ips=None, connect_timeout=None,
verbose=True, connection_bundle_file=None):
@staticmethod
def create_ssl_context(keyfile: str, certfile: str, truststore: str):
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_context.load_cert_chain(certfile=certfile, keyfile=keyfile)
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
ssl_context.load_verify_locations(cafile=truststore)
return ssl_context

def _create_session(self, node, keyspace, user, password, compression, protocol_version, load_balancing_policy=None, port=None,
ssl_context=None, node_ips=None, connect_timeout=None, verbose=True, connection_bundle_file=None):
if not port:
port = node.CQL_PORT

Expand All @@ -3409,10 +3417,12 @@ def _create_session(self, node, keyspace, user, password, compression,
else:
auth_provider = None

if ssl_opts is None and self.params.get('client_encrypt'):
ssl_opts = {'ca_certs': './data_dir/ssl_conf/client/catest.pem'}
self.log.debug(str(ssl_opts))
kwargs = dict(contact_points=node_ips, port=port, ssl_options=ssl_opts)
if ssl_context is None and self.params.get('client_encrypt'):
ssl_context = self.create_ssl_context(
keyfile=CLIENT_KEYFILE, certfile=CLIENT_CERTFILE, truststore=CLIENT_TRUSTSTORE)
self.log.debug("ssl_context: %s", str(ssl_context))

kwargs = dict(contact_points=node_ips, port=port, ssl_context=ssl_context)
if connection_bundle_file:
kwargs = dict(scylla_cloud=connection_bundle_file)
cluster_driver = ClusterDriver(auth_provider=auth_provider,
Expand All @@ -3438,23 +3448,22 @@ def _create_session(self, node, keyspace, user, password, compression,

def cql_connection(self, node, keyspace=None, user=None, # pylint: disable=too-many-arguments
password=None, compression=True, protocol_version=None,
port=None, ssl_opts=None, connect_timeout=100, verbose=True):
port=None, ssl_context=None, connect_timeout=100, verbose=True):
if connection_bundle_file := node.parent_cluster.connection_bundle_file:
wlrr = None
node_ips = []
else:
node_ips = self.get_node_cql_ips()
wlrr = WhiteListRoundRobinPolicy(node_ips)
return self._create_session(node=node, keyspace=keyspace, user=user, password=password,
compression=compression, protocol_version=protocol_version,
load_balancing_policy=wlrr, port=port, ssl_opts=ssl_opts, node_ips=node_ips,
connect_timeout=connect_timeout, verbose=verbose,
return self._create_session(node=node, keyspace=keyspace, user=user, password=password, compression=compression,
protocol_version=protocol_version, load_balancing_policy=wlrr, port=port, ssl_context=ssl_context,
node_ips=node_ips, connect_timeout=connect_timeout, verbose=verbose,
connection_bundle_file=connection_bundle_file)

def cql_connection_exclusive(self, node, keyspace=None, user=None, # pylint: disable=too-many-arguments,too-many-locals
password=None, compression=True,
protocol_version=None, port=None,
ssl_opts=None, connect_timeout=100, verbose=True):
ssl_context=None, connect_timeout=100, verbose=True):
if connection_bundle_file := node.parent_cluster.connection_bundle_file:
# TODO: handle the case of multiple datacenters
bundle_yaml = yaml.safe_load(connection_bundle_file.open('r', encoding='utf-8'))
Expand All @@ -3470,18 +3479,17 @@ def host_filter(host):
else:
node_ips = [node.cql_address]
wlrr = WhiteListRoundRobinPolicy(node_ips)
return self._create_session(node=node, keyspace=keyspace, user=user, password=password,
compression=compression, protocol_version=protocol_version,
load_balancing_policy=wlrr, port=port, ssl_opts=ssl_opts, node_ips=node_ips,
connect_timeout=connect_timeout, verbose=verbose,
return self._create_session(node=node, keyspace=keyspace, user=user, password=password, compression=compression,
protocol_version=protocol_version, load_balancing_policy=wlrr, port=port, ssl_context=ssl_context,
node_ips=node_ips, connect_timeout=connect_timeout, verbose=verbose,
connection_bundle_file=connection_bundle_file)

@retrying(n=8, sleep_time=15, allowed_exceptions=(NoHostAvailable,))
def cql_connection_patient(self, node, keyspace=None,
# pylint: disable=too-many-arguments,unused-argument
user=None, password=None,
compression=True, protocol_version=None,
port=None, ssl_opts=None, connect_timeout=100, verbose=True):
port=None, ssl_context=None, connect_timeout=100, verbose=True):
"""
Returns a connection after it stops throwing NoHostAvailables.

Expand All @@ -3497,7 +3505,7 @@ def cql_connection_patient_exclusive(self, node, keyspace=None,
user=None, password=None,
compression=True,
protocol_version=None,
port=None, ssl_opts=None, connect_timeout=100, verbose=True):
port=None, ssl_context=None, connect_timeout=100, verbose=True):
"""
Returns a connection after it stops throwing NoHostAvailables.

Expand Down
4 changes: 4 additions & 0 deletions sdcm/provision/helpers/certificate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from sdcm.remote import shell_script_cmd
from sdcm.utils.common import get_data_dir_path

CLIENT_KEYFILE = get_data_dir_path('ssl_conf', "client/test.key")
CLIENT_CERTFILE = get_data_dir_path('ssl_conf', "client/test.crt")
CLIENT_TRUSTSTORE = get_data_dir_path('ssl_conf', "client/catest.pem")


def install_client_certificate(remoter):
if remoter.run('ls /etc/scylla/ssl_conf', ignore_status=True).ok:
Expand Down
10 changes: 5 additions & 5 deletions sdcm/provision/scylla_yaml/certificate_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
# See LICENSE for more details.
#
# Copyright (c) 2021 ScyllaDB

import os
from functools import cached_property
from typing import Optional, Any

from pydantic import Field

from sdcm.provision.helpers.certificate import install_client_certificate
from sdcm.provision.helpers.certificate import install_client_certificate, CLIENT_CERTFILE, CLIENT_KEYFILE, CLIENT_TRUSTSTORE
from sdcm.provision.scylla_yaml.auxiliaries import ScyllaYamlAttrBuilderBase, ClientEncryptionOptions, \
ServerEncryptionOptions

Expand All @@ -40,9 +40,9 @@ def client_encryption_options(self) -> Optional[ClientEncryptionOptions]:
return None
return ClientEncryptionOptions(
enabled=True,
certificate=self._ssl_files_path + '/client/test.crt',
keyfile=self._ssl_files_path + '/client/test.key',
truststore=self._ssl_files_path + '/client/catest.pem',
certificate=os.path.join(self._ssl_files_path, 'client', os.path.basename(CLIENT_CERTFILE)),
keyfile=os.path.join(self._ssl_files_path, 'client', os.path.basename(CLIENT_KEYFILE)),
truststore=os.path.join(self._ssl_files_path, 'client', os.path.basename(CLIENT_TRUSTSTORE)),
)

@property
Expand Down
23 changes: 23 additions & 0 deletions unit_tests/test_python_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,26 @@ def test_01_test_python_driver_serverless_connectivity(params):
output = res.all()
log.debug(output)
assert len(output) == 1


@pytest.mark.parametrize('encrypted', [
pytest.param(True, marks=pytest.mark.docker_scylla_args(ssl=True), id='encrypted'),
pytest.param(False, marks=pytest.mark.docker_scylla_args(ssl=False), id='clear')
])
def test_02_test_python_driver(docker_scylla, params, encrypted):

params['client_encrypt'] = encrypted
node = docker_scylla
db_cluster = DummyDbCluster(nodes=[node], params=params)
node.parent_cluster = db_cluster

for func in [db_cluster.cql_connection_patient,
db_cluster.cql_connection_patient_exclusive]:

with func(node) as session:
for host in session.cluster.metadata.all_hosts():
log.debug(host)
res = session.execute("SELECT * FROM system.local")
output = res.all()
log.debug(output)
assert len(output) == 1