@Override public int hashCode() { return Objects.hash(sources, ImmutableSet.copyOf(extractConjuncts(filter)), outputSymbols); }
public static Expression filterConjuncts(Expression expression, Predicate<Expression> predicate) { List<Expression> conjuncts = extractConjuncts(expression).stream() .filter(predicate) .collect(toList()); return combineConjuncts(conjuncts); }
@Override public boolean equals(Object obj) { if (!(obj instanceof MultiJoinNode)) { return false; } MultiJoinNode other = (MultiJoinNode) obj; return this.sources.equals(other.sources) && ImmutableSet.copyOf(extractConjuncts(this.filter)).equals(ImmutableSet.copyOf(extractConjuncts(other.filter))) && this.outputSymbols.equals(other.outputSymbols); }
/** * Provides a convenience Iterable of Expression conjuncts which have not been added to the inference */ public static Iterable<Expression> nonInferrableConjuncts(Expression expression) { return filter(extractConjuncts(expression), not(isInferenceCandidate())); }
/** * Returns a subset of conjuncts matching one of the following shapes: * - ST_Contains(...) * - ST_Within(...) * - ST_Intersects(...) * <p> * Doesn't check or guarantee anything about function arguments. */ public static List<FunctionCall> extractSupportedSpatialFunctions(Expression filterExpression) { return extractConjuncts(filterExpression).stream() .filter(FunctionCall.class::isInstance) .map(FunctionCall.class::cast) .filter(SpatialJoinUtils::isSupportedSpatialFunction) .collect(toImmutableList()); }
/** * Returns a subset of conjuncts matching one the following shapes: * - ST_Distance(...) <= ... * - ST_Distance(...) < ... * - ... >= ST_Distance(...) * - ... > ST_Distance(...) * <p> * Doesn't check or guarantee anything about ST_Distance functions arguments * or the other side of the comparison. */ public static List<ComparisonExpression> extractSupportedSpatialComparisons(Expression filterExpression) { return extractConjuncts(filterExpression).stream() .filter(ComparisonExpression.class::isInstance) .map(ComparisonExpression.class::cast) .filter(SpatialJoinUtils::isSupportedSpatialComparison) .collect(toImmutableList()); }
public static Expression combineConjuncts(Collection<Expression> expressions) { requireNonNull(expressions, "expressions is null"); List<Expression> conjuncts = expressions.stream() .flatMap(e -> ExpressionUtils.extractConjuncts(e).stream()) .filter(e -> !e.equals(TRUE_LITERAL)) .collect(toList()); conjuncts = removeDuplicates(conjuncts); if (conjuncts.contains(FALSE_LITERAL)) { return FALSE_LITERAL; } return and(conjuncts); }
public static Optional<SortExpressionContext> extractSortExpression(Set<Symbol> buildSymbols, Expression filter) { List<Expression> filterConjuncts = ExpressionUtils.extractConjuncts(filter); SortExpressionVisitor visitor = new SortExpressionVisitor(buildSymbols); List<SortExpressionContext> sortExpressionCandidates = filterConjuncts.stream() .filter(DeterminismEvaluator::isDeterministic) .map(visitor::process) .filter(Optional::isPresent) .map(Optional::get) .collect(toMap(SortExpressionContext::getSortExpression, identity(), SortExpressionExtractor::merge)) .values() .stream() .collect(toImmutableList()); // For now heuristically pick sort expression which has most search expressions assigned to it. // TODO: make it cost based decision based on symbol statistics return sortExpressionCandidates.stream() .sorted(comparing(context -> -1 * context.getSearchExpressions().size())) .findFirst(); }
public Builder extractInferenceCandidates(Expression expression) { return addAllEqualities(filter(extractConjuncts(expression), isInferenceCandidate())); }
private boolean canConvertOuterToInner(List<Symbol> innerSymbolsForOuterJoin, Expression inheritedPredicate) { Set<Symbol> innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin); for (Expression conjunct : extractConjuncts(inheritedPredicate)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { // Ignore a conjunct for this test if we can not deterministically get responses from it Object response = nullInputEvaluator(innerSymbols, conjunct); if (response == null || response instanceof NullLiteral || Boolean.FALSE.equals(response)) { // If there is a single conjunct that returns FALSE or NULL given all NULL inputs for the inner side symbols of an outer join // then this conjunct removes all effects of the outer join, and effectively turns this into an equivalent of an inner join. // So, let's just rewrite this join as an INNER join return true; } } } return false; }
private static Optional<SemiJoinOutputFilter> extractSemiJoinOutputFilter(Expression predicate, Symbol semiJoinOutput) { List<Expression> conjuncts = extractConjuncts(predicate); List<Expression> semiJoinOutputReferences = conjuncts.stream() .filter(conjunct -> isSemiJoinOutputReference(conjunct, semiJoinOutput)) .collect(toImmutableList()); if (semiJoinOutputReferences.size() != 1) { return Optional.empty(); } Expression semiJoinOutputReference = Iterables.getOnlyElement(semiJoinOutputReferences); Expression remainingPredicate = combineConjuncts(conjuncts.stream() .filter(conjunct -> conjunct != semiJoinOutputReference) .collect(toImmutableList())); boolean negated = semiJoinOutputReference instanceof NotExpression; return Optional.of(new SemiJoinOutputFilter(negated, remainingPredicate)); }
private Expression deriveCommonPredicates(PlanNode node, Function<Integer, Collection<Map.Entry<Symbol, SymbolReference>>> mapping) { // Find the predicates that can be pulled up from each source List<Set<Expression>> sourceOutputConjuncts = new ArrayList<>(); for (int i = 0; i < node.getSources().size(); i++) { Expression underlyingPredicate = node.getSources().get(i).accept(this, null); List<Expression> equalities = mapping.apply(i).stream() .filter(SYMBOL_MATCHES_EXPRESSION.negate()) .map(ENTRY_TO_EQUALITY) .collect(toImmutableList()); sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughSymbols(combineConjuncts( ImmutableList.<Expression>builder() .addAll(equalities) .add(underlyingPredicate) .build()), node.getOutputSymbols())))); } // Find the intersection of predicates across all sources // TODO: use a more precise way to determine overlapping conjuncts (e.g. commutative predicates) Iterator<Set<Expression>> iterator = sourceOutputConjuncts.iterator(); Set<Expression> potentialOutputConjuncts = iterator.next(); while (iterator.hasNext()) { potentialOutputConjuncts = Sets.intersection(potentialOutputConjuncts, iterator.next()); } return combineConjuncts(potentialOutputConjuncts); }
@Override public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Expression> context) { Set<Symbol> pushDownableSymbols = ImmutableSet.copyOf(node.getDistinctSymbols()); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream() .collect(Collectors.partitioningBy(conjunct -> SymbolsExtractor.extractUnique(conjunct).stream().allMatch(pushDownableSymbols::contains))); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(conjuncts.get(true))); if (!conjuncts.get(false).isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; }
@Override public PlanNode visitGroupId(GroupIdNode node, RewriteContext<Expression> context) { Map<Symbol, SymbolReference> commonGroupingSymbolMapping = node.getGroupingColumns().entrySet().stream() .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey())) .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); Predicate<Expression> pushdownEligiblePredicate = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(commonGroupingSymbolMapping.keySet()::contains); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate)); // Push down conjuncts from the inherited predicate that apply to common grouping symbols PlanNode rewrittenNode = context.defaultRewrite(node, inlineSymbols(commonGroupingSymbolMapping, combineConjuncts(conjuncts.get(true)))); // All other conjuncts, if any, will be in the filter node. if (!conjuncts.get(false).isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; }
private void assertGetSortExpression(Expression expression, String expectedSymbol) { // for now we expect that search expressions contain all the conjuncts from filterExpression as more complex cases are not supported yet. assertGetSortExpression(expression, expectedSymbol, extractConjuncts(expression)); }
@Override public PlanNode visitProject(ProjectNode node, RewriteContext<Expression> context) { Set<Symbol> deterministicSymbols = node.getAssignments().entrySet().stream() .filter(entry -> DeterminismEvaluator.isDeterministic(entry.getValue())) .map(Map.Entry::getKey) .collect(Collectors.toSet()); Predicate<Expression> deterministic = conjunct -> SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(deterministicSymbols::contains); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic)); // Push down conjuncts from the inherited predicate that only depend on deterministic assignments with // certain limitations. List<Expression> deterministicConjuncts = conjuncts.get(true); // We partition the expressions in the deterministicConjuncts into two lists, and only inline the // expressions that are in the inlining targets list. Map<Boolean, List<Expression>> inlineConjuncts = deterministicConjuncts.stream() .collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node))); List<Expression> inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() .map(entry -> inlineSymbols(node.getAssignments().getMap(), entry)) .collect(Collectors.toList()); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(inlinedDeterministicConjuncts)); // All deterministic conjuncts that contains non-inlining targets, and non-deterministic conjuncts, // if any, will be in the filter node. List<Expression> nonInliningConjuncts = inlineConjuncts.get(false); nonInliningConjuncts.addAll(conjuncts.get(false)); if (!nonInliningConjuncts.isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(nonInliningConjuncts)); } return rewrittenNode; }
@Override public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Expression> context) { Expression inheritedPredicate = context.get(); if (!extractConjuncts(inheritedPredicate).contains(node.getSemiJoinOutput().toSymbolReference())) { return visitNonFilteringSemiJoin(node, context); } return visitFilteringSemiJoin(node, context); }
@Override public PlanNode visitWindow(WindowNode node, RewriteContext<Expression> context) { List<Symbol> partitionSymbols = node.getPartitionBy(); // TODO: This could be broader. We can push down conjucts if they are constant for all rows in a window partition. // The simplest way to guarantee this is if the conjucts are deterministic functions of the partitioning symbols. // This can leave out cases where they're both functions of some set of common expressions and the partitioning // function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by // pre-projected symbols. Predicate<Expression> isSupported = conjunct -> DeterminismEvaluator.isDeterministic(conjunct) && SymbolsExtractor.extractUnique(conjunct).stream() .allMatch(partitionSymbols::contains); Map<Boolean, List<Expression>> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); PlanNode rewrittenNode = context.defaultRewrite(node, combineConjuncts(conjuncts.get(true))); if (!conjuncts.get(false).isEmpty()) { rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, combineConjuncts(conjuncts.get(false))); } return rewrittenNode; }
@Override public Expression visitSpatialJoin(SpatialJoinNode node, Void context) { Expression leftPredicate = node.getLeft().accept(this, context); Expression rightPredicate = node.getRight().accept(this, context); switch (node.getType()) { case INNER: return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) .build()); case LEFT: return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .build()); default: throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType()); } }
return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(leftPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .build()); return combineConjuncts(ImmutableList.<Expression>builder() .add(pullExpressionThroughSymbols(rightPredicate, node.getOutputSymbols())) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .build()); case FULL: return combineConjuncts(ImmutableList.<Expression>builder() .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputSymbols(), node.getRight().getOutputSymbols()::contains)) .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputSymbols(), node.getLeft().getOutputSymbols()::contains, node.getRight().getOutputSymbols()::contains)) .build());