@Test public void testEmpty() { KeyIndex empty = new FrozenHashKeyIndex(LongLists.EMPTY_LIST); assertThat(empty.getKeyList(), hasSize(0)); assertThat(empty.containsKey(30), equalTo(false)); assertThat(empty.tryGetIndex(30), equalTo(-1)); try { empty.getIndex(30); fail("getting absent index should fail"); } catch (IllegalArgumentException ex) { /* expected */ } }
@Override public RandomDataSplitStrategy get() { final int userNum = snapshot.userIndex().size(); final int itemNum = snapshot.itemIndex().size(); logger.info("Rating matrix size: {} users and {} items", userNum, itemNum); List<RatingMatrixEntry> allRatings = new ArrayList<>(snapshot.getRatings()); final int size = allRatings.size(); final int validationSize = Math.toIntExact(Math.round(size*proportion)); logger.info("validation set size: {} ratings", validationSize); Collections.shuffle(allRatings, random); List<RatingMatrixEntry> subList = allRatings.subList(0, validationSize); final List<RatingMatrixEntry> validationRatings = ImmutableList.copyOf(subList); subList.clear(); logger.info("validation rating size: {}", validationRatings.size()); final KeyIndex userIndex = snapshot.userIndex(); final KeyIndex itemIndex = snapshot.itemIndex(); return new RandomDataSplitStrategy(allRatings, validationRatings, userIndex, itemIndex); } }
@Override public long getUserId() { return userIndex.getKey(getUserIndex()); }
/** * Create a map from an array and index mapping. * * @param mapping The index mapping specifying the keys. * @param values The array of values. * @return A sparse vector mapping the IDs in {@code map} to the values in {@code values}. * @throws IllegalArgumentException if {@code values} not the same size as {@code idx}. */ public static Long2DoubleSortedArrayMap fromArray(KeyIndex mapping, DoubleList values) { Preconditions.checkArgument(values.size() == mapping.size(), "value array and index have different sizes: " + values.size() + " != " + mapping.size()); final int n = values.size(); double[] nvs = new double[n]; SortedKeyIndex index = SortedKeyIndex.fromCollection(mapping.getKeyList()); for (int i = 0; i < n; i++) { long item = index.getKey(i); int origIndex = mapping.getIndex(item); nvs[i] = values.getDouble(origIndex); } return wrap(index, nvs); }
@Test public void testUserIndex() { KeyIndex ind = snap.userIndex(); assertEquals(6, ind.size()); assertTrue(ind.getKeyList().contains(1)); assertTrue(ind.getKeyList().contains(3)); assertTrue(ind.getKeyList().contains(4)); assertTrue(ind.getKeyList().contains(5)); assertTrue(ind.getKeyList().contains(6)); assertTrue(ind.getKeyList().contains(7)); assertEquals(0, ind.getIndex(1)); assertEquals(1, ind.getIndex(3)); assertEquals(2, ind.getIndex(4)); assertEquals(3, ind.getIndex(5)); assertEquals(4, ind.getIndex(6)); assertEquals(5, ind.getIndex(7)); assertEquals(1, ind.getKey(0)); assertEquals(3, ind.getKey(1)); assertEquals(4, ind.getKey(2)); assertEquals(5, ind.getKey(3)); assertEquals(6, ind.getKey(4)); assertEquals(7, ind.getKey(5)); }
@Override public Collection<RatingMatrixEntry> getUserRatings(long userId) { int uidx = userIndex().tryGetIndex(userId); List<Collection<RatingMatrixEntry>> userLists = userIndexLists.get(); if (uidx < 0 || uidx >= userLists.size()) { return Collections.emptyList(); } else { return userLists.get(uidx); } }
long userId = re.getUserId(); long itemId = re.getItemId(); assertThat(user, equalTo(userIndex.tryGetIndex(userId))); assertThat(item, equalTo(itemIndex.tryGetIndex(itemId))); int user = re.getUserIndex(); int item = re.getItemIndex(); long userId = userIndex.getKey(user); long itemId = itemIndex.getKey(item); assertThat(user, equalTo(snapshotUserIndex.tryGetIndex(userId))); assertThat(item, equalTo(snapshotItemIndex.tryGetIndex(itemId))); int size = snapshotUserIndex.getUpperBound(); for (int i = 0; i < size; i++) { assertThat(snapshotUserIndex.getKey(i), equalTo(userIndex.getKey(i))); size = snapshotItemIndex.getUpperBound(); for (int i = 0; i < size; i++) { assertThat(snapshotItemIndex.getKey(i), equalTo(itemIndex.getKey(i)));
@Test public void testItemIndex() { KeyIndex ind = snap.itemIndex(); assertEquals(5, ind.size()); assertThat(ind.getKeyList(), containsInAnyOrder(7L, 8L, 9L, 10L, 11L)); }
@Test public void testImmutableCopy() { HashKeyIndex idx = new HashKeyIndex(); assertThat(idx.internId(42), equalTo(0)); assertThat(idx.internId(39), equalTo(1)); KeyIndex imm = idx.frozenCopy(); assertThat(imm.getKey(0), equalTo(42L)); assertThat(imm.getKey(1), equalTo(39L)); assertThat(imm.getIndex(42), equalTo(0)); } }
@Override public LongCollection getItemIds() { return itemIndex().getKeyList(); }
@Nullable public RealVector getItemVector(long item) { int iidx = itemIndex.tryGetIndex(item); if (iidx < 0) { return null; } else { return Vectors.matrixRow(itemMatrix, iidx); } }
/** * Create a map from an array and index mapping. * * @param mapping The index mapping specifying the keys. * @param values The array of values. * @return A sparse vector mapping the IDs in {@code map} to the values in {@code values}. * @throws IllegalArgumentException if {@code values} not the same size as {@code idx}. */ public static Long2DoubleSortedArrayMap fromArray(KeyIndex mapping, DoubleList values) { Preconditions.checkArgument(values.size() == mapping.size(), "value array and index have different sizes: " + values.size() + " != " + mapping.size()); final int n = values.size(); double[] nvs = new double[n]; SortedKeyIndex index = SortedKeyIndex.fromCollection(mapping.getKeyList()); for (int i = 0; i < n; i++) { long item = index.getKey(i); int origIndex = mapping.getIndex(item); nvs[i] = values.get(origIndex); } return wrap(index, nvs); }
@Override public LongCollection getUserIds() { return userIndex().getKeyList(); }
@Test public void testSingleton() { KeyIndex idx = new FrozenHashKeyIndex(LongLists.singleton(42)); assertThat(idx.getKeyList(), hasSize(1)); assertThat(idx.getKeyList(), contains(42L)); assertThat(idx.containsKey(30), equalTo(false)); assertThat(idx.containsKey(42), equalTo(true)); assertThat(idx.tryGetIndex(30), equalTo(-1)); assertThat(idx.tryGetIndex(42), equalTo(0)); assertThat(idx.getIndex(42), equalTo(0)); assertThat(idx.getKey(0), equalTo(42L)); try { idx.getIndex(30); fail("getting absent index should fail"); } catch (IllegalArgumentException ex) { /* expected */ } }
@Nullable public RealVector getUserVector(long user) { int uidx = userIndex.tryGetIndex(user); if (uidx < 0) { return null; } else { return Vectors.matrixRow(userMatrix, uidx); } }
private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { featureCount = input.readInt(); userCount = input.readInt(); itemCount = input.readInt(); RealMatrix umat = MatrixUtils.createRealMatrix(userCount, featureCount); for (int i = 0; i < userCount; i++) { for (int j = 0; j < featureCount; j++) { umat.setEntry(i, j, input.readDouble()); } } userMatrix = umat; RealMatrix imat = MatrixUtils.createRealMatrix(itemCount, featureCount); for (int i = 0; i < itemCount; i++) { for (int j = 0; j < featureCount; j++) { imat.setEntry(i, j, input.readDouble()); } } itemMatrix = imat; userIndex = (KeyIndex) input.readObject(); itemIndex = (KeyIndex) input.readObject(); if (userIndex.size() != userMatrix.getRowDimension()) { throw new InvalidObjectException("user matrix and index have different row counts"); } if (itemIndex.size() != itemMatrix.getRowDimension()) { throw new InvalidObjectException("item matrix and index have different row counts"); } }
@Override public long getItemId() { return itemIndex.getKey(getItemIndex()); }
@Override public LongCollection getUserIds() { return userIndex().getKeyList(); }
/** * Get a particular feature value for an user. * @param uid The item ID. * @param feature The feature. * @return The user-feature value, or 0 if the user was not in the training set. */ public double getUserFeature(long uid, int feature) { int uidx = userIndex.tryGetIndex(uid); if (uidx < 0) { return 0; } else { return userMatrix.getEntry(uidx, feature); } }