Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix data race sending mail #82

Merged
merged 6 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions email.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net/textproto"
"strconv"
"strings"
"sync"
"time"

"github.com/toorop/go-dkim"
Expand Down Expand Up @@ -55,6 +56,7 @@ type SMTPServer struct {

// SMTPClient represents a SMTP Client for send email
type SMTPClient struct {
mu sync.Mutex
Client *smtpClient
KeepAlive bool
SendTimeout time.Duration
Expand Down Expand Up @@ -865,21 +867,29 @@ func (server *SMTPServer) Connect() (*SMTPClient, error) {

// Reset send RSET command to smtp client
func (smtpClient *SMTPClient) Reset() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.reset()
}

// Noop send NOOP command to smtp client
func (smtpClient *SMTPClient) Noop() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.noop()
}

// Quit send QUIT command to smtp client
func (smtpClient *SMTPClient) Quit() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.quit()
}

// Close closes the connection
func (smtpClient *SMTPClient) Close() error {
smtpClient.mu.Lock()
defer smtpClient.mu.Unlock()
return smtpClient.Client.close()
}

Expand Down Expand Up @@ -909,14 +919,14 @@ func send(from string, to []string, msg string, client *SMTPClient) error {
if client.SendTimeout != 0 {
smtpSendChannel = make(chan error, 1)

go func(from string, to []string, msg string, c *smtpClient) {
smtpSendChannel <- sendMailProcess(from, to, msg, c)
}(from, to, msg, client.Client)
go func(from string, to []string, msg string, client *SMTPClient) {
smtpSendChannel <- sendMailProcess(from, to, msg, client)
}(from, to, msg, client)
}

if client.SendTimeout == 0 {
// no SendTimeout, just fire the sendMailProcess
return sendMailProcess(from, to, msg, client.Client)
return sendMailProcess(from, to, msg, client)
}

// get the send result or timeout result, which ever happens first
Expand All @@ -928,35 +938,36 @@ func send(from string, to []string, msg string, client *SMTPClient) error {
checkKeepAlive(client)
return errors.New("Mail Error: SMTP Send timed out")
}

}
}

return errors.New("Mail Error: No SMTP Client Provided")
}

func sendMailProcess(from string, to []string, msg string, c *smtpClient) error {
func sendMailProcess(from string, to []string, msg string, c *SMTPClient) error {
c.mu.Lock()
defer c.mu.Unlock()

cmdArgs := make(map[string]string)

if _, ok := c.ext["SIZE"]; ok {
if _, ok := c.Client.ext["SIZE"]; ok {
cmdArgs["SIZE"] = strconv.Itoa(len(msg))
}

// Set the sender
if err := c.mail(from, cmdArgs); err != nil {
if err := c.Client.mail(from, cmdArgs); err != nil {
return err
}

// Set the recipients
for _, address := range to {
if err := c.rcpt(address); err != nil {
if err := c.Client.rcpt(address); err != nil {
return err
}
}

// Send the data command
w, err := c.data()
w, err := c.Client.data()
if err != nil {
return err
}
Expand All @@ -978,9 +989,9 @@ func sendMailProcess(from string, to []string, msg string, c *smtpClient) error
// check if keepAlive for close or reset
func checkKeepAlive(client *SMTPClient) {
if client.KeepAlive {
client.Client.reset()
client.Reset()
} else {
client.Client.quit()
client.Client.close()
client.Quit()
client.Close()
}
}
112 changes: 112 additions & 0 deletions email_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package mail

import (
"fmt"
"log"
"net"
"testing"
"time"
)

func TestSendRace(t *testing.T) {
port := 56666
port2 := 56667
timeout := 1 * time.Second

responses := []string{
`220 test connected`,
`250 after helo`,
`250 after mail from`,
`250 after rcpt to`,
`354 after data`,
}

startService(port, responses, 5*time.Second)
startService(port2, responses, 0)

server := NewSMTPClient()
server.ConnectTimeout = timeout
server.SendTimeout = timeout
server.KeepAlive = false
server.Host = `127.0.0.1`
server.Port = port

smtpClient, err := server.Connect()
if err != nil {
log.Fatalf("couldn't connect: %s", err.Error())
}
defer smtpClient.Close()

// create another server in other port to test timeouts
server.Port = port2
smtpClient2, err := server.Connect()
if err != nil {
log.Fatalf("couldn't connect: %s", err.Error())
}
defer smtpClient2.Close()

msg := NewMSG().
SetFrom(`foo@bar`).
AddTo(`rcpt@bar`).
SetSubject("subject").
SetBody(TextPlain, "body")

// the smtpClient2 has not timeout
err = msg.Send(smtpClient2)
if err != nil {
log.Fatalf("couldn't send: %s", err.Error())
}

// the smtpClient send to listener with the last response is after SendTimeout, so when this error is returned the test succeed.
err = msg.Send(smtpClient)
if err != nil && err.Error() != "Mail Error: SMTP Send timed out" {
log.Fatalf("couldn't send: %s", err.Error())
}
}

func startService(port int, responses []string, timeout time.Duration) {
log.Printf("starting service at %d...\n", port)
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
if err != nil {
log.Fatalf("couldn't listen to port %d: %s", port, err)
}

go func() {
for {
conn, err := listener.Accept()
if err != nil {
log.Fatalf("couldn't listen accept the request in port %d", port)
}
go respond(conn, responses, timeout)
}
}()
}

func respond(conn net.Conn, responses []string, timeout time.Duration) {
buf := make([]byte, 1024)
for _, resp := range responses {
write(conn, resp)
n, err := conn.Read(buf)
if err != nil {
log.Println("couldn't read data")
return
}
readStr := string(buf[:n])
log.Printf("READ:%s", string(readStr))
}

// if timeout, sleep for that time, otherwise sent a 250 OK
if timeout > 0 {
time.Sleep(timeout)
} else {
write(conn, "250 OK")
}

conn.Close()
fmt.Print("\n\n")
}

func write(conn net.Conn, command string) {
log.Printf("WRITE:%s", command)
conn.Write([]byte(command + "\n"))
}
Loading