Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MSK IAM Authentication #74

Merged
merged 7 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion kafka/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class SimpleClient(object):
# socket timeout.
def __init__(self, hosts, client_id=CLIENT_ID,
timeout=DEFAULT_SOCKET_TIMEOUT_SECONDS,
correlation_id=0, metrics=None):
correlation_id=0, metrics=None, **kwargs):
# We need one connection to bootstrap
self.client_id = client_id
self.timeout = timeout
Expand All @@ -90,6 +90,10 @@ def __init__(self, hosts, client_id=CLIENT_ID,
self.topics_to_brokers = {} # TopicPartition -> BrokerMetadata
self.topic_partitions = {} # topic -> partition -> leader

# Support arbitrary kwargs to be provided as config to BrokerConnection
# This will allow advanced features like Authentication to work
self.config = kwargs

self.load_metadata_for_topics() # bootstrap with all metadata

##################
Expand All @@ -108,6 +112,7 @@ def _get_conn(self, host, port, afi, node_id='bootstrap'):
metrics=self._metrics_registry,
metric_group_prefix='simple-client',
node_id=node_id,
**self.config,
)

conn = self._conns[host_key]
Expand Down
50 changes: 49 additions & 1 deletion kafka/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import kafka.errors as Errors
from kafka.future import Future
from kafka.metrics.stats import Avg, Count, Max, Rate
from kafka.msk import AwsMskIamClient
from kafka.oauth.abstract import AbstractTokenProvider
from kafka.protocol.admin import SaslHandShakeRequest
from kafka.protocol.commit import OffsetFetchRequest
Expand Down Expand Up @@ -81,6 +82,12 @@ class SSLWantWriteError(Exception):
gssapi = None
GSSError = None

# needed for AWS_MSK_IAM authentication:
try:
from botocore.session import Session as BotoSession
except ImportError:
# no botocore available, will disable AWS_MSK_IAM mechanism
BotoSession = None

AFI_NAMES = {
socket.AF_UNSPEC: "unspecified",
Expand Down Expand Up @@ -224,7 +231,7 @@ class BrokerConnection(object):
'sasl_oauth_token_provider': None
}
SECURITY_PROTOCOLS = ('PLAINTEXT', 'SSL', 'SASL_PLAINTEXT', 'SASL_SSL')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER')
SASL_MECHANISMS = ('PLAIN', 'GSSAPI', 'OAUTHBEARER', 'AWS_MSK_IAM')

def __init__(self, host, port, afi, **configs):
self.host = host
Expand Down Expand Up @@ -269,6 +276,11 @@ def __init__(self, host, port, afi, **configs):
token_provider = self.config['sasl_oauth_token_provider']
assert token_provider is not None, 'sasl_oauth_token_provider required for OAUTHBEARER sasl'
assert callable(getattr(token_provider, "token", None)), 'sasl_oauth_token_provider must implement method #token()'

if self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
assert BotoSession is not None, 'AWS_MSK_IAM requires the "botocore" package'
assert self.config['security_protocol'] == 'SASL_SSL', 'AWS_MSK_IAM requires SASL_SSL'

# This is not a general lock / this class is not generally thread-safe yet
# However, to avoid pushing responsibility for maintaining
# per-connection locks to the upstream client, we will use this lock to
Expand Down Expand Up @@ -552,6 +564,8 @@ def _handle_sasl_handshake_response(self, future, response):
return self._try_authenticate_gssapi(future)
elif self.config['sasl_mechanism'] == 'OAUTHBEARER':
return self._try_authenticate_oauth(future)
elif self.config['sasl_mechanism'] == 'AWS_MSK_IAM':
return self._try_authenticate_aws_msk_iam(future)
else:
return future.failure(
Errors.UnsupportedSaslMechanismError(
Expand Down Expand Up @@ -652,6 +666,40 @@ def _try_authenticate_plain(self, future):
log.info('%s: Authenticated as %s via PLAIN', self, self.config['sasl_plain_username'])
return future.success(True)

def _try_authenticate_aws_msk_iam(self, future):
session = BotoSession()
danielpops marked this conversation as resolved.
Show resolved Hide resolved
client = AwsMskIamClient(
host=self.host,
boto_session=session,
)

msg = client.first_message()
size = Int32.encode(len(msg))

err = None
close = False
with self._lock:
if not self._can_send_recv():
err = Errors.NodeNotReadyError(str(self))
close = False
else:
try:
self._send_bytes_blocking(size + msg)
data = self._recv_bytes_blocking(4)
data = self._recv_bytes_blocking(struct.unpack('4B', data)[-1])
except (ConnectionError, TimeoutError) as e:
log.exception("%s: Error receiving reply from server", self)
err = Errors.KafkaConnectionError("%s: %s" % (self, e))
close = True

if err is not None:
if close:
self.close(error=err)
return future.failure(err)

log.info('%s: Authenticated via AWS_MSK_IAM %s', self, data.decode('utf-8'))
return future.success(True)

def _try_authenticate_gssapi(self, future):
kerberos_damin_name = self.config['sasl_kerberos_domain_name'] or self.host
auth_id = self.config['sasl_kerberos_service_name'] + '@' + kerberos_damin_name
Expand Down
213 changes: 213 additions & 0 deletions kafka/msk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import datetime
import hashlib
import hmac
import json
import string

from kafka.errors import IllegalArgumentError
from kafka.vendor.six.moves import urllib


class AwsMskIamClient:
UNRESERVED_CHARS = string.ascii_letters + string.digits + '-._~'

def __init__(self, host, boto_session):
"""
Arguments:
host (str): The hostname of the broker.
boto_session (botocore.BotoSession) the boto session
"""
self.algorithm = 'AWS4-HMAC-SHA256'
self.expires = '900'
self.hashfunc = hashlib.sha256
self.headers = [
('host', host)
]
self.version = '2020_10_22'

self.service = 'kafka-cluster'
self.action = '{}:Connect'.format(self.service)

now = datetime.datetime.utcnow()
self.datestamp = now.strftime('%Y%m%d')
self.timestamp = now.strftime('%Y%m%dT%H%M%SZ')

self.host = host
self.boto_session = boto_session

# This will raise if the region can't be determined
# Do this during init instead of waiting for failures downstream
if self.region:
pass

@property
def access_key(self):
return self.boto_session.get_credentials().access_key

@property
def secret_key(self):
return self.boto_session.get_credentials().secret_key

@property
def token(self):
return self.boto_session.get_credentials().token

@property
def region(self):
# Try to get the region information from the broker hostname
for host in self.host.split(','):
if 'amazonaws.com' in host:
return host.split('.')[-3]

# If the region can't be determined from hostname, try the boto session
# This will only have a value if:
# - `AWS_DEFAULT_REGION` environment variable is set
# - `~/.aws/config` region variable is set
region = self.boto_session.get_config_variable('region')
if region:
return region

# Otherwise give up
raise IllegalArgumentError('Could not determine region from broker host(s) or aws configuration')

@property
def _credential(self):
return '{0.access_key}/{0._scope}'.format(self)

@property
def _scope(self):
return '{0.datestamp}/{0.region}/{0.service}/aws4_request'.format(self)

@property
def _signed_headers(self):
"""
Returns (str):
An alphabetically sorted, semicolon-delimited list of lowercase
request header names.
"""
return ';'.join(sorted(k.lower() for k, _ in self.headers))

@property
def _canonical_headers(self):
"""
Returns (str):
A newline-delited list of header names and values.
Header names are lowercased.
"""
return '\n'.join(map(':'.join, self.headers)) + '\n'

@property
def _canonical_request(self):
"""
Returns (str):
An AWS Signature Version 4 canonical request in the format:
<Method>\n
<Path>\n
<CanonicalQueryString>\n
<CanonicalHeaders>\n
<SignedHeaders>\n
<HashedPayload>
"""
# The hashed_payload is always an empty string for MSK.
hashed_payload = self.hashfunc(b'').hexdigest()
return '\n'.join((
'GET',
'/',
self._canonical_querystring,
self._canonical_headers,
self._signed_headers,
hashed_payload,
))

@property
def _canonical_querystring(self):
"""
Returns (str):
A '&'-separated list of URI-encoded key/value pairs.
"""
params = []
params.append(('Action', self.action))
params.append(('X-Amz-Algorithm', self.algorithm))
params.append(('X-Amz-Credential', self._credential))
params.append(('X-Amz-Date', self.timestamp))
params.append(('X-Amz-Expires', self.expires))
if self.token:
params.append(('X-Amz-Security-Token', self.token))
params.append(('X-Amz-SignedHeaders', self._signed_headers))

return '&'.join(self._uriencode(k) + '=' + self._uriencode(v) for k, v in params)

@property
def _signing_key(self):
"""
Returns (bytes):
An AWS Signature V4 signing key generated from the secret_key, date,
region, service, and request type.
"""
key = self._hmac(('AWS4' + self.secret_key).encode('utf-8'), self.datestamp)
key = self._hmac(key, self.region)
key = self._hmac(key, self.service)
key = self._hmac(key, 'aws4_request')
return key

@property
def _signing_str(self):
"""
Returns (str):
A string used to sign the AWS Signature V4 payload in the format:
<Algorithm>\n
<Timestamp>\n
<Scope>\n
<CanonicalRequestHash>
"""
canonical_request_hash = self.hashfunc(self._canonical_request.encode('utf-8')).hexdigest()
return '\n'.join((self.algorithm, self.timestamp, self._scope, canonical_request_hash))

def _uriencode(self, msg):
"""
Arguments:
msg (str): A string to URI-encode.

Returns (str):
The URI-encoded version of the provided msg, following the encoding
rules specified: https://github.com/aws/aws-msk-iam-auth#uriencode
"""
return urllib.parse.quote(msg, safe=self.UNRESERVED_CHARS)

def _hmac(self, key, msg):
"""
Arguments:
key (bytes): A key to use for the HMAC digest.
msg (str): A value to include in the HMAC digest.
Returns (bytes):
An HMAC digest of the given key and msg.
"""
return hmac.new(key, msg.encode('utf-8'), digestmod=self.hashfunc).digest()

def first_message(self):
"""
Returns (bytes):
An encoded JSON authentication payload that can be sent to the
broker.
"""
signature = hmac.new(
self._signing_key,
self._signing_str.encode('utf-8'),
digestmod=self.hashfunc,
).hexdigest()
msg = {
'version': self.version,
'host': self.host,
'user-agent': 'kafka-python',
'action': self.action,
'x-amz-algorithm': self.algorithm,
'x-amz-credential': self._credential,
'x-amz-date': self.timestamp,
'x-amz-signedheaders': self._signed_headers,
'x-amz-expires': self.expires,
'x-amz-signature': signature,
}
if self.token:
msg['x-amz-security-token'] = self.token

return json.dumps(msg, separators=(',', ':')).encode('utf-8')
Loading