线段树用于解决什么问题
线段树用于解决区间更新并查询的问题,比如以下操作交替进行
- updateRange
- updatePoint
- queryRange
线段树图解
下面是对长度为 7 的数组建立的线段树,每一段管理一个区间,每一段又分成两个子区间进行管理,从图中可以看出,这样的树是一颗完全二叉树,可以使用数组+递归的方式创建。
线段树的构建
使用数组建树。
void build(int s, int t, int p) {
// 对 [s,t] 区间建立线段树,当前根的编号为 p
if (s == t) {
d[p] = a[s];
return;
}
int m = s + ((t - s) >> 1);
// 移位运算符的优先级小于加减法,所以加上括号
// 如果写成 (s + t) >> 1 可能会超出 int 范围
build(s, m, p * 2), build(m + 1, t, p * 2 + 1);
// 递归对左右区间建树
d[p] = d[p * 2] + d[(p * 2) + 1]; // 根据需要可以选择不同的操作
}
数组大小保险期间是 4n,证明方法如下:
区间查询
构建线段树的主要目的是区间查询,区间查询,比如求区间的总和、求区间最大值/最小值等操作。基本思路是将其拆分成几个不相交的区间,这些不相交的区间正好就是线段树中的节点,然后合并多个区间的值。
int getsum(int l, int r, int s, int t, int p) {
// [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
if (l <= s && t <= r)
return d[p]; // 当前区间为询问区间的子集时直接返回当前区间的和
int m = s + ((t - s) >> 1), sum = 0;
if (l <= m) sum += getsum(l, r, s, m, p * 2);
// 如果左儿子代表的区间 [s, m] 与询问区间有交集, 则递归查询左儿子
if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
// 如果右儿子代表的区间 [m + 1, t] 与询问区间有交集, 则递归查询右儿子
return sum;
}
单点更新
单点更新的思路是从根节点出发,找到其所在的叶子节点,修改叶子节点的值,然后向上更新。这个过程可以递归表示。
void update_point(int index, int val) {
// 将a[index] = val
a[index] = val;
__update_point(index, val, 0, n-1, 1);
}
void __update_point(int index, int val, int s, int t, int p) {
if (s == t) {
d[p] = val;
return;
}
int mid = s + (t-s) / 2;
if (index <= mid) {
// 在左儿子上
__update_point(index, val, s, mid, p*2);
} else {
// 在右儿子上
__update_point(index, val ,mid+1, t, p*2+1);
}
// 左儿子或右儿子修改完之后,更新当前节点的值
d[p] = d[p*2] + d[p*2+1]
}
区间更新
如果要求修改区间 ,把所有包含在区间 中的节点都遍历一次、修改一次,时间复杂度无法承受。我们这里要引入一个叫做 「懒惰标记」 的东西。
每次修改,通过打标记的方式表明该节点对应的区间在某一次操作中被修改,但是不更新该节点的子节点,实质上的修改等到下一次访问带有标记的节点才进行。
void update(int l, int r, int c, int s, int t, int p) {
// [l, r] 为修改区间, c 为被修改的元素的变化量, [s, t] 为当前节点包含的区间, p
// 为当前节点的编号
if (l <= s && t <= r) {
d[p] += (t - s + 1) * c, b[p] += c;
return;
} // 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改
int m = s + ((t - s) >> 1);
if (b[p] && s != t) {
// 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
b[p * 2] += b[p], b[p * 2 + 1] += b[p]; // 将标记下传给子节点
b[p] = 0; // 清空当前节点的标记
}
if (l <= m) update(l, r, c, s, m, p * 2);
if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
d[p] = d[p * 2] + d[p * 2 + 1];
}
在访问的时候,如果遇到带有标记的节点,那么将标记下推。
int getsum(int l, int r, int s, int t, int p) {
// [l, r] 为查询区间, [s, t] 为当前节点包含的区间, p 为当前节点的编号
if (l <= s && t <= r) return d[p];
// 当前区间为询问区间的子集时直接返回当前区间的和
int m = s + ((t - s) >> 1);
if (b[p]) {
// 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m);
b[p * 2] += b[p], b[p * 2 + 1] += b[p]; // 将标记下传给子节点
b[p] = 0; // 清空当前节点的标记
}
int sum = 0;
if (l <= m) sum = getsum(l, r, s, m, p * 2);
if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
return sum;
}
如果是要将区间修改为一个指定的值,而不是加上某一个值:
void update(int l, int r, int c, int s, int t, int p) {
if (l <= s && t <= r) {
d[p] = (t - s + 1) * c, b[p] = c, v[p] = 1;
return;
}
int m = s + ((t - s) >> 1);
// 额外数组储存是否修改值
if (v[p]) {
d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m);
b[p * 2] = b[p * 2 + 1] = b[p];
v[p * 2] = v[p * 2 + 1] = 1;
v[p] = 0;
}
if (l <= m) update(l, r, c, s, m, p * 2);
if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
d[p] = d[p * 2] + d[p * 2 + 1];
}
int getsum(int l, int r, int s, int t, int p) {
if (l <= s && t <= r) return d[p];
int m = s + ((t - s) >> 1);
if (v[p]) {
d[p * 2] = b[p] * (m - s + 1), d[p * 2 + 1] = b[p] * (t - m);
b[p * 2] = b[p * 2 + 1] = b[p];
v[p * 2] = v[p * 2 + 1] = 1;
v[p] = 0;
}
int sum = 0;
if (l <= m) sum = getsum(l, r, s, m, p * 2);
if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
return sum;
}
总结修改和访问的逻辑就是,如果当前区间在修改区间内,那么直接增加标记,如果不在,则看下自身有没有修改标记,有就下推,然后递归左右儿子,最后更新自身节点值。
代表题目类型
区间最大值/最小值/和
区间合并问题
区间合并,区间查询问题:
- 修改某一个区间的值。
- 查询区间为 中满足条件的连续最长区间值。
这类问题需要在区间更新和区间查询的基础上增加变动,在进行向上更新时需要对左右子节点的区间进行合并。