catalog = trimEmptyToNull(servletRequest.getHeader(PRESTO_CATALOG)); schema = trimEmptyToNull(servletRequest.getHeader(PRESTO_SCHEMA)); path = trimEmptyToNull(servletRequest.getHeader(PRESTO_PATH)); assertRequest((catalog != null) || (schema == null), "Schema is set but catalog is not"); String user = trimEmptyToNull(servletRequest.getHeader(PRESTO_USER)); assertRequest(user != null, "User must be set"); identity = new Identity(user, Optional.ofNullable(servletRequest.getUserPrincipal())); traceToken = Optional.ofNullable(trimEmptyToNull(servletRequest.getHeader(PRESTO_TRACE_TOKEN))); userAgent = servletRequest.getHeader(USER_AGENT); remoteUserAddress = servletRequest.getRemoteAddr(); language = servletRequest.getHeader(PRESTO_LANGUAGE); clientInfo = servletRequest.getHeader(PRESTO_CLIENT_INFO); clientTags = parseClientTags(servletRequest); clientCapabilities = parseClientCapabilities(servletRequest); resourceEstimates = parseResourceEstimate(servletRequest); for (Entry<String, String> entry : parseSessionHeaders(servletRequest).entrySet()) { String fullPropertyName = entry.getKey(); String propertyValue = entry.getValue(); assertRequest(!propertyName.isEmpty(), "Invalid %s header", PRESTO_SESSION); assertRequest(!catalogName.isEmpty(), "Invalid %s header", PRESTO_SESSION); assertRequest(!propertyName.isEmpty(), "Invalid %s header", PRESTO_SESSION); throw badRequest(format("Invalid %s header", PRESTO_SESSION));
private static Optional<TransactionId> parseTransactionId(String transactionId) { transactionId = trimEmptyToNull(transactionId); if (transactionId == null || transactionId.equalsIgnoreCase("none")) { return Optional.empty(); } try { return Optional.of(TransactionId.valueOf(transactionId)); } catch (Exception e) { throw badRequest(e.getMessage()); } }
private static Map<String, String> parseSessionHeaders(HttpServletRequest servletRequest) { Map<String, String> sessionProperties = new HashMap<>(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_SESSION))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_SESSION); sessionProperties.put(nameValue.get(0), nameValue.get(1)); } return sessionProperties; }
private static Map<String, String> parsePreparedStatementsHeaders(HttpServletRequest servletRequest) { ImmutableMap.Builder<String, String> preparedStatements = ImmutableMap.builder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_PREPARED_STATEMENT))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_PREPARED_STATEMENT); String statementName; String sqlString; try { statementName = urlDecode(nameValue.get(0)); sqlString = urlDecode(nameValue.get(1)); } catch (IllegalArgumentException e) { throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage())); } // Validate statement SqlParser sqlParser = new SqlParser(); try { sqlParser.createStatement(sqlString, new ParsingOptions(AS_DOUBLE /* anything */)); } catch (ParsingException e) { throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage())); } preparedStatements.put(statementName, sqlString); } return preparedStatements.build(); }
@Test public void testSessionContext() { HttpServletRequest request = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_SOURCE, "testSource") .put(PRESTO_CATALOG, "testCatalog") .put(PRESTO_SCHEMA, "testSchema") .put(PRESTO_PATH, "testPath") .put(PRESTO_LANGUAGE, "zh-TW") .put(PRESTO_TIME_ZONE, "Asia/Taipei") .put(PRESTO_CLIENT_INFO, "client-info") .put(PRESTO_SESSION, QUERY_MAX_MEMORY + "=1GB") .put(PRESTO_SESSION, JOIN_DISTRIBUTION_TYPE + "=partitioned," + HASH_PARTITION_COUNT + " = 43") .put(PRESTO_PREPARED_STATEMENT, "query1=select * from foo,query2=select * from bar") .build(), "testRemote"); HttpRequestSessionContext context = new HttpRequestSessionContext(request); assertEquals(context.getSource(), "testSource"); assertEquals(context.getCatalog(), "testCatalog"); assertEquals(context.getSchema(), "testSchema"); assertEquals(context.getPath(), "testPath"); assertEquals(context.getIdentity(), new Identity("testUser", Optional.empty())); assertEquals(context.getClientInfo(), "client-info"); assertEquals(context.getLanguage(), "zh-TW"); assertEquals(context.getTimeZoneId(), "Asia/Taipei"); assertEquals(context.getSystemProperties(), ImmutableMap.of(QUERY_MAX_MEMORY, "1GB", JOIN_DISTRIBUTION_TYPE, "partitioned", HASH_PARTITION_COUNT, "43")); assertEquals(context.getPreparedStatements(), ImmutableMap.of("query1", "select * from foo", "query2", "select * from bar")); }
SessionContext sessionContext = new HttpRequestSessionContext(servletRequest);
private ResourceEstimates parseResourceEstimate(HttpServletRequest servletRequest) { ResourceEstimateBuilder builder = new ResourceEstimateBuilder(); for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_RESOURCE_ESTIMATE))) { List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header); assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_RESOURCE_ESTIMATE); String name = nameValue.get(0); String value = nameValue.get(1); try { switch (name.toUpperCase()) { case ResourceEstimates.EXECUTION_TIME: builder.setExecutionTime(Duration.valueOf(value)); break; case ResourceEstimates.CPU_TIME: builder.setCpuTime(Duration.valueOf(value)); break; case ResourceEstimates.PEAK_MEMORY: builder.setPeakMemory(DataSize.valueOf(value)); break; default: throw badRequest(format("Unsupported resource name %s", name)); } } catch (IllegalArgumentException e) { throw badRequest(format("Unsupported format for resource estimate '%s': %s", value, e)); } } return builder.build(); }
@Test public void testEmptyClientTags() { HttpServletRequest request1 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .build(), "remoteAddress"); HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1); assertEquals(context1.getClientTags(), ImmutableSet.of()); HttpServletRequest request2 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_CLIENT_TAGS, "") .build(), "remoteAddress"); HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2); assertEquals(context2.getClientTags(), ImmutableSet.of()); }
@Test public void testClientCapabilities() { HttpServletRequest request1 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_CLIENT_CAPABILITIES, "foo, bar") .build(), "remoteAddress"); HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1); assertEquals(context1.getClientCapabilities(), ImmutableSet.of("foo", "bar")); HttpServletRequest request2 = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .build(), "remoteAddress"); HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2); assertEquals(context2.getClientCapabilities(), ImmutableSet.of()); }
private static void assertRequest(boolean expression, String format, Object... args) { if (!expression) { throw badRequest(format(format, args)); } }
@Test public void testCreateSession() HttpRequestSessionContext context = new HttpRequestSessionContext(TEST_REQUEST); QuerySessionSupplier sessionSupplier = new QuerySessionSupplier( createTestTransactionManager(),
@Test(expectedExceptions = WebApplicationException.class) public void testPreparedStatementsHeaderDoesNotParse() { HttpServletRequest request = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_SOURCE, "testSource") .put(PRESTO_CATALOG, "testCatalog") .put(PRESTO_SCHEMA, "testSchema") .put(PRESTO_PATH, "testPath") .put(PRESTO_LANGUAGE, "zh-TW") .put(PRESTO_TIME_ZONE, "Asia/Taipei") .put(PRESTO_CLIENT_INFO, "null") .put(PRESTO_PREPARED_STATEMENT, "query1=abcdefg") .build(), "testRemote"); new HttpRequestSessionContext(request); } }
@Test(expectedExceptions = PrestoException.class) public void testInvalidTimeZone() { HttpServletRequest request = new MockHttpServletRequest( ImmutableListMultimap.<String, String>builder() .put(PRESTO_USER, "testUser") .put(PRESTO_TIME_ZONE, "unknown_timezone") .build(), "testRemote"); HttpRequestSessionContext context = new HttpRequestSessionContext(request); QuerySessionSupplier sessionSupplier = new QuerySessionSupplier( createTestTransactionManager(), new AllowAllAccessControl(), new SessionPropertyManager(), new SqlEnvironmentConfig()); sessionSupplier.createSession(new QueryId("test_query_id"), context); }