在scala中使用for表达式做monad运算

在haskell中,我们有语法糖‘do’帮助表达monad运算。scala中我们也有相应语法糖‘for’。

for表达式会被scala compiler做一些变换,简单的例子如下:

for {
  a <- foo
  b <- bar
} yield (a + b)

===>

foo.flatMap((a) => {
  bar.map((b) => {
     a + b
  })
})

所以我们需要实现两个方法 flatMap和map。

还是用前面的state monad作为例子, 我们给类型State加上flatMap和map。

case class State[S, A](runState: S => (S, A))(implicit m : Monad[({type M[a] = State[S, a]})#M]) {
  def map[B](f: A => B) : State[S, B] = m.bind(this, (a: A) => m.ret(f(a)))
  def flatMap[B](f: A => State[S, B]) : State[S, B] = m.bind(this, f)
}

这里我们使用了一个隐式参数,然后我们可以直接使用ret和bind。

同时加一个helper简化Monad[({type M[a] = State[S, a]})#M].ret

def ret[S, A](a: A) : State[S, A] = Monad[({type M[a] = State[S, a]})#M].ret(a)

好了,我们可以使用for表达式了,例子如下:

object Main {
  
  import StateMonad._
  
  def main(args: Array[String]) {
    val r = for {
        a <- ret[Int, Int](3)
        b <- ret[Int, Int](4)
    } yield (a+b)
    println(r.runState(1))
  }

}
posted @ 2015-03-12 13:21  wehu  阅读(381)  评论(0编辑  收藏  举报