Skip to content

Commit

Permalink
Adding flag to specify dictionary fetch timeout separately
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanoj3 committed Dec 15, 2019
1 parent 1072117 commit 10f7a63
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 37 deletions.
26 changes: 16 additions & 10 deletions pkg/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,41 @@ import (
"github.com/stefanoj3/dirstalk/pkg/scan"
)

const failedToReadPropertyError = "failed to read %s"

func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) {
c := &scan.Config{}

var err error

c.DictionaryPath = cmd.Flag(flagScanDictionary).Value.String()

if c.DictionaryTimeoutInMilliseconds, err = cmd.Flags().GetInt(flagScanDictionaryGetTimeout); err != nil {
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanDictionaryGetTimeout)
}

if c.HTTPMethods, err = cmd.Flags().GetStringSlice(flagScanHTTPMethods); err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPMethods)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPMethods)
}

if c.HTTPStatusesToIgnore, err = cmd.Flags().GetIntSlice(flagScanHTTPStatusesToIgnore); err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPStatusesToIgnore)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPStatusesToIgnore)
}

if c.Threads, err = cmd.Flags().GetInt(flagScanThreads); err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanThreads)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanThreads)
}

if c.TimeoutInMilliseconds, err = cmd.Flags().GetInt(flagScanHTTPTimeout); err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPTimeout)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPTimeout)
}

if c.CacheRequests, err = cmd.Flags().GetBool(flagScanHTTPCacheRequests); err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanHTTPCacheRequests)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHTTPCacheRequests)
}

if c.ScanDepth, err = cmd.Flags().GetInt(flagScanScanDepth); err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanScanDepth)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanScanDepth)
}

socks5Host := cmd.Flag(flagScanSocks5Host).Value.String()
Expand All @@ -52,12 +58,12 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) {
c.UserAgent = cmd.Flag(flagScanUserAgent).Value.String()

if c.UseCookieJar, err = cmd.Flags().GetBool(flagScanCookieJar); err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanCookieJar)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanCookieJar)
}

rawCookies, err := cmd.Flags().GetStringArray(flagScanCookie)
if err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanCookie)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanCookie)
}

if c.Cookies, err = rawCookiesToCookies(rawCookies); err != nil {
Expand All @@ -66,7 +72,7 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) {

rawHeaders, err := cmd.Flags().GetStringArray(flagScanHeader)
if err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagScanHeader)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagScanHeader)
}

if c.Headers, err = rawHeadersToHeaders(rawHeaders); err != nil {
Expand All @@ -77,7 +83,7 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) {

c.ShouldSkipSSLCertificatesValidation, err = cmd.Flags().GetBool(flagShouldSkipSSLCertificatesValidation)
if err != nil {
return nil, errors.Wrapf(err, "failed to read %s", flagShouldSkipSSLCertificatesValidation)
return nil, errors.Wrapf(err, failedToReadPropertyError, flagShouldSkipSSLCertificatesValidation)
}

return c, nil
Expand Down
1 change: 1 addition & 0 deletions pkg/cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ const (
// Scan flags
flagScanDictionary = "dictionary"
flagScanDictionaryShort = "d"
flagScanDictionaryGetTimeout = "dictionary-get-timeout"
flagScanHTTPMethods = "http-methods"
flagScanHTTPStatusesToIgnore = "http-statuses-to-ignore"
flagScanHTTPTimeout = "http-timeout"
Expand Down
112 changes: 85 additions & 27 deletions pkg/cmd/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ func NewScanCommand(logger *logrus.Logger) *cobra.Command {
common.Must(cmd.MarkFlagFilename(flagScanDictionary))
common.Must(cmd.MarkFlagRequired(flagScanDictionary))

cmd.Flags().IntP(
flagScanDictionaryGetTimeout,
"",
50000,
"timeout in milliseconds (used when fetching remote dictionary)",
)

cmd.Flags().StringSlice(
flagScanHTTPMethods,
[]string{"GET"},
Expand Down Expand Up @@ -164,39 +171,16 @@ func getURL(args []string) (*url.URL, error) {

// startScan is a convenience method that wires together all the dependencies needed to start a scan
func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error {
c, err := client.NewClientFromConfig(
cnf.TimeoutInMilliseconds,
cnf.Socks5Url,
cnf.UserAgent,
cnf.UseCookieJar,
cnf.Cookies,
cnf.Headers,
cnf.CacheRequests,
cnf.ShouldSkipSSLCertificatesValidation,
u,
)
dict, err := buildDictionary(cnf, u)
if err != nil {
return errors.Wrap(err, "failed to build client")
return err
}

dict, err := dictionary.NewDictionaryFrom(cnf.DictionaryPath, c)
s, err := buildScanner(cnf, dict, u, logger)
if err != nil {
return errors.Wrap(err, "failed to build dictionary")
return err
}

targetProducer := producer.NewDictionaryProducer(cnf.HTTPMethods, dict, cnf.ScanDepth)
reproducer := producer.NewReProducer(targetProducer)

resultFilter := filter.NewHTTPStatusResultFilter(cnf.HTTPStatusesToIgnore)

s := scan.NewScanner(
c,
targetProducer,
reproducer,
resultFilter,
logger,
)

logger.WithFields(logrus.Fields{
"url": u.String(),
"threads": cnf.Threads,
Expand Down Expand Up @@ -265,6 +249,80 @@ func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error {
}
}

func buildScanner(cnf *scan.Config, dict []string, u *url.URL, logger *logrus.Logger) (*scan.Scanner, error) {
targetProducer := producer.NewDictionaryProducer(cnf.HTTPMethods, dict, cnf.ScanDepth)
reproducer := producer.NewReProducer(targetProducer)

resultFilter := filter.NewHTTPStatusResultFilter(cnf.HTTPStatusesToIgnore)

scannerClient, err := buildScannerClient(cnf, u)
if err != nil {
return nil, err
}

s := scan.NewScanner(
scannerClient,
targetProducer,
reproducer,
resultFilter,
logger,
)

return s, nil
}

func buildDictionary(cnf *scan.Config, u *url.URL) ([]string, error) {
c, err := buildDictionaryClient(cnf, u)
if err != nil {
return nil, err
}

dict, err := dictionary.NewDictionaryFrom(cnf.DictionaryPath, c)
if err != nil {
return nil, errors.Wrap(err, "failed to build dictionary")
}

return dict, nil
}

func buildScannerClient(cnf *scan.Config, u *url.URL) (*http.Client, error) {
c, err := client.NewClientFromConfig(
cnf.TimeoutInMilliseconds,
cnf.Socks5Url,
cnf.UserAgent,
cnf.UseCookieJar,
cnf.Cookies,
cnf.Headers,
cnf.CacheRequests,
cnf.ShouldSkipSSLCertificatesValidation,
u,
)
if err != nil {
return nil, errors.Wrap(err, "failed to build scanner client")
}

return c, nil
}

func buildDictionaryClient(cnf *scan.Config, u *url.URL) (*http.Client, error) {
c, err := client.NewClientFromConfig(
cnf.DictionaryTimeoutInMilliseconds,
cnf.Socks5Url,
cnf.UserAgent,
cnf.UseCookieJar,
cnf.Cookies,
cnf.Headers,
cnf.CacheRequests,
cnf.ShouldSkipSSLCertificatesValidation,
u,
)
if err != nil {
return nil, errors.Wrap(err, "failed to build dictionary client")
}

return c, nil
}

func newOutputSaver(path string) (OutputSaver, error) {
if path == "" {
return output.NewNullSaver(), nil
Expand Down
76 changes: 76 additions & 0 deletions pkg/cmd/scan_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,3 +760,79 @@ func startSocks5TestServer(t *testing.T) net.Listener {

return listener
}

func TestScanShouldFailIfDictionaryFetchExceedTimeout(t *testing.T) {
logger, _ := test.NewLogger()

c := createCommand(logger)
assert.NotNil(t, c)

testServer, serverAssertion := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}),
)
defer testServer.Close()

dictionaryTestServer, _ := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Second)
w.Write([]byte("/dictionary/entry")) //nolint
}),
)
defer dictionaryTestServer.Close()

err := executeCommand(
c,
"scan",
testServer.URL,
"--dictionary",
dictionaryTestServer.URL,
"--dictionary-get-timeout",
"5",
)
assert.Error(t, err)

assert.Contains(t, err.Error(), "dictionary: failed to get")
assert.Contains(t, err.Error(), "Client.Timeout exceeded while awaiting headers")

assert.Equal(t, 0, serverAssertion.Len())
}

func TestScanShouldBeAbleToFetchRemoteDictionary(t *testing.T) {
logger, _ := test.NewLogger()

c := createCommand(logger)
assert.NotNil(t, c)

testServer, serverAssertion := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}),
)
defer testServer.Close()

dictionaryTestServer, dictionaryTestServerAssertion := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("/dictionary/entry")) //nolint
}),
)
defer dictionaryTestServer.Close()

err := executeCommand(
c,
"scan",
testServer.URL,
"--dictionary",
dictionaryTestServer.URL,
"--dictionary-get-timeout",
"500",
)
assert.NoError(t, err)

assert.Equal(t, 1, dictionaryTestServerAssertion.Len())
assert.Equal(t, 1, serverAssertion.Len())

serverAssertion.At(0, func(r http.Request) {
assert.Equal(t, "/dictionary/entry", r.URL.Path)
})
}
1 change: 1 addition & 0 deletions pkg/scan/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
// Config represents the configuration needed to perform a scan
type Config struct {
DictionaryPath string
DictionaryTimeoutInMilliseconds int
HTTPMethods []string
HTTPStatusesToIgnore []int
Threads int
Expand Down

0 comments on commit 10f7a63

Please sign in to comment.