public boolean hasType(Expr expr) { return lookupType(expr) != null; }
@Override public void visitParenExpr(ParenExpr parenExpr) { super.visitParenExpr(parenExpr); Type type = lookupType(parenExpr.getExpr()); if (type != null) { types.put(parenExpr, type); } }
private boolean isOneFloatVec(Expr expr) { if (!(expr instanceof TypeConstructorExpr)) { return false; } if (!Arrays.asList(BasicType.VEC2, BasicType.VEC3, BasicType.VEC4) .contains(typer.lookupType(expr))) { return false; } return ((TypeConstructorExpr) expr).getArgs() .stream().allMatch(item -> isOneFloat(item) || isOneFloatVec(item)); }
private boolean isZeroFloatVecOrSquareMat(Expr expr) { if (!(expr instanceof TypeConstructorExpr)) { return false; } if (!Arrays.asList(BasicType.VEC2, BasicType.VEC3, BasicType.VEC4, BasicType.MAT2X2, BasicType.MAT3X3, BasicType.MAT4X4).contains(typer.lookupType(expr))) { return false; } return ((TypeConstructorExpr) expr).getArgs() .stream().allMatch(item -> isZeroFloat(item) || isZeroFloatVecOrSquareMat(item)); }
private void findFoldAddZeroOpportunities(IAstNode parent, Expr child, Expr thisHandSide, Expr thatHandSide) { if (isZeroFloat(thisHandSide)) { addReplaceWithExpr(parent, child, thatHandSide); } final Type childType = typer.lookupType(child); final Type thatHandSideType = typer.lookupType(thatHandSide); if (childType != null && thatHandSideType != null && childType.getWithoutQualifiers().equals(thatHandSideType.getWithoutQualifiers())) { if (isZeroFloatVecOrSquareMat(thisHandSide)) { addReplaceWithExpr(parent, child, thatHandSide); } } }
private void findFoldSomethingSubZeroOpportunities(IAstNode parent, Expr child, Expr lhs, Expr rhs) { if (isZeroFloat(rhs)) { addReplaceWithExpr(parent, child, lhs); } final Type childType = typer.lookupType(child); final Type lhsType = typer.lookupType(lhs); if (childType != null && lhsType != null && childType.getWithoutQualifiers().equals(lhsType.getWithoutQualifiers())) { if (isZeroFloatVecOrSquareMat(rhs)) { addReplaceWithExpr(parent, child, lhs); } } }
private void findFoldMulZeroOpportunities(IAstNode parent, Expr child, Expr thisHandSide, Expr thatHandSide) { if (isZeroFloat(thisHandSide)) { addReplaceWithZero(parent, child); } final Type childType = typer.lookupType(child); final Type thatHandSideType = typer.lookupType(thatHandSide); if (childType != null && thatHandSideType != null && childType.getWithoutQualifiers().equals(thatHandSideType.getWithoutQualifiers())) { if (isZeroFloatVecOrSquareMat(thisHandSide)) { addReplaceWithZero(parent, child); } } }
private void findFoldSomethingDivOneOpportunities(IAstNode parent, Expr child, Expr lhs, Expr rhs) { if (isOneFloat(rhs)) { addReplaceWithExpr(parent, child, lhs); } final Type childType = typer.lookupType(child); final Type lhsType = typer.lookupType(lhs); if (childType != null && lhsType != null && childType.getWithoutQualifiers().equals(lhsType.getWithoutQualifiers())) { if (isOneFloatVec(rhs)) { addReplaceWithExpr(parent, child, lhs); } } }
private Set<FunctionPrototype> findPossibleMatchesForCall(FunctionCallExpr functionCallExpr) { Set<FunctionPrototype> candidates = declaredFunctions.stream() .filter(proto -> proto.getName().equals(functionCallExpr.getCallee())) .filter(proto -> proto.getNumParameters() == functionCallExpr.getNumArgs()) .collect(Collectors.toSet()); for (int i = 0; i < functionCallExpr.getNumArgs(); i++) { if (!typer.hasType(functionCallExpr.getArg(i))) { // If we don't have a type for this argument, we're OK with any function prototype's type continue; } final int currentIndex = i; // Capture i in final variable so it can be used in lambda. candidates = candidates.stream().filter(proto -> typer.lookupType(functionCallExpr.getArg(currentIndex)).getWithoutQualifiers() .equals(proto.getParameters().get(currentIndex).getType().getWithoutQualifiers())) .collect(Collectors.toSet()); } return candidates; }
private void findFoldMulIdentityOpportunities(IAstNode parent, Expr child, Expr thisHandSide, Expr thatHandSide) { if (isOneFloat(thisHandSide)) { addReplaceWithExpr(parent, child, thatHandSide); } final Type childType = typer.lookupType(child); final Type thisHandSideType = typer.lookupType(thisHandSide); if (childType != null && thisHandSideType != null && childType.getWithoutQualifiers().equals(thisHandSideType.getWithoutQualifiers())) { if (isOneFloatVec(thisHandSide) || isIdentityMatrix(thisHandSide)) { addReplaceWithExpr(parent, child, thatHandSide); } } }
@Override void identifyReductionOpportunitiesForChild(IAstNode parent, Expr child) { if (allowedToReduceExpr(parent, child) && !inLValueContext() && typeIsReducibleToConst(typer.lookupType(child)) && !isFullyReducedConstant(child)) { addOpportunity(new SimplifyExprReductionOpportunity( parent, typer.lookupType(child).getCanonicalConstant(), child, getVistitationDepth())); } }
@Override public void visitMemberLookupExpr(MemberLookupExpr memberLookupExpr) { super.visitMemberLookupExpr(memberLookupExpr); Type type = typer.lookupType(memberLookupExpr.getStructure()); if (type == null) { return; } type = type.getWithoutQualifiers(); if (!(type instanceof StructNameType)) { return; } final StructNameType structType = (StructNameType) type; assert structFieldRenaming.containsKey(structType.getName()); final Map<String, String> fieldRenaming = structFieldRenaming.get(structType.getName()); assert fieldRenaming.containsKey(memberLookupExpr.getMember()); memberLookupExpr.setMember(fieldRenaming .get(memberLookupExpr.getMember())); }
@Override public void visitArrayIndexExpr(ArrayIndexExpr arrayIndexExpr) { super.visitArrayIndexExpr(arrayIndexExpr); Type arrayType = lookupType(arrayIndexExpr.getArray()); if (arrayType == null) { return; } arrayType = arrayType.getWithoutQualifiers(); Type elementType; if (BasicType.allVectorTypes().contains(arrayType)) { elementType = ((BasicType) arrayType).getElementType(); } else if (BasicType.allMatrixTypes().contains(arrayType)) { elementType = ((BasicType) arrayType).getColumnType(); } else { assert arrayType instanceof ArrayType; elementType = ((ArrayType) arrayType).getBaseType(); } types.put(arrayIndexExpr, elementType); }
private void identifyMutationPoints(Expr expr, Set<Integer> indicesToSkip) { if (insideLValueCount == 0) { for (int i = 0; i < expr.getNumChildren(); i++) { if (indicesToSkip.contains(i)) { continue; } if (typer.hasType(expr.getChild(i))) { Scope clonedScope = currentScope.shallowClone(); if (shadingLanguageVersion.restrictedForLoops()) { for (Set<String> iterators : forLoopIterators) { iterators.forEach(clonedScope::remove); } } mutationPoints.add(new MutationPoint(expr, i, typer.lookupType(expr.getChild(i)), clonedScope, isConstContext(), shadingLanguageVersion, generator, generationParams)); } } } }
private boolean callMatchesPrototype(FunctionCallExpr call, FunctionPrototype prototype) { assert call.getNumArgs() == prototype.getNumParameters(); for (int i = 0; i < call.getNumArgs(); i++) { Type argType = typer.lookupType(call.getArg(i)); if (argType == null) { // With incomplete information we say there is a match continue; } if (!typesMatchWithoutQualifiers(argType, prototype.getParameters().get(i).getType())) { return false; } } return true; }
private boolean functionMatches(FunctionDefinition declaration) { final FunctionPrototype prototype = declaration.getPrototype(); if (!prototype.getName().equals(call.getCallee())) { return false; } if (prototype.getNumParameters() != call.getNumArgs()) { return false; } for (int i = 0; i < prototype.getNumParameters(); i++) { if (typer.lookupType(call.getArg(i)) == null) { continue; } if (!typer.lookupType(call.getArg(i)).getWithoutQualifiers() .equals(prototype.getParameter(i).getType().getWithoutQualifiers())) { return false; } } return true; }
private void findFoldZeroSubSomethingOpportunities(IAstNode parent, Expr child, Expr lhs, Expr rhs) { if (isZeroFloat(lhs)) { addReplaceWithExpr(parent, child, new ParenExpr(new UnaryExpr(rhs, UnOp.MINUS))); } final Type childType = typer.lookupType(child); final Type rhsType = typer.lookupType(rhs); if (childType != null && rhsType != null && childType.getWithoutQualifiers().equals(rhsType.getWithoutQualifiers())) { if (isZeroFloatVecOrSquareMat(lhs)) { addReplaceWithExpr(parent, child, new ParenExpr(new UnaryExpr(rhs, UnOp.MINUS))); } } }
final Type exprType = typer.lookupType(expr); if (!Arrays.asList(BasicType.MAT2X2, BasicType.MAT3X3, BasicType.MAT4X4) .contains(exprType)) {
@Override void identifyReductionOpportunitiesForChild(IAstNode parent, Expr child) { if (!allowedToReduceExpr(parent, child)) { return; } if (inLValueContext()) { return; } final Type resultType = typer.lookupType(child); if (resultType == null) { return; } for (int i = 0; i < child.getNumChildren(); i++) { final Expr subExpr = child.getChild(i); final Type subExprType = typer.lookupType(subExpr); if (subExprType == null) { continue; } if (!subExprType.getWithoutQualifiers().equals(resultType.getWithoutQualifiers())) { continue; } addOpportunity(new SimplifyExprReductionOpportunity( parent, subExpr, child, // We mark this as deeper since we would prefer to reduce the root expression // to a constant. getVistitationDepth().deeper())); } }
@Override public void visitArrayIndexExpr(ArrayIndexExpr arrayIndexExpr) { Type type = typer.lookupType(arrayIndexExpr.getArray()); if (type == null) { return; } type = type.getWithoutQualifiers(); assert isArrayVectorOrMatrix(type); if (!staticallyInBounds(arrayIndexExpr.getIndex(), type)) { arrayIndexExpr.setIndex(new TernaryExpr( new BinaryExpr( new BinaryExpr( new ParenExpr(arrayIndexExpr.getIndex().clone()), new IntConstantExpr("0"), BinOp.GE), new BinaryExpr( new ParenExpr(arrayIndexExpr.getIndex().clone()), new IntConstantExpr(getSize(type).toString()), BinOp.LT), BinOp.LAND), arrayIndexExpr.getIndex(), new IntConstantExpr("0")) ); } super.visitArrayIndexExpr(arrayIndexExpr); }