diff --git a/main.go b/main.go index 801e30e8a4e6..452dd8ad9b75 100644 --- a/main.go +++ b/main.go @@ -17,9 +17,11 @@ import ( "github.com/gorilla/mux" "github.com/jpillora/overseer" "github.com/sirupsen/logrus" + "gopkg.in/alecthomas/kingpin.v2" + + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/updater" "github.com/trufflesecurity/trufflehog/v3/pkg/version" - "gopkg.in/alecthomas/kingpin.v2" "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" @@ -189,36 +191,75 @@ func run(state overseer.State) { if remote { defer os.RemoveAll(repoPath) } - err = e.ScanGit(ctx, repoPath, *gitScanBranch, *gitScanSinceCommit, *gitScanMaxDepth, filter) - if err != nil { - logrus.WithError(err).Fatal("Failed to scan git.") + + g := func(c *sources.Config) { + c.RepoPath = repoPath + c.HeadRef = *gitScanBranch + c.BaseRef = *gitScanSinceCommit + c.MaxDepth = *gitScanMaxDepth + c.Filter = filter + } + + if err = e.ScanGit(ctx, sources.NewConfig(g)); err != nil { + logrus.WithError(err).Fatal("Failed to scan Git.") } case githubScan.FullCommand(): if len(*githubScanOrgs) == 0 && len(*githubScanRepos) == 0 { log.Fatal("You must specify at least one organization or repository.") } - err = e.ScanGitHub(ctx, *githubScanEndpoint, *githubScanRepos, *githubScanOrgs, *githubScanToken, *githubIncludeForks, filter, *concurrency, *githubIncludeMembers) - if err != nil { - logrus.WithError(err).Fatal("Failed to scan git.") + + github := func(c *sources.Config) { + c.Endpoint = *githubScanEndpoint + c.Repos = *githubScanRepos + c.Orgs = *githubScanOrgs + c.Token = *githubScanToken + c.IncludeForks = *githubIncludeForks + c.IncludeMembers = *githubIncludeMembers + c.Concurrency = *concurrency + } + + if err = e.ScanGitHub(ctx, sources.NewConfig(github)); err != nil { + logrus.WithError(err).Fatal("Failed to scan Github.") } case gitlabScan.FullCommand(): - err := e.ScanGitLab(ctx, *gitlabScanEndpoint, *gitlabScanToken, *gitlabScanRepos) - if err != nil { + gitlab := func(c *sources.Config) { + c.Endpoint = *gitlabScanEndpoint + c.Token = *gitlabScanToken + c.Repos = *gitlabScanRepos + } + + if err = e.ScanGitLab(ctx, sources.NewConfig(gitlab)); err != nil { logrus.WithError(err).Fatal("Failed to scan GitLab.") } case filesystemScan.FullCommand(): - err := e.ScanFileSystem(ctx, *filesystemDirectories) - if err != nil { + fs := func(c *sources.Config) { + c.Directories = *filesystemDirectories + } + + if err = e.ScanFileSystem(ctx, sources.NewConfig(fs)); err != nil { logrus.WithError(err).Fatal("Failed to scan filesystem") } case s3Scan.FullCommand(): - err := e.ScanS3(ctx, *s3ScanKey, *s3ScanSecret, *s3ScanCloudEnv, *s3ScanBuckets) - if err != nil { + s3 := func(c *sources.Config) { + c.Key = *s3ScanKey + c.Secret = *s3ScanSecret + c.Buckets = *s3ScanBuckets + } + + if err = e.ScanS3(ctx, sources.NewConfig(s3)); err != nil { logrus.WithError(err).Fatal("Failed to scan S3.") } case syslogScan.FullCommand(): - err := e.ScanSyslog(ctx, *syslogAddress, *syslogProtocol, *syslogTLSCert, *syslogTLSKey, *syslogFormat, *concurrency) - if err != nil { + syslog := func(c *sources.Config) { + c.Address = *syslogAddress + c.Protocol = *syslogProtocol + c.CertPath = *syslogTLSCert + c.KeyPath = *syslogTLSKey + c.Format = *syslogFormat + c.Concurrency = *concurrency + } + + if err = e.ScanSyslog(ctx, sources.NewConfig(syslog)); err != nil { logrus.WithError(err).Fatal("Failed to scan syslog.") } } diff --git a/pkg/engine/filesystem.go b/pkg/engine/filesystem.go index e245183c1ac8..5b3f3e46ef9f 100644 --- a/pkg/engine/filesystem.go +++ b/pkg/engine/filesystem.go @@ -6,15 +6,22 @@ import ( "github.com/go-errors/errors" "github.com/sirupsen/logrus" - "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" - "github.com/trufflesecurity/trufflehog/v3/pkg/sources/filesystem" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + + "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources/filesystem" ) -func (e *Engine) ScanFileSystem(ctx context.Context, directories []string) error { +// ScanFileSystem scans a given file system. +func (e *Engine) ScanFileSystem(ctx context.Context, c *sources.Config) error { + if c == nil { + return errors.New("nil config provided to ScanFileSystem") + } + connection := &sourcespb.Filesystem{ - Directories: directories, + Directories: c.Directories, } var conn anypb.Any err := anypb.MarshalFrom(&conn, connection, proto.MarshalOptions{}) diff --git a/pkg/engine/git.go b/pkg/engine/git.go index 2f434334a50e..f18ac416d8e9 100644 --- a/pkg/engine/git.go +++ b/pkg/engine/git.go @@ -11,34 +11,39 @@ import ( "github.com/go-git/go-git/v5/plumbing/object" "github.com/sirupsen/logrus" - "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/git" ) -func (e *Engine) ScanGit(ctx context.Context, repoPath, headRef, baseRef string, maxDepth int, filter *common.Filter) error { +// ScanGit scans any git source. +func (e *Engine) ScanGit(ctx context.Context, c *sources.Config) error { + if c == nil { + return errors.New("nil config for ScanGit") + } + logOptions := &gogit.LogOptions{} opts := []git.ScanOption{ - git.ScanOptionFilter(filter), + git.ScanOptionFilter(c.Filter), git.ScanOptionLogOptions(logOptions), } - repo, err := gogit.PlainOpenWithOptions(repoPath, &gogit.PlainOpenOptions{DetectDotGit: true}) + repo, err := gogit.PlainOpenWithOptions(c.RepoPath, &gogit.PlainOpenOptions{DetectDotGit: true}) if err != nil { - return fmt.Errorf("could open repo: %s: %w", repoPath, err) + return fmt.Errorf("could open repo: %s: %w", c.RepoPath, err) } var baseCommit *object.Commit - if len(baseRef) > 0 { - baseHash := plumbing.NewHash(baseRef) - if !plumbing.IsHash(baseRef) { - base, err := git.TryAdditionalBaseRefs(repo, baseRef) + if len(c.BaseRef) > 0 { + baseHash := plumbing.NewHash(c.BaseRef) + if !plumbing.IsHash(c.BaseRef) { + base, err := git.TryAdditionalBaseRefs(repo, c.BaseRef) if err != nil { return errors.WrapPrefix(err, "unable to resolve base ref", 0) } else { - baseRef = base.String() - baseCommit, _ = repo.CommitObject(plumbing.NewHash(baseRef)) + c.BaseRef = base.String() + baseCommit, _ = repo.CommitObject(plumbing.NewHash(c.BaseRef)) } } else { baseCommit, err = repo.CommitObject(baseHash) @@ -49,15 +54,15 @@ func (e *Engine) ScanGit(ctx context.Context, repoPath, headRef, baseRef string, } var headCommit *object.Commit - if len(headRef) > 0 { - headHash := plumbing.NewHash(headRef) - if !plumbing.IsHash(headRef) { - head, err := git.TryAdditionalBaseRefs(repo, headRef) + if len(c.HeadRef) > 0 { + headHash := plumbing.NewHash(c.HeadRef) + if !plumbing.IsHash(c.HeadRef) { + head, err := git.TryAdditionalBaseRefs(repo, c.HeadRef) if err != nil { return errors.WrapPrefix(err, "unable to resolve head ref", 0) } else { - headRef = head.String() - headCommit, _ = repo.CommitObject(plumbing.NewHash(baseRef)) + c.HeadRef = head.String() + headCommit, _ = repo.CommitObject(plumbing.NewHash(c.BaseRef)) } } else { headCommit, err = repo.CommitObject(headHash) @@ -67,23 +72,23 @@ func (e *Engine) ScanGit(ctx context.Context, repoPath, headRef, baseRef string, } } - // If baseCommit is an ancestor of headCommit, update baseRef to be the common ancestor. + // If baseCommit is an ancestor of headCommit, update c.BaseRef to be the common ancestor. if headCommit != nil && baseCommit != nil { mergeBase, err := headCommit.MergeBase(baseCommit) if err != nil || len(mergeBase) < 1 { return errors.WrapPrefix(err, "could not find common base between the given references", 0) } - baseRef = mergeBase[0].Hash.String() + c.BaseRef = mergeBase[0].Hash.String() } - if maxDepth != 0 { - opts = append(opts, git.ScanOptionMaxDepth(int64(maxDepth))) + if c.MaxDepth != 0 { + opts = append(opts, git.ScanOptionMaxDepth(int64(c.MaxDepth))) } - if baseRef != "" { - opts = append(opts, git.ScanOptionBaseHash(baseRef)) + if c.BaseRef != "" { + opts = append(opts, git.ScanOptionBaseHash(c.BaseRef)) } - if headRef != "" { - opts = append(opts, git.ScanOptionHeadCommit(headRef)) + if c.HeadRef != "" { + opts = append(opts, git.ScanOptionHeadCommit(c.HeadRef)) } scanOptions := git.NewScanOptions(opts...) @@ -106,7 +111,7 @@ func (e *Engine) ScanGit(ctx context.Context, repoPath, headRef, baseRef string, e.sourcesWg.Add(1) go func() { defer e.sourcesWg.Done() - err := gitSource.ScanRepo(ctx, repo, repoPath, scanOptions, e.ChunksChan()) + err := gitSource.ScanRepo(ctx, repo, c.RepoPath, scanOptions, e.ChunksChan()) if err != nil { logrus.WithError(err).Fatal("could not scan repo") } diff --git a/pkg/engine/git_test.go b/pkg/engine/git_test.go index ce0035e8a602..8ef368019e13 100644 --- a/pkg/engine/git_test.go +++ b/pkg/engine/git_test.go @@ -8,6 +8,7 @@ import ( "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/git" ) @@ -56,7 +57,14 @@ func TestGitEngine(t *testing.T) { WithDecoders(decoders.DefaultDecoders()...), WithDetectors(false, DefaultDetectors()...), ) - if err := e.ScanGit(ctx, path, tTest.branch, tTest.base, tTest.maxDepth, tTest.filter); err != nil { + cfg := sources.Config{ + RepoPath: path, + HeadRef: tTest.branch, + BaseRef: tTest.base, + MaxDepth: tTest.maxDepth, + Filter: tTest.filter, + } + if err := e.ScanGit(ctx, &cfg); err != nil { return } go e.Finish() @@ -106,7 +114,11 @@ func BenchmarkGitEngine(b *testing.B) { for i := 0; i < b.N; i++ { // TODO: this is measuring the time it takes to initialize the source // and not to do the full scan - if err := e.ScanGit(ctx, path, "", "", 0, common.FilterEmpty()); err != nil { + cfg := sources.Config{ + RepoPath: path, + Filter: common.FilterEmpty(), + } + if err := e.ScanGit(ctx, &cfg); err != nil { return } } diff --git a/pkg/engine/github.go b/pkg/engine/github.go index c3ebe9b5f40d..4a8b3e434ceb 100644 --- a/pkg/engine/github.go +++ b/pkg/engine/github.go @@ -3,38 +3,44 @@ package engine import ( "context" + "github.com/go-errors/errors" "github.com/sirupsen/logrus" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" - "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/github" ) -func (e *Engine) ScanGitHub(ctx context.Context, endpoint string, repos, orgs []string, token string, includeForks bool, filter *common.Filter, concurrency int, includeMembers bool) error { +// ScanGitHub scans Github with the provided options. +func (e *Engine) ScanGitHub(ctx context.Context, c *sources.Config) error { + if c == nil { + return errors.New("nil config provided for ScanGitHub") + } + source := github.Source{} connection := sourcespb.GitHub{ - Endpoint: endpoint, - Organizations: orgs, - Repositories: repos, - ScanUsers: includeMembers, + Endpoint: c.Endpoint, + Organizations: c.Orgs, + Repositories: c.Repos, + ScanUsers: c.IncludeMembers, } - if len(token) > 0 { + if len(c.Token) > 0 { connection.Credential = &sourcespb.GitHub_Token{ - Token: token, + Token: c.Token, } } else { connection.Credential = &sourcespb.GitHub_Unauthenticated{} } - connection.IncludeForks = includeForks + connection.IncludeForks = c.IncludeForks var conn anypb.Any err := anypb.MarshalFrom(&conn, &connection, proto.MarshalOptions{}) if err != nil { logrus.WithError(err).Error("failed to marshal github connection") return err } - err = source.Init(ctx, "trufflehog - github", 0, 0, false, &conn, concurrency) + err = source.Init(ctx, "trufflehog - github", 0, 0, false, &conn, c.Concurrency) if err != nil { logrus.WithError(err).Error("failed to initialize github source") return err diff --git a/pkg/engine/gitlab.go b/pkg/engine/gitlab.go index ea89be94d326..fd508fdb1f9a 100644 --- a/pkg/engine/gitlab.go +++ b/pkg/engine/gitlab.go @@ -6,31 +6,38 @@ import ( "github.com/go-errors/errors" "github.com/sirupsen/logrus" - "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" - "github.com/trufflesecurity/trufflehog/v3/pkg/sources/gitlab" "golang.org/x/net/context" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + + "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources/gitlab" ) -func (e *Engine) ScanGitLab(ctx context.Context, endpoint, token string, repositories []string) error { +// ScanGitLab scans GitLab with the provided configuration. +func (e *Engine) ScanGitLab(ctx context.Context, c *sources.Config) error { + if c == nil { + return errors.New("config is nil for ScanGitlab") + } + connection := &sourcespb.GitLab{} switch { - case len(token) > 0: + case len(c.Token) > 0: connection.Credential = &sourcespb.GitLab_Token{ - Token: token, + Token: c.Token, } default: return fmt.Errorf("must provide token") } - if len(endpoint) > 0 { - connection.Endpoint = endpoint + if len(c.Endpoint) > 0 { + connection.Endpoint = c.Endpoint } - if len(repositories) > 0 { - connection.Repositories = repositories + if len(c.Repos) > 0 { + connection.Repositories = c.Repos } var conn anypb.Any diff --git a/pkg/engine/s3.go b/pkg/engine/s3.go index 84d31bf5c4d5..de25b8f39ab5 100644 --- a/pkg/engine/s3.go +++ b/pkg/engine/s3.go @@ -7,33 +7,40 @@ import ( "github.com/go-errors/errors" "github.com/sirupsen/logrus" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/s3" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" ) -func (e *Engine) ScanS3(ctx context.Context, key, secret string, cloudCred bool, buckets []string) error { +// ScanS3 scans S3 buckets. +func (e *Engine) ScanS3(ctx context.Context, c *sources.Config) error { + if c == nil { + return errors.New("nil config provided for ScanS3") + } + connection := &sourcespb.S3{ Credential: &sourcespb.S3_Unauthenticated{}, } - if cloudCred { - if len(key) > 0 || len(secret) > 0 { + if c.CloudCred { + if len(c.Key) > 0 || len(c.Secret) > 0 { return fmt.Errorf("cannot use cloud credentials and basic auth together") } connection.Credential = &sourcespb.S3_CloudEnvironment{} } - if len(key) > 0 && len(secret) > 0 { + if len(c.Key) > 0 && len(c.Secret) > 0 { connection.Credential = &sourcespb.S3_AccessKey{ AccessKey: &credentialspb.KeySecret{ - Key: key, - Secret: secret, + Key: c.Key, + Secret: c.Secret, }, } } - if len(buckets) > 0 { - connection.Buckets = buckets + if len(c.Buckets) > 0 { + connection.Buckets = c.Buckets } var conn anypb.Any err := anypb.MarshalFrom(&conn, connection, proto.MarshalOptions{}) diff --git a/pkg/engine/syslog.go b/pkg/engine/syslog.go index 92b011ae4b6b..8fc874d34d65 100644 --- a/pkg/engine/syslog.go +++ b/pkg/engine/syslog.go @@ -10,24 +10,30 @@ import ( "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" + "github.com/trufflesecurity/trufflehog/v3/pkg/sources" "github.com/trufflesecurity/trufflehog/v3/pkg/sources/syslog" ) -func (e *Engine) ScanSyslog(ctx context.Context, address, protocol, certPath, keyPath, format string, concurrency int) error { +// ScanSyslog is a source that scans syslog files. +func (e *Engine) ScanSyslog(ctx context.Context, c *sources.Config) error { + if c == nil { + return errors.New("nil config provided for ScanSyslog") + } + connection := &sourcespb.Syslog{ - Protocol: protocol, - ListenAddress: address, - Format: format, + Protocol: c.Protocol, + ListenAddress: c.Address, + Format: c.Format, } - if certPath != "" && keyPath != "" { - cert, err := os.ReadFile(certPath) + if c.CertPath != "" && c.KeyPath != "" { + cert, err := os.ReadFile(c.CertPath) if err != nil { return errors.WrapPrefix(err, "could not open TLS cert file", 0) } connection.TlsCert = string(cert) - key, err := os.ReadFile(keyPath) + key, err := os.ReadFile(c.KeyPath) if err != nil { return errors.WrapPrefix(err, "could not open TLS key file", 0) } @@ -40,7 +46,7 @@ func (e *Engine) ScanSyslog(ctx context.Context, address, protocol, certPath, ke return errors.WrapPrefix(err, "error unmarshalling connection", 0) } source := syslog.Source{} - err = source.Init(ctx, "trufflehog - syslog", 0, 0, false, &conn, concurrency) + err = source.Init(ctx, "trufflehog - syslog", 0, 0, false, &conn, c.Concurrency) source.InjectConnection(connection) if err != nil { logrus.WithError(err).Error("failed to initialize syslog source") diff --git a/pkg/sources/filesystem/filesystem_test.go b/pkg/sources/filesystem/filesystem_test.go index ce4b1b8bb572..21c39427bda1 100644 --- a/pkg/sources/filesystem/filesystem_test.go +++ b/pkg/sources/filesystem/filesystem_test.go @@ -6,9 +6,8 @@ import ( "time" "github.com/kylelemons/godebug/pretty" - "google.golang.org/protobuf/types/known/anypb" - log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/anypb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" diff --git a/pkg/sources/github/github_integration_test.go b/pkg/sources/github/github_integration_test.go index 78477c627326..a5ae67f69082 100644 --- a/pkg/sources/github/github_integration_test.go +++ b/pkg/sources/github/github_integration_test.go @@ -39,7 +39,7 @@ func TestSource_Scan(t *testing.T) { // For the personal access token test githubToken := secret.MustGetField("GITHUB_TOKEN") - //For the NEW github app test (+Member enum) + // For the NEW github app test (+Member enum) githubPrivateKeyB64New := secret.MustGetField("GITHUB_PRIVATE_KEY_NEW") githubPrivateKeyBytesNew, err := base64.StdEncoding.DecodeString(githubPrivateKeyB64New) if err != nil { @@ -49,7 +49,7 @@ func TestSource_Scan(t *testing.T) { githubInstallationIDNew := secret.MustGetField("GITHUB_INSTALLATION_ID_NEW") githubAppIDNew := secret.MustGetField("GITHUB_APP_ID_NEW") - //OLD app for breaking app change tests + // OLD app for breaking app change tests // githubPrivateKeyB64 := secret.MustGetField("GITHUB_PRIVATE_KEY") // githubPrivateKeyBytes, err := base64.StdEncoding.DecodeString(githubPrivateKeyB64) // if err != nil { @@ -343,7 +343,7 @@ func TestSource_Scan(t *testing.T) { s := Source{} log.SetLevel(log.DebugLevel) - //uncomment for windows Testing + // uncomment for windows Testing log.SetFormatter(&log.TextFormatter{ForceColors: true}) log.SetOutput(colorable.NewColorableStdout()) @@ -368,7 +368,7 @@ func TestSource_Scan(t *testing.T) { return } }() - if err = common.HandleTestChannel(chunksCh, basicCheckFunc(tt.minOrg, tt.minRepo, tt.wantChunk, &s)); err != nil { + if err = sources.HandleTestChannel(chunksCh, basicCheckFunc(tt.minOrg, tt.minRepo, tt.wantChunk, &s)); err != nil { t.Error(err) } }) @@ -386,7 +386,7 @@ func TestSource_paginateGists(t *testing.T) { if err != nil { t.Fatal(fmt.Errorf("failed to access secret: %v", err)) } - //For the NEW github app test (+Member enum) + // For the NEW github app test (+Member enum) githubPrivateKeyB64New := secret.MustGetField("GITHUB_PRIVATE_KEY_NEW") githubPrivateKeyBytesNew, err := base64.StdEncoding.DecodeString(githubPrivateKeyB64New) if err != nil { @@ -492,7 +492,7 @@ func TestSource_paginateGists(t *testing.T) { s := Source{} log.SetLevel(log.DebugLevel) - //uncomment for windows Testing + // uncomment for windows Testing log.SetFormatter(&log.TextFormatter{ForceColors: true}) log.SetOutput(colorable.NewColorableStdout()) @@ -515,14 +515,14 @@ func TestSource_paginateGists(t *testing.T) { if tt.wantChunk != nil { wantedRepo = tt.wantChunk.SourceMetadata.GetGithub().Repository } - if err = common.HandleTestChannel(chunksCh, gistsCheckFunc(wantedRepo, tt.minRepos, &s)); err != nil { + if err = sources.HandleTestChannel(chunksCh, gistsCheckFunc(wantedRepo, tt.minRepos, &s)); err != nil { t.Error(err) } }) } } -func gistsCheckFunc(expected string, minRepos int, s *Source) common.ChunkFunc { +func gistsCheckFunc(expected string, minRepos int, s *Source) sources.ChunkFunc { return func(chunk *sources.Chunk) error { if minRepos != 0 && minRepos > len(s.repos) { return fmt.Errorf("didn't find enough repos. expected: %d, got :%d", minRepos, len(s.repos)) @@ -539,7 +539,7 @@ func gistsCheckFunc(expected string, minRepos int, s *Source) common.ChunkFunc { } } -func basicCheckFunc(minOrg, minRepo int, wantChunk *sources.Chunk, s *Source) common.ChunkFunc { +func basicCheckFunc(minOrg, minRepo int, wantChunk *sources.Chunk, s *Source) sources.ChunkFunc { return func(chunk *sources.Chunk) error { if minOrg != 0 && minOrg > len(s.orgs) { return fmt.Errorf("incorrect number of orgs. expected at least: %d, got %d", minOrg, len(s.orgs)) @@ -551,7 +551,7 @@ func basicCheckFunc(minOrg, minRepo int, wantChunk *sources.Chunk, s *Source) co if diff := pretty.Compare(chunk.SourceMetadata.GetGithub().Repository, wantChunk.SourceMetadata.GetGithub().Repository); diff == "" { return nil } - return common.MatchError + return sources.MatchError } return nil } diff --git a/pkg/sources/gitlab/gitlab_test.go b/pkg/sources/gitlab/gitlab_test.go index c23087116997..8aa373f2798f 100644 --- a/pkg/sources/gitlab/gitlab_test.go +++ b/pkg/sources/gitlab/gitlab_test.go @@ -8,14 +8,12 @@ import ( "testing" "github.com/kylelemons/godebug/pretty" + log "github.com/sirupsen/logrus" "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" - - log "github.com/sirupsen/logrus" - - "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) diff --git a/pkg/sources/s3/s3_test.go b/pkg/sources/s3/s3_test.go index 993ebfce91d9..5ce7cd2d3aea 100644 --- a/pkg/sources/s3/s3_test.go +++ b/pkg/sources/s3/s3_test.go @@ -11,11 +11,12 @@ import ( "github.com/kylelemons/godebug/pretty" log "github.com/sirupsen/logrus" + "google.golang.org/protobuf/types/known/anypb" + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/credentialspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" - "google.golang.org/protobuf/types/known/anypb" ) func TestSource_Chunks(t *testing.T) { diff --git a/pkg/sources/sources.go b/pkg/sources/sources.go index eaecae20c658..7e2eb4669443 100644 --- a/pkg/sources/sources.go +++ b/pkg/sources/sources.go @@ -4,9 +4,11 @@ import ( "context" "sync" + "google.golang.org/protobuf/types/known/anypb" + + "github.com/trufflesecurity/trufflehog/v3/pkg/common" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" - "google.golang.org/protobuf/types/known/anypb" ) // Chunk contains data to be decoded and scanned along with context on where it came from. @@ -42,6 +44,66 @@ type Source interface { GetProgress() *Progress } +// Config defines the optional configuration for a source. +type Config struct { + // Endpoint is the endpoint of the source. + Endpoint, + // Repo is the repository to scan. + Repo, + // Token is the token to use to authenticate with the source. + Token, + // Key is any key to use to authenticate with the source. (ex: S3) + Key, + // Secret is any secret to use to authenticate with the source. (ex: S3) + Secret, + // Address used to connect to the source. (ex: syslog) + Address, + // Protocol used to connect to the source. + Protocol, + // CertPath is the path to the certificate to use to connect to the source. + CertPath, + // KeyPath is the path to the key to use to connect to the source. + KeyPath, + // Format is the format used to connect to the source. + Format, + // RepoPath is the path to the repository to scan. + RepoPath, + // HeadRef is the head reference to use to scan from. + HeadRef, + // BaseRef is the base reference to use to scan from. + BaseRef string + // Concurrency is the number of concurrent workers to use to scan the source. + Concurrency, + // MaxDepth is the maximum depth to scan the source. + MaxDepth int + // IncludeForks indicates whether to include forks in the scan. + IncludeForks, + // IncludeMembers indicates whether to include members in the scan. + IncludeMembers, + // CloudCred determines whether to use cloud credentials. + // This can NOT be used with a secret. + CloudCred bool + // Repos is the list of repositories to scan. + Repos, + // Orgs is the list of organizations to scan. + Orgs, + // Buckets is the list of buckets to scan. + Buckets, + // Directories is the list of directories to scan. + Directories []string + // Filter is the filter to use to scan the source. + Filter *common.Filter +} + +// NewConfig returns a new Config with optional values. +func NewConfig(opts ...func(*Config)) *Config { + c := &Config{} + for _, opt := range opts { + opt(c) + } + return c +} + // PercentComplete is used to update job completion percentages across sources type Progress struct { mut sync.Mutex @@ -68,7 +130,7 @@ func (p *Progress) SetProgressComplete(i, scope int, message, encodedResumeInfo p.PercentComplete = int64((float64(i) / float64(scope)) * 100) } -//GetProgressComplete gets job completion percentage for metrics reporting +// GetProgressComplete gets job completion percentage for metrics reporting func (p *Progress) GetProgress() *Progress { p.mut.Lock() defer p.mut.Unlock() diff --git a/pkg/common/test_helpers.go b/pkg/sources/test_helpers.go similarity index 66% rename from pkg/common/test_helpers.go rename to pkg/sources/test_helpers.go index e08e08effba2..967b2fd0137d 100644 --- a/pkg/common/test_helpers.go +++ b/pkg/sources/test_helpers.go @@ -1,18 +1,16 @@ -package common +package sources import ( "errors" "fmt" "time" - - "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) -type ChunkFunc func(chunk *sources.Chunk) error +type ChunkFunc func(chunk *Chunk) error var MatchError = errors.New("chunk doesn't match") -func HandleTestChannel(chunksCh chan *sources.Chunk, cf ChunkFunc) error { +func HandleTestChannel(chunksCh chan *Chunk, cf ChunkFunc) error { for { select { case gotChunk := <-chunksCh: