Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add expiry check to DataPlaneTokenRefreshServiceImpl #1124

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.eclipse.edc.runtime.metamodel.annotation.Extension;
import org.eclipse.edc.runtime.metamodel.annotation.Inject;
import org.eclipse.edc.runtime.metamodel.annotation.Provider;
import org.eclipse.edc.runtime.metamodel.annotation.Setting;
import org.eclipse.edc.spi.security.PrivateKeyResolver;
import org.eclipse.edc.spi.system.ServiceExtension;
import org.eclipse.edc.spi.system.ServiceExtensionContext;
Expand All @@ -34,6 +35,7 @@
import org.jetbrains.annotations.NotNull;

import java.security.PrivateKey;
import java.time.Clock;
import java.util.function.Supplier;

import static org.eclipse.edc.connector.dataplane.spi.TransferDataPlaneConfig.TOKEN_SIGNER_PRIVATE_KEY_ALIAS;
Expand All @@ -42,15 +44,19 @@
@Extension(value = NAME)
public class DataPlaneTokenRefreshServiceExtension implements ServiceExtension {
public static final String NAME = "DataPlane Token Refresh Service extension";
public static final int DEFAULT_TOKEN_EXPIRY_TOLERANCE_SECONDS = 5;
@Setting(value = "Token expiry tolerance period in seconds to allow for clock skew", defaultValue = "" + DEFAULT_TOKEN_EXPIRY_TOLERANCE_SECONDS)
public static final String TOKEN_EXPIRY_TOLERANCE_SECONDS_PROPERTY = "edc.dataplane.api.token.expiry.tolerance";
@Inject
private TokenValidationService tokenValidationService;
@Inject
private DidPublicKeyResolver didPkResolver;
@Inject
private AccessTokenDataStore accessTokenDataStore;

@Inject
private PrivateKeyResolver privateKeyResolver;
@Inject
private Clock clock;
private DataPlaneTokenRefreshServiceImpl tokenRefreshService;

@Override
Expand All @@ -73,7 +79,8 @@ public DataPlaneTokenRefreshService createRefreshTokenService(ServiceExtensionCo
@NotNull
private DataPlaneTokenRefreshServiceImpl getTokenRefreshService(ServiceExtensionContext context) {
if (tokenRefreshService == null) {
tokenRefreshService = new DataPlaneTokenRefreshServiceImpl(tokenValidationService, didPkResolver, accessTokenDataStore, new JwtGenerationService(), getPrivateKeySupplier(context), context.getMonitor(), "foo.bar");
var epsilon = context.getConfig().getInteger(TOKEN_EXPIRY_TOLERANCE_SECONDS_PROPERTY, DEFAULT_TOKEN_EXPIRY_TOLERANCE_SECONDS);
tokenRefreshService = new DataPlaneTokenRefreshServiceImpl(clock, tokenValidationService, didPkResolver, accessTokenDataStore, new JwtGenerationService(), getPrivateKeySupplier(context), context.getMonitor(), null, epsilon);
}
return tokenRefreshService;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
import org.eclipse.edc.connector.dataplane.spi.iam.DataPlaneAccessTokenService;
import org.eclipse.edc.connector.dataplane.spi.store.AccessTokenDataStore;
import org.eclipse.edc.iam.did.spi.resolution.DidPublicKeyResolver;
import org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames;
import org.eclipse.edc.spi.iam.ClaimToken;
import org.eclipse.edc.spi.iam.TokenParameters;
import org.eclipse.edc.spi.iam.TokenRepresentation;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.Result;
import org.eclipse.edc.spi.types.domain.DataAddress;
import org.eclipse.edc.token.rules.ExpirationIssuedAtValidationRule;
import org.eclipse.edc.token.spi.TokenDecorator;
import org.eclipse.edc.token.spi.TokenGenerationService;
import org.eclipse.edc.token.spi.TokenValidationRule;
Expand All @@ -40,6 +42,7 @@
import org.eclipse.tractusx.edc.dataplane.tokenrefresh.spi.model.TokenResponse;

import java.security.PrivateKey;
import java.time.Clock;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -51,6 +54,7 @@
import java.util.stream.Stream;

import static org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames.AUDIENCE;
import static org.eclipse.edc.jwt.spi.JwtRegisteredClaimNames.EXPIRATION_TIME;

/**
* This implementation of the {@link DataPlaneTokenRefreshService} validates an incoming authentication token.
Expand All @@ -60,33 +64,41 @@ public class DataPlaneTokenRefreshServiceImpl implements DataPlaneTokenRefreshSe
public static final String TOKEN_ID_CLAIM = "jti";
public static final String REFRESH_TOKEN_PROPERTY = "refreshToken";
private static final Long DEFAULT_EXPIRY_IN_SECONDS = 60 * 5L;
private final List<TokenValidationRule> authenticationTokenValidationRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(AUDIENCE), // we don't check the contents, only it is present
new ClaimIsPresentRule(ACCESS_TOKEN_CLAIM),
new ClaimIsPresentRule(TOKEN_ID_CLAIM));
private final List<TokenValidationRule> authenticationTokenValidationRules;
private final List<TokenValidationRule> accessTokenRules;
private final TokenValidationService tokenValidationService;
private final DidPublicKeyResolver publicKeyResolver;
private final AccessTokenDataStore accessTokenDataStore;
private final TokenGenerationService tokenGenerationService;
private final Supplier<PrivateKey> privateKeySupplier;
private final Monitor monitor;
private final String refreshEndpoint;
private final Clock clock;


public DataPlaneTokenRefreshServiceImpl(TokenValidationService tokenValidationService,
public DataPlaneTokenRefreshServiceImpl(Clock clock, TokenValidationService tokenValidationService,
DidPublicKeyResolver publicKeyResolver,
AccessTokenDataStore accessTokenDataStore,
TokenGenerationService tokenGenerationService,
Supplier<PrivateKey> privateKeySupplier,
Monitor monitor,
String refreshEndpoint) {
String refreshEndpoint, int tokenExpiryToleranceSeconds) {
this.tokenValidationService = tokenValidationService;
this.publicKeyResolver = publicKeyResolver;
this.accessTokenDataStore = accessTokenDataStore;
this.tokenGenerationService = tokenGenerationService;
this.privateKeySupplier = privateKeySupplier;
this.monitor = monitor;
this.refreshEndpoint = refreshEndpoint;
this.clock = clock;
authenticationTokenValidationRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(AUDIENCE), // we don't check the contents, only it is present
new ClaimIsPresentRule(ACCESS_TOKEN_CLAIM),
new ClaimIsPresentRule(TOKEN_ID_CLAIM));
accessTokenRules = List.of(new IssuerEqualsSubjectRule(),
new ClaimIsPresentRule(AUDIENCE),
new ClaimIsPresentRule(TOKEN_ID_CLAIM),
new ExpirationIssuedAtValidationRule(clock, tokenExpiryToleranceSeconds));
}

/**
Expand Down Expand Up @@ -192,7 +204,7 @@ public Result<TokenRepresentation> obtainToken(TokenParameters tokenParameters,

@Override
public Result<AccessTokenData> resolve(String token) {
return resolveToken(token, authenticationTokenValidationRules);
return resolveToken(token, accessTokenRules);
}

/**
Expand All @@ -213,6 +225,12 @@ private Result<TokenRepresentationWithId> createToken(TokenParameters tokenParam
TokenDecorator tokenIdDecorator = params -> params.claims(TOKEN_ID_CLAIM, tokenId.get());
allDecorators.add(tokenIdDecorator);
}
//if there is not "exp" header on the token params, we'll configure one
if (!tokenParameters.getClaims().containsKey(JwtRegisteredClaimNames.EXPIRATION_TIME)) {
monitor.info("No '%s' claim found on TokenParameters. Will use the default of %d seconds".formatted(EXPIRATION_TIME, DEFAULT_EXPIRY_IN_SECONDS));
var exp = clock.instant().plusSeconds(DEFAULT_EXPIRY_IN_SECONDS).getEpochSecond();
allDecorators.add(tp -> tp.claims(JwtRegisteredClaimNames.EXPIRATION_TIME, exp));
}

return tokenGenerationService.generate(privateKeySupplier, allDecorators.toArray(new TokenDecorator[0]))
.map(tr -> new TokenRepresentationWithId(tokenId.get(), tr));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.junit.jupiter.api.Test;

import java.text.ParseException;
import java.time.Clock;
import java.time.Instant;
import java.util.List;
import java.util.Map;
Expand All @@ -71,23 +72,26 @@ class DataPlaneTokenRefreshServiceImplComponentTest {
private DataPlaneTokenRefreshServiceImpl tokenRefreshService;
private InMemoryAccessTokenDataStore tokenDataStore;
private ECKey consumerKey;
private ECKey providerKey;

@BeforeEach
void setup() throws JOSEException {

var providerKey = new ECKeyGenerator(Curve.P_384).keyID(PROVIDER_BPN + "#provider-key").keyUse(KeyUse.SIGNATURE).generate();
providerKey = new ECKeyGenerator(Curve.P_384).keyID(PROVIDER_BPN + "#provider-key").keyUse(KeyUse.SIGNATURE).generate();
consumerKey = new ECKeyGenerator(Curve.P_384).keyID(CONSUMER_DID + "#consumer-key").keyUse(KeyUse.SIGNATURE).generate();

var privateKey = providerKey.toPrivateKey();

tokenDataStore = new InMemoryAccessTokenDataStore(CriterionOperatorRegistryImpl.ofDefaults());
tokenRefreshService = new DataPlaneTokenRefreshServiceImpl(new TokenValidationServiceImpl(),
tokenRefreshService = new DataPlaneTokenRefreshServiceImpl(Clock.systemUTC(),
new TokenValidationServiceImpl(),
didPkResolverMock,
tokenDataStore,
new JwtGenerationService(),
() -> privateKey,
mock(),
TEST_REFRESH_ENDPOINT);
TEST_REFRESH_ENDPOINT,
1);

when(didPkResolverMock.resolveKey(eq(consumerKey.getKeyID()))).thenReturn(Result.success(consumerKey.toPublicKey()));
when(didPkResolverMock.resolveKey(eq(providerKey.getKeyID()))).thenReturn(Result.success(providerKey.toPublicKey()));
Expand All @@ -106,7 +110,8 @@ void obtainToken() {
.containsKey("agreement_id")
.containsEntry("iss", PROVIDER_BPN)
.containsEntry("sub", PROVIDER_BPN)
.containsEntry("aud", List.of(CONSUMER_BPN));
.containsEntry("aud", List.of(CONSUMER_BPN))
.containsKey("exp");

// assert additional properties -> refresh token
assertThat(edr.getContent().getAdditional())
Expand Down Expand Up @@ -235,6 +240,52 @@ void refresh_whenIssNotEqualToSub() throws JOSEException {
.isEqualTo("The 'iss' and 'sub' claims must be non-null and identical.");
}

@DisplayName("Verify that resolving an expired token fails")
@Test
void resolve_whenExpired_shouldFail() {
var tokenId = "test-token-id";
var edr = tokenRefreshService.obtainToken(tokenParamsBuilder(tokenId)
//token was issued 10min ago, and expired 5min ago
.claims(JwtRegisteredClaimNames.ISSUED_AT, Instant.now().minusSeconds(600).getEpochSecond())
.claims(JwtRegisteredClaimNames.EXPIRATION_TIME, Instant.now().minusSeconds(300).getEpochSecond())
.build(),
DataAddress.Builder.newInstance().type("test-type").build(), Map.of("audience", CONSUMER_DID))
.orElseThrow(f -> new RuntimeException(f.getFailureDetail()));

assertThat(tokenRefreshService.resolve(edr.getToken())).isFailed()
.detail().isEqualTo("Token has expired (exp)");

}

@DisplayName("Verify that resolving a valid token succeeds")
@Test
void resolve_success() {
var tokenId = "test-token-id";
var edr = tokenRefreshService.obtainToken(tokenParamsBuilder(tokenId)
.claims(JwtRegisteredClaimNames.ISSUED_AT, Instant.now().getEpochSecond())
.build(),
DataAddress.Builder.newInstance().type("test-type").build(), Map.of("audience", CONSUMER_DID))
.orElseThrow(f -> new RuntimeException(f.getFailureDetail()));

assertThat(tokenRefreshService.resolve(edr.getToken())).isSucceeded();
}

@DisplayName("Verify that attempting to resolve a non-existing token results in a failure")
@Test
void resolve_notFound() {
var tokenId = "test-token-id";
var edr = tokenRefreshService.obtainToken(tokenParamsBuilder(tokenId)
.claims(JwtRegisteredClaimNames.ISSUED_AT, Instant.now().getEpochSecond())
.build(),
DataAddress.Builder.newInstance().type("test-type").build(), Map.of("audience", CONSUMER_DID))
.orElseThrow(f -> new RuntimeException(f.getFailureDetail()));
tokenDataStore.deleteById(tokenId).orElseThrow(f -> new AssertionError(f.getFailureDetail()));

assertThat(tokenRefreshService.resolve(edr.getToken()))
.isFailed()
.detail().isEqualTo("AccessTokenData with ID '%s' does not exist.".formatted(tokenId));
}

private JWTClaimsSet.Builder getAuthTokenClaims(String tokenId, String accessToken) {
return new JWTClaimsSet.Builder()
.jwtID(tokenId)
Expand All @@ -245,17 +296,23 @@ private JWTClaimsSet.Builder getAuthTokenClaims(String tokenId, String accessTok
}

private TokenParameters tokenParams(String id) {
return tokenParamsBuilder(id).build();
}

private TokenParameters.Builder tokenParamsBuilder(String id) {
return TokenParameters.Builder.newInstance()
.claims(JwtRegisteredClaimNames.JWT_ID, id)
.claims(JwtRegisteredClaimNames.AUDIENCE, CONSUMER_BPN)
.claims(JwtRegisteredClaimNames.ISSUER, PROVIDER_BPN)
.claims(JwtRegisteredClaimNames.SUBJECT, PROVIDER_BPN)
.claims(JwtRegisteredClaimNames.ISSUED_AT, Instant.now().toEpochMilli()) // todo: milli or second?
.claims(JwtRegisteredClaimNames.EXPIRATION_TIME, Instant.now().plusSeconds(60).getEpochSecond())
.claims(JwtRegisteredClaimNames.ISSUED_AT, Instant.now().getEpochSecond())
.claims(CLAIM_AGREEMENT_ID, "test-agreement-id")
.claims(CLAIM_ASSET_ID, "test-asset-id")
.claims(CLAIM_PROCESS_ID, "test-process-id")
.claims(CLAIM_FLOW_TYPE, FlowType.PULL.toString())
.build();
.header("kid", providerKey.getKeyID());

}

private Map<String, Object> asClaims(String serializedJwt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.eclipse.edc.token.spi.TokenValidationService;
import org.junit.jupiter.api.Test;

import java.time.Clock;
import java.util.Map;
import java.util.regex.Pattern;

Expand All @@ -58,7 +59,7 @@ class DataPlaneTokenRefreshServiceImplTest {
private final TokenValidationService tokenValidationService = mock();
private final DidPublicKeyResolver didPublicKeyResolver = mock();

private final DataPlaneTokenRefreshServiceImpl accessTokenService = new DataPlaneTokenRefreshServiceImpl(tokenValidationService, didPublicKeyResolver, accessTokenDataStore, tokenGenService, mock(), mock(), "https://example.com");
private final DataPlaneTokenRefreshServiceImpl accessTokenService = new DataPlaneTokenRefreshServiceImpl(Clock.systemUTC(), tokenValidationService, didPublicKeyResolver, accessTokenDataStore, tokenGenService, mock(), mock(), "https://example.com", 1);

@Test
void obtainToken() {
Expand Down
Loading