@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { Numeric numeric = builder.numeric(); DRes<SInt> random = numeric.randomElement(); DRes<SInt> product = numeric.mult(value, random); DRes<BigInteger> open = numeric.open(product); return builder.seq((seq) -> { BigInteger value = open.out(); BigInteger inverse = value.modInverse(seq.getBasicNumericContext().getModulus()); return seq.numeric().mult(inverse, random); }); } }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric rootBuilder) { return rootBuilder.par((parallelBuilder) -> { List<DRes<SInt>> result = new ArrayList<>(); Comparison builder = parallelBuilder.comparison(); Numeric numericBuilder = builder.numeric(); DRes<SInt> lastComparison = comparisons.get(comparisons.size() - 1); comparisons.add(numericBuilder.sub(BigInteger.ONE, lastComparison)); Numeric numericBuilder = parallelBuilder.numeric(); List<DRes<SInt>> innerScores = new ArrayList<>(); innerScores.add(numericBuilder.mult(comparisons.get(0), scores.get(0))); for (int i = 1; i < scores.size() - 1; i++) { int finalI = i; final DRes<SInt> res = parallelBuilder.seq(seq -> { DRes<SInt> hit = seq.numeric() .sub(comparisons.get(finalI), comparisons.get(finalI - 1)); return seq.numeric().mult(hit, scores.get(finalI)); }); return () -> innerScores; }).seq((seq, list) -> seq.advancedNumeric().sum(list));
@Override public DRes<SInt> add(DRes<SInt> a, DRes<SInt> b) { SpdzAddProtocol spdzAddProtocol = new SpdzAddProtocol(a, b); return protocolBuilder.append(spdzAddProtocol); }
@Override public DRes<Vector<DRes<SReal>>> input(Vector<BigDecimal> a, int inputParty) { return builder.par(par -> { Vector<DRes<SReal>> matrix = a.stream().map(e -> par.realNumeric().input(e, inputParty)) .collect(Collectors.toCollection(Vector::new)); return () -> matrix; }); }
@Override public DRes<Matrix<DRes<SInt>>> permute(DRes<Matrix<DRes<SInt>>> values, int[] idxPerm) { return builder .seq(new PermuteRows(values, idxPerm, builder.getBasicNumericContext().getMyId(), true)); }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { DRes<SInt> diff = builder.numeric().sub(left, right); return builder.comparison().compareZero(diff, bitLength); } }
@Override public void test() throws Exception { Application<List<BigInteger>, ProtocolBuilderNumeric> app = builder -> { Numeric input = builder.numeric(); Comparison comparison = builder.comparison(); int maxBitLength = builder.getBasicNumericContext().getMaxBitLength(); input.known(BigInteger.valueOf(1))) ); DRes<List<DRes<BigInteger>>> opened = builder.collections().openList(() -> comps); return builder.seq((seq) -> { return () -> opened.out().stream().map(DRes::out).collect(Collectors.toList()); });
@Override public DRes<SReal> sqrt(DRes<SReal> x) { return builder.seq(seq -> { SFixed cast = (SFixed) x.out(); DRes<SInt> underlyingInt = cast.getSInt(); int scale = cast.getPrecision(); DRes<SInt> intResult = seq.advancedNumeric().sqrt(underlyingInt, seq.getBasicNumericContext().getMaxBitLength()); int newScale = Math.floorDiv(scale, 2); DRes<SReal> result = new SFixed(intResult, newScale); int scaleResidue = Math.floorMod(scale, 2); if (scaleResidue == 1) { result = seq.realNumeric().mult(BigDecimal.valueOf(1.0 / Math.sqrt(2.0)), result); } return result; }); }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { return builder.par((par) -> { int n = keys.size(); List<DRes<SInt>> index = new ArrayList<>(n); for (DRes<SInt> key : keys) { index.add(par.comparison().equals(lookUpKey, key)); } return () -> index; }).seq((seq, index) -> { DRes<SInt> outputValue = notFoundValue; for (int i = 0, valuesLength = values.size(); i < valuesLength; i++) { DRes<SInt> value = values.get(i); outputValue = seq.seq(new ConditionalSelect(index.get(i), value, outputValue)); } return outputValue; }); } }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { AdvancedNumeric advancedNumericBuilder = builder.advancedNumeric(); DRes<SInt> divisionResult = advancedNumericBuilder.div(dividend, divisor); Numeric numeric = builder.numeric(); return numeric.sub(dividend, numeric.mult(divisor, divisionResult)); } }
public void compareAndSwap(ProtocolBuilderNumeric builder, int a, int b, List<DRes<SInt>> values) { //Non splitting version Numeric numeric = builder.numeric(); DRes<SInt> valueA = values.get(a); DRes<SInt> valueB = values.get(b); DRes<SInt> comparison = builder.comparison().compareLEQ(valueA, valueB); DRes<SInt> sub = numeric.sub(valueA, valueB); DRes<SInt> c = numeric.mult(comparison, sub); DRes<SInt> d = numeric.mult(minusOne, c); //a = comparison*a+(1-comparison)*b ==> comparison*(a-b)+b //b = comparison*b+(1-comparison)*a ==> -comparison*(a-b)+a builder.par(par -> { values.set(a, par.numeric().add(c, valueB)); values.set(b, par.numeric().add(d, valueA)); return null; }); }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { return builder.seq((seq) -> () -> this.data ).seq((seq, list) -> seq.advancedNumeric().sum(list) ).seq((seq, sum) -> { BigInteger n = BigInteger.valueOf(this.degreesOfFreedom); return seq.advancedNumeric().div(() -> sum, n); }); }
@Override public DRes<SInt> buildComputation(ProtocolBuilderNumeric builder) { BasicNumericContext basicNumericContext = builder.getBasicNumericContext(); Numeric numeric = builder.numeric(); DRes<SInt> dividendSign = builder.comparison().sign(dividend); DRes<SInt> dividendAbs = numeric.mult(dividend, dividendSign); DRes<SInt> q = builder.advancedNumeric().rightShift(quotientAbs, shifts);
@Override public void test() { Application<BigInteger, ProtocolBuilderNumeric> app = builder -> { fieldDefinition = builder.getBasicNumericContext().getFieldDefinition(); DRes<SInt> p = builder.numeric().known(BigInteger.valueOf(numerator)); DRes<SInt> q = builder.numeric().known(BigInteger.valueOf(denominator)); DRes<SInt> result = builder.advancedNumeric().div(p, q); return builder.numeric().open(result); }; BigInteger result = runApplication(app); Assert.assertEquals( BigInteger.valueOf(numerator / denominator), fieldDefinition.convertToSigned(result)); } };
/** * Creates a ProtocolProducer to compute the first half of a simplex iteration. * <p> * This finds the variable to enter the basis, based on the pivot rule of most negative entry in * the <i>F</i> vector. Also tests if no negative entry in the <i>F</i> vector is present. If this * is the case we should terminate the simplex method. * </p> * * @return a delayed result of the phaseOne computation of the first half of a simplex iteration */ private DRes<Pair<List<DRes<SInt>>, BigInteger>> phaseOneDanzig( ProtocolBuilderNumeric builder, LpState state, DRes<SInt> zero) { return builder .seq( // Compute potential entering variable index and corresponding value of // entry in F new EnteringVariable(state.tableau, state.updateMatrix)) .seq((seq, enteringAndMinimum) -> { List<DRes<SInt>> entering = enteringAndMinimum.getFirst(); SInt minimum = enteringAndMinimum.getSecond(); // Check if the entry in F is non-negative DRes<SInt> positive = seq.comparison().compareLEQLong(zero, () -> minimum); DRes<BigInteger> terminationOut = seq.numeric().open(positive); return () -> new Pair<>(entering, terminationOut.out()); }); }
@Override public void test() throws Exception { Application<Matrix<BigInteger>, ProtocolBuilderNumeric> testApplication = root -> { DRes<Matrix<DRes<SInt>>> closed = root.collections().closeMatrix(input, 1); // use package-private constructor to fix randomness DRes<Matrix<DRes<SInt>>> shuffled = root.seq( new ShuffleRows(closed, new Random(42 + root.getBasicNumericContext().getMyId()))); DRes<Matrix<DRes<BigInteger>>> opened = root.collections().openMatrix(shuffled); return () -> new MatrixUtils().unwrapMatrix(opened); }; Matrix<BigInteger> actual = runApplication(testApplication); assertThat(actual.getRows(), is(expected.getRows())); } };
@Override public DRes<SInt> input(BigInteger value, int inputParty) { return builder.seq( new Spdz2kInputComputation<>(factory.createElement(value), inputParty) ); }
@Override public void test() { Application<Pair<Integer, List<DRes<BigInteger>>>, ProtocolBuilderNumeric> app = producer -> { Numeric numeric = producer.numeric(); int noOfParties = producer.getBasicNumericContext().getNoOfParties(); List<DRes<SInt>> inputs = new ArrayList<>(noOfParties); for (int i = 1; i <= noOfParties; i++) { inputs.add(numeric.input(BigInteger.valueOf(i), i)); } DRes<List<DRes<BigInteger>>> opened = producer.collections().openList(() -> inputs); return () -> new Pair<>(noOfParties, opened.out()); }; Pair<Integer, List<DRes<BigInteger>>> output = runApplication(app); int noOfParties = output.getFirst(); List<DRes<BigInteger>> inputs = output.getSecond(); Assert.assertEquals(noOfParties, inputs.size()); for (int i = 0; i < noOfParties; i++) { Assert.assertEquals(i + 1, inputs.get(i).out().intValue()); } } };
public DRes<SInt> isSorted(ProtocolBuilderNumeric builder, List<DRes<SInt>> values) { return builder.par(par -> { Comparison comparison = par.comparison(); ArrayList<DRes<SInt>> comparisons = new ArrayList<>(); boolean first = true; DRes<SInt> previous = null; for (DRes<SInt> value : values) { if (!first) { comparisons.add(comparison.compareLEQ(previous, value)); } else { first = false; } previous = value; } return () -> comparisons; }).seq((seq, comparison) -> seq.advancedNumeric().product(comparison)); }