@Override public RandomDataSplitStrategy get() { final int userNum = snapshot.userIndex().size(); final int itemNum = snapshot.itemIndex().size(); logger.info("Rating matrix size: {} users and {} items", userNum, itemNum); List<RatingMatrixEntry> allRatings = new ArrayList<>(snapshot.getRatings()); final int size = allRatings.size(); final int validationSize = Math.toIntExact(Math.round(size*proportion)); logger.info("validation set size: {} ratings", validationSize); Collections.shuffle(allRatings, random); List<RatingMatrixEntry> subList = allRatings.subList(0, validationSize); final List<RatingMatrixEntry> validationRatings = ImmutableList.copyOf(subList); subList.clear(); logger.info("validation rating size: {}", validationRatings.size()); final KeyIndex userIndex = snapshot.userIndex(); final KeyIndex itemIndex = snapshot.itemIndex(); return new RandomDataSplitStrategy(allRatings, validationRatings, userIndex, itemIndex); } }
private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { featureCount = input.readInt(); userCount = input.readInt(); itemCount = input.readInt(); RealMatrix umat = MatrixUtils.createRealMatrix(userCount, featureCount); for (int i = 0; i < userCount; i++) { for (int j = 0; j < featureCount; j++) { umat.setEntry(i, j, input.readDouble()); } } userMatrix = umat; RealMatrix imat = MatrixUtils.createRealMatrix(itemCount, featureCount); for (int i = 0; i < itemCount; i++) { for (int j = 0; j < featureCount; j++) { imat.setEntry(i, j, input.readDouble()); } } itemMatrix = imat; userIndex = (KeyIndex) input.readObject(); itemIndex = (KeyIndex) input.readObject(); if (userIndex.size() != userMatrix.getRowDimension()) { throw new InvalidObjectException("user matrix and index have different row counts"); } if (itemIndex.size() != itemMatrix.getRowDimension()) { throw new InvalidObjectException("item matrix and index have different row counts"); } }
@Override @Nonnull public List<Collection<RatingMatrixEntry>> get() { int nusers = data.getUserIndex().size(); ArrayList<IntArrayList> userLists = new ArrayList<>(nusers); for (int i = 0; i < nusers; i++) { userLists.add(new IntArrayList()); } for (RatingMatrixEntry pref : getRatings()) { final int uidx = pref.getUserIndex(); final int idx = pref.getIndex(); userLists.get(uidx).add(idx); } ArrayList<Collection<RatingMatrixEntry>> users = new ArrayList<>(nusers); for (IntArrayList list: userLists) { list.trim(); users.add(new PackedRatingCollection(data, list)); } return users; } }
/** * Construct a matrix factorization model. The matrices are not copied, so the caller should * make sure they won't be modified by anyone else. * * @param umat The user feature matrix (users x features). * @param imat The item feature matrix (items x features). * @param uidx The user index mapping. * @param iidx The item index mapping. */ public MFModel(RealMatrix umat, RealMatrix imat, KeyIndex uidx, KeyIndex iidx) { Preconditions.checkArgument(umat.getColumnDimension() == imat.getColumnDimension(), "mismatched matrix sizes"); featureCount = umat.getColumnDimension(); userCount = uidx.size(); itemCount = iidx.size(); Preconditions.checkArgument(umat.getRowDimension() == userCount, "user matrix has %s rows, expected %s", umat.getRowDimension(), userCount); Preconditions.checkArgument(imat.getRowDimension() == itemCount, "item matrix has %s rows, expected %s", imat.getRowDimension(), itemCount); userMatrix = umat; itemMatrix = imat; userIndex = uidx; itemIndex = iidx; }
/** * Create a map from an array and index mapping. * * @param mapping The index mapping specifying the keys. * @param values The array of values. * @return A sparse vector mapping the IDs in {@code map} to the values in {@code values}. * @throws IllegalArgumentException if {@code values} not the same size as {@code idx}. */ public static Long2DoubleSortedArrayMap fromArray(KeyIndex mapping, DoubleList values) { Preconditions.checkArgument(values.size() == mapping.size(), "value array and index have different sizes: " + values.size() + " != " + mapping.size()); final int n = values.size(); double[] nvs = new double[n]; SortedKeyIndex index = SortedKeyIndex.fromCollection(mapping.getKeyList()); for (int i = 0; i < n; i++) { long item = index.getKey(i); int origIndex = mapping.getIndex(item); nvs[i] = values.getDouble(origIndex); } return wrap(index, nvs); }
@Test public void testItemIndex() { KeyIndex ind = snap.itemIndex(); assertEquals(5, ind.size()); assertThat(ind.getKeyList(), containsInAnyOrder(7L, 8L, 9L, 10L, 11L)); }
private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException { featureCount = input.readInt(); userCount = input.readInt(); itemCount = input.readInt(); RealMatrix umat = MatrixUtils.createRealMatrix(userCount, featureCount); for (int i = 0; i < userCount; i++) { for (int j = 0; j < featureCount; j++) { umat.setEntry(i, j, input.readDouble()); } } userMatrix = umat; RealMatrix imat = MatrixUtils.createRealMatrix(itemCount, featureCount); for (int i = 0; i < itemCount; i++) { for (int j = 0; j < featureCount; j++) { imat.setEntry(i, j, input.readDouble()); } } itemMatrix = imat; userIndex = (KeyIndex) input.readObject(); itemIndex = (KeyIndex) input.readObject(); if (userIndex.size() != userMatrix.getRowDimension()) { throw new InvalidObjectException("user matrix and index have different row counts"); } if (itemIndex.size() != itemMatrix.getRowDimension()) { throw new InvalidObjectException("item matrix and index have different row counts"); } }
@Override @Nonnull public List<Collection<RatingMatrixEntry>> get() { int nusers = data.getUserIndex().size(); ArrayList<IntArrayList> userLists = new ArrayList<>(nusers); for (int i = 0; i < nusers; i++) { userLists.add(new IntArrayList()); } for (RatingMatrixEntry pref : getRatings()) { final int uidx = pref.getUserIndex(); final int idx = pref.getIndex(); userLists.get(uidx).add(idx); } ArrayList<Collection<RatingMatrixEntry>> users = new ArrayList<>(nusers); for (IntArrayList list: userLists) { list.trim(); users.add(new PackedRatingCollection(data, list)); } return users; } }
/** * Construct a matrix factorization model. The matrices are not copied, so the caller should * make sure they won't be modified by anyone else. * * @param umat The user feature matrix (users x features). * @param imat The item feature matrix (items x features). * @param uidx The user index mapping. * @param iidx The item index mapping. */ public MFModel(RealMatrix umat, RealMatrix imat, KeyIndex uidx, KeyIndex iidx) { Preconditions.checkArgument(umat.getColumnDimension() == imat.getColumnDimension(), "mismatched matrix sizes"); featureCount = umat.getColumnDimension(); userCount = uidx.size(); itemCount = iidx.size(); Preconditions.checkArgument(umat.getRowDimension() == userCount, "user matrix has %s rows, expected %s", umat.getRowDimension(), userCount); Preconditions.checkArgument(imat.getRowDimension() == itemCount, "item matrix has %s rows, expected %s", imat.getRowDimension(), itemCount); userMatrix = umat; itemMatrix = imat; userIndex = uidx; itemIndex = iidx; }
/** * Create a map from an array and index mapping. * * @param mapping The index mapping specifying the keys. * @param values The array of values. * @return A sparse vector mapping the IDs in {@code map} to the values in {@code values}. * @throws IllegalArgumentException if {@code values} not the same size as {@code idx}. */ public static Long2DoubleSortedArrayMap fromArray(KeyIndex mapping, DoubleList values) { Preconditions.checkArgument(values.size() == mapping.size(), "value array and index have different sizes: " + values.size() + " != " + mapping.size()); final int n = values.size(); double[] nvs = new double[n]; SortedKeyIndex index = SortedKeyIndex.fromCollection(mapping.getKeyList()); for (int i = 0; i < n; i++) { long item = index.getKey(i); int origIndex = mapping.getIndex(item); nvs[i] = values.get(origIndex); } return wrap(index, nvs); }
@Test public void testUserIndex() { KeyIndex ind = snap.userIndex(); assertEquals(6, ind.size()); assertTrue(ind.getKeyList().contains(1)); assertTrue(ind.getKeyList().contains(3)); assertTrue(ind.getKeyList().contains(4)); assertTrue(ind.getKeyList().contains(5)); assertTrue(ind.getKeyList().contains(6)); assertTrue(ind.getKeyList().contains(7)); assertEquals(0, ind.getIndex(1)); assertEquals(1, ind.getIndex(3)); assertEquals(2, ind.getIndex(4)); assertEquals(3, ind.getIndex(5)); assertEquals(4, ind.getIndex(6)); assertEquals(5, ind.getIndex(7)); assertEquals(1, ind.getKey(0)); assertEquals(3, ind.getKey(1)); assertEquals(4, ind.getKey(2)); assertEquals(5, ind.getKey(3)); assertEquals(6, ind.getKey(4)); assertEquals(7, ind.getKey(5)); }