Photo URL is broken

It's currently a rainy day up here in Saratoga Springs, NY, so I've decided to write about a problem that introduced me to Fenwick trees, otherwise known as binary indexed trees. It exploits the binary representation of numbers to speed up computation from $O(N)$ to $O(\log N)$ in a similar way that binary lifiting does in the Least Common Ancestor Problem.

Consider a vector elements of $X = (x_1,x_2,\ldots,x_N)$. A common problem is to find the sum or a range of elements. Define $S_{jk} = \sum_{i=j}^kx_i$. We want to compute $S_{jk}$ quickly for any $j$ and $k$. The first thing to is define $F(k) = \sum_{i=1}^kx_i$, where we let $F(0) = 0$. Then, we rewrite $S_{jk} = F(k) - F(j-1)$, so our problem reduces to finding a way to compute $F$ quickly.

Computing from $X$ directly, computing $F$ is a $O(N)$ operation, but updating incrementing some $x_i$ is a $O(1)$ operation. On the other hand, we can precompute $F$ with some dynamic programming since $F(k) = x_k + F(k-1)$. In this way computing $F$ is a $O(1)$ operation, but if we increment $x_i$, we have to update $F(j)$ for all $j \geq i$, which is a $O(N)$ operation. The Fenwick tree makes both of these operations $O(\log N)$.

Fenwick Tree

The key idea of the Fenwick tree is to cache certain $S_{jk}$. Suppose that we wanted to compute $F(n)$. First, we write $$n = d_02^0 + d_12^1 + \cdots + d_m2^m.$$ If we remove all the terms where $d_i = 0$, we rewrite this as $$ n = 2^{i_1} + 2^{i_2} + \cdots + 2^{i_l},~\text{where}~i_1 < i_2 < \cdots < i_l. $$ Let $n_j = \sum_{k = j}^l 2^{i_k}$, and $n_{l + 1} = 0$. Then, we have that $$ F(n) = \sum_{k=1}^{l} S_{n_{k+1}+1,n_k}. $$ For example for $n = 12$, we first sum the elements $(8,12]$, and then, the elements $(0,8]$, secondly.

We can represent these intervals and nodes in a tree like this. Like a binary heap, the tree is stored as an array. The number before the colon is the index in the array. The number after the colon is the value is the sum of $x_i$, where $i$ is in the interval $(a,b]$. The interval for node $i$ is $(p_i,i]$, where $p_i$ is the parent of node $i$.

Now, suppose $X = (58,62,96,87,9,46,64,54,87,7,51,84,33,69,43)$. Our tree should look like this.

Calculating $F(n)$

Suppose we wanted to calculate $F(13)$. We start at node $13$, and we walk up towards the root adding the values of all the nodes that we visit. In this case, we find that $F(13) = 738$. Writing the nodes visited in their binary representations reveals a curious thing: \begin{align*} 13 &= \left(1101\right)_2 \\ 12 &= \left(1100\right)_2 \\ 8 &= \left(1000\right)_2 \\ 0 &= \left(0000\right)_2. \end{align*} If you look closely, at each step, we simply remove the rightmost bit, so finding the parent node is easy.

Updating a Fenwick Tree

This is a little trickier, but it uses the same idea. Suppose that we want to increment $x_n$ by $\delta$. First, we increase the value of node $n$ by $\delta$. Recall that we can write $$ n = 2^{i_1} + 2^{i_2} + \cdots + 2^{i_l},~\text{where}~i_1 < i_2 < \cdots < i_l. $$ Now if $j < i_1$, node $n + 2^{j}$ is a descendant of node $n$. Thus, the next node we need to update is $n + 2^{i_1}$. We repeat this process of adding the rightmost bit and updating the value of the node until we exceed the capacity of the tree. For instance, if we add $4$ to $x_5$, we'll update the nodes in blue.

Two's Complement

If you read the above carefully, we you'll note that we often need to find the rightmost bit. We subtract it when summing and add it when updating. Using the fact that binary numbers are represented with Two's complement, there's an elegant trick we can use to make finding the rightmost bit easy.

Consider a 32-bit signed integer with bits $n = b_{31}b_{30}\cdots b_{0}$. For $0 \leq i \leq 30$, $b_i = 1$ indicates a term of $2^i$. If $b_{31} = 0$, then $n$ is positive and $$ n = \sum_{i \in \left\{0 \leq i \leq 30~:~b_i = 1\right\}}2^i. $$ On the other hand if $b_{31} = 1$, we still have the same terms but we subtract $2^{31}$, so $$ n = -2^{31} + \sum_{i \in \left\{0 \leq i \leq 30~:~b_i = 1\right\}}2^i, $$ which makes $n$ negative.

As an example of the result of flipping $b_{31}$, we have \begin{align*} 49 &= (00000000000000000000000000110001)_{2} \\ -2147483599 &= (10000000000000000000000000110001)_{2}. \end{align*}

Now, consider the operation of negation. Fix $x$ to be a nonnegative integer. Let $y$ be such that $-2^{31} + y = -x$, so solving, we find that $$y = -x + 2^{31} = -x + 1 + \sum_{i=0}^{30}2^i.$$ Therefore, $y$ is the positive integer we get by flipping all the bits of $x$ except $b_{31}$ and adding $1$. Making $x$ negative, $-x = -2^{31} + y$ will have $b_{31}$ flipped, too. Using $49$ as an example again, we see that \begin{align*} 49 &= (00000000000000000000000000110001)_{2} \\ -49 &= (11111111111111111111111111001111)_{2}. \end{align*}

This process looks something like this: $$ x = (\cdots 10\cdots0)_2 \xrightarrow{\text{Flip bits}} (\cdots 01\cdots1)_2 \xrightarrow{+1} (\cdots 10\cdots0)_2 = y. $$ In this way $x$ and $y$ have same rightmost bit. $-x$ has all the same bits as $y$ except for $b_{31}$. Thus, $x \land -x$ gives us the rightmost bit.

Fenwick Tree Implementation

Using this trick, the implementation of the Fenwick tree is just a couple dozen lines. My implemenation is adapted for an $X$ that is $0$-indexed.

class FenwickTree {
  vector<int> tree;
public:
  FenwickTree(int N) : tree(N + 1, 0) {}
  // sums the elements from 0 to i inclusive
  int sum(int i) {
    if (i < 0) return 0;
    ++i; // use 1-indexing, we're actually summing first i + 1 elements 
    if (i > tree.size() - 1) i = tree.size() - 1;
    int res = 0;
    while (i > 0) {
      res += tree[i];
      i -= (i & -i); // hack to get least bit based on two's complement
    }
    return res;
  }
  // sums the elements from i to j inclusive
  int sum(int i, int j) {
    return sum(j) - sum(i - 1);
  }
  // update counts
  void update(int i, int delta) {
    ++i;  // convert to 1-indexing  
    while (i < tree.size()) {
      tree[i] += delta;
      i += (i & -i);
    }
  }
};

Vika and Segments

The original motivation for me to learn about Fenwick trees was the problem, Vika and Segments. Here's the problem statement:

Vika has an infinite sheet of squared paper. Initially all squares are white. She introduced a two-dimensional coordinate system on this sheet and drew $n$ black horizontal and vertical segments parallel to the coordinate axes. All segments have width equal to $1$ square, that means every segment occupy some set of neighbouring squares situated in one row or one column.

Your task is to calculate the number of painted cells. If a cell was painted more than once, it should be calculated exactly once.

The first thing to do is join together the horizontal lines that overlap and the vertical lines that overlap. The basic idea is to count all the squares that are painted by the horizontal lines. Then, we sort the vertical lines by their x-coordinates and sweep from left to right.

For each vertical line, we count the squares that it covers and subtract out its intersection with the horizontal lines. This is where the Fenwick tree comes into play.

For each vertical line, it will have endpoints $(x,y_1)$ and $(x,y_2)$, where $y_1 < y_2$. As we sweep from left to right, we keep track of which horizontal lines are active. Let $Y$ be array of $0$s and $1$s. We set $Y[y] = 1$ if we encounter a horizontal line, and $Y[y] = 0$ if the horizonal line ends. Every time that we encounter a vertical line, we'll want to compute $\sum_{y = y_1}^{y_2}Y[y]$, which we can quickly with the Fenwick tree.

Now, the range of possible coordinates is large, so there are some details with coordinate compression, but I believe the comments in the code are clear enough.

struct LineEndpoint {
  int x, y;
  bool start, end;
  LineEndpoint(int x, int y, bool start) : x(x), y(y), start(start), end(!start)
  {}
};

void joinLines(map<int, vector<pair<int, int>>> &lines) {
  for (map<int, vector<pair<int, int>>>::iterator lineSegments = lines.begin();
       lineSegments != lines.end(); ++lineSegments) {
    sort((lineSegments -> second).begin(), (lineSegments -> second).end());
    vector<pair<int, int>> newLineSegments;
    newLineSegments.push_back((lineSegments -> second).front());
    for (int i = 1; i < (lineSegments -> second).size(); ++i) {
      if (newLineSegments.back().second + 1 >= (lineSegments -> second)[i].first) { // join line segment
        // make line as large as possible
        newLineSegments.back().second = max((lineSegments -> second)[i].second, newLineSegments.back().second);
      } else { // start a new segment
        newLineSegments.push_back((lineSegments -> second)[i]);
      }
    }
    (lineSegments -> second).swap(newLineSegments);
  }
}

int main(int argc, char *argv[]) {
  ios::sync_with_stdio(false); cin.tie(NULL);
  int N; cin >> N; // number of segments
  map<int, vector<pair<int, int>>> horizontalLines; // index by y coordinate
  map<int, vector<pair<int, int>>> verticalLines; // index by x coordinate
  for (int n = 0; n < N; ++n) { // go through segements
    int x1, y1, x2, y2;
    cin >> x1 >> y1 >> x2 >> y2;
    if (y1 == y2) {
      if (x1 > x2) swap(x1, x2);
      horizontalLines[y1].emplace_back(x1, x2);
    } else if (x1 == x2) {
      if (y1 > y2) swap(y1, y2);
      verticalLines[x1].emplace_back(y1, y2);
    }
  }  
  // first join horizontal and vertical segments that coincide  
  joinLines(horizontalLines); joinLines(verticalLines);
  /* now compress coordinates
   * partition range so that queries can be answered exactly
   */
  vector<int> P; 
  for (pair<int, vector<pair<int, int>>> lineSegments : verticalLines) {
    for (pair<int, int> lineSegment : lineSegments.second) {
      P.push_back(lineSegment.first - 1);
      P.push_back(lineSegment.second);
    }
  }
  sort(P.begin(), P.end());
  P.resize(unique(P.begin(), P.end()) - P.begin());
  /* Now let M = P.size(). We have M + 1 partitions.
   * (-INF, P[0]], (P[0], P[1]], (P[1], P[2]], ..., (P[M - 2], P[M-1]], (P[M-1],INF]
   */
  unordered_map<int, int> coordinateBucket;
  for (int i = 0; i < P.size(); ++i) coordinateBucket[P[i]] = i;
  // begin keep track of blackened squares
  long long blackenedSquares = 0;
  // sort the horizontal lines end points to prepare for a scan
  // tuple is (x-coordinate, flag for left or right endpoint, y-coordinate)
  vector<LineEndpoint> horizontalLineEndpoints;  
  for (pair<int, vector<pair<int, int>>> lineSegments : horizontalLines) {
    for (pair<int, int> lineSegment : lineSegments.second) {
      horizontalLineEndpoints.emplace_back(lineSegment.first, lineSegments.first, true); // start
      horizontalLineEndpoints.emplace_back(lineSegment.second, lineSegments.first, false); //end
      // horizontal lines don't coincide with one another, count them all
      blackenedSquares += lineSegment.second - lineSegment.first + 1;
    }
  }  
  // now prepare to scan vertical lines from left to right
  sort(horizontalLineEndpoints.begin(), horizontalLineEndpoints.end(), 
       [](LineEndpoint &a, LineEndpoint &b) -> bool {
         if (a.x != b.x) return a.x < b.x;
         if (a.start != b.start) return a.start; // add lines before removing them
         return a.y < b.y;
       });
  FenwickTree horizontalLineState(P.size() + 1);
  vector<LineEndpoint>::iterator lineEndpoint = horizontalLineEndpoints.begin();
  for (pair<int, vector<pair<int, int>>> lineSegments : verticalLines) {
    /* update the horizontal line state
     * process endpoints that occur before vertical line      
     * add line if it occurs at the vertical line
     */
    while (lineEndpoint != horizontalLineEndpoints.end() && 
           (lineEndpoint -> x < lineSegments.first ||
            (lineEndpoint -> x == lineSegments.first && lineEndpoint -> start))) {
      int bucketIdx = lower_bound(P.begin(), P.end(), lineEndpoint -> y) - P.begin();
      if (lineEndpoint -> start) { // add the line
        horizontalLineState.update(bucketIdx, 1);
      } else if (lineEndpoint -> end) { // remove the line
        horizontalLineState.update(bucketIdx, -1);
      }
      ++lineEndpoint;
    }
    for (pair<int, int> lineSegment : lineSegments.second) {
      // count all squares
      blackenedSquares += lineSegment.second - lineSegment.first + 1;
      // subtract away intersections, make sure we start at first bucket that intersects with line
      blackenedSquares -= horizontalLineState.sum(coordinateBucket[lineSegment.first - 1] + 1, 
                                                  coordinateBucket[lineSegment.second]);
    }    
  }
  cout << blackenedSquares << endl;
  return 0;
}

New Comment


Comments

No comments have been posted yet. You can be the first!