Skip to content

Commit

Permalink
Complete rework of AuthenticationMiddleware...
Browse files Browse the repository at this point in the history
- to align methods and enum with PsychicHttp and Arduino WebServer
- to support hash
- to pre-compute base64 / digest hash to speed up requests
Closes #111
  • Loading branch information
mathieucarbou committed Sep 30, 2024
1 parent c295c7b commit 6a19dd0
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 93 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,12 @@ AuthenticationMiddleware authMiddleware;

// [...]

authMiddleware.setAuthType(AuthenticationMiddleware::AuthType::AUTH_DIGEST);
authMiddleware.setRealm("My app name");
authMiddleware.setUsername("admin");
authMiddleware.setPassword("admin");
authMiddleware.setRealm("My app name");
authMiddleware.setAuthMethod(HTTPAuthMethod::DIGEST_AUTH);
authMiddleware.setAuthFailureMessage("Authentication failed");
authMiddleware.generateHash(); // optimization to avoid generating the hash at each request

// [...]

Expand Down
6 changes: 4 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,12 @@ AuthenticationMiddleware authMiddleware;

// [...]

authMiddleware.setAuthType(AuthenticationMiddleware::AuthType::AUTH_DIGEST);
authMiddleware.setRealm("My app name");
authMiddleware.setUsername("admin");
authMiddleware.setPassword("admin");
authMiddleware.setRealm("My app name");
authMiddleware.setAuthMethod(HTTPAuthMethod::DIGEST_AUTH);
authMiddleware.setAuthFailureMessage("Authentication failed");
authMiddleware.generateHash(); // optimization to avoid generating the hash at each request

// [...]

Expand Down
75 changes: 65 additions & 10 deletions examples/SimpleServer/SimpleServer.ino
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,13 @@ HeaderFilterMiddleware headerFilter;
// remove all headers from the incoming request except the ones provided in the constructor
HeaderFreeMiddleware headerFree;

// basicAuth
AuthenticationMiddleware basicAuth;
AuthenticationMiddleware basicAuthHash;

// simple digest authentication
AuthenticationMiddleware simpleDigestAuth;
AuthenticationMiddleware digestAuth;
AuthenticationMiddleware digestAuthHash;

// complex authentication which adds request attributes for the next middlewares and handler
AsyncMiddlewareFunction complexAuth([](AsyncWebServerRequest* request, ArMiddlewareNext next) {
Expand Down Expand Up @@ -177,9 +182,31 @@ void setup() {

requestLogger.setOutput(Serial);

simpleDigestAuth.setUsername("admin");
simpleDigestAuth.setPassword("admin");
simpleDigestAuth.setRealm("MyApp");
basicAuth.setUsername("admin");
basicAuth.setPassword("admin");
basicAuth.setRealm("MyApp");
basicAuth.setAuthFailureMessage("Authentication failed");
basicAuth.setAuthMethod(HTTPAuthMethod::BASIC_AUTH);
basicAuth.generateHash();

basicAuthHash.setUsername("admin");
basicAuthHash.setPasswordHash("YWRtaW46YWRtaW4="); // BASE64(admin:admin)
basicAuthHash.setRealm("MyApp");
basicAuthHash.setAuthFailureMessage("Authentication failed");
basicAuthHash.setAuthMethod(HTTPAuthMethod::BASIC_AUTH);

digestAuth.setUsername("admin");
digestAuth.setPassword("admin");
digestAuth.setRealm("MyApp");
digestAuth.setAuthFailureMessage("Authentication failed");
digestAuth.setAuthMethod(HTTPAuthMethod::DIGEST_AUTH);
digestAuth.generateHash();

digestAuthHash.setUsername("admin");
digestAuthHash.setPasswordHash("f499b71f9a36d838b79268e145e132f7"); // MD5(user:realm:pass)
digestAuthHash.setRealm("MyApp");
digestAuthHash.setAuthFailureMessage("Authentication failed");
digestAuthHash.setAuthMethod(HTTPAuthMethod::DIGEST_AUTH);

rateLimit.setMaxRequests(5);
rateLimit.setWindowSize(10);
Expand Down Expand Up @@ -225,15 +252,37 @@ void setup() {
})
.addMiddleware(&headerFree);

// simple digest authentication
// curl -v -X GET -H "x-remove-me: value" --digest -u admin:admin http://192.168.4.1/middleware/auth-simple
server.on("/middleware/auth-simple", HTTP_GET, [](AsyncWebServerRequest* request) {
// basic authentication method
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic
server.on("/middleware/auth-basic", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&basicAuth);

// basic authentication method with hash
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin http://192.168.4.1/middleware/auth-basic-hash
server.on("/middleware/auth-basic-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&basicAuthHash);

// digest authentication
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest
server.on("/middleware/auth-digest", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&digestAuth);

// digest authentication with hash
// curl -v -X GET -H "origin: http://192.168.4.1" -u admin:admin --digest http://192.168.4.1/middleware/auth-digest-hash
server.on("/middleware/auth-digest-hash", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world!");
})
.addMiddleware(&simpleDigestAuth);
.addMiddleware(&digestAuthHash);

// curl -v -X GET -H "x-remove-me: value" --digest -u user:password http://192.168.4.1/middleware/auth-complex
server.on("/middleware/auth-complex", HTTP_GET, [](AsyncWebServerRequest* request) {
// test digest auth with cors
// curl -v -X GET -H "origin: http://192.168.4.1" --digest -u user:password http://192.168.4.1/middleware/auth-custom
server.on("/middleware/auth-custom", HTTP_GET, [](AsyncWebServerRequest* request) {
String buffer = "Hello ";
buffer.concat(request->getAttribute("user"));
buffer.concat(" with role: ");
Expand All @@ -244,6 +293,12 @@ void setup() {

///////////////////////////////////////////////////////////////////////

// curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
// curl -v -X POST -H "origin: http://192.168.4.1" http://192.168.4.1/redirect
server.on("/redirect", HTTP_GET | HTTP_POST, [](AsyncWebServerRequest* request) {
request->redirect("/");
});

server.on("/", HTTP_GET, [](AsyncWebServerRequest* request) {
request->send(200, "text/plain", "Hello, world");
});
Expand Down
48 changes: 30 additions & 18 deletions src/ESPAsyncWebServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ typedef enum { RCT_NOT_USED = -1,
RCT_EVENT,
RCT_MAX } RequestedConnectionType;

// this enum is similar to Arduino WebServer's HTTPAuthMethod and PsychicHttp
typedef enum {
NO_AUTH = 0,
BASIC_AUTH,
DIGEST_AUTH,
BEARER_AUTH,
OTHER_AUTH
} HTTPAuthMethod;

typedef std::function<size_t(uint8_t*, size_t, size_t)> AwsResponseFiller;
typedef std::function<String(const String&)> AwsTemplateProcessor;

Expand Down Expand Up @@ -194,7 +203,7 @@ class AsyncWebServerRequest {
String _boundary;
String _authorization;
RequestedConnectionType _reqconntype;
bool _isDigest;
HTTPAuthMethod _authMethod = NO_AUTH;
bool _isMultipart;
bool _isPlainPost;
bool _expectingContinue;
Expand Down Expand Up @@ -271,8 +280,9 @@ class AsyncWebServerRequest {
// base64(user:pass) for basic or
// user:realm:md5(user:realm:pass) for digest
bool authenticate(const char* hash);
bool authenticate(const char* username, const char* password, const char* realm = NULL, bool passwordIsHash = false);
void requestAuthentication(const char* realm = NULL, bool isDigest = true);
bool authenticate(const char* username, const char* credentials, const char* realm = NULL, bool isHash = false);
void requestAuthentication(const char* realm = nullptr, bool isDigest = true) { requestAuthentication(isDigest ? DIGEST_AUTH : BASIC_AUTH, realm); }
void requestAuthentication(HTTPAuthMethod method, const char* realm = nullptr, const char* _authFailMsg = nullptr);

void setHandler(AsyncWebHandler* handler) { _handler = handler; }

Expand Down Expand Up @@ -554,28 +564,30 @@ class AsyncMiddlewareChain {
// AuthenticationMiddleware is a middleware that checks if the request is authenticated
class AuthenticationMiddleware : public AsyncMiddleware {
public:
typedef enum {
AUTH_NONE,
AUTH_BASIC,
AUTH_DIGEST
} AuthType;

void setUsername(const char* username) { _username = username; }
void setPassword(const char* password) { _password = password; }
void setUsername(const char* username);
void setPassword(const char* password);
void setPasswordHash(const char* hash);

void setRealm(const char* realm) { _realm = realm; }
void setPasswordIsHash(bool passwordIsHash) { _hash = passwordIsHash; }
void setAuthType(AuthType authType) { _authType = authType; }
void setAuthFailureMessage(const char* message) { _authFailMsg = message; }
void setAuthMethod(HTTPAuthMethod authMethod) { _authMethod = authMethod; }

bool allowed(AsyncWebServerRequest* request) { return _authType == AUTH_NONE || !_username.length() || !_password.length() || request->authenticate(_username.c_str(), _password.c_str(), _realm, _hash); }
// precompute and store the hash value based on the username, realm, and authMethod
void generateHash();

void run(AsyncWebServerRequest* request, ArMiddlewareNext next) { return allowed(request) ? next() : request->requestAuthentication(_realm, _authType == AUTH_DIGEST); }
bool allowed(AsyncWebServerRequest* request);

void run(AsyncWebServerRequest* request, ArMiddlewareNext next);

private:
String _username;
String _password;
const char* _realm = nullptr;
String _credentials;
bool _hash = false;
AuthType _authType = AUTH_DIGEST;

String _realm = asyncsrv::T_LOGIN_REQ;
HTTPAuthMethod _authMethod = DIGEST_AUTH;
String _authFailMsg;
bool _hasCreds = false;
};

using ArAuthorizeFunction = std::function<bool(AsyncWebServerRequest* request)>;
Expand Down
49 changes: 49 additions & 0 deletions src/Middleware.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "WebAuthentication.h"
#include <ESPAsyncWebServer.h>

AsyncMiddlewareChain::~AsyncMiddlewareChain() {
Expand Down Expand Up @@ -52,6 +53,54 @@ void AsyncMiddlewareChain::_runChain(AsyncWebServerRequest* request, ArMiddlewar
return next();
}

void AuthenticationMiddleware::setUsername(const char* username) {
_username = username;
_hasCreds = !_username.isEmpty() && !_credentials.isEmpty();
}

void AuthenticationMiddleware::setPassword(const char* password) {
_credentials = password;
_hash = false;
_hasCreds = !_username.isEmpty() && !_credentials.isEmpty();
}

void AuthenticationMiddleware::setPasswordHash(const char* hash) {
_credentials = hash;
_hash = true;
_hasCreds = !_username.isEmpty() && !_credentials.isEmpty();
}

void AuthenticationMiddleware::generateHash() {
switch (_authMethod) {
case DIGEST_AUTH:
_credentials = generateDigestHash(_username.c_str(), _credentials.c_str(), _realm.c_str());
_hash = true;
break;
case BASIC_AUTH:
_credentials = generateBasicHash(_username.c_str(), _credentials.c_str());
_hash = true;
break;
default:
break;
}
// log_d("Generated hash: %s", _credentials.c_str());
_hasCreds = !_username.isEmpty() && !_credentials.isEmpty();
}

bool AuthenticationMiddleware::allowed(AsyncWebServerRequest* request) {
if (_authMethod == NO_AUTH)
return true;

if (!_hasCreds)
return false;

return request->authenticate(_username.c_str(), _credentials.c_str(), _realm.c_str(), _hash);
}

void AuthenticationMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
return allowed(request) ? next() : request->requestAuthentication(_authMethod, _realm.c_str(), _authFailMsg.c_str());
}

void HeaderFreeMiddleware::run(AsyncWebServerRequest* request, ArMiddlewareNext next) {
std::vector<const char*> reqHeaders;
request->getHeaderNames(reqHeaders);
Expand Down
66 changes: 27 additions & 39 deletions src/WebAuthentication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,34 @@ using namespace asyncsrv;
bool checkBasicAuthentication(const char* hash, const char* username, const char* password) {
if (username == NULL || password == NULL || hash == NULL)
return false;
return generateBasicHash(username, password).equalsIgnoreCase(hash);
}

String generateBasicHash(const char* username, const char* password) {
if (username == NULL || password == NULL)
return emptyString;

size_t toencodeLen = strlen(username) + strlen(password) + 1;
size_t encodedLen = base64_encode_expected_len(toencodeLen);
if (strlen(hash) != encodedLen)
// Fix from https://github.com/me-no-dev/ESPAsyncWebServer/issues/667
#ifdef ARDUINO_ARCH_ESP32
if (strlen(hash) != encodedLen)
#else
if (strlen(hash) != encodedLen - 1)
#endif
return false;

char* toencode = new char[toencodeLen + 1];
if (toencode == NULL) {
return false;
return emptyString;
}
char* encoded = new char[base64_encode_expected_len(toencodeLen) + 1];
if (encoded == NULL) {
delete[] toencode;
return false;
return emptyString;
}
sprintf_P(toencode, PSTR("%s:%s"), username, password);
if (base64_encode_chars(toencode, toencodeLen, encoded) > 0 && memcmp(hash, encoded, encodedLen) == 0) {
if (base64_encode_chars(toencode, toencodeLen, encoded) > 0) {
String res = String(encoded);
delete[] toencode;
delete[] encoded;
return true;
return res;
}
delete[] toencode;
delete[] encoded;
return false;
return emptyString;
}

static bool getMD5(uint8_t* data, uint16_t len, char* output) { // 33 bytes or more
Expand Down Expand Up @@ -94,7 +92,7 @@ static bool getMD5(uint8_t* data, uint16_t len, char* output) { // 33 bytes or m
return true;
}

static String genRandomMD5() {
String genRandomMD5() {
#ifdef ESP8266
uint32_t r = RANDOM_REG32;
#else
Expand Down Expand Up @@ -122,31 +120,21 @@ String generateDigestHash(const char* username, const char* password, const char
return emptyString;
}
char* out = (char*)malloc(33);
String res = String(username);
res += ':';
res.concat(realm);
res += ':';
String in = res;

String in;
in.reserve(strlen(username) + strlen(realm) + strlen(password) + 2);
in.concat(username);
in.concat(':');
in.concat(realm);
in.concat(':');
in.concat(password);

if (out == NULL || !getMD5((uint8_t*)(in.c_str()), in.length(), out))
return emptyString;
res.concat(out);
free(out);
return res;
}

String requestDigestAuthentication(const char* realm) {
String header(T_realm__);
if (realm == NULL)
header.concat(T_asyncesp);
else
header.concat(realm);
header.concat(T_auth_nonce);
header.concat(genRandomMD5());
header.concat(T__opaque);
header.concat(genRandomMD5());
header += (char)0x22; // '"'
return header;
in = String(out);
free(out);
return in;
}

#ifndef ESP8266
Expand Down Expand Up @@ -235,9 +223,9 @@ bool checkDigestAuthentication(const char* header, const __FlashStringHelper* me
}
} while (nextBreak > 0);

String ha1 = (passwordIsHash) ? String(password) : stringMD5(myUsername + ':' + myRealm + ':' + password);
String ha2 = String(method) + ':' + myUri;
String response = ha1 + ':' + myNonce + ':' + myNc + ':' + myCnonce + ':' + myQop + ':' + stringMD5(ha2);
String ha1 = passwordIsHash ? password : stringMD5(myUsername + ':' + myRealm + ':' + password).c_str();
String ha2 = stringMD5(String(method) + ':' + myUri);
String response = ha1 + ':' + myNonce + ':' + myNc + ':' + myCnonce + ':' + myQop + ':' + ha2;

if (myResponse.equals(stringMD5(response))) {
// os_printf("AUTH SUCCESS\n");
Expand Down
Loading

0 comments on commit 6a19dd0

Please sign in to comment.