BIT - redultimate/utility GitHub Wiki

概要

  • Binary indexed tree
  • Fenwick Tree
    • Peter M. Fenwick, "A New Data Structure for Cumulative Frequency Tables" (1994)
  • 数列a1, ......, aNに対して次の2つの処理をO(logn)で実現するデータ構造
  1. 1からiまでの和を計算する.
  2. ある値を加算する.
  • 値の更新が必要ない場合には, 単に配列とかに記録すれば済む場合が多い.

計算量

  • O(logn)

実装例

  • 1次元BITクラス
template<class T> class BIT {

   public:
      vector<T> bit;
      int SIZE;

      void init(int N, T a) {
         SIZE = N + 1;
         bit.resize(SIZE);
         for (int i = 0; i < SIZE; i++) {
            bit[i] = a;
         }
      }
      // 1-indexed 
      void add(int a, T w) {
         for (int x = a; x <= SIZE; x += x & -x) bit[x] += w;
      }
      
      // [1, a], 1-indexed
      T sum(int a) {
         T ret = 0;
         for (int x = a; x > 0; x -= x & -x) ret += bit[x];
         return ret;
      }

      // [a, b], 1-indexed
      T sum(int a, int b) {
         return sum(b) - sum(a - 1);
      }

      // get k-th number with binary search, k: 0-indexed
      int getk(T k) {
         k++;
         T ret = 0;
         int N = 1;
         while (N < SIZE) N *= 2;
         for (int i = N / 2; i > 0; i /= 2) {
            if (ret + i < SIZE && bit[ret + i] < k) {
               k = k - bit[ret + i];
               ret = ret + i;
            }
         }
         return ret + 1;
      }
};
  • 2次元BIT
int N;
int bit[1010][1010];
void add(int a, int b, int w) {
   for (int x = a; x <= M; x += x & -x)
      for (int y = b; y <= N; y += y & -y) {
         bit[x][y] += w;
      }
   }
}
int sum(int a, int b) {
   int ret = 0;
   for (int x = a; x > 0; x -= x & -x) {
      for (int y = b; y > 0; y -= y & -y) {
         ret += bit[x][y];
      }
   }
   return ret;
}
  • 2次元BITクラス(ただし, 大きいところで未確認.)
const int NMAX = 200010;
// vector < vector<int> > bit(NMAX, vector<int>(NMAX));
//int bit[NMAX][NMAX];
template<class T> class BIT2 {

      int N, M;
      map<ipair, T> log;
   public:
      // x, y
      void init(int W, int H, T a) {
         N = H;
         M = W;
         // bit.resize(NMAX);
         // for (int i = 0; i < NMAX; i++) bit[i].resize(NMAX);
         // bit.assign(NMAX, vector<T>(NMAX, a));
         // fill(bit[0], bit[M], a);
      }

      void add(int a, int b, T w) {
         for (int x = a; x <= M; x += x & -x) {
            for (int y = b; y <= N; y += y & -y) {
               // bit[x][y] += w;
               log[ipair(x, y)] += w;
            }
         }
      }
      // [1, a] x [1, b]
      T sum(int a, int b) {
         T ret = 0;
         for (int x = a; x > 0; x -= x & -x) {
            for (int y = b; y > 0; y -= y & -y) {
               // ret += bit[x][y];
               ret += log[ipair(x, y)];
            }
         }
         return ret;
      }

      // [a1, a2] x [b1, b2]
      T sum2(int a1, int a2, int b1, int b2) {
         T ret = 0;
         ret += sum(a2, b2);
         ret -= sum(a2, b1 - 1);
         ret -= sum(a1 - 1, b2);
         ret += sum(a1 - 1, b1 - 1);
         return ret;
      }
};

注意

  • 添字は1から!
  • それに伴って, 初期化時にSIZE = N + 1としている.
  • SIZEは2のべきにする必要はない.
  • 2次元BITクラスは大きいところで未確認.

使用例

  • 1次元BITクラスを使用.
  • 追加とk番目に小さい要素を答える問題で, BIT+二分探索をクラスに追加した.
  • 和ではないので, add(x, y)ではなくadd(x, 1)とする. 削除はadd(x, -1).
  • getk確認済み
int main() {   
   int Q;
   cin >> Q;

   BIT<int> bit;
   bit.init(200000, 0);
   for (int q = 0; q < Q; q++) {
      int t, x;
      cin >> t >> x;
      if (t == 1) {
         bit.add(x, 1);
      }
      else {
         int ans = bit.getk(x - 1);
         cout << ans << endl;
         bit.add(ans, -1);
      }
   }
   return 0;
}
  • 1次元BITクラスを使用.
  • sum確認済み
int main() {
   int N, Q;
   cin >> N >> Q;
   BIT<int> bit;
   bit.init(100000, 0);
   for (int q = 0; q < Q; q++) {
      int c, x, y;
      cin >> c >> x >> y;
      if (c == 0) {
         bit.add(x, y);
      }
      else {
         int ans = bit.sum(x, y);
         cout << ans << endl;
      }
   }
   return 0;
}
  • 1次元BITクラスを使用.
  • 単純な反転数に加えて, 要素数よりも取りうる数が大きいので, 座標圧縮っぽいことを前処理としてする必要がある.
  • 反転数の処理は, j番目の要素についてa[j]以下ですでに動かしたやつ以外の数を足しあげる. 「すでに動かしたやつの和」を考えながら更新していくのでBITを使うのが自然.
int main() {   
   int N;
   cin >> N;
   int A[N];
   map<int, int> log; 
   for (int i = 0; i < N; i++) {
      cin >> A[i];
      log[A[i]] = 1;
   }
   int num = 1;
   for (auto& l : log) {
      if(l.second == 0) continue;
      l.second = num;
      num++;
   }
   for (int i = 0; i < N; i++) {
      A[i] = log[A[i]];
   }
   
   BIT<int> bits;
   bits.init(N, 0);
   ll ans = 0;
   for (int i = 0; i < N; i++) {
      ans += i - bits.sum(A[i]);
      bits.add(A[i], 1);
   }
   cout << ans << endl;
   return 0;
}
  • 1次元BITクラスを使用.
  • 3つの要素からなる.
  • 座標圧縮, 2次元っぽいBIT, 包除原理.
  • 2次元っぽいBIT部分が難しかった. 普通の2次元BITだとデータの持ち方をmapなどで工夫したとしても間に合わない.(たとえばこの提出)
  • 冷静にこの問題の制限を見てみると, x, yに重複がなかったりするので, 単に順番にupdateしていけばまさにBITの使いどころ.
  • すなわち, 左側のBITと右側のBITを持っておけば, あらかじめ全ての点は与えられるので, 順番に見ていけば2つのBITで2次元BITを表すことができる.
  • フルのコードはこちら
int main() {
   int N;
   cin >> N;
   
   int x[N], y[N];
   vector<int> xs, ys;

   for (int i = 0; i < N; i++) {
      cin >> x[i] >> y[i];
      xs.push_back(x[i]);
      ys.push_back(y[i]);
   }
   sort(xs.begin(), xs.end());
   xs.erase(unique(xs.begin(), xs.end()), xs.end());
   sort(ys.begin(), ys.end());
   ys.erase(unique(ys.begin(), ys.end()), ys.end());

   for (int i = 0; i < N; i++) {
      x[i] = lower_bound(xs.begin(), xs.end(), x[i]) - xs.begin();
      x[i]++;
      y[i] = lower_bound(ys.begin(), ys.end(), y[i]) - ys.begin();
      y[i]++;
   }
   
   BIT<int> bitL, bitR;
   bitL.init(N + 1, 0);
   bitR.init(N + 1, 0);
   for (int i = 0; i < N; i++) {
      bitR.add(y[i], 1);
   }
   int log[N];
   for (int i = 0; i < N; i++) {
      log[x[i] - 1] = i;
   }
   sort(x, x + N);

   ll ans = 0;
   for (int i = 0; i < N; i++) {
      int y2 = y[log[x[i] - 1]];
      bitR.add(y2, -1);
      int LU = bitL.sum(y2 + 1, N);
      int RU = bitR.sum(y2 + 1, N);
      int LD = bitL.sum(1, y2 - 1);
      int RD = bitR.sum(1, y2 - 1);
      ll tmp;
      tmp = 1;
      mul(tmp, llmodpow(2, LU));
      mul(tmp, llmodpow(2, RU));
      mul(tmp, llmodpow(2, LD));
      mul(tmp, llmodpow(2, RD));
      add(ans, tmp);
      tmp = llmodpow(2, LU) - 1;
      mul(tmp, llmodpow(2, RD) - 1);
      mul(tmp, llmodpow(2, RU));
      mul(tmp, llmodpow(2, LD));
      add(ans, tmp);
      tmp = llmodpow(2, RU) - 1;
      mul(tmp, llmodpow(2, LD) - 1);
      mul(tmp, llmodpow(2, LU));
      mul(tmp, llmodpow(2, RD));
      add(ans, tmp);
      tmp = llmodpow(2, LU) - 1;
      mul(tmp, llmodpow(2, RU) - 1);
      mul(tmp, llmodpow(2, LD) - 1);
      mul(tmp, llmodpow(2, RD) - 1);
      sub(ans, tmp);
      bitL.add(y2, 1);
   }
   ans %= MOD;
   if (ans < 0) ans += MOD; 
   cout << ans << endl;
   return 0;
}

参考資料

⚠️ **GitHub.com Fallback** ⚠️