Skip to content

Commit

Permalink
[fix] - Add Size Method to BufferedReadSeeker and Refactor Context Ti…
Browse files Browse the repository at this point in the history
…meout Handling in HandleFile (#3307)
  • Loading branch information
ahrav committed Sep 23, 2024
1 parent c33ab21 commit 97fd2f8
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 10 deletions.
10 changes: 2 additions & 8 deletions pkg/handlers/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@ const (
var (
// NOTE: This is a temporary workaround for |openArchive| incrementing depth twice per archive.
// See: https://github.com/trufflesecurity/trufflehog/issues/2942
maxDepth = 5 * 2
maxSize = 2 << 30 // 2 GB
maxTimeout = time.Duration(30) * time.Second
maxDepth = 5 * 2
maxSize = 2 << 30 // 2 GB
)

// SetArchiveMaxSize sets the maximum size of the archive.
Expand All @@ -34,9 +33,6 @@ func SetArchiveMaxSize(size int) { maxSize = size }
// SetArchiveMaxDepth sets the maximum depth of the archive.
func SetArchiveMaxDepth(depth int) { maxDepth = depth }

// SetArchiveMaxTimeout sets the maximum timeout for the archive handler.
func SetArchiveMaxTimeout(timeout time.Duration) { maxTimeout = timeout }

// archiveHandler is a handler for common archive files that are supported by the archiver library.
type archiveHandler struct{ *defaultHandler }

Expand All @@ -57,8 +53,6 @@ func (h *archiveHandler) HandleFile(ctx logContext.Context, input fileReader) (c
}

go func() {
ctx, cancel := logContext.WithTimeout(ctx, maxTimeout)
defer cancel()
defer close(dataChan)

// Update the metrics for the file processing.
Expand Down
20 changes: 18 additions & 2 deletions pkg/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"time"

"github.com/gabriel-vasile/mimetype"
"github.com/mholt/archiver/v4"
Expand Down Expand Up @@ -249,6 +250,11 @@ func selectHandler(mimeT mimeType, isGenericArchive bool) FileHandler {
}
}

var maxTimeout = time.Duration(30) * time.Second

// SetArchiveMaxTimeout sets the maximum timeout for the archive handler.
func SetArchiveMaxTimeout(timeout time.Duration) { maxTimeout = timeout }

// HandleFile orchestrates the complete file handling process for a given file.
// It determines the MIME type of the file, selects the appropriate handler based on this type, and processes the file.
// This function initializes the handling process and delegates to the specific handler to manage file
Expand Down Expand Up @@ -279,20 +285,30 @@ func HandleFile(
}
defer rdr.Close()

size, err := rdr.Size()
if err != nil {
ctx.Logger().Error(err, "error getting file size")
}

ctx = logContext.WithValues(ctx, "mime", rdr.mime.String(), "size_bytes", size)

mimeT := mimeType(rdr.mime.String())
config := newFileHandlingConfig(options...)
if config.skipArchives && rdr.isGenericArchive {
ctx.Logger().V(5).Info("skipping archive file", "mime", mimeT)
return nil
}

processingCtx, cancel := logContext.WithTimeout(ctx, maxTimeout)
defer cancel()

handler := selectHandler(mimeT, rdr.isGenericArchive)
archiveChan, err := handler.HandleFile(ctx, rdr) // Delegate to the specific handler to process the file.
archiveChan, err := handler.HandleFile(processingCtx, rdr) // Delegate to the specific handler to process the file.
if err != nil {
return fmt.Errorf("error handling file: %w", err)
}

return handleChunks(ctx, archiveChan, chunkSkel, reporter)
return handleChunks(processingCtx, archiveChan, chunkSkel, reporter)
}

// handleChunks reads data from the handlerChan and uses it to fill chunks according to a predefined skeleton (chunkSkel).
Expand Down
27 changes: 27 additions & 0 deletions pkg/iobuf/bufferedreaderseeker.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package iobuf

import (
"errors"
"fmt"
"io"
"os"

Expand Down Expand Up @@ -355,3 +356,29 @@ func (br *BufferedReadSeeker) Close() error {
}
return nil
}

// Size returns the total size of the reader.
func (br *BufferedReadSeeker) Size() (int64, error) {
if br.sizeKnown {
return br.totalSize, nil
}

currentPos, err := br.Seek(0, io.SeekCurrent)
if err != nil {
return 0, fmt.Errorf("failed to get current position: %w", err)
}

endPos, err := br.Seek(0, io.SeekEnd)
if err != nil {
return 0, fmt.Errorf("failed to seek to end: %w", err)
}

if _, err = br.Seek(currentPos, io.SeekStart); err != nil {
return 0, fmt.Errorf("failed to restore position: %w", err)
}

br.totalSize = endPos
br.sizeKnown = true

return br.totalSize, nil
}
128 changes: 128 additions & 0 deletions pkg/iobuf/bufferedreaderseeker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package iobuf

import (
"bytes"
"errors"
"io"
"strings"
"testing"
Expand Down Expand Up @@ -350,3 +351,130 @@ func TestBufferedReaderSeekerReadAt(t *testing.T) {
})
}
}

// TestBufferedReadSeekerSize tests the Size method of BufferedReadSeeker.
func TestBufferedReadSeekerSize(t *testing.T) {
tests := []struct {
name string
reader io.Reader
setup func(*BufferedReadSeeker)
expectedSize int64
expectError bool
verifyPosition func(*BufferedReadSeeker, int64)
}{
{
name: "size of seekable reader",
reader: strings.NewReader("Hello, World!"),
expectedSize: 13,
},
{
name: "size of non-seekable reader",
reader: bytes.NewBufferString("Hello, World!"),
expectedSize: 13,
},
{
name: "size of empty seekable reader",
reader: strings.NewReader(""),
expectedSize: 0,
},
{
name: "size of empty non-seekable reader",
reader: bytes.NewBufferString(""),
expectedSize: 0,
},
{
name: "size of non-seekable reader after partial read",
reader: bytes.NewBufferString("Partial read data"),
setup: func(brs *BufferedReadSeeker) {
// Read first 7 bytes ("Partial").
buf := make([]byte, 7)
_, _ = brs.Read(buf)
},
expectedSize: 17, // "Partial read data" is 16 bytes
expectError: false,
verifyPosition: func(brs *BufferedReadSeeker, expectedSize int64) {
// After Size is called, the read position should remain at 7
currentPos, err := brs.Seek(0, io.SeekCurrent)
assert.NoError(t, err)
assert.Equal(t, int64(7), currentPos)
},
},
{
name: "repeated Size calls",
reader: strings.NewReader("Repeated Size Calls Test"),
expectedSize: 24,
expectError: false,
setup: func(brs *BufferedReadSeeker) {
// Call Size multiple times.
size1, err1 := brs.Size()
assert.NoError(t, err1)
assert.Equal(t, int64(24), size1)

size2, err2 := brs.Size()
assert.NoError(t, err2)
assert.Equal(t, int64(24), size2)
},
},
{
name: "size with error during reading",
reader: &errorReader{
data: "Data before error",
errorAfter: 5, // Return error after reading 5 bytes
},
expectedSize: 0,
expectError: true,
},
{
name: "size with limited reader simulating EOF",
reader: io.LimitReader(strings.NewReader("Limited data"), 7),
expectedSize: 7,
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

brs := NewBufferedReaderSeeker(tt.reader)

if tt.setup != nil {
tt.setup(brs)
}

size, err := brs.Size()
if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedSize, size)
}

if tt.verifyPosition != nil {
tt.verifyPosition(brs, tt.expectedSize)
}
})
}
}

// errorReader is an io.Reader that returns an error after reading a specified number of bytes.
// It's used to simulate non-EOF errors during read operations.
type errorReader struct {
data string
errorAfter int // Number of bytes to read before returning an error
readBytes int
}

func (er *errorReader) Read(p []byte) (int, error) {
if er.readBytes >= er.errorAfter {
return 0, errors.New("simulated read error")
}
remaining := er.errorAfter - er.readBytes
toRead := len(p)
if toRead > remaining {
toRead = remaining
}
copy(p, er.data[er.readBytes:er.readBytes+toRead])
er.readBytes += toRead
return toRead, nil
}

0 comments on commit 97fd2f8

Please sign in to comment.