在scala中使用for表达式做monad运算
在haskell中,我们有语法糖‘do’帮助表达monad运算。scala中我们也有相应语法糖‘for’。
for表达式会被scala compiler做一些变换,简单的例子如下:
for { a <- foo b <- bar } yield (a + b)
===>
foo.flatMap((a) => { bar.map((b) => { a + b }) })
所以我们需要实现两个方法 flatMap和map。
还是用前面的state monad作为例子, 我们给类型State加上flatMap和map。
case class State[S, A](runState: S => (S, A))(implicit m : Monad[({type M[a] = State[S, a]})#M]) { def map[B](f: A => B) : State[S, B] = m.bind(this, (a: A) => m.ret(f(a))) def flatMap[B](f: A => State[S, B]) : State[S, B] = m.bind(this, f) }
这里我们使用了一个隐式参数,然后我们可以直接使用ret和bind。
同时加一个helper简化Monad[({type M[a] = State[S, a]})#M].ret
def ret[S, A](a: A) : State[S, A] = Monad[({type M[a] = State[S, a]})#M].ret(a)
好了,我们可以使用for表达式了,例子如下:
object Main { import StateMonad._ def main(args: Array[String]) { val r = for { a <- ret[Int, Int](3) b <- ret[Int, Int](4) } yield (a+b) println(r.runState(1)) } }