/** * This method is meant for invoking the rewrite logic on children while processing a node */ public PlanNode rewrite(PlanNode node, C userContext) { PlanNode result = node.accept(nodeRewriter, new RewriteContext<>(nodeRewriter, userContext)); verify(result != null, "nodeRewriter returned null for %s", node.getClass().getName()); return result; }
Set<Symbol> uniqueMasks = ImmutableSet.copyOf(masks); if (uniqueMasks.size() != 1 || masks.size() == node.getAggregations().size()) { return context.defaultRewrite(node, Optional.empty()); return context.defaultRewrite(node, Optional.empty()); return context.defaultRewrite(node, Optional.empty()); return context.defaultRewrite(node, Optional.empty()); PlanNode source = context.rewrite(node.getSource(), Optional.of(aggregateInfo)); return context.defaultRewrite(node, Optional.empty());
@Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Optional<AggregateInfo>> context) Optional<AggregateInfo> aggregateInfo = context.get(); return context.defaultRewrite(node, Optional.empty()); PlanNode source = context.rewrite(node.getSource(), Optional.empty());
/** * This method is meant for invoking the rewrite logic on children while processing a node */ public PlanNode rewrite(PlanNode node, C userContext) { PlanNode result = node.accept(nodeRewriter, new RewriteContext<>(nodeRewriter, userContext)); requireNonNull(result, format("nodeRewriter returned null for %s", node.getClass().getName())); return result; }
@Override public PlanNode visitValues(ValuesNode node, RewriteContext<FragmentProperties> context) { context.get().setSingleNodeDistribution(); return context.defaultRewrite(node, context.get()); }
@Override public PlanNode visitTableScan(TableScanNode node, RewriteContext<FragmentProperties> context) { context.get().setSourceDistribution(node.getId()); return context.defaultRewrite(node, context.get()); }
@Override public PlanNode visitFilter(FilterNode node, RewriteContext<Context> context) { if (node.getSource() instanceof TableScanNode) { return planTableScan((TableScanNode) node.getSource(), node.getPredicate(), context.get()); } return context.defaultRewrite(node, new Context(context.get().getLookupSymbols(), context.get().getSuccess())); }
return context.defaultRewrite(node); return context.defaultRewrite(node); return context.defaultRewrite(node); return context.defaultRewrite(node); if (value == null) { return context.defaultRewrite(node);
private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext<Expression> context) Expression inheritedPredicate = context.get(); List<Expression> sourceConjuncts = new ArrayList<>(); List<Expression> postJoinConjuncts = new ArrayList<>(); PlanNode rewrittenFilteringSource = context.defaultRewrite(node.getFilteringSource(), TRUE_LITERAL); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = context.rewrite(node.getSource(), combineConjuncts(sourceConjuncts));
@Override public PlanNode visitProject(ProjectNode node, RewriteContext<Expression> context) { Set<Symbol> deterministicSymbols = node.getAssignments().entrySet().stream() .filter(entry -> DeterminismEvaluator.isDeterministic(entry.getValue())) .map(Map.Entry::getKey) .collect(Collectors.toSet()); Predicate<Expression> deterministic = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(deterministicSymbols::contains); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic)); // Push down conjuncts from the inherited predicate that only depend on deterministic assignments with // certain limitations. List<Expression> deterministicConjuncts = conjuncts.get(true); // We partition the expressions in the deterministicConjuncts into two lists, and only inline the // expressions that are in the inlining targets list. Map<Boolean, List<Expression>> inlineConjuncts = deterministicConjuncts.stream() .collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node))); List<Expression> inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() .map(entry -> inlineSymbols(node.getAssignments().getMap(), entry)) .collect(Collectors.toList()); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(inlinedDeterministicConjuncts)); // All deterministic conjuncts that contains non-inlining targets, and non-deterministic conjuncts, // if any, will be in the filter node. List<Expression> nonInliningConjuncts = inlineConjuncts.get(false); nonInliningConjuncts.addAll(conjuncts.get(false)); if (!nonInliningConjuncts.isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(nonInliningConjuncts)); } return rewrittenNode; }
@Override public PlanNode visitLimit(LimitNode node, RewriteContext<Void> context) { // Operators can handle MAX_VALUE rows per page, so do not optimize if count is greater than this value if (node.getCount() > Integer.MAX_VALUE) { return context.defaultRewrite(node); } PlanNode source = context.rewrite(node.getSource()); int limit = toIntExact(node.getCount()); if (source instanceof RowNumberNode) { RowNumberNode rowNumberNode = mergeLimit(((RowNumberNode) source), limit); if (rowNumberNode.getPartitionBy().isEmpty()) { return rowNumberNode; } source = rowNumberNode; } else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source) && isOptimizeTopNRowNumber(session)) { WindowNode windowNode = (WindowNode) source; // verify that unordered row_number window functions are replaced by RowNumberNode verify(windowNode.getOrderingScheme().isPresent()); TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit); if (windowNode.getPartitionBy().isEmpty()) { return topNRowNumberNode; } source = topNRowNumberNode; } return replaceChildren(node, ImmutableList.of(source)); }
@Override public PlanNode visitExchange(ExchangeNode exchange, RewriteContext<FragmentProperties> context) { if (exchange.getScope() != REMOTE) { return context.defaultRewrite(exchange, context.get()); } PartitioningScheme partitioningScheme = exchange.getPartitioningScheme(); if (exchange.getType() == ExchangeNode.Type.GATHER) { context.get().setSingleNodeDistribution(); } else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { context.get().setDistribution(partitioningScheme.getPartitioning().getHandle(), metadata, session); } ImmutableList.Builder<SubPlan> builder = ImmutableList.builder(); for (int sourceIndex = 0; sourceIndex < exchange.getSources().size(); sourceIndex++) { FragmentProperties childProperties = new FragmentProperties(partitioningScheme.translateOutputLayout(exchange.getInputs().get(sourceIndex))); builder.add(buildSubPlan(exchange.getSources().get(sourceIndex), childProperties, context)); } List<SubPlan> children = builder.build(); context.get().addChildren(children); List<PlanFragmentId> childrenIds = children.stream() .map(SubPlan::getFragment) .map(PlanFragment::getId) .collect(toImmutableList()); return new RemoteSourceNode(exchange.getId(), childrenIds, exchange.getOutputSymbols(), exchange.getOrderingScheme(), exchange.getType()); }
@Override public PlanNode visitWindow(WindowNode node, RewriteContext<Context> context) { if (!node.getWindowFunctions().values().stream() .map(function -> function.getFunctionCall().getName()) .allMatch(metadata.getFunctionRegistry()::isAggregationFunction)) { return node; } // Don't need this restriction if we can prove that all order by symbols are deterministically produced if (node.getOrderingScheme().isPresent()) { return node; } // Only RANGE frame type currently supported for aggregation functions because it guarantees the // same value for each peer group. // ROWS frame type requires the ordering to be fully deterministic (e.g. deterministically sorted on all columns) if (node.getFrames().stream().map(WindowNode.Frame::getType).anyMatch(type -> type != WindowFrame.Type.RANGE)) { // TODO: extract frames of type RANGE and allow optimization on them return node; } // Lookup symbols can only be passed through if they are part of the partitioning Set<Symbol> partitionByLookupSymbols = context.get().getLookupSymbols().stream() .filter(node.getPartitionBy()::contains) .collect(toImmutableSet()); if (partitionByLookupSymbols.isEmpty()) { return node; } return context.defaultRewrite(node, new Context(partitionByLookupSymbols, context.get().getSuccess())); }
@Override public PlanNode visitWindow(WindowNode node, RewriteContext<Expression> context) { List<Symbol> partitionSymbols = node.getPartitionBy(); // TODO: This could be broader. We can push down conjucts if they are constant for all rows in a window partition. // The simplest way to guarantee this is if the conjucts are deterministic functions of the partitioning symbols. // This can leave out cases where they're both functions of some set of common expressions and the partitioning // function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by // pre-projected symbols. Predicate<Expression> isSupported = conjunct -> DeterminismEvaluator.isDeterministic(conjunct) && SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(partitionSymbols::contains); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(conjuncts.get(true))); if (!conjuncts.get(false).isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; }
@Override @Deprecated public PlanNode visitAggregation(AggregationNode node, RewriteContext<LimitContext> context) { LimitContext limit = context.get(); if (limit != null && node.getAggregations().isEmpty() && node.getOutputSymbols().size() == node.getGroupingKeys().size() && node.getOutputSymbols().containsAll(node.getGroupingKeys())) { PlanNode rewrittenSource = context.rewrite(node.getSource()); return new DistinctLimitNode(idAllocator.getNextId(), rewrittenSource, limit.getCount(), false, rewrittenSource.getOutputSymbols(), Optional.empty()); } PlanNode rewrittenNode = context.defaultRewrite(node); if (limit != null) { // Drop in a LimitNode b/c limits cannot be pushed through aggregations rewrittenNode = new LimitNode(idAllocator.getNextId(), rewrittenNode, limit.getCount(), limit.isPartial()); } return rewrittenNode; }
@Override public PlanNode visitTableFinish(TableFinishNode node, RewriteContext<Void> context) { Optional<DeleteNode> delete = findNode(node.getSource(), DeleteNode.class); if (!delete.isPresent()) { return context.defaultRewrite(node); } Optional<TableScanNode> tableScan = findNode(delete.get().getSource(), TableScanNode.class); if (!tableScan.isPresent()) { return context.defaultRewrite(node); } TableScanNode tableScanNode = tableScan.get(); if (!metadata.supportsMetadataDelete(session, tableScanNode.getTable(), tableScanNode.getLayout().get())) { return context.defaultRewrite(node); } return new MetadataDeleteNode(idAllocator.getNextId(), delete.get().getTarget(), Iterables.getOnlyElement(node.getOutputSymbols()), tableScanNode.getLayout().get()); }
@Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Set<Symbol>> context) { if (!context.get().contains(node.getMarkerSymbol())) { return context.rewrite(node.getSource(), context.get()); } ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(node.getDistinctSymbols()) .addAll(context.get().stream() .filter(symbol -> !symbol.equals(node.getMarkerSymbol())) .collect(toImmutableList())); if (node.getHashSymbol().isPresent()) { expectedInputs.add(node.getHashSymbol().get()); } PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new MarkDistinctNode(node.getId(), source, node.getMarkerSymbol(), node.getDistinctSymbols(), node.getHashSymbol()); }
@Override public PlanNode visitTableScan(TableScanNode node, RewriteContext<Set<Symbol>> context) { List<Symbol> newOutputs = node.getOutputSymbols().stream() .filter(context.get()::contains) .collect(toImmutableList()); Map<Symbol, ColumnHandle> newAssignments = newOutputs.stream() .collect(Collectors.toMap(Function.identity(), node.getAssignments()::get)); return new TableScanNode( node.getId(), node.getTable(), newOutputs, newAssignments, node.getLayout(), node.getCurrentConstraint(), node.getEnforcedConstraint()); }
@Override public PlanNode visitUnnest(UnnestNode node, RewriteContext<Set<Symbol>> context) { List<Symbol> replicateSymbols = node.getReplicateSymbols().stream() .filter(context.get()::contains) .collect(toImmutableList()); Optional<Symbol> ordinalitySymbol = node.getOrdinalitySymbol(); if (ordinalitySymbol.isPresent() && !context.get().contains(ordinalitySymbol.get())) { ordinalitySymbol = Optional.empty(); } Map<Symbol, List<Symbol>> unnestSymbols = node.getUnnestSymbols(); ImmutableSet.Builder<Symbol> expectedInputs = ImmutableSet.<Symbol>builder() .addAll(replicateSymbols) .addAll(unnestSymbols.keySet()); PlanNode source = context.rewrite(node.getSource(), expectedInputs.build()); return new UnnestNode(node.getId(), source, replicateSymbols, unnestSymbols, ordinalitySymbol); }
@Override public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Context> context) { // Lookup symbols can only be passed through the probe side of an index join Set<Symbol> probeLookupSymbols = context.get().getLookupSymbols().stream() .filter(node.getProbeSource().getOutputSymbols()::contains) .collect(toImmutableSet()); if (probeLookupSymbols.isEmpty()) { return node; } PlanNode rewrittenProbeSource = context.rewrite(node.getProbeSource(), new Context(probeLookupSymbols, context.get().getSuccess())); PlanNode source = node; if (rewrittenProbeSource != node.getProbeSource()) { source = new IndexJoinNode(node.getId(), node.getType(), rewrittenProbeSource, node.getIndexSource(), node.getCriteria(), node.getProbeHashSymbol(), node.getIndexHashSymbol()); } return source; }