public List<Derivation> parseSyntactic(String input){ List<String> tokens = tokenizer.tokenize(input); List<String> tokensLower = new ArrayList<>(tokens.size()); for(String token: tokens) tokensLower.add(token.toLowerCase()); int N = tokens.size(); Chart chart = new Chart(N+1); for(int e = 1; e <= N; e++) { for(int s = e-1; s >= 0; s--) { applyAnnotators(chart, tokens, s, e); applyLexicalRules(chart, tokensLower, s, e); applyBinaryRules(chart, s, e); applyUnaryRules(chart, s, e); } } List<Derivation> derivations = new LinkedList<>(); for(Derivation d: chart.getDerivations(0, N)) if(grammar.isRoot(d.rule)) derivations.add(d); return derivations; }
public LogicalForm computeLogicalForm(Derivation d){ LogicalForm lf = new LogicalForm(d, applySemantics(d).get(0)); lf.updateScore(weights); return lf; }
public void interactive(){ model.loadWeights(weightsPath); Parser parser = model.getParser(); Scanner scanner = new Scanner(System.in); while(true){ System.out.print(">> "); String input = scanner.nextLine(); List<LogicalForm> results = parser.parse(input); if(results.size() > 0) System.out.println(results.get(0)); else System.err.println("Failed to parse input\n"); } } }
public List<LogicalForm> parse(String input) { List<LogicalForm> lfs = new ArrayList<>(); for(Derivation d: parseSyntactic(input)) lfs.add(computeLogicalForm(d)); Collections.sort(lfs, Collections.reverseOrder()); return lfs; }
@Test void applyAnnotators() { Parser p = new Parser(null, null, Collections.singletonList(PhraseAnnotator.INSTANCE)); Parser.Chart chart = p.new Chart(10); List<String> tokens = Arrays.asList("A", "B", "C"); Rule r = PhraseAnnotator.INSTANCE.annotate(tokens).get(0); p.applyAnnotators(chart, tokens, 0, 3); assertEquals(r, chart.getDerivations(0, 3).get(0).rule); }
@Test void applyLexicalRules() { List<Rule> rules = Collections.singletonList(new Rule("$A", "B C")); Grammar grammar = new Grammar(rules, "$ROOT"); Parser p = new Parser(grammar, null, null); Parser.Chart chart = p.new Chart(10); List<String> tokens = Arrays.asList("A", "B", "C"); p.applyLexicalRules(chart, tokens, 1, 3); assertEquals(rules.get(0), chart.getDerivations(1, 3).get(0).rule); }
@Test void parseSyntactic() { Rule r1 = new Rule("$A", "a"); Rule r2 = new Rule("$B", "b"); Rule r3 = new Rule("$C", "$A $B"); Grammar grammar = new Grammar(Arrays.asList(r1, r2, r3), "$C"); Parser p = new Parser(grammar, s -> Arrays.asList(s.split(" ")), Collections.emptyList()); Derivation dc1 = new Derivation(r1, null); Derivation dc2 = new Derivation(r2, null); Derivation expected = new Derivation(r3, Arrays.asList(dc1, dc2)); Derivation actual = p.parseSyntactic("a b").get(0); assertEquals(expected.rule, actual.rule); assertEquals(expected.children.get(0).rule, actual.children.get(0).rule); assertEquals(expected.children.get(1).rule, actual.children.get(1).rule); }
@Test void parse() { List<Rule> rules = Arrays.asList( new Rule("$A", "a"), new Rule("$B", "b"), new Rule("$C", "$A $B", "{e:@first, f:@last}") ); Grammar grammar = new Grammar(rules, "$C"); Parser p = new Parser(grammar, s -> Arrays.asList(s.split(" ")), Collections.emptyList()); Map<String, Object> expected = new HashMap<String, Object>(){{ put("e", "a"); put("f", "b"); }}; List<LogicalForm> actual = p.parse("a b"); assertEquals(1, actual.size()); assertEquals(expected, actual.get(0).getMap()); }
@Test void applySemantics() { Rule r1 = new Rule("$A", "a"); Rule r2 = new Rule("$B", "$A $A", "{b:@1}"); Derivation dc1 = new Derivation(r1, null); Derivation dc2 = new Derivation(r1, null); Derivation d = new Derivation(r2, Arrays.asList(dc1, dc2)); Map<String, Object> expected = new HashMap<String, Object>(){{ put("b", "a"); }}; Parser p = new Parser(null, null, null); assertEquals(expected, p.applySemantics(d).get(0)); }
@Test void applyUnaryRules() { List<Rule> rules = Arrays.asList( new Rule("$F", "$E"), new Rule("$E", "$D")); Grammar grammar = new Grammar(rules, "$ROOT"); Parser p = new Parser(grammar, null, null); Parser.Chart chart = p.new Chart(10); chart.addDerivation(1, 3, new Derivation(new Rule("$D", "$B $C"), null)); p.applyUnaryRules(chart, 1, 3); assertEquals(3, chart.getDerivations(1, 3).size()); assertEquals(rules.get(0), chart.getDerivations(1, 3).get(2).rule); assertEquals(rules.get(1), chart.getDerivations(1, 3).get(1).rule); }
@Test void applyBinaryRules() { List<Rule> rules = Collections.singletonList( new Rule("$C", "$A $B")); Grammar grammar = new Grammar(rules, "$ROOT"); Parser p = new Parser(grammar, null, null); Parser.Chart chart = p.new Chart(10); chart.addDerivation(0, 1, new Derivation(new Rule("$A", "A"), null)); chart.addDerivation(1, 2, new Derivation(new Rule("$B", "B"), null)); p.applyBinaryRules(chart, 0, 2); assertEquals(rules.get(0), chart.getDerivations(0, 2).get(0).rule); } }
public static Model makeReminderModel(){ List<Rule> rules = new LinkedList<>(); rules.addAll(Rules.BASE); rules.addAll(Rules.fromFile("model/reminders.rules")); rules.addAll(DateTimeAnnotator.rules()); List<Annotator> annotators = Arrays.asList( TokenAnnotator.INSTANCE, PhraseAnnotator.INSTANCE, NumberAnnotator.INSTANCE, DateTimeAnnotator.INSTANCE ); Grammar grammar = new Grammar(rules, "$ROOT"); Parser parser = new Parser(grammar, new BasicTokenizer(), annotators); return new Model(parser); }
void train(Dataset d, Optimizer optimizer, int epochs){ float loss; float maxAcc = 0; int count; Map<String, Float> bestWeights = null; for(int i = 0; i < epochs; i++){ loss = 0; count = 0; optimizer.onEpochStart(); for(Example e: d.shuffle()){ loss += optimizer.optimize(e); count++; } optimizer.onEpochComplete(count); System.out.println(String.format("Epoch %d: Train loss = %f", i+1, loss/count)); float acc = evaluate(d, 0); if(acc > maxAcc){ maxAcc = acc; bestWeights = new HashMap<>(parser.getWeights()); } System.out.println(); } System.out.println("Max accuracy: " + maxAcc); weights = bestWeights; parser.setWeights(bestWeights); }
public Model(Parser parser){ this.parser = parser; weights = parser.getWeights(); }
@Test void chartRetrieval(){ Parser p = new Parser(null, null, null); Parser.Chart chart = p.new Chart(10); assertEquals(32, chart.mapSpan(2, 3)); Derivation d = new Derivation(null, null); chart.addDerivation(3, 5, d); assertEquals(1, chart.getDerivations(3, 5).size()); assertEquals(d, chart.getDerivations(3, 5).get(0)); }
boolean firstCorrect = false; List<LogicalForm> lfs = parser.parse(e.input); for(LogicalForm lf: lfs){ if(first)
List<Map<String, Object>> applySemantics(Derivation d){ SemanticFunction fn = d.rule.getSemantics(); if(d.children == null) { if(fn == null) return Collections.singletonList(SemanticUtils.value(d.rule.getRHS().toString())); else return fn.apply(Collections.emptyList()); } String rule = d.rule.toString(); d.score = weights.containsKey(rule) ? weights.get(rule) : 0f; List<Map<String, Object>> params = new LinkedList<>(); for(Derivation child: d.children){ params.addAll(applySemantics(child)); d.score += child.score; } return fn.apply(params); }
@Override public float optimize(Example e) { List<LogicalForm> candidates = parser.parse(e.input); if(candidates.isEmpty()) return 1; LogicalForm lf = randomCandidate(e, candidates); Derivation d = lf.getDerivation(); Map<String, Integer> features = d.getRuleFeatures(); features.putAll(lf.fields().stream().collect(Collectors.toMap(Function.identity(), k -> 1))); float target = computeTarget(lf, e); float loss = hingeLoss(features, target); if(loss > 0) updateWeights(features, target, learnRate); updateL2Penalty(learnRate, l2Penalty); return loss; }