From 55c4278115e72cd6f6b905931ac55266a060b838 Mon Sep 17 00:00:00 2001 From: sthuthighoshm Date: Thu, 27 Jun 2024 01:48:16 +0530 Subject: [PATCH 01/13] Added new interface for the plugin - Incorporated the changes for using the plugin in authentication filter. --- .../presto/RequestModifierManager.java | 38 +++++++ .../presto/RequestModifierModule.java | 28 +++++ .../facebook/presto/server/PluginManager.java | 12 +- .../facebook/presto/server/PrestoServer.java | 4 +- .../server/security/AuthenticationFilter.java | 107 ++++++++++++++++-- .../server/testing/TestingPrestoServer.java | 2 + .../presto/testing/LocalQueryRunner.java | 4 +- .../java/com/facebook/presto/spi/Plugin.java | 5 + .../facebook/presto/spi/RequestModifier.java | 25 ++++ 9 files changed, 210 insertions(+), 15 deletions(-) create mode 100644 presto-main/src/main/java/com/facebook/presto/RequestModifierManager.java create mode 100644 presto-main/src/main/java/com/facebook/presto/RequestModifierModule.java create mode 100644 presto-spi/src/main/java/com/facebook/presto/spi/RequestModifier.java diff --git a/presto-main/src/main/java/com/facebook/presto/RequestModifierManager.java b/presto-main/src/main/java/com/facebook/presto/RequestModifierManager.java new file mode 100644 index 000000000000..e598caaa68b8 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/RequestModifierManager.java @@ -0,0 +1,38 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto; + +import com.facebook.presto.spi.RequestModifier; + +import java.util.ArrayList; +import java.util.List; + +public class RequestModifierManager +{ + private final List requestModifiers; + public RequestModifierManager() + { + this.requestModifiers = new ArrayList<>(); + } + + public List getRequestModifiers() + { + return new ArrayList<>(requestModifiers); + } + + public void registerRequestModifier(RequestModifier requestModifier) + { + requestModifiers.add(requestModifier); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/RequestModifierModule.java b/presto-main/src/main/java/com/facebook/presto/RequestModifierModule.java new file mode 100644 index 000000000000..bf95affbc374 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/RequestModifierModule.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +public class RequestModifierModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(RequestModifierManager.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java index 0f72c608bb3b..176380747107 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java @@ -15,6 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.RequestModifierManager; import com.facebook.presto.common.block.BlockEncoding; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.type.ParametricType; @@ -28,6 +29,7 @@ import com.facebook.presto.security.AccessControlManager; import com.facebook.presto.server.security.PasswordAuthenticatorManager; import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.RequestModifier; import com.facebook.presto.spi.analyzer.AnalyzerProvider; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.ConnectorFactory; @@ -125,6 +127,7 @@ public class PluginManager private final TracerProviderManager tracerProviderManager; private final AnalyzerProviderManager analyzerProviderManager; private final NodeStatusNotificationManager nodeStatusNotificationManager; + private final RequestModifierManager requestModifierManager; @Inject public PluginManager( @@ -145,7 +148,8 @@ public PluginManager( ClusterTtlProviderManager clusterTtlProviderManager, HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, TracerProviderManager tracerProviderManager, - NodeStatusNotificationManager nodeStatusNotificationManager) + NodeStatusNotificationManager nodeStatusNotificationManager, + RequestModifierManager requestModifierManager) { requireNonNull(nodeInfo, "nodeInfo is null"); requireNonNull(config, "config is null"); @@ -176,6 +180,7 @@ public PluginManager( this.tracerProviderManager = requireNonNull(tracerProviderManager, "tracerProviderManager is null"); this.analyzerProviderManager = requireNonNull(analyzerProviderManager, "analyzerProviderManager is null"); this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null"); + this.requestModifierManager = requireNonNull(requestModifierManager, "requestModifierManager is null"); } public void loadPlugins() @@ -326,6 +331,11 @@ public void installPlugin(Plugin plugin) log.info("Registering node status notification provider %s", nodeStatusNotificationProviderFactory.getName()); nodeStatusNotificationManager.addNodeStatusNotificationProviderFactory(nodeStatusNotificationProviderFactory); } + + for (RequestModifier requestModifier : plugin.getRequestModifiers()) { + log.info("Registering request modifier"); + requestModifierManager.registerRequestModifier(requestModifier); + } } private URLClassLoader buildClassLoader(String plugin) diff --git a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java index 8b3aaa3009cc..f4e37268edc6 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java @@ -30,6 +30,7 @@ import com.facebook.airlift.tracetoken.TraceTokenModule; import com.facebook.drift.server.DriftServer; import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; +import com.facebook.presto.RequestModifierModule; import com.facebook.presto.dispatcher.QueryPrerequisitesManager; import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule; import com.facebook.presto.eventlistener.EventListenerManager; @@ -133,7 +134,8 @@ public void run() new TempStorageModule(), new QueryPrerequisitesManagerModule(), new NodeTtlFetcherManagerModule(), - new ClusterTtlProviderManagerModule()); + new ClusterTtlProviderManagerModule(), + new RequestModifierModule()); modules.addAll(getAdditionalModules()); diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 6e43f6739164..c5d9372403c8 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -15,6 +15,8 @@ import com.facebook.airlift.http.server.AuthenticationException; import com.facebook.airlift.http.server.Authenticator; +import com.facebook.presto.RequestModifierManager; +import com.facebook.presto.spi.RequestModifier; import com.google.common.base.Joiner; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; @@ -34,8 +36,14 @@ import java.io.IOException; import java.io.InputStream; import java.security.Principal; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.Set; import static com.google.common.io.ByteStreams.copy; @@ -50,12 +58,14 @@ public class AuthenticationFilter private static final String HTTPS_PROTOCOL = "https"; private final List authenticators; private final boolean allowForwardedHttps; + private final RequestModifierManager requestModifierManager; @Inject - public AuthenticationFilter(List authenticators, SecurityConfig securityConfig) + public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, RequestModifierManager requestModifierManager) { this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null")); this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps(); + this.requestModifierManager = requireNonNull(requestModifierManager, "requestModifierManager is null"); } @Override @@ -93,9 +103,28 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo e.getAuthenticateHeader().ifPresent(authenticateHeaders::add); continue; } - // authentication succeeded - nextFilter.doFilter(withPrincipal(request, principal), response); + CustomHttpServletRequestWrapper wrappedRequest = withPrincipal(request, principal); + Map extraHeadersMap = new HashMap<>(); + + for (RequestModifier modifier : requestModifierManager.getRequestModifiers()) { + boolean headersPresent = modifier.getHeaderNames().stream() + .allMatch(headerName -> request.getHeaders(headerName) != null); + + if (!headersPresent) { + Optional> extraHeaderValueMap = modifier.getExtraHeaders(principal); + + extraHeaderValueMap.ifPresent(map -> { + for (Map.Entry extraHeaderEntry : map.entrySet()) { + if (request.getHeaders(extraHeaderEntry.getKey()) == null) { + extraHeadersMap.putIfAbsent(extraHeaderEntry.getKey(), extraHeaderEntry.getValue()); + } + } + }); + } + } + wrappedRequest.setHeaders(extraHeadersMap); + nextFilter.doFilter(wrappedRequest, response); return; } @@ -126,17 +155,10 @@ private boolean doesRequestSupportAuthentication(HttpServletRequest request) return false; } - private static ServletRequest withPrincipal(HttpServletRequest request, Principal principal) + private static CustomHttpServletRequestWrapper withPrincipal(HttpServletRequest request, Principal principal) { requireNonNull(principal, "principal is null"); - return new HttpServletRequestWrapper(request) - { - @Override - public Principal getUserPrincipal() - { - return principal; - } - }; + return new CustomHttpServletRequestWrapper(request, principal); } private static void skipRequestBody(HttpServletRequest request) @@ -152,4 +174,65 @@ private static void skipRequestBody(HttpServletRequest request) copy(inputStream, nullOutputStream()); } } + + public static class CustomHttpServletRequestWrapper + extends HttpServletRequestWrapper + { + private final Map customHeaders; + + private final Principal principal; + + public CustomHttpServletRequestWrapper(HttpServletRequest request, Principal principal) + { + super(request); + this.principal = principal; + this.customHeaders = new HashMap<>(); + } + + public void addHeader(String name, String value) + { + customHeaders.put(name, value); + } + + @Override + public String getHeader(String name) + { + String headerValue = customHeaders.get(name); + if (headerValue != null) { + return headerValue; + } + return super.getHeader(name); + } + + @Override + public Enumeration getHeaderNames() + { + Set headerNames = new HashSet<>(customHeaders.keySet()); + Enumeration originalHeaderNames = super.getHeaderNames(); + while (originalHeaderNames.hasMoreElements()) { + headerNames.add(originalHeaderNames.nextElement()); + } + return Collections.enumeration(headerNames); + } + + @Override + public Enumeration getHeaders(String name) + { + if (customHeaders.containsKey(name)) { + return Collections.enumeration(Collections.singleton(customHeaders.get(name))); + } + return super.getHeaders(name); + } + + @Override + public Principal getUserPrincipal() + { + return principal; + } + + public void setHeaders(Map headers) + { + this.customHeaders.putAll(headers); + } + } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index 0b10eceff089..c7b38071d888 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -32,6 +32,7 @@ import com.facebook.airlift.tracetoken.TraceTokenModule; import com.facebook.drift.server.DriftServer; import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; +import com.facebook.presto.RequestModifierModule; import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.dispatcher.DispatchManager; @@ -303,6 +304,7 @@ public TestingPrestoServer( .add(new QueryPrerequisitesManagerModule()) .add(new NodeTtlFetcherManagerModule()) .add(new ClusterTtlProviderManagerModule()) + .add(new RequestModifierModule()) .add(binder -> { binder.bind(TestingAccessControlManager.class).in(Scopes.SINGLETON); binder.bind(TestingEventListenerManager.class).in(Scopes.SINGLETON); diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index 4c0d205a4572..dd65e3f4d6a7 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -16,6 +16,7 @@ import com.facebook.airlift.node.NodeInfo; import com.facebook.presto.GroupByHashPageIndexerFactory; import com.facebook.presto.PagesIndexPageSorter; +import com.facebook.presto.RequestModifierManager; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.client.NodeVersion; @@ -508,7 +509,8 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, new ThrowingClusterTtlProviderManager(), historyBasedPlanStatisticsManager, new TracerProviderManager(new TracingConfig()), - new NodeStatusNotificationManager()); + new NodeStatusNotificationManager(), + new RequestModifierManager()); connectorManager.addConnectorFactory(globalSystemConnectorFactory); connectorManager.createConnection(GlobalSystemConnector.NAME, GlobalSystemConnector.NAME, ImmutableMap.of()); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java index 64d58edba0db..bfef0ca25e89 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java @@ -136,4 +136,9 @@ default Iterable getNodeStatusNotificatio { return emptyList(); } + + default Iterable getRequestModifiers() + { + return emptyList(); + } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/RequestModifier.java b/presto-spi/src/main/java/com/facebook/presto/spi/RequestModifier.java new file mode 100644 index 000000000000..aaf80e38470e --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/RequestModifier.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi; + +import java.util.List; +import java.util.Map; +import java.util.Optional; + +public interface RequestModifier +{ + List getHeaderNames(); + + Optional> getExtraHeaders(T additionalInfo); +} From fb5d8126ca3c3a55ed0c61ce2b87ca286f051a1f Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Sun, 25 Aug 2024 19:28:02 +0530 Subject: [PATCH 02/13] - Add a test case to check the modification of the request header. --- presto-main/pom.xml | 5 + .../server/security/AuthenticationFilter.java | 2 +- .../presto/TestRequestModifierPlugin.java | 122 ++++++++++++++++++ 3 files changed, 128 insertions(+), 1 deletion(-) create mode 100644 presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java diff --git a/presto-main/pom.xml b/presto-main/pom.xml index 824132c1001b..0d0404812096 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -499,6 +499,11 @@ ratis-common true + + org.mockito + mockito-core + test + diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index c5d9372403c8..8a02aa624a17 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -155,7 +155,7 @@ private boolean doesRequestSupportAuthentication(HttpServletRequest request) return false; } - private static CustomHttpServletRequestWrapper withPrincipal(HttpServletRequest request, Principal principal) + public CustomHttpServletRequestWrapper withPrincipal(HttpServletRequest request, Principal principal) { requireNonNull(principal, "principal is null"); return new CustomHttpServletRequestWrapper(request, principal); diff --git a/presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java new file mode 100644 index 000000000000..3ea65b672152 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java @@ -0,0 +1,122 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto; + +import com.facebook.airlift.http.server.AuthenticationException; +import com.facebook.airlift.http.server.Authenticator; +import com.facebook.presto.server.security.AuthenticationFilter; +import com.facebook.presto.server.security.SecurityConfig; +import com.facebook.presto.spi.RequestModifier; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import java.io.IOException; +import java.security.Principal; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; + +public class TestRequestModifierPlugin +{ + @Mock + private HttpServletRequest request; + + @Mock + private HttpServletResponse response; + + @Mock + private FilterChain filterChain; + + @Mock + private Authenticator authenticator; + + @Mock + private RequestModifierManager requestModifierManager; + + @Mock + private RequestModifier requestModifier; + + @Mock + private SecurityConfig securityConfig; + + private AuthenticationFilter filter; + private List authenticators; + + @BeforeClass + public void setup() + { + MockitoAnnotations.initMocks(this); + authenticators = new ArrayList<>(); + authenticators.add(authenticator); + filter = spy(new AuthenticationFilter(authenticators, securityConfig, requestModifierManager)); + } + + @Test + public void testDoFilter_SuccessfulAuthenticationWithHeaderModification() throws ServletException, IOException, AuthenticationException + { + Principal testPrincipal = mock(Principal.class); + when(authenticator.authenticate(request)).thenReturn(testPrincipal); + + // Set up RequestModifierManager to return a RequestModifier + when(requestModifierManager.getRequestModifiers()).thenReturn(Collections.singletonList(requestModifier)); + + // Mock behavior of RequestModifier and HttpServletRequest + when(requestModifier.getHeaderNames()).thenReturn(Collections.singletonList("Authorization")); + when(request.getHeaders("Authorization")).thenReturn(null); + when(request.getPathInfo()).thenReturn("/oauth2/token-value/"); + when(request.isSecure()).thenReturn(true); + + // Set up the extra header to be returned by the RequestModifier + Map extraHeaders = new HashMap<>(); + extraHeaders.put("X-Custom-Header", "CustomValue"); + when(requestModifier.getExtraHeaders(testPrincipal)).thenReturn(Optional.of(extraHeaders)); + + AuthenticationFilter.CustomHttpServletRequestWrapper wrappedRequest = spy(new AuthenticationFilter.CustomHttpServletRequestWrapper(request, testPrincipal)); + doNothing().when(wrappedRequest).setHeaders(any(Map.class)); + + doReturn(wrappedRequest).when(filter).withPrincipal(request, testPrincipal); + + filter.doFilter(request, response, filterChain); + + ArgumentCaptor headersCaptor = ArgumentCaptor.forClass(Map.class); + verify(wrappedRequest).setHeaders(headersCaptor.capture()); + Map capturedHeaders = headersCaptor.getValue(); + assertEquals("CustomValue", capturedHeaders.get("X-Custom-Header")); + + verify(filterChain).doFilter(eq(wrappedRequest), eq(response)); + verify(authenticator).authenticate(request); + verify(requestModifier).getExtraHeaders(testPrincipal); + } +} From 0b6e62c139978c639e49c8ec3e557f6dbd646ad6 Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Fri, 6 Sep 2024 00:54:18 +0530 Subject: [PATCH 03/13] - Modified the test case to avoid the use of Mockito --- .../server/security/AuthenticationFilter.java | 4 +- .../presto/TestRequestHeaderModifier.java | 1211 +++++++++++++++++ 2 files changed, 1213 insertions(+), 2 deletions(-) create mode 100644 presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 8a02aa624a17..17ce499f2fb3 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -109,14 +109,14 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo for (RequestModifier modifier : requestModifierManager.getRequestModifiers()) { boolean headersPresent = modifier.getHeaderNames().stream() - .allMatch(headerName -> request.getHeaders(headerName) != null); + .allMatch(headerName -> request.getHeader(headerName) != null); if (!headersPresent) { Optional> extraHeaderValueMap = modifier.getExtraHeaders(principal); extraHeaderValueMap.ifPresent(map -> { for (Map.Entry extraHeaderEntry : map.entrySet()) { - if (request.getHeaders(extraHeaderEntry.getKey()) == null) { + if (request.getHeader(extraHeaderEntry.getKey()) == null) { extraHeadersMap.putIfAbsent(extraHeaderEntry.getKey(), extraHeaderEntry.getValue()); } } diff --git a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java b/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java new file mode 100644 index 000000000000..16f1fcf18c65 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java @@ -0,0 +1,1211 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto; + +import com.facebook.airlift.http.server.Authenticator; +import com.facebook.presto.server.security.AuthenticationFilter; +import com.facebook.presto.server.security.SecurityConfig; +import com.facebook.presto.spi.RequestModifier; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import javax.servlet.AsyncContext; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import javax.servlet.FilterChain; +import javax.servlet.FilterRegistration; +import javax.servlet.RequestDispatcher; +import javax.servlet.Servlet; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.ServletInputStream; +import javax.servlet.ServletOutputStream; +import javax.servlet.ServletRegistration; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.SessionCookieConfig; +import javax.servlet.SessionTrackingMode; +import javax.servlet.descriptor.JspConfigDescriptor; +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpUpgradeHandler; +import javax.servlet.http.Part; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.PrintWriter; +import java.net.URL; +import java.security.Principal; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.EventListener; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.testng.Assert.assertEquals; + +public class TestRequestHeaderModifier +{ + private MockWebServer mockWebServer; + private HttpServletResponse response; + private FilterChainStub filterChain; + private AuthenticationFilter filter; + private AuthenticatorStub authenticator; + private RequestModifierManagerStub requestModifierManager; + private RequestModifierStub requestModifier; + + @BeforeMethod + public void setUp() throws IOException + { + mockWebServer = new MockWebServer(); + mockWebServer.start(); + + response = new ConcreteHttpServletResponse(); + filterChain = new FilterChainStub(); + + authenticator = new AuthenticatorStub(); + requestModifierManager = new RequestModifierManagerStub(); + requestModifier = new RequestModifierStub(); + + List authenticators = Collections.singletonList(authenticator); + filter = new AuthenticationFilter(authenticators, new SecurityConfigStub(), requestModifierManager); + } + + @AfterMethod + public void tearDown() throws IOException + { + mockWebServer.shutdown(); + } + + @Test + public void testDoFilter_SuccessfulAuthenticationWithHeaderModification() throws ServletException, IOException + { + mockWebServer.enqueue(new MockResponse().setBody("Mocked Body").setResponseCode(200)); + + ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(); + request.setPathInfo("/oauth2/token-value/"); + request.setSecure(true); + + PrincipalStub testPrincipal = new PrincipalStub(); + authenticator.setPrincipal(testPrincipal); + + requestModifierManager.setModifiers(Collections.singletonList(requestModifier)); + requestModifier.setHeaderNames(Collections.singletonList("Extra-credential")); + requestModifier.setExtraHeaders(Collections.singletonMap("X-Custom-Header", "CustomValue")); + + filter.doFilter(request, response, filterChain); + + HttpServletRequest wrappedRequest = (HttpServletRequest) filterChain.getCapturedRequest(); + assertEquals("CustomValue", wrappedRequest.getHeader("X-Custom-Header")); + } + + abstract static class HttpServletRequestAdapter + implements HttpServletRequest + { + @Override + public String getAuthType() + { + throw new UnsupportedOperationException(); + } + @Override + public Cookie[] getCookies() + { + throw new UnsupportedOperationException(); + } + @Override + public long getDateHeader(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public String getHeader(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public Enumeration getHeaders(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public Enumeration getHeaderNames() + { + throw new UnsupportedOperationException(); + } + @Override + public int getIntHeader(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public String getMethod() + { + throw new UnsupportedOperationException(); + } + @Override + public String getPathInfo() + { + throw new UnsupportedOperationException(); + } + @Override + public String getPathTranslated() + { + throw new UnsupportedOperationException(); + } + @Override + public String getContextPath() + { + throw new UnsupportedOperationException(); + } + @Override + public String getQueryString() + { + throw new UnsupportedOperationException(); + } + @Override + public String getRemoteUser() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isUserInRole(String role) + { + throw new UnsupportedOperationException(); + } + @Override + public Principal getUserPrincipal() + { + throw new UnsupportedOperationException(); + } + @Override + public String getRequestedSessionId() + { + throw new UnsupportedOperationException(); + } + @Override + public String getRequestURI() + { + throw new UnsupportedOperationException(); + } + @Override + public StringBuffer getRequestURL() + { + throw new UnsupportedOperationException(); + } + @Override + public String getServletPath() + { + throw new UnsupportedOperationException(); + } + @Override + public HttpSession getSession(boolean create) + { + throw new UnsupportedOperationException(); + } + @Override + public HttpSession getSession() + { + throw new UnsupportedOperationException(); + } + @Override + public String changeSessionId() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isRequestedSessionIdValid() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isRequestedSessionIdFromCookie() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isRequestedSessionIdFromURL() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isRequestedSessionIdFromUrl() + { + throw new UnsupportedOperationException(); + } + @Override + public Object getAttribute(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public Enumeration getAttributeNames() + { + throw new UnsupportedOperationException(); + } + @Override + public String getCharacterEncoding() + { + throw new UnsupportedOperationException(); + } + @Override + public void setCharacterEncoding(String env) + { + throw new UnsupportedOperationException(); + } + @Override + public int getContentLength() + { + throw new UnsupportedOperationException(); + } + @Override + public long getContentLengthLong() + { + throw new UnsupportedOperationException(); + } + @Override + public String getContentType() + { + throw new UnsupportedOperationException(); + } + @Override + public ServletInputStream getInputStream() + { + throw new UnsupportedOperationException(); + } + @Override + public String getParameter(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public Enumeration getParameterNames() + { + throw new UnsupportedOperationException(); + } + @Override + public String[] getParameterValues(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public Map getParameterMap() + { + throw new UnsupportedOperationException(); + } + @Override + public String getProtocol() + { + throw new UnsupportedOperationException(); + } + @Override + public String getScheme() + { + throw new UnsupportedOperationException(); + } + @Override + public String getServerName() + { + throw new UnsupportedOperationException(); + } + @Override + public int getServerPort() + { + throw new UnsupportedOperationException(); + } + @Override + public BufferedReader getReader() + { + throw new UnsupportedOperationException(); + } + @Override + public String getRemoteAddr() + { + throw new UnsupportedOperationException(); + } + @Override + public String getRemoteHost() + { + throw new UnsupportedOperationException(); + } + @Override + public void setAttribute(String name, Object o) + { + throw new UnsupportedOperationException(); + } + @Override + public void removeAttribute(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public Locale getLocale() + { + throw new UnsupportedOperationException(); + } + @Override + public Enumeration getLocales() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isSecure() + { + throw new UnsupportedOperationException(); + } + @Override + public RequestDispatcher getRequestDispatcher(String path) + { + throw new UnsupportedOperationException(); + } + @Override + public String getRealPath(String path) + { + throw new UnsupportedOperationException(); + } + @Override + public int getRemotePort() + { + throw new UnsupportedOperationException(); + } + @Override + public String getLocalName() + { + throw new UnsupportedOperationException(); + } + @Override + public String getLocalAddr() + { + throw new UnsupportedOperationException(); + } + @Override + public int getLocalPort() + { + throw new UnsupportedOperationException(); + } + @Override + public ServletContext getServletContext() + { + throw new UnsupportedOperationException(); + } + @Override + public AsyncContext startAsync() throws IllegalStateException + { + throw new UnsupportedOperationException(); + } + @Override + public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isAsyncStarted() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isAsyncSupported() + { + throw new UnsupportedOperationException(); + } + @Override + public AsyncContext getAsyncContext() + { + throw new UnsupportedOperationException(); + } + @Override + public DispatcherType getDispatcherType() + { + throw new UnsupportedOperationException(); + } + } + + abstract static class HttpServletResponseAdapter + implements HttpServletResponse + { + @Override + public void addCookie(Cookie cookie) + { + throw new UnsupportedOperationException(); + } + @Override + public boolean containsHeader(String name) + { + throw new UnsupportedOperationException(); + } + @Override + public String encodeRedirectURL(String url) + { + throw new UnsupportedOperationException(); + } + @Override + public String encodeRedirectUrl(String url) + { + throw new UnsupportedOperationException(); + } + @Override + public String encodeURL(String url) + { + throw new UnsupportedOperationException(); + } + @Override + public String encodeUrl(String url) + { + throw new UnsupportedOperationException(); + } + @Override + public void sendError(int sc, String msg) + { + throw new UnsupportedOperationException(); + } + @Override + public void sendError(int sc) + { + throw new UnsupportedOperationException(); + } + @Override + public void sendRedirect(String location) + { + throw new UnsupportedOperationException(); + } + @Override + public void setDateHeader(String name, long date) + { + throw new UnsupportedOperationException(); + } + @Override + public void addDateHeader(String name, long date) + { + throw new UnsupportedOperationException(); + } + @Override + public void setHeader(String name, String value) + { + throw new UnsupportedOperationException(); + } + @Override + public void addHeader(String name, String value) + { + throw new UnsupportedOperationException(); + } + @Override + public void setIntHeader(String name, int value) + { + throw new UnsupportedOperationException(); + } + @Override + public void addIntHeader(String name, int value) + { + throw new UnsupportedOperationException(); + } + @Override + public void setContentLength(int len) + { + throw new UnsupportedOperationException(); + } + @Override + public void setContentLengthLong(long len) + { + throw new UnsupportedOperationException(); + } + @Override + public void setContentType(String type) + { + throw new UnsupportedOperationException(); + } + @Override + public void setBufferSize(int size) + { + throw new UnsupportedOperationException(); + } + @Override + public int getBufferSize() + { + throw new UnsupportedOperationException(); + } + @Override + public void flushBuffer() + { + throw new UnsupportedOperationException(); + } + @Override + public void resetBuffer() + { + throw new UnsupportedOperationException(); + } + @Override + public boolean isCommitted() + { + throw new UnsupportedOperationException(); + } + @Override + public void reset() + { + throw new UnsupportedOperationException(); + } + @Override + public void setLocale(Locale loc) + { + throw new UnsupportedOperationException(); + } + @Override + public Locale getLocale() + { + throw new UnsupportedOperationException(); + } + } + + static class ConcreteHttpServletRequest + extends HttpServletRequestAdapter + { + private final Map headers = new HashMap<>(); + private Principal principal; + private final Map attributes = new HashMap<>(); + private final ServletContext servletContext = new ServletContext() + { + @Override + public String getContextPath() + { + return null; + } + + @Override + public ServletContext getContext(String s) + { + return null; + } + + @Override + public int getMajorVersion() + { + return 0; + } + + @Override + public int getMinorVersion() + { + return 0; + } + + @Override + public int getEffectiveMajorVersion() + { + return 0; + } + + @Override + public int getEffectiveMinorVersion() + { + return 0; + } + + @Override + public String getMimeType(String s) + { + return null; + } + + @Override + public Set getResourcePaths(String s) + { + return null; + } + + @Override + public URL getResource(String s) + { + return null; + } + + @Override + public InputStream getResourceAsStream(String s) + { + return null; + } + + @Override + public RequestDispatcher getRequestDispatcher(String s) + { + return null; + } + + @Override + public RequestDispatcher getNamedDispatcher(String s) + { + return null; + } + + @Override + public Servlet getServlet(String s) + { + return null; + } + + @Override + public Enumeration getServlets() + { + return null; + } + + @Override + public Enumeration getServletNames() + { + return null; + } + + @Override + public void log(String s) + { + } + + @Override + public void log(Exception e, String s) + { + } + + @Override + public void log(String s, Throwable throwable) + { + } + + @Override + public String getRealPath(String s) + { + return null; + } + + @Override + public String getServerInfo() + { + return null; + } + + @Override + public String getInitParameter(String s) + { + return null; + } + + @Override + public Enumeration getInitParameterNames() + { + return null; + } + + @Override + public boolean setInitParameter(String s, String s1) + { + return false; + } + + @Override + public Object getAttribute(String s) + { + return null; + } + + @Override + public Enumeration getAttributeNames() + { + return null; + } + + @Override + public void setAttribute(String s, Object o) + { + } + + @Override + public void removeAttribute(String s) + { + } + + @Override + public String getServletContextName() + { + return null; + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, String s1) + { + return null; + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Servlet servlet) + { + return null; + } + + @Override + public ServletRegistration.Dynamic addServlet(String s, Class aClass) + { + return null; + } + + @Override + public T createServlet(Class aClass) + { + return null; + } + + @Override + public ServletRegistration getServletRegistration(String s) + { + return null; + } + + @Override + public Map getServletRegistrations() + { + return null; + } + + @Override + public FilterRegistration.Dynamic addFilter(String s, String s1) + { + return null; + } + + @Override + public FilterRegistration.Dynamic addFilter(String s, Filter filter) + { + return null; + } + + @Override + public FilterRegistration.Dynamic addFilter(String s, Class aClass) + { + return null; + } + + @Override + public T createFilter(Class aClass) + { + return null; + } + + @Override + public FilterRegistration getFilterRegistration(String s) + { + return null; + } + + @Override + public Map getFilterRegistrations() + { + return null; + } + + @Override + public SessionCookieConfig getSessionCookieConfig() + { + return null; + } + + @Override + public void setSessionTrackingModes(Set set) + { + } + + @Override + public Set getDefaultSessionTrackingModes() + { + return null; + } + + @Override + public Set getEffectiveSessionTrackingModes() + { + return null; + } + + @Override + public void addListener(String s) + { + } + + @Override + public void addListener(T t) + { + } + + @Override + public void addListener(Class aClass) + { + } + + @Override + public T createListener(Class aClass) + { + return null; + } + + @Override + public JspConfigDescriptor getJspConfigDescriptor() + { + return null; + } + + @Override + public ClassLoader getClassLoader() + { + return null; + } + + @Override + public void declareRoles(String... strings) + { + } + + @Override + public String getVirtualServerName() + { + return null; + } + }; + + private boolean secure = true; + private String pathInfo = "/oauth2/token-value/"; + + @Override + public String getHeader(String name) + { + return headers.get(name); + } + + @Override + public Enumeration getHeaders(String name) + { + String header = headers.get(name); + return header != null ? Collections.enumeration(Collections.singletonList(header)) : Collections.enumeration(Collections.emptyList()); + } + + @Override + public Enumeration getHeaderNames() + { + return Collections.enumeration(headers.keySet()); + } + + @Override + public void setAttribute(String name, Object o) + { + attributes.put(name, o); + } + + @Override + public void removeAttribute(String name) + { + attributes.remove(name); + } + + @Override + public ServletContext getServletContext() + { + return servletContext; + } + + @Override + public boolean isSecure() + { + return secure; + } + + public void setSecure(boolean secure) + { + this.secure = secure; + } + + @Override + public String getPathInfo() + { + return pathInfo; + } + + public void setPathInfo(String pathInfo) + { + this.pathInfo = pathInfo; + } + + public void setHeader(String name, String value) + { + headers.put(name, value); + } + + @Override + public String getRequestURI() + { + return "/example"; + } + + @Override + public StringBuffer getRequestURL() + { + return new StringBuffer("http://example.com"); + } + + @Override + public boolean authenticate(HttpServletResponse httpServletResponse) + { + return false; + } + + @Override + public void login(String s, String s1) + { + } + + @Override + public void logout() + { + } + + @Override + public Collection getParts() + { + return null; + } + + @Override + public Part getPart(String s) + { + return null; + } + + @Override + public T upgrade(Class aClass) + { + return null; + } + } + + static class ConcreteHttpServletResponse + extends HttpServletResponseAdapter + { + private final PrintWriter writer = new PrintWriter(System.out); + private int status; + private String contentType; + + @Override + public void setStatus(int sc) + { + this.status = sc; + } + + @Override + public void setStatus(int i, String s) + { + } + + @Override + public int getStatus() + { + return 0; + } + + @Override + public String getHeader(String s) + { + return null; + } + + @Override + public Collection getHeaders(String s) + { + return null; + } + + @Override + public Collection getHeaderNames() + { + return null; + } + + @Override + public void setContentType(String type) + { + this.contentType = type; + } + + @Override + public String getCharacterEncoding() + { + return null; + } + + @Override + public String getContentType() + { + return null; + } + + @Override + public ServletOutputStream getOutputStream() + { + return null; + } + + @Override + public PrintWriter getWriter() + { + return writer; + } + + @Override + public void setCharacterEncoding(String s) + { + } + + @Override + public int getBufferSize() + { + return 0; + } + + @Override + public void setBufferSize(int size) + { + } + + @Override + public boolean isCommitted() + { + return false; + } + + @Override + public void resetBuffer() + { + } + } + + static class FilterChainStub + implements FilterChain + { + private boolean filterCalled = true; + private ServletRequest capturedRequest; + + @Override + public void doFilter(ServletRequest request, ServletResponse response) + { + this.capturedRequest = request; + } + + public ServletRequest getCapturedRequest() + { + return capturedRequest; + } + + public boolean isFilterCalled() + { + return filterCalled; + } + } + + static class AuthenticatorStub + implements Authenticator + { + private Principal principal; + private boolean authenticateCalled; + + @Override + public Principal authenticate(HttpServletRequest request) + { + authenticateCalled = true; + return principal; + } + + public void setPrincipal(Principal principal) + { + this.principal = principal; + } + + public boolean isAuthenticateCalled() + { + return authenticateCalled; + } + } + + static class RequestModifierManagerStub + extends RequestModifierManager + { + private List modifiers; + + @Override + public List getRequestModifiers() + { + return modifiers; + } + + public void setModifiers(List modifiers) + { + this.modifiers = modifiers; + } + } + + static class RequestModifierStub + implements RequestModifier + { + private Map extraHeaders; + private List headerNames; + + @Override + public List getHeaderNames() + { + return Collections.singletonList("Authorization"); + } + + @Override + public Optional> getExtraHeaders(T additionalInfo) + { + return Optional.of(Collections.singletonMap("X-Custom-Header", "CustomValue")); + } + + public void setExtraHeaders(Map extraHeaders) + { + this.extraHeaders = extraHeaders; + } + + public void setHeaderNames(List headerNames) + { + this.headerNames = headerNames; + } + } + + static class SecurityConfigStub + extends SecurityConfig + { + } + + static class PrincipalStub + implements Principal + { + @Override + public String getName() + { + return "TestPrincipal"; + } + } +} From a3d875da831132a6a4e9365fbd09dbc523916d15 Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Fri, 6 Sep 2024 01:02:39 +0530 Subject: [PATCH 04/13] - Removed Mockito dependency and added mockwebserver --- presto-main/pom.xml | 4 +- .../presto/TestRequestModifierPlugin.java | 122 ------------------ 2 files changed, 2 insertions(+), 124 deletions(-) delete mode 100644 presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java diff --git a/presto-main/pom.xml b/presto-main/pom.xml index 0d0404812096..0bddf4b75f1e 100644 --- a/presto-main/pom.xml +++ b/presto-main/pom.xml @@ -500,8 +500,8 @@ true - org.mockito - mockito-core + com.squareup.okhttp3 + mockwebserver test diff --git a/presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java deleted file mode 100644 index 3ea65b672152..000000000000 --- a/presto-main/src/test/java/com/facebook/presto/TestRequestModifierPlugin.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto; - -import com.facebook.airlift.http.server.AuthenticationException; -import com.facebook.airlift.http.server.Authenticator; -import com.facebook.presto.server.security.AuthenticationFilter; -import com.facebook.presto.server.security.SecurityConfig; -import com.facebook.presto.spi.RequestModifier; -import org.mockito.ArgumentCaptor; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.testng.annotations.BeforeClass; -import org.testng.annotations.Test; - -import javax.servlet.FilterChain; -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import java.io.IOException; -import java.security.Principal; -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; - -import static org.mockito.Matchers.any; -import static org.mockito.Matchers.eq; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import static org.testng.Assert.assertEquals; - -public class TestRequestModifierPlugin -{ - @Mock - private HttpServletRequest request; - - @Mock - private HttpServletResponse response; - - @Mock - private FilterChain filterChain; - - @Mock - private Authenticator authenticator; - - @Mock - private RequestModifierManager requestModifierManager; - - @Mock - private RequestModifier requestModifier; - - @Mock - private SecurityConfig securityConfig; - - private AuthenticationFilter filter; - private List authenticators; - - @BeforeClass - public void setup() - { - MockitoAnnotations.initMocks(this); - authenticators = new ArrayList<>(); - authenticators.add(authenticator); - filter = spy(new AuthenticationFilter(authenticators, securityConfig, requestModifierManager)); - } - - @Test - public void testDoFilter_SuccessfulAuthenticationWithHeaderModification() throws ServletException, IOException, AuthenticationException - { - Principal testPrincipal = mock(Principal.class); - when(authenticator.authenticate(request)).thenReturn(testPrincipal); - - // Set up RequestModifierManager to return a RequestModifier - when(requestModifierManager.getRequestModifiers()).thenReturn(Collections.singletonList(requestModifier)); - - // Mock behavior of RequestModifier and HttpServletRequest - when(requestModifier.getHeaderNames()).thenReturn(Collections.singletonList("Authorization")); - when(request.getHeaders("Authorization")).thenReturn(null); - when(request.getPathInfo()).thenReturn("/oauth2/token-value/"); - when(request.isSecure()).thenReturn(true); - - // Set up the extra header to be returned by the RequestModifier - Map extraHeaders = new HashMap<>(); - extraHeaders.put("X-Custom-Header", "CustomValue"); - when(requestModifier.getExtraHeaders(testPrincipal)).thenReturn(Optional.of(extraHeaders)); - - AuthenticationFilter.CustomHttpServletRequestWrapper wrappedRequest = spy(new AuthenticationFilter.CustomHttpServletRequestWrapper(request, testPrincipal)); - doNothing().when(wrappedRequest).setHeaders(any(Map.class)); - - doReturn(wrappedRequest).when(filter).withPrincipal(request, testPrincipal); - - filter.doFilter(request, response, filterChain); - - ArgumentCaptor headersCaptor = ArgumentCaptor.forClass(Map.class); - verify(wrappedRequest).setHeaders(headersCaptor.capture()); - Map capturedHeaders = headersCaptor.getValue(); - assertEquals("CustomValue", capturedHeaders.get("X-Custom-Header")); - - verify(filterChain).doFilter(eq(wrappedRequest), eq(response)); - verify(authenticator).authenticate(request); - verify(requestModifier).getExtraHeaders(testPrincipal); - } -} From 87c12ecc53f97513c68877c5de3d7514973dddac Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Fri, 6 Sep 2024 11:25:03 +0530 Subject: [PATCH 05/13] - used MockHttpServletRequest --- .../presto/TestRequestHeaderModifier.java | 343 +----------------- 1 file changed, 10 insertions(+), 333 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java b/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java index 16f1fcf18c65..148657820b5f 100644 --- a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java +++ b/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java @@ -14,17 +14,18 @@ package com.facebook.presto; import com.facebook.airlift.http.server.Authenticator; +import com.facebook.presto.server.MockHttpServletRequest; import com.facebook.presto.server.security.AuthenticationFilter; import com.facebook.presto.server.security.SecurityConfig; import com.facebook.presto.spi.RequestModifier; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import javax.servlet.AsyncContext; -import javax.servlet.DispatcherType; import javax.servlet.Filter; import javax.servlet.FilterChain; import javax.servlet.FilterRegistration; @@ -32,7 +33,6 @@ import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.ServletException; -import javax.servlet.ServletInputStream; import javax.servlet.ServletOutputStream; import javax.servlet.ServletRegistration; import javax.servlet.ServletRequest; @@ -43,11 +43,9 @@ import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; import javax.servlet.http.HttpUpgradeHandler; import javax.servlet.http.Part; -import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; import java.io.PrintWriter; @@ -104,7 +102,7 @@ public void testDoFilter_SuccessfulAuthenticationWithHeaderModification() throws { mockWebServer.enqueue(new MockResponse().setBody("Mocked Body").setResponseCode(200)); - ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(); + ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); request.setPathInfo("/oauth2/token-value/"); request.setSecure(true); @@ -121,326 +119,6 @@ public void testDoFilter_SuccessfulAuthenticationWithHeaderModification() throws assertEquals("CustomValue", wrappedRequest.getHeader("X-Custom-Header")); } - abstract static class HttpServletRequestAdapter - implements HttpServletRequest - { - @Override - public String getAuthType() - { - throw new UnsupportedOperationException(); - } - @Override - public Cookie[] getCookies() - { - throw new UnsupportedOperationException(); - } - @Override - public long getDateHeader(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public String getHeader(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public Enumeration getHeaders(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public Enumeration getHeaderNames() - { - throw new UnsupportedOperationException(); - } - @Override - public int getIntHeader(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public String getMethod() - { - throw new UnsupportedOperationException(); - } - @Override - public String getPathInfo() - { - throw new UnsupportedOperationException(); - } - @Override - public String getPathTranslated() - { - throw new UnsupportedOperationException(); - } - @Override - public String getContextPath() - { - throw new UnsupportedOperationException(); - } - @Override - public String getQueryString() - { - throw new UnsupportedOperationException(); - } - @Override - public String getRemoteUser() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isUserInRole(String role) - { - throw new UnsupportedOperationException(); - } - @Override - public Principal getUserPrincipal() - { - throw new UnsupportedOperationException(); - } - @Override - public String getRequestedSessionId() - { - throw new UnsupportedOperationException(); - } - @Override - public String getRequestURI() - { - throw new UnsupportedOperationException(); - } - @Override - public StringBuffer getRequestURL() - { - throw new UnsupportedOperationException(); - } - @Override - public String getServletPath() - { - throw new UnsupportedOperationException(); - } - @Override - public HttpSession getSession(boolean create) - { - throw new UnsupportedOperationException(); - } - @Override - public HttpSession getSession() - { - throw new UnsupportedOperationException(); - } - @Override - public String changeSessionId() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isRequestedSessionIdValid() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isRequestedSessionIdFromCookie() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isRequestedSessionIdFromURL() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isRequestedSessionIdFromUrl() - { - throw new UnsupportedOperationException(); - } - @Override - public Object getAttribute(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public Enumeration getAttributeNames() - { - throw new UnsupportedOperationException(); - } - @Override - public String getCharacterEncoding() - { - throw new UnsupportedOperationException(); - } - @Override - public void setCharacterEncoding(String env) - { - throw new UnsupportedOperationException(); - } - @Override - public int getContentLength() - { - throw new UnsupportedOperationException(); - } - @Override - public long getContentLengthLong() - { - throw new UnsupportedOperationException(); - } - @Override - public String getContentType() - { - throw new UnsupportedOperationException(); - } - @Override - public ServletInputStream getInputStream() - { - throw new UnsupportedOperationException(); - } - @Override - public String getParameter(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public Enumeration getParameterNames() - { - throw new UnsupportedOperationException(); - } - @Override - public String[] getParameterValues(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public Map getParameterMap() - { - throw new UnsupportedOperationException(); - } - @Override - public String getProtocol() - { - throw new UnsupportedOperationException(); - } - @Override - public String getScheme() - { - throw new UnsupportedOperationException(); - } - @Override - public String getServerName() - { - throw new UnsupportedOperationException(); - } - @Override - public int getServerPort() - { - throw new UnsupportedOperationException(); - } - @Override - public BufferedReader getReader() - { - throw new UnsupportedOperationException(); - } - @Override - public String getRemoteAddr() - { - throw new UnsupportedOperationException(); - } - @Override - public String getRemoteHost() - { - throw new UnsupportedOperationException(); - } - @Override - public void setAttribute(String name, Object o) - { - throw new UnsupportedOperationException(); - } - @Override - public void removeAttribute(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public Locale getLocale() - { - throw new UnsupportedOperationException(); - } - @Override - public Enumeration getLocales() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isSecure() - { - throw new UnsupportedOperationException(); - } - @Override - public RequestDispatcher getRequestDispatcher(String path) - { - throw new UnsupportedOperationException(); - } - @Override - public String getRealPath(String path) - { - throw new UnsupportedOperationException(); - } - @Override - public int getRemotePort() - { - throw new UnsupportedOperationException(); - } - @Override - public String getLocalName() - { - throw new UnsupportedOperationException(); - } - @Override - public String getLocalAddr() - { - throw new UnsupportedOperationException(); - } - @Override - public int getLocalPort() - { - throw new UnsupportedOperationException(); - } - @Override - public ServletContext getServletContext() - { - throw new UnsupportedOperationException(); - } - @Override - public AsyncContext startAsync() throws IllegalStateException - { - throw new UnsupportedOperationException(); - } - @Override - public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isAsyncStarted() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isAsyncSupported() - { - throw new UnsupportedOperationException(); - } - @Override - public AsyncContext getAsyncContext() - { - throw new UnsupportedOperationException(); - } - @Override - public DispatcherType getDispatcherType() - { - throw new UnsupportedOperationException(); - } - } - abstract static class HttpServletResponseAdapter implements HttpServletResponse { @@ -577,7 +255,7 @@ public Locale getLocale() } static class ConcreteHttpServletRequest - extends HttpServletRequestAdapter + extends MockHttpServletRequest { private final Map headers = new HashMap<>(); private Principal principal; @@ -890,6 +568,11 @@ public String getVirtualServerName() private boolean secure = true; private String pathInfo = "/oauth2/token-value/"; + public ConcreteHttpServletRequest(ListMultimap headers, String remoteAddress, Map attributes) + { + super(headers, remoteAddress, attributes); + } + @Override public String getHeader(String name) { @@ -960,12 +643,6 @@ public String getRequestURI() return "/example"; } - @Override - public StringBuffer getRequestURL() - { - return new StringBuffer("http://example.com"); - } - @Override public boolean authenticate(HttpServletResponse httpServletResponse) { From 1506807159a162d1503b7200de7fdfdb30d1d2ba Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Sun, 8 Sep 2024 18:44:13 +0530 Subject: [PATCH 06/13] - Refactored the code according to the review comment --- .../server/testing/TestingPrestoServer.java | 24 + .../presto/TestRequestHeaderModifier.java | 888 ------------------ .../TestRequestHeaderModifierPlugin.java | 110 +++ 3 files changed, 134 insertions(+), 888 deletions(-) delete mode 100644 presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java create mode 100644 presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifierPlugin.java diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index c7b38071d888..be3141f8f947 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -32,6 +32,7 @@ import com.facebook.airlift.tracetoken.TraceTokenModule; import com.facebook.drift.server.DriftServer; import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; +import com.facebook.presto.RequestModifierManager; import com.facebook.presto.RequestModifierModule; import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.cost.StatsCalculator; @@ -63,6 +64,7 @@ import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.QueryId; +import com.facebook.presto.spi.RequestModifier; import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.security.AccessControl; @@ -108,6 +110,7 @@ import java.io.UncheckedIOException; import java.net.URI; import java.nio.file.Path; +import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -827,4 +830,25 @@ private static int driftServerPort(DriftServer server) { return ((DriftNettyServerTransport) server.getServerTransport()).getPort(); } + + public RequestModifierManager getRequestModifierManager() + { + RequestModifierManager manager = new RequestModifierManager(); + RequestModifier sampleModifier = new RequestModifier() { + @Override + public List getHeaderNames() + { + return Collections.singletonList("Extra-credential"); + } + @Override + public Optional> getExtraHeaders(T additionalInfo) + { + Map headers = new HashMap<>(); + headers.put("X-Custom-Header", "CustomValue"); + return Optional.of(headers); + } + }; + manager.registerRequestModifier(sampleModifier); + return manager; + } } diff --git a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java b/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java deleted file mode 100644 index 148657820b5f..000000000000 --- a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifier.java +++ /dev/null @@ -1,888 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.facebook.presto; - -import com.facebook.airlift.http.server.Authenticator; -import com.facebook.presto.server.MockHttpServletRequest; -import com.facebook.presto.server.security.AuthenticationFilter; -import com.facebook.presto.server.security.SecurityConfig; -import com.facebook.presto.spi.RequestModifier; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ListMultimap; -import okhttp3.mockwebserver.MockResponse; -import okhttp3.mockwebserver.MockWebServer; -import org.testng.annotations.AfterMethod; -import org.testng.annotations.BeforeMethod; -import org.testng.annotations.Test; - -import javax.servlet.Filter; -import javax.servlet.FilterChain; -import javax.servlet.FilterRegistration; -import javax.servlet.RequestDispatcher; -import javax.servlet.Servlet; -import javax.servlet.ServletContext; -import javax.servlet.ServletException; -import javax.servlet.ServletOutputStream; -import javax.servlet.ServletRegistration; -import javax.servlet.ServletRequest; -import javax.servlet.ServletResponse; -import javax.servlet.SessionCookieConfig; -import javax.servlet.SessionTrackingMode; -import javax.servlet.descriptor.JspConfigDescriptor; -import javax.servlet.http.Cookie; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpUpgradeHandler; -import javax.servlet.http.Part; - -import java.io.IOException; -import java.io.InputStream; -import java.io.PrintWriter; -import java.net.URL; -import java.security.Principal; -import java.util.Collection; -import java.util.Collections; -import java.util.Enumeration; -import java.util.EventListener; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.Set; - -import static org.testng.Assert.assertEquals; - -public class TestRequestHeaderModifier -{ - private MockWebServer mockWebServer; - private HttpServletResponse response; - private FilterChainStub filterChain; - private AuthenticationFilter filter; - private AuthenticatorStub authenticator; - private RequestModifierManagerStub requestModifierManager; - private RequestModifierStub requestModifier; - - @BeforeMethod - public void setUp() throws IOException - { - mockWebServer = new MockWebServer(); - mockWebServer.start(); - - response = new ConcreteHttpServletResponse(); - filterChain = new FilterChainStub(); - - authenticator = new AuthenticatorStub(); - requestModifierManager = new RequestModifierManagerStub(); - requestModifier = new RequestModifierStub(); - - List authenticators = Collections.singletonList(authenticator); - filter = new AuthenticationFilter(authenticators, new SecurityConfigStub(), requestModifierManager); - } - - @AfterMethod - public void tearDown() throws IOException - { - mockWebServer.shutdown(); - } - - @Test - public void testDoFilter_SuccessfulAuthenticationWithHeaderModification() throws ServletException, IOException - { - mockWebServer.enqueue(new MockResponse().setBody("Mocked Body").setResponseCode(200)); - - ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); - request.setPathInfo("/oauth2/token-value/"); - request.setSecure(true); - - PrincipalStub testPrincipal = new PrincipalStub(); - authenticator.setPrincipal(testPrincipal); - - requestModifierManager.setModifiers(Collections.singletonList(requestModifier)); - requestModifier.setHeaderNames(Collections.singletonList("Extra-credential")); - requestModifier.setExtraHeaders(Collections.singletonMap("X-Custom-Header", "CustomValue")); - - filter.doFilter(request, response, filterChain); - - HttpServletRequest wrappedRequest = (HttpServletRequest) filterChain.getCapturedRequest(); - assertEquals("CustomValue", wrappedRequest.getHeader("X-Custom-Header")); - } - - abstract static class HttpServletResponseAdapter - implements HttpServletResponse - { - @Override - public void addCookie(Cookie cookie) - { - throw new UnsupportedOperationException(); - } - @Override - public boolean containsHeader(String name) - { - throw new UnsupportedOperationException(); - } - @Override - public String encodeRedirectURL(String url) - { - throw new UnsupportedOperationException(); - } - @Override - public String encodeRedirectUrl(String url) - { - throw new UnsupportedOperationException(); - } - @Override - public String encodeURL(String url) - { - throw new UnsupportedOperationException(); - } - @Override - public String encodeUrl(String url) - { - throw new UnsupportedOperationException(); - } - @Override - public void sendError(int sc, String msg) - { - throw new UnsupportedOperationException(); - } - @Override - public void sendError(int sc) - { - throw new UnsupportedOperationException(); - } - @Override - public void sendRedirect(String location) - { - throw new UnsupportedOperationException(); - } - @Override - public void setDateHeader(String name, long date) - { - throw new UnsupportedOperationException(); - } - @Override - public void addDateHeader(String name, long date) - { - throw new UnsupportedOperationException(); - } - @Override - public void setHeader(String name, String value) - { - throw new UnsupportedOperationException(); - } - @Override - public void addHeader(String name, String value) - { - throw new UnsupportedOperationException(); - } - @Override - public void setIntHeader(String name, int value) - { - throw new UnsupportedOperationException(); - } - @Override - public void addIntHeader(String name, int value) - { - throw new UnsupportedOperationException(); - } - @Override - public void setContentLength(int len) - { - throw new UnsupportedOperationException(); - } - @Override - public void setContentLengthLong(long len) - { - throw new UnsupportedOperationException(); - } - @Override - public void setContentType(String type) - { - throw new UnsupportedOperationException(); - } - @Override - public void setBufferSize(int size) - { - throw new UnsupportedOperationException(); - } - @Override - public int getBufferSize() - { - throw new UnsupportedOperationException(); - } - @Override - public void flushBuffer() - { - throw new UnsupportedOperationException(); - } - @Override - public void resetBuffer() - { - throw new UnsupportedOperationException(); - } - @Override - public boolean isCommitted() - { - throw new UnsupportedOperationException(); - } - @Override - public void reset() - { - throw new UnsupportedOperationException(); - } - @Override - public void setLocale(Locale loc) - { - throw new UnsupportedOperationException(); - } - @Override - public Locale getLocale() - { - throw new UnsupportedOperationException(); - } - } - - static class ConcreteHttpServletRequest - extends MockHttpServletRequest - { - private final Map headers = new HashMap<>(); - private Principal principal; - private final Map attributes = new HashMap<>(); - private final ServletContext servletContext = new ServletContext() - { - @Override - public String getContextPath() - { - return null; - } - - @Override - public ServletContext getContext(String s) - { - return null; - } - - @Override - public int getMajorVersion() - { - return 0; - } - - @Override - public int getMinorVersion() - { - return 0; - } - - @Override - public int getEffectiveMajorVersion() - { - return 0; - } - - @Override - public int getEffectiveMinorVersion() - { - return 0; - } - - @Override - public String getMimeType(String s) - { - return null; - } - - @Override - public Set getResourcePaths(String s) - { - return null; - } - - @Override - public URL getResource(String s) - { - return null; - } - - @Override - public InputStream getResourceAsStream(String s) - { - return null; - } - - @Override - public RequestDispatcher getRequestDispatcher(String s) - { - return null; - } - - @Override - public RequestDispatcher getNamedDispatcher(String s) - { - return null; - } - - @Override - public Servlet getServlet(String s) - { - return null; - } - - @Override - public Enumeration getServlets() - { - return null; - } - - @Override - public Enumeration getServletNames() - { - return null; - } - - @Override - public void log(String s) - { - } - - @Override - public void log(Exception e, String s) - { - } - - @Override - public void log(String s, Throwable throwable) - { - } - - @Override - public String getRealPath(String s) - { - return null; - } - - @Override - public String getServerInfo() - { - return null; - } - - @Override - public String getInitParameter(String s) - { - return null; - } - - @Override - public Enumeration getInitParameterNames() - { - return null; - } - - @Override - public boolean setInitParameter(String s, String s1) - { - return false; - } - - @Override - public Object getAttribute(String s) - { - return null; - } - - @Override - public Enumeration getAttributeNames() - { - return null; - } - - @Override - public void setAttribute(String s, Object o) - { - } - - @Override - public void removeAttribute(String s) - { - } - - @Override - public String getServletContextName() - { - return null; - } - - @Override - public ServletRegistration.Dynamic addServlet(String s, String s1) - { - return null; - } - - @Override - public ServletRegistration.Dynamic addServlet(String s, Servlet servlet) - { - return null; - } - - @Override - public ServletRegistration.Dynamic addServlet(String s, Class aClass) - { - return null; - } - - @Override - public T createServlet(Class aClass) - { - return null; - } - - @Override - public ServletRegistration getServletRegistration(String s) - { - return null; - } - - @Override - public Map getServletRegistrations() - { - return null; - } - - @Override - public FilterRegistration.Dynamic addFilter(String s, String s1) - { - return null; - } - - @Override - public FilterRegistration.Dynamic addFilter(String s, Filter filter) - { - return null; - } - - @Override - public FilterRegistration.Dynamic addFilter(String s, Class aClass) - { - return null; - } - - @Override - public T createFilter(Class aClass) - { - return null; - } - - @Override - public FilterRegistration getFilterRegistration(String s) - { - return null; - } - - @Override - public Map getFilterRegistrations() - { - return null; - } - - @Override - public SessionCookieConfig getSessionCookieConfig() - { - return null; - } - - @Override - public void setSessionTrackingModes(Set set) - { - } - - @Override - public Set getDefaultSessionTrackingModes() - { - return null; - } - - @Override - public Set getEffectiveSessionTrackingModes() - { - return null; - } - - @Override - public void addListener(String s) - { - } - - @Override - public void addListener(T t) - { - } - - @Override - public void addListener(Class aClass) - { - } - - @Override - public T createListener(Class aClass) - { - return null; - } - - @Override - public JspConfigDescriptor getJspConfigDescriptor() - { - return null; - } - - @Override - public ClassLoader getClassLoader() - { - return null; - } - - @Override - public void declareRoles(String... strings) - { - } - - @Override - public String getVirtualServerName() - { - return null; - } - }; - - private boolean secure = true; - private String pathInfo = "/oauth2/token-value/"; - - public ConcreteHttpServletRequest(ListMultimap headers, String remoteAddress, Map attributes) - { - super(headers, remoteAddress, attributes); - } - - @Override - public String getHeader(String name) - { - return headers.get(name); - } - - @Override - public Enumeration getHeaders(String name) - { - String header = headers.get(name); - return header != null ? Collections.enumeration(Collections.singletonList(header)) : Collections.enumeration(Collections.emptyList()); - } - - @Override - public Enumeration getHeaderNames() - { - return Collections.enumeration(headers.keySet()); - } - - @Override - public void setAttribute(String name, Object o) - { - attributes.put(name, o); - } - - @Override - public void removeAttribute(String name) - { - attributes.remove(name); - } - - @Override - public ServletContext getServletContext() - { - return servletContext; - } - - @Override - public boolean isSecure() - { - return secure; - } - - public void setSecure(boolean secure) - { - this.secure = secure; - } - - @Override - public String getPathInfo() - { - return pathInfo; - } - - public void setPathInfo(String pathInfo) - { - this.pathInfo = pathInfo; - } - - public void setHeader(String name, String value) - { - headers.put(name, value); - } - - @Override - public String getRequestURI() - { - return "/example"; - } - - @Override - public boolean authenticate(HttpServletResponse httpServletResponse) - { - return false; - } - - @Override - public void login(String s, String s1) - { - } - - @Override - public void logout() - { - } - - @Override - public Collection getParts() - { - return null; - } - - @Override - public Part getPart(String s) - { - return null; - } - - @Override - public T upgrade(Class aClass) - { - return null; - } - } - - static class ConcreteHttpServletResponse - extends HttpServletResponseAdapter - { - private final PrintWriter writer = new PrintWriter(System.out); - private int status; - private String contentType; - - @Override - public void setStatus(int sc) - { - this.status = sc; - } - - @Override - public void setStatus(int i, String s) - { - } - - @Override - public int getStatus() - { - return 0; - } - - @Override - public String getHeader(String s) - { - return null; - } - - @Override - public Collection getHeaders(String s) - { - return null; - } - - @Override - public Collection getHeaderNames() - { - return null; - } - - @Override - public void setContentType(String type) - { - this.contentType = type; - } - - @Override - public String getCharacterEncoding() - { - return null; - } - - @Override - public String getContentType() - { - return null; - } - - @Override - public ServletOutputStream getOutputStream() - { - return null; - } - - @Override - public PrintWriter getWriter() - { - return writer; - } - - @Override - public void setCharacterEncoding(String s) - { - } - - @Override - public int getBufferSize() - { - return 0; - } - - @Override - public void setBufferSize(int size) - { - } - - @Override - public boolean isCommitted() - { - return false; - } - - @Override - public void resetBuffer() - { - } - } - - static class FilterChainStub - implements FilterChain - { - private boolean filterCalled = true; - private ServletRequest capturedRequest; - - @Override - public void doFilter(ServletRequest request, ServletResponse response) - { - this.capturedRequest = request; - } - - public ServletRequest getCapturedRequest() - { - return capturedRequest; - } - - public boolean isFilterCalled() - { - return filterCalled; - } - } - - static class AuthenticatorStub - implements Authenticator - { - private Principal principal; - private boolean authenticateCalled; - - @Override - public Principal authenticate(HttpServletRequest request) - { - authenticateCalled = true; - return principal; - } - - public void setPrincipal(Principal principal) - { - this.principal = principal; - } - - public boolean isAuthenticateCalled() - { - return authenticateCalled; - } - } - - static class RequestModifierManagerStub - extends RequestModifierManager - { - private List modifiers; - - @Override - public List getRequestModifiers() - { - return modifiers; - } - - public void setModifiers(List modifiers) - { - this.modifiers = modifiers; - } - } - - static class RequestModifierStub - implements RequestModifier - { - private Map extraHeaders; - private List headerNames; - - @Override - public List getHeaderNames() - { - return Collections.singletonList("Authorization"); - } - - @Override - public Optional> getExtraHeaders(T additionalInfo) - { - return Optional.of(Collections.singletonMap("X-Custom-Header", "CustomValue")); - } - - public void setExtraHeaders(Map extraHeaders) - { - this.extraHeaders = extraHeaders; - } - - public void setHeaderNames(List headerNames) - { - this.headerNames = headerNames; - } - } - - static class SecurityConfigStub - extends SecurityConfig - { - } - - static class PrincipalStub - implements Principal - { - @Override - public String getName() - { - return "TestPrincipal"; - } - } -} diff --git a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifierPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifierPlugin.java new file mode 100644 index 000000000000..38c9bb66555d --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifierPlugin.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto; + +import com.facebook.presto.server.MockHttpServletRequest; +import com.facebook.presto.server.testing.TestingPrestoServer; +import com.facebook.presto.spi.RequestModifier; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import org.testng.annotations.Test; + +import java.security.Principal; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +import static org.testng.Assert.assertEquals; + +public class TestRequestHeaderModifierPlugin +{ + @Test + public void testCustomRequestModifierWithHeaders() throws Exception + { + ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); + TestingPrestoServer server = new TestingPrestoServer(); + RequestModifierManager requestModifierManager = server.getRequestModifierManager(); + PrincipalStub testPrincipal = new PrincipalStub(); + + Map extraHeadersMap = new HashMap<>(); + + for (RequestModifier modifier : requestModifierManager.getRequestModifiers()) { + boolean headersPresent = modifier.getHeaderNames().stream() + .allMatch(headerName -> request.getHeader(headerName) != null); + + if (!headersPresent) { + Optional> extraHeaderValueMap = modifier.getExtraHeaders(testPrincipal); + + extraHeaderValueMap.ifPresent(map -> { + for (Map.Entry extraHeaderEntry : map.entrySet()) { + if (request.getHeader(extraHeaderEntry.getKey()) == null) { + extraHeadersMap.putIfAbsent(extraHeaderEntry.getKey(), extraHeaderEntry.getValue()); + } + } + }); + } + } + request.setHeaders(extraHeadersMap); + assertEquals("CustomValue", request.getHeader("X-Custom-Header")); + } + + static class ConcreteHttpServletRequest + extends MockHttpServletRequest + { + public ConcreteHttpServletRequest(ListMultimap headers, String remoteAddress, Map attributes) + { + super(headers, remoteAddress, attributes); + this.customHeaders = new HashMap<>(); + } + + private final Map customHeaders; + + @Override + public boolean isSecure() + { + return true; + } + + @Override + public String getPathInfo() + { + return "/oauth2/token-value/"; + } + + public void setHeaders(Map headers) + { + this.customHeaders.putAll(headers); + } + + @Override + public Enumeration getHeaders(String name) + { + if (customHeaders.containsKey(name)) { + return Collections.enumeration(Collections.singleton(customHeaders.get(name))); + } + return super.getHeaders(name); + } + } + static class PrincipalStub + implements Principal + { + @Override + public String getName() + { + return "TestPrincipal"; + } + } +} From 0def3a4f36f26be82c04023fa7b442fd6a96e5ed Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Tue, 10 Sep 2024 10:02:30 +0530 Subject: [PATCH 07/13] - Renamed RequestModifier to ClientRequestFilter - Improved code flow based on the review comments --- ...r.java => ClientRequestFilterManager.java} | 21 ++++---- ...le.java => ClientRequestFilterModule.java} | 4 +- .../facebook/presto/server/PluginManager.java | 16 +++--- .../facebook/presto/server/PrestoServer.java | 4 +- .../server/security/AuthenticationFilter.java | 50 ++++++++----------- .../server/testing/TestingPrestoServer.java | 29 +++-------- .../presto/testing/LocalQueryRunner.java | 4 +- ...ava => TestClientRequestFilterPlugin.java} | 30 ++++++++--- ...Modifier.java => ClientRequestFilter.java} | 2 +- .../java/com/facebook/presto/spi/Plugin.java | 2 +- 10 files changed, 79 insertions(+), 83 deletions(-) rename presto-main/src/main/java/com/facebook/presto/{RequestModifierManager.java => ClientRequestFilterManager.java} (51%) rename presto-main/src/main/java/com/facebook/presto/{RequestModifierModule.java => ClientRequestFilterModule.java} (87%) rename presto-main/src/test/java/com/facebook/presto/{TestRequestHeaderModifierPlugin.java => TestClientRequestFilterPlugin.java} (77%) rename presto-spi/src/main/java/com/facebook/presto/spi/{RequestModifier.java => ClientRequestFilter.java} (95%) diff --git a/presto-main/src/main/java/com/facebook/presto/RequestModifierManager.java b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java similarity index 51% rename from presto-main/src/main/java/com/facebook/presto/RequestModifierManager.java rename to presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java index e598caaa68b8..65a642ec9908 100644 --- a/presto-main/src/main/java/com/facebook/presto/RequestModifierManager.java +++ b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java @@ -13,26 +13,27 @@ */ package com.facebook.presto; -import com.facebook.presto.spi.RequestModifier; +import com.facebook.presto.spi.ClientRequestFilter; -import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; -public class RequestModifierManager +public class ClientRequestFilterManager { - private final List requestModifiers; - public RequestModifierManager() + private final CopyOnWriteArrayList clientRequestFilters; + public ClientRequestFilterManager() { - this.requestModifiers = new ArrayList<>(); + this.clientRequestFilters = new CopyOnWriteArrayList<>(); } - public List getRequestModifiers() + public List getClientRequestFilters() { - return new ArrayList<>(requestModifiers); + return Collections.unmodifiableList(clientRequestFilters); } - public void registerRequestModifier(RequestModifier requestModifier) + public void registerClientRequestFilter(ClientRequestFilter clientRequestFilter) { - requestModifiers.add(requestModifier); + clientRequestFilters.add(clientRequestFilter); } } diff --git a/presto-main/src/main/java/com/facebook/presto/RequestModifierModule.java b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterModule.java similarity index 87% rename from presto-main/src/main/java/com/facebook/presto/RequestModifierModule.java rename to presto-main/src/main/java/com/facebook/presto/ClientRequestFilterModule.java index bf95affbc374..4beaa15db8e6 100644 --- a/presto-main/src/main/java/com/facebook/presto/RequestModifierModule.java +++ b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterModule.java @@ -17,12 +17,12 @@ import com.google.inject.Module; import com.google.inject.Scopes; -public class RequestModifierModule +public class ClientRequestFilterModule implements Module { @Override public void configure(Binder binder) { - binder.bind(RequestModifierManager.class).in(Scopes.SINGLETON); + binder.bind(ClientRequestFilterManager.class).in(Scopes.SINGLETON); } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java index 176380747107..684daf2a1367 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java @@ -15,7 +15,7 @@ import com.facebook.airlift.log.Logger; import com.facebook.airlift.node.NodeInfo; -import com.facebook.presto.RequestModifierManager; +import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.common.block.BlockEncoding; import com.facebook.presto.common.block.BlockEncodingManager; import com.facebook.presto.common.type.ParametricType; @@ -28,8 +28,8 @@ import com.facebook.presto.metadata.Metadata; import com.facebook.presto.security.AccessControlManager; import com.facebook.presto.server.security.PasswordAuthenticatorManager; +import com.facebook.presto.spi.ClientRequestFilter; import com.facebook.presto.spi.Plugin; -import com.facebook.presto.spi.RequestModifier; import com.facebook.presto.spi.analyzer.AnalyzerProvider; import com.facebook.presto.spi.classloader.ThreadContextClassLoader; import com.facebook.presto.spi.connector.ConnectorFactory; @@ -127,7 +127,7 @@ public class PluginManager private final TracerProviderManager tracerProviderManager; private final AnalyzerProviderManager analyzerProviderManager; private final NodeStatusNotificationManager nodeStatusNotificationManager; - private final RequestModifierManager requestModifierManager; + private final ClientRequestFilterManager clientRequestFilterManager; @Inject public PluginManager( @@ -149,7 +149,7 @@ public PluginManager( HistoryBasedPlanStatisticsManager historyBasedPlanStatisticsManager, TracerProviderManager tracerProviderManager, NodeStatusNotificationManager nodeStatusNotificationManager, - RequestModifierManager requestModifierManager) + ClientRequestFilterManager clientRequestFilterManager) { requireNonNull(nodeInfo, "nodeInfo is null"); requireNonNull(config, "config is null"); @@ -180,7 +180,7 @@ public PluginManager( this.tracerProviderManager = requireNonNull(tracerProviderManager, "tracerProviderManager is null"); this.analyzerProviderManager = requireNonNull(analyzerProviderManager, "analyzerProviderManager is null"); this.nodeStatusNotificationManager = requireNonNull(nodeStatusNotificationManager, "nodeStatusNotificationManager is null"); - this.requestModifierManager = requireNonNull(requestModifierManager, "requestModifierManager is null"); + this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null"); } public void loadPlugins() @@ -332,9 +332,9 @@ public void installPlugin(Plugin plugin) nodeStatusNotificationManager.addNodeStatusNotificationProviderFactory(nodeStatusNotificationProviderFactory); } - for (RequestModifier requestModifier : plugin.getRequestModifiers()) { - log.info("Registering request modifier"); - requestModifierManager.registerRequestModifier(requestModifier); + for (ClientRequestFilter clientRequestFilter : plugin.getClientRequestFilters()) { + log.info("Registering client request filter"); + clientRequestFilterManager.registerClientRequestFilter(clientRequestFilter); } } diff --git a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java index f4e37268edc6..57127c981100 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PrestoServer.java @@ -30,7 +30,7 @@ import com.facebook.airlift.tracetoken.TraceTokenModule; import com.facebook.drift.server.DriftServer; import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; -import com.facebook.presto.RequestModifierModule; +import com.facebook.presto.ClientRequestFilterModule; import com.facebook.presto.dispatcher.QueryPrerequisitesManager; import com.facebook.presto.dispatcher.QueryPrerequisitesManagerModule; import com.facebook.presto.eventlistener.EventListenerManager; @@ -135,7 +135,7 @@ public void run() new QueryPrerequisitesManagerModule(), new NodeTtlFetcherManagerModule(), new ClusterTtlProviderManagerModule(), - new RequestModifierModule()); + new ClientRequestFilterModule()); modules.addAll(getAdditionalModules()); diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 17ce499f2fb3..c249c020fcd8 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -15,8 +15,8 @@ import com.facebook.airlift.http.server.AuthenticationException; import com.facebook.airlift.http.server.Authenticator; -import com.facebook.presto.RequestModifierManager; -import com.facebook.presto.spi.RequestModifier; +import com.facebook.presto.ClientRequestFilterManager; +import com.facebook.presto.spi.ClientRequestFilter; import com.google.common.base.Joiner; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; @@ -45,6 +45,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import static com.google.common.io.ByteStreams.copy; import static com.google.common.io.ByteStreams.nullOutputStream; @@ -58,14 +59,14 @@ public class AuthenticationFilter private static final String HTTPS_PROTOCOL = "https"; private final List authenticators; private final boolean allowForwardedHttps; - private final RequestModifierManager requestModifierManager; + private final ClientRequestFilterManager clientRequestFilterManager; @Inject - public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, RequestModifierManager requestModifierManager) + public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager) { this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null")); this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps(); - this.requestModifierManager = requireNonNull(requestModifierManager, "requestModifierManager is null"); + this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null"); } @Override @@ -104,15 +105,15 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo continue; } // authentication succeeded - CustomHttpServletRequestWrapper wrappedRequest = withPrincipal(request, principal); + CustomHttpServletRequestWrapper wrappedRequest = new CustomHttpServletRequestWrapper(request); Map extraHeadersMap = new HashMap<>(); - for (RequestModifier modifier : requestModifierManager.getRequestModifiers()) { - boolean headersPresent = modifier.getHeaderNames().stream() + for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { + boolean headersPresent = requestFilter.getHeaderNames().stream() .allMatch(headerName -> request.getHeader(headerName) != null); if (!headersPresent) { - Optional> extraHeaderValueMap = modifier.getExtraHeaders(principal); + Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(principal); extraHeaderValueMap.ifPresent(map -> { for (Map.Entry extraHeaderEntry : map.entrySet()) { @@ -124,7 +125,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo } } wrappedRequest.setHeaders(extraHeadersMap); - nextFilter.doFilter(wrappedRequest, response); + nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response); return; } @@ -155,10 +156,17 @@ private boolean doesRequestSupportAuthentication(HttpServletRequest request) return false; } - public CustomHttpServletRequestWrapper withPrincipal(HttpServletRequest request, Principal principal) + public static ServletRequest withPrincipal(HttpServletRequest request, Principal principal) { requireNonNull(principal, "principal is null"); - return new CustomHttpServletRequestWrapper(request, principal); + return new HttpServletRequestWrapper(request) + { + @Override + public Principal getUserPrincipal() + { + return principal; + } + }; } private static void skipRequestBody(HttpServletRequest request) @@ -180,18 +188,10 @@ public static class CustomHttpServletRequestWrapper { private final Map customHeaders; - private final Principal principal; - - public CustomHttpServletRequestWrapper(HttpServletRequest request, Principal principal) + public CustomHttpServletRequestWrapper(HttpServletRequest request) { super(request); - this.principal = principal; - this.customHeaders = new HashMap<>(); - } - - public void addHeader(String name, String value) - { - customHeaders.put(name, value); + this.customHeaders = new ConcurrentHashMap<>(); } @Override @@ -224,12 +224,6 @@ public Enumeration getHeaders(String name) return super.getHeaders(name); } - @Override - public Principal getUserPrincipal() - { - return principal; - } - public void setHeaders(Map headers) { this.customHeaders.putAll(headers); diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index be3141f8f947..a8f0cdbed990 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -32,8 +32,8 @@ import com.facebook.airlift.tracetoken.TraceTokenModule; import com.facebook.drift.server.DriftServer; import com.facebook.drift.transport.netty.server.DriftNettyServerTransport; -import com.facebook.presto.RequestModifierManager; -import com.facebook.presto.RequestModifierModule; +import com.facebook.presto.ClientRequestFilterManager; +import com.facebook.presto.ClientRequestFilterModule; import com.facebook.presto.connector.ConnectorManager; import com.facebook.presto.cost.StatsCalculator; import com.facebook.presto.dispatcher.DispatchManager; @@ -61,10 +61,10 @@ import com.facebook.presto.server.ServerMainModule; import com.facebook.presto.server.ShutdownAction; import com.facebook.presto.server.security.ServerSecurityModule; +import com.facebook.presto.spi.ClientRequestFilter; import com.facebook.presto.spi.ConnectorId; import com.facebook.presto.spi.Plugin; import com.facebook.presto.spi.QueryId; -import com.facebook.presto.spi.RequestModifier; import com.facebook.presto.spi.eventlistener.EventListener; import com.facebook.presto.spi.memory.ClusterMemoryPoolManager; import com.facebook.presto.spi.security.AccessControl; @@ -110,7 +110,6 @@ import java.io.UncheckedIOException; import java.net.URI; import java.nio.file.Path; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.LinkedHashSet; @@ -307,7 +306,7 @@ public TestingPrestoServer( .add(new QueryPrerequisitesManagerModule()) .add(new NodeTtlFetcherManagerModule()) .add(new ClusterTtlProviderManagerModule()) - .add(new RequestModifierModule()) + .add(new ClientRequestFilterModule()) .add(binder -> { binder.bind(TestingAccessControlManager.class).in(Scopes.SINGLETON); binder.bind(TestingEventListenerManager.class).in(Scopes.SINGLETON); @@ -831,24 +830,10 @@ private static int driftServerPort(DriftServer server) return ((DriftNettyServerTransport) server.getServerTransport()).getPort(); } - public RequestModifierManager getRequestModifierManager() + public static ClientRequestFilterManager getClientRequestFilterManager(ClientRequestFilter customModifier) { - RequestModifierManager manager = new RequestModifierManager(); - RequestModifier sampleModifier = new RequestModifier() { - @Override - public List getHeaderNames() - { - return Collections.singletonList("Extra-credential"); - } - @Override - public Optional> getExtraHeaders(T additionalInfo) - { - Map headers = new HashMap<>(); - headers.put("X-Custom-Header", "CustomValue"); - return Optional.of(headers); - } - }; - manager.registerRequestModifier(sampleModifier); + ClientRequestFilterManager manager = new ClientRequestFilterManager(); + manager.registerClientRequestFilter(customModifier); return manager; } } diff --git a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java index dd65e3f4d6a7..61d1390d57e8 100644 --- a/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java +++ b/presto-main/src/main/java/com/facebook/presto/testing/LocalQueryRunner.java @@ -14,9 +14,9 @@ package com.facebook.presto.testing; import com.facebook.airlift.node.NodeInfo; +import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.GroupByHashPageIndexerFactory; import com.facebook.presto.PagesIndexPageSorter; -import com.facebook.presto.RequestModifierManager; import com.facebook.presto.Session; import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.client.NodeVersion; @@ -510,7 +510,7 @@ private LocalQueryRunner(Session defaultSession, FeaturesConfig featuresConfig, historyBasedPlanStatisticsManager, new TracerProviderManager(new TracingConfig()), new NodeStatusNotificationManager(), - new RequestModifierManager()); + new ClientRequestFilterManager()); connectorManager.addConnectorFactory(globalSystemConnectorFactory); connectorManager.createConnection(GlobalSystemConnector.NAME, GlobalSystemConnector.NAME, ImmutableMap.of()); diff --git a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifierPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java similarity index 77% rename from presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifierPlugin.java rename to presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java index 38c9bb66555d..3e173080e028 100644 --- a/presto-main/src/test/java/com/facebook/presto/TestRequestHeaderModifierPlugin.java +++ b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java @@ -15,7 +15,7 @@ import com.facebook.presto.server.MockHttpServletRequest; import com.facebook.presto.server.testing.TestingPrestoServer; -import com.facebook.presto.spi.RequestModifier; +import com.facebook.presto.spi.ClientRequestFilter; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ListMultimap; import org.testng.annotations.Test; @@ -24,29 +24,45 @@ import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; import static org.testng.Assert.assertEquals; -public class TestRequestHeaderModifierPlugin +public class TestClientRequestFilterPlugin { @Test public void testCustomRequestModifierWithHeaders() throws Exception { ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); - TestingPrestoServer server = new TestingPrestoServer(); - RequestModifierManager requestModifierManager = server.getRequestModifierManager(); + ClientRequestFilter customModifier = new ClientRequestFilter() { + @Override + public List getHeaderNames() + { + return Collections.singletonList("Extra-credential"); + } + @Override + public Optional> getExtraHeaders(T additionalInfo) + { + Map headers = new HashMap<>(); + headers.put("X-Custom-Header", "CustomValue"); + return Optional.of(headers); + } + }; + + ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(customModifier); + PrincipalStub testPrincipal = new PrincipalStub(); Map extraHeadersMap = new HashMap<>(); - for (RequestModifier modifier : requestModifierManager.getRequestModifiers()) { - boolean headersPresent = modifier.getHeaderNames().stream() + for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { + boolean headersPresent = requestFilter.getHeaderNames().stream() .allMatch(headerName -> request.getHeader(headerName) != null); if (!headersPresent) { - Optional> extraHeaderValueMap = modifier.getExtraHeaders(testPrincipal); + Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(testPrincipal); extraHeaderValueMap.ifPresent(map -> { for (Map.Entry extraHeaderEntry : map.entrySet()) { diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/RequestModifier.java b/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilter.java similarity index 95% rename from presto-spi/src/main/java/com/facebook/presto/spi/RequestModifier.java rename to presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilter.java index aaf80e38470e..65cbf6ed32e5 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/RequestModifier.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/ClientRequestFilter.java @@ -17,7 +17,7 @@ import java.util.Map; import java.util.Optional; -public interface RequestModifier +public interface ClientRequestFilter { List getHeaderNames(); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java index bfef0ca25e89..659347b0d98c 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java @@ -137,7 +137,7 @@ default Iterable getNodeStatusNotificatio return emptyList(); } - default Iterable getRequestModifiers() + default Iterable getClientRequestFilters() { return emptyList(); } From 57d0d17cb013cbd1f66507c02d4d868e69366192 Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Wed, 25 Sep 2024 23:27:11 +0530 Subject: [PATCH 08/13] Incorporated review comments --- .../server/security/AuthenticationFilter.java | 17 +- .../server/testing/TestingPrestoServer.java | 4 +- .../presto/TestClientRequestFilterPlugin.java | 186 +++++++++++++++++- 3 files changed, 194 insertions(+), 13 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 73b776199e3a..62f21c3862c7 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -62,6 +62,7 @@ public class AuthenticationFilter private final List authenticators; private final boolean allowForwardedHttps; private final ClientRequestFilterManager clientRequestFilterManager; + private final List headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token"); @Inject public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager) @@ -109,6 +110,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo // authentication succeeded CustomHttpServletRequestWrapper wrappedRequest = new CustomHttpServletRequestWrapper(request); Map extraHeadersMap = new HashMap<>(); + Set globallyAddedHeaders = new HashSet<>(); for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { boolean headersPresent = requestFilter.getHeaderNames().stream() @@ -119,8 +121,16 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo extraHeaderValueMap.ifPresent(map -> { for (Map.Entry extraHeaderEntry : map.entrySet()) { - if (request.getHeader(extraHeaderEntry.getKey()) == null) { - extraHeadersMap.putIfAbsent(extraHeaderEntry.getKey(), extraHeaderEntry.getValue()); + String headerKey = extraHeaderEntry.getKey(); + if (headersBlockList.contains(headerKey)) { + throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); + } + if (globallyAddedHeaders.contains(headerKey)) { + throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); + } + if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { + extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); + globallyAddedHeaders.add(headerKey); } } }); @@ -200,12 +210,11 @@ private static void skipRequestBody(HttpServletRequest request) public static class CustomHttpServletRequestWrapper extends HttpServletRequestWrapper { - private final Map customHeaders; + private final Map customHeaders = new ConcurrentHashMap<>(); public CustomHttpServletRequestWrapper(HttpServletRequest request) { super(request); - this.customHeaders = new ConcurrentHashMap<>(); } @Override diff --git a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java index c11893491e08..31ad79fc9642 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/testing/TestingPrestoServer.java @@ -844,10 +844,10 @@ private static int driftServerPort(DriftServer server) return ((DriftNettyServerTransport) server.getServerTransport()).getPort(); } - public static ClientRequestFilterManager getClientRequestFilterManager(ClientRequestFilter customModifier) + public static ClientRequestFilterManager getClientRequestFilterManager(List requestFilters) { ClientRequestFilterManager manager = new ClientRequestFilterManager(); - manager.registerClientRequestFilter(customModifier); + requestFilters.forEach(manager::registerClientRequestFilter); return manager; } } diff --git a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java index 3e173080e028..f3d81d348d98 100644 --- a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java +++ b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java @@ -16,31 +16,79 @@ import com.facebook.presto.server.MockHttpServletRequest; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.spi.ClientRequestFilter; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ListMultimap; import org.testng.annotations.Test; import java.security.Principal; +import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import static org.testng.Assert.assertEquals; public class TestClientRequestFilterPlugin { + private final List headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token"); + @Test - public void testCustomRequestModifierWithHeaders() throws Exception + public void testCustomRequestFilterWithHeaders() throws Exception { ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); - ClientRequestFilter customModifier = new ClientRequestFilter() { + List requestFilters = getClientRequestFilter(); + + ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(requestFilters); + + PrincipalStub testPrincipal = new PrincipalStub(); + + Map extraHeadersMap = new HashMap<>(); + Set globallyAddedHeaders = new HashSet<>(); + + for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { + boolean headersPresent = requestFilter.getHeaderNames().stream() + .allMatch(headerName -> request.getHeader(headerName) != null); + + if (!headersPresent) { + Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(testPrincipal); + + extraHeaderValueMap.ifPresent(map -> { + for (Map.Entry extraHeaderEntry : map.entrySet()) { + String headerKey = extraHeaderEntry.getKey(); + if (headersBlockList.contains(headerKey)) { + throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); + } + if (globallyAddedHeaders.contains(headerKey)) { + throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); + } + + if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { + extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); + globallyAddedHeaders.add(headerKey); + } + } + }); + } + } + request.setHeaders(extraHeadersMap); + assertEquals("CustomValue", request.getHeader("X-Custom-Header")); + } + + private List getClientRequestFilter() + { + List requestFilters = new ArrayList<>(); + ClientRequestFilter customModifier = new ClientRequestFilter() + { @Override public List getHeaderNames() { - return Collections.singletonList("Extra-credential"); + return Collections.singletonList("X-Custom-Header"); } @Override public Optional> getExtraHeaders(T additionalInfo) @@ -50,12 +98,23 @@ public Optional> getExtraHeaders(T additionalInfo) return Optional.of(headers); } }; + requestFilters.add(customModifier); + return requestFilters; + } - ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(customModifier); + @Test( + expectedExceptions = RuntimeException.class, + expectedExceptionsMessageRegExp = "Modification attempt detected: The header X-Presto-Transaction-Id is present in the blocked headers list.") + public void testCustomRequestFilterWithHeadersInBlockList() + { + ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); + List requestFilters = getClientRequestFilterInBlockList(); + ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(requestFilters); PrincipalStub testPrincipal = new PrincipalStub(); Map extraHeadersMap = new HashMap<>(); + Set globallyAddedHeaders = new HashSet<>(); for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { boolean headersPresent = requestFilter.getHeaderNames().stream() @@ -66,15 +125,128 @@ public Optional> getExtraHeaders(T additionalInfo) extraHeaderValueMap.ifPresent(map -> { for (Map.Entry extraHeaderEntry : map.entrySet()) { - if (request.getHeader(extraHeaderEntry.getKey()) == null) { - extraHeadersMap.putIfAbsent(extraHeaderEntry.getKey(), extraHeaderEntry.getValue()); + String headerKey = extraHeaderEntry.getKey(); + if (headersBlockList.contains(headerKey)) { + throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); + } + if (globallyAddedHeaders.contains(headerKey)) { + throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); + } + + if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { + extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); + globallyAddedHeaders.add(headerKey); } } }); } } request.setHeaders(extraHeadersMap); - assertEquals("CustomValue", request.getHeader("X-Custom-Header")); + } + + private List getClientRequestFilterInBlockList() + { + List requestFilters = new ArrayList<>(); + ClientRequestFilter customModifier = new ClientRequestFilter() + { + @Override + public List getHeaderNames() + { + return Collections.singletonList("X-Presto-Transaction-Id"); + } + @Override + public Optional> getExtraHeaders(T additionalInfo) + { + Map headers = new HashMap<>(); + headers.put("X-Presto-Transaction-Id", "CustomValue"); + return Optional.of(headers); + } + }; + requestFilters.add(customModifier); + return requestFilters; + } + + @Test( + expectedExceptions = RuntimeException.class, + expectedExceptionsMessageRegExp = "Header conflict detected: X-Custom-Header already added by another filter.") + public void testCustomRequestFilterHandlesConflict() + { + ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); + List requestFilters = getClientRequestFilters(); + + ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(requestFilters); + + PrincipalStub testPrincipal = new PrincipalStub(); + + Map extraHeadersMap = new HashMap<>(); + Set globallyAddedHeaders = new HashSet<>(); + + for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { + boolean headersPresent = requestFilter.getHeaderNames().stream() + .allMatch(headerName -> request.getHeader(headerName) != null); + + if (!headersPresent) { + Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(testPrincipal); + + extraHeaderValueMap.ifPresent(map -> { + for (Map.Entry extraHeaderEntry : map.entrySet()) { + String headerKey = extraHeaderEntry.getKey(); + if (headersBlockList.contains(headerKey)) { + throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); + } + if (globallyAddedHeaders.contains(headerKey)) { + throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); + } + + if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { + extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); + globallyAddedHeaders.add(headerKey); + } + } + }); + } + } + request.setHeaders(extraHeadersMap); + } + + private List getClientRequestFilters() + { + List requestFilters = new ArrayList<>(); + ClientRequestFilter customModifier = new ClientRequestFilter() + { + @Override + public List getHeaderNames() + { + return Collections.singletonList("X-Custom-Header"); + } + @Override + public Optional> getExtraHeaders(T additionalInfo) + { + Map headers = new HashMap<>(); + headers.put("X-Custom-Header", "CustomValue_1"); + return Optional.of(headers); + } + }; + + ClientRequestFilter customModifierConflict = new ClientRequestFilter() + { + @Override + public List getHeaderNames() + { + return Collections.singletonList("X-Custom-Header"); + } + @Override + public Optional> getExtraHeaders(T additionalInfo) + { + Map headers = new HashMap<>(); + headers.put("X-Custom-Header", "CustomValue_2"); + return Optional.of(headers); + } + }; + + requestFilters.add(customModifier); + requestFilters.add(customModifierConflict); + return requestFilters; } static class ConcreteHttpServletRequest From 7c28ec321972d88b207bff544bc50e7ed3e8628b Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Thu, 26 Sep 2024 21:48:13 +0530 Subject: [PATCH 09/13] Incorporated review comments --- .../facebook/presto/ClientRequestFilterManager.java | 10 +++------- .../presto/server/security/AuthenticationFilter.java | 12 ++++++++---- .../com/facebook/presto/spi/StandardErrorCode.java | 1 + 3 files changed, 12 insertions(+), 11 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java index 65a642ec9908..bb2eae4ae59d 100644 --- a/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java +++ b/presto-main/src/main/java/com/facebook/presto/ClientRequestFilterManager.java @@ -14,22 +14,18 @@ package com.facebook.presto; import com.facebook.presto.spi.ClientRequestFilter; +import com.google.common.collect.ImmutableList; -import java.util.Collections; import java.util.List; import java.util.concurrent.CopyOnWriteArrayList; public class ClientRequestFilterManager { - private final CopyOnWriteArrayList clientRequestFilters; - public ClientRequestFilterManager() - { - this.clientRequestFilters = new CopyOnWriteArrayList<>(); - } + private final List clientRequestFilters = new CopyOnWriteArrayList<>(); public List getClientRequestFilters() { - return Collections.unmodifiableList(clientRequestFilters); + return ImmutableList.copyOf(clientRequestFilters); } public void registerClientRequestFilter(ClientRequestFilter clientRequestFilter) diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 62f21c3862c7..96eabf9cb365 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -17,9 +17,11 @@ import com.facebook.airlift.http.server.Authenticator; import com.facebook.presto.ClientRequestFilterManager; import com.facebook.presto.spi.ClientRequestFilter; +import com.facebook.presto.spi.PrestoException; import com.google.common.base.Joiner; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.net.HttpHeaders; import javax.inject.Inject; @@ -48,6 +50,7 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT; import static com.google.common.io.ByteStreams.copy; import static com.google.common.io.ByteStreams.nullOutputStream; import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; @@ -123,7 +126,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo for (Map.Entry extraHeaderEntry : map.entrySet()) { String headerKey = extraHeaderEntry.getKey(); if (headersBlockList.contains(headerKey)) { - throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); + throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); } if (globallyAddedHeaders.contains(headerKey)) { throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); @@ -230,12 +233,13 @@ public String getHeader(String name) @Override public Enumeration getHeaderNames() { - Set headerNames = new HashSet<>(customHeaders.keySet()); + ImmutableSet.Builder headerNamesBuilder = ImmutableSet.builder(); + headerNamesBuilder.addAll(customHeaders.keySet()); Enumeration originalHeaderNames = super.getHeaderNames(); while (originalHeaderNames.hasMoreElements()) { - headerNames.add(originalHeaderNames.nextElement()); + headerNamesBuilder.add(originalHeaderNames.nextElement()); } - return Collections.enumeration(headerNames); + return Collections.enumeration(headerNamesBuilder.build()); } @Override diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java index 900a72526e54..894a8abea512 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/StandardErrorCode.java @@ -140,6 +140,7 @@ public enum StandardErrorCode EXCEEDED_WRITTEN_INTERMEDIATE_BYTES_LIMIT(0x0002_0012, INSUFFICIENT_RESOURCES), TOO_MANY_SIDECARS(0x0002_0013, INTERNAL_ERROR), NO_CPP_SIDECARS(0x0002_0014, INTERNAL_ERROR), + HEADER_MODIFICATION_ATTEMPT(0x0002_0015, INTERNAL_ERROR), /**/; // Error code range 0x0003 is reserved for Presto-on-Spark From dc15a70ee028cb5f2bb78ff224c0cf8d0c945518 Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Wed, 9 Oct 2024 01:21:48 +0530 Subject: [PATCH 10/13] incorporated review comments --- .../presto/TestClientRequestFilterPlugin.java | 34 ++++++++----------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java index f3d81d348d98..7bf483a19e38 100644 --- a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java +++ b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java @@ -212,40 +212,36 @@ public void testCustomRequestFilterHandlesConflict() private List getClientRequestFilters() { List requestFilters = new ArrayList<>(); - ClientRequestFilter customModifier = new ClientRequestFilter() + + class CustomHeaderFilter + implements ClientRequestFilter { - @Override - public List getHeaderNames() + private final String headerName; + private final String headerValue; + + public CustomHeaderFilter(String headerName, String headerValue) { - return Collections.singletonList("X-Custom-Header"); + this.headerName = headerName; + this.headerValue = headerValue; } - @Override - public Optional> getExtraHeaders(T additionalInfo) - { - Map headers = new HashMap<>(); - headers.put("X-Custom-Header", "CustomValue_1"); - return Optional.of(headers); - } - }; - ClientRequestFilter customModifierConflict = new ClientRequestFilter() - { @Override public List getHeaderNames() { - return Collections.singletonList("X-Custom-Header"); + return Collections.singletonList(headerName); } + @Override public Optional> getExtraHeaders(T additionalInfo) { Map headers = new HashMap<>(); - headers.put("X-Custom-Header", "CustomValue_2"); + headers.put(headerName, headerValue); return Optional.of(headers); } - }; + } + requestFilters.add(new CustomHeaderFilter("X-Custom-Header", "CustomValue_1")); + requestFilters.add(new CustomHeaderFilter("X-Custom-Header", "CustomValue_2")); - requestFilters.add(customModifier); - requestFilters.add(customModifierConflict); return requestFilters; } From 0c2092ec99b0a9fe2f6d5509576783d4f47f0760 Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Fri, 11 Oct 2024 04:22:48 +0530 Subject: [PATCH 11/13] Refactored code for test cases --- .../server/security/AuthenticationFilter.java | 67 +++--- .../presto/TestClientRequestFilterPlugin.java | 218 +++++++----------- 2 files changed, 120 insertions(+), 165 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java index 96eabf9cb365..abf26b6a5bb6 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/AuthenticationFilter.java @@ -111,35 +111,7 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo continue; } // authentication succeeded - CustomHttpServletRequestWrapper wrappedRequest = new CustomHttpServletRequestWrapper(request); - Map extraHeadersMap = new HashMap<>(); - Set globallyAddedHeaders = new HashSet<>(); - - for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { - boolean headersPresent = requestFilter.getHeaderNames().stream() - .allMatch(headerName -> request.getHeader(headerName) != null); - - if (!headersPresent) { - Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(principal); - - extraHeaderValueMap.ifPresent(map -> { - for (Map.Entry extraHeaderEntry : map.entrySet()) { - String headerKey = extraHeaderEntry.getKey(); - if (headersBlockList.contains(headerKey)) { - throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); - } - if (globallyAddedHeaders.contains(headerKey)) { - throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); - } - if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { - extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); - globallyAddedHeaders.add(headerKey); - } - } - }); - } - } - wrappedRequest.setHeaders(extraHeadersMap); + CustomHttpServletRequestWrapper wrappedRequest = mergeExtraHeaders(request, principal); nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response); return; } @@ -169,6 +141,43 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo } } + public CustomHttpServletRequestWrapper mergeExtraHeaders(HttpServletRequest request, Principal principal) + { + CustomHttpServletRequestWrapper wrappedRequest = new CustomHttpServletRequestWrapper(request); + Map extraHeadersMap = new HashMap<>(); + Set globallyAddedHeaders = new HashSet<>(); + + for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { + boolean headersPresent = requestFilter.getHeaderNames().stream() + .allMatch(headerName -> request.getHeader(headerName) != null); + + if (!headersPresent) { + Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(principal); + + extraHeaderValueMap.ifPresent(map -> { + for (Map.Entry extraHeaderEntry : map.entrySet()) { + String headerKey = extraHeaderEntry.getKey(); + if (headersBlockList.contains(headerKey)) { + throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, + "Modification attempt detected: The header " + headerKey + " is not allowed to be modified. The following headers cannot be modified: " + + String.join(", ", headersBlockList)); + } + if (globallyAddedHeaders.contains(headerKey)) { + throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Header conflict detected: " + headerKey + " already added by another filter."); + } + if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { + extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); + globallyAddedHeaders.add(headerKey); + } + } + }); + } + } + + wrappedRequest.setHeaders(extraHeadersMap); + return wrappedRequest; + } + private boolean doesRequestSupportAuthentication(HttpServletRequest request) { if (authenticators.isEmpty()) { diff --git a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java index 7bf483a19e38..d9a7774ef785 100644 --- a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java +++ b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java @@ -13,10 +13,12 @@ */ package com.facebook.presto; +import com.facebook.airlift.http.server.Authenticator; import com.facebook.presto.server.MockHttpServletRequest; +import com.facebook.presto.server.security.AuthenticationFilter; +import com.facebook.presto.server.security.SecurityConfig; import com.facebook.presto.server.testing.TestingPrestoServer; import com.facebook.presto.spi.ClientRequestFilter; -import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ListMultimap; import org.testng.annotations.Test; @@ -26,60 +28,14 @@ import java.util.Collections; import java.util.Enumeration; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import static org.testng.Assert.assertEquals; public class TestClientRequestFilterPlugin { - private final List headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token"); - - @Test - public void testCustomRequestFilterWithHeaders() throws Exception - { - ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); - List requestFilters = getClientRequestFilter(); - - ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(requestFilters); - - PrincipalStub testPrincipal = new PrincipalStub(); - - Map extraHeadersMap = new HashMap<>(); - Set globallyAddedHeaders = new HashSet<>(); - - for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { - boolean headersPresent = requestFilter.getHeaderNames().stream() - .allMatch(headerName -> request.getHeader(headerName) != null); - - if (!headersPresent) { - Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(testPrincipal); - - extraHeaderValueMap.ifPresent(map -> { - for (Map.Entry extraHeaderEntry : map.entrySet()) { - String headerKey = extraHeaderEntry.getKey(); - if (headersBlockList.contains(headerKey)) { - throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); - } - if (globallyAddedHeaders.contains(headerKey)) { - throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); - } - - if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { - extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); - globallyAddedHeaders.add(headerKey); - } - } - }); - } - } - request.setHeaders(extraHeadersMap); - assertEquals("CustomValue", request.getHeader("X-Custom-Header")); - } - private List getClientRequestFilter() { List requestFilters = new ArrayList<>(); @@ -88,13 +44,13 @@ private List getClientRequestFilter() @Override public List getHeaderNames() { - return Collections.singletonList("X-Custom-Header"); + return Collections.singletonList("ExpectedExtraHeader"); } @Override public Optional> getExtraHeaders(T additionalInfo) { Map headers = new HashMap<>(); - headers.put("X-Custom-Header", "CustomValue"); + headers.put("ExpectedExtraHeader", "ExpectedExtraValue"); return Optional.of(headers); } }; @@ -102,48 +58,6 @@ public Optional> getExtraHeaders(T additionalInfo) return requestFilters; } - @Test( - expectedExceptions = RuntimeException.class, - expectedExceptionsMessageRegExp = "Modification attempt detected: The header X-Presto-Transaction-Id is present in the blocked headers list.") - public void testCustomRequestFilterWithHeadersInBlockList() - { - ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); - List requestFilters = getClientRequestFilterInBlockList(); - ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(requestFilters); - - PrincipalStub testPrincipal = new PrincipalStub(); - - Map extraHeadersMap = new HashMap<>(); - Set globallyAddedHeaders = new HashSet<>(); - - for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { - boolean headersPresent = requestFilter.getHeaderNames().stream() - .allMatch(headerName -> request.getHeader(headerName) != null); - - if (!headersPresent) { - Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(testPrincipal); - - extraHeaderValueMap.ifPresent(map -> { - for (Map.Entry extraHeaderEntry : map.entrySet()) { - String headerKey = extraHeaderEntry.getKey(); - if (headersBlockList.contains(headerKey)) { - throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); - } - if (globallyAddedHeaders.contains(headerKey)) { - throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); - } - - if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { - extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); - globallyAddedHeaders.add(headerKey); - } - } - }); - } - } - request.setHeaders(extraHeadersMap); - } - private List getClientRequestFilterInBlockList() { List requestFilters = new ArrayList<>(); @@ -166,49 +80,6 @@ public Optional> getExtraHeaders(T additionalInfo) return requestFilters; } - @Test( - expectedExceptions = RuntimeException.class, - expectedExceptionsMessageRegExp = "Header conflict detected: X-Custom-Header already added by another filter.") - public void testCustomRequestFilterHandlesConflict() - { - ConcreteHttpServletRequest request = new ConcreteHttpServletRequest(ImmutableListMultimap.of("X-Custom-Header1", "CustomValue1"), "http://request-modifier.com", Collections.singletonMap("attribute", "attribute1")); - List requestFilters = getClientRequestFilters(); - - ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(requestFilters); - - PrincipalStub testPrincipal = new PrincipalStub(); - - Map extraHeadersMap = new HashMap<>(); - Set globallyAddedHeaders = new HashSet<>(); - - for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { - boolean headersPresent = requestFilter.getHeaderNames().stream() - .allMatch(headerName -> request.getHeader(headerName) != null); - - if (!headersPresent) { - Optional> extraHeaderValueMap = requestFilter.getExtraHeaders(testPrincipal); - - extraHeaderValueMap.ifPresent(map -> { - for (Map.Entry extraHeaderEntry : map.entrySet()) { - String headerKey = extraHeaderEntry.getKey(); - if (headersBlockList.contains(headerKey)) { - throw new RuntimeException("Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); - } - if (globallyAddedHeaders.contains(headerKey)) { - throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); - } - - if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { - extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); - globallyAddedHeaders.add(headerKey); - } - } - }); - } - } - request.setHeaders(extraHeadersMap); - } - private List getClientRequestFilters() { List requestFilters = new ArrayList<>(); @@ -239,8 +110,8 @@ public Optional> getExtraHeaders(T additionalInfo) return Optional.of(headers); } } - requestFilters.add(new CustomHeaderFilter("X-Custom-Header", "CustomValue_1")); - requestFilters.add(new CustomHeaderFilter("X-Custom-Header", "CustomValue_2")); + requestFilters.add(new CustomHeaderFilter("ExpectedExtraValue", "ExpectedExtraHeader_1")); + requestFilters.add(new CustomHeaderFilter("ExpectedExtraValue", "ExpectedExtraHeader_2")); return requestFilters; } @@ -291,4 +162,79 @@ public String getName() return "TestPrincipal"; } } + + @Test + public void testCustomRequestFilterWithHeaders() + { + ConcreteHttpServletRequest request = createRequest(); + List requestFilters = getClientRequestFilter(); + AuthenticationFilter filter = setupAuthenticationFilter(requestFilters); + PrincipalStub testPrincipal = new PrincipalStub(); + + AuthenticationFilter.CustomHttpServletRequestWrapper wrappedRequest = filter.collectExtraHeadersMap(request, testPrincipal); + + assertEquals("CustomValue", wrappedRequest.getHeader("X-Custom-Header")); + assertEquals("ExpectedExtraValue", wrappedRequest.getHeader("ExpectedExtraHeader")); + } + + @Test( + expectedExceptions = RuntimeException.class, + expectedExceptionsMessageRegExp = "Modification attempt detected: The header X-Presto-Transaction-Id is not allowed to be modified. The following headers cannot be modified: " + + "X-Presto-Transaction-Id, X-Presto-Started-Transaction-Id, X-Presto-Clear-Transaction-Id, X-Presto-Trace-Token") + public void testCustomRequestFilterWithHeadersInBlockList() + { + ConcreteHttpServletRequest request = createRequest(); + List requestFilters = getClientRequestFilterInBlockList(); + AuthenticationFilter filter = setupAuthenticationFilter(requestFilters); + PrincipalStub testPrincipal = new PrincipalStub(); + + filter.collectExtraHeadersMap(request, testPrincipal); + } + + @Test( + expectedExceptions = RuntimeException.class, + expectedExceptionsMessageRegExp = "Header conflict detected: ExpectedExtraValue already added by another filter.") + public void testCustomRequestFilterHandlesConflict() + { + ConcreteHttpServletRequest request = createRequest(); + List requestFilters = getClientRequestFilters(); + AuthenticationFilter filter = setupAuthenticationFilter(requestFilters); + PrincipalStub testPrincipal = new PrincipalStub(); + + filter.collectExtraHeadersMap(request, testPrincipal); + } + + private AuthenticationFilter setupAuthenticationFilter(List requestFilters) + { + ClientRequestFilterManager clientRequestFilterManager = TestingPrestoServer.getClientRequestFilterManager(requestFilters); + + List authenticators = createAuthenticators(); + SecurityConfig securityConfig = createSecurityConfig(); + + return new AuthenticationFilter(authenticators, securityConfig, clientRequestFilterManager); + } + + private ConcreteHttpServletRequest createRequest() + { + return new ConcreteHttpServletRequest( + ImmutableListMultimap.of("X-Custom-Header", "CustomValue"), + "http://request-modifier.com", + Collections.singletonMap("attribute", "attribute1")); + } + + private List createAuthenticators() + { + return Collections.emptyList(); + } + + private SecurityConfig createSecurityConfig() + { + return new SecurityConfig() { + @Override + public boolean getAllowForwardedHttps() + { + return true; + } + }; + } } From 3a2e4ed704286e6efe0cf58f3fd11b3b95af50b5 Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Fri, 11 Oct 2024 04:23:41 +0530 Subject: [PATCH 12/13] Renamed a method --- .../com/facebook/presto/TestClientRequestFilterPlugin.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java index d9a7774ef785..be85bdf53887 100644 --- a/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java +++ b/presto-main/src/test/java/com/facebook/presto/TestClientRequestFilterPlugin.java @@ -171,7 +171,7 @@ public void testCustomRequestFilterWithHeaders() AuthenticationFilter filter = setupAuthenticationFilter(requestFilters); PrincipalStub testPrincipal = new PrincipalStub(); - AuthenticationFilter.CustomHttpServletRequestWrapper wrappedRequest = filter.collectExtraHeadersMap(request, testPrincipal); + AuthenticationFilter.CustomHttpServletRequestWrapper wrappedRequest = filter.mergeExtraHeaders(request, testPrincipal); assertEquals("CustomValue", wrappedRequest.getHeader("X-Custom-Header")); assertEquals("ExpectedExtraValue", wrappedRequest.getHeader("ExpectedExtraHeader")); @@ -188,7 +188,7 @@ public void testCustomRequestFilterWithHeadersInBlockList() AuthenticationFilter filter = setupAuthenticationFilter(requestFilters); PrincipalStub testPrincipal = new PrincipalStub(); - filter.collectExtraHeadersMap(request, testPrincipal); + filter.mergeExtraHeaders(request, testPrincipal); } @Test( @@ -201,7 +201,7 @@ public void testCustomRequestFilterHandlesConflict() AuthenticationFilter filter = setupAuthenticationFilter(requestFilters); PrincipalStub testPrincipal = new PrincipalStub(); - filter.collectExtraHeadersMap(request, testPrincipal); + filter.mergeExtraHeaders(request, testPrincipal); } private AuthenticationFilter setupAuthenticationFilter(List requestFilters) From 77e4df062e79893efe090228eb82544ce357524d Mon Sep 17 00:00:00 2001 From: SthuthiGhosh9400 Date: Fri, 11 Oct 2024 04:51:13 +0530 Subject: [PATCH 13/13] corrected checkstyle --- .../src/main/java/com/facebook/presto/server/PluginManager.java | 2 +- presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java index e10ca5d09d11..f8d286f3108c 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java +++ b/presto-main/src/main/java/com/facebook/presto/server/PluginManager.java @@ -363,7 +363,7 @@ public void installPlugin(Plugin plugin) log.info("Registering plan checker provider factory %s", planCheckerProviderFactory.getName()); planCheckerProviderManager.addPlanCheckerProviderFactory(planCheckerProviderFactory); } - + for (ClientRequestFilter clientRequestFilter : plugin.getClientRequestFilters()) { log.info("Registering client request filter"); clientRequestFilterManager.registerClientRequestFilter(clientRequestFilter); diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java index 691655ad019f..34db9e031963 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/Plugin.java @@ -148,7 +148,7 @@ default Iterable getPlanCheckerProviderFactories() { return emptyList(); } - + default Iterable getClientRequestFilters() { return emptyList();