Dataset<Row> df = spark.table("testData"); df.select("key", "value"); df.select(col("key"), col("value")); df.selectExpr("key", "value + 1"); df.sort("key", "value"); df.sort(col("key"), col("value")); df.orderBy("key", "value"); df.orderBy(col("key"), col("value")); df.groupBy("key", "value").agg(col("key"), col("value"), sum("value")); df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value")); df.agg(first("key"), sum("value")); df.groupBy().agg(countDistinct("key", "value")); df.groupBy().agg(countDistinct(col("key"), col("value"))); df.select(coalesce(col("key"))); df2.select(exp("a"), exp("b")); df2.select(exp(log("a"))); df2.select(pow("a", "a"), pow("b", 2.0)); df2.select(pow(col("a"), col("b")), exp("b")); df2.select(sin("a"), acos("b")); df2.select(rand(), acos("b")); df2.select(col("*"), randn(5L));
/** * Returns the values for the given URI and version. * * @param uri the uri of the value set for which we get values * @param version the version of the value set for which we get values * @return a dataset of values for the given URI and version. */ public Dataset<Value> getValues(String uri, String version) { return this.values.where(col("valueseturi").equalTo(lit(uri)) .and(col("valuesetversion").equalTo(lit(version)))); }
@Test public void testUDF() { UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); Dataset<Row> df = spark.table("testData").select(foo.apply(col("key"), col("value"))); String[] result = df.collectAsList().stream().map(row -> row.getString(0)) .toArray(String[]::new); String[] expected = spark.table("testData").collectAsList().stream() .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); Assert.assertArrayEquals(expected, result); } }
@Test public void testSelect() { List<Integer> data = Arrays.asList(2, 6); Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()); Dataset<Tuple2<Integer, String>> selected = ds.select( expr("value + 1"), col("value").cast("string")).as(Encoders.tuple(Encoders.INT(), Encoders.STRING())); Assert.assertEquals( Arrays.asList(tuple2(3, "2"), tuple2(7, "6")), selected.collectAsList()); }
@Test public void testSampleBy() { Dataset<Row> df = spark.range(0, 100, 1, 2).select(col("id").mod(3).as("key")); Dataset<Row> sampled = df.stat().sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L); List<Row> actual = sampled.groupBy("key").count().orderBy("key").collectAsList(); Assert.assertEquals(0, actual.get(0).getLong(0)); Assert.assertTrue(0 <= actual.get(0).getLong(1) && actual.get(0).getLong(1) <= 8); Assert.assertEquals(1, actual.get(1).getLong(0)); Assert.assertTrue(2 <= actual.get(1).getLong(1) && actual.get(1).getLong(1) <= 13); }
protected C withConceptMaps(Dataset<T> newMaps, Dataset<Mapping> newMappings) { Dataset<UrlAndVersion> newMembers = getUrlAndVersions(newMaps); // Instantiating a new composite ConceptMaps requires a new timestamp Timestamp timestamp = new Timestamp(System.currentTimeMillis()); Dataset<T> newMapsWithTimestamp = newMaps .withColumn("timestamp", lit(timestamp.toString()).cast("timestamp")) .as(conceptMapEncoder); return newInstance(spark, this.members.union(newMembers), this.conceptMaps.union(newMapsWithTimestamp), this.mappings.union(newMappings)); }
/** * Reads a Snomed relationship file and converts it to a {@link HierarchicalElement} dataset. * * @param spark the Spark session * @param snomedRelationshipPath path to the SNOMED relationship file * @return a dataset of{@link HierarchicalElement} representing the hierarchical relationship. */ public static Dataset<HierarchicalElement> readRelationshipFile(SparkSession spark, String snomedRelationshipPath) { return spark.read() .option("header", true) .option("delimiter", "\t") .csv(snomedRelationshipPath) .where(col("typeId").equalTo(lit(SNOMED_ISA_RELATIONSHIP_ID))) .where(col("active").equalTo(lit("1"))) .select(col("destinationId"), col("sourceId")) .where(col("destinationId").isNotNull() .and(col("destinationId").notEqual(lit("")))) .where(col("sourceId").isNotNull() .and(col("sourceId").notEqual(lit("")))) .map((MapFunction<Row, HierarchicalElement>) row -> { HierarchicalElement element = new HierarchicalElement(); element.setAncestorSystem(SNOMED_CODE_SYSTEM_URI); element.setAncestorValue(row.getString(0)); element.setDescendantSystem(SNOMED_CODE_SYSTEM_URI); element.setDescendantValue(row.getString(1)); return element; }, Hierarchies.getHierarchicalElementEncoder()); }
.as("toload"); .as("present") .join( referencesToLoad, col("present.valueSetUri").equalTo(col("toload.valueSetUri")) .and(col("present.valueSetVersion").equalTo(col("toload.valueSetVersion")))) .select("referenceName", "system", "value") .collectAsList(); .join( ancestorsToLoad, col("present.uri").equalTo(col("toload.uri")) .and(col("present.version").equalTo(col("toload.version"))) .and(col("present.ancestorSystem").equalTo(col("toload.ancestorSystem"))) .and(col("present.ancestorValue").equalTo(col("toload.ancestorValue")))) .select("referenceName", "descendantSystem", "descendantValue") .collectAsList();
@Test public void testJoin() { List<Integer> data = Arrays.asList(1, 2, 3); Dataset<Integer> ds = spark.createDataset(data, Encoders.INT()).as("a"); List<Integer> data2 = Arrays.asList(2, 3, 4); Dataset<Integer> ds2 = spark.createDataset(data2, Encoders.INT()).as("b"); Dataset<Tuple2<Integer, Integer>> joined = ds.joinWith(ds2, col("a.value").equalTo(col("b.value"))); Assert.assertEquals( Arrays.asList(tuple2(2, 2), tuple2(3, 3)), joined.collectAsList()); }
/** * Returns the latest versions of a given set of concept maps. * * @param urls a set of URLs to retrieve the latest version for, or null to load them all. * @param includeExperimental flag to include concept maps marked as experimental * * @return a map of concept map URLs to the latest version for them. */ public Map<String,String> getLatestVersions(final Set<String> urls, boolean includeExperimental) { // Reduce by the concept map URI to return only the latest version // per concept map. Spark's provided max aggregation function // only works on numeric types, so we jump into RDDs and perform // the reduce by hand. JavaRDD<UrlAndVersion> changes = this.conceptMaps.select(col("url"), col("version"), col("experimental")) .toJavaRDD() .filter(row -> (urls == null || urls.contains(row.getString(0))) && (includeExperimental || row.isNullAt(2) || !row.getBoolean(2))) .mapToPair(row -> new Tuple2<>(row.getString(0), row.getString(1))) .reduceByKey((leftVersion, rightVersion) -> leftVersion.compareTo(rightVersion) > 0 ? leftVersion : rightVersion) .map(tuple -> new UrlAndVersion(tuple._1, tuple._2)); return this.spark.createDataset(changes.rdd(), URL_AND_VERSION_ENCODER) .collectAsList() .stream() .collect(Collectors.toMap(UrlAndVersion::getUrl, UrlAndVersion::getVersion)); }
@Test public void testUDAF() { Dataset<Row> df = hc.range(0, 100).union(hc.range(0, 100)).select(col("id").as("value")); UserDefinedAggregateFunction udaf = new MyDoubleSum(); UserDefinedAggregateFunction registeredUDAF = hc.udf().register("mydoublesum", udaf); // Create Columns for the UDAF. For now, callUDF does not take an argument to specific if // we want to use distinct aggregation. Dataset<Row> aggregatedDF = df.groupBy() .agg( udaf.distinct(col("value")), udaf.apply(col("value")), registeredUDAF.apply(col("value")), callUDF("mydoublesum", col("value"))); List<Row> expectedResult = new ArrayList<>(); expectedResult.add(RowFactory.create(4950.0, 9900.0, 9900.0, 9900.0)); checkAnswer( aggregatedDF, expectedResult); } }
/** * Returns the value set with the given uri and version, or null if there is no such value set. * * @param uri the uri of the value set to return * @param version the version of the value set to return * @return the specified value set. */ public T getValueSet(String uri, String version) { // Load the value sets, which may contain zero items if the value set does not exist // Typecast necessary to placate the Java compiler calling this Scala function T[] valueSets = (T[]) this.valueSets.filter( col("url").equalTo(lit(uri)) .and(col("version").equalTo(lit(version)))) .head(1); if (valueSets.length == 0) { return null; } else { T valueSet = valueSets[0]; Dataset<Value> filteredValues = getValues(uri, version); addToValueSet(valueSet, filteredValues); return valueSet; } }
/** * Returns all value sets that are disjoint with value sets stored in the given database and * adds them to our collection. The directory may be anything readable from a Spark path, * including local filesystems, HDFS, S3, or others. * * @param path a path from which disjoint value sets will be loaded * @param database the database to check value sets against * @return an instance of ValueSets that includes content from that directory that is disjoint * with content already contained in the given database. */ public C withDisjointValueSetsFromDirectory(String path, String database) { Dataset<UrlAndVersion> currentMembers = this.spark.table(database + "." + VALUE_SETS_TABLE) .select("url", "version") .distinct() .as(URL_AND_VERSION_ENCODER) .alias("current"); Dataset<T> valueSets = valueSetDatasetFromDirectory(path) .alias("new") .join(currentMembers, col("new.url").equalTo(col("current.url")) .and(col("new.version").equalTo(col("current.version"))), "leftanti") .as(valueSetEncoder); return withValueSets(valueSets); }
@Test public void pivot() { Dataset<Row> df = spark.table("courseSales"); List<Row> actual = df.groupBy("year") .pivot("course", Arrays.asList("dotNET", "Java")) .agg(sum("earnings")).orderBy("year").collectAsList(); Assert.assertEquals(2012, actual.get(0).getInt(0)); Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); Assert.assertEquals(2013, actual.get(1).getInt(0)); Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); }
ds.printSchema(); ds.show(); ds.where(col("from").getField("x").gt(7.0)).select(col("to")).show(); .where(col("points").getItem(2).getField("y").gt(7.0)) .select(col("name"), size(col("points")).as("count")).show(); .where(size(col("points")).gt(1)) .select(col("name"), size(col("points")).as("count"), col("points").getItem("p1")).show();
@Test public void saveTableAndQueryIt() { checkAnswer( df.select(avg("key").over( Window.partitionBy("value").orderBy("key").rowsBetween(-1, 1))), hc.sql("SELECT avg(key) " + "OVER (PARTITION BY value " + " ORDER BY key " + " ROWS BETWEEN 1 preceding and 1 following) " + "FROM window_table").collectAsList()); }