Skip to content

Commit

Permalink
Auto detect authentication type
Browse files Browse the repository at this point in the history
Signed-off-by: utkarshm <utkarshm@jfrog.com>
  • Loading branch information
utkarshm committed Jun 30, 2023
1 parent 812c77e commit 1829554
Showing 1 changed file with 63 additions and 30 deletions.
93 changes: 63 additions & 30 deletions email.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/mail"
"net/textproto"
"strconv"
"strings"
"time"

"github.com/toorop/go-dkim"
Expand Down Expand Up @@ -134,8 +135,22 @@ const (
AuthCRAMMD5
// AuthNone for SMTP servers without authentication
AuthNone
AuthAuto
)

func (at AuthType) String() string {
switch at {
case 0:
return "PLAIN"
case 1:
return "LOGIN"
case 2:
return "CRAM-MD5"
default:
return ""
}
}

// NewMSG creates a new email. It uses UTF-8 by default. All charsets: http://webcheatsheet.com/HTML/character_sets_list.php
func NewMSG() *Email {
email := &Email{
Expand All @@ -152,7 +167,7 @@ func NewMSG() *Email {
// NewSMTPClient returns the client for send email
func NewSMTPClient() *SMTPServer {
server := &SMTPServer{
Authentication: AuthPlain,
Authentication: AuthAuto,
Encryption: EncryptionNone,
ConnectTimeout: 10 * time.Second,
SendTimeout: 10 * time.Second,
Expand Down Expand Up @@ -717,7 +732,7 @@ func dial(host string, port string, encryption Encryption, config *tls.Config) (

// smtpConnect connects to the smtp server and starts TLS and passes auth
// if necessary
func smtpConnect(host, port, helo string, a auth, at AuthType, encryption Encryption, config *tls.Config) (*smtpClient, error) {
func smtpConnect(host, port, helo string, encryption Encryption, config *tls.Config) (*smtpClient, error) {
// connect to the mail server
c, err := dial(host, port, encryption, config)

Expand Down Expand Up @@ -746,42 +761,60 @@ func smtpConnect(host, port, helo string, a auth, at AuthType, encryption Encryp
}
}

// only pass authentication if defined
if at != AuthNone {
// pass the authentication if necessary
if a != nil {
if ok, _ := c.extension("AUTH"); ok {
if err = c.authenticate(a); err != nil {
c.close()
return nil, fmt.Errorf("Mail Error on Auth: %w", err)
}
}
}
}

return c, nil
}

// Connect returns the smtp client
func (server *SMTPServer) Connect() (*SMTPClient, error) {

var a auth

switch server.Authentication {
case AuthPlain:
func (server *SMTPServer) getAuth(a string) (auth, error) {
var afn auth
switch {
case strings.Contains(a, AuthPlain.String()):
if server.Username != "" || server.Password != "" {
a = plainAuthfn("", server.Username, server.Password, server.Host)
afn = plainAuthfn("", server.Username, server.Password, server.Host)
}
case AuthLogin:
case strings.Contains(a, AuthLogin.String()):
if server.Username != "" || server.Password != "" {
a = loginAuthfn("", server.Username, server.Password, server.Host)
afn = loginAuthfn("", server.Username, server.Password, server.Host)
}
case AuthCRAMMD5:
case strings.Contains(a, AuthCRAMMD5.String()):
if server.Username != "" || server.Password != "" {
a = cramMD5Authfn(server.Username, server.Password)
afn = cramMD5Authfn(server.Username, server.Password)
}
default:
return nil, fmt.Errorf("Mail Error on determining auth type, %s is not supported", a)
}
return afn, nil
}

func (server *SMTPServer) validateAuth(c *smtpClient) error {
var err error
var afn auth
switch {
case server.Authentication == AuthNone || server.Username == "":
return nil
case server.Authentication != AuthAuto:
afn, err = server.getAuth(server.Authentication.String())
if err != nil {
return err
}
}
if ok, a := c.extension("AUTH"); ok {
// Determine Auth type automatically from extension
if afn == nil {
afn, err = server.getAuth(a)
if err != nil {
return err
}
}
if err = c.authenticate(afn); err != nil {
c.close()
return fmt.Errorf("Mail Error on Auth: %w", err)
}
}
return nil
}

// Connect returns the smtp client
func (server *SMTPServer) Connect() (*SMTPClient, error) {
var smtpConnectChannel chan error
var c *smtpClient
var err error
Expand All @@ -795,7 +828,7 @@ func (server *SMTPServer) Connect() (*SMTPClient, error) {
if server.ConnectTimeout != 0 {
smtpConnectChannel = make(chan error, 2)
go func() {
c, err = smtpConnect(server.Host, fmt.Sprintf("%d", server.Port), server.Helo, a, server.Authentication, server.Encryption, tlsConfig)
c, err = smtpConnect(server.Host, fmt.Sprintf("%d", server.Port), server.Helo, server.Encryption, tlsConfig)
// send the result
smtpConnectChannel <- err
}()
Expand All @@ -810,7 +843,7 @@ func (server *SMTPServer) Connect() (*SMTPClient, error) {
}
} else {
// no ConnectTimeout, just fire the connect
c, err = smtpConnect(server.Host, fmt.Sprintf("%d", server.Port), server.Helo, a, server.Authentication, server.Encryption, tlsConfig)
c, err = smtpConnect(server.Host, fmt.Sprintf("%d", server.Port), server.Helo, server.Encryption, tlsConfig)
if err != nil {
return nil, err
}
Expand All @@ -820,7 +853,7 @@ func (server *SMTPServer) Connect() (*SMTPClient, error) {
Client: c,
KeepAlive: server.KeepAlive,
SendTimeout: server.SendTimeout,
}, nil
}, server.validateAuth(c)
}

// Reset send RSET command to smtp client
Expand Down

0 comments on commit 1829554

Please sign in to comment.