private PhysicalOperation planGlobalAggregation(int operatorId, AggregationNode node, PhysicalOperation source) { int outputChannel = 0; ImmutableMap.Builder<Symbol, Integer> outputMappings = ImmutableMap.builder(); List<AccumulatorFactory> accumulatorFactories = new ArrayList<>(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); accumulatorFactories.add(buildAccumulatorFactory(source, node.getFunctions().get(symbol), entry.getValue(), node.getMasks().get(entry.getKey()), Optional.<Integer>empty(), node.getSampleWeight(), node.getConfidence())); outputMappings.put(symbol, outputChannel); // one aggregation per channel outputChannel++; } OperatorFactory operatorFactory = new AggregationOperatorFactory(operatorId, node.getId(), node.getStep(), accumulatorFactories); return new PhysicalOperation(operatorFactory, outputMappings.build(), source); }
functions.put(symbol, node.getFunctions().get(symbol));
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context) { PlanNode source = context.rewrite(node.getSource()); ImmutableMap.Builder<Symbol, Signature> functionInfos = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, FunctionCall> functionCalls = ImmutableMap.builder(); ImmutableMap.Builder<Symbol, Symbol> masks = ImmutableMap.builder(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); Symbol canonical = canonicalize(symbol); FunctionCall canonicalCall = (FunctionCall) canonicalize(entry.getValue()); functionCalls.put(canonical, canonicalCall); functionInfos.put(canonical, node.getFunctions().get(symbol)); } for (Map.Entry<Symbol, Symbol> entry : node.getMasks().entrySet()) { masks.put(canonicalize(entry.getKey()), canonicalize(entry.getValue())); } List<Symbol> groupByKeys = canonicalizeAndDistinct(node.getGroupBy()); return new AggregationNode( node.getId(), source, groupByKeys, functionCalls.build(), functionInfos.build(), masks.build(), node.getStep(), canonicalize(node.getSampleWeight()), node.getConfidence(), canonicalize(node.getHashSymbol())); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Optional<Symbol>> context) { // optimize if and only if // all aggregation functions have a single common distinct mask symbol // AND all aggregation functions have mask Set<Symbol> masks = ImmutableSet.copyOf(node.getMasks().values()); if (masks.size() != 1 || node.getMasks().size() != node.getAggregations().size()) { return context.defaultRewrite(node, Optional.empty()); } PlanNode source = context.rewrite(node.getSource(), Optional.of(Iterables.getOnlyElement(masks))); Map<Symbol, FunctionCall> aggregations = ImmutableMap.copyOf(Maps.transformValues(node.getAggregations(), call -> new FunctionCall(call.getName(), call.getWindow(), false, call.getArguments()))); return new AggregationNode(idAllocator.getNextId(), source, node.getGroupBy(), aggregations, node.getFunctions(), Collections.emptyMap(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context) { Map<Symbol, FunctionCall> aggregations = new LinkedHashMap<>(node.getAggregations()); Map<Symbol, Signature> functions = new LinkedHashMap<>(node.getFunctions()); PlanNode source = context.rewrite(node.getSource()); if (source instanceof ProjectNode) { ProjectNode projectNode = (ProjectNode) source; for (Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Symbol symbol = entry.getKey(); FunctionCall functionCall = entry.getValue(); Signature signature = node.getFunctions().get(symbol); if (isCountConstant(projectNode, functionCall, signature)) { aggregations.put(symbol, new FunctionCall(functionCall.getName(), functionCall.isDistinct(), ImmutableList.<Expression>of())); functions.put(symbol, new Signature("count", AGGREGATE, StandardTypes.BIGINT)); } } } return new AggregationNode( node.getId(), source, node.getGroupBy(), aggregations, functions, node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
Signature signature = node.getFunctions().get(entry.getKey()); InternalAggregationFunction function = metadata.getFunctionRegistry().getAggregateFunctionImplementation(signature); node.getGroupBy(), finalCalls, node.getFunctions(), ImmutableMap.of(), INTERMEDIATE,
@Override public PlanNode visitAggregation(AggregationNode node, List<PlanNode> newChildren) { return new AggregationNode(node.getId(), Iterables.getOnlyElement(newChildren), node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
Map<Symbol, Symbol> intermediateMask = new HashMap<>(); for (Map.Entry<Symbol, FunctionCall> entry : node.getAggregations().entrySet()) { Signature signature = node.getFunctions().get(entry.getKey()); InternalAggregationFunction function = metadata.getFunctionRegistry().getAggregateFunctionImplementation(signature); node.getGroupBy(), finalCalls, node.getFunctions(), ImmutableMap.of(), FINAL,
node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(),
node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(),
@Override public PlanNode visitAggregation(AggregationNode node, RewriteContext<Boolean> context) { boolean distinct = isDistinctOperator(node); PlanNode rewrittenNode = context.rewrite(node.getSource(), distinct); if (context.get() && distinct) { // Assumes underlying node has same output symbols as this distinct node return rewrittenNode; } return new AggregationNode( node.getId(), rewrittenNode, node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence(), node.getHashSymbol()); }
boolean decomposable = node.getFunctions() .values().stream() .map(functionRegistry::getAggregateFunctionImplementation)