Friday, November 20, 2015

Avoiding vars in Scala

One of the difficulties that everyone new to functional programming has is avoiding vars and mutable types. These aren't the same thing of course but they're similar enough for purposes of this discussion.

Of course, the first thing to understand is why are such mutable objects considered harmful--why would we want to avoid them? Without getting into an essay on referential transparency, we should aim to preserve RT in our code for the following reasons: referentially transparent programs are reliably testable and potentially provable.

So, let's say we have some code like the following and we want to avoid using vars and make it more functional (so that we can write unit tests on it):

  def printStatistics(xs: Seq[Double]) {
    var c = 0
    var s = 0.0
    var v = 0.0
    for (x <- xs) {
      if (x >= 0 && x <= 1) {
        c += 1
        s += x
        v += x*x
      }
    }
    println(s"Mean: ${s/c}, Std. Dev: ${math.sqrt(v/c)}")
  }

I chose this particular bit of code because it involves two of the features that most seem to get in the way of functionalizing the code, viz. calculating more than one quantity and having complex logic. Here, we want to calculate both the mean and the standard deviation of a sequence of values. However, we also want to exclude, for some reason, values that are outliers.

We can deal with the multiplicity of quantities by turning to one of the functional programmer's best friends: the tuple. Let's group our three quantities together in a 3-tuple: (c,s,v) for the purposes of evaluation and then we will return a 2-tuple (mu, sigma) as the result of the method.  Secondly, the logic which was complicating the calculations can in this case be easily refactored: we filter the values of xs.

Finally, there is one aspect of rolling up vars that is definitely tricky: because iteration uses vars we need to map the iteration-with-vars code into recursive code. Again, without getting into a formal proof, we can say that recursion (with immutable objects) is the dual of iteration (with mutable objects).

How are we going to turn this particular code fragment into something recursive? What's the appropriate pattern? Well, here another feature of functional programming comes to the rescue: higher-order functions. In particular, there is a family of higher-order functions which provide the necessary recursive pattern: Foldables. A Foldable is an extension of Functor (which defines the map method) and looks something like this:

  trait Foldable[F[_]] extends Functor[F] {
    def foldLeft[A,B](z: B)(f: (B, A) => B): B
    def foldRight[A,B](z: B)(f: (A, B) => B): B
  }

Where F[_] represents a collection type such as Seq[A] or List[A]. z is the initial ("zero") value of our accumulator, f combines the current accumulator (first parameter) with the current element. After all elements of our collection have been visited, we return the accumulator as the result. If A happens to be a monoid (a type that has its own zero and combine methods), then we can define a reduce method:

      def reduce[A](): A

But foldLeft and foldRight don't actually require monoids because we provide the zero method and the combine method explicitly. This is what enables us, in the present instance, to keep two separate sums going, one for the mean and one for the variance. This is what the final code looks like:

  def getStatistics(xs: Seq[Double]) = {
      val y = xs filter { x => x>=0 && x<=1 }
      val r = y.foldLeft[(Int,Double,Double)]((0,0,0)){case ((c,s,v),x) => (c+1,s+x,v+x*x)}
      (r._2/r._1,math.sqrt(r._3/r._1))
  }

Here, we are explicitly counting the elements of y (although y.size would do just as well). First we filter the input to remove outliers and then we perform our calculations using foldLeft. The initial value is easy, of course: (0,0,0). The combiner function is also quite straightforward: strictly speaking it is a partial function matching the case where the inputs are presented as (c,s,v) [the accumulated values] and x [the current element from the sequence]. The returned value is simply: (c+1,s+x,v+x*x).

Finally, we return a tuple of the mean and standard deviation. We can of course print these values as was done before in the printStatistics method. But we can also test these values in unit tests to ensure that we have our logic right.

No comments:

Post a Comment