private static Expression pullExpressionThroughSymbols(Expression expression, Collection<Symbol> symbols) { EqualityInference equalityInference = createEqualityInference(expression); ImmutableList.Builder<Expression> effectiveConjuncts = ImmutableList.builder(); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(expression)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { Expression rewritten = equalityInference.rewriteExpression(conjunct, in(symbols)); if (rewritten != null) { effectiveConjuncts.add(rewritten); } } } effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(symbols)).getScopeEqualities()); return combineConjuncts(effectiveConjuncts.build()); } }
private Expression rewriteExpression(Expression expression, Predicate<Symbol> symbolScope, boolean allowFullReplacement) { Iterable<Expression> subExpressions = SubExpressionExtractor.extract(expression); if (!allowFullReplacement) { subExpressions = filter(subExpressions, not(equalTo(expression))); } ImmutableMap.Builder<Expression, Expression> expressionRemap = ImmutableMap.builder(); for (Expression subExpression : subExpressions) { Expression canonical = getScopedCanonical(subExpression, symbolScope); if (canonical != null) { expressionRemap.put(subExpression, canonical); } } // Perform a naive single-pass traversal to try to rewrite non-compliant portions of the tree. Prefers to replace // larger subtrees over smaller subtrees // TODO: this rewrite can probably be made more sophisticated Expression rewritten = ExpressionTreeRewriter.rewriteWith(new ExpressionNodeInliner(expressionRemap.build()), expression); if (!symbolToExpressionPredicate(symbolScope).apply(rewritten)) { // If the rewritten is still not compliant with the symbol scope, just give up return null; } return rewritten; }
/** * Provides a convenience Iterable of Expression conjuncts which have not been added to the inference */ public static Iterable<Expression> nonInferrableConjuncts(Expression expression) { return filter(extractConjuncts(expression), not(isInferenceCandidate())); }
/** * Returns a canonical expression that is fully contained by the symbolScope and that is equivalent * to the specified expression. Returns null if unable to to find a canonical. */ @VisibleForTesting Expression getScopedCanonical(Expression expression, Predicate<Symbol> symbolScope) { Expression canonicalIndex = canonicalMap.get(expression); if (canonicalIndex == null) { return null; } return getCanonical(filter(equalitySets.get(canonicalIndex), symbolToExpressionPredicate(symbolScope))); }
private JoinEnumerationResult getJoinSource(LinkedHashSet<PlanNode> nodes, List<Symbol> outputSymbols) { if (nodes.size() == 1) { PlanNode planNode = getOnlyElement(nodes); ImmutableList.Builder<Expression> predicates = ImmutableList.builder(); predicates.addAll(allFilterInference.generateEqualitiesPartitionedBy(outputSymbols::contains).getScopeEqualities()); stream(nonInferrableConjuncts(allFilter)) .map(conjunct -> allFilterInference.rewriteExpression(conjunct, outputSymbols::contains)) .filter(Objects::nonNull) .forEach(predicates::add); Expression filter = combineConjuncts(predicates.build()); if (!TRUE_LITERAL.equals(filter)) { planNode = new FilterNode(idAllocator.getNextId(), planNode, filter); } return createJoinEnumerationResult(planNode); } return chooseJoinOrder(nodes, outputSymbols); }
EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b")); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate())); .build(); EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(symbolBeginsWith("a", "b"));
/** * Attempts to rewrite an Expression in terms of the symbols allowed by the symbol scope * given the known equalities. Returns null if unsuccessful. * This method allows rewriting non-deterministic expressions. */ public Expression rewriteExpressionAllowNonDeterministic(Expression expression, Predicate<Symbol> symbolScope) { return rewriteExpression(expression, symbolScope, true); }
@Test public void testConstantEqualities() { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("d1", "d2")), someExpression("d1", "d2")); inference.rewriteExpression(someExpression("a1", "c1"), matchesSymbols("b1")), someExpression("b1", "b1")); inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "d2", "c3")), someExpression("b1", "d2")); inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")), inference.getScopedCanonical(nameReference("b2"), matchesSymbols("c2", "d2"))); Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")); inference.rewriteExpression(someExpression("a2", "b2"), matchesSymbols("c2", "d2")), someExpression(canonical, canonical));
Expression scopeRewritten = rewriteExpression(expression, symbolScope, false); if (scopeRewritten != null) { scopeExpressions.add(scopeRewritten); Expression scopeComplementRewritten = rewriteExpression(expression, not(symbolScope), false); if (scopeComplementRewritten != null) { scopeComplementExpressions.add(scopeComplementRewritten); Expression matchingCanonical = getCanonical(scopeExpressions); if (scopeExpressions.size() >= 2) { for (Expression expression : filter(scopeExpressions, not(equalTo(matchingCanonical)))) { Expression complementCanonical = getCanonical(scopeComplementExpressions); if (scopeComplementExpressions.size() >= 2) { for (Expression expression : filter(scopeComplementExpressions, not(equalTo(complementCanonical)))) { connectingExpressions.addAll(scopeStraddlingExpressions); connectingExpressions = ImmutableList.copyOf(filter(connectingExpressions, Predicates.notNull())); Expression connectingCanonical = getCanonical(connectingExpressions); if (connectingCanonical != null) { for (Expression expression : filter(connectingExpressions, not(equalTo(connectingCanonical)))) {
@Test public void testEqualityGeneration() { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.addEquality(nameReference("a1"), add("b", "c")); // a1 = b + c builder.addEquality(nameReference("e1"), add("b", "d")); // e1 = b + d addEquality("c", "d", builder); EqualityInference inference = builder.build(); Expression scopedCanonical = inference.getScopedCanonical(nameReference("e1"), symbolBeginsWith("a")); assertEquals(scopedCanonical, nameReference("a1")); }
@VisibleForTesting JoinEnumerator(CostComparator costComparator, Expression filter, Context context) { this.context = requireNonNull(context); this.session = requireNonNull(context.getSession(), "session is null"); this.costProvider = requireNonNull(context.getCostProvider(), "costProvider is null"); this.resultComparator = costComparator.forSession(session).onResultOf(result -> result.cost); this.idAllocator = requireNonNull(context.getIdAllocator(), "idAllocator is null"); this.allFilter = requireNonNull(filter, "filter is null"); this.allFilterInference = createEqualityInference(filter); this.lookup = requireNonNull(context.getLookup(), "lookup is null"); }
@Test public void testExpressionsThatMayReturnNullOnNonNullInput() { List<Expression> candidates = ImmutableList.of( new Cast(nameReference("b"), "BIGINT", true), // try_cast new FunctionCall(QualifiedName.of("try"), ImmutableList.of(nameReference("b"))), new NullIfExpression(nameReference("b"), number(1)), new IfExpression(nameReference("b"), number(1), new NullLiteral()), new DereferenceExpression(nameReference("b"), identifier("x")), new InPredicate(nameReference("b"), new InListExpression(ImmutableList.of(new NullLiteral()))), new SearchedCaseExpression(ImmutableList.of(new WhenClause(new IsNotNullPredicate(nameReference("b")), new NullLiteral())), Optional.empty()), new SimpleCaseExpression(nameReference("b"), ImmutableList.of(new WhenClause(number(1), new NullLiteral())), Optional.empty()), new SubscriptExpression(new ArrayConstructor(ImmutableList.of(new NullLiteral())), nameReference("b"))); for (Expression candidate : candidates) { EqualityInference.Builder builder = new EqualityInference.Builder(); builder.extractInferenceCandidates(equals(nameReference("b"), nameReference("x"))); builder.extractInferenceCandidates(equals(nameReference("a"), candidate)); EqualityInference inference = builder.build(); List<Expression> equalities = inference.generateEqualitiesPartitionedBy(matchesSymbols("b")).getScopeStraddlingEqualities(); assertEquals(equalities.size(), 1); assertTrue(equalities.get(0).equals(equals(nameReference("x"), nameReference("b"))) || equalities.get(0).equals(equals(nameReference("b"), nameReference("x")))); } }
public EqualityInference build() { generateMoreEquivalences(); return new EqualityInference(equalities.getEquivalentClasses(), derivedExpressions); } }
/** * Determines whether an Expression may be successfully applied to the equality inference */ public static Predicate<Expression> isInferenceCandidate() { return expression -> { expression = normalizeInPredicateToEquality(expression); if (expression instanceof ComparisonExpression && isDeterministic(expression) && !mayReturnNullOnNonNullInput(expression)) { ComparisonExpression comparison = (ComparisonExpression) expression; if (comparison.getOperator() == ComparisonExpression.Operator.EQUAL) { // We should only consider equalities that have distinct left and right components return !comparison.getLeft().equals(comparison.getRight()); } } return false; }; }
EqualityInference.EqualityPartition emptyScopePartition = inference.generateEqualitiesPartitionedBy(Predicates.alwaysFalse()); EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("c1")); assertTrue(Iterables.all(equalityPartition.getScopeEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.all(equalityPartition.getScopeComplementEqualities(), EqualityInference.isInferenceCandidate())); assertTrue(Iterables.all(equalityPartition.getScopeStraddlingEqualities(), EqualityInference.isInferenceCandidate())); .build(); EqualityInference.EqualityPartition newEqualityPartition = newInference.generateEqualitiesPartitionedBy(matchesSymbols("c1"));
/** * Attempts to rewrite an Expression in terms of the symbols allowed by the symbol scope * given the known equalities. Returns null if unsuccessful. * This method checks if rewritten expression is non-deterministic. */ public Expression rewriteExpression(Expression expression, Predicate<Symbol> symbolScope) { checkArgument(isDeterministic(expression), "Only deterministic expressions may be considered for rewrite"); return rewriteExpression(expression, symbolScope, true); }
@Test public void testConstantEqualities() throws Exception { EqualityInference.Builder builder = new EqualityInference.Builder(); addEquality("a1", "b1", builder); addEquality("b1", "c1", builder); builder.addEquality(nameReference("c1"), number(1)); EqualityInference inference = builder.build(); // Should always prefer a constant if available (constant is part of all scopes) assertEquals(inference.rewriteExpression(nameReference("a1"), matchesSymbols("a1", "b1")), number(1)); // All scope equalities should utilize the constant if possible EqualityInference.EqualityPartition equalityPartition = inference.generateEqualitiesPartitionedBy(matchesSymbols("a1", "b1")); assertEquals(equalitiesAsSets(equalityPartition.getScopeEqualities()), set(set(nameReference("a1"), number(1)), set(nameReference("b1"), number(1)))); assertEquals(equalitiesAsSets(equalityPartition.getScopeComplementEqualities()), set(set(nameReference("c1"), number(1)))); // There should be no scope straddling equalities as the full set of equalities should be already represented by the scope and inverse scope assertTrue(equalityPartition.getScopeStraddlingEqualities().isEmpty()); }
inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("d1", "d2")), someExpression("d1", "d2")); inference.rewriteExpression(someExpression("a1", "c1"), matchesSymbols("b1")), someExpression("b1", "b1")); inference.rewriteExpression(someExpression("a1", "a2"), matchesSymbols("b1", "d2", "c3")), someExpression("b1", "d2")); inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")), inference.getScopedCanonical(nameReference("b2"), matchesSymbols("c2", "d2"))); Expression canonical = inference.getScopedCanonical(nameReference("a2"), matchesSymbols("c2", "d2")); inference.rewriteExpression(someExpression("a2", "b2"), matchesSymbols("c2", "d2")), someExpression(canonical, canonical));
Expression scopeRewritten = rewriteExpression(expression, symbolScope, false); if (scopeRewritten != null) { scopeExpressions.add(scopeRewritten); Expression scopeComplementRewritten = rewriteExpression(expression, not(symbolScope), false); if (scopeComplementRewritten != null) { scopeComplementExpressions.add(scopeComplementRewritten); Expression matchingCanonical = getCanonical(scopeExpressions); if (scopeExpressions.size() >= 2) { for (Expression expression : filter(scopeExpressions, not(equalTo(matchingCanonical)))) { Expression complementCanonical = getCanonical(scopeComplementExpressions); if (scopeComplementExpressions.size() >= 2) { for (Expression expression : filter(scopeComplementExpressions, not(equalTo(complementCanonical)))) { connectingExpressions.addAll(scopeStraddlingExpressions); connectingExpressions = ImmutableList.copyOf(filter(connectingExpressions, Predicates.notNull())); Expression connectingCanonical = getCanonical(connectingExpressions); if (connectingCanonical != null) { for (Expression expression : filter(connectingExpressions, not(equalTo(connectingCanonical)))) {