Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing ClientRequestFilter.java: A New Plugin for Applying Request Headers in the Authentication Filter Class #23380

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,11 @@
<artifactId>ratis-common</artifactId>
<optional>true</optional>
</dependency>
<dependency>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not use Mockito in this project. Can you please refactor your code to avoid its use?

Copy link
Author

@SthuthiGhosh9400 SthuthiGhosh9400 Sep 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdcmeehan
Instead of using Mockito to create mock objects and define their behavior, can I manually write simple classes (stubs) that simulate the behavior of the actual classes?

Do you have any suggestion here in order to verify the codebase if value correctly passed to request header?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is how we go about this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tdcmeehan
I have refactored the code that was added for the test case.

<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto;

import com.facebook.presto.spi.ClientRequestFilter;
import com.google.common.collect.ImmutableList;

import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

public class ClientRequestFilterManager
{
private final List<ClientRequestFilter> clientRequestFilters = new CopyOnWriteArrayList<>();

public List<ClientRequestFilter> getClientRequestFilters()
{
return ImmutableList.copyOf(clientRequestFilters);
}

public void registerClientRequestFilter(ClientRequestFilter clientRequestFilter)
{
clientRequestFilters.add(clientRequestFilter);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto;

import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Scopes;

public class ClientRequestFilterModule
implements Module
{
@Override
public void configure(Binder binder)
{
binder.bind(ClientRequestFilterManager.class).in(Scopes.SINGLETON);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import com.facebook.airlift.log.Logger;
import com.facebook.airlift.node.NodeInfo;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.common.block.BlockEncoding;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.type.ParametricType;
Expand All @@ -27,6 +28,7 @@
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.security.AccessControlManager;
import com.facebook.presto.server.security.PasswordAuthenticatorManager;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.CoordinatorPlugin;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.analyzer.AnalyzerProvider;
Expand Down Expand Up @@ -131,6 +133,7 @@ public class PluginManager
private final AnalyzerProviderManager analyzerProviderManager;
private final QueryPreparerProviderManager queryPreparerProviderManager;
private final NodeStatusNotificationManager nodeStatusNotificationManager;
private final ClientRequestFilterManager clientRequestFilterManager;

@Inject
public PluginManager(
Expand All @@ -152,7 +155,8 @@ public PluginManager(
ClusterTtlProviderManager clusterTtlProviderManager,
HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager,
TracerProviderManager tracerProviderManager,
NodeStatusNotificationManager nodeStatusNotificationManager)
NodeStatusNotificationManager nodeStatusNotificationManager,
ClientRequestFilterManager clientRequestFilterManager)
{
requireNonNull(nodeInfo, "nodeInfo is null");
requireNonNull(config, "config is null");
Expand Down Expand Up @@ -184,6 +188,7 @@ public PluginManager(
this.analyzerProviderManager = requireNonNull(analyzerProviderManager, "analyzerProviderManager is null");
this.queryPreparerProviderManager = requireNonNull(queryPreparerProviderManager, "queryPreparerProviderManager is null");
this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null");
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
}

public void loadPlugins()
Expand Down Expand Up @@ -348,6 +353,11 @@ public void installPlugin(Plugin plugin)
log.info("Registering node status notification provider %s", nodeStatusNotificationProviderFactory.getName());
nodeStatusNotificationManager.addNodeStatusNotificationProviderFactory(nodeStatusNotificationProviderFactory);
}

for (ClientRequestFilter clientRequestFilter : plugin.getClientRequestFilters()) {
log.info("Registering client request filter");
clientRequestFilterManager.registerClientRequestFilter(clientRequestFilter);
}
}

public void installCoordinatorPlugin(CoordinatorPlugin plugin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.facebook.airlift.tracetoken.TraceTokenModule;
import com.facebook.drift.server.DriftServer;
import com.facebook.drift.transport.netty.server.DriftNettyServerTransport;
import com.facebook.presto.ClientRequestFilterModule;
import com.facebook.presto.dispatcher.QueryPrerequisitesManager;
import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule;
import com.facebook.presto.eventlistener.EventListenerManager;
Expand Down Expand Up @@ -133,7 +134,8 @@ public void run()
new TempStorageModule(),
new QueryPrerequisitesManagerModule(),
new NodeTtlFetcherManagerModule(),
new ClusterTtlProviderManagerModule());
new ClusterTtlProviderManagerModule(),
new ClientRequestFilterModule());

modules.addAll(getAdditionalModules());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@

import com.facebook.airlift.http.server.AuthenticationException;
import com.facebook.airlift.http.server.Authenticator;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.PrestoException;
import com.google.common.base.Joiner;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.net.HttpHeaders;

import javax.inject.Inject;
Expand All @@ -35,10 +39,18 @@
import java.io.InputStream;
import java.io.PrintWriter;
import java.security.Principal;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT;
import static com.google.common.io.ByteStreams.copy;
import static com.google.common.io.ByteStreams.nullOutputStream;
import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE;
Expand All @@ -52,12 +64,15 @@ public class AuthenticationFilter
private static final String HTTPS_PROTOCOL = "https";
private final List<Authenticator> authenticators;
private final boolean allowForwardedHttps;
private final ClientRequestFilterManager clientRequestFilterManager;
private final List<String> headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token");

@Inject
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig)
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager)
{
this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null"));
this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps();
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null");
}

@Override
Expand Down Expand Up @@ -95,9 +110,37 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo
e.getAuthenticateHeader().ifPresent(authenticateHeaders::add);
continue;
}

// authentication succeeded
nextFilter.doFilter(withPrincipal(request, principal), response);
CustomHttpServletRequestWrapper wrappedRequest = new CustomHttpServletRequestWrapper(request);
Map<String, String> extraHeadersMap = new HashMap<>();
Set<String> globallyAddedHeaders = new HashSet<>();

for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) {
boolean headersPresent = requestFilter.getHeaderNames().stream()
.allMatch(headerName -> request.getHeader(headerName) != null);

if (!headersPresent) {
Optional<Map<String, String>> extraHeaderValueMap = requestFilter.getExtraHeaders(principal);

extraHeaderValueMap.ifPresent(map -> {
for (Map.Entry<String, String> extraHeaderEntry : map.entrySet()) {
String headerKey = extraHeaderEntry.getKey();
if (headersBlockList.contains(headerKey)) {
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Modification attempt detected: The header " + headerKey + " is present in the blocked headers list.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of mentioning a "blocked headers list", just list out the headers that are not allowed to be modified.

}
if (globallyAddedHeaders.contains(headerKey)) {
throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not throw RuntimeException, throw PrestoException.

}
if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) {
extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue());
globallyAddedHeaders.add(headerKey);
}
}
});
}
}
wrappedRequest.setHeaders(extraHeadersMap);
nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response);
return;
}

Expand Down Expand Up @@ -140,7 +183,7 @@ private boolean doesRequestSupportAuthentication(HttpServletRequest request)
return false;
}

private static ServletRequest withPrincipal(HttpServletRequest request, Principal principal)
public static ServletRequest withPrincipal(HttpServletRequest request, Principal principal)
{
requireNonNull(principal, "principal is null");
return new HttpServletRequestWrapper(request)
Expand All @@ -166,4 +209,51 @@ private static void skipRequestBody(HttpServletRequest request)
copy(inputStream, nullOutputStream());
}
}

public static class CustomHttpServletRequestWrapper
extends HttpServletRequestWrapper
{
private final Map<String, String> customHeaders = new ConcurrentHashMap<>();

public CustomHttpServletRequestWrapper(HttpServletRequest request)
{
super(request);
}

@Override
public String getHeader(String name)
{
String headerValue = customHeaders.get(name);
if (headerValue != null) {
return headerValue;
}
return super.getHeader(name);
}

@Override
public Enumeration<String> getHeaderNames()
{
ImmutableSet.Builder<String> headerNamesBuilder = ImmutableSet.builder();
headerNamesBuilder.addAll(customHeaders.keySet());
Enumeration<String> originalHeaderNames = super.getHeaderNames();
while (originalHeaderNames.hasMoreElements()) {
headerNamesBuilder.add(originalHeaderNames.nextElement());
}
return Collections.enumeration(headerNamesBuilder.build());
}

@Override
public Enumeration<String> getHeaders(String name)
{
if (customHeaders.containsKey(name)) {
return Collections.enumeration(Collections.singleton(customHeaders.get(name)));
}
return super.getHeaders(name);
}

public void setHeaders(Map<String, String> headers)
{
this.customHeaders.putAll(headers);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import com.facebook.airlift.tracetoken.TraceTokenModule;
import com.facebook.drift.server.DriftServer;
import com.facebook.drift.transport.netty.server.DriftNettyServerTransport;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.ClientRequestFilterModule;
import com.facebook.presto.connector.ConnectorManager;
import com.facebook.presto.cost.StatsCalculator;
import com.facebook.presto.dispatcher.DispatchManager;
Expand Down Expand Up @@ -59,6 +61,7 @@
import com.facebook.presto.server.ServerMainModule;
import com.facebook.presto.server.ShutdownAction;
import com.facebook.presto.server.security.ServerSecurityModule;
import com.facebook.presto.spi.ClientRequestFilter;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.CoordinatorPlugin;
import com.facebook.presto.spi.Plugin;
Expand Down Expand Up @@ -306,6 +309,7 @@ public TestingPrestoServer(
.add(new QueryPrerequisitesManagerModule())
.add(new NodeTtlFetcherManagerModule())
.add(new ClusterTtlProviderManagerModule())
.add(new ClientRequestFilterModule())
.add(binder -> {
binder.bind(TestingAccessControlManager.class).in(Scopes.SINGLETON);
binder.bind(TestingEventListenerManager.class).in(Scopes.SINGLETON);
Expand Down Expand Up @@ -839,4 +843,11 @@ private static int driftServerPort(DriftServer server)
{
return ((DriftNettyServerTransport) server.getServerTransport()).getPort();
}

public static ClientRequestFilterManager getClientRequestFilterManager(List<ClientRequestFilter> requestFilters)
{
ClientRequestFilterManager manager = new ClientRequestFilterManager();
requestFilters.forEach(manager::registerClientRequestFilter);
return manager;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package com.facebook.presto.testing;

import com.facebook.airlift.node.NodeInfo;
import com.facebook.presto.ClientRequestFilterManager;
import com.facebook.presto.GroupByHashPageIndexerFactory;
import com.facebook.presto.PagesIndexPageSorter;
import com.facebook.presto.Session;
Expand Down Expand Up @@ -515,7 +516,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig,
new ThrowingClusterTtlProviderManager(),
historyBasedPlanStatisticsManager,
new TracerProviderManager(new TracingConfig()),
new NodeStatusNotificationManager());
new NodeStatusNotificationManager(),
new ClientRequestFilterManager());

connectorManager.addConnectorFactory(globalSystemConnectorFactory);
connectorManager.createConnection(GlobalSystemConnector.NAME, GlobalSystemConnector.NAME, ImmutableMap.of());
Expand Down
Loading
Loading