@Override protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) { CRFLogConditionalObjectiveFloatFunction func = new CRFLogConditionalObjectiveFloatFunction(data, labels, windowSize, classIndex, labelIndices, map, flags.backgroundSymbol, flags.sigma); cliquePotentialFunctionHelper = func; initialWeights = func.initial(); } else { try {
public float[] to1D(float[][] weights) { float[] newWeights = new float[domainDimension()]; int index = 0; for (float[] weight : weights) { System.arraycopy(weight, 0, newWeights, index, weight.length); index += weight.length; } return newWeights; }
CRFLogConditionalObjectiveFloatFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, int prior, String backgroundSymbol, double sigma) { this.window = window; this.classIndex = classIndex; this.numClasses = classIndex.size(); this.labelIndices = labelIndices; this.map = map; this.data = data; this.labels = labels; this.prior = prior; this.backgroundSymbol = backgroundSymbol; this.sigma = (float) sigma; empiricalCounts(data, labels); }
float[][] weights = to2D(x); float prob = 0; float[][] E = empty2D(); FloatFactorTable[] factorTables = getCalibratedCliqueTree(weights, data[m], labelIndices, numClasses);
calculateWeird(x); return; float[][] weights = to2D(x); float prob = 0; float[][] E = empty2D(); FloatFactorTable[] factorTables = getCalibratedCliqueTree(weights, data[m], labelIndices, numClasses);
CRFLogConditionalObjectiveFloatFunction func = new CRFLogConditionalObjectiveFloatFunction(data, labels, featureIndex, windowSize, classIndex, labelIndices, map, flags.backgroundSymbol, flags.sigma); func.crfType = flags.crfType; initialWeights = func.initial(); } else { try { this.weights = ArrayMath.floatArrayToDoubleArray(func.to2D(weights));
private void empiricalCounts(int[][][][] data, int[][] labels) { Ehat = empty2D(); for (int m = 0; m < data.length; m++) { int[][][] dataDoc = data[m]; int[] labelsDoc = labels[m]; int[] label = new int[window]; //Arrays.fill(label, classIndex.indexOf("O")); Arrays.fill(label, classIndex.indexOf(backgroundSymbol)); for (int i = 0; i < dataDoc.length; i++) { System.arraycopy(label, 1, label, 0, window - 1); label[window - 1] = labelsDoc[i]; for (int j = 0; j < dataDoc[i].length; j++) { int[] cliqueLabel = new int[j + 1]; System.arraycopy(label, window - 1 - j, cliqueLabel, 0, j + 1); CRFLabel crfLabel = new CRFLabel(cliqueLabel); int labelIndex = labelIndices.get(j).indexOf(crfLabel); //log.info(crfLabel + " " + labelIndex); for (int k = 0; k < dataDoc[i][j].length; k++) { Ehat[dataDoc[i][j][k]][labelIndex]++; } } } } }
factorTables[i] = getFloatFactorTable(weights, data[i], labelIndices, numClasses); if (VERBOSE) { log.info(i + ": " + factorTables[i]);
float[][] weights = to2D(x); float prob = 0; float[][] E = empty2D(); FloatFactorTable[] factorTables = getCalibratedCliqueTree(weights, data[m], labelIndices, numClasses);
private void empiricalCounts(int[][][][] data, int[][] labels) { Ehat = empty2D(); for (int m = 0; m < data.length; m++) { int[][][] dataDoc = data[m]; int[] labelsDoc = labels[m]; int[] label = new int[window]; //Arrays.fill(label, classIndex.indexOf("O")); Arrays.fill(label, classIndex.indexOf(backgroundSymbol)); for (int i = 0; i < dataDoc.length; i++) { System.arraycopy(label, 1, label, 0, window - 1); label[window - 1] = labelsDoc[i]; for (int j = 0; j < dataDoc[i].length; j++) { int[] cliqueLabel = new int[j + 1]; System.arraycopy(label, window - 1 - j, cliqueLabel, 0, j + 1); CRFLabel crfLabel = new CRFLabel(cliqueLabel); int labelIndex = labelIndices.get(j).indexOf(crfLabel); //log.info(crfLabel + " " + labelIndex); for (int k = 0; k < dataDoc[i][j].length; k++) { Ehat[dataDoc[i][j][k]][labelIndex]++; } } } } }
factorTables[i] = getFloatFactorTable(weights, data[i], labelIndices, numClasses); if (VERBOSE) { System.err.println(i + ": " + factorTables[i]);
float[][] weights = to2D(x); float prob = 0; float[][] E = empty2D(); FloatFactorTable[] factorTables = getCalibratedCliqueTree(weights, data[m], labelIndices, numClasses);
@Override protected double[] trainWeights(int[][][][] data, int[][] labels, Evaluator[] evaluators, int pruneFeatureItr, double[][][][] featureVals) { CRFLogConditionalObjectiveFloatFunction func = new CRFLogConditionalObjectiveFloatFunction(data, labels, featureIndex, windowSize, classIndex, labelIndices, map, flags.backgroundSymbol, flags.sigma); cliquePotentialFunctionHelper = func; initialWeights = func.initial(); } else { try {
public float[] to1D(float[][] weights) { float[] newWeights = new float[domainDimension()]; int index = 0; for (float[] weight : weights) { System.arraycopy(weight, 0, newWeights, index, weight.length); index += weight.length; } return newWeights; }
CRFLogConditionalObjectiveFloatFunction(int[][][][] data, int[][] labels, int window, Index<String> classIndex, List<Index<CRFLabel>> labelIndices, int[] map, int prior, String backgroundSymbol, double sigma) { this.window = window; this.classIndex = classIndex; this.numClasses = classIndex.size(); this.labelIndices = labelIndices; this.map = map; this.data = data; this.labels = labels; this.prior = prior; this.backgroundSymbol = backgroundSymbol; this.sigma = (float) sigma; empiricalCounts(data, labels); }
private void empiricalCounts(int[][][][] data, int[][] labels) { Ehat = empty2D(); for (int m = 0; m < data.length; m++) { int[][][] dataDoc = data[m]; int[] labelsDoc = labels[m]; int[] label = new int[window]; //Arrays.fill(label, classIndex.indexOf("O")); Arrays.fill(label, classIndex.indexOf(backgroundSymbol)); for (int i = 0; i < dataDoc.length; i++) { System.arraycopy(label, 1, label, 0, window - 1); label[window - 1] = labelsDoc[i]; for (int j = 0; j < dataDoc[i].length; j++) { int[] cliqueLabel = new int[j + 1]; System.arraycopy(label, window - 1 - j, cliqueLabel, 0, j + 1); CRFLabel crfLabel = new CRFLabel(cliqueLabel); int labelIndex = labelIndices.get(j).indexOf(crfLabel); //System.err.println(crfLabel + " " + labelIndex); for (int k = 0; k < dataDoc[i][j].length; k++) { Ehat[dataDoc[i][j][k]][labelIndex]++; } } } } }
factorTables[i] = getFloatFactorTable(weights, data[i], labelIndices, numClasses); if (VERBOSE) { System.err.println(i + ": " + factorTables[i]);