Skip to content

Commit

Permalink
Use common chunker for archive handler (#1717)
Browse files Browse the repository at this point in the history
* optimize the ReadToMax.

* add comment.

* remove dumb comment.

* update comment.

* fix test.

* lint.

* Expired invite link fix (#1713)

* Use comon chunker for archive handler.

---------

Co-authored-by: Zachary Rice <zachary.rice@trufflesec.com>
  • Loading branch information
ahrav and zricethezav committed Sep 6, 2023
1 parent bf581ae commit f6512ac
Showing 1 changed file with 53 additions and 34 deletions.
87 changes: 53 additions & 34 deletions pkg/handlers/archive.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,15 @@ import (

"github.com/trufflesecurity/trufflehog/v3/pkg/common"
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
"github.com/trufflesecurity/trufflehog/v3/pkg/sources"
)

type ctxKey int

const (
depthKey ctxKey = iota

errMaxArchiveDepthReached = "max archive depth reached"
)

var (
Expand Down Expand Up @@ -81,51 +84,66 @@ func (a *Archive) FromFile(originalCtx context.Context, data io.Reader) chan []b
return archiveChan
}

type decompressorInfo struct {
depth int
reader io.Reader
archiveChan chan []byte
archiver archiver.Decompressor
}

// openArchive takes a reader and extracts the contents up to the maximum depth.
func (a *Archive) openArchive(ctx context.Context, depth int, reader io.Reader, archiveChan chan []byte) error {
if depth >= maxDepth {
return fmt.Errorf("max archive depth reached")
return fmt.Errorf(errMaxArchiveDepthReached)
}
format, reader, err := archiver.Identify("", reader)
if err != nil {
if errors.Is(err, archiver.ErrNoMatch) && depth > 0 {
chunkSize := 10 * 1024
for {
chunk := make([]byte, chunkSize)
n, _ := reader.Read(chunk)
archiveChan <- chunk
if n < chunkSize {
break
}
}
return nil
}
return err

format, arReader, err := archiver.Identify("", reader)
if errors.Is(err, archiver.ErrNoMatch) && depth > 0 {
return a.handleNonArchiveContent(ctx, arReader, archiveChan)
}

switch archive := format.(type) {
case archiver.Decompressor:
compReader, err := archive.OpenReader(reader)
if err != nil {
return err
}
fileBytes, err := a.ReadToMax(ctx, compReader)
if err != nil {
return err
}
newReader := bytes.NewReader(fileBytes)
return a.openArchive(ctx, depth+1, newReader, archiveChan)
info := decompressorInfo{depth: depth, reader: arReader, archiveChan: archiveChan, archiver: archive}
return a.handleDecompressor(ctx, info)
case archiver.Extractor:
err := archive.Extract(context.WithValue(ctx, depthKey, depth+1), reader, nil, a.extractorHandler(archiveChan))
if err != nil {
return archive.Extract(context.WithValue(ctx, depthKey, depth+1), reader, nil, a.extractorHandler(archiveChan))
default:
return fmt.Errorf("unknown archive type: %s", format.Name())
}
}

func (a *Archive) handleNonArchiveContent(ctx context.Context, reader io.Reader, archiveChan chan []byte) error {
aCtx := logContext.AddLogger(ctx)
chunkReader := sources.NewChunkReader()
chunkResChan := chunkReader(aCtx, reader)
for data := range chunkResChan {
if err := data.Error(); err != nil {
aCtx.Logger().Error(err, "error reading chunk")
continue
}
if err := common.CancellableWrite(ctx, archiveChan, data.Bytes()); err != nil {
return err
}
return nil
}
return fmt.Errorf("Unknown archive type: %s", format.Name())
return nil
}

func (a *Archive) handleDecompressor(ctx context.Context, info decompressorInfo) error {
compReader, err := info.archiver.OpenReader(info.reader)
if err != nil {
return err
}
fileBytes, err := a.ReadToMax(ctx, compReader)
if err != nil {
return err
}
newReader := bytes.NewReader(fileBytes)
return a.openArchive(ctx, info.depth+1, newReader, info.archiveChan)
}

// IsFiletype returns true if the provided reader is an archive.
func (a *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader, bool) {
func (a *Archive) IsFiletype(_ context.Context, reader io.Reader) (io.Reader, bool) {
format, readerB, err := archiver.Identify("", reader)
if err != nil {
return readerB, false
Expand All @@ -135,8 +153,9 @@ func (a *Archive) IsFiletype(ctx context.Context, reader io.Reader) (io.Reader,
return readerB, true
case archiver.Decompressor:
return readerB, true
default:
return readerB, false
}
return readerB, false
}

// extractorHandler is applied to each file in an archiver.Extractor file.
Expand Down Expand Up @@ -180,7 +199,7 @@ func (a *Archive) ReadToMax(ctx context.Context, reader io.Reader) (data []byte,
if e, ok := r.(error); ok {
err = e
} else {
err = fmt.Errorf("Panic occurred: %v", r)
err = fmt.Errorf("panic occurred: %v", r)
}
logger.Error(err, "Panic occurred when reading archive")
}
Expand Down Expand Up @@ -283,7 +302,7 @@ func (a *Archive) HandleSpecialized(ctx logContext.Context, reader io.Reader) (i
// The caller is responsible for closing the returned reader.
func (a *Archive) extractDebContent(ctx logContext.Context, file io.Reader) (io.ReadCloser, error) {
if a.currentDepth >= maxDepth {
return nil, fmt.Errorf("max archive depth reached")
return nil, fmt.Errorf(errMaxArchiveDepthReached)
}

tmpEnv, err := a.createTempEnv(ctx, file)
Expand Down

0 comments on commit f6512ac

Please sign in to comment.