Class<?> definitionClass = concreteImplementation.getDefinitionClass(); DynamicClassLoader classLoader = new DynamicClassLoader(definitionClass.getClassLoader(), getClass().getClassLoader()); Class<?> stateClass = concreteImplementation.getStateClass(); AccumulatorStateSerializer<?> stateSerializer = getAccumulatorStateSerializer(concreteImplementation, variables, typeManager, functionRegistry, stateClass, classLoader); AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader); MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), variables, typeManager, functionRegistry); MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), variables, typeManager, functionRegistry); MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), variables, typeManager, functionRegistry); List<ParameterMetadata> parametersMetadata = buildParameterMetadata(concreteImplementation.getInputParameterMetadataTypes(), inputTypes);
private static AccumulatorStateSerializer<?> getAccumulatorStateSerializer(AggregationImplementation implementation, BoundVariables variables, TypeManager typeManager, FunctionRegistry functionRegistry, Class<?> stateClass, DynamicClassLoader classLoader) { AccumulatorStateSerializer<?> stateSerializer; Optional<MethodHandle> stateSerializerFactory = implementation.getStateSerializerFactory(); if (stateSerializerFactory.isPresent()) { try { MethodHandle factoryHandle = bindDependencies(stateSerializerFactory.get(), implementation.getStateSerializerFactoryDependencies(), variables, typeManager, functionRegistry); stateSerializer = (AccumulatorStateSerializer<?>) factoryHandle.invoke(); } catch (Throwable t) { throwIfUnchecked(t); throw new RuntimeException(t); } } else { stateSerializer = generateStateSerializer(stateClass, classLoader); } return stateSerializer; }
private AggregationImplementation get() { Signature signature = new Signature( header.getName(), FunctionKind.AGGREGATE, typeVariableConstraints, longVariableConstraints, returnType, inputTypes, false); return new AggregationImplementation(signature, aggregationDefinition, stateClass, inputHandle, outputHandle, combineHandle, stateSerializerFactoryHandle, argumentNativeContainerTypes, inputDependencies, combineDependencies, outputDependencies, stateSerializerFactoryDependencies, parameterMetadataTypes); }
void assertDependencyCount(AggregationImplementation implementation, int input, int combine, int output) { assertEquals(implementation.getInputDependencies().size(), input); assertEquals(implementation.getCombineDependencies().size(), combine); assertEquals(implementation.getOutputDependencies().size(), output); } }
assertTrue(implementation.getStateSerializerFactory().isPresent()); assertEquals(implementation.getDefinitionClass(), InjectOperatorAggregateFunction.class); assertEquals(implementation.getStateSerializerFactoryDependencies().size(), 1); assertTrue(implementation.getInputDependencies().get(0) instanceof OperatorImplementationDependency); assertTrue(implementation.getCombineDependencies().get(0) instanceof OperatorImplementationDependency); assertTrue(implementation.getOutputDependencies().get(0) instanceof OperatorImplementationDependency); assertTrue(implementation.getStateSerializerFactoryDependencies().get(0) instanceof OperatorImplementationDependency); assertFalse(implementation.hasSpecializedTypeParameters()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementation.getInputParameterMetadataTypes().equals(expectedMetadataTypes));
@Test public void testFixedTypeParameterInjectionAggregateFunctionParse() { Signature expectedSignature = new Signature( "fixed_type_parameter_injection", FunctionKind.AGGREGATE, ImmutableList.of(), ImmutableList.of(), DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false); ParametricAggregation aggregation = parseFunctionDefinition(FixedTypeParameterInjectionAggregateFunction.class); assertEquals(aggregation.getDescription(), "Simple aggregate with fixed parameter type injected"); assertTrue(aggregation.isDeterministic()); assertEquals(aggregation.getSignature(), expectedSignature); ParametricImplementationsGroup<AggregationImplementation> implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 1, 0, 0); AggregationImplementation implementationDouble = implementations.getExactImplementations().get(expectedSignature); assertFalse(implementationDouble.getStateSerializerFactory().isPresent()); assertEquals(implementationDouble.getDefinitionClass(), FixedTypeParameterInjectionAggregateFunction.class); assertDependencyCount(implementationDouble, 1, 1, 1); assertFalse(implementationDouble.hasSpecializedTypeParameters()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementationDouble.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); assertEquals(implementationDouble.getStateClass(), NullableDoubleState.class); }
public void testSimpleExplicitSpecializedAggregationParse() { Signature expectedSignature = new Signature( "explicit_specialized_aggregate", FunctionKind.AGGREGATE, ImmutableList.of(typeVariable("T")), ImmutableList.of(), parseTypeSignature("T"), ImmutableList.of(new TypeSignature(ARRAY, TypeSignatureParameter.of(parseTypeSignature("T")))), false); ParametricAggregation aggregation = parseFunctionDefinition(ExplicitSpecializedAggregationFunction.class); assertEquals(aggregation.getDescription(), "Simple explicit specialized aggregate"); assertTrue(aggregation.isDeterministic()); assertEquals(aggregation.getSignature(), expectedSignature); ParametricImplementationsGroup<AggregationImplementation> implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 1, 1); AggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0); assertTrue(implementation1.hasSpecializedTypeParameters()); assertFalse(implementation1.hasSpecializedTypeParameters()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementation1.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); AggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1); assertTrue(implementation2.hasSpecializedTypeParameters()); assertFalse(implementation2.hasSpecializedTypeParameters()); assertTrue(implementation2.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); InternalAggregationFunction specialized = aggregation.specialize(BoundVariables.builder().setTypeVariable("T", DoubleType.DOUBLE).build(), 1, new TypeRegistry(), null); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertTrue(specialized.isDecomposable()); assertEquals(specialized.name(), "implicit_specialized_aggregate"); }
@Test public void testStateOnDifferentThanFirstPositionAggregationParse() { Signature expectedSignature = new Signature( "simple_exact_aggregate_aggregation_state_moved", FunctionKind.AGGREGATE, DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature())); ParametricAggregation aggregation = parseFunctionDefinition(StateOnDifferentThanFirstPositionAggregationFunction.class); assertEquals(aggregation.getSignature(), expectedSignature); AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); assertEquals(implementation.getDefinitionClass(), StateOnDifferentThanFirstPositionAggregationFunction.class); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL, AggregationMetadata.ParameterMetadata.ParameterType.STATE); assertTrue(implementation.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); }
@Test public void testNotAnnotatedAggregateStateAggregationParse() { ParametricAggregation aggregation = parseFunctionDefinition(NotAnnotatedAggregateStateAggregationFunction.class); AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementation.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); InternalAggregationFunction specialized = aggregation.specialize(BoundVariables.builder().build(), 1, new TypeRegistry(), null); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertTrue(specialized.isDecomposable()); assertEquals(specialized.name(), "no_aggregation_state_aggregate"); }
@Test public void testCustomStateSerializerAggregationParse() { ParametricAggregation aggregation = parseFunctionDefinition(CustomStateSerializerAggregationFunction.class); AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); assertTrue(implementation.getStateSerializerFactory().isPresent()); InternalAggregationFunction specialized = aggregation.specialize(BoundVariables.builder().build(), 1, new TypeRegistry(), null); AccumulatorStateSerializer<?> createdSerializer = getOnlyElement(((LazyAccumulatorFactoryBinder) specialized.getAccumulatorFactoryBinder()) .getGenericAccumulatorFactoryBinder().getStateDescriptors()).getSerializer(); Class<?> serializerFactory = implementation.getStateSerializerFactory().get().type().returnType(); assertTrue(serializerFactory.isInstance(createdSerializer)); }
private AggregationImplementation findMatchingImplementation(Signature boundSignature, BoundVariables variables, TypeManager typeManager, FunctionRegistry functionRegistry) { Optional<AggregationImplementation> foundImplementation = Optional.empty(); if (implementations.getExactImplementations().containsKey(boundSignature)) { foundImplementation = Optional.of(implementations.getExactImplementations().get(boundSignature)); } else { for (AggregationImplementation candidate : implementations.getGenericImplementations()) { if (candidate.areTypesAssignable(boundSignature, variables, typeManager, functionRegistry)) { if (foundImplementation.isPresent()) { throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, format("Ambiguous function call (%s) for %s", variables, getSignature())); } foundImplementation = Optional.of(candidate); } } } if (!foundImplementation.isPresent()) { throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", variables, getSignature())); } return foundImplementation.get(); }
assertTrue(implementation.getStateSerializerFactory().isPresent()); assertEquals(implementation.getDefinitionClass(), InjectOperatorAggregateFunction.class); assertEquals(implementation.getStateSerializerFactoryDependencies().size(), 1); assertTrue(implementation.getInputDependencies().get(0) instanceof OperatorImplementationDependency); assertTrue(implementation.getCombineDependencies().get(0) instanceof OperatorImplementationDependency); assertTrue(implementation.getOutputDependencies().get(0) instanceof OperatorImplementationDependency); assertTrue(implementation.getStateSerializerFactoryDependencies().get(0) instanceof OperatorImplementationDependency); assertFalse(implementation.hasSpecializedTypeParameters()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementation.getInputParameterMetadataTypes().equals(expectedMetadataTypes));
@Test public void testFixedTypeParameterInjectionAggregateFunctionParse() { Signature expectedSignature = new Signature( "fixed_type_parameter_injection", FunctionKind.AGGREGATE, ImmutableList.of(), ImmutableList.of(), DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature()), false); ParametricAggregation aggregation = parseFunctionDefinition(FixedTypeParameterInjectionAggregateFunction.class); assertEquals(aggregation.getDescription(), "Simple aggregate with fixed parameter type injected"); assertTrue(aggregation.isDeterministic()); assertEquals(aggregation.getSignature(), expectedSignature); ParametricImplementationsGroup<AggregationImplementation> implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 1, 0, 0); AggregationImplementation implementationDouble = implementations.getExactImplementations().get(expectedSignature); assertFalse(implementationDouble.getStateSerializerFactory().isPresent()); assertEquals(implementationDouble.getDefinitionClass(), FixedTypeParameterInjectionAggregateFunction.class); assertDependencyCount(implementationDouble, 1, 1, 1); assertFalse(implementationDouble.hasSpecializedTypeParameters()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementationDouble.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); assertEquals(implementationDouble.getStateClass(), NullableDoubleState.class); }
void assertDependencyCount(AggregationImplementation implementation, int input, int combine, int output) { assertEquals(implementation.getInputDependencies().size(), input); assertEquals(implementation.getCombineDependencies().size(), combine); assertEquals(implementation.getOutputDependencies().size(), output); } }
public void testSimpleImplicitSpecializedAggregationParse() { Signature expectedSignature = new Signature( "implicit_specialized_aggregate", FunctionKind.AGGREGATE, ImmutableList.of(typeVariable("T")), ImmutableList.of(), parseTypeSignature("T"), ImmutableList.of(new TypeSignature(ARRAY, TypeSignatureParameter.of(parseTypeSignature("T"))), parseTypeSignature("T")), false); ParametricAggregation aggregation = parseFunctionDefinition(ImplicitSpecializedAggregationFunction.class); assertEquals(aggregation.getDescription(), "Simple implicit specialized aggregate"); assertTrue(aggregation.isDeterministic()); assertEquals(aggregation.getSignature(), expectedSignature); ParametricImplementationsGroup<AggregationImplementation> implementations = aggregation.getImplementations(); assertImplementationCount(implementations, 0, 0, 2); AggregationImplementation implementation1 = implementations.getSpecializedImplementations().get(0); assertTrue(implementation1.hasSpecializedTypeParameters()); assertFalse(implementation1.hasSpecializedTypeParameters()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementation1.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); AggregationImplementation implementation2 = implementations.getSpecializedImplementations().get(1); assertTrue(implementation2.hasSpecializedTypeParameters()); assertFalse(implementation2.hasSpecializedTypeParameters()); assertTrue(implementation2.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); InternalAggregationFunction specialized = aggregation.specialize(BoundVariables.builder().setTypeVariable("T", DoubleType.DOUBLE).build(), 1, new TypeRegistry(), null); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertTrue(specialized.isDecomposable()); assertEquals(specialized.name(), "implicit_specialized_aggregate"); }
@Test public void testStateOnDifferentThanFirstPositionAggregationParse() { Signature expectedSignature = new Signature( "simple_exact_aggregate_aggregation_state_moved", FunctionKind.AGGREGATE, DoubleType.DOUBLE.getTypeSignature(), ImmutableList.of(DoubleType.DOUBLE.getTypeSignature())); ParametricAggregation aggregation = parseFunctionDefinition(StateOnDifferentThanFirstPositionAggregationFunction.class); assertEquals(aggregation.getSignature(), expectedSignature); AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); assertEquals(implementation.getDefinitionClass(), StateOnDifferentThanFirstPositionAggregationFunction.class); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL, AggregationMetadata.ParameterMetadata.ParameterType.STATE); assertTrue(implementation.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); }
@Test public void testNotAnnotatedAggregateStateAggregationParse() { ParametricAggregation aggregation = parseFunctionDefinition(NotAnnotatedAggregateStateAggregationFunction.class); AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementation.getInputParameterMetadataTypes().equals(expectedMetadataTypes)); InternalAggregationFunction specialized = aggregation.specialize(BoundVariables.builder().build(), 1, new TypeRegistry(), null); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertTrue(specialized.isDecomposable()); assertEquals(specialized.name(), "no_aggregation_state_aggregate"); }
@Test public void testCustomStateSerializerAggregationParse() { ParametricAggregation aggregation = parseFunctionDefinition(CustomStateSerializerAggregationFunction.class); AggregationImplementation implementation = getOnlyElement(aggregation.getImplementations().getExactImplementations().values()); assertTrue(implementation.getStateSerializerFactory().isPresent()); InternalAggregationFunction specialized = aggregation.specialize(BoundVariables.builder().build(), 1, new TypeRegistry(), null); AccumulatorStateSerializer<?> createdSerializer = getOnlyElement(((LazyAccumulatorFactoryBinder) specialized.getAccumulatorFactoryBinder()) .getGenericAccumulatorFactoryBinder().getStateDescriptors()).getSerializer(); Class<?> serializerFactory = implementation.getStateSerializerFactory().get().type().returnType(); assertTrue(serializerFactory.isInstance(createdSerializer)); }
private AggregationImplementation findMatchingImplementation(Signature boundSignature, BoundVariables variables, TypeManager typeManager, FunctionRegistry functionRegistry) { Optional<AggregationImplementation> foundImplementation = Optional.empty(); if (implementations.getExactImplementations().containsKey(boundSignature)) { foundImplementation = Optional.of(implementations.getExactImplementations().get(boundSignature)); } else { for (AggregationImplementation candidate : implementations.getGenericImplementations()) { if (candidate.areTypesAssignable(boundSignature, variables, typeManager, functionRegistry)) { if (foundImplementation.isPresent()) { throw new PrestoException(AMBIGUOUS_FUNCTION_CALL, format("Ambiguous function call (%s) for %s", variables, getSignature())); } foundImplementation = Optional.of(candidate); } } } if (!foundImplementation.isPresent()) { throw new PrestoException(FUNCTION_IMPLEMENTATION_MISSING, format("Unsupported type parameters (%s) for %s", variables, getSignature())); } return foundImplementation.get(); }
AggregationImplementation implementation = implementations.getGenericImplementations().get(0); assertTrue(implementation.getStateSerializerFactory().isPresent()); assertEquals(implementation.getDefinitionClass(), InjectLiteralAggregateFunction.class); assertEquals(implementation.getStateSerializerFactoryDependencies().size(), 1); assertTrue(implementation.getInputDependencies().get(0) instanceof LiteralImplementationDependency); assertTrue(implementation.getCombineDependencies().get(0) instanceof LiteralImplementationDependency); assertTrue(implementation.getOutputDependencies().get(0) instanceof LiteralImplementationDependency); assertTrue(implementation.getStateSerializerFactoryDependencies().get(0) instanceof LiteralImplementationDependency); assertFalse(implementation.hasSpecializedTypeParameters()); List<AggregationMetadata.ParameterMetadata.ParameterType> expectedMetadataTypes = ImmutableList.of(AggregationMetadata.ParameterMetadata.ParameterType.STATE, AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL); assertTrue(implementation.getInputParameterMetadataTypes().equals(expectedMetadataTypes));