/** * Creates a new fixed point based RealNumeric ComputationDirectory * * @param builder a ProtocolBuilder for the numeric computations which will be used to implement * the fixed point operations. * @param precision the precision used for the fixed point numbers. The precision must be in the * range <i>0 ... <code>builder.getMaxBitLength</code> / 4</i>. */ public FixedNumeric(ProtocolBuilderNumeric builder, int precision) { this.builder = builder; this.defaultPrecision = precision; /* * We reserve as many bits the integer part as for the fractional part and to be able to * represent products, we need to be able to hold twice that under the max bit length. */ Objects.requireNonNull(builder); this.maxPrecision = builder.getBasicNumericContext().getMaxBitLength() / 4; if (defaultPrecision < 0 || defaultPrecision > maxPrecision) { throw new IllegalArgumentException( "Precision must be in the range 0 ... " + maxPrecision + " but was " + defaultPrecision); } }
@Override public DRes<SInt> equals(DRes<SInt> x, DRes<SInt> y) { int maxBitLength = builder.getBasicNumericContext().getMaxBitLength(); return equals(maxBitLength, x, y); }
@Override public DRes<Matrix<DRes<SInt>>> permute(DRes<Matrix<DRes<SInt>>> values, int[] idxPerm) { return builder .seq(new PermuteRows(values, idxPerm, builder.getBasicNumericContext().getMyId(), true)); }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { int reducedRounds = MimcEncryptionReducedRounds .computeReducedRounds(builder.getBasicNumericContext().getModulus()); return (new MiMCDecryption(cipherText, encryptionKey, reducedRounds)).buildComputation(builder); }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { int reducedRounds = computeReducedRounds(builder.getBasicNumericContext().getModulus()); return (new MiMCEncryption(plainText, encryptionKey, reducedRounds)).buildComputation(builder); }
@Override public DRes<SReal> sqrt(DRes<SReal> x) { return builder.seq(seq -> { SFixed cast = (SFixed) x.out(); DRes<SInt> underlyingInt = cast.getSInt(); int scale = cast.getPrecision(); DRes<SInt> intResult = seq.advancedNumeric().sqrt(underlyingInt, seq.getBasicNumericContext().getMaxBitLength()); int newScale = Math.floorDiv(scale, 2); DRes<SReal> result = new SFixed(intResult, newScale); int scaleResidue = Math.floorMod(scale, 2); if (scaleResidue == 1) { result = seq.realNumeric().mult(BigDecimal.valueOf(1.0 / Math.sqrt(2.0)), result); } return result; }); }
@Override public void test() throws Exception { Application<SInt, ProtocolBuilderNumeric> app = (builder) -> { int maxLength = builder.getBasicNumericContext().getMaxBitLength(); BigInteger divisor = BigInteger.valueOf(2).pow(maxLength); DRes<SInt> dividend = builder.numeric().known(BigInteger.TEN); return builder.seq(new KnownDivisor(dividend, divisor)); }; runApplication(app); }
@Override public void test() throws Exception { // define input List<BigInteger> input = new ArrayList<>(); int numInputs = 100; for (int i = 0; i < numInputs; i++) { input.add(BigInteger.valueOf(i)); } // define functionality to be tested Application<List<BigInteger>, ProtocolBuilderNumeric> testApplication = root -> { Collections collections = root.collections(); DRes<List<DRes<SInt>>> closed; if (root.getBasicNumericContext().getMyId() == 1) { // party 1 provides input closed = collections.closeList(input, 1); } else { // other parties receive it closed = collections.closeList(numInputs, 1); } DRes<List<DRes<BigInteger>>> opened = collections.openList(closed); return () -> opened.out().stream().map(DRes::out).collect(Collectors.toList()); }; // run test application List<BigInteger> output = runApplication(testApplication); assertThat(output, is(input)); } };
@Override public DRes<BigDecimal> open(DRes<SReal> x) { return builder.seq(seq -> { SFixed floatX = (SFixed) x.out(); DRes<SInt> unscaled = floatX.getSInt(); DRes<BigInteger> unscaledOpen = seq.numeric().open(unscaled); int precision = floatX.getPrecision(); return () -> scaled(builder.getBasicNumericContext().getFieldDefinition(), unscaledOpen.out(), precision); }); }
@Override public DRes<Matrix<DRes<SInt>>> buildComputation(ProtocolBuilderNumeric builder) { /* * There is a round for each party in pids. Each party chooses a random permutation (of indexes) * and applies it to values using PermuteRows. */ Matrix<DRes<SInt>> valuesOut = values.out(); final int height = valuesOut.getHeight(); if (height < 2) { return values; } final int pid = builder.getBasicNumericContext().getMyId(); final int numPids = builder.getBasicNumericContext().getNoOfParties(); return builder.seq( (seq) -> new IterationState(0, values) ).whileLoop((state) -> state.round < numPids, (seq, state) -> { int thisRoundPid = state.round + 1; // parties start from 1 DRes<Matrix<DRes<SInt>>> permuted; if (pid == thisRoundPid) { permuted = seq.collections().permute(state.intermediate, getIdxPerm(height)); } else { permuted = seq.collections().permute(state.intermediate, thisRoundPid); } return new IterationState(state.round + 1, permuted); }).seq((seq, state) -> state.intermediate); }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { Numeric numeric = builder.numeric(); DRes<SInt> random = numeric.randomElement(); DRes<SInt> product = numeric.mult(value, random); DRes<BigInteger> open = numeric.open(product); return builder.seq((seq) -> { BigInteger value = open.out(); BigInteger inverse = value.modInverse(seq.getBasicNumericContext().getModulus()); return seq.numeric().mult(inverse, random); }); } }
@Override public void test() { Application<Pair<BigInteger, BigInteger>, ProtocolBuilderNumeric> app = producer -> { BigInteger modulus = producer.getBasicNumericContext().getModulus(); BigInteger input = modulus.divide(BigInteger.valueOf(2)).add(BigInteger.ONE); Numeric numeric = producer.numeric(); DRes<SInt> closed = numeric.input(input, 1); DRes<BigInteger> opened = numeric.open(closed); return () -> new Pair<>(opened.out(), input); }; Pair<BigInteger, BigInteger> actualAndExpected = runApplication(app); Assert.assertEquals(actualAndExpected.getSecond(), actualAndExpected.getFirst()); } };
@Override public DRes<BigDecimal> open(DRes<SReal> x, int outputParty) { return builder.seq(seq -> { SFixed floatX = (SFixed) x.out(); DRes<SInt> unscaled = floatX.getSInt(); DRes<BigInteger> unscaledOpen = seq.numeric().open(unscaled, outputParty); int precision = floatX.getPrecision(); return () -> { if (unscaledOpen.out() != null) { return scaled(builder.getBasicNumericContext().getFieldDefinition(), unscaledOpen.out(), precision); } else { return null; } }; }); }
@Override public void test() { Application<Pair<Integer, List<DRes<BigInteger>>>, ProtocolBuilderNumeric> app = producer -> { Numeric numeric = producer.numeric(); int noOfParties = producer.getBasicNumericContext().getNoOfParties(); List<DRes<SInt>> inputs = new ArrayList<>(noOfParties); for (int i = 1; i <= noOfParties; i++) { inputs.add(numeric.input(BigInteger.valueOf(i), i)); } DRes<List<DRes<BigInteger>>> opened = producer.collections().openList(() -> inputs); return () -> new Pair<>(noOfParties, opened.out()); }; Pair<Integer, List<DRes<BigInteger>>> output = runApplication(app); int noOfParties = output.getFirst(); List<DRes<BigInteger>> inputs = output.getSecond(); Assert.assertEquals(noOfParties, inputs.size()); for (int i = 0; i < noOfParties; i++) { Assert.assertEquals(i + 1, inputs.get(i).out().intValue()); } } };
@Override public void test() { Application<BigInteger, ProtocolBuilderNumeric> app = builder -> { fieldDefinition = builder.getBasicNumericContext().getFieldDefinition(); DRes<SInt> p = builder.numeric().known(BigInteger.valueOf(numerator)); DRes<SInt> q = builder.numeric().known(BigInteger.valueOf(denominator)); DRes<SInt> result = builder.advancedNumeric().div(p, q); return builder.numeric().open(result); }; BigInteger result = runApplication(app); Assert.assertEquals( BigInteger.valueOf(numerator / denominator), fieldDefinition.convertToSigned(result)); } };
@Override public void test() throws Exception { Application<Matrix<BigInteger>, ProtocolBuilderNumeric> testApplication = root -> { DRes<Matrix<DRes<SInt>>> closed = root.collections().closeMatrix(input, 1); // use package-private constructor to fix randomness DRes<Matrix<DRes<SInt>>> shuffled = root.seq( new ShuffleRows(closed, new Random(42 + root.getBasicNumericContext().getMyId()))); DRes<Matrix<DRes<BigInteger>>> opened = root.collections().openMatrix(shuffled); return () -> new MatrixUtils().unwrapMatrix(opened); }; Matrix<BigInteger> actual = runApplication(testApplication); assertThat(actual.getRows(), is(expected.getRows())); } };
@Override public void test() { Application<BigInteger, ProtocolBuilderNumeric> app = builder -> { fieldDefinition = builder.getBasicNumericContext().getFieldDefinition(); DRes<SInt> p = builder.numeric().known(numerator); DRes<SInt> result = builder.advancedNumeric().div(p, (long) denominator); return builder.numeric().open(result); }; BigInteger result = runApplication(app); Assert.assertEquals( BigInteger.valueOf(numerator / denominator), fieldDefinition.convertToSigned(result)); } };
@Override public void test() throws Exception { // define functionality to be tested Application<Matrix<BigInteger>, ProtocolBuilderNumeric> testApplication = root -> { Collections collections = root.collections(); DRes<Matrix<DRes<SInt>>> closed = collections.closeMatrix(input, 1); DRes<Matrix<DRes<SInt>>> permuted = null; if (root.getBasicNumericContext().getMyId() == 1) { permuted = collections.permute(closed, idxPerm); } else { permuted = collections.permute(closed, 1); } DRes<Matrix<DRes<BigInteger>>> opened = collections.openMatrix(permuted); return () -> new MatrixUtils().unwrapMatrix(opened); }; Matrix<BigInteger> actual = runApplication(testApplication); assertThat(actual.getRows(), is(expected.getRows())); } };
@Override public void test() { Application<BigInteger, ProtocolBuilderNumeric> app = producer -> { Numeric numeric = producer.numeric(); DRes<SInt> input = numeric.input(BigInteger.ONE, 1); return producer.seq(seq -> { SInt value = input.out(); if (seq.getBasicNumericContext().getMyId() == cheatingPartyId) { value = ((SpdzSInt) value).multiply(definition.createElement(2)); } final SInt finalSInt = value; return seq.numeric().open(() -> finalSInt); }); }; try { runApplication(app); } catch (Exception e) { assertThat(e.getCause(), IsInstanceOf.instanceOf(MaliciousException.class)); } } };
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { BigInteger modulus = builder.getBasicNumericContext().getModulus(); final int requiredRounds = getRequiredRounds(modulus, requestedRounds); /* * In the first round we compute c = (p + K)^{3} where p is the plaintext */ return builder.seq(seq -> { DRes<SInt> add = seq.numeric().add(plainText, encryptionKey); return new IterationState(1, seq.advancedNumeric().exp(add, THREE)); }).whileLoop((state) -> state.round < requiredRounds, (seq, state) -> { /* * We're in an intermediate round where we compute c_{i} = (c_{i - 1} + K + r_{i})^{3} where K * is the symmetric key i is the reverse of the current round count r_{i} is the round * constant c_{i - 1} is the cipher text we have computed in the previous round */ BigInteger roundConstantInteger = roundConstants.getConstant(state.round, modulus); Numeric numeric = seq.numeric(); DRes<SInt> masked = numeric.add(roundConstantInteger, numeric.add(state.value, encryptionKey)); DRes<SInt> updatedValue = seq.advancedNumeric().exp(masked, THREE); return new IterationState(state.round + 1, updatedValue); }).seq((seq, state) -> /* * We're in the last round so we just mask the current cipher text with the encryption key */ seq.numeric().add(state.value, encryptionKey)); }