diff --git a/client_handler.go b/client_handler.go index 011ddf6c..c519aa91 100644 --- a/client_handler.go +++ b/client_handler.go @@ -2,6 +2,7 @@ package ftpserver import ( "bufio" + "compress/flate" "errors" "fmt" "io" @@ -34,6 +35,14 @@ const ( TransferTypeBinary ) +// TransferMode is the enumerable that represents the transfer mode (stream, block, compressed, deflate) +type TransferMode int8 + +const ( + TransferModeStream TransferMode = iota // TransferModeStream is the standard mode + TransferModeDeflate // TransferModeDeflate is the deflate mode +) + // DataChannel is the enumerable that represents the data channel (active or passive) type DataChannel int8 @@ -99,6 +108,7 @@ type clientHandler struct { selectedHashAlgo HASHAlgo // algorithm used when we receive the HASH command logger log.Logger // Client handler logging currentTransferType TransferType // current transfer type + transferMode TransferMode // Transfer mode (stream, block, compressed) transferWg sync.WaitGroup // wait group for command that open a transfer connection transferMu sync.Mutex // this mutex will protect the transfer parameters transfer transferHandler // Transfer connection (passive or active)s @@ -627,7 +637,7 @@ func (c *clientHandler) GetTranferInfo() string { return c.transfer.GetInfo() } -func (c *clientHandler) TransferOpen(info string) (net.Conn, error) { +func (c *clientHandler) TransferOpen(info string) (io.ReadWriter, error) { c.transferMu.Lock() defer c.transferMu.Unlock() @@ -663,6 +673,17 @@ func (c *clientHandler) TransferOpen(info string) (net.Conn, error) { return nil, err } + var transferStream io.ReadWriter = conn + + if c.transferMode == TransferModeDeflate { + transferStream, err = newDeflateTransfer(transferStream, c.server.settings.DeflateCompressionLevel) + if err != nil { + c.writeMessage(StatusActionNotTaken, fmt.Sprintf("Could not switch to deflate mode: %v", err)) + + return nil, fmt.Errorf("could not switch to deflate mode: %w", err) + } + } + c.isTransferOpen = true c.transfer.SetInfo(info) @@ -675,13 +696,27 @@ func (c *clientHandler) TransferOpen(info string) (net.Conn, error) { "localAddr", conn.LocalAddr().String()) } - return conn, nil + return transferStream, nil } -func (c *clientHandler) TransferClose(err error) { +// Flusher is the interface that wraps the basic Flush method. +type Flusher interface { + Flush() error +} + +func (c *clientHandler) TransferClose(transfer io.ReadWriter, err error) { c.transferMu.Lock() defer c.transferMu.Unlock() + if flush, ok := transfer.(Flusher); ok { + if errFlush := flush.Flush(); errFlush != nil { + c.logger.Warn( + "Error flushing transfer connection", + "err", errFlush, + ) + } + } + errClose := c.closeTransfer() if errClose != nil { c.logger.Warn( @@ -788,3 +823,25 @@ func getMessageLines(message string) []string { return lines } + +// We check that it implements flusher +var _ Flusher = (*deflateReadWriter)(nil) + +type deflateReadWriter struct { + io.Reader + *flate.Writer +} + +func newDeflateTransfer(conn io.ReadWriter, level int) (io.ReadWriter, error) { + writer, err := flate.NewWriter(conn, level) + if err != nil { + return nil, fmt.Errorf("could not create deflate writer: %w", err) + } + + reader := flate.NewReader(conn) + + return &deflateReadWriter{ + Reader: reader, + Writer: writer, + }, nil +} diff --git a/driver.go b/driver.go index 80316a77..5423a94e 100644 --- a/driver.go +++ b/driver.go @@ -265,6 +265,7 @@ type Settings struct { DisableSTAT bool // Disable Server STATUS, STAT on files and directories will still work DisableSYST bool // Disable SYST EnableCOMB bool // Enable COMB support + DeflateCompressionLevel int // Deflate compression level (1-9) DefaultTransferType TransferType // Transfer type to use if the client don't send the TYPE command // ActiveConnectionsCheck defines the security requirements for active connections ActiveConnectionsCheck DataConnectionRequirement diff --git a/driver_test.go b/driver_test.go index b9a7ccb1..bdfac19f 100644 --- a/driver_test.go +++ b/driver_test.go @@ -56,12 +56,12 @@ func (driver *TestServerDriver) Init() { } { - dir, _ := os.MkdirTemp("", "example") - if err := os.MkdirAll(dir, 0o750); err != nil { + driver.serverDir, _ = os.MkdirTemp("", "example") + if err := os.MkdirAll(driver.serverDir, 0o750); err != nil { panic(err) } - driver.fs = afero.NewBasePathFs(afero.NewOsFs(), dir) + driver.fs = afero.NewBasePathFs(afero.NewOsFs(), driver.serverDir) } } @@ -126,6 +126,7 @@ type TestServerDriver struct { CloseOnConnect bool // disconnect the client as soon as it connects Settings *Settings // Settings + serverDir string fs afero.Fs clientMU sync.Mutex Clients []ClientContext diff --git a/errors_test.go b/errors_test.go index f789ac16..f4cfc2ba 100644 --- a/errors_test.go +++ b/errors_test.go @@ -24,7 +24,7 @@ func TestCustomErrorsCode(t *testing.T) { func TestTransferCloseStorageExceeded(t *testing.T) { buf := bytes.Buffer{} h := clientHandler{writer: bufio.NewWriter(&buf)} - h.TransferClose(ErrStorageExceeded) + h.TransferClose(nil, ErrStorageExceeded) require.Equal(t, "552 Issue during transfer: storage limit exceeded\r\n", buf.String()) } diff --git a/handle_dirs.go b/handle_dirs.go index d3473e58..5a962b63 100644 --- a/handle_dirs.go +++ b/handle_dirs.go @@ -191,7 +191,7 @@ func (c *clientHandler) handleLIST(param string) error { if files, _, err := c.getFileList(param, true); err == nil || errors.Is(err, io.EOF) { if tr, errTr := c.TransferOpen(info); errTr == nil { err = c.dirTransferLIST(tr, files) - c.TransferClose(err) + c.TransferClose(tr, err) return nil } @@ -210,7 +210,7 @@ func (c *clientHandler) handleNLST(param string) error { if files, parentDir, err := c.getFileList(param, true); err == nil || errors.Is(err, io.EOF) { if tr, errTrOpen := c.TransferOpen(info); errTrOpen == nil { err = c.dirTransferNLST(tr, files, parentDir) - c.TransferClose(err) + c.TransferClose(tr, err) return nil } @@ -257,7 +257,7 @@ func (c *clientHandler) handleMLSD(param string) error { if files, _, err := c.getFileList(param, false); err == nil || errors.Is(err, io.EOF) { if tr, errTr := c.TransferOpen(info); errTr == nil { err = c.dirTransferMLSD(tr, files) - c.TransferClose(err) + c.TransferClose(tr, err) return nil } diff --git a/handle_files.go b/handle_files.go index 0b41ba12..63436171 100644 --- a/handle_files.go +++ b/handle_files.go @@ -12,7 +12,6 @@ import ( "hash" "hash/crc32" "io" - "net" "os" "runtime" "strconv" @@ -118,10 +117,10 @@ func (c *clientHandler) transferFile(write bool, appendFile bool, param, info st } // closing the transfer we also send the response message to the FTP client - c.TransferClose(err) + c.TransferClose(fileTransferConn, err) } -func (c *clientHandler) doFileTransfer(transferConn net.Conn, file io.ReadWriter, write bool) error { +func (c *clientHandler) doFileTransfer(transferConn io.ReadWriter, file io.ReadWriter, write bool) error { var err error var reader io.Reader var writer io.Writer diff --git a/handle_misc.go b/handle_misc.go index 6c3a7cd1..a63b46a6 100644 --- a/handle_misc.go +++ b/handle_misc.go @@ -284,9 +284,14 @@ func (c *clientHandler) handleTYPE(param string) error { } func (c *clientHandler) handleMODE(param string) error { - if param == "S" { + switch param { + case "S": + c.transferMode = TransferModeStream c.writeMessage(StatusOK, "Using stream mode") - } else { + case "Z": + c.transferMode = TransferModeDeflate + c.writeMessage(StatusOK, "Using deflate mode") + default: c.writeMessage(StatusNotImplementedParam, "Unsupported mode") } diff --git a/server.go b/server.go index c8d84eba..8804159a 100644 --- a/server.go +++ b/server.go @@ -137,6 +137,10 @@ func (server *FtpServer) loadSettings() error { settings.Banner = "ftpserver - golang FTP server" } + if settings.DeflateCompressionLevel == 0 { + settings.DeflateCompressionLevel = 5 + } + server.settings = settings return nil diff --git a/transfer_active.go b/transfer_active.go index 8ba853a6..dd28516a 100644 --- a/transfer_active.go +++ b/transfer_active.go @@ -72,6 +72,9 @@ func (c *clientHandler) handlePORT(param string) error { return nil } +// activeTransferHandler implements the transferHandler interface +var _ transferHandler = (*activeTransferHandler)(nil) + // Active connection type activeTransferHandler struct { raddr *net.TCPAddr // Remote address of the client diff --git a/transfer_pasv.go b/transfer_pasv.go index bd0c0324..17c0f689 100644 --- a/transfer_pasv.go +++ b/transfer_pasv.go @@ -26,6 +26,9 @@ type transferHandler interface { GetInfo() string } +// activeTransferHandler implements the transferHandler interface +var _ transferHandler = (*passiveTransferHandler)(nil) + // Passive connection type passiveTransferHandler struct { listener net.Listener // TCP or SSL Listener diff --git a/transfer_test.go b/transfer_test.go index 20aee28f..959cb646 100644 --- a/transfer_test.go +++ b/transfer_test.go @@ -10,6 +10,7 @@ import ( "math/rand" "net" "os" + "path" "path/filepath" "runtime" "strconv" @@ -114,12 +115,21 @@ func ftpDownloadAndHash(t *testing.T, ftp *goftp.Client, filename string) string return hex.EncodeToString(hasher.Sum(nil)) } -func ftpDownloadAndHashWithRawConnection(t *testing.T, raw goftp.RawConn, fileName string) string { +type ftpDownloadOptions struct { + deflateMode bool + otherWriter io.Writer +} + +func ftpDownloadAndHashWithRawConnection(t *testing.T, raw goftp.RawConn, fileName string, options *ftpDownloadOptions) string { t.Helper() req := require.New(t) hasher := sha256.New() + if options == nil { + options = &ftpDownloadOptions{} + } + dcGetter, err := raw.PrepareDataConn() req.NoError(err) @@ -130,7 +140,20 @@ func ftpDownloadAndHashWithRawConnection(t *testing.T, raw goftp.RawConn, fileNa dataConn, err := dcGetter() req.NoError(err) - _, err = io.Copy(hasher, dataConn) + var transfer io.ReadWriter = dataConn + + if options.deflateMode { + transfer, err = newDeflateTransfer(transfer, 5) + req.NoError(err) + } + + var writer io.Writer = hasher + + if options.otherWriter != nil { + writer = io.MultiWriter(writer, options.otherWriter) + } + + _, err = io.Copy(writer, transfer) req.NoError(err) err = dataConn.Close() @@ -143,15 +166,24 @@ func ftpDownloadAndHashWithRawConnection(t *testing.T, raw goftp.RawConn, fileNa return hex.EncodeToString(hasher.Sum(nil)) } -func ftpUploadWithRawConnection(t *testing.T, raw goftp.RawConn, file io.Reader, fileName string, appendFile bool) { +type ftpUploadOptions struct { + appendFile bool + deflateMode bool +} + +func ftpUploadWithRawConnection(t *testing.T, raw goftp.RawConn, file io.Reader, fileName string, options *ftpUploadOptions) { t.Helper() req := require.New(t) dcGetter, err := raw.PrepareDataConn() req.NoError(err) + if options == nil { + options = &ftpUploadOptions{} + } + cmd := "STOR" - if appendFile { + if options.appendFile { cmd = "APPE" } @@ -162,9 +194,20 @@ func ftpUploadWithRawConnection(t *testing.T, raw goftp.RawConn, file io.Reader, dataConn, err := dcGetter() req.NoError(err) - _, err = io.Copy(dataConn, file) + var transfer io.ReadWriter = dataConn + + if options.deflateMode { + transfer, err = newDeflateTransfer(transfer, 5) + req.NoError(err) + } + + _, err = io.Copy(transfer, file) req.NoError(err) + if transferFlusher, ok := transfer.(Flusher); ok { + req.NoError(transferFlusher.Flush()) + } + err = dataConn.Close() req.NoError(err) @@ -536,7 +579,7 @@ func TestAPPEExistingFile(t *testing.T) { _, err = file.Seek(1024, io.SeekStart) require.NoError(t, err) - ftpUploadWithRawConnection(t, raw, file, fileName, true) + ftpUploadWithRawConnection(t, raw, file, fileName, &ftpUploadOptions{appendFile: true}) info, err := client.Stat(fileName) require.NoError(t, err) @@ -572,7 +615,7 @@ func TestAPPENewFile(t *testing.T) { fileName := filepath.Base(file.Name()) - ftpUploadWithRawConnection(t, raw, file, fileName, true) + ftpUploadWithRawConnection(t, raw, file, fileName, &ftpUploadOptions{appendFile: true}) localHash := hashFile(t, file) remoteHash := ftpDownloadAndHash(t, client, fileName) @@ -927,7 +970,7 @@ func TestASCIITransfers(t *testing.T) { _, err = file.Seek(0, io.SeekStart) require.NoError(t, err) - ftpUploadWithRawConnection(t, raw, file, "file.txt", false) + ftpUploadWithRawConnection(t, raw, file, "file.txt", nil) files, err := client.ReadDir("/") require.NoError(t, err) @@ -939,7 +982,7 @@ func TestASCIITransfers(t *testing.T) { require.Equal(t, int64(len(contents)), files[0].Size()) } - remoteHash := ftpDownloadAndHashWithRawConnection(t, raw, "file.txt") + remoteHash := ftpDownloadAndHashWithRawConnection(t, raw, "file.txt", nil) localHash := hashFile(t, file) require.Equal(t, localHash, remoteHash) } @@ -979,9 +1022,9 @@ func TestASCIITransfersInvalidFiles(t *testing.T) { require.NoError(t, err) require.Equal(t, StatusOK, rc, response) - ftpUploadWithRawConnection(t, raw, file, "file.bin", false) + ftpUploadWithRawConnection(t, raw, file, "file.bin", nil) - remoteHash := ftpDownloadAndHashWithRawConnection(t, raw, "file.bin") + remoteHash := ftpDownloadAndHashWithRawConnection(t, raw, "file.bin", nil) require.Equal(t, localHash, remoteHash) } @@ -1231,3 +1274,84 @@ func getPortFromPASVResponse(t *testing.T, resp string) int { return port } + +func TestTransferModeDeflate(t *testing.T) { + driver := &TestServerDriver{Debug: true} + server := NewTestServerWithTestDriver(t, driver) + + conf := goftp.Config{ + User: authUser, + Password: authPass, + } + client, err := goftp.DialConfig(conf, server.Addr()) + require.NoError(t, err, "Couldn't connect") + + defer func() { require.NoError(t, client.Close()) }() + + raw, err := client.OpenRawConn() + require.NoError(t, err) + + defer func() { require.NoError(t, raw.Close()) }() + + file, err := os.CreateTemp("", "ftpserver") + require.NoError(t, err) + + contents := []byte("line1\r\n\r\nline3\r\n,line4") + _, err = file.Write(contents) + require.NoError(t, err) + localHash := hashFile(t, file) + + defer func() { require.NoError(t, file.Close()) }() + + { + rc, response, errMode := raw.SendCommand("MODE Z") + require.NoError(t, errMode) + require.Equal(t, StatusOK, rc, response) + } + + _, err = file.Seek(0, io.SeekStart) + require.NoError(t, err) + + ftpUploadWithRawConnection(t, raw, file, "file.txt", &ftpUploadOptions{deflateMode: true}) + + files, err := client.ReadDir("/") + require.NoError(t, err) + require.Len(t, files, 1) + + { // Check on server dir + fp, err := os.Open(path.Join(driver.serverDir, "file.txt")) + require.NoError(t, err) + + defer func() { require.NoError(t, fp.Close()) }() + + readContents, err := io.ReadAll(fp) + require.NoError(t, err) + require.Equal(t, string(contents), string(readContents)) + } + + /*{ + rc, response, errMode := raw.SendCommand("MODE S") + require.NoError(t, errMode) + require.Equal(t, StatusOK, rc, response) + } + + { // Hash on standard connection + writer := bytes.NewBuffer(nil) + remoteHash := ftpDownloadAndHashWithRawConnection(t, raw, "file.txt", &ftpDownloadOptions{otherWriter: writer}) + require.Equal(t, string(contents), writer.String()) + require.Equal(t, localHash, remoteHash) + } + + { + rc, response, errMode := raw.SendCommand("MODE Z") + require.NoError(t, errMode) + require.Equal(t, StatusOK, rc, response) + }*/ + + { // Hash on deflate connection + writer := bytes.NewBuffer(nil) + remoteHash := ftpDownloadAndHashWithRawConnection(t, raw, "file.txt", &ftpDownloadOptions{deflateMode: true, otherWriter: writer}) + require.Equal(t, string(contents), writer.String()) + require.Equal(t, localHash, remoteHash) + } +}