Skip to content

Commit

Permalink
Complete rework of AuthenticationMiddleware...
Browse files Browse the repository at this point in the history
- to align methods and enum with PsychicHttp and Arduino WebServer
- to support hash
- to pre-compute base64 / digest hash to speed up requests
Closes #111
  • Loading branch information
mathieucarbou committed Oct 1, 2024
1 parent c295c7b commit 3e416ac
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 91 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,12 @@ AuthenticationMiddleware authMiddleware;

// [...]

authMiddleware.setAuthType(AuthenticationMiddleware::AuthType::AUTH_DIGEST);
authMiddleware.setAuthType(AsyncAuthType::AUTH_DIGEST);
authMiddleware.setRealm("My app name");
authMiddleware.setUsername("admin");
authMiddleware.setPassword("admin");
authMiddleware.setAuthFailureMessage("Authentication failed");
authMiddleware.generateHash(); // optimization to avoid generating the hash at each request

// [...]

Expand Down
4 changes: 3 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,12 @@ AuthenticationMiddleware authMiddleware;

// [...]

authMiddleware.setAuthType(AuthenticationMiddleware::AuthType::AUTH_DIGEST);
authMiddleware.setAuthType(AsyncAuthType::AUTH_DIGEST);
authMiddleware.setRealm("My app name");
authMiddleware.setUsername("admin");
authMiddleware.setPassword("admin");
authMiddleware.setAuthFailureMessage("Authentication failed");
authMiddleware.generateHash(); // optimization to avoid generating the hash at each request

// [...]

Expand Down
75 changes: 65 additions & 10 deletions examples/SimpleServer/SimpleServer.ino
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,13 @@ HeaderFilterMiddleware headerFilter;
// remove all headers from the incoming request except the ones provided in the constructor
HeaderFreeMiddleware headerFree;

// basicAuth
AuthenticationMiddleware basicAuth;
AuthenticationMiddleware basicAuthHash;

// simple digest authentication
AuthenticationMiddleware simpleDigestAuth;
AuthenticationMiddleware digestAuth;
AuthenticationMiddleware digestAuthHash;

// complex authentication which adds request attributes for the next middlewares and handler
AsyncMiddlewareFunction complexAuth([](AsyncWebServerRequest* request, ArMiddlewareNext next) {
Expand Down Expand Up @@ -177,9 +182,31 @@ void setup() {

requestLogger.setOutput(Serial);

simpleDigestAuth.setUsername("admin");
simpleDigestAuth.setPassword("admin");
simpleDigestAuth.setRealm("MyApp");
basicAuth.setUsername("admin");
basicAuth.setPassword("admin");
basicAuth.setRealm("MyApp");
basicAuth.setAuthFailureMessage("Authentication failed");
basicAuth.setAuthType(AsyncAuthType::AUTH_BASIC);
basicAuth.generateHash();

basicAuthHash.setUsername("admin");
basicAuthHash.setPasswordHash("YWRtaW46YWRtaW4="); // BASE64(admin:admin)
basicAuthHash.setRealm("MyApp");
basicAuthHash.setAuthFailureMessage("Authentication failed");
basicAuthHash.setAuthType(AsyncAuthType::AUTH_BASIC);

digestAuth.setUsername("admin");
digestAuth.setPassword("admin");
digestAuth.setRealm("MyApp");
digestAuth.setAuthFailureMessage("Authentication failed");
digestAuth.setAuthType(AsyncAuthType::AUTH_DIGEST);
digestAuth.generateHash();

digestAuthHash.setUsername("admin");
digestAuthHash.setPasswordHash("f499b71f9a36d838b79268e145e132f7"); // MD5(user:realm:pass)
digestAuthHash.setRealm("MyApp");
digestAuthHash.setAuthFailureMessage("Authentication failed");
digestAuthHash.setAuthType(AsyncAuthType::AUTH_DIGEST);

rateLimit.setMaxRequests(5);
rateLimit.setWindowSize(10);
Expand Down Expand Up @@ -225,15 +252,37 @@ void setup() {
})
.addMiddleware(&headerFree);

// simple digest authentication
// curl -v -X GET -H "x-remove-me: value" --digest -u admin:admin http://192.168.4.1/middleware/auth-simple
server.on("/middleware/auth-simple", HTTP_GET, [](AsyncWebServerRequest* request) {
// basic authentication method
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic
server.on("/middleware/auth-basic", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&basicAuth);

// basic authentication method with hash
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic-hash
server.on("/middleware/auth-basic-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&basicAuthHash);

// digest authentication
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest
server.on("/middleware/auth-digest", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&digestAuth);

// digest authentication with hash
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest-hash
server.on("/middleware/auth-digest-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&simpleDigestAuth);
.addMiddleware(&digestAuthHash);

// curl -v -X GET -H "x-remove-me: value" --digest -u user:password http://192.168.4.1/middleware/auth-complex
server.on("/middleware/auth-complex", HTTP_GET, [](AsyncWebServerRequest* request) {
// test digest auth with cors
// curl -v -X GET -H "origin: http://192.168.4.1" --digest -u user:password http://192.168.4.1/middleware/auth-custom
server.on("/middleware/auth-custom", HTTP_GET, [](AsyncWebServerRequest* request) {
String buffer = "Hello ";
buffer.concat(request->getAttribute("user"));
buffer.concat(" with role: ");
Expand All @@ -244,6 +293,12 @@ void setup() {

///////////////////////////////////////////////////////////////////////

// curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
// curl -v -X POST -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
server.on("/redirect", HTTP_GET | HTTP_POST, [](AsyncWebServerRequest* request) {
request->redirect("/");
});

server.on("/", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world");
});
Expand Down
49 changes: 31 additions & 18 deletions src/ESPAsyncWebServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ typedef enum { RCT_NOT_USED = -1,
RCT_EVENT,
RCT_MAX } RequestedConnectionType;

// this enum is similar to Arduino WebServer's AsyncAuthType and PsychicHttp
typedef enum {
AUTH_NONE = 0,
AUTH_BASIC,
AUTH_DIGEST,
AUTH_BEARER,
AUTH_OTHER,
} AsyncAuthType;

typedef std::function<size_t(uint8_t*, size_t, size_t)> AwsResponseFiller;
typedef std::function<String(const String&)> AwsTemplateProcessor;

Expand Down Expand Up @@ -194,7 +203,7 @@ class AsyncWebServerRequest {
String _boundary;
String _authorization;
RequestedConnectionType _reqconntype;
bool _isDigest;
AsyncAuthType _authMethod = AsyncAuthType::AUTH_NONE;
bool _isMultipart;
bool _isPlainPost;
bool _expectingContinue;
Expand Down Expand Up @@ -271,8 +280,9 @@ class AsyncWebServerRequest {
// base64(user:pass) for basic or
// user:realm:md5(user:realm:pass) for digest
bool authenticate(const char* hash);
bool authenticate(const char* username, const char* password, const char* realm = NULL, bool passwordIsHash = false);
void requestAuthentication(const char* realm = NULL, bool isDigest = true);
bool authenticate(const char* username, const char* credentials, const char* realm = NULL, bool isHash = false);
void requestAuthentication(const char* realm = nullptr, bool isDigest = true) { requestAuthentication(isDigest ? AsyncAuthType::AUTH_DIGEST : AsyncAuthType::AUTH_BASIC, realm); }
void requestAuthentication(AsyncAuthType method, const char* realm = nullptr, const char* _authFailMsg = nullptr);

void setHandler(AsyncWebHandler* handler) { _handler = handler; }

Expand Down Expand Up @@ -554,28 +564,31 @@ class AsyncMiddlewareChain {
// AuthenticationMiddleware is a middleware that checks if the request is authenticated
class AuthenticationMiddleware : public AsyncMiddleware {
public:
typedef enum {
AUTH_NONE,
AUTH_BASIC,
AUTH_DIGEST
} AuthType;

void setUsername(const char* username) { _username = username; }
void setPassword(const char* password) { _password = password; }
void setUsername(const char* username);
void setPassword(const char* password);
void setPasswordHash(const char* hash);

void setRealm(const char* realm) { _realm = realm; }
void setPasswordIsHash(bool passwordIsHash) { _hash = passwordIsHash; }
void setAuthType(AuthType authType) { _authType = authType; }
void setAuthFailureMessage(const char* message) { _authFailMsg = message; }
void setAuthType(AsyncAuthType authMethod) { _authMethod = authMethod; }

bool allowed(AsyncWebServerRequest* request) { return _authType == AUTH_NONE || !_username.length() || !_password.length() || request->authenticate(_username.c_str(), _password.c_str(), _realm, _hash); }
// precompute and store the hash value based on the username, realm, and authMethod
// returns true if the hash was successfully generated and replaced
bool generateHash();

void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return allowed(request) ? next() : request->requestAuthentication(_realm, _authType == AUTH_DIGEST); }
bool allowed(AsyncWebServerRequest* request);

void run(AsyncWebServerRequest* request, ArMiddlewareNext next);

private:
String _username;
String _password;
const char* _realm = nullptr;
String _credentials;
bool _hash = false;
AuthType _authType = AUTH_DIGEST;

String _realm = asyncsrv::T_LOGIN_REQ;
AsyncAuthType _authMethod = AsyncAuthType::AUTH_NONE;
String _authFailMsg;
bool _hasCreds = false;
};

using ArAuthorizeFunction = std::function<bool(AsyncWebServerRequest* request)>;
Expand Down
57 changes: 57 additions & 0 deletions src/Middleware.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "WebAuthentication.h"
#include <ESPAsyncWebServer.h>

AsyncMiddlewareChain::~AsyncMiddlewareChain() {
Expand Down Expand Up @@ -52,6 +53,62 @@ void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest* request, ArMiddlewar
return next();
}

void AuthenticationMiddleware::setUsername(const char* username) {
_username = username;
_hasCreds = _username.length() && _credentials.length();
}

void AuthenticationMiddleware::setPassword(const char* password) {
_credentials = password;
_hash = false;
_hasCreds = _username.length() && _credentials.length();
}

void AuthenticationMiddleware::setPasswordHash(const char* hash) {
_credentials = hash;
_hash = true;
_hasCreds = _username.length() && _credentials.length();
}

bool AuthenticationMiddleware::generateHash() {
// ensure we have all the necessary data
if (!_hasCreds)
return false;

// if we already have a hash, do nothing
if (_hash)
return false;

switch (_authMethod) {
case AsyncAuthType::AUTH_DIGEST:
_credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str());
_hash = true;
return true;

case AsyncAuthType::AUTH_BASIC:
_credentials = generateBasicHash(_username.c_str(), _credentials.c_str());
_hash = true;
return true;

default:
return false;
}
}

bool AuthenticationMiddleware::allowed(AsyncWebServerRequest* request) {
if (_authMethod == AsyncAuthType::AUTH_NONE)
return true;

if (!_hasCreds)
return false;

return request->authenticate(_username.c_str(), _credentials.c_str(), _realm.c_str(), _hash);
}

void AuthenticationMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str());
}

void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
std::vector<const char*> reqHeaders;
request->getHeaderNames(reqHeaders);
Expand Down
Loading

0 comments on commit 3e416ac

Please sign in to comment.