algo10 递归、分治与二分 — demo.py 代码详解
运行方式
cd algo10_divide_conquer/code
python demo.py代码逐段详解
第1步:归并排序 — 分治法的标准模板
归并排序完美体现了分治三阶段——分解(Divide)、解决(Conquer)、合并(Combine)。
def merge_sort(arr, left=0, right=None):
if left >= right:
return [arr[left]] # 基线:单元素数组天然有序
mid = (left + right) // 2
left_sorted = merge_sort(arr, left, mid) # 解决左半
right_sorted = merge_sort(arr, mid + 1, right) # 解决右半
return merge(left_sorted, right_sorted) # 合并merge() 使用双指针法合并两个有序数组:两个指针 i 和 j 分别指向左右子数组的起始位置,每次比较 left[i] 和 right[j],取较小的放入结果,并将对应指针后移。当一个子数组耗尽后,直接将另一子数组的剩余部分追加。
复杂度分析:
第2步:逆序对计数 — 在合并过程中"顺手"统计
逆序对的数量衡量了数组的"混乱程度"。
def merge_and_count(left, right):
result = [] ; i = j = 0 ; inv_count = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i]); i += 1
else:
result.append(right[j])
inv_count += len(left) - i # ← 关键行:横跨逆序对!
j += 1
...为什么 len(left) - i? 当 right[j] 被放入结果时,说明 right[j] < left[i]。由于 left 已排序,left[i:] 中的所有元素都大于 right[j],每个都与 right[j] 构成一个逆序对。因此累加 len(left) - i。
运行示例:对 [2, 4, 1, 3, 5]:
- 分解为
[2, 4]和[1, 3, 5] [2, 4]内部:0 个逆序对[1, 3, 5]内部:0 个逆序对- 合并时:
right[0]=1比left[0]=2小 → 横跨 2 个逆序对(2,1), (4,1);然后4 > 3→ 横跨 1 个(4,3);共 3 个。
第3步:快速排序 — 三种 pivot 策略
快速排序的核心在 partition() 上——选择 pivot 并将数组分为"小于 pivot"和"大于 pivot"两部分。
Lomuto 方案
def partition_lomuto(arr, lo, hi):
pivot = arr[hi] # 以最后一个元素为 pivot
i = lo - 1 # "小元素边界"
for j in range(lo, hi):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i] # 把小的往前换
arr[i + 1], arr[hi] = arr[hi], arr[i + 1] # pivot 归位
return i + 1i 指针始终指向"最后一个 ≤ pivot 的元素"。当 j 扫描到一个 ≤ pivot 的元素时,i 向右移一位(腾出空间),然后交换,将该元素放入"小元素区"。
随机化快排
最关键的改进:随机选择 pivot 来避免最坏情况
pivot_idx = random.randint(lo, hi)
arr[pivot_idx], arr[hi] = arr[hi], arr[pivot_idx] # 随机 swap 到末尾Hoare 方案
从两端向中间双向扫描,比 Lomuto 更高效(约少 30% 的交换次数)。
第4步:二分搜索变体
def lower_bound(arr, target):
lo, hi = 0, len(arr) # 注意 hi = len(arr),不是 len(arr)-1
while lo < hi:
mid = (lo + hi) // 2
if arr[mid] < target:
lo = mid + 1
else:
hi = mid
return lo关键区别:
hi = len(arr)而非len(arr)-1:允许返回len(arr)(所有元素都 < target)lo < hi而非lo <= hi:因为 hi 是半开区间端点arr[mid] < target而非<=:<时lo = mid+1,>=时hi = mid
实例:lower_bound([1,2,4,4,4,6,8,8,10], 4):
lo=0, hi=9, mid=4→arr[4]=4, hi=4lo=0, hi=4, mid=2→arr[2]=4, hi=2lo=0, hi=2, mid=1→arr[1]=2<4, lo=2lo=2, hi=2→ 返回 2(第一个 4 的位置)
第5步:三分搜索
用于寻找单峰函数的极值点。每次取两个三等分点,利用函数值的比较来缩小区间:
f(m1) < f(m2)→ 极值在左半区 → 缩小到[lo, m2]f(m1) > f(m2)→ 极值在右半区 → 缩小到[m1, hi]f(m1) = f(m2)→ 极值在中间 → 缩小到[m1, m2]
每次迭代区间缩小约 33%,
第6步:牛顿迭代法
几何解释:在点
演示:求
关键概念速查表
| 概念 | 要点 | 代码位置 |
|---|---|---|
| 分治三步骤 | Divide → Conquer → Combine | merge_sort() |
| 逆序对计数 | 合并时累加左半剩余元素数 | merge_and_count() |
| Lomuto 划分 | pivot=最后一个,i 维护小元素边界 | partition_lomuto() |
| 随机化快排 | 随机选 pivot 交换到末尾 | quicksort_randomized() |
| 二分下界 | 第一个 ≥ x 的位置,hi=len(arr) | lower_bound() |
| 二分答案 | 对单调性答案空间二分 | binary_search_answer() |
| 三分搜索 | 两个三等分点比较,缩小区间 | ternary_search() |
| 牛顿法 | 切线迭代,二次收敛 | newtons_method() |
完整代码
# -*- coding: utf-8 -*-
"""
algo10 递归、分治与二分 — 演示代码
==================================
功能:归并排序 + 逆序对计数、快速排序(三种 pivot 策略对比)、
二分搜索(含 lower/upper bound)、三分搜索、牛顿法求根。
每个函数都有中文 docstring,每行逻辑代码都有中文注释。
运行方式:在 algo10_divide_conquer/ 目录下执行 python code/demo.py
"""
import os
import random
import time
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams['axes.unicode_minus'] = False
_HERE = os.path.dirname(os.path.abspath(__file__))
_IMAGES = os.path.join(_HERE, '..', 'images')
os.makedirs(_IMAGES, exist_ok=True)
# ============================================================================
# 第一部分:归并排序 + 逆序对计数
# ============================================================================
def merge_sort(arr, left=0, right=None):
"""
归并排序(不修改原数组,返回新数组)。
分治步骤:分解(二分) → 解决(递归排序子数组) → 合并(merge)
"""
if right is None:
right = len(arr) - 1
if left >= right:
return [arr[left]]
mid = (left + right) // 2
# 递归排序左右两部分
left_sorted = merge_sort(arr, left, mid)
right_sorted = merge_sort(arr, mid + 1, right)
# 合并两个有序数组
return merge(left_sorted, right_sorted)
def merge(left, right):
"""合并两个有序数组。两个指针分别遍历,取较小的放入结果。"""
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
# 将剩余元素追加到结果
result.extend(left[i:])
result.extend(right[j:])
return result
def count_inversions(arr):
"""
使用分治思想统计逆序对数量。
返回: (sorted_array, inversion_count)
思路:逆序对 = 左半边逆序对 + 右半边逆序对 + 横跨逆序对
横跨逆序对在 merge 过程中计算
"""
n = len(arr)
if n <= 1:
return arr[:], 0
mid = n // 2
# 递归计算左右半边的逆序对
left, inv_left = count_inversions(arr[:mid])
right, inv_right = count_inversions(arr[mid:])
# 合并并计算横跨逆序对
merged, inv_cross = merge_and_count(left, right)
return merged, inv_left + inv_right + inv_cross
def merge_and_count(left, right):
"""
合并两个有序数组,同时统计横跨逆序对。
当 right[j] 被放入结果时,left 中剩余的元素都大于 right[j],
它们都与 right[j] 构成逆序对。
"""
result = []
i = j = 0
inv_count = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
# 关键:left[i:] 中的所有元素都 > right[j]
inv_count += len(left) - i
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result, inv_count
# ============================================================================
# 第二部分:快速排序(三种 pivot 策略)
# ============================================================================
def quicksort_lomuto(arr, lo, hi):
"""Lomuto 划分方案的快速排序。pivot = arr[hi](最后一个元素)。"""
if lo >= hi:
return
p = partition_lomuto(arr, lo, hi)
quicksort_lomuto(arr, lo, p - 1)
quicksort_lomuto(arr, p + 1, hi)
def partition_lomuto(arr, lo, hi):
"""
Lomuto 划分:以 arr[hi] 为 pivot。
维护 i 指针表示"小元素边界",j 扫描整个区间。
最终 pivot 位于 i+1。
"""
pivot = arr[hi]
i = lo - 1 # i 指向最后一个 ≤ pivot 的元素
for j in range(lo, hi):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i] # 交换
# 将 pivot 放到正确位置
arr[i + 1], arr[hi] = arr[hi], arr[i + 1]
return i + 1
def quicksort_randomized(arr, lo, hi):
"""随机化快速排序:随机选择 pivot 避免最坏情况。"""
if lo >= hi:
return
# 随机选一个元素与最后一个交换作为 pivot
pivot_idx = random.randint(lo, hi)
arr[pivot_idx], arr[hi] = arr[hi], arr[pivot_idx]
p = partition_lomuto(arr, lo, hi)
quicksort_randomized(arr, lo, p - 1)
quicksort_randomized(arr, p + 1, hi)
def quicksort_hoare(arr, lo, hi):
"""Hoare 划分方案的快速排序(更高效)。"""
if lo >= hi:
return
p = partition_hoare(arr, lo, hi)
quicksort_hoare(arr, lo, p)
quicksort_hoare(arr, p + 1, hi)
def partition_hoare(arr, lo, hi):
"""
Hoare 划分:从两端向中间扫描,交换顺序错误的元素。
pivot = arr[(lo+hi)//2](中间元素)。
"""
pivot = arr[(lo + hi) // 2]
i = lo - 1
j = hi + 1
while True:
# 从左找第一个 ≥ pivot 的元素
i += 1
while arr[i] < pivot:
i += 1
# 从右找第一个 ≤ pivot 的元素
j -= 1
while arr[j] > pivot:
j -= 1
if i >= j:
return j # 注意:Hoare 返回 j,不是 pivot 的最终位置
arr[i], arr[j] = arr[j], arr[i]
# ============================================================================
# 第三部分:二分搜索及其变体
# ============================================================================
def binary_search(arr, target):
"""标准二分搜索:在有序数组中查找 target 的索引,找不到返回 -1。"""
lo, hi = 0, len(arr) - 1
while lo <= hi:
mid = (lo + hi) // 2
if arr[mid] == target:
return mid
elif arr[mid] < target:
lo = mid + 1
else:
hi = mid - 1
return -1
def lower_bound(arr, target):
"""
二分下界:返回第一个 >= target 的位置。
如果所有元素都 < target,返回 len(arr)。
"""
lo, hi = 0, len(arr)
while lo < hi:
mid = (lo + hi) // 2
if arr[mid] < target:
lo = mid + 1
else:
hi = mid
return lo
def upper_bound(arr, target):
"""
二分上界:返回第一个 > target 的位置。
如果所有元素都 <= target,返回 len(arr)。
"""
lo, hi = 0, len(arr)
while lo < hi:
mid = (lo + hi) // 2
if arr[mid] <= target:
lo = mid + 1
else:
hi = mid
return lo
def equal_range(arr, target):
"""返回 target 在有序数组中的起止范围 [first, last)。"""
lo = lower_bound(arr, target)
hi = upper_bound(arr, target)
return lo, hi
def binary_search_answer(check_func, lo, hi, eps=1e-9):
"""
二分答案:在 [lo, hi] 上二分搜索,找到满足 check_func(mid) 的最小值。
适用于答案具有单调性的场景(如最大化最小值)。
check_func(x) 返回 True 表示 x 可行。
"""
while hi - lo > eps:
mid = (lo + hi) / 2.0
if check_func(mid):
hi = mid # mid 可行 → 尝试更小的值
else:
lo = mid # mid 不可行 → 需要更大的值
return lo
# ============================================================================
# 第四部分:三分搜索与牛顿法
# ============================================================================
def ternary_search(func, lo, hi, eps=1e-9):
"""
三分搜索:在单峰函数 func 的 [lo, hi] 区间找最小值点。
每次取两个三等分点 m1, m2,比较函数值来缩小区间。
"""
while hi - lo > eps:
m1 = lo + (hi - lo) / 3.0
m2 = hi - (hi - lo) / 3.0
if func(m1) < func(m2):
hi = m2 # 极小值在左半区
elif func(m1) > func(m2):
lo = m1 # 极小值在右半区
else:
lo, hi = m1, m2 # 相等,极值在中间
return (lo + hi) / 2.0
def newtons_method(f, f_prime, x0, eps=1e-9, max_iter=100):
"""
牛顿迭代法求 f(x)=0 的根。
x_{n+1} = x_n - f(x_n) / f'(x_n)
每次在 (x_n, f(x_n)) 处作切线,切线与 x 轴交点即为 x_{n+1}。
"""
x = x0
history = [x] # 记录迭代轨迹
for _ in range(max_iter):
fx = f(x)
fpx = f_prime(x)
if abs(fx) < eps:
break
if fpx == 0:
print(' 警告: 导数为 0,牛顿法停滞')
break
x = x - fx / fpx # 牛顿迭代公式
history.append(x)
return x, history
# ============================================================================
# 第五部分:可视化与基准测试
# ============================================================================
def benchmark_sorting():
"""对比三种排序算法在不同输入规模下的运行时间。"""
sizes = [100, 500, 1000, 2000, 5000]
merge_times = []
quick_times = []
builtin_times = []
for n in sizes:
arr = [random.randint(0, 100000) for _ in range(n)]
# 归并排序
arr_copy = arr[:]
t0 = time.perf_counter()
merge_sort(arr_copy)
merge_times.append(time.perf_counter() - t0)
# 随机化快排
arr_copy = arr[:]
t0 = time.perf_counter()
quicksort_randomized(arr_copy, 0, n - 1)
quick_times.append(time.perf_counter() - t0)
# Python 内置排序(Timsort)
arr_copy = arr[:]
t0 = time.perf_counter()
arr_copy.sort()
builtin_times.append(time.perf_counter() - t0)
fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(sizes, merge_times, 'o-', label='Merge Sort', linewidth=2)
ax.plot(sizes, quick_times, 's-', label='Randomized QuickSort', linewidth=2)
ax.plot(sizes, builtin_times, '^-', label='Python Timsort', linewidth=2)
ax.set_xlabel('Input Size (n)', fontsize=12)
ax.set_ylabel('Time (seconds)', fontsize=12)
ax.set_title('Sorting Algorithm Benchmark', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
path = os.path.join(_IMAGES, 'algo10_sort_benchmark.png')
plt.savefig(path, dpi=150, bbox_inches='tight')
plt.close()
print(f'[可视化] 排序基准测试图已保存: {path}')
# ============================================================================
# 主程序
# ============================================================================
def main():
print('=' * 60)
print('algo10 递归、分治与二分 — 演示')
print('=' * 60)
# --- 1. 归并排序 ---
print('\n### 1. 归并排序 ###')
arr1 = [38, 27, 43, 3, 9, 82, 10]
print(f'原数组: {arr1}')
sorted_arr = merge_sort(arr1)
print(f'排序后: {sorted_arr}')
# --- 2. 逆序对计数 ---
print('\n### 2. 逆序对计数 ###')
test_arrays = [
[2, 4, 1, 3, 5], # 逆序对: (2,1), (4,1), (4,3) → 3
[5, 4, 3, 2, 1], # 逆序对: C(5,2) = 10
[1, 2, 3, 4, 5], # 逆序对: 0(已排序)
]
for arr in test_arrays:
_, inv = count_inversions(arr)
print(f' {arr} → 逆序对数: {inv}')
# --- 3. 快速排序对比 ---
print('\n### 3. 快速排序(对比三种 pivot 策略) ###')
test_arr = [3, 6, 8, 10, 1, 2, 1]
print(f'原数组: {test_arr}')
for name, sort_func in [('Lomuto', quicksort_lomuto),
('Randomized', quicksort_randomized),
('Hoare', quicksort_hoare)]:
arr_copy = test_arr[:]
sort_func(arr_copy, 0, len(arr_copy) - 1)
print(f' {name} 快排结果: {arr_copy}')
# --- 4. 二分搜索 ---
print('\n### 4. 二分搜索及其变体 ###')
sorted_test = [1, 2, 4, 4, 4, 6, 8, 8, 10]
targets = [4, 5, 8]
print(f'有序数组: {sorted_test}')
for t in targets:
idx = binary_search(sorted_test, t)
lo = lower_bound(sorted_test, t)
hi = upper_bound(sorted_test, t)
rng = equal_range(sorted_test, t)
print(f' target={t}: 精确查找={idx}, lower_bound={lo}, '
f'upper_bound={hi}, range={rng}, 出现次数={rng[1]-rng[0]}')
# --- 5. 二分答案示例 ---
print('\n### 5. 二分答案示例:求 √2 ###')
def can_square_less_than_2(x):
return x * x >= 2.0 # 检查 x^2 是否 >= 2
sqrt2 = binary_search_answer(can_square_less_than_2, 0, 2.0)
print(f' 二分答案求 √2 ≈ {sqrt2:.10f} (真实值: {2.0**0.5:.10f})')
# --- 6. 三分搜索 ---
print('\n### 6. 三分搜索:求单峰函数极小值 ###')
# 示例函数:f(x) = (x-3)^2 + 2,最小值在 x=3 处
def f_quadratic(x):
return (x - 3.0) ** 2 + 2.0
min_x = ternary_search(f_quadratic, -10, 10)
print(f' f(x) = (x-3)^2 + 2 在 [-10, 10] 上的最小值点: x={min_x:.6f}')
print(f' f({min_x:.6f}) = {f_quadratic(min_x):.10f}')
# --- 7. 牛顿法 ---
print('\n### 7. 牛顿迭代法求根 ###')
# 目标:求 x^3 - 2x - 5 = 0 的根(约 2.09455...)
def f_newton(x):
return x**3 - 2*x - 5
def f_prime_newton(x):
return 3*x**2 - 2
root, history = newtons_method(f_newton, f_prime_newton, x0=2.5)
print(f' 方程: x^3 - 2x - 5 = 0')
print(f' 初始猜测 x0 = 2.5')
print(f' 求得的根: x = {root:.10f}')
print(f' 验证 f(root) = {f_newton(root):.2e}')
print(f' 迭代历史 ({len(history)} 步):')
for i, xi in enumerate(history):
print(f' x_{i} = {xi:.10f}')
# --- 8. 基准测试 ---
print('\n### 8. 排序算法基准测试 ###')
benchmark_sorting()
if __name__ == '__main__':
main()