Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(deflate): Non working deflate transfer mode #461

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 60 additions & 3 deletions client_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ftpserver

import (
"bufio"
"compress/flate"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -34,6 +35,14 @@ const (
TransferTypeBinary
)

// TransferMode is the enumerable that represents the transfer mode (stream, block, compressed, deflate)
type TransferMode int8

const (
TransferModeStream TransferMode = iota // TransferModeStream is the standard mode
TransferModeDeflate // TransferModeDeflate is the deflate mode
)

// DataChannel is the enumerable that represents the data channel (active or passive)
type DataChannel int8

Expand Down Expand Up @@ -99,6 +108,7 @@ type clientHandler struct {
selectedHashAlgo HASHAlgo // algorithm used when we receive the HASH command
logger log.Logger // Client handler logging
currentTransferType TransferType // current transfer type
transferMode TransferMode // Transfer mode (stream, block, compressed)
transferWg sync.WaitGroup // wait group for command that open a transfer connection
transferMu sync.Mutex // this mutex will protect the transfer parameters
transfer transferHandler // Transfer connection (passive or active)s
Expand Down Expand Up @@ -627,7 +637,7 @@ func (c *clientHandler) GetTranferInfo() string {
return c.transfer.GetInfo()
}

func (c *clientHandler) TransferOpen(info string) (net.Conn, error) {
func (c *clientHandler) TransferOpen(info string) (io.ReadWriter, error) {
c.transferMu.Lock()
defer c.transferMu.Unlock()

Expand Down Expand Up @@ -663,6 +673,17 @@ func (c *clientHandler) TransferOpen(info string) (net.Conn, error) {
return nil, err
}

var transferStream io.ReadWriter = conn

if c.transferMode == TransferModeDeflate {
transferStream, err = newDeflateTransfer(transferStream, c.server.settings.DeflateCompressionLevel)
if err != nil {
c.writeMessage(StatusActionNotTaken, fmt.Sprintf("Could not switch to deflate mode: %v", err))

return nil, fmt.Errorf("could not switch to deflate mode: %w", err)
}
}

c.isTransferOpen = true
c.transfer.SetInfo(info)

Expand All @@ -675,13 +696,27 @@ func (c *clientHandler) TransferOpen(info string) (net.Conn, error) {
"localAddr", conn.LocalAddr().String())
}

return conn, nil
return transferStream, nil
}

func (c *clientHandler) TransferClose(err error) {
// Flusher is the interface that wraps the basic Flush method.
type Flusher interface {
Flush() error
}

func (c *clientHandler) TransferClose(transfer io.ReadWriter, err error) {
c.transferMu.Lock()
defer c.transferMu.Unlock()

if flush, ok := transfer.(Flusher); ok {
if errFlush := flush.Flush(); errFlush != nil {
c.logger.Warn(
"Error flushing transfer connection",
"err", errFlush,
)
}
}

errClose := c.closeTransfer()
if errClose != nil {
c.logger.Warn(
Expand Down Expand Up @@ -788,3 +823,25 @@ func getMessageLines(message string) []string {

return lines
}

// We check that it implements flusher
var _ Flusher = (*deflateReadWriter)(nil)

type deflateReadWriter struct {
io.Reader
*flate.Writer
}

func newDeflateTransfer(conn io.ReadWriter, level int) (io.ReadWriter, error) {
writer, err := flate.NewWriter(conn, level)
if err != nil {
return nil, fmt.Errorf("could not create deflate writer: %w", err)
}

reader := flate.NewReader(conn)

return &deflateReadWriter{
Reader: reader,
Writer: writer,
}, nil
}
1 change: 1 addition & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ type Settings struct {
DisableSTAT bool // Disable Server STATUS, STAT on files and directories will still work
DisableSYST bool // Disable SYST
EnableCOMB bool // Enable COMB support
DeflateCompressionLevel int // Deflate compression level (1-9)
DefaultTransferType TransferType // Transfer type to use if the client don't send the TYPE command
// ActiveConnectionsCheck defines the security requirements for active connections
ActiveConnectionsCheck DataConnectionRequirement
Expand Down
7 changes: 4 additions & 3 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ func (driver *TestServerDriver) Init() {
}

{
dir, _ := os.MkdirTemp("", "example")
if err := os.MkdirAll(dir, 0o750); err != nil {
driver.serverDir, _ = os.MkdirTemp("", "example")
if err := os.MkdirAll(driver.serverDir, 0o750); err != nil {
panic(err)
}

driver.fs = afero.NewBasePathFs(afero.NewOsFs(), dir)
driver.fs = afero.NewBasePathFs(afero.NewOsFs(), driver.serverDir)
}
}

Expand Down Expand Up @@ -126,6 +126,7 @@ type TestServerDriver struct {
CloseOnConnect bool // disconnect the client as soon as it connects

Settings *Settings // Settings
serverDir string
fs afero.Fs
clientMU sync.Mutex
Clients []ClientContext
Expand Down
2 changes: 1 addition & 1 deletion errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestCustomErrorsCode(t *testing.T) {
func TestTransferCloseStorageExceeded(t *testing.T) {
buf := bytes.Buffer{}
h := clientHandler{writer: bufio.NewWriter(&buf)}
h.TransferClose(ErrStorageExceeded)
h.TransferClose(nil, ErrStorageExceeded)
require.Equal(t, "552 Issue during transfer: storage limit exceeded\r\n", buf.String())
}

Expand Down
6 changes: 3 additions & 3 deletions handle_dirs.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (c *clientHandler) handleLIST(param string) error {
if files, _, err := c.getFileList(param, true); err == nil || errors.Is(err, io.EOF) {
if tr, errTr := c.TransferOpen(info); errTr == nil {
err = c.dirTransferLIST(tr, files)
c.TransferClose(err)
c.TransferClose(tr, err)

return nil
}
Expand All @@ -210,7 +210,7 @@ func (c *clientHandler) handleNLST(param string) error {
if files, parentDir, err := c.getFileList(param, true); err == nil || errors.Is(err, io.EOF) {
if tr, errTrOpen := c.TransferOpen(info); errTrOpen == nil {
err = c.dirTransferNLST(tr, files, parentDir)
c.TransferClose(err)
c.TransferClose(tr, err)

return nil
}
Expand Down Expand Up @@ -257,7 +257,7 @@ func (c *clientHandler) handleMLSD(param string) error {
if files, _, err := c.getFileList(param, false); err == nil || errors.Is(err, io.EOF) {
if tr, errTr := c.TransferOpen(info); errTr == nil {
err = c.dirTransferMLSD(tr, files)
c.TransferClose(err)
c.TransferClose(tr, err)

return nil
}
Expand Down
5 changes: 2 additions & 3 deletions handle_files.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"hash"
"hash/crc32"
"io"
"net"
"os"
"runtime"
"strconv"
Expand Down Expand Up @@ -118,10 +117,10 @@ func (c *clientHandler) transferFile(write bool, appendFile bool, param, info st
}

// closing the transfer we also send the response message to the FTP client
c.TransferClose(err)
c.TransferClose(fileTransferConn, err)
}

func (c *clientHandler) doFileTransfer(transferConn net.Conn, file io.ReadWriter, write bool) error {
func (c *clientHandler) doFileTransfer(transferConn io.ReadWriter, file io.ReadWriter, write bool) error {
var err error
var reader io.Reader
var writer io.Writer
Expand Down
9 changes: 7 additions & 2 deletions handle_misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,14 @@ func (c *clientHandler) handleTYPE(param string) error {
}

func (c *clientHandler) handleMODE(param string) error {
if param == "S" {
switch param {
case "S":
c.transferMode = TransferModeStream
c.writeMessage(StatusOK, "Using stream mode")
} else {
case "Z":
c.transferMode = TransferModeDeflate
c.writeMessage(StatusOK, "Using deflate mode")
default:
c.writeMessage(StatusNotImplementedParam, "Unsupported mode")
}

Expand Down
4 changes: 4 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ func (server *FtpServer) loadSettings() error {
settings.Banner = "ftpserver - golang FTP server"
}

if settings.DeflateCompressionLevel == 0 {
settings.DeflateCompressionLevel = 5
}

server.settings = settings

return nil
Expand Down
3 changes: 3 additions & 0 deletions transfer_active.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ func (c *clientHandler) handlePORT(param string) error {
return nil
}

// activeTransferHandler implements the transferHandler interface
var _ transferHandler = (*activeTransferHandler)(nil)

// Active connection
type activeTransferHandler struct {
raddr *net.TCPAddr // Remote address of the client
Expand Down
3 changes: 3 additions & 0 deletions transfer_pasv.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ type transferHandler interface {
GetInfo() string
}

// activeTransferHandler implements the transferHandler interface
var _ transferHandler = (*passiveTransferHandler)(nil)

// Passive connection
type passiveTransferHandler struct {
listener net.Listener // TCP or SSL Listener
Expand Down
Loading
Loading