diff --git a/broker.go b/broker.go index 2d75a8cd3..648e355f9 100644 --- a/broker.go +++ b/broker.go @@ -1,12 +1,16 @@ package sarama import ( + "bytes" "crypto/tls" "encoding/binary" + "encoding/json" "errors" "fmt" "io" "net" + "net/http" + "net/url" "sort" "strconv" "strings" @@ -14,6 +18,12 @@ import ( "sync/atomic" "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/session" + sign "github.com/aws/aws-sdk-go/aws/signer/v4" + "github.com/aws/aws-sdk-go/service/sts" "github.com/rcrowley/go-metrics" ) @@ -68,6 +78,8 @@ const ( // SASLTypeSCRAMSHA512 represents the SCRAM-SHA-512 mechanism. SASLTypeSCRAMSHA512 = "SCRAM-SHA-512" SASLTypeGSSAPI = "GSSAPI" + // SASLTypeAWSMSKIAM represents the SASL IAM mechanism + SASLTypeAWSMSKIAM = "AWS_MSK_IAM" // SASLHandshakeV0 is v0 of the Kafka SASL handshake protocol. Client and // server negotiate SASL auth using opaque packets. SASLHandshakeV0 = int16(0) @@ -77,6 +89,8 @@ const ( // SASLExtKeyAuth is the reserved extension key name sent as part of the // SASL/OAUTHBEARER initial client response SASLExtKeyAuth = "auth" + + IAMAuthVersion = "2020_10_22" ) // AccessToken contains an access token used to authenticate a @@ -1100,6 +1114,8 @@ func (b *Broker) authenticateViaSASL() error { return b.sendAndReceiveSASLOAuth(b.conf.Net.SASL.TokenProvider) case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512: return b.sendAndReceiveSASLSCRAM() + case SASLTypeAWSMSKIAM: + return b.sendAndReceiveSASLIAM() case SASLTypeGSSAPI: return b.sendAndReceiveKerberos() default: @@ -1490,6 +1506,7 @@ func (b *Broker) receiveSaslAuthenticateResponse(correlationID int32) ([]byte, e if err := versionedDecode(buf, res, 0); err != nil { return nil, err } + if !errors.Is(res.Err, ErrNoError) { return nil, res.Err } @@ -1714,3 +1731,130 @@ func validServerNameTLS(addr string, cfg *tls.Config) *tls.Config { c.ServerName = sn return c } + +func (b *Broker) sendAndReceiveSASLIAM() error { + if err := b.sendAndReceiveSASLHandshake(SASLTypeAWSMSKIAM, SASLHandshakeV1); err != nil { + return err + } + + msg, err := getIAMPayload( + b.addr, + b.conf.ClientID, + b.conf.Net.SASL.AWSMSKIAM, + ) + if err != nil { + return err + } + + requestTime := time.Now() + // Will be decremented in updateIncomingCommunicationMetrics (except error) + b.addRequestInFlightMetrics(1) + correlationID := b.correlationID + + bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, msg) + b.updateOutgoingCommunicationMetrics(bytesWritten) + if err != nil { + b.addRequestInFlightMetrics(-1) + Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error()) + return err + } + b.correlationID++ + bytesRead, err := b.receiveSaslAuthenticateResponse(correlationID) + if err != nil { + b.addRequestInFlightMetrics(-1) + Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error()) + return err + } + + resp := struct { + Version string `json:"version"` + RequestID string `json:"request-id"` + }{} + err = json.NewDecoder(bytes.NewReader(bytesRead)).Decode(&resp) + if err != nil { + return fmt.Errorf("unable to process msk response: %w", err) + } + if resp.Version != IAMAuthVersion { + return fmt.Errorf("unknown version found in response") + } + + requestLatency := time.Since(requestTime) + b.updateIncomingCommunicationMetrics(len(bytesRead), requestLatency) + + DebugLogger.Println("SASL authentication succeeded") + return nil +} + +func getIAMPayload(addr, useragent string, cfg AWSMSKIAMConfig) ([]byte, error) { + sess, err := session.NewSession(&aws.Config{Region: &cfg.Region}) + if err != nil { + return nil, err + } + + signer := sign.NewSigner( + credentials.NewChainCredentials([]credentials.Provider{ + &credentials.EnvProvider{}, + &credentials.StaticProvider{ + Value: credentials.Value{ + AccessKeyID: cfg.AccessKeyID, + SecretAccessKey: cfg.SecretAccessKey, + SessionToken: cfg.SessionToken, + }, + }, + stscreds.NewWebIdentityRoleProviderWithOptions( + sts.New(sess), + cfg.RoleArn, + useragent, + stscreds.FetchTokenPath(cfg.WebIdentityTokenFile), + ), + }), + ) + + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + + var action = "kafka-cluster:Connect" + + q := url.Values{ + "Action": {action}, + } + + u := url.URL{ + Host: host, + Path: "/", + RawQuery: q.Encode(), + } + + req, err := http.NewRequest("GET", u.String(), nil) + if err != nil { + return nil, err + } + + if cfg.Expiry == time.Duration(0) { + cfg.Expiry = 5 * time.Minute + } + + header, err := signer.Presign(req, nil, "kafka-cluster", cfg.Region, cfg.Expiry, time.Now().UTC()) + if err != nil { + return nil, err + } + + payload := map[string]string{ + "version": IAMAuthVersion, + "host": host, + "user-agent": useragent, + "action": action, + } + + for key, vals := range header { + payload[strings.ToLower(key)] = vals[0] + } + + for key, vals := range req.URL.Query() { + payload[strings.ToLower(key)] = vals[0] + } + + return json.Marshal(payload) +} diff --git a/config.go b/config.go index ef8bf8df6..d4e5210e5 100644 --- a/config.go +++ b/config.go @@ -17,6 +17,16 @@ const defaultClientID = "sarama" var validID = regexp.MustCompile(`\A[A-Za-z0-9._-]+\z`) +type AWSMSKIAMConfig struct { + Region string + AccessKeyID string + SecretAccessKey string + SessionToken string + RoleArn string + WebIdentityTokenFile string + Expiry time.Duration +} + // Config is used to pass multiple configuration options to Sarama's constructors. type Config struct { // Admin is the namespace for ClusterAdmin properties used by the administrative Kafka client. @@ -97,6 +107,8 @@ type Config struct { TokenProvider AccessTokenProvider GSSAPI GSSAPIConfig + + AWSMSKIAM AWSMSKIAMConfig } // KeepAlive specifies the keep-alive period for an active network connection (defaults to 0). @@ -586,6 +598,10 @@ func (c *Config) Validate() error { if c.Net.SASL.TokenProvider == nil { return ConfigurationError("An AccessTokenProvider instance must be provided to Net.SASL.TokenProvider") } + case SASLTypeAWSMSKIAM: + if c.Net.SASL.AWSMSKIAM.Region == "" { + return ConfigurationError("AWSMSKIAM.Region must be set when AWS_MSK_IAM is enabled") + } case SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512: if c.Net.SASL.User == "" { return ConfigurationError("Net.SASL.User must not be empty when SASL is enabled") @@ -624,8 +640,8 @@ func (c *Config) Validate() error { return ConfigurationError("Net.SASL.GSSAPI.Realm must not be empty when GSS-API mechanism is used") } default: - msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s`, `%s` and `%s`", - SASLTypeOAuth, SASLTypePlaintext, SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512, SASLTypeGSSAPI) + msg := fmt.Sprintf("The SASL mechanism configuration is invalid. Possible values are `%s`, `%s`, `%s`, `%s`, `%s` and `%s`", + SASLTypeAWSMSKIAM, SASLTypeOAuth, SASLTypePlaintext, SASLTypeSCRAMSHA256, SASLTypeSCRAMSHA512, SASLTypeGSSAPI) return ConfigurationError(msg) } } diff --git a/go.mod b/go.mod index c7978d7d2..9871c4b15 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.16 require ( github.com/Shopify/toxiproxy/v2 v2.3.0 + github.com/aws/aws-sdk-go v1.43.22 github.com/davecgh/go-spew v1.1.1 github.com/eapache/go-resiliency v1.2.0 github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 diff --git a/go.sum b/go.sum index 45772e7b5..721a019db 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/Shopify/toxiproxy/v2 v2.3.0 h1:62YkpiP4bzdhKMH+6uC5E95y608k3zDwdzuBMsnn3uQ= github.com/Shopify/toxiproxy/v2 v2.3.0/go.mod h1:KvQTtB6RjCJY4zqNJn7C7JDFgsG5uoHYDirfUfpIm0c= +github.com/aws/aws-sdk-go v1.43.22 h1:QY9/1TZB73UDEVQ68sUVJXf/7QUiHZl7zbbLF1wpqlc= +github.com/aws/aws-sdk-go v1.43.22/go.mod h1:y4AeaBuwd2Lk+GepC1E9v0qOiTws0MIWAX4oIKwKHZo= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -41,6 +43,10 @@ github.com/jcmturner/gokrb5/v8 v8.4.2 h1:6ZIM6b/JJN0X8UM43ZOM6Z4SJzla+a/u7scXFJz github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/klauspost/compress v1.15.0 h1:xqfchp4whNFxn5A4XFyyYtitiWI8Hy5EW59jEwcyL6U= github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -53,6 +59,7 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= @@ -82,6 +89,7 @@ golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -109,6 +117,8 @@ gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=