诗海

我们越谦卑,就离真理越近

0%

线段树模板

本文主要来介绍一种竞赛题中常见的数据结构:线段树(Segment Tree)。线段树可以在O(logN)的时间复杂度内实现单点修改/区间修改,区间查询(区间求和,求区间最大值,求区间最小值)等操作。

对于普通的数组,比如说a = [1, 3, -2, 8, 7],我们可以在O(1)时间内完成单个元素值a[i]的修改,但是如果想进行区间求和运算,就必须用O(N)时间遍历数组;同样地,如果我们使用前缀和数组presum,我们可以在O(1)的时间内完成a[l...r]区间的求和运算,但是如果说原始数组a[i]的值发生了改变,那么就需要O(N)时间来修改前缀和数组presum[i + 1...N]。因此,线段树就是一种更好的数据结构,能够兼顾两者。

Template

如果原始数组a的长度为n,那么用树状图表示a[0...n-1]的区间范围,一共需要2n - 1个树节点。

对于单点修改(对任意的i,修改a[i]的值为val),区间求和(对任意区间a[l...r]求和)的线段树,可以自下而上构造,下面给出的通用模板:

class SegmentTree {
private:
// here n is the size of the original input
int _n;
// _arr[1, 2 * _n - 1] are available nodes
// _arr[_n, 2 * _n - 1] are leaf nodes
vector<int> _arr;

public:
// build the segment tree, O(N)
SegmentTree(const vector<int> input) : _n(input.size()), _arr(_n * 2) {
for (int i = _n, j = 0; j < _n; ++i, ++j) {
_arr[i] = input[j];
}
for (int i = _n - 1; i > 0; --i) {
_arr[i] = _arr[i * 2] + _arr[i * 2 + 1];
}
}

// single point update, O(logN)
void update(int pos, int val) {
pos += _n;
_arr[pos] = val;

while (pos > 1) {
int left = pos;
int right = pos;
// pos is multiple of 2, means pos is left node
if (pos % 2 == 0) {
right = pos + 1;
} else {
left = pos - 1;
}
// update parent
_arr[pos / 2] = _arr[left] + _arr[right];
pos /= 2;
}
}

// range sum query, O(logN)
int sumRange(int l, int r) {
l += _n;
r += _n;
int sum = 0;
while (l <= r) {
// l is on the right node
if ((l % 2) == 1) {
sum += _arr[l];
++l;
}
// r is on the left node
if ((r % 2) == 0) {
sum += _arr[r];
--r;
}
l /= 2;
r /= 2;
}
return sum;
}
};

注意:以上的写法不能保证区间连续,即对于一个idx = v的父节点(如果它的表示范围是[tl, tr]),它的左子节点idx = 2v不一定表示范围[tl, mid],右子节点idx = 2v + 1不一定表示范围[mid + 1, tr],因此只适用于单点修改,而不是区间修改。

区间求最大值线段树,区间求最小值线段树与此类似,只不过在p节点时存储的值(_arr[p])不是左右子节点之和(_arr[2 * p] + _arr[2 * p + 1]),而是左右子节点的最大值/最小值。如果在一个线段树中想要同时维护区间求和和求最大值/最小值操作,那么_arr存储的值就不再是单一的int而是struct NodeNode中包括sum, min, max等成员变量。

对于区间修改(a[ul..ur]的所有元素值+val),区间查询(查询a[ul...ur]的所有元素值之和)的线段树,需要自上而下构造,下面给出通用模板:

class SegmentTree {
private:
int _n;
vector<int> _arr;
vector<int> _lazy;
// tl, tr means the fixed available range represented by _arr[idx] node
void build(int idx, const vector<int> &input, int tl, int tr) {
if (tl == tr) {
// leaf node
_arr[idx] = input[tl];
} else if (tl < tr) {
int mid = tl + (tr - tl) / 2;
build(idx * 2, input, tl, mid);
build(idx * 2 + 1, input, mid + 1, tr);
_arr[idx] = _arr[idx * 2] + _arr[idx * 2 + 1];
}
}

// for parent node _arr[idx], renew its left/right child node
void pushdown(int idx, int tl, int tr) {
int mid = tl + (tr - tl) / 2;
if (tl != tr && _lazy[idx] != 0) {
_lazy[idx * 2] += _lazy[idx];
_arr[idx * 2] += _lazy[idx] * (mid - tl + 1);

_lazy[idx * 2 + 1] += _lazy[idx];
_arr[idx * 2 + 1] += _lazy[idx] * (tr - mid);

_lazy[idx] = 0;
}
}

// ul, ur means the range specified by user calling, keep constant
void add(int idx, int tl, int tr, int ul, int ur, int val) {
if (ul <= tl && tr <= ur) {
// current range[tl, tr] is entirely included
// store add record at _lazy and not go deeper
_lazy[idx] += val;
_arr[idx] += (tr - tl + 1) * val;
} else {
// currrent range [tl, tr] is partially included
int mid = tl + (tr - tl) / 2;
pushdown(idx, tl, tr);
if (ul <= mid) add(idx * 2, tl, mid, ul, ur, val);
if (ur > mid) add(idx * 2 + 1, mid + 1, tr, ul, ur, val);
_arr[idx] = _arr[idx * 2] + _arr[idx * 2 + 1];
}
}

int sum(int idx, int tl, int tr, int ul, int ur) {
if (ul <= tl && tr <= ur) {
// current range[tl, tr] is entirely included
return _arr[idx];
} else {
int mid = tl + (tr - tl) / 2;
int res = 0;
pushdown(idx, tl, tr);
if (ul <= mid) res += sum(idx * 2, tl, mid, ul, ur);
if (ur > mid) res += sum(idx * 2 + 1, mid + 1, tr, ul, ur);
return res;
}
}

public:
SegmentTree(const vector<int> &input)
: _n(input.size()), _arr(_n * 4), _lazy(_n * 4, 0) {
build(1, input, 0, _n - 1);
}

void addRange(int ul, int ur, int val) { add(1, 0, _n - 1, ul, ur, val); }

int sumRange(int ul, int ur) { return sum(1, 0, _n - 1, ul, ur); }
};

为了保证idx = v的节点(表示范围[tl, tr])其左右子节点为idx = 2v, idx = 2v + 1,同时左节点表示范围一定是[tl, mid],右节点表示范围一定是[mid + 1, tr],整个线段树要求的空间是4n(虽然一共只有2n - 1个有效的节点,_arr数组有空隙)

自上而下构造的线段树,需要4n的空间;自下而上构造的线段树,需要2n的空间

这里的tl, tr指的是当前这个Node节点(_arr[idx])所覆盖的范围,如果是Root Node,那么就是[0, _n - 1]

我们的add操作保证了最后停止的Node节点们(也许有多个,因为输入的[ul, ur]范围会被分割成若干个Node节点范围的相加),以及它们所有的父节点,存储的求和值(_arr[idx])都是最新正确的值,但是对于下面更小的子节点,则不做立刻更新,而是把待更新的值放入_lazy[idx],这种策略称作Lazy Propagation:

Lazy Propagation, which means when you perform update operations over a range, the update process affects the least nodes as possible so that the bigger the range you want to update the less time it consumes to update it. Eventually those changes will be propagates to the children and the whole array will be up to date.

References