@Override public void collectionProcessComplete() throws AnalysisEngineProcessException { try { ParallelTopicModel model = new ParallelTopicModel(nTopics, alphaSum, beta); model.addInstances(getInstanceList()); model.setNumThreads(getNumThreads()); model.setNumIterations(nIterations); model.setBurninPeriod(burninPeriod); model.setOptimizeInterval(optimizeInterval); model.setRandomSeed(randomSeed); model.setSaveSerializedModel(saveInterval, getTargetLocation()); model.setSymmetricAlpha(useSymmetricAlpha); model.setTopicDisplay(displayInterval, displayNTopicWords); model.estimate(); getLogger().info("Writing model to " + getTargetLocation()); File targetFile = new File(getTargetLocation()); if (targetFile.getParentFile() != null) { targetFile.getParentFile().mkdirs(); } model.write(targetFile); } catch (IOException | SecurityException e) { throw new AnalysisEngineProcessException(e); } }
ArrayList<Integer> features = new ArrayList<Integer>(); Alphabet seqAlphabet = lda.getAlphabet(); int numTopics = lda.getNumTopics(); Object[][] sorted = lda.getTopWords(seqAlphabet.size());
public void topicXMLReport (PrintWriter out, int numWords) { ArrayList<TreeSet<IDSorter>> topicSortedWords = getSortedWords(); out.println("<?xml version='1.0' ?>"); out.println("<topicModel>"); for (int topic = 0; topic < numTopics; topic++) { out.println(" <topic id='" + topic + "' alpha='" + alpha[topic] + "' totalTokens='" + tokensPerTopic[topic] + "'>"); int word = 1; Iterator<IDSorter> iterator = topicSortedWords.get(topic).iterator(); while (iterator.hasNext() && word <= numWords) { IDSorter info = iterator.next(); out.println(" <word rank='" + word + "'>" + alphabet.lookupObject(info.getID()) + "</word>"); word++; } out.println(" </topic>"); } out.println("</topicModel>"); }
/** * Initialize. * * @param modelFile * the file containing the model * @param nWords * the number of words that should be written for each topic * @throws IOException * if the model cannot be read */ public PrintTopicWordWeights(File modelFile, int nWords) throws IOException { try { model = ParallelTopicModel.read(modelFile); } catch (Exception e) { throw new IOException(e); } alphabet = model.getAlphabet(); this.nWords = nWords; }
/** * Estimate a topic model for collaborative filtering data. * * @param <U> user type * @param <I> item type * @param preferences preference data * @param k number of topics * @param alpha alpha in model * @param beta beta in model * @param numIterations number of iterations * @param burninPeriod burnin period * @return a topic model * @throws IOException when internal IO error occurs */ public static <U, I> ParallelTopicModel estimate(FastPreferenceData<U, I> preferences, int k, double alpha, double beta, int numIterations, int burninPeriod) throws IOException { ParallelTopicModel topicModel = new ParallelTopicModel(k, alpha * k, beta); topicModel.addInstances(new LDAInstanceList<>(preferences)); topicModel.setTopicDisplay(numIterations + 1, 0); topicModel.setNumIterations(numIterations); topicModel.setBurninPeriod(burninPeriod); topicModel.setNumThreads(Runtime.getRuntime().availableProcessors()); topicModel.estimate(); return topicModel; }
public static void main (String[] args) { try { InstanceList training = InstanceList.load (new File(args[0])); int numTopics = args.length > 1 ? Integer.parseInt(args[1]) : 200; ParallelTopicModel lda = new ParallelTopicModel (numTopics, 50.0, 0.01); lda.printLogLikelihood = true; lda.setTopicDisplay(50, 7); lda.addInstances(training); lda.setNumThreads(Integer.parseInt(args[2])); lda.estimate(); logger.info("printing state"); lda.printState(new File("state.gz")); logger.info("finished printing"); } catch (Exception e) { e.printStackTrace(); } }
topicModel = ParallelTopicModel.read(new File(inputModelFilename.value)); } catch (Exception e) { logger.warning("Unable to restore saved topic model " + topicModel = new ParallelTopicModel (numTopics.value, alpha.value, beta.value); topicModel.setRandomSeed(randomSeed.value); topicModel.addInstances(training); topicModel.initializeFromState(new File(inputStateFilename.value)); topicModel.setTopicDisplay(showTopicsInterval.value, topWords.value); topicModel.setNumIterations(numIterations.value); topicModel.setOptimizeInterval(optimizeInterval.value); topicModel.setBurninPeriod(optimizeBurnIn.value); topicModel.setSymmetricAlpha(useSymmetricAlpha.value); topicModel.setSaveState(outputStateInterval.value, stateFile.value); topicModel.setSaveSerializedModel(outputModelInterval.value, outputModelFilename.value); topicModel.setNumThreads(numThreads.value); topicModel.estimate(); topicModel.maximize(numMaximizationIterations.value);
ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01); model.addInstances(instances); model.setNumThreads(2); model.setNumIterations(50); model.estimate(); FeatureSequence tokens = (FeatureSequence) model.getData().get(0).instance.getData(); LabelSequence topics = model.getData().get(0).topicSequence; double[] topicDistribution = model.getTopicProbabilities(0); ArrayList<TreeSet<IDSorter>> topicSortedWords = model.getSortedWords(); testing.addThruPipe(new Instance(topicZeroText.toString(), null, "test instance", null)); TopicInferencer inferencer = model.getInferencer(); double[] testProbabilities = inferencer.getSampledDistribution(testing.get(0), 10, 1, 5); System.out.println("0\t" + testProbabilities[0]);
ParallelTopicModel topicModel = new ParallelTopicModel(labeledLDA.topicAlphabet, labeledLDA.alpha * labeledLDA.numTopics, labeledLDA.beta); topicModel.data = labeledLDA.data; topicModel.alphabet = labeledLDA.alphabet; topicModel.numTypes = labeledLDA.numTypes; topicModel.betaSum = labeledLDA.betaSum; topicModel.buildInitialTypeTopicCounts(); topicModel.topicXMLReport(out, numTopWords.value); out.close(); topicModel.topicPhraseXMLReport(out, numTopWords.value); out.close(); topicModel.printState (new File(stateFile.value)); topicModel.printTopicDocuments(out, numTopDocs.value); out.close(); PrintWriter out = new PrintWriter (new FileWriter ((new File(docTopicsFile.value)))); if (docTopicsThreshold.value == 0.0) { topicModel.printDenseDocumentTopics(out); topicModel.printDocumentTopics(out, docTopicsThreshold.value, docTopicsMax.value); topicModel.printTopicWordWeights(new File (topicWordWeightsFile.value)); topicModel.printTypeTopicCounts(new File (wordTopicCountsFile.value)); ObjectOutputStream oos =
public static void main (String[] args) throws Exception { InstanceList instances = InstanceList.load(new File(args[0])); ParallelTopicModel model = new ParallelTopicModel(50, 5.0, 0.01); model.addInstances(instances); model.setNumIterations(100); model.estimate(); TopicReports reports = new JSONTopicReports(model); reports.printSummary(new File("summary.json"), 20); }
ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01); model.addInstances(instances); model.setNumThreads(1); // Important, since this is being run in the reduce, just use one thread model.setTopicDisplay(0,0); model.setNumIterations(2000); model.estimate(); ArrayList<TopicAssignment> assignments = model.getData(); for (int topicNum = 0; topicNum < model.getNumTopics(); topicNum++) { TreeSet<IDSorter> sortedWords = model.getSortedWords().get(topicNum); Iterator<IDSorter> iterator = sortedWords.iterator();
@Override protected void execute(JobSettings settings) throws AnalysisEngineProcessException { InstanceList instances = new InstanceList(new TopicModelPipe(stopwords)); instances.addThruPipe(getDocumentsFromMongo()); ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01); model.setNumThreads(numThreads); model.setNumIterations(numIterations); model.addInstances(instances); try { model.estimate(); } catch (IOException e) { getMonitor().warn("Couldn't estimate topic model"); throw new AnalysisEngineProcessException(e); } File serializedModelFile = new File(modelFile); try { Files.createDirectories(serializedModelFile.toPath().getParent()); model.write(serializedModelFile); writeTopicAssignmentsToMongo(instances, new TopicWords(model), model); } catch (IOException e) { throw new AnalysisEngineProcessException("Error writing model", new Object[0], e); } }
+ optimizationInterval); ParallelTopicModel malletParallelModel = new ParallelTopicModel(numTopics, alphaSum, beta); Model model = new Model(); try { LOGGER.info("Start preprocessing"); malletParallelModel.addInstances(instances); malletParallelModel.setNumThreads(numThreads); malletParallelModel.setNumIterations(numIterations); malletParallelModel.setOptimizeInterval(optimizationInterval); LOGGER.info("Start training"); malletParallelModel.estimate(); model.malletModel = malletParallelModel; model.modelId = modelId;
public void topicPhraseXMLReport(PrintWriter out, int numWords) { int numTopics = this.getNumTopics(); gnu.trove.TObjectIntHashMap<String>[] phrases = new gnu.trove.TObjectIntHashMap[numTopics]; Alphabet alphabet = this.getAlphabet(); for (int di = 0; di < this.getData().size(); di++) { TopicAssignment t = this.getData().get(di); Instance instance = t.instance; FeatureSequence fvs = (FeatureSequence) instance.getData(); for (int pi = 0; pi < doclen; pi++) { feature = fvs.getIndexAtPosition(pi); topic = this.getData().get(di).topicSequence.getIndexAtPosition(pi); if (topic == prevtopic && (!withBigrams || ((FeatureSequenceWithBigrams)fvs).getBiIndexAtPosition(pi) != -1)) { if (sb == null) out.println("<topics>"); ArrayList<TreeSet<IDSorter>> topicSortedWords = getSortedWords(); double[] probs = new double[alphabet.size()]; for (int ti = 0; ti < numTopics; ti++) {
public TopicModelDiagnostics (ParallelTopicModel model, int numTopWords) { numTopics = model.getNumTopics(); this.numTopWords = numTopWords; alphabet = model.getAlphabet(); topicSortedWords = model.getSortedWords();
ParallelTopicModel model; try { model = ParallelTopicModel.read(modelFile); Alphabet alphabet = model.getAlphabet(); List<Map<String, Double>> topics = new ArrayList<>(model.getNumTopics()); for (TreeSet<IDSorter> topic : model.getSortedWords()) { Map<String, Double> topicWords = new HashMap<>(nWords);
@Override public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); ParallelTopicModel model; try { getLogger().info("Loading model file " + modelLocation); model = ParallelTopicModel.read(modelLocation); if (maxTopicAssignments <= 0) { maxTopicAssignments = model.getNumTopics() / 10; } } catch (Exception e) { throw new ResourceInitializationException(e); } getLogger().info("Model loaded."); inferencer = model.getInferencer(); malletPipe = new TokenSequence2FeatureSequence(model.getAlphabet()); try { sequenceGenerator = new PhraseSequenceGenerator.Builder() .featurePath(tokenFeaturePath) .minTokenLength(minTokenLength) .lowercase(lowercase) .buildStringSequenceGenerator(); } catch (IOException e) { throw new ResourceInitializationException(e); } }
/** * @param topic * @param number of keywords required * @return key words for topic */ public List<String> forTopic(int topic, int number) { return model .getSortedWords() .get(topic) .stream() .map(IDSorter::getID) .map(model.getAlphabet()::lookupObject) .map(Object::toString) .limit(number) .collect(Collectors.toList()); } }
/** * Print the top n words of each topic into a file. * * @param modelFile * the model file * @param targetFile * the file in which the topic words are written * @param nWords * the number of words to extract * @throws IOException * if the model file cannot be read or if the target file cannot be written */ public static void printTopicWords(File modelFile, File targetFile, int nWords) throws IOException { boolean newLineAfterEachWord = false; ParallelTopicModel model; try { model = ParallelTopicModel.read(modelFile); } catch (Exception e) { throw new IOException(e); } model.printTopWords(targetFile, nWords, newLineAfterEachWord); }
TopicInferencer inferencer = model.getInferencer(); double[] testProbabilities = inferencer .getSampledDistribution(testing.get(0), 10, 1, 5);