Lambda memoization in Java 8
from https://opencredo.com/lambda-memoization-in-java-8/
Memoization is a technique whereby we trade memory for execution speed. Suppose you have a function which
- Is costly to execute.
- Always returns the same output for the same input.
- May be called many times with the same input.
In this scenario, it may make sense to “remember” the output returned for each distinct input in a hash map, and replace function execution with a lookup in the hash map. Here’s a utility method, memoize
, which takes a function and returns a memoized function that remembers the outputs returned for previous invocations:
public final class Memoizer { public static <I, O> Function<I, O> memoize(Function<I, O> f) { Map<I, O> lookup = new HashMap<>(); return input -> lookup.computeIfAbsent(input, f); } }
This works fine provided that the function is only ever used by one thread at once; otherwise there is a risk of concurrent access to the lookup
map, and it’s possible that the same input may be computed twice. If this is a problem, we can replace Map
with ConcurrentMap
, and HashMap
with ConcurrentHashMap
:
public static <I, O> Function<I, O> memoize(Function<I, O> f) { ConcurrentMap<I, O> lookup = new ConcurrentHashMap<>(); return input -> lookup.computeIfAbsent(input, f); }
In both cases, if the value is already in the lookup map, it is returned immediately; otherwise it is calculated by passing the key to the supplied function, placed in the map and returned. In the ConcurrentMap
case, all of this happens atomically, so it is perfectly safe for multiple threads to access the memoized function at once.
Here’s an example of how to use memoized
to memoize an expensive method:
public class Factorial { private static final Function<Long, Long> CACHED = Memoizer.memoize(Factorial::uncached); public static long factorial(long n) { return CACHED.apply(n); } private static long uncached(long n) { long result = n; long m = n - 1; while (m > 1) { result = result * m; m -= 1; } return result; } }
If you’ve seen memoization examples along these lines before, you may be wondering why the uncached
method isn’t defined recursively, so that it populates (or makes use of) the cache for all values between 1 and n. The problem here is that the recursive call will take place inside the call to computeIfAbsent
, and try to modify theConcurrentMap
while it is already being modified by the outer call. This is explicitly forbidden by theConcurrentMap
contract, and leads to undefined behaviour (i.e. the program hangs).
What if we want to memoize a recursive function? Two options suggest themselves:
- Use the non-threadsafe memoization function given above, and make sure it’s never called from more than one thread at once.
- Wrap the non-threadsafe memoization function with a re-entrant lock. This is thread-safe, but a lot less efficient than using
ConcurrentMap
, since the latter locks on individual hash buckets rather than globally locking the entire map during updates.
Here is a thread-safe and recursion-safe implementation using a re-entrant lock:
public static <I, O> Function<I, O> memoize(Function<I, O> f) { Map<I, O> lookup = new HashMap<>(); ReentrantLock lock = new ReentrantLock(); return input -> { lock.lock(); try { return lookup.computeIfAbsent(input, f); } finally { lock.unlock(); } }; }
And here is uncached
rewritten in a recursive style:
private static long uncached(long n) { if (n < 1) { return n; } return n == 1? n : n * factorial(n - 1); }
It may be worth having two versions of memoize
– one recursion-safe and one recursion-unsafe (but both thread-safe), or one thread-safe and one thread-unsafe (but both recursion-safe) – if this is likely to be a problem.
Some further points to bear in mind about this technique:
- Object creation: Every memoized function is effectively a closure over a new
Map
orConcurrentMap
object: if you memoize lots of functions, you will also be creating lots of new lookup map objects that will only be garbage collected after the memoized functions are themselves collected. - Value lifetime: Values calculated by the function are stored in the lookup map forever: it will grow continually, for as long as new inputs are supplied. In some circumstances it may be better to use a bounded cache, such as the one provided by Google’s Guava, which allows you to set an eviction policy to get rid of values, either to make room in the cache when it is “full” or when they are considered “stale”.
- Cache control: The returned function completely hides the lookup map it closes over, and there is no way to access that map to clear it out – this might create problems when testing, for example.
Finally, here is a version of memoize
where the lookup map is initialised at the outset with the complete range of input values – it will return null
if called with any value that is not in the supplied range. The resulting function is thread-safe (since its access to the underlying lookup is read-only) but not recursion-safe.
public static <I, O> Function<I, O> memoize(Function<I, O> f, Stream<I> range) { return range.collect(Collectors.toMap(Function.identity(), f))::get; }
Of passing interest here is the way the method reference, to Map.get
, is automatically cast to a Function
. This makes it possible to write the entire method body as a very compact one-liner.