diff --git a/pkg/cmd/config.go b/pkg/cmd/config.go index 077aff4..ef3c9e8 100644 --- a/pkg/cmd/config.go +++ b/pkg/cmd/config.go @@ -11,6 +11,8 @@ import ( "github.com/stefanoj3/dirstalk/pkg/scan" ) +const failedToReadPropertyError = "failed to read %s" + func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) { c := &scan.Config{} @@ -18,28 +20,32 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) { c.DictionaryPath = cmd.Flag(flagScanDictionary).Value.String() + if c.DictionaryTimeoutInMilliseconds, err = cmd.Flags().GetInt(flagScanDictionaryGetTimeout); err != nil { + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanDictionaryGetTimeout) + } + if c.HTTPMethods, err = cmd.Flags().GetStringSlice(flagScanHTTPMethods); err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPMethods) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPMethods) } if c.HTTPStatusesToIgnore, err = cmd.Flags().GetIntSlice(flagScanHTTPStatusesToIgnore); err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPStatusesToIgnore) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPStatusesToIgnore) } if c.Threads, err = cmd.Flags().GetInt(flagScanThreads); err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanThreads) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanThreads) } if c.TimeoutInMilliseconds, err = cmd.Flags().GetInt(flagScanHTTPTimeout); err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPTimeout) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPTimeout) } if c.CacheRequests, err = cmd.Flags().GetBool(flagScanHTTPCacheRequests); err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPCacheRequests) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPCacheRequests) } if c.ScanDepth, err = cmd.Flags().GetInt(flagScanScanDepth); err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanScanDepth) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanScanDepth) } socks5Host := cmd.Flag(flagScanSocks5Host).Value.String() @@ -52,12 +58,12 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) { c.UserAgent = cmd.Flag(flagScanUserAgent).Value.String() if c.UseCookieJar, err = cmd.Flags().GetBool(flagScanCookieJar); err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanCookieJar) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanCookieJar) } rawCookies, err := cmd.Flags().GetStringArray(flagScanCookie) if err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanCookie) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanCookie) } if c.Cookies, err = rawCookiesToCookies(rawCookies); err != nil { @@ -66,7 +72,7 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) { rawHeaders, err := cmd.Flags().GetStringArray(flagScanHeader) if err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagScanHeader) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHeader) } if c.Headers, err = rawHeadersToHeaders(rawHeaders); err != nil { @@ -77,7 +83,7 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) { c.ShouldSkipSSLCertificatesValidation, err = cmd.Flags().GetBool(flagShouldSkipSSLCertificatesValidation) if err != nil { - return nil, errors.Wrapf(err, "failed to read %s", flagShouldSkipSSLCertificatesValidation) + return nil, errors.Wrapf(err, failedToReadPropertyError, flagShouldSkipSSLCertificatesValidation) } return c, nil diff --git a/pkg/cmd/flags.go b/pkg/cmd/flags.go index d10cd05..3888ec9 100644 --- a/pkg/cmd/flags.go +++ b/pkg/cmd/flags.go @@ -8,6 +8,7 @@ const ( // Scan flags flagScanDictionary = "dictionary" flagScanDictionaryShort = "d" + flagScanDictionaryGetTimeout = "dictionary-get-timeout" flagScanHTTPMethods = "http-methods" flagScanHTTPStatusesToIgnore = "http-statuses-to-ignore" flagScanHTTPTimeout = "http-timeout" diff --git a/pkg/cmd/scan.go b/pkg/cmd/scan.go index 870ec69..50ceb26 100644 --- a/pkg/cmd/scan.go +++ b/pkg/cmd/scan.go @@ -39,6 +39,13 @@ func NewScanCommand(logger *logrus.Logger) *cobra.Command { common.Must(cmd.MarkFlagFilename(flagScanDictionary)) common.Must(cmd.MarkFlagRequired(flagScanDictionary)) + cmd.Flags().IntP( + flagScanDictionaryGetTimeout, + "", + 50000, + "timeout in milliseconds (used when fetching remote dictionary)", + ) + cmd.Flags().StringSlice( flagScanHTTPMethods, []string{"GET"}, @@ -164,39 +171,16 @@ func getURL(args []string) (*url.URL, error) { // startScan is a convenience method that wires together all the dependencies needed to start a scan func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error { - c, err := client.NewClientFromConfig( - cnf.TimeoutInMilliseconds, - cnf.Socks5Url, - cnf.UserAgent, - cnf.UseCookieJar, - cnf.Cookies, - cnf.Headers, - cnf.CacheRequests, - cnf.ShouldSkipSSLCertificatesValidation, - u, - ) + dict, err := buildDictionary(cnf, u) if err != nil { - return errors.Wrap(err, "failed to build client") + return err } - dict, err := dictionary.NewDictionaryFrom(cnf.DictionaryPath, c) + s, err := buildScanner(cnf, dict, u, logger) if err != nil { - return errors.Wrap(err, "failed to build dictionary") + return err } - targetProducer := producer.NewDictionaryProducer(cnf.HTTPMethods, dict, cnf.ScanDepth) - reproducer := producer.NewReProducer(targetProducer) - - resultFilter := filter.NewHTTPStatusResultFilter(cnf.HTTPStatusesToIgnore) - - s := scan.NewScanner( - c, - targetProducer, - reproducer, - resultFilter, - logger, - ) - logger.WithFields(logrus.Fields{ "url": u.String(), "threads": cnf.Threads, @@ -265,6 +249,80 @@ func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error { } } +func buildScanner(cnf *scan.Config, dict []string, u *url.URL, logger *logrus.Logger) (*scan.Scanner, error) { + targetProducer := producer.NewDictionaryProducer(cnf.HTTPMethods, dict, cnf.ScanDepth) + reproducer := producer.NewReProducer(targetProducer) + + resultFilter := filter.NewHTTPStatusResultFilter(cnf.HTTPStatusesToIgnore) + + scannerClient, err := buildScannerClient(cnf, u) + if err != nil { + return nil, err + } + + s := scan.NewScanner( + scannerClient, + targetProducer, + reproducer, + resultFilter, + logger, + ) + + return s, nil +} + +func buildDictionary(cnf *scan.Config, u *url.URL) ([]string, error) { + c, err := buildDictionaryClient(cnf, u) + if err != nil { + return nil, err + } + + dict, err := dictionary.NewDictionaryFrom(cnf.DictionaryPath, c) + if err != nil { + return nil, errors.Wrap(err, "failed to build dictionary") + } + + return dict, nil +} + +func buildScannerClient(cnf *scan.Config, u *url.URL) (*http.Client, error) { + c, err := client.NewClientFromConfig( + cnf.TimeoutInMilliseconds, + cnf.Socks5Url, + cnf.UserAgent, + cnf.UseCookieJar, + cnf.Cookies, + cnf.Headers, + cnf.CacheRequests, + cnf.ShouldSkipSSLCertificatesValidation, + u, + ) + if err != nil { + return nil, errors.Wrap(err, "failed to build scanner client") + } + + return c, nil +} + +func buildDictionaryClient(cnf *scan.Config, u *url.URL) (*http.Client, error) { + c, err := client.NewClientFromConfig( + cnf.DictionaryTimeoutInMilliseconds, + cnf.Socks5Url, + cnf.UserAgent, + cnf.UseCookieJar, + cnf.Cookies, + cnf.Headers, + cnf.CacheRequests, + cnf.ShouldSkipSSLCertificatesValidation, + u, + ) + if err != nil { + return nil, errors.Wrap(err, "failed to build dictionary client") + } + + return c, nil +} + func newOutputSaver(path string) (OutputSaver, error) { if path == "" { return output.NewNullSaver(), nil diff --git a/pkg/cmd/scan_integration_test.go b/pkg/cmd/scan_integration_test.go index 40718b3..334fcf6 100644 --- a/pkg/cmd/scan_integration_test.go +++ b/pkg/cmd/scan_integration_test.go @@ -760,3 +760,79 @@ func startSocks5TestServer(t *testing.T) net.Listener { return listener } + +func TestScanShouldFailIfDictionaryFetchExceedTimeout(t *testing.T) { + logger, _ := test.NewLogger() + + c := createCommand(logger) + assert.NotNil(t, c) + + testServer, serverAssertion := test.NewServerWithAssertion( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + }), + ) + defer testServer.Close() + + dictionaryTestServer, _ := test.NewServerWithAssertion( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Second) + w.Write([]byte("/dictionary/entry")) //nolint + }), + ) + defer dictionaryTestServer.Close() + + err := executeCommand( + c, + "scan", + testServer.URL, + "--dictionary", + dictionaryTestServer.URL, + "--dictionary-get-timeout", + "5", + ) + assert.Error(t, err) + + assert.Contains(t, err.Error(), "dictionary: failed to get") + assert.Contains(t, err.Error(), "Client.Timeout exceeded while awaiting headers") + + assert.Equal(t, 0, serverAssertion.Len()) +} + +func TestScanShouldBeAbleToFetchRemoteDictionary(t *testing.T) { + logger, _ := test.NewLogger() + + c := createCommand(logger) + assert.NotNil(t, c) + + testServer, serverAssertion := test.NewServerWithAssertion( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + }), + ) + defer testServer.Close() + + dictionaryTestServer, dictionaryTestServerAssertion := test.NewServerWithAssertion( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("/dictionary/entry")) //nolint + }), + ) + defer dictionaryTestServer.Close() + + err := executeCommand( + c, + "scan", + testServer.URL, + "--dictionary", + dictionaryTestServer.URL, + "--dictionary-get-timeout", + "500", + ) + assert.NoError(t, err) + + assert.Equal(t, 1, dictionaryTestServerAssertion.Len()) + assert.Equal(t, 1, serverAssertion.Len()) + + serverAssertion.At(0, func(r http.Request) { + assert.Equal(t, "/dictionary/entry", r.URL.Path) + }) +} diff --git a/pkg/scan/config.go b/pkg/scan/config.go index ce1ed53..7e7bb2c 100644 --- a/pkg/scan/config.go +++ b/pkg/scan/config.go @@ -8,6 +8,7 @@ import ( // Config represents the configuration needed to perform a scan type Config struct { DictionaryPath string + DictionaryTimeoutInMilliseconds int HTTPMethods []string HTTPStatusesToIgnore []int Threads int