catalog = trimEmptyToNull(servletRequest.getHeader(PRESTO_CATALOG)); schema = trimEmptyToNull(servletRequest.getHeader(PRESTO_SCHEMA)); path = trimEmptyToNull(servletRequest.getHeader(PRESTO_PATH)); assertRequest((catalog != null) || (schema == null), "Schema is set but catalog is not"); String user = trimEmptyToNull(servletRequest.getHeader(PRESTO_USER)); assertRequest(user != null, "User must be set"); identity = new Identity( user, Optional.ofNullable(servletRequest.getUserPrincipal()), parseRoleHeaders(servletRequest), parseExtraCredentials(servletRequest)); traceToken = Optional.ofNullable(trimEmptyToNull(servletRequest.getHeader(PRESTO_TRACE_TOKEN))); userAgent = servletRequest.getHeader(USER_AGENT); remoteUserAddress = servletRequest.getRemoteAddr(); language = servletRequest.getHeader(PRESTO_LANGUAGE); clientInfo = servletRequest.getHeader(PRESTO_CLIENT_INFO); clientTags = parseClientTags(servletRequest); clientCapabilities = parseClientCapabilities(servletRequest); resourceEstimates = parseResourceEstimate(servletRequest); for (Entry<String, String> entry : parseSessionHeaders(servletRequest).entrySet()) { String fullPropertyName = entry.getKey(); String propertyValue = entry.getValue(); assertRequest(!propertyName.isEmpty(), "Invalid %s header", PRESTO_SESSION);
private static Map<String, String> parseProperty(HttpServletRequest servletRequest, String headerName) { Map<String, String> properties = new HashMap<>(); for (String header : splitSessionHeader(servletRequest.getHeaders(headerName))) { List<String> nameValue = Splitter.on('=').trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", headerName); properties.put(nameValue.get(0), nameValue.get(1)); } return properties; }
private static Optional<TransactionId> parseTransactionId(String transactionId) { transactionId = trimEmptyToNull(transactionId); if (transactionId == null || transactionId.equalsIgnoreCase("none")) { return Optional.empty(); } try { return Optional.of(TransactionId.valueOf(transactionId)); } catch (Exception e) { throw badRequest(e.getMessage()); } }
private static Map<String, SelectedRole> parseRoleHeaders(HttpServletRequest servletRequest) { ImmutableMap.Builder<String, SelectedRole> roles = ImmutableMap.builder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_ROLE))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_ROLE); roles.put(nameValue.get(0), SelectedRole.valueOf(urlDecode(nameValue.get(1)))); } return roles.build(); }
"testRemote"); HttpRequestSessionContext context = new HttpRequestSessionContext(request); assertEquals(context.getSource(), "testSource"); assertEquals(context.getCatalog(), "testCatalog"); assertEquals(context.getSchema(), "testSchema"); assertEquals(context.getPath(), "testPath"); assertEquals(context.getIdentity(), new Identity("testUser", Optional.empty())); assertEquals(context.getClientInfo(), "client-info"); assertEquals(context.getLanguage(), "zh-TW"); assertEquals(context.getTimeZoneId(), "Asia/Taipei"); assertEquals(context.getSystemProperties(), ImmutableMap.of(QUERY_MAX_MEMORY, "1GB", JOIN_DISTRIBUTION_TYPE, "partitioned", HASH_PARTITION_COUNT, "43")); assertEquals(context.getPreparedStatements(), ImmutableMap.of("query1", "select * from foo", "query2", "select * from bar")); assertEquals(context.getIdentity().getRoles(), ImmutableMap.of( "foo_connector", new SelectedRole(SelectedRole.Type.ALL, Optional.empty()), "bar_connector", new SelectedRole(SelectedRole.Type.NONE, Optional.empty()),
SessionContext sessionContext = new HttpRequestSessionContext(servletRequest);
private static Map<String, String> parsePreparedStatementsHeaders(HttpServletRequest servletRequest) { ImmutableMap.Builder<String, String> preparedStatements = ImmutableMap.builder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_PREPARED_STATEMENT))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_PREPARED_STATEMENT); String statementName; String sqlString; try { statementName = urlDecode(nameValue.get(0)); sqlString = urlDecode(nameValue.get(1)); } catch (IllegalArgumentException e) { throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage())); } // Validate statement SqlParser sqlParser = new SqlParser(); try { sqlParser.createStatement(sqlString, new ParsingOptions(AS_DOUBLE /* anything */)); } catch (ParsingException e) { throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage())); } preparedStatements.put(statementName, sqlString); } return preparedStatements.build(); }
@Test public void testEmptyClientTags() { HttpServletRequest request1 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .build(), "remoteAddress"); HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1); assertEquals(context1.getClientTags(), ImmutableSet.of()); HttpServletRequest request2 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_CLIENT_TAGS, "") .build(), "remoteAddress"); HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2); assertEquals(context2.getClientTags(), ImmutableSet.of()); }
@Test public void testClientCapabilities() { HttpServletRequest request1 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_CLIENT_CAPABILITIES, "foo, bar") .build(), "remoteAddress"); HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1); assertEquals(context1.getClientCapabilities(), ImmutableSet.of("foo", "bar")); HttpServletRequest request2 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .build(), "remoteAddress"); HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2); assertEquals(context2.getClientCapabilities(), ImmutableSet.of()); }
private ResourceEstimates parseResourceEstimate(HttpServletRequest servletRequest) { ResourceEstimateBuilder builder = new ResourceEstimateBuilder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_RESOURCE_ESTIMATE))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_RESOURCE_ESTIMATE); String name = nameValue.get(0); String value = nameValue.get(1); try { switch (name.toUpperCase()) { case ResourceEstimates.EXECUTION_TIME: builder.setExecutionTime(Duration.valueOf(value)); break; case ResourceEstimates.CPU_TIME: builder.setCpuTime(Duration.valueOf(value)); break; case ResourceEstimates.PEAK_MEMORY: builder.setPeakMemory(DataSize.valueOf(value)); break; default: throw badRequest(format("Unsupported resource name %s", name)); } } catch (IllegalArgumentException e) { throw badRequest(format("Unsupported format for resource estimate '%s': %s", value, e)); } } return builder.build(); }
private static void assertRequest(boolean expression, String format, Object... args) { if (!expression) { throw badRequest(format(format, args)); } }
"testRemote"); HttpRequestSessionContext context = new HttpRequestSessionContext(request); assertEquals(context.getSource(), "testSource"); assertEquals(context.getCatalog(), "testCatalog"); assertEquals(context.getSchema(), "testSchema"); assertEquals(context.getPath(), "testPath"); assertEquals(context.getIdentity(), new Identity("testUser", Optional.empty())); assertEquals(context.getClientInfo(), "client-info"); assertEquals(context.getLanguage(), "zh-TW"); assertEquals(context.getTimeZoneId(), "Asia/Taipei"); assertEquals(context.getSystemProperties(), ImmutableMap.of(QUERY_MAX_MEMORY, "1GB", JOIN_DISTRIBUTION_TYPE, "partitioned", HASH_PARTITION_COUNT, "43")); assertEquals(context.getPreparedStatements(), ImmutableMap.of("query1", "select * from foo", "query2", "select * from bar")); assertEquals(context.getIdentity().getRoles(), ImmutableMap.of( "foo_connector", new SelectedRole(SelectedRole.Type.ALL, Optional.empty()), "bar_connector", new SelectedRole(SelectedRole.Type.NONE, Optional.empty()), "foobar_connector", new SelectedRole(SelectedRole.Type.ROLE, Optional.of("role")))); assertEquals(context.getIdentity().getExtraCredentials(), ImmutableMap.of("test.token.foo", "bar", "test.token.abc", "xyz"));
SessionContext sessionContext = new HttpRequestSessionContext(servletRequest);
private static Map<String, String> parsePreparedStatementsHeaders(HttpServletRequest servletRequest) { ImmutableMap.Builder<String, String> preparedStatements = ImmutableMap.builder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_PREPARED_STATEMENT))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_PREPARED_STATEMENT); String statementName; String sqlString; try { statementName = urlDecode(nameValue.get(0)); sqlString = urlDecode(nameValue.get(1)); } catch (IllegalArgumentException e) { throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage())); } // Validate statement SqlParser sqlParser = new SqlParser(); try { sqlParser.createStatement(sqlString, new ParsingOptions(AS_DOUBLE /* anything */)); } catch (ParsingException e) { throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage())); } preparedStatements.put(statementName, sqlString); } return preparedStatements.build(); }
private static Map<String, SelectedRole> parseRoleHeaders(HttpServletRequest servletRequest) { ImmutableMap.Builder<String, SelectedRole> roles = ImmutableMap.builder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_ROLE))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_ROLE); roles.put(nameValue.get(0), SelectedRole.valueOf(urlDecode(nameValue.get(1)))); } return roles.build(); }
@Test public void testEmptyClientTags() { HttpServletRequest request1 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .build(), "remoteAddress"); HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1); assertEquals(context1.getClientTags(), ImmutableSet.of()); HttpServletRequest request2 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_CLIENT_TAGS, "") .build(), "remoteAddress"); HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2); assertEquals(context2.getClientTags(), ImmutableSet.of()); }
@Test public void testClientCapabilities() { HttpServletRequest request1 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_CLIENT_CAPABILITIES, "foo, bar") .build(), "remoteAddress"); HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1); assertEquals(context1.getClientCapabilities(), ImmutableSet.of("foo", "bar")); HttpServletRequest request2 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .build(), "remoteAddress"); HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2); assertEquals(context2.getClientCapabilities(), ImmutableSet.of()); }
private ResourceEstimates parseResourceEstimate(HttpServletRequest servletRequest) { ResourceEstimateBuilder builder = new ResourceEstimateBuilder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_RESOURCE_ESTIMATE))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_RESOURCE_ESTIMATE); String name = nameValue.get(0); String value = nameValue.get(1); try { switch (name.toUpperCase()) { case ResourceEstimates.EXECUTION_TIME: builder.setExecutionTime(Duration.valueOf(value)); break; case ResourceEstimates.CPU_TIME: builder.setCpuTime(Duration.valueOf(value)); break; case ResourceEstimates.PEAK_MEMORY: builder.setPeakMemory(DataSize.valueOf(value)); break; default: throw badRequest(format("Unsupported resource name %s", name)); } } catch (IllegalArgumentException e) { throw badRequest(format("Unsupported format for resource estimate '%s': %s", value, e)); } } return builder.build(); }
private static Map<String, String> parseSessionHeaders(HttpServletRequest servletRequest) { Map<String, String> sessionProperties = new HashMap<>(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_SESSION))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_SESSION); sessionProperties.put(nameValue.get(0), nameValue.get(1)); } return sessionProperties; }
private static Optional<TransactionId> parseTransactionId(String transactionId) { transactionId = trimEmptyToNull(transactionId); if (transactionId == null || transactionId.equalsIgnoreCase("none")) { return Optional.empty(); } try { return Optional.of(TransactionId.valueOf(transactionId)); } catch (Exception e) { throw badRequest(e.getMessage()); } }