-
Notifications
You must be signed in to change notification settings - Fork 5.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introducing ClientRequestFilter.java: A New Plugin for Applying Request Headers in the Authentication Filter Class #23380
base: master
Are you sure you want to change the base?
Changes from 10 commits
55c4278
fb5d812
0b6e62c
a3d875d
87c12ec
1506807
0def3a4
5639435
57d0d17
7c28ec3
dc15a70
09a5031
0c2092e
3a2e4ed
dd1dd36
77e4df0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.facebook.presto; | ||
|
||
import com.facebook.presto.spi.ClientRequestFilter; | ||
import com.google.common.collect.ImmutableList; | ||
|
||
import java.util.List; | ||
import java.util.concurrent.CopyOnWriteArrayList; | ||
|
||
public class ClientRequestFilterManager | ||
{ | ||
private final List<ClientRequestFilter> clientRequestFilters = new CopyOnWriteArrayList<>(); | ||
|
||
public List<ClientRequestFilter> getClientRequestFilters() | ||
{ | ||
return ImmutableList.copyOf(clientRequestFilters); | ||
} | ||
|
||
public void registerClientRequestFilter(ClientRequestFilter clientRequestFilter) | ||
{ | ||
clientRequestFilters.add(clientRequestFilter); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
/* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package com.facebook.presto; | ||
|
||
import com.google.inject.Binder; | ||
import com.google.inject.Module; | ||
import com.google.inject.Scopes; | ||
|
||
public class ClientRequestFilterModule | ||
implements Module | ||
{ | ||
@Override | ||
public void configure(Binder binder) | ||
{ | ||
binder.bind(ClientRequestFilterManager.class).in(Scopes.SINGLETON); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,9 +15,13 @@ | |
|
||
import com.facebook.airlift.http.server.AuthenticationException; | ||
import com.facebook.airlift.http.server.Authenticator; | ||
import com.facebook.presto.ClientRequestFilterManager; | ||
import com.facebook.presto.spi.ClientRequestFilter; | ||
import com.facebook.presto.spi.PrestoException; | ||
import com.google.common.base.Joiner; | ||
import com.google.common.base.Strings; | ||
import com.google.common.collect.ImmutableList; | ||
import com.google.common.collect.ImmutableSet; | ||
import com.google.common.net.HttpHeaders; | ||
|
||
import javax.inject.Inject; | ||
|
@@ -35,10 +39,18 @@ | |
import java.io.InputStream; | ||
import java.io.PrintWriter; | ||
import java.security.Principal; | ||
import java.util.Collections; | ||
import java.util.Enumeration; | ||
import java.util.HashMap; | ||
import java.util.HashSet; | ||
import java.util.LinkedHashSet; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Optional; | ||
import java.util.Set; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
|
||
import static com.facebook.presto.spi.StandardErrorCode.HEADER_MODIFICATION_ATTEMPT; | ||
import static com.google.common.io.ByteStreams.copy; | ||
import static com.google.common.io.ByteStreams.nullOutputStream; | ||
import static com.google.common.net.HttpHeaders.WWW_AUTHENTICATE; | ||
|
@@ -52,12 +64,15 @@ public class AuthenticationFilter | |
private static final String HTTPS_PROTOCOL = "https"; | ||
private final List<Authenticator> authenticators; | ||
private final boolean allowForwardedHttps; | ||
private final ClientRequestFilterManager clientRequestFilterManager; | ||
private final List<String> headersBlockList = ImmutableList.of("X-Presto-Transaction-Id", "X-Presto-Started-Transaction-Id", "X-Presto-Clear-Transaction-Id", "X-Presto-Trace-Token"); | ||
|
||
@Inject | ||
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig) | ||
public AuthenticationFilter(List<Authenticator> authenticators, SecurityConfig securityConfig, ClientRequestFilterManager clientRequestFilterManager) | ||
{ | ||
this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null")); | ||
this.allowForwardedHttps = requireNonNull(securityConfig, "securityConfig is null").getAllowForwardedHttps(); | ||
this.clientRequestFilterManager = requireNonNull(clientRequestFilterManager, "clientRequestFilterManager is null"); | ||
} | ||
|
||
@Override | ||
|
@@ -95,9 +110,37 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo | |
e.getAuthenticateHeader().ifPresent(authenticateHeaders::add); | ||
continue; | ||
} | ||
|
||
// authentication succeeded | ||
nextFilter.doFilter(withPrincipal(request, principal), response); | ||
CustomHttpServletRequestWrapper wrappedRequest = new CustomHttpServletRequestWrapper(request); | ||
Map<String, String> extraHeadersMap = new HashMap<>(); | ||
Set<String> globallyAddedHeaders = new HashSet<>(); | ||
|
||
for (ClientRequestFilter requestFilter : clientRequestFilterManager.getClientRequestFilters()) { | ||
boolean headersPresent = requestFilter.getHeaderNames().stream() | ||
.allMatch(headerName -> request.getHeader(headerName) != null); | ||
|
||
if (!headersPresent) { | ||
Optional<Map<String, String>> extraHeaderValueMap = requestFilter.getExtraHeaders(principal); | ||
|
||
extraHeaderValueMap.ifPresent(map -> { | ||
for (Map.Entry<String, String> extraHeaderEntry : map.entrySet()) { | ||
String headerKey = extraHeaderEntry.getKey(); | ||
if (headersBlockList.contains(headerKey)) { | ||
throw new PrestoException(HEADER_MODIFICATION_ATTEMPT, "Modification attempt detected: The header " + headerKey + " is present in the blocked headers list."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of mentioning a "blocked headers list", just list out the headers that are not allowed to be modified. |
||
} | ||
if (globallyAddedHeaders.contains(headerKey)) { | ||
throw new RuntimeException("Header conflict detected: " + headerKey + " already added by another filter."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not throw RuntimeException, throw PrestoException. |
||
} | ||
if (request.getHeader(headerKey) == null && requestFilter.getHeaderNames().contains(headerKey)) { | ||
extraHeadersMap.putIfAbsent(headerKey, extraHeaderEntry.getValue()); | ||
globallyAddedHeaders.add(headerKey); | ||
} | ||
} | ||
}); | ||
} | ||
} | ||
wrappedRequest.setHeaders(extraHeadersMap); | ||
nextFilter.doFilter(withPrincipal(wrappedRequest, principal), response); | ||
return; | ||
} | ||
|
||
|
@@ -140,7 +183,7 @@ private boolean doesRequestSupportAuthentication(HttpServletRequest request) | |
return false; | ||
} | ||
|
||
private static ServletRequest withPrincipal(HttpServletRequest request, Principal principal) | ||
public static ServletRequest withPrincipal(HttpServletRequest request, Principal principal) | ||
{ | ||
requireNonNull(principal, "principal is null"); | ||
return new HttpServletRequestWrapper(request) | ||
|
@@ -166,4 +209,51 @@ private static void skipRequestBody(HttpServletRequest request) | |
copy(inputStream, nullOutputStream()); | ||
} | ||
} | ||
|
||
public static class CustomHttpServletRequestWrapper | ||
extends HttpServletRequestWrapper | ||
{ | ||
private final Map<String, String> customHeaders = new ConcurrentHashMap<>(); | ||
|
||
public CustomHttpServletRequestWrapper(HttpServletRequest request) | ||
{ | ||
super(request); | ||
} | ||
|
||
@Override | ||
public String getHeader(String name) | ||
{ | ||
String headerValue = customHeaders.get(name); | ||
if (headerValue != null) { | ||
return headerValue; | ||
} | ||
return super.getHeader(name); | ||
} | ||
|
||
@Override | ||
public Enumeration<String> getHeaderNames() | ||
{ | ||
ImmutableSet.Builder<String> headerNamesBuilder = ImmutableSet.builder(); | ||
headerNamesBuilder.addAll(customHeaders.keySet()); | ||
Enumeration<String> originalHeaderNames = super.getHeaderNames(); | ||
while (originalHeaderNames.hasMoreElements()) { | ||
headerNamesBuilder.add(originalHeaderNames.nextElement()); | ||
} | ||
return Collections.enumeration(headerNamesBuilder.build()); | ||
} | ||
|
||
@Override | ||
public Enumeration<String> getHeaders(String name) | ||
{ | ||
if (customHeaders.containsKey(name)) { | ||
return Collections.enumeration(Collections.singleton(customHeaders.get(name))); | ||
} | ||
return super.getHeaders(name); | ||
} | ||
|
||
public void setHeaders(Map<String, String> headers) | ||
{ | ||
this.customHeaders.putAll(headers); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not use Mockito in this project. Can you please refactor your code to avoid its use?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdcmeehan
Instead of using Mockito to create mock objects and define their behavior, can I manually write simple classes (stubs) that simulate the behavior of the actual classes?
Do you have any suggestion here in order to verify the codebase if value correctly passed to request header?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is how we go about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tdcmeehan
I have refactored the code that was added for the test case.