Type safety on Spark DataFrames - Part 1

At 51zero we've been using Apache Spark since early 2014 to deliver Big Data solutions for our clients.  We're huge fans of the framework and we've enjoyed watching it's development as well as contributing to the project.  

Spark 1.3 introduced the DataFrame API in March 2015, in some respects, this was a great improvement in comparison to the RDD API as spark could now optimize your queries thanks to the Catalyst optimizer.

However, for those of us who like to program in Scala because of its strong typing, it left us with mixed feelings.

For instance, do you know what this function does to your DataFrame?

def computeStuff(df: DataFrame): DataFrame

Without a good documentation, it is impossible to know:

  • what are the required columns in the input DataFrame?
  • what are the columns added to the output DataFrame?
  • what are the types of the input/output columns: are they String, Double, Int?

If you have a non-trivial program which composes several such transformations, it becomes tricky to follow what is going on.

Without proper unit testing, your program becomes brittle and breaks with simple changes.
You start to feel as if you were using some kind of dynamic language. This can be beneficial in some situations, but then why would you use Scala for that?

A good API should let you manipulate data flexibly when you do not know its structure in advance, but it should allow you to put strong constraints when you do know its structure.

For instance, a program which executes a parametrized SQL statement to write data to a directory would use a 'flexible' API => DataFrame

On the other hand, a program which executes pre-defined statements to compute a sum or an average should use a 'constrained' API.

DataSets to the rescue

These problems have been acknowledged by the Spark development team for quite some time.  This is why Spark 1.6 brought us the experimental DataSet API.  But do they really bring us the type safety we are looking for? Let's have a look:

Select

spark-shell
case class Person(id: Int, name: String, age: Int, city: Option[String])
val df = sc.makeRDD(Seq(
(1, "Albert", 72, Some("Paris")),
(2, "Gerard", 55, Some("London")),
(3, "Gerard", 65, None),
(4, "Robert", 63, Some("Paris"))
)).toDF(
"id","name", "age", "city")
val ds = df.as[Person]

First of all, let's try to execute a simple select to narrow down the number of columns:

ds.select(col("name").as[String], $"age".as[Int]).collect()

This returns a properly typed collection, but we have to use a String to refer to the fields of the Person case class, and we are forced to cast the columns to the appropriate type.
However we can alternatively map over the DataSet to keep our type safety:

ds.map(p => (p.name, p.age)).collect()

The inconvenient of is that this code does not use the Catalyst optimizer, and hence we have witnessed it being slower than using select.

Option

A mild annoyance of DataFrames is that it encodes optional values to null pointers. Hence, when we collect our DataFrame above which was created using Options for the city columns, we get:

scala> df.select($"city").collect()
res4: Array[org.apache.spark.sql.Row] = Array([Paris], [London], [null], [Paris])

However DataSet can bring us back our beloved Options 

scala> ds.select($"city".as[Option[String]]).collect() res7: Array[Option[String]] = Array(Some(Paris), Some(London), None, Some(Paris))

Aggregation

groupBy can take a function or a list of columns. If we use the one that takes a function, we keep our type safety

scala> val group = ds.groupBy(_.city)
group: org.apache.spark.sql.GroupedDataset[Option[String],Person] = org.apache.spark.sql.GroupedDataset@71377123

Then we can use mapGroups or agg for aggregating our groups:

scala> ds.groupBy(_.city).agg(max($"age").as[Int]).collect()
res13: Array[(Option[String], Int)] = Array((Some(Paris),72), (Some(London),55), (None,65))

scala> group.mapGroups((city, persons) => (city, persons.map(_.age).max)).collect()
res14: Array[(Option[String], Int)] = Array((Some(Paris),72), (Some(London),55), (None,65))

From a type safety point of view, mapGroups seems to be the way to go. However our experience is that this is slower than using avg.

The mapGroup's scaladoc mentions

"This function does not support partial aggregation, and as a result requires shuffling all the data in the [[Dataset]].  If an application intends to perform an aggregation over each key, it is best to use the reduce function or an [[Aggregator]]."

If we look at GroupedDataset.reduce, it appears that it calls flatMapGroups, which also requires shuffling all the data in the DataSet.
Let's see how we can use Aggregator:

scala> val personMaxAgeAgg = new Aggregator[Person, Int, Int] {
 | def zero: Int = 0
 | def reduce(b: Int, a: Person): Int = b.max(a.age)
 | def merge(b1: Int, b2: Int): Int = b1.max(b2)
 | def finish(r: Int): Int = r
 | }.toColumn
personMaxAgeAgg: org.apache.spark.sql.TypedColumn[Person,Int] = $anon$1()

scala> group.agg(personMaxAgeAgg).collect()
res16: Array[(Option[String], Int)] = Array((Some(Paris),72), (Some(London),55), (None,65))

A bit cumbersome, but we keep our type safety while getting the same performance as with max($"age").as[Int].

User Defined function

We can define a UDF for a DataSet, but as with the select, we have to cast and use strings for the field names.

val fn = udf {age: Int => age +1}
ds.select(fn($"age").as[Int]).collect()
scala> val fn = udf {age: Int => age +1}
fn: org.apache.spark.sql.UserDefinedFunction = UserDefinedFunction(<function1>,IntegerType,List(IntegerType))

scala> ds.select(fn($"age").as[Int]).collect()
res0: Array[Int] = Array(73, 56, 66, 64)

As with select, we could keep our types by using map:

ds.map(_.age+1).collect()
scala> ds.map(_.age+1).collect()
res1: Array[Int] = Array(73, 56, 66, 64)

However, the udf solution is more efficient, as the execution plan can optimize the retrieval of the columns.

If our DataSet was backed by a parquet file, the udf solution would only read the age column. The map solution would have to read and decode the whole row.

Thoughts & What's next ?

It appears that DataSets bring some advantages over DataFrames for a better type safety. However, for some key functionality you need to use casts and refer to fields using Strings.
For a large project, these drawbacks could put you in a situation which is no better than with the good 'old' DataFrames. Your code will still be brittle and can break on simple refactorings. Many errors will be discovered at runtime instead of compile time.

Given that DataSet is currently an experimental API and that Spark 2.0 promises to unify the DataSet and DataFrame API, I would think twice about introducing them in a large project right now.

In the next article I will explore the possibilities of Frameless, this is an exciting open source project which uses Shapeless to bring type safety to DataFrames and DataSets.

We hope you've found this article a useful guide to type safety in Spark, please let us know how you're managing type safety using the comment box below.