Posts tagged data structure

Photo URL is broken

Some time ago, I was doing a problem on HackerRank that in introduced me to two new data structures that I want to write about. The problem is called Cross the River.

The premise is this:

You're standing on a shore of a river. You'd like to reach the opposite shore.

The river can be described with two straight lines on the Cartesian plane, describing the shores. The shore you're standing on is $Y=0$ and another one is $Y=H$.

There are some rocks in the river. Each rock is described with its coordinates and the number of points you'll gain in case you step on this rock.

You can choose the starting position arbitrarily on the first shore. Then, you will make jumps. More precisely, you can jump to the position $(X_2,Y_2)$ from the position $(X_1,Y_1)$ in case $\left|Y_2−Y_1\right| \leq dH$, $\left|X_2−X_1\right| \leq dW$ and $Y_2>Y_1$. You can jump only on the rocks and the shores.

What is the maximal sum of scores of all the used rocks you can obtain so that you cross the river, i.e. get to the opposite shore?

No two rocks share the same position, and it is guaranteed that there exists a way to cross the river.

Now, my first instinct was to use dynamic programming. If $Z_i$ is the point value of the rock, and $S_i$ is the max score at rock $i$, then $$ S_i = \begin{cases} Z_i + \max\{S_j : 1 \leq Y_i - Y_j \leq dH,~|X_i - X_j| \leq dW\} &\text{if rock is reachable} \\ -\infty~\text{otherwise,} \end{cases} $$ where we assume the existence of rocks with $Y$ coordinate $0$ of $0$ point value for all $X.$

Thus, we can sort the rocks by their $Y$ coordinate and visit them in order. However, we run into the problem that if $dW$ and $dH$ are large we may need to check a large number of rocks visited previously, so this approach is $O(N^2).$

My dynamic programming approach was the right idea, but it needs some improvements. Somehow, we need to speed up the process of looking through the previous rocks. To do this, we do two things:

  1. Implement a way to quickly find the max score in a range $[X-dW, X + dW]$
  2. Only store the scores of rocks in range $[Y-dH, Y)$

To accomplish these tasks, we use two specialized data structures.

Segment Trees

Segment trees solve the first problem. They provide a way to query a value (such as a maximum or minimum) over a range and update these values in $\log$ time. The key idea is to use a binary tree, where the nodes correspond to segments instead of indices.

For example suppose that we have $N$ indices $i = 0,1,\ldots, N-1$ with corresponding values $v_i.$ Let $k$ be the smallest integer such that $2^k \geq N.$ The root node of our binary tree will be the interval $[0,2^k).$ The first left child will be $[0,2^{k-1}),$ and the first right child will be $[2^{k-1},2^k).$ In general, we have for some node $[a,b)$ if $b - a > 1$, then the left child is $[a,(b-a)/2),$ and the right child is $[(b-a)/2,b).$ Otherwise, if $b - a = 1$, there are no children, and the node is a leaf. For example, if $5 \leq N \leq 8$, our segment tree looks like this.

[0, 8) [0, 4) [0, 2) [0, 1) [1, 2) [2, 4) [2, 3) [3, 4) [4, 8) [4, 6) [4, 5) [5, 6) [6, 8) [6, 7) [7, 8)

In general, there are $2^0 + 2^1 + 2^2 + \cdots + 2^k = 2^{k+1} - 1$ nodes needed. $2N - 1 \leq 2^{k+1} - 1 \leq 2^2(N-1) - 1$, so the amount of memory needed is $O(N).$ Here's the code for constructing the tree.

class MaxSegmentTree {
    private long[] maxes;        
    private int size;
    public MaxSegmentTree(int size) {
        int actualSize = 1;
        while (actualSize < size) actualSize *= 2;
        this.size = actualSize;
        // if size is 2^k, we need 2^(k+1) - 1 nodes for all the intervals
        maxes = new long[2*actualSize - 1];
        Arrays.fill(maxes, Long.MIN_VALUE);
    }
    ...
}

Now, for each node $[a,b),$ we store a value $\max(v_a,v_{a+1},\ldots,v_{b-1}).$ An update call consists of two parameters, an index $k$ and a new $v_k.$ We would traverse the binary tree until we reach the node $[k, k+1)$ and update that node. Then, we update the max of each ancestor by taking the max of its left and right child since the segment of child is always contained in the segment of the parent. In practice, this is done recursively like this.

class MaxSegmentTree {
    ...
    public long set(int key, long value) {
        return set(key, value, 0, 0, this.size);
    }
    /** 
     * @param node index of node since binary tree is implement with array
     * @param l    lower bound of segement (inclusive)
     * @param r    upper bound of segement (exclusive)
     */
    private long set(int key, long value,
                     int node, int l, int r) {
        // if not in range, do not set anything
        if (key < l || key >= r) return maxes[node]; 
        if (l + 1 == r) {
            // return when you reach a leaf
            maxes[node] = value;
            return value;
        }
        int mid = l + (r-l)/2;
        // left node
        long left = set(key, value, 2*(node + 1) - 1, l, mid);
        // right node
        long right = set(key, value, 2*(node + 1), mid, r);
        maxes[node] = Math.max(left, right);
        return maxes[node];
    }
    ...
}

A range max query takes two parameters: the lower bound of the range and the upper bound bound of the range in the form $[i,j).$ We obtain the max recursively. Let $[l,r)$ be the segment corresponding to a node. If $[l,r) \subseteq [i,j),$ we return the max associated with $[l,r)$. If $[l,r) \cap [i,j) = \emptyset,$ we ignore this node. Otherwise, $[l,r) \cap [i,j) \neq \emptyset,$ and $\exists k \in [l,r)$ such that $k \not\in [i,j),$ so $l < i < r$ or $l < j < r.$ In this case, we descend to the child nodes. The algorithm looks like this.

class MaxSegmentTree {
    ...
    /** 
     * @param i from index, inclusive
     * @param j to index, exclusive
     * @return the max value in a segment.
     */
    public long max(int i, int j) {
        return max(i, j, 0, 0, this.size);
    }

    private long max(int i, int j, int node, int l, int r) {
        // if in interval
        if (i <= l && r <= j) return maxes[node];
        // if completely outside interval
        if (j <= l || i >= r ) return Long.MIN_VALUE;
        int mid = l + (r-l)/2;    
        long left = max(i, j, 2*(node+1) - 1, l, mid);
        long right = max(i, j, 2*(node+1), mid, r);
        return Math.max(left, right);
    }
    ...
}

I prove that this operation is $O(\log_2 N).$ To simplify things, let us assume that $N$ is a power of $2$, so $2^k = N.$ I claim that the worst case is $[i,j) = [1, 2^k - 1).$ Clearly this is true when $k = 2$ since we'll have to visit all the nodes but $[0,1)$ and $[3,4),$ so we visit $5 = 4k - 3 = 4\log_2 N - 3$ nodes.

Now, for our induction hypothesis we assume that the operation is $O(\log_2 N)$ for $1,2,\ldots, k - 1$. Then, for some $k$, we can assume that $i < 2^{k-1}$ and $j > 2^{k-1}$ since otherwise, we only descend one half of the tree, and it reduces to the $k - 1$ case. Now, given $[i, j)$ and some node $[l,r)$, we'll stop there if $[i,j) \cap [l,r) = \emptyset$ or $[l,r) \subseteq [i,j).$ Otherwise, we'll descend to the node's children. Now, we have assumed that $i < 2^{k-1} < j,$ so if we're on the left side of the tree, $j > r$ for all such nodes. We're not going to visit any nodes with $r \leq i,$ we'll stop at nodes with $l \geq i$ and compare their max, and we'll descend into nodes with $l < i < r$. At any given node on the left side, if $[l,r)$ is not a leaf and $l < i < r$, we'll choose to descend. Let the left child be $[l_l, r_l)$ and the right child be $[l_r,r_r)$. The two child segments are disjoint, so we will only choose to descend one of them since only one of $l_l < i < r_l$ or $l_r < i < r_r$ can be true. Since $l_l = l < i$, we'll stop only at the right child if $l_r = i.$ If $i$ is not odd, we'll stop before we reach a leaf. Thus, the worst case is when $i$ is odd.

On the right side, we reach a similar conclusion, where we stop when $r_l = j,$ and so the worst case is when $j$ is odd. To see this visually, here's an example of the query $[1,7)$ when $k = 3.$ Nodes where we visit the children are colored red. Nodes where we compare a max are colored green.

[0, 8) [0, 4) [0, 2) [0, 1) [1, 2) [2, 4) [2, 3) [3, 4) [4, 8) [4, 6) [4, 5) [5, 6) [6, 8) [6, 7) [7, 8)

Thus, we'll descend at $2k - 1 = 2\log_2 N - 1$ nodes and compare maxes at $2(k-1) = 2(\log_2 N - 1)$ nodes, so $4\log_2 N - 3$ nodes are visited.

Max Queues

Now, the segment tree contains the max score at each $X$ coordinate, but we want to our segement tree to only contain values corresponding to rocks that are within range of our current position. If our current height is $Y$, we want rocks $j$ if $0 < Y - Y_j \leq dH.$

Recall that we visit the rocks in order of their $Y$ coordinate. Thus, for each $X$ coordinate we add the rock to some data structure when we visit it, and we remove it when it becomes out of range. Since rocks with smaller $Y$ coordinates become out of range first, this is a first in, first out (FIFO) situation, so we use a queue.

However, when removing a rock, we need to know when to update the segment tree. So, the queue needs to keep track of maxes. We can do this with two queues. The primary queue is a normal queue. The second queue will contain a monotone decreasing sequence. Upon adding to the queue, we maintain this invariant by removing all the smaller elements. In this way, the head of the queue will always contain the max element since it would have been removed otherwise. When we removing an element from the max queue, if the two heads are equal in value, we remove the head of each queue. Here is the code.

class MaxQueue<E extends Comparable<? super E>> extends ArrayDeque<E> {
    private Queue<E> q; // queue of decreasing subsequence of elements (non-strict)
    public MaxQueue() {
        super();
        q = new ArrayDeque<E>();
    }

    @Override
    public void clear() {
        q.clear();
        super.clear();
    }

    @Override
    public E poll() { 
        if (!super.isEmpty() && q.peek().equals(super.peek())) q.poll();
        return super.poll();
    }

    @Override
    public E remove() {
        if (!super.isEmpty() && q.peek().equals(super.peek())) q.remove();
        return super.remove();
    }        

    @Override
    public boolean add(E e) {
        // remove all the smaller elements
        while (!q.isEmpty() && q.peek().compareTo(e) < 0) q.poll();
        q.add(e);
        return super.add(e);
    }

    @Override
    public boolean offer(E e) {
        // remove all the smaller elements
        while (!q.isEmpty() && q.peek().compareTo(e) < 0) q.poll();
        q.offer(e);
        return super.offer(e);
    }

    public E max() {
        return q.element();
    }        
}

Solution

With these two data structures the solution is pretty short. We keep one segment tree that stores the current max at each $X$ coordinate. For each $X$, we keep a queue to keep track of all possible maxes. The one tricky part is to make sure that we look at all rocks at a certain height before updating the segment tree since lateral moves are not possible. Each rock is only added and removed from a queue once, and we can find the max in $\log$ time, so the running time is $O(N\log N)$, where $N$ is the number of rocks. Here's the code.

public class CrossTheRiver {

    private static final int MAX_X = 100000;
    ...
    public static void main(String[] args) throws IOException {
        BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
        PrintWriter out = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
        StringTokenizer st = new StringTokenizer(in.readLine());
        int N = Integer.parseInt(st.nextToken()); // rocks
        int H = Integer.parseInt(st.nextToken()); // height
        int dH = Integer.parseInt(st.nextToken()); // max y jump
        int dW = Integer.parseInt(st.nextToken()); // max x jump        
        Rock[] rocks = new Rock[N];
        for (int i = 0; i < N; ++i) { // read through rocks
            st = new StringTokenizer(in.readLine());           
            int Y = Integer.parseInt(st.nextToken());
            int X = Integer.parseInt(st.nextToken()); // 0 index
            int Z = Integer.parseInt(st.nextToken());
            rocks[i] = new Rock(X, Y, Z);
        }        
        Arrays.sort(rocks);                
        long[] cumulativeScore = new long[N];
        MaxSegmentTree sTree = new MaxSegmentTree(MAX_X + 1);
        ArrayList<MaxQueue<Long>> maxX = new ArrayList<MaxQueue<Long>>(MAX_X + 1);
        for (int i = 0; i <= MAX_X; ++i) maxX.add(new MaxQueue<Long>());
        int i = 0; // current rock
        int j = 0; // in range rocks
        while (i < N) {
            int currentY = rocks[i].y;
            while (rocks[j].y < currentY - dH) {
                // clear out rocks that are out of range
                maxX.get(rocks[j].x).poll();
                if (maxX.get(rocks[j].x).isEmpty()) {
                    sTree.set(rocks[j].x, Long.MIN_VALUE);
                } else {
                    sTree.set(rocks[j].x, maxX.get(rocks[j].x).max());
                }
                ++j;
            }            
            while (i < N && rocks[i].y == currentY) {
                // get previous max score from segment tree
                long previousScore = sTree.max(rocks[i].x - dW, rocks[i].x + dW + 1);
                if (rocks[i].y <= dH && previousScore < 0) previousScore = 0;
                if (previousScore > Long.MIN_VALUE) {  // make sure rock is reachable
                    cumulativeScore[i] = rocks[i].score + previousScore;
                    // keep max queue up to date
                    maxX.get(rocks[i].x).add(cumulativeScore[i]); 
                }
                ++i;
            }       
            // now update segment tree
            for (int k = i - 1; k >= 0 && rocks[k].y == currentY; --k) {
                if (cumulativeScore[k] == maxX.get(rocks[k].x).max()) {
                    sTree.set(rocks[k].x, cumulativeScore[k]);
                }
            }
        }

        long maxScore = Long.MIN_VALUE;
        for (i = N - 1; i >= 0 && H - rocks[i].y <= dH; --i) {
            if (maxScore < cumulativeScore[i]) maxScore = cumulativeScore[i];
        }
        out.println(maxScore);
        in.close();
        out.close();
    }
}