0%

Segment Tree

Segment Tree选编

原理

如何表示线段树?

  1. 直接用node构造一棵树
  2. 用一个数组模拟一棵树

三大功能:

  1. build
  2. querySum
  3. update
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class NumArray {
SegmentTree tree;
public NumArray(int[] nums) {
tree = new SegmentTree(nums);
}

public void update(int i, int val) {
tree.update(i, val);
}

public int sumRange(int i, int j) {
return tree.querySum(i, j);
}

static class SegmentTree{
Node root = null;
SegmentTree(int[] array){
this.root = build(0, array.length-1, array);
}
Node build(int start, int end, int[] array){
if (start>end){
return null;
}
Node node = new Node(start, end);
if (start==end){
node.sum = array[start];
return node;
}
int mid = start + (end-start)/2;
node.left = build(start, mid, array);
node.right = build(mid+1, end, array);
node.sum = node.left.sum + node.right.sum;
return node;
}
int querySum(int start, int end){
return querySum(root, start, end);
}
int querySum(Node root, int start, int end){
if (start>end){
return 0;
}
if (root.start==start && root.end==end){
return root.sum;
}
int mid = root.start + (root.end-root.start)/2;
int leftsum = 0;
int rightsum = 0;
if (start<=mid){
leftsum = querySum(root.left, start, Math.min(mid, end));
}
if (end>=mid+1){
rightsum = querySum(root.right, Math.max(mid+1, start), end);
}
return leftsum+rightsum;
}
void update(int index, int value){
update(root, index, value);
}
void update(Node root, int index, int value){
if (root.start==index && root.end==index){
root.sum = value;
return;
}
int mid = root.start + (root.end-root.start)/2;
if (index <= mid){
update(root.left, index, value);
}else{
update(root.right, index, value);
}
root.sum = root.left.sum + root.right.sum;
}
}

static class Node{
int start;
int end;
int sum;
Node left;
Node right;
Node(int start, int end){
this.start = start;
this.end = end;
sum = 0;
left = null;
right = null;
}
}
}

327. Count of Range Sum

https://leetcode.com/problems/count-of-range-sum/discuss/1674377/Java-Segment-Tree-With-Explanation

prefix_sum超时,考虑用segment_tree

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
//Time Limit Exceeded
public int countRangeSum(int[] nums, int lower, int upper) {
int n=nums.length;
//pres[i]=[0,i]=pres[i-1]+nums[i]
long[] pres=new long[n];
pres[0]=nums[0];
for (int i = 1; i < n; i++) {
pres[i]=pres[i-1]+nums[i];
}
//sums(i,j)=[0,j]-[0,i-1]=pres[j]-pres[i-1]
int res=0;
for (int i = 0; i < n; i++) {
for (int j = i; j < n; j++) {
long sum= i==0 ? pres[j] : pres[j]-pres[i-1];
if(sum>=lower && sum<=upper){
res++;
}
}
}
return res;
}

https://leetcode.cn/problems/count-of-range-sum/solution/by-ac_oier-b36o/

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class Solution {
/*
树状数组模板
*/
int idx = 0; // 树状数组索引
int[] tr = new int[3 * 100010]; // 树状数组节点(数目等于元素坑位数)

int lowBit(int x) {
return x & -x;
}

// 往u坑位加上x
void add(int u, int x) {
for (int i = u; i <= idx; i += lowBit(i)) {
tr[i] += x;
}
}

// 查询[1,x]前缀和
int query(int u) {
int res = 0;
for (int i = u; i > 0; i -= lowBit(i)) {
res += tr[i];
}
return res;
}

public int countRangeSum(int[] nums, int lower, int upper) {
/*
离散化+树状数组:
题目原意:给定数组nums,寻找nums中区间和在[lower,upper]的区间sum[i,j]个数
设当前区间区间为[i,j],移动j指针,我们只需要的求出每个以j结尾的区间有多少个是区间和在[lower,upper]即可
朴素的解法中我们可以扫描k∈[0,i-1],求适合的k有多少个,时间复杂度综合为O(N^2)
逆向思维,我们把目光放在[0,k]区间上,设sum[0,i]=sum[0,k]+sum[k+1,i]=s
其中lower<=sum[k+1,i]<=upper --> lower<=s-sum[0,k]<=upper --> s-upper<=sum[0,k]<=s-lower
我们把问题就转化为了求在[0,i-1]区间内前缀和为[s-upper,s-lower]的个数
求前缀和"个数",同时前缀和的值域范围爆炸,我们把可能出现的前缀和直接离散化分散到特定数据结构
然后统计[s-upper,s-lower]范围内的前缀和个数可以抽象成 前面前缀和出现了就在对应坑位数目+1 ,然后再进行区间求和
此处要想在O(logN)内求出某个区间的和(前缀和作差得到)就要用到树状数组
前缀和可能出现的情况数目为:3*1e5
总体时间复杂度:O(NlogN) 空间复杂度:O(N)
*/
// 将前缀和进行去重
HashSet<Long> set = new HashSet<>(); // 存储可能出现的前缀和
long preSum = 0L;
set.add(0L);
for (int num : nums) {
preSum += num;
set.add(preSum);
set.add(preSum - upper);
set.add(preSum - lower);
}
// 排序并离散化
List<Long> list = new ArrayList<>(set);
Collections.sort(list);
HashMap<Long, Integer> map = new HashMap<>(); // 存储某个前缀和对应的索引
for (long sum : list) map.put(sum, ++idx); // 树状数组索引从1开始
preSum = 0L;
add(map.get(preSum), 1); // 初始化前缀和为0的坑位+1
int res = 0;
for (int num : nums) {
preSum += num;
// a 为前缀和preSum-lower对应坑位; b 为前缀和preSum-upper对应坑位
int a = map.get(preSum - lower), b = map.get(preSum - upper);
res += query(a) - query(b - 1); // 累加[preSum-upper,preSum-lower]前缀和个数
// 记得是先求完前缀和个数再更新,保证树状数组是[0,i-1]状态下的
add(map.get(preSum), 1);
}
return res;
}
}