@Test public void testGetElementPosition() { int elementCount = 100; // Set initialTypedSetEntryCount to a small number to trigger rehash() int initialTypedSetEntryCount = 10; TypedSet typedSet = new TypedSet(BIGINT, initialTypedSetEntryCount, FUNCTION_NAME); BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(elementCount); for (int i = 0; i < elementCount; i++) { BIGINT.writeLong(blockBuilder, i); typedSet.add(blockBuilder, i); } assertEquals(typedSet.size(), elementCount); for (int j = 0; j < blockBuilder.getPositionCount(); j++) { assertEquals(typedSet.positionOf(blockBuilder, j), j); } }
@TypeParameter("E") @SqlType("array(E)") public static Block except( @TypeParameter("E") Type type, @SqlType("array(E)") Block leftArray, @SqlType("array(E)") Block rightArray) { int leftPositionCount = leftArray.getPositionCount(); int rightPositionCount = rightArray.getPositionCount(); if (leftPositionCount == 0) { return leftArray; } TypedSet typedSet = new TypedSet(type, leftPositionCount + rightPositionCount, "array_except"); BlockBuilder distinctElementBlockBuilder = type.createBlockBuilder(null, leftPositionCount); for (int i = 0; i < rightPositionCount; i++) { typedSet.add(rightArray, i); } for (int i = 0; i < leftPositionCount; i++) { if (!typedSet.contains(leftArray, i)) { typedSet.add(leftArray, i); type.appendTo(leftArray, i, distinctElementBlockBuilder); } } return distinctElementBlockBuilder.build(); } }
public void add(Block block, int position) { requireNonNull(block, "block must not be null"); checkArgument(position >= 0, "position must be >= 0"); // containsNullElement flag is maintained so contains() method can have shortcut for null value if (block.isNull(position)) { containsNullElement = true; } int hashPosition = getHashPositionOfElement(block, position); if (blockPositionByHash.get(hashPosition) == EMPTY_SLOT) { addNewElement(hashPosition, block, position); } }
private static double mapDotProduct(Block leftMap, Block rightMap) { TypedSet rightMapKeys = new TypedSet(VARCHAR, rightMap.getPositionCount(), "cosine_similarity"); for (int i = 0; i < rightMap.getPositionCount(); i += 2) { rightMapKeys.add(rightMap, i); } double result = 0.0; for (int i = 0; i < leftMap.getPositionCount(); i += 2) { int position = rightMapKeys.positionOf(leftMap, i); if (position != -1) { result += DOUBLE.getDouble(leftMap, i + 1) * DOUBLE.getDouble(rightMap, 2 * position + 1); } } return result; }
private static void testBigintFor(TypedSet typedSet, Block longBlock) { Set<Long> set = new HashSet<>(); for (int blockPosition = 0; blockPosition < longBlock.getPositionCount(); blockPosition++) { long number = BIGINT.getLong(longBlock, blockPosition); assertEquals(typedSet.contains(longBlock, blockPosition), set.contains(number)); assertEquals(typedSet.size(), set.size()); set.add(number); typedSet.add(longBlock, blockPosition); assertEquals(typedSet.contains(longBlock, blockPosition), set.contains(number)); assertEquals(typedSet.size(), set.size()); } } }
private void testGetElementPositionRandomFor(TypedSet set) { BlockBuilder keys = VARCHAR.createBlockBuilder(null, 5); VARCHAR.writeSlice(keys, utf8Slice("hello")); VARCHAR.writeSlice(keys, utf8Slice("bye")); VARCHAR.writeSlice(keys, utf8Slice("abc")); for (int i = 0; i < keys.getPositionCount(); i++) { set.add(keys, i); } BlockBuilder values = VARCHAR.createBlockBuilder(null, 5); VARCHAR.writeSlice(values, utf8Slice("bye")); VARCHAR.writeSlice(values, utf8Slice("abc")); VARCHAR.writeSlice(values, utf8Slice("hello")); VARCHAR.writeSlice(values, utf8Slice("bad")); values.appendNull(); assertEquals(set.positionOf(values, 4), -1); assertEquals(set.positionOf(values, 2), 0); assertEquals(set.positionOf(values, 1), 2); assertEquals(set.positionOf(values, 0), 1); assertFalse(set.contains(values, 3)); set.add(values, 4); assertTrue(set.contains(values, 4)); }
@Test public void testMemoryExceeded() { try { TypedSet typedSet = new TypedSet(BIGINT, 10, FUNCTION_NAME); for (int i = 0; i <= TypedSet.FOUR_MEGABYTES + 1; i++) { Block block = createLongsBlock(nCopies(1, (long) i)); typedSet.add(block, 0); } fail("expected exception"); } catch (PrestoException e) { assertEquals(e.getErrorCode(), EXCEEDED_FUNCTION_MEMORY_LIMIT.toErrorCode()); } }
@Test public void testConstructor() { for (int i = -2; i <= -1; i++) { try { //noinspection ResultOfObjectAllocationIgnored new TypedSet(BIGINT, i, FUNCTION_NAME); fail("Should throw exception if expectedSize < 0"); } catch (IllegalArgumentException e) { // ignored } } try { //noinspection ResultOfObjectAllocationIgnored new TypedSet(null, 1, FUNCTION_NAME); fail("Should throw exception if type is null"); } catch (NullPointerException | IllegalArgumentException e) { // ignored } }
private void rehash() { long newCapacityLong = hashCapacity * 2L; if (newCapacityLong > Integer.MAX_VALUE) { throw new PrestoException(GENERIC_INSUFFICIENT_RESOURCES, "Size of hash table cannot exceed 1 billion entries"); } int newCapacity = (int) newCapacityLong; hashCapacity = newCapacity; hashMask = newCapacity - 1; maxFill = calculateMaxFill(newCapacity); blockPositionByHash.size(newCapacity); for (int i = 0; i < newCapacity; i++) { blockPositionByHash.set(i, EMPTY_SLOT); } for (int blockPosition = initialElementBlockOffset; blockPosition < elementBlock.getPositionCount(); blockPosition++) { blockPositionByHash.set(getHashPositionOfElement(elementBlock, blockPosition), blockPosition); } }
public int positionOf(Block block, int position) { return blockPositionByHash.get(getHashPositionOfElement(block, position)); }
private void rehash(int size) { int newHashSize = arraySize(size + 1, FILL_RATIO); hashMask = newHashSize - 1; maxFill = calculateMaxFill(newHashSize); blockPositionByHash.ensureCapacity(newHashSize); for (int i = 0; i < newHashSize; i++) { blockPositionByHash.set(i, EMPTY_SLOT); } rehashBlock(elementBlock); }
private void addNewElement(int hashPosition, Block block, int position) { elementType.appendTo(block, position, elementBlock); if (elementBlock.getSizeInBytes() - initialElementBlockSizeInBytes > FOUR_MEGABYTES) { throw new PrestoException( EXCEEDED_FUNCTION_MEMORY_LIMIT, format("The input to %s is too large. More than %s of memory is needed to hold the intermediate hash set.\n", functionName, MAX_FUNCTION_MEMORY)); } blockPositionByHash.set(hashPosition, elementBlock.getPositionCount() - 1); // increase capacity, if necessary size++; if (size >= maxFill) { rehash(); } }
public TypedSet(Type elementType, BlockBuilder blockBuilder, int expectedSize, String functionName) { checkArgument(expectedSize >= 0, "expectedSize must not be negative"); this.elementType = requireNonNull(elementType, "elementType must not be null"); this.elementBlock = requireNonNull(blockBuilder, "blockBuilder must not be null"); this.functionName = functionName; initialElementBlockOffset = elementBlock.getPositionCount(); initialElementBlockSizeInBytes = elementBlock.getSizeInBytes(); this.size = 0; this.hashCapacity = arraySize(expectedSize, FILL_RATIO); this.maxFill = calculateMaxFill(hashCapacity); this.hashMask = hashCapacity - 1; blockPositionByHash = new IntArrayList(hashCapacity); blockPositionByHash.size(hashCapacity); for (int i = 0; i < hashCapacity; i++) { blockPositionByHash.set(i, EMPTY_SLOT); } this.containsNullElement = false; }
/** * Get slot position of element at {@code position} of {@code block} */ private int getHashPositionOfElement(Block block, int position) { int hashPosition = getMaskedHash(hashPosition(elementType, block, position)); while (true) { int blockPosition = blockPositionByHash.get(hashPosition); // Doesn't have this element if (blockPosition == EMPTY_SLOT) { return hashPosition; } // Already has this element else if (positionEqualsPosition(elementType, elementBlock, blockPosition, block, position)) { return hashPosition; } hashPosition = getMaskedHash(hashPosition + 1); } }
@Test public void testGetElementPosition() throws Exception { int elementCount = 100; TypedSet typedSet = new TypedSet(BIGINT, elementCount); BlockBuilder blockBuilder = BIGINT.createFixedSizeBlockBuilder(elementCount); for (int i = 0; i < elementCount; i++) { BIGINT.writeLong(blockBuilder, i); typedSet.add(blockBuilder, i); } for (int j = 0; j < blockBuilder.getPositionCount(); j++) { assertEquals(typedSet.positionOf(blockBuilder, j), j); } }
/** * Only add this key value pair if we are in multi-value mode or we haven't met this key before. * Otherwise, ignore it. */ public void add(Block key, Block value, int keyPosition, int valuePosition) { if (isMultiValue || !keySet.contains(key, keyPosition)) { keySet.add(key, keyPosition); keyType.appendTo(key, keyPosition, keyBlockBuilder); if (value.isNull(valuePosition)) { valueBlockBuilder.appendNull(); } else { valueType.appendTo(value, valuePosition, valueBlockBuilder); } } } }
private static void testBigint(Block longBlock, int expectedSetSize) { TypedSet typedSet = new TypedSet(BIGINT, expectedSetSize, FUNCTION_NAME); testBigintFor(typedSet, longBlock); BlockBuilder emptyBlockBuilder = BIGINT.createBlockBuilder(null, expectedSetSize); TypedSet typedSetWithPassedInBuilder = new TypedSet(BIGINT, emptyBlockBuilder, expectedSetSize, FUNCTION_NAME); testBigintFor(typedSetWithPassedInBuilder, longBlock); }
public boolean contains(Block block, int position) { requireNonNull(block, "block must not be null"); checkArgument(position >= 0, "position must be >= 0"); if (block.isNull(position)) { return containsNullElement; } else { return blockPositionByHash.get(getHashPositionOfElement(block, position)) != EMPTY_SLOT; } }
private void addNewElement(int hashPosition, Block block, int position) { elementType.appendTo(block, position, elementBlock); if (elementBlock.getSizeInBytes() > FOUR_MEGABYTES) { throw exceededLocalLimit(new DataSize(4, MEGABYTE)); } blockPositionByHash.set(hashPosition, elementBlock.getPositionCount() - 1); // increase capacity, if necessary if (elementBlock.getPositionCount() >= maxFill) { rehash(maxFill * 2); } }