Skip to content

Commit

Permalink
Add authenticate to IdentityPlugin interface
Browse files Browse the repository at this point in the history
Signed-off-by: Craig Perkins <cwperx@amazon.com>
  • Loading branch information
cwperks committed Aug 26, 2024
1 parent f195285 commit e6b82ba
Show file tree
Hide file tree
Showing 25 changed files with 355 additions and 164 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.identity.shiro;

import org.opensearch.client.node.NodeClient;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;

import java.util.List;
import java.util.Objects;

/**
* Delegating RestHandler that delegates all implementations to original handler
*/
public class DelegatingRestHandler implements RestHandler {

protected final RestHandler delegate;

public DelegatingRestHandler(RestHandler delegate) {
Objects.requireNonNull(delegate, "RestHandler delegate can not be null");
this.delegate = delegate;
}

@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
delegate.handleRequest(request, channel, client);
}

@Override
public boolean canTripCircuitBreaker() {
return delegate.canTripCircuitBreaker();
}

@Override
public boolean supportsContentStream() {
return delegate.supportsContentStream();
}

@Override
public boolean allowsUnsafeBuffers() {
return delegate.allowsUnsafeBuffers();
}

@Override
public List<Route> routes() {
return delegate.routes();
}

@Override
public List<DeprecatedRoute> deprecatedRoutes() {
return delegate.deprecatedRoutes();
}

@Override
public List<ReplacedRoute> replacedRoutes() {
return delegate.replacedRoutes();
}

@Override
public boolean allowSystemIndexAccessByDefault() {
return delegate.allowSystemIndexAccessByDefault();
}

@Override
public String toString() {
return delegate.toString();
}

@Override
public boolean supportsStreaming() {
return delegate.supportsStreaming();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,25 @@
import org.apache.logging.log4j.Logger;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.mgt.SecurityManager;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.identity.Subject;
import org.opensearch.identity.tokens.AuthToken;
import org.opensearch.identity.tokens.RestTokenExtractor;
import org.opensearch.identity.tokens.TokenManager;
import org.opensearch.plugins.IdentityPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;

import java.util.function.UnaryOperator;

/**
* Identity implementation with Shiro
*
* @opensearch.experimental
*/
public final class ShiroIdentityPlugin extends Plugin implements IdentityPlugin {
private Logger log = LogManager.getLogger(this.getClass());
Expand Down Expand Up @@ -61,4 +70,35 @@ public Subject getSubject() {
public TokenManager getTokenManager() {
return this.authTokenHandler;
}

@Override
public UnaryOperator<RestHandler> authenticate(ThreadContext threadContext) {
return AuthcRestHandler::new;
}

class AuthcRestHandler extends DelegatingRestHandler {
public AuthcRestHandler(RestHandler original) {
super(original);
}

@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
final AuthToken token = RestTokenExtractor.extractToken(request);
// If no token was found, continue executing the request
if (token == null) {
// Authentication did not fail so return true. Authorization is handled at the action level.
delegate.handleRequest(request, channel, client);
return;
}
try {
ShiroSubject shiroSubject = (ShiroSubject) getSubject();
shiroSubject.authenticate(token);
// Caller was authorized, forward the request to the handler
delegate.handleRequest(request, channel, client);
} catch (final Exception e) {
final BytesRestResponse bytesRestResponse = new BytesRestResponse(RestStatus.UNAUTHORIZED, e.getMessage());
channel.sendResponse(bytesRestResponse);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

/**
* OpenSearch specific security manager implementation
*
* @opensearch.experimental
*/
public class ShiroSecurityManager extends DefaultSecurityManager {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

/**
* Subject backed by Shiro
*
* @opensearch.experimental
*/
public class ShiroSubject implements Subject {
private final ShiroTokenManager authTokenHandler;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@

/**
* Extracts Shiro's {@link AuthenticationToken} from different types of auth headers
*
* @opensearch.experimental
*/
class ShiroTokenManager implements TokenManager {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

/**
* Password matcher for BCrypt
*
* @opensearch.experimental
*/
public class BCryptPasswordMatcher implements CredentialsMatcher {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

/**
* Internal Realm is a custom realm using the internal OpenSearch IdP
*
* @opensearch.experimental
*/
public class OpenSearchRealm extends AuthenticatingRealm {
private static final String DEFAULT_REALM_NAME = "internal";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

/**
* A non-volatile and immutable object in the storage.
*
* @opensearch.experimental
*/
public class User {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@
import org.opensearch.identity.IdentityService;
import org.opensearch.plugins.IdentityPlugin;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;

import java.util.List;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.mock;

public class ShiroIdentityPluginTests extends OpenSearchTestCase {

public void testSingleIdentityPluginSucceeds() {
IdentityPlugin identityPlugin1 = new ShiroIdentityPlugin(Settings.EMPTY);
List<IdentityPlugin> pluginList1 = List.of(identityPlugin1);
IdentityService identityService1 = new IdentityService(Settings.EMPTY, pluginList1);
IdentityService identityService1 = new IdentityService(Settings.EMPTY, mock(ThreadPool.class), pluginList1);
assertThat(identityService1.getTokenManager(), is(instanceOf(ShiroTokenManager.class)));
}

Expand All @@ -35,7 +37,10 @@ public void testMultipleIdentityPluginsFail() {
IdentityPlugin identityPlugin2 = new ShiroIdentityPlugin(Settings.EMPTY);
IdentityPlugin identityPlugin3 = new ShiroIdentityPlugin(Settings.EMPTY);
List<IdentityPlugin> pluginList = List.of(identityPlugin1, identityPlugin2, identityPlugin3);
Exception ex = assertThrows(OpenSearchException.class, () -> new IdentityService(Settings.EMPTY, pluginList));
Exception ex = assertThrows(
OpenSearchException.class,
() -> new IdentityService(Settings.EMPTY, mock(ThreadPool.class), pluginList)
);
assert (ex.getMessage().contains("Multiple identity plugins are not supported,"));
}

Expand Down
28 changes: 20 additions & 8 deletions server/src/main/java/org/opensearch/action/ActionModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
Expand All @@ -491,6 +492,7 @@

import static java.util.Collections.unmodifiableMap;
import static java.util.Objects.requireNonNull;
import static org.opensearch.rest.RestController.PASS_THROUGH_REST_HANDLER_WRAPPER;

/**
* Builds and binds the generic action map, all {@link TransportAction}s, and {@link ActionFilters}.
Expand Down Expand Up @@ -522,6 +524,7 @@ public class ActionModule extends AbstractModule {
private final AutoCreateIndex autoCreateIndex;
private final DestructiveOperations destructiveOperations;
private final RestController restController;
private final UnaryOperator<RestHandler> restWrapper;
private final RequestValidators<PutMappingRequest> mappingRequestValidators;
private final RequestValidators<IndicesAliasesRequest> indicesAliasesRequestRequestValidators;
private final ThreadPool threadPool;
Expand Down Expand Up @@ -559,25 +562,29 @@ public ActionModule(
actionPlugins.stream().flatMap(p -> p.getRestHeaders().stream()),
Stream.of(new RestHeaderDefinition(Task.X_OPAQUE_ID, false))
).collect(Collectors.toSet());
UnaryOperator<RestHandler> restWrapper = null;
for (ActionPlugin plugin : actionPlugins) {
UnaryOperator<RestHandler> newRestWrapper = plugin.getRestHandlerWrapper(threadPool.getThreadContext());
if (newRestWrapper != null) {
logger.debug("Using REST wrapper from plugin " + plugin.getClass().getName());
if (restWrapper != null) {
UnaryOperator<RestHandler> restWrapper = identityService.authenticate();
// Check if implementation is provided by one of the actionPlugins
if (PASS_THROUGH_REST_HANDLER_WRAPPER.equals(restWrapper)) {
List<UnaryOperator<RestHandler>> restWrappers = actionPlugins.stream()
.map(p -> p.getRestHandlerWrapper(threadPool.getThreadContext()))
.filter(Objects::nonNull)
.collect(Collectors.toUnmodifiableList());
if (!restWrappers.isEmpty()) {
if (restWrappers.size() > 1) {
throw new IllegalArgumentException("Cannot have more than one plugin implementing a REST wrapper");
}
restWrapper = newRestWrapper;
restWrapper = restWrappers.get(0);
}
}
this.restWrapper = restWrapper;
mappingRequestValidators = new RequestValidators<>(
actionPlugins.stream().flatMap(p -> p.mappingRequestValidators().stream()).collect(Collectors.toList())
);
indicesAliasesRequestRequestValidators = new RequestValidators<>(
actionPlugins.stream().flatMap(p -> p.indicesAliasesRequestValidators().stream()).collect(Collectors.toList())
);

restController = new RestController(headers, restWrapper, nodeClient, circuitBreakerService, usageService, identityService);
restController = new RestController(headers, restWrapper, nodeClient, circuitBreakerService, usageService);
}

public Map<String, ActionHandler<?, ?>> getActions() {
Expand Down Expand Up @@ -1059,6 +1066,11 @@ public RestController getRestController() {
return restController;
}

// Visible for testing
UnaryOperator<RestHandler> getRestWrapper() {
return restWrapper;
}

/**
* The DynamicActionRegistry maintains a registry mapping {@link ActionType} instances to {@link TransportAction} instances.
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.opensearch.transport.TransportService;

import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.Set;

Expand All @@ -31,8 +30,8 @@
*/
public class NoopExtensionsManager extends ExtensionsManager {

public NoopExtensionsManager() throws IOException {
super(Set.of(), new IdentityService(Settings.EMPTY, List.of()));
public NoopExtensionsManager(IdentityService identityService) throws IOException {
super(Set.of(), identityService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import org.opensearch.identity.noop.NoopIdentityPlugin;
import org.opensearch.identity.tokens.TokenManager;
import org.opensearch.plugins.IdentityPlugin;
import org.opensearch.rest.RestHandler;
import org.opensearch.threadpool.ThreadPool;

import java.util.List;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;

/**
Expand All @@ -26,9 +29,11 @@ public class IdentityService {

private final Settings settings;
private final IdentityPlugin identityPlugin;
private final ThreadPool threadPool;

public IdentityService(final Settings settings, final List<IdentityPlugin> identityPlugins) {
public IdentityService(final Settings settings, final ThreadPool threadPool, final List<IdentityPlugin> identityPlugins) {
this.settings = settings;
this.threadPool = threadPool;

if (identityPlugins.size() == 0) {
log.debug("Identity plugins size is 0");
Expand Down Expand Up @@ -57,4 +62,11 @@ public Subject getSubject() {
public TokenManager getTokenManager() {
return identityPlugin.getTokenManager();
}

/**
* Gets the RestHandlerWrapper to authenticate the request
*/
public UnaryOperator<RestHandler> authenticate() {
return identityPlugin.authenticate(this.threadPool.getThreadContext());
}
}
Loading

0 comments on commit e6b82ba

Please sign in to comment.