Skip to content

algo14 线段树与树状数组 — exercise.py 练习指南

Download exercise.py

练习目标

通过三个练习巩固 BIT 和线段树的核心操作:差分 BIT 的区间更新、线段树的区间最大值、BIT 逆序对计数。

预备知识

  • BIT 的 add()query() 操作及 lowbit 原理
  • 差分数组技巧:diff[l]+=x, diff[r+1]=x 实现区间更新
  • 线段树的递归构建和查询结构
  • 逆序对的定义和 BIT 统计方法

任务清单

任务1:差分 BIT — 区间更新 + 点查询

  • 核心原理:维护原数组 A 的差分数组 D 的 BIT。区间 [l,r] 加 val → bit.add(l, val); bit.add(r+1, -val)
  • 点查询bit.query(i) 就是 A[i] 的值(因为 A[i] = sum(D[1..i]))。

任务2:线段树区间最大值

  • 不需要惰性传播!最大值在区间更新时无法简单地用 lazy 标签处理(除非是区间赋值操作)。
  • 单点更新:从叶子向上更新路径,每个节点取左右子树的最大值。
  • 区间查询:标准的三分支递归——完全包含/不相交/部分相交。

任务3:BIT 逆序对计数

  • 离散化:将原始数组的值映射到 1~N 的排名(sorted + dict)。
  • 统计方法
    • 从左到右遍历,BIT 初始为空
    • 对每个元素 x:逆序对 += i - BIT.query(x)(已遍历的元素中 > x 的个数)
    • 然后 BIT.add(x, 1)
  • 复杂度O(nlogn)

提示

  1. 差分 BIT 中注意 n+2 数组大小,为 r+1 越界留空间。
  2. 线段树最大值查询中,不相交时返回 -inf(而非 0)。
  3. BIT 逆序对需要离散化,否则值域太大 BIT 数组放不下。
py
# -*- coding: utf-8 -*-
"""
algo14 线段树与树状数组 — 练习代码
==================================
请完成以下 TODO 任务,巩固 BIT 和线段树的理解。
"""


# ============================================================================
# TODO 1: 实现差分 BIT 支持区间更新 + 点查询
# ============================================================================
class DiffBIT:
    """
    差分 BIT:维护差分数组的 BIT,支持区间更新 + 点查询。
    区间 [l, r] + val → diff[l] += val, diff[r+1] -= val
    点查询 arr[i] = sum(diff[1..i])
    """
    def __init__(self, n):
        self.n = n
        self.tree = [0] * (n + 2)  # +2 为了 diff[r+1] 不越界

    # TODO: 实现 add 和 query 方法
    # def _add(self, i, delta):
    #     while i <= self.n:
    #         self.tree[i] += delta
    #         i += i & -i

    # def _query(self, i):
    #     total = 0
    #     while i > 0:
    #         total += self.tree[i]
    #         i -= i & -i
    #     return total

    def range_add(self, l, r, val):
        """区间 [l, r] 所有元素 +val"""
        # TODO: self._add(l, val)
        # TODO: self._add(r + 1, -val)
        pass

    def point_query(self, i):
        """查询 arr[i] 的值"""
        # TODO: return self._query(i)
        pass


# ============================================================================
# TODO 2: 线段树 — 区间最大值(带单点更新)
# ============================================================================
class MaxSegmentTree:
    """
    线段树维护区间最大值(不需要惰性传播,因为最大值不可差分)。
    支持:单点更新、区间最大值查询。
    """
    def __init__(self, arr):
        self.n = len(arr)
        self.tree = [0] * (4 * self.n)
        if self.n > 0:
            self._build(1, 0, self.n - 1, arr)

    def _build(self, p, l, r, arr):
        """构建线段树。"""
        # TODO: 实现
        # if l == r:
        #     self.tree[p] = arr[l]
        #     return
        # mid = (l + r) // 2
        # self._build(p*2, l, mid, arr)
        # self._build(p*2+1, mid+1, r, arr)
        # self.tree[p] = max(self.tree[p*2], self.tree[p*2+1])
        pass

    def update(self, idx, val):
        """单点更新:arr[idx] = val"""
        self._update(1, 0, self.n - 1, idx, val)

    def _update(self, p, l, r, idx, val):
        # TODO: 实现
        pass

    def query_max(self, ql, qr):
        """区间最大值查询"""
        return self._query_max(1, 0, self.n - 1, ql, qr)

    def _query_max(self, p, l, r, ql, qr):
        # TODO: 实现
        pass


# ============================================================================
# TODO 3: BIT 上的逆序对计数
# ============================================================================
def count_inversions_bit(arr):
    """
    使用 BIT 统计逆序对。遍历数组,对每个元素 arr[i]:
    1. 查询 BIT 中有多少元素 > arr[i](即 i - rank(arr[i]))
    2. 将 arr[i] 插入 BIT
    需要先对值进行离散化(坐标压缩)。

    参数:
        arr: 整数数组
    返回:
        逆序对总数
    """
    # TODO: 1. 对 arr 进行离散化(sorted(set(arr)) → 排名)
    # TODO: 2. 从右到左遍历(或从左到右,使用不同的计数方式)
    # TODO: 3. 使用 BIT 查询 + 插入
    return 0


# ============================================================================
# 测试代码
# ============================================================================
if __name__ == '__main__':
    print('=' * 50)
    print('algo14 线段树与树状数组 — 练习')
    print('=' * 50)

    # 测试 TODO 1: 差分 BIT
    print('\n--- TODO 1: 差分 BIT — 区间更新 + 点查询 ---')
    db = DiffBIT(10)
    db.range_add(2, 5, 3)
    db.range_add(4, 7, 2)
    for i in range(1, 9):
        print(f'  arr[{i}] = {db.point_query(i)}')
    # 期望: arr[1]=0, arr[2]=3, arr[3]=3, arr[4]=5, arr[5]=5, arr[6]=2, arr[7]=2, arr[8]=0

    # 测试 TODO 2: 区间最大值线段树
    print('\n--- TODO 2: 线段树 — 区间最大值 ---')
    arr = [4, 2, 7, 1, 9, 3, 6]
    mst = MaxSegmentTree(arr)
    print(f'  数组: {arr}')
    print(f'  max[1..4] = {mst.query_max(1, 4)} (期望: 9)')
    mst.update(3, 10)
    print(f'  更新 arr[3]=10 后 max[1..4] = {mst.query_max(1, 4)} (期望: 10)')

    # 测试 TODO 3: BIT 逆序对
    print('\n--- TODO 3: BIT 逆序对计数 ---')
    test_arrs = [[2, 4, 1, 3, 5], [5, 4, 3, 2, 1], [1, 2, 3, 4, 5]]
    for a in test_arrs:
        print(f'  {a} → 逆序对: {count_inversions_bit(a)}')

    print('\n提示: 请补全上方所有 TODO 函数后再运行测试。')