diff --git a/middlewares/auth.go b/middlewares/auth.go index a5c8775..d7ead81 100644 --- a/middlewares/auth.go +++ b/middlewares/auth.go @@ -3,10 +3,30 @@ package middlewares import ( "strings" + "crypto/sha256" + "crypto/subtle" + "github.com/gin-gonic/gin" "github.com/vsouza/go-gin-boilerplate/config" ) +func sha256Sum(s string) []byte { + sum := sha256.Sum256([]byte(s)) + arr := make([]byte, len(sum)) + copy(arr, sum[:]) + + return arr +} + +// secureCompare calculates sha256 hash of parameters a and b and does constant time comparison +// to avoid time based attacks. +func secureCompare(a, b string) int { + aSum := sha256Sum(a) + bSum := sha256Sum(b) + + return subtle.ConstantTimeCompare(aSum, bSum) +} + func AuthMiddleware() gin.HandlerFunc { return func(c *gin.Context) { config := config.GetConfig() @@ -21,7 +41,10 @@ func AuthMiddleware() gin.HandlerFunc { if secret = config.GetString("http.auth.secret"); len(strings.TrimSpace(secret)) == 0 { c.AbortWithStatus(401) } - if key != reqKey || secret != reqSecret { + + isKeysEqual := secureCompare(key, reqKey) == 1 + isSecretsEqual := secureCompare(secret, reqSecret) == 1 + if !isKeysEqual || !isSecretsEqual { c.AbortWithStatus(401) return } diff --git a/middlewares/auth_test.go b/middlewares/auth_test.go new file mode 100644 index 0000000..ec7c62f --- /dev/null +++ b/middlewares/auth_test.go @@ -0,0 +1,47 @@ +package middlewares + +import "testing" + +func Test_secureCompare(t *testing.T) { + type args struct { + a string + b string + } + tests := []struct { + name string + args args + want int + }{ + { + name: "Parameters a and b are equal", + args: args{ + a: "abc123", + b: "abc123", + }, + want: 1, + }, + { + name: "Parameters a and b are not equal", + args: args{ + a: "123abc", + b: "abc123", + }, + want: 0, + }, + { + name: "Parameters a and b are almost equal", + args: args{ + a: "abc123", + b: "abd123", + }, + want: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := secureCompare(tt.args.a, tt.args.b); got != tt.want { + t.Errorf("secureCompare() = %v, want %v", got, tt.want) + } + }) + } +}