/** * Creates a processing pipeline. * @return a pipeline */ private Pipeline createPipeline() { Tokenizer tokenizer = new Tokenizer() .setInputCol("featureStrings") .setOutputCol("tokens"); CountVectorizer countVectorizer = new CountVectorizer() .setInputCol("tokens") .setOutputCol("features") .setMinDF((Double)params.getOrDefault(params.getMinFF())) .setVocabSize((Integer)params.getOrDefault(params.getNumFeatures())); StringIndexer tagIndexer = new StringIndexer() .setInputCol("tag") .setOutputCol("label"); Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, countVectorizer, tagIndexer}); return pipeline; }
/** * Creates a processing pipeline. * @return a pipeline */ protected Pipeline createPipeline() { Tokenizer tokenizer = new Tokenizer() .setInputCol("text") .setOutputCol("tokens"); CountVectorizer countVectorizer = new CountVectorizer() .setInputCol("tokens") .setOutputCol("features") .setMinDF((Double)params.getOrDefault(params.getMinFF())) .setVocabSize((Integer)params.getOrDefault(params.getNumFeatures())); StringIndexer transitionIndexer = new StringIndexer() .setInputCol("transition") .setOutputCol("label"); Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{tokenizer, countVectorizer, transitionIndexer}); return pipeline; }
@Test public void testStringIndexer() { StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List<Row> data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); Dataset<Row> dataset = spark.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); Dataset<Row> output = indexer.fit(dataset).transform(dataset); Assert.assertEquals( Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)), output.orderBy("id").select("id", "labelIndex").collectAsList()); }
@Test public void testStringIndexer() { StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List<Row> data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); Dataset<Row> dataset = spark.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); Dataset<Row> output = indexer.fit(dataset).transform(dataset); Assert.assertEquals( Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)), output.orderBy("id").select("id", "labelIndex").collectAsList()); }
@Test public void testStringIndexer() { StructType schema = createStructType(new StructField[]{ createStructField("id", IntegerType, false), createStructField("label", StringType, false) }); List<Row> data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); Dataset<Row> dataset = spark.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); Dataset<Row> output = indexer.fit(dataset).transform(dataset); Assert.assertEquals( Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)), output.orderBy("id").select("id", "labelIndex").collectAsList()); }