思路

cr. huahua

在做题的过程中,发现 binary search 的边界条件和各种目标,比如说找到 upper bound,lower bound 或者第一个比 target 大的数,最后一个比 target 小的数等等。喷到很多次,还是不能快速准确地写出来,所以在这里做一个总结。总结模板的灵感来自 Huahua 以下这个原模板

1
2
3
4
5
6
7
8
9
10
11
12
"""
Returns the smallest number m such that g(m) is true.
"""
def binary_search(l, r):
while l < r:
m = l + (r - l) // 2
if f(m): return m # if m is the answer
if g(m):
r = m # new range [l, m)
else
l = m + 1 # new range [m+1, r)
return l # or not found

计算组合方式

m 边界移动 计算区间
使 g(m) 为真的最小值 m = l + (r - l) // 2 if g(m):
r = m
else:
l = m + 1
[l, m)
使 g(m) 为真的最大值 m = m = l + (r-l+1)//2 if g(m):
l = m
else:
r = m - 1
(l, m]

Bisect

bisect_left

1
2
3
4
5
6
7
8
9
10
11
def bisectLeft(A, l, r, val):
'''
Return the leftmost index to insert the value
'''
while l < r:
m = (l + r)// 2
if A[m] < val:
l = m+ 1
else:
r = m
return l

bisect_right

1
2
3
4
5
6
7
8
9
10
11
12
def bisectRight(A, l, r, val):
'''
Return the rightmost index to insert the value
'''
while l < r:
m = (l+r) //2
if A[m] > val:
r = m
else:
l = m + 1
return l

example

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
A = [0,1,2,5,5,5,7,9]
# [0,1,2,3,4,5,6,7]

print('bisectRight')
print(bisectRight(A, 0, len(A), 5)) # 6
print(bisectRight(A, 0, len(A), 6)) # 6
print(bisectRight(A, 0, len(A), 0)) # 1
print(bisectRight(A, 0, len(A), 3)) # 3
print(bisectRight(A, 0, len(A), 9)) # 8

print('bisectLeft')
print(bisectLeft(A, 0, len(A), 5)) # 3
print(bisectLeft(A, 0, len(A), 6)) # 6
print(bisectLeft(A, 0, len(A), 0)) # 0
print(bisectLeft(A, 0, len(A), 3)) # 3
print(bisectLeft(A, 0, len(A), 9)) # 7

几种模板

Find the lower bound

1
2
3
4
5
6
7
8
9
10
# find the lowerBoud of A[l,r) 
# find the target or the first one bigger than target
def lowerBound(A, l, r, val):
while l < r:
m = l + (r-l)//2
if A[m] >= val:
r = m
else:
l = m + 1
return l

Find the upper bound

1
2
3
4
5
6
7
8
9
10
# find the upperBound of A(l,r]
# find the target or the last one smaller than the target
def upperBound(A, l, r, val):
while l < r:
m = l + (r-l+1)//2
if A[m] <= val:
l = m
else:
r = m - 1
return r

第一个大于 target 的值

1
2
3
4
5
6
7
8
9
10
# first position bigger than the target value
# [l,r)
def upperInsertId(A, l, r, val):
while l < r:
m = l + (r-l)//2
if A[m] > val:
r = m
else:
l = m + 1
return l

最后一个小于 target 的值

1
2
3
4
5
6
7
8
9
10
# the last position smaller than the target value
# (l,r]
def lowerInsertId(A, l, r, val):
while l < r:
m = l + (r-l+1)//2
if A[m] < val:
l = m #(m, r]
else:
r = m - 1 #(l,m-1]
return l

示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
A = [0,1,2,5,5,5,7,9]
# [0,1,2,3,4,5,6,7]
print(lowerBound(A, 0, len(A), 5)) # 3
print(lowerBound(A, 0, len(A), 6)) # 6 the first position bigger than the target
print(lowerBound(A, 0, len(A), 0)) # 0
print(lowerBound(A, 0, len(A), 3)) # 3

print('upperBound')
print(upperBound(A, -1, len(A)-1, 5)) # 5
print(upperBound(A, -1, len(A)-1, 6)) # 5 the last position smaller than the target
print(upperBound(A, -1, len(A)-1, 0)) # 0
print(upperBound(A, -1, len(A)-1, 3)) # 2
print(upperBound(A, -1, len(A)-1, 9)) # 7



print('upperInsert')
# insert id
print(upperInsertId(A, 1, len(A), 5)) # 6 position to insert another 5 after original 5s
print(upperInsertId(A, 1, len(A), 6)) # 6 position to insert the target
print(upperInsertId(A, 1, len(A), 0)) # 1
print(upperInsertId(A, 1, len(A), 3)) # 3
print(upperInsertId(A, 1, len(A), 8)) # 7
print(upperInsertId(A, 1, len(A), 9)) # 8 bigger than the largest one

print('lowerInsert')
print(lowerInsertId(A, -1, len(A)-1, 5)) # 2 the last element smaller than 5
print(lowerInsertId(A, -1, len(A)-1, 6)) # 5 the last smaller than 6
print(lowerInsertId(A, -1, len(A)-1, 0)) # -1 smaller than the first element
print(lowerInsertId(A, -1, len(A)-1, 3)) # 2
print(upperInsertId(A, -1, len(A)-1, 8)) # 7
print(upperInsertId(A, -1, len(A)-1, 9)) # 7