/** * Add and validate all the configured extensions. * Token keys, apart from passing regex validation, must not be equal to the reserved key {@link OAuthBearerClientInitialResponse#AUTH_KEY} */ private void handleExtensionsCallback(SaslExtensionsCallback callback) { Map<String, String> extensions = new HashMap<>(); for (Map.Entry<String, String> configEntry : this.moduleOptions.entrySet()) { String key = configEntry.getKey(); if (!key.startsWith(EXTENSION_PREFIX)) continue; extensions.put(key.substring(EXTENSION_PREFIX.length()), configEntry.getValue()); } SaslExtensions saslExtensions = new SaslExtensions(extensions); try { OAuthBearerClientInitialResponse.validateExtensions(saslExtensions); } catch (SaslException e) { throw new ConfigException(e.getMessage()); } callback.extensions(saslExtensions); }
/** * Validates a specific extension in the original {@code inputExtensions} map * @param extensionName - the name of the extension which was validated */ public void valid(String extensionName) { if (!inputExtensions.map().containsKey(extensionName)) throw new IllegalArgumentException(String.format("Extension %s was not found in the original extensions", extensionName)); validatedExtensions.put(extensionName, inputExtensions.map().get(extensionName)); } /**
@Test(expected = UnsupportedOperationException.class) public void testReturnedMapIsImmutable() { SaslExtensions extensions = new SaslExtensions(this.map); extensions.map().put("hello", "test"); }
@Test public void testCannotAddValueToMapReferenceAndGetFromExtensions() { SaslExtensions extensions = new SaslExtensions(this.map); assertNull(extensions.map().get("hello")); this.map.put("hello", "42"); assertNull(extensions.map().get("hello")); } }
/** * Converts the SASLExtensions to an OAuth protocol-friendly string */ private String extensionsMessage() { return Utils.mkString(saslExtensions.map(), "", "", "=", SEPARATOR); } }
public OAuthBearerClientInitialResponse(byte[] response) throws SaslException { String responseMsg = new String(response, StandardCharsets.UTF_8); Matcher matcher = CLIENT_INITIAL_RESPONSE_PATTERN.matcher(responseMsg); if (!matcher.matches()) throw new SaslException("Invalid OAUTHBEARER client first message"); String authzid = matcher.group("authzid"); this.authorizationId = authzid == null ? "" : authzid; String kvPairs = matcher.group("kvpairs"); Map<String, String> properties = Utils.parseMap(kvPairs, "=", SEPARATOR); String auth = properties.get(AUTH_KEY); if (auth == null) throw new SaslException("Invalid OAUTHBEARER client first message: 'auth' not specified"); properties.remove(AUTH_KEY); SaslExtensions extensions = new SaslExtensions(properties); validateExtensions(extensions); this.saslExtensions = extensions; Matcher authMatcher = AUTH_PATTERN.matcher(auth); if (!authMatcher.matches()) throw new SaslException("Invalid OAUTHBEARER client first message: invalid 'auth' format"); if (!"bearer".equalsIgnoreCase(authMatcher.group("scheme"))) { String msg = String.format("Invalid scheme in OAUTHBEARER client first message: %s", matcher.group("scheme")); throw new SaslException(msg); } this.tokenValue = authMatcher.group("token"); }
@Override public Object getNegotiatedProperty(String propName) { if (!complete) throw new IllegalStateException("Authentication exchange has not completed"); if (NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName)) return tokenForNegotiatedProperty; if (SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY.equals(propName)) return tokenForNegotiatedProperty.lifetimeMs(); return extensions.map().get(propName); }
@Test(expected = SaslException.class) public void testThrowsSaslExceptionOnInvalidExtensionKey() throws Exception { Map<String, String> extensions = new HashMap<>(); extensions.put("19", "42"); // keys can only be a-z new OAuthBearerClientInitialResponse("123.345.567", new SaslExtensions(extensions)); }
/** * @return An immutable {@link Map} consisting of the extensions that have neither been validated nor invalidated */ public Map<String, String> ignoredExtensions() { return Collections.unmodifiableMap(subtractMap(subtractMap(inputExtensions.map(), invalidExtensions), validatedExtensions)); }
@Test(expected = IllegalArgumentException.class) public void testCannotValidateExtensionWhichWasNotGiven() { Map<String, String> extensions = new HashMap<>(); extensions.put("hello", "bye"); OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions)); callback.valid("???"); } }
/** * Validates that the given extensions conform to the standard. They should also not contain the reserve key name {@link OAuthBearerClientInitialResponse#AUTH_KEY} * * @param extensions * optional extensions to validate * @throws SaslException * if any extension name or value fails to conform to the required * regular expression as defined by the specification, or if the * reserved {@code auth} appears as a key * * @see <a href="https://tools.ietf.org/html/rfc7628#section-3.1">RFC 7628, * Section 3.1</a> */ public static void validateExtensions(SaslExtensions extensions) throws SaslException { if (extensions == null) return; if (extensions.map().containsKey(OAuthBearerClientInitialResponse.AUTH_KEY)) throw new SaslException("Extension name " + OAuthBearerClientInitialResponse.AUTH_KEY + " is invalid"); for (Map.Entry<String, String> entry : extensions.map().entrySet()) { String extensionName = entry.getKey(); String extensionValue = entry.getValue(); if (!EXTENSION_KEY_PATTERN.matcher(extensionName).matches()) throw new SaslException("Extension name " + extensionName + " is invalid"); if (!EXTENSION_VALUE_PATTERN.matcher(extensionValue).matches()) throw new SaslException("Extension value (" + extensionValue + ") for extension " + extensionName + " is invalid"); } }
@Test public void testBuildClientResponseToBytes() throws Exception { String expectedMesssage = "n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001"; Map<String, String> extensions = new HashMap<>(); extensions.put("nineteen", "42"); OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse("123.345.567", new SaslExtensions(extensions)); String message = new String(response.toBytes(), StandardCharsets.UTF_8); assertEquals(expectedMesssage, message); }
@Override public void handle(Callback[] callbacks) throws UnsupportedCallbackException { if (!configured()) throw new IllegalStateException("Callback handler not configured"); for (Callback callback : callbacks) { if (callback instanceof OAuthBearerValidatorCallback) { OAuthBearerValidatorCallback validationCallback = (OAuthBearerValidatorCallback) callback; try { handleCallback(validationCallback); } catch (OAuthBearerIllegalTokenException e) { OAuthBearerValidationResult failureReason = e.reason(); String failureScope = failureReason.failureScope(); validationCallback.error(failureScope != null ? "insufficient_scope" : "invalid_token", failureScope, failureReason.failureOpenIdConfig()); } } else if (callback instanceof OAuthBearerExtensionsValidatorCallback) { OAuthBearerExtensionsValidatorCallback extensionsCallback = (OAuthBearerExtensionsValidatorCallback) callback; extensionsCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsCallback.valid(extensionName)); } else throw new UnsupportedCallbackException(callback); } }
@Test public void testNoExtensionsDoesNotAttachAnythingToFirstClientMessage() throws Exception { TEST_PROPERTIES.clear(); testExtensions = new SaslExtensions(TEST_PROPERTIES); String expectedToken = new String(new OAuthBearerClientInitialResponse("", new SaslExtensions(TEST_PROPERTIES)).toBytes(), StandardCharsets.UTF_8); OAuthBearerSaslClient client = new OAuthBearerSaslClient(new ExtensionsCallbackHandler(false)); String message = new String(client.evaluateChallenge("".getBytes()), StandardCharsets.UTF_8); assertEquals(expectedToken, message); }
@Test public void testExtensions() throws Exception { String message = "n,,\u0001propA=valueA1, valueA2\u0001auth=Bearer 567\u0001propB=valueB\u0001\u0001"; OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8)); assertEquals("567", response.tokenValue()); assertEquals("", response.authorizationId()); assertEquals("valueA1, valueA2", response.extensions().map().get("propA")); assertEquals("valueB", response.extensions().map().get("propB")); }
private byte[] process(String tokenValue, String authorizationId, SaslExtensions extensions) throws SaslException { OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(tokenValue); try { callbackHandler.handle(new Callback[] {callback}); } catch (IOException | UnsupportedCallbackException e) { handleCallbackError(e); } OAuthBearerToken token = callback.token(); if (token == null) { errorMessage = jsonErrorResponse(callback.errorStatus(), callback.errorScope(), callback.errorOpenIDConfiguration()); log.debug(errorMessage); return errorMessage.getBytes(StandardCharsets.UTF_8); } /* * We support the client specifying an authorization ID as per the SASL * specification, but it must match the principal name if it is specified. */ if (!authorizationId.isEmpty() && !authorizationId.equals(token.principalName())) throw new SaslAuthenticationException(String.format( "Authentication failed: Client requested an authorization id (%s) that is different from the token's principal name (%s)", authorizationId, token.principalName())); Map<String, String> validExtensions = processExtensions(token, extensions); tokenForNegotiatedProperty = token; this.extensions = new SaslExtensions(validExtensions); complete = true; log.debug("Successfully authenticate User={}", token.principalName()); return new byte[0]; }
@Test public void testNoExtensionsFromTokenAndNullExtensions() throws Exception { OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse("token", null); assertTrue(response.extensions().map().isEmpty()); }
/** * Extensions that are neither validated or invalidated must not be present in either maps */ @Test public void testUnvalidatedExtensionsAreIgnored() { Map<String, String> extensions = new HashMap<>(); extensions.put("valid", "valid"); extensions.put("error", "error"); extensions.put("nothing", "nothing"); OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions)); callback.error("error", "error"); callback.valid("valid"); assertFalse(callback.validatedExtensions().containsKey("nothing")); assertFalse(callback.invalidExtensions().containsKey("nothing")); assertEquals("nothing", callback.ignoredExtensions().get("nothing")); }
@Test public void testRfc7688Example() throws Exception { String message = "n,a=user@example.com,\u0001host=server.example.com\u0001port=143\u0001" + "auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001"; OAuthBearerClientInitialResponse response = new OAuthBearerClientInitialResponse(message.getBytes(StandardCharsets.UTF_8)); assertEquals("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", response.tokenValue()); assertEquals("user@example.com", response.authorizationId()); assertEquals("server.example.com", response.extensions().map().get("host")); assertEquals("143", response.extensions().map().get("port")); }
@Test public void testValidatedExtensionsAreReturned() { Map<String, String> extensions = new HashMap<>(); extensions.put("hello", "bye"); OAuthBearerExtensionsValidatorCallback callback = new OAuthBearerExtensionsValidatorCallback(TOKEN, new SaslExtensions(extensions)); assertTrue(callback.validatedExtensions().isEmpty()); assertTrue(callback.invalidExtensions().isEmpty()); callback.valid("hello"); assertFalse(callback.validatedExtensions().isEmpty()); assertEquals("bye", callback.validatedExtensions().get("hello")); assertTrue(callback.invalidExtensions().isEmpty()); }