Chain ‐ an immutable List with O(1) concatenation - google/mug GitHub Wiki
Agent Aragorn (Son of Arathorn), Jonny English and Smith collaborated on a bunch of missions.
The Mission
class has signature like:
class Mission {
abstract MissionId id();
abstract Range<LocalDate> timeWindow();
abstract ImmutableSet<Agent> agents();
}
The goal is to create an ImmutableRangeMap<LocalDate, Set<Agent>>
to account for all the agents during each time window. Note that missions can have overlapping time windows, and agents could work on multiple missions at the same time.
So for missions like:
missions = [{
timeWindow: [10/01..10/30]
heroes: [Aragorn, English]
},
{
timeWindow: [10/15..11/15]
heroes: [Aragorn, Smith]
}]
I want the result to be:
[10/01..10/15): [Aragorn, English]
[10/15..10/30]: [Aragorn, English, Smith]
(10/30..11/15]: [Aragorn, Smith]
At first I thought to use the toImmutableRangeMap() collector, as in:
missions.stream()
.collect(toImmutableRangeMap(Mission::timeWindow, Mission::agents));
Voila, done, right?
No. My colleague pointed out that toImmutableRangeMap()
does not allow overlapping ranges. It wants all input time windows to be disjoint.
Luckily, TreeRangeMap
has a merge() method that already does the heavylifting: finds overlapping and splits the ranges, and then merge the values mapped to the overlapping subrange.
With some effort, I created a toImmutableRangeMap(merger) BiCollector
on top of the merge()
function.
So if what I needed is just to count the number of agents, I could have done:
import static ...BiStream.biStream;
ImmutableRangeMap<LocalDate, Integer> agentCounts =
biStream(missions)
.mapKeys(Mission::timeWindow)
.mapValues(mission -> mission.agents().size())
.collect(toImmutableRangeMap(Integer::sum));
(It'll double count the duplicate agents though)
Anyhoo, here goes the interesting part: how do I merge the Set<Agent>
?
I could use Guava's Sets.union()
:
import com.google.common.collect.Sets;
ImmutableRangeMap<LocalDate, ImmutableSet<Agent>> agentsTimeline =
biStream(missions)
.mapKeys(Mission::timeWindow)
.mapValues(mission -> mission.agents())
.collect(toImmutableRangeMap((set1, set2) -> Sets.union(set1, set2).immutableCopy()));
The gotcha is that each time merging happens, merging two original sets into one is O(n)
where n is the number of agents from the two overlapping ranges. If we are unlucky, we can get into the situation where a time window is repetitively discovered to overlap with another time window, and we keep copying and copying over again. The time complexity is quadratic.
Could I remove the .immutableCopy()
? Sets.union()
returns a view that takes constant time so we should be good?
Well, not really. We don't know how many times merging will happen, a Set
can be unioned, then unioned again for unknown times. In the worst case, we'd create a union-of-union-of-union N levels deep. If N is a large number, we'll stack overflow when we try to access the final SetView
!
The same will happen if for example I use Iterables.concat()
. The Stream.concat() javadoc discusses this problem.
I slept on this problem for about two days for an idea to come to me: can we use something like Haskell's List
?
Tl;dr, Haskell's List is like LinkedList
except it's immutable. So given a list of [2, 3]
, you can cons
the number 1 onto the list to get a new instance of [1, 2, 3]
. Under the hood it's as simple as creating a new object with the internal tail
pointer pointing to the old [2, 3]
list.
If I can do this, each time merging happens, I only need to pay O(1) cost. The resulting object is probably less efficient for random access than ArrayList
or Guava's ImmutableList
because of all the pointers and indirections. But that's okay. When the whole split-merge process is done, I can perform a final copy into ImmutableList
, which is O(n).
The only problem? Haskell's cons
only allows to add one element, while I have two List<Agent>
s to concatenate (I can't cons every element from one list, because then I'm getting back to quadratic).
To support concat(list1, list2)
, I decided to use a binary tree to represent the List's state:
private static final class Tree<T> {
final T mid;
@Nullable final Tree<T> left; // null means empty
@Nullable final Tree<T> right; // null means empty
Tree(T value, Tree<T> left, Tree<T> right) {...}
}
In the list, the elements in left
show up first, followed by mid
, then followed by the elements in right
. In other words, an in-order traversal will give us back the list.
The key trick is to figure out how to concatenate two binary trees into one. Intuitively, I need to find the new "mid point" value, which can be either the left
tree's last element, or the right
tree's first element. Say, if I take the right
tree's first element, then the new tree's left
remains the old left
, while the new tree's right
would need to be the old right
after removing the first element.
Since the Tree is immutable, how do I remove? And in a binary tree, finding the first element takes up to O(n) time (it's not balanced tree).
It turns out there's a law in computer science:
All problems in computer science can be solved by another level of indirection
In human language: if a problem can't be solved with one layer of indirection, add a second layer of indirection. :)
Here goes my second layer of indirection that handles the remove first element from an immutable list task:
public final class Chain<T> {
private final T head;
@Nullable private final Tree<T> tail;
public static <T> Chain<T> of(T value) {
return new Chain<>(value, null);
}
public static <T> Chain<T> concat(Chain<T> left, Chain<T> right) {
T newHead = left.head;
Tree<T> newTail = new Tree<>(right.head, left.tail, right.tail);
return new Chain<T>(newHead, newTail);
}
}
It takes a bit of brain gymnastics. But if you sit down and think for a minute, it's actually pretty straight forward.
This solves the O(1) concatenation. And the good thing is that, no matter how deep concat()
is nested, the result is always one layer of Chain
with a heap-allocated Tree
object.
Now we just need to make sure when we access the Chain
, we take no more than O(n) time, and constant stack space.
My secret weapon is Walker.inBinaryTree() from Mug. It already does everything I needed:
- O(n) time in-order traversal.
- Constant stack space.
Using it is pretty simple. First we add a stream()
method to the Tree
class:
private static final class Tree<T> {
...
Stream<T> stream() {
return Walker.<Tree<T>>inBinaryTree(t -> t.left, t -> t.right)
.inOrderFrom(this)
.map(t -> t.mid);
}
}
The inOrderFrom()
method returns a lazy stream, which will take at the worst case O(n) heap space and constant stack space.
Then we wrap and polish it up in our wrapper Chain
class:
public final class Chain<T> {
...
/**
* Returns a <em>lazy</em> stream of the elements in this list.
* The returned stream is lazy in that concatenated chains aren't consumed until the stream
* reaches their elements.
*/
public Stream<T> stream() {
return tail == null
? Stream.of(head)
: Stream.concat(Stream.of(head), tail.stream());
}
}
With that, it gives me O(n) time read access to the tree and I can easily collect()
it into an ImmutableList
.
In the actual implementation, I also made Chain implements List
to make it nicer to use, and used lazy initialization to pay the cost only once. But that's just some extra API makeup. The meat is all here.
A bit of googling shows that people have run into similar needs but I didn't find a similar implementation that handles both the O(1) concatenation time and stack overflow concern.
So to build the RangeMap
, we can first wrap each Mission
in a Chain
, let the merge process run, and finally flatten the merged mission chain:
import static com.google.mu.util.stream.BiStream.biStream;
import static com.google.mu.util.stream.GuavaCollectors.toDisjointRanges;
ImmutableRangeMap<LocalDate, ImmutableSet<Agent>> agentsTimeline =
biStream(missions)
.mapKeys(Mission::timeWindow)
.mapValues(Chain::of)
.collect(toDisjointRanges(Chain::concat)) // BiStream<Range, Chain<Mission>>
.mapValues(
missions -> missions.stream()
.flatMap(mission -> mission.agents().stream())
.collect(toImmutableSet())
.collect(ImmutableRangeMap::toImmutableRangeMap);
It still feels a bit verbose having to first wrap each Mission
in a Chain
and finally unwrap them. So I created another toDisjointRanges()
overload to hide some of the implementation details. The above code is then simplified to:
import static com.google.mu.util.stream.BiStream.biStream;
import static com.google.mu.util.stream.GuavaCollectors.toDisjointRanges;
ImmutableRangeMap<LocalDate, ImmutableSet<Agent>> agentsTimeline =
biStream(missions)
.mapKeys(Mission::timeWindow)
.collect(toDisjointRanges()) // BiStream<Range, Chain<Mission>>
.mapValues(
missions -> missions.stream()
.flatMap(mission -> mission.agents().stream())
.collect(toImmutableSet())))
.collect(ImmutableRangeMap::toImmutableRangeMap);
To be honest, I simplified the use case a bit to make it easier to explain. The original motivating use case also involves nested Map
s.
Assume instead of the ImmutableRangeMap<LocalDate, Set<Agent>>
which gives no information about which mission(s) each agent worked on at that time, we want ImmutableRangeMap<LocalDate, SetMultimap<Agent, MissionId>>
where for each time window, there're both the agents and the missions they worked on.
We can change the above code to turn a Mission
into a stream of mappings from Agent
to MissionId
(if you know the BiStream
API):
import static com.google.mu.util.stream.BiStream.biStream;
import static com.google.mu.util.stream.GuavaCollectors.toDisjointRanges;
import static com.google.mu.util.stream.GuavaCollectors.toImmutableSetMultimap;
ImmutableRangeMap<LocalDate, ImmutableSetMultimap<Agent, MissionId>> agentsTimeline =
biStream(missions)
.mapKeys(Mission::timeWindow)
.collect(toDisjointRanges()) // BiStream<Range, Chain<Mission>>
.mapValues(
missions -> biStream(missions)
.flatMapKeys(mission -> mission.agents().stream())
.mapValues(Mission::id)
.collect(toImmutableSetMultimap())))
.collect(ImmutableRangeMap::toImmutableRangeMap);
Before we conclude, there is another gotcha. That is, even with all that effort, you can still get into quadratic range merging.
Consider this:
{
[0..1]: a,
[0..2]: b,
[0..3]: c,
[0..4]: d,
[0..5]: e,
...
}
When TreeRangeMap
merges them in the given order, the number of range merges will be 1 + 2 + 3 + 4 + ...
times. There is not much we can do in the Chain
class to help because it's managed and triggered by the TreeRangeMap
class itself.
To avoid the quadratic merging problem, perhaps it's prudent to first sort the ranges by their length in descending order. For the above example, merging in the descending order will cut the number of merges to 4, in other words, O(n).