-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathSegmentTreeIntervalAddMax.java
68 lines (60 loc) · 1.83 KB
/
SegmentTreeIntervalAddMax.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
package Data_Structures;
public class SegmentTreeIntervalAddMax {
int n;
int[] tmax;
int[] tadd; // tadd[i] affects tmax[i], tadd[2*i+1] and tadd[2*i+2]
void push(int root) {
tmax[root] += tadd[root];
tadd[2 * root + 1] += tadd[root];
tadd[2 * root + 2] += tadd[root];
tadd[root] = 0;
}
public SegmentTreeIntervalAddMax(int n) {
this.n = n;
tmax = new int[4 * n];
tadd = new int[4 * n];
}
public int max(int from, int to) {
return max(from, to, 0, 0, n - 1);
}
int max(int from, int to, int root, int left, int right) {
if (from == left && to == right) {
return tmax[root] + tadd[root];
}
push(root);
int mid = (left + right) >> 1;
int res = Integer.MIN_VALUE;
if (from <= mid)
res = Math.max(res, max(from, Math.min(to, mid), 2 * root + 1, left, mid));
else if (to > mid)
res = Math.max(res, max(Math.max(from, mid + 1), to, 2 * root + 2, mid + 1, right));
return res;
}
public void add(int from, int to, int delta) {
add(from, to, delta, 0, 0, n - 1);
}
void add(int from, int to, int delta, int root, int left, int right) {
if (from == left && to == right) {
tadd[root] += delta;
return;
}
// push can be skipped for add, but is necessary for other operations such as set
push(root);
int mid = (left + right) >> 1;
if (from <= mid)
add(from, Math.min(to, mid), delta, 2 * root + 1, left, mid);
if (to > mid)
add(Math.max(from, mid + 1), to, delta, 2 * root + 2, mid + 1, right);
tmax[root] = Math.max(tmax[2 * root + 1] + tadd[2 * root + 1], tmax[2 * root + 2] + tadd[2 * root + 2]);
}
// tests
public static void main(String[] args) {
SegmentTreeIntervalAddMax t = new SegmentTreeIntervalAddMax(10);
t.add(0, 9, 1);
t.add(2, 4, 2);
t.add(3, 5, 3);
System.out.println(t.max(0, 9));
System.out.println(t.tmax[0] + t.tadd[0]);
System.out.println(t.max(0, 0));
}
}