跳转至

01 Trie 树与异或问题实战

本文记录了两道经典的 异或(XOR) 相关的算法题,分别涵盖了 01 Trie 树 的标准应用、Python 的内存优化技巧,以及将树上路径问题转化为数组最大异或对问题的思维转换。

题目1:CF706D Vasiliy's Multiset - 洛谷

核心思路

这道题是 “求最大异或值” 的模板题。核心思想是建立一颗 01 Trie 树(字典树),将每个数字的二进制位从高到低插入树中。 查询时,利用贪心策略:对于每一位,尽量走与当前位相反的分支(0走1,1走0),这样能让异或结果该位为1,从而最大化数值。

❌ 原始版本:二维数组写法 (MLE)

类 C++ 的写法,使用 tree[p][0]tree[p][1] 存储左右子节点。 痛点:在 Python 中,list 对象也是对象,存储大量的小列表会消耗巨大的内存头(Overhead)。这种写法在 Codeforces 上会直接 MLE (Memory Limit Exceeded)

import sys  
data = sys.stdin.read().split()  
it = iter(data)  
q = int(next(it))  
max_n = 200005 * 31  # q 和 位数  
max_bit = 29  # 从第29位枚举到第0位 (共30位)  

tree = [[0, 0] for _ in range(max_n)]  
pas = [0] * max_n  
node = 0  


def insert(x):  
    global node  
    p = 0  
    for i in range(max_bit, -1, -1):  
        u = (x >> i) & 1  
        if not tree[p][u]:  
            node += 1  
            tree[p][u] = node  
        p = tree[p][u]  
        pas[p] += 1  

def delete(x):  
    p = 0  
    for i in range(max_bit, -1, -1):  
        u = (x >> i) & 1  
        p = tree[p][u]  # 简化了删除步骤,不管是不是0,直接减  
        pas[p] -= 1  

def query(x):   # 求最大异或值通用函数  
    res = 0  
    p = 0  
    for i in range(max_bit, -1, -1):  
        u = (x >> i) & 1  
        want = 1 - u  # 想要相反的位  

        if tree[p][want] and pas[tree[p][want]] > 0:  
            res |= (1 << i)  
            p = tree[p][want]  
        else:  
            p = tree[p][u]  
    return res  


insert(0)  # 题目说集合初始包含一个0  
ans = []  
for _ in range(q):  
    opt, x = next(it), int(next(it))  
    if opt == '+':  
        insert(x)  
    elif opt == '-':  
        delete(x)  
    else:  
        ans.append(str(query(x)))  

print('\n'.join(ans))

✅ 优化版本:扁平化数组 (AC)

为了解决内存问题,我们采用 “数组扁平化” 的技巧。 我们不再使用列表套列表,而是直接开两个大的一维数组: * tree0:专门存左孩子(bit=0 的路径)。 * tree1:专门存右孩子(bit=1 的路径)。

PyPy3 解释器下,这种纯一维整数数组会被优化成接近 C 语言原生数组的内存布局,极大节省空间。 注意: 此解法必须使用 PyPy3 提交

import sys  
data = sys.stdin.read().split()  
it = iter(data)  
q = int(next(it))  
max_n = 200005 * 31  # q 和 位数  
max_bit = 29  # 从第29位枚举到第0位 (共30位)  

pas = [0] * max_n  
node = 0  

# --- 关键修改:用两个一维数组代替二维数组 ---    
# tree0[p] 相当于 tree[p][0]    
# tree1[p] 相当于 tree[p][1]    
# 这样写在 PyPy3 下极其省内存  
tree0 = [0] * max_n  
tree1 = [0] * max_n  
def insert(x):  
    global node  
    p = 0  
    for i in range(max_bit, -1, -1):  
        u = (x >> i) & 1  

        # 这里的逻辑要做对应修改:判断 u 是 0 还是 1        if u == 0:  
            if not tree0[p]:  
                node += 1  
                tree0[p] = node  
            p = tree0[p]  
        else:  
            if not tree1[p]:  
                node += 1  
                tree1[p] = node  
            p = tree1[p]  

        pas[p] += 1  


def delete(x):  
    p = 0  
    for i in range(max_bit, -1, -1):  
        u = (x >> i) & 1  
        if u == 0:  
            p = tree0[p]  
        else:  
            p = tree1[p]  
        pas[p] -= 1  

def query(x):   # 求最大异或值通用函数  
    res = 0  
    p = 0  
    for i in range(max_bit, -1, -1):  
        u = (x >> i) & 1  
        want = 1 - u  # 想要相反的位  

        # 检查 want 方向是否有路且 pas > 0        if want == 0:  
            if tree0[p] and pas[tree0[p]] > 0:  
                res |= (1 << i)  
                p = tree0[p]  
            else:  
                p = tree1[p]  # 只能走 1        else:  # want == 1  
            if tree1[p] and pas[tree1[p]] > 0:  
                res |= (1 << i)  
                p = tree1[p]  
            else:  
                p = tree0[p]  # 只能走 0    return res  


insert(0)  # 题目说集合初始包含一个0  
ans = []  
for _ in range(q):  
    opt, x = next(it), int(next(it))  
    if opt == '+':  
        insert(x)  
    elif opt == '-':  
        delete(x)  
    else:  
        ans.append(str(query(x)))  

print('\n'.join(ans))

题目2:树上最长异或路径 P4551 - 洛谷

核心思路

这道题的标准做法是 DFS 预处理 + 01 Trie 树,但我们可以借鉴 LeetCode 421 的思路,用 哈希表 + 贪心 来解决。

  1. 转化问题: 先用 DFS 求出从根节点到每个节点 \(i\) 的路径异或和 \(D[i]\)。 根据异或性质:\(u \to v\) 的路径异或和 = \(D[u] \oplus D[v]\)(公共祖先部分的异或和会被抵消为 0)。

    此时,问题就转化为了:在数组 \(D\) 中选出两个数,使它们的异或值最大。

  2. 哈希表贪心: 也就是 421. 数组中两个数的最大异或值 的解法。

    • 从高位到低位枚举答案。
    • 对于每一位,假设当前能达到的最大前缀是 new_ans
    • 利用性质:如果 \(A \oplus B = C\),那么 \(A \oplus C = B\)。我们检查是否存在一个数 \(pre\),使得 \(pre \oplus \text{new\_ans}\) 也在哈希表中。
import sys  

sys.setrecursionlimit(200000)  
data = map(int, sys.stdin.read().split())  
it = iter(data)  
n = next(it)  

path = [[] for _ in range(n + 1)]  
for _ in range(n - 1):  
    u, v, w = next(it), next(it), next(it)  
    path[u].append([w, v])  

d = [0] * (n + 1)  

def dfs(u, fa):  
    for w, v in path[u]:  
        if v != fa:  
            d[v] = d[u] ^ w  
            dfs(v, u)  

dfs(1, 0)  

high = max(d).bit_length() - 1  
ans = 0  
mask = 0  
for i in range(high, -1, -1):  
    mask |= 1 << i  
    new_ans = (1 << i) | ans  
    hashs = set()  
    for s in d:  
        pre = s & mask  
        if pre ^ new_ans in hashs:  
            ans = new_ans  
            break  
        hashs.add(pre)  
print(ans)

参考资料:灵神题解 421. 数组中两个数的最大异或值 - 力扣(LeetCode)