/** * This function assumes a LinearCliquePotentialFunction is used for wrapping the weights * @return a new CRFCliqueTree for the weights on the data */ public static <E> CRFCliqueTree<E> getCalibratedCliqueTree(double[] weights, double wscale, int[][] weightIndices, int[][][] data, List<Index<CRFLabel>> labelIndices, int numClasses, Index<E> classIndex, E backgroundSymbol) { FactorTable[] factorTables = new FactorTable[data.length]; FactorTable[] messages = new FactorTable[data.length - 1]; for (int i = 0; i < data.length; i++) { factorTables[i] = getFactorTable(weights, wscale, weightIndices, data[i], labelIndices, numClasses); if (i > 0) { messages[i - 1] = factorTables[i - 1].sumOutFront(); factorTables[i].multiplyInFront(messages[i - 1]); } } for (int i = factorTables.length - 2; i >= 0; i--) { FactorTable summedOut = factorTables[i + 1].sumOutEnd(); summedOut.divideBy(messages[i]); factorTables[i].multiplyInEnd(summedOut); } return new CRFCliqueTree<>(factorTables, classIndex, backgroundSymbol); }
public TestSequenceModel(CRFCliqueTree<? extends CharSequence> cliqueTree, LabelDictionary labelDictionary, List<? extends CoreMap> document) { // this.factorTables = factorTables; this.cliqueTree = cliqueTree; // this.window = factorTables[0].windowSize(); this.window = cliqueTree.window(); // this.numClasses = factorTables[0].numClasses(); int numClasses = cliqueTree.getNumClasses(); this.backgroundTag = new int[] { cliqueTree.backgroundIndex() }; allTags = new int[numClasses]; for (int i = 0; i < allTags.length; i++) { allTags[i] = i; } if (labelDictionary != null) { // Constrained allowedTagsAtPosition = new int[document.size()][]; for (int i = 0; i < allowedTagsAtPosition.length; ++i) { CoreMap token = document.get(i); String observation = token.get(CoreAnnotations.TextAnnotation.class); allowedTagsAtPosition[i] = labelDictionary.isConstrained(observation) ? labelDictionary.getConstrainedSet(observation) : allTags; } } else { allowedTagsAtPosition = null; } }
public double condProbGivenNext(int position, E label, E[] nextLabels) { return Math.exp(condLogProbGivenNext(position, label, nextLabels)); }
/** * returns the log probability for the given labels, where the last label * corresponds to the label at the specified position. For instance if you * called logProb(5, {"O", "PER", "ORG"}) it will return the marginal log prob * that the label at position 3 is "O", the label at position 4 is "PER" and * the label at position 5 is "ORG". */ public double logProb(int position, E[] labels) { return logProb(position, objectArrayToIntArray(labels)); }
public double condLogProbGivenNext(int position, E label, E[] nextLabels) { return condLogProbGivenNext(position, classIndex.indexOf(label), objectArrayToIntArray(nextLabels)); }
public double condLogProbGivenPrevious(int position, E label, E[] prevLabels) { return condLogProbGivenPrevious(position, classIndex.indexOf(label), objectArrayToIntArray(prevLabels)); }
/** * Returns the log probability of this sequence given the CRF. Does so by * computing the marginal of the first windowSize tags, and then computing the * conditional probability for the rest of them, conditioned on the previous * tags. * * @param sequence The sequence to compute a score for * @return the score for the sequence */ @Override public double scoreOf(int[] sequence) { int[] given = new int[window() - 1]; Arrays.fill(given, classIndex.indexOf(backgroundSymbol)); double logProb = 0.0; for (int i = 0, length = length(); i < length; i++) { int label = sequence[i]; logProb += condLogProbGivenPrevious(i, label, given); System.arraycopy(given, 1, given, 0, given.length - 1); given[given.length - 1] = label; } return logProb; }
CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, null); double p = cliqueTree.condLogProbGivenPrevious(i, label, given); if (VERBOSE) { log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p); double p = cliqueTree.prob(i, label); // probability of these labels occurring in this clique with these features for (int lopIter = 0; lopIter < numLopExpert; lopIter++) { Set<Integer> indicesSet = featureIndicesSetArray.get(lopIter);
CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr); double startPosLogProb = cliqueTree.logProbStartPos(); if (VERBOSE) System.err.printf("P_-1(Background) = % 5.3f\n", startPosLogProb); double p = cliqueTree.condLogProbGivenPrevious(i, label, given); if (VERBOSE) { log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + Math.exp(p)); double p = cliqueTree.prob(i, label); // probability of these labels occurring in this clique with these features if (dropoutApprox) increScore(EForADocPosAtI, fIndex, k, fVal * p);
public List<Counter<String>> zeroOrderProbabilities(List<IN> document) { List<Counter<String>> ret = new ArrayList<>(); Triple<int[][][], int[], double[][][]> p = documentToDataAndLabels(document); CRFCliqueTree<String> cliqueTree = getCliqueTree(p); for (int i = 0; i < cliqueTree.length(); i++) { Counter<String> ctr = new ClassicCounter<>(); for (String label : classIndex) { int index = classIndex.indexOf(label); double prob = cliqueTree.prob(i, index); ctr.setCount(label, prob); } ret.add(ctr); } return ret; }
/** * Returns the probability for the given labels (indexed using classIndex), * where the last label corresponds to the label at the specified position. * For instance if you called prob(5, {1,2,3}) it will return the marginal * prob that the label at position 3 is 1, the label at position 4 is 2 and * the label at position 5 is 3. */ public double prob(int position, int[] labels) { return Math.exp(logProb(position, labels)); }
/** * Return the score of the proposed tags for position given. * @param tags is an array indicating the assignment of labels to score. * @param pos is the position to return a score for. */ @Override public double scoreOf(int[] tags, int pos) { int[] previous = new int[window - 1]; int realPos = pos - window + 1; for (int i = 0; i < window - 1; i++) { previous[i] = tags[realPos + i]; } return cliqueTree.condLogProbGivenPrevious(realPos, tags[pos], previous); }
/** * Takes a {@link List} of something that extends {@link CoreMap} and prints the likelihood * of each possible label at each point. * * @param document A {@link List} of something that extends CoreMap. */ @Override public void printProbsDocument(List<IN> document) { Pair<int[][][],int[]> p = documentToDataAndLabels(document); int[][][] data = p.first(); //FactorTable[] factorTables = CRFLogConditionalObjectiveFunction.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size()); CRFCliqueTree cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(weights, data, labelIndices, classIndex.size(), classIndex, flags.backgroundSymbol); // for (int i = 0; i < factorTables.length; i++) { for (int i = 0; i < cliqueTree.length(); i++) { IN wi = document.get(i); System.out.print(wi.get(CoreAnnotations.TextAnnotation.class) + '\t'); for (Iterator<String> iter = classIndex.iterator(); iter.hasNext();) { String label = iter.next(); int index = classIndex.indexOf(label); // double prob = Math.pow(Math.E, factorTables[i].logProbEnd(index)); double prob = cliqueTree.prob(i, index); System.out.print(label + '=' + prob); if (iter.hasNext()) { System.out.print("\t"); } else { System.out.print("\n"); } } } }
@Override protected double expectedAndEmpiricalCountsAndValueForADoc(double[][] E, double[][] Ehat, int docIndex) { int[][][] docData = data[docIndex]; double[][][] featureVal3DArr = null; if (featureVal != null) { featureVal3DArr = featureVal[docIndex]; } // make a clique tree for this document CRFCliqueTree<String> cliqueTreeNoisyLabel = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, getFunc(docIndex), featureVal3DArr); CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr); double prob = cliqueTreeNoisyLabel.totalMass() - cliqueTree.totalMass(); documentExpectedCounts(E, docData, featureVal3DArr, cliqueTree); documentExpectedCounts(Ehat, docData, featureVal3DArr, cliqueTreeNoisyLabel); return prob; }
protected double expectedCountsAndValueForADoc(double[][] E, int docIndex, boolean doExpectedCountCalc, boolean doValueCalc) { int[][][] docData = data[docIndex]; double[][][] featureVal3DArr = null; if (featureVal != null) { featureVal3DArr = featureVal[docIndex]; } // make a clique tree for this document CRFCliqueTree<String> cliqueTree = CRFCliqueTree.getCalibratedCliqueTree(docData, labelIndices, numClasses, classIndex, backgroundSymbol, cliquePotentialFunc, featureVal3DArr); double prob = 0.0; if (doValueCalc) { prob = documentLogProbability(docData, docIndex, cliqueTree); } if (doExpectedCountCalc) { documentExpectedCounts(E, docData, featureVal3DArr, cliqueTree); } return prob; }
@Override public int length() { return cliqueTree.length(); }
public GeneralizedCounter<E> logProbs(int position, int window) { GeneralizedCounter<E> gc = new GeneralizedCounter<>(window); int[] labels = new int[window]; // cdm july 2005: below array initialization isn't necessary: JLS (3rd ed.) // 4.12.5 // Arrays.fill(labels, 0); OUTER: while (true) { List<E> labelsList = intArrayToListE(labels); gc.incrementCount(labelsList, logProb(position, labels)); for (int i = 0; i < labels.length; i++) { labels[i]++; if (labels[i] < numClasses) { break; } if (i == labels.length - 1) { break OUTER; } labels[i] = 0; } } return gc; }
double startPosLogProb = cliqueTree.logProbStartPos(); if (VERBOSE) { System.err.printf("P_-1(Background) = % 5.3f%n", startPosLogProb); double p = cliqueTree.condLogProbGivenPrevious(i, label, given); if (VERBOSE) { log.info("P(" + label + "|" + ArrayMath.toString(given) + ")=" + p);
/** * Takes a {@link List} of something that extends {@link CoreMap} and prints * the factor table at each point. * * @param document A {@link List} of something that extends {@link CoreMap}. */ @SuppressWarnings("WeakerAccess") public void printFactorTableDocument(List<IN> document) { CRFCliqueTree<String> cliqueTree = getCliqueTree(document); FactorTable[] factorTables = cliqueTree.getFactorTables(); StringBuilder sb = new StringBuilder(); for (int i=0; i < factorTables.length; i++) { IN wi = document.get(i); sb.append(wi.get(CoreAnnotations.TextAnnotation.class)); sb.append('\t'); FactorTable table = factorTables[i]; for (int j = 0; j < table.size(); j++) { int[] arr = table.toArray(j); sb.append(classIndex.get(arr[0])); sb.append(':'); sb.append(classIndex.get(arr[1])); sb.append(':'); sb.append(cliqueTree.logProb(i, arr)); sb.append(' '); } sb.append('\n'); } System.out.print(sb); }