Trie树实战:三道典型例题

前言:

 给出模板代码,并推荐三个不错的Trie数题目,前两个比较简单,第三个比较有难度。

模板代码:

最常用的模板代码,插入和查询操作:

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll N = 500010; 

struct Node{
    ll son[26]; // 孩子节点
    ll cnt;     // 以当前节点为前缀的字符串数 (或作为单词结束标记)
} tr[N * 2];
ll tot;       // 节点总数

void init() // 初始化Trie
{
    memset(tr[0].son, 0, sizeof(tr[0].son));
    tr[0].cnt = 0;
    tot = 0; 
}

void insert(string s){  // 插入字符串
    ll u = 0;
    for(auto c : s){ 
        ll c_ = c - 'a';
        if(!tr[u].son[c_]){
            tr[u].son[c_] = ++tot;
            // memset(tr[tot].son, 0, sizeof(tr[tot].son));   多测需要
            // tr[tot].cnt = 0;
        }
        u = tr[u].son[c_];
        tr[u].cnt++; 
    }
}

ll query(string s){   // 查询前缀是 s 的字符串数量
    ll u = 0;
    for(auto c : s){
        ll cur = c - 'a';
        if(!tr[u].son[cur]){
            return 0; // 路径中断,不存在
        }
        u = tr[u].son[cur];
    }
    return tr[u].cnt; // 返回以s为前缀的字符串数
}
补充代码:
删除操作:        
bool del(int u, const string& s, int dep) {
    if (dep == s.length()) {
        // 1. 到达字符串末尾
        tr[u].cnt--;
        // 只有当 cnt 减为 0 且没有其他分支时,当前节点才能被删除
        return tr[u].cnt == 0; 
    }

    int cur = s[dep] - 'a';
    int nex = tr[u].son[cur]; // 下一个节点

    if (nex == 0) return false; // 字符串不存在

        // 2. 递归删除下一个节点
    if (del(nex, s, dep + 1)) {
        // 3. 子节点被删除,则清空当前节点指向它的指针
        tr[u].son[cur] = 0;

        for (int i = 0; i < 26; ++i) {
            if (tr[u].son[i] != 0) {
                return false; // 还有其他分支,不能删除
            }
        }
        return tr[u].cnt == 0; // 否则,只有当它不是其他单词的前缀时才能删除
    }

    return false; // 子节点没有被删除,则当前节点不能被删除
}
某个数的最大异或和:
// 01-Trie 节点结构 (使用 int 而不是 LL 节省空间)
struct Otr{
    int son[2]; // 只有 0 和 1 两个分支
} tr[N * 32]; // N个数字 * 32位
int tot;     // 01-Trie 节点总数

// 01-Trie 初始化
void init() { // 多测
    memset(tr[0].son, 0, sizeof(tr[0].son));
    tot = 0;
}

// 插入数字 num 的二进制表示 (从高位 M 到低位 0)
// M 是位数,例如 30
void oins(int num, int M = 30) {
    int u = 0;
    for (int i = M; i >= 0; --i) {
        int bit = (num >> i) & 1; // 当前位是 0 还是 1
        if (!tr[u].son[bit]) {
            tr[u].son[bit] = ++tot;
            // memset(tr[tot].son, 0, sizeof(tr[tot].son));  //多测需要
        }
        u = tr[u].son[bit];
    }
}

// 查询与数字 num 异或的最大值
int qmax(int num, int M = 30) {
    int u = 0;
    int ans = 0;
    for (int i = M; i >= 0; --i) {
        int bit = (num >> i) & 1;
        int rev = bit ^ 1; // 目标:走向相反路径以最大化异或结果

        if (tr[u].son[rev]) {
            // 存在相反路径,走相反路径 (异或结果为 1)
            ans |= (1 << i);
            u = tr[u].son[rev];
        } else {
            // 不存在相反路径,只能走相同路径 (异或结果为 0)
            u = tr[u].son[bit];
        }
    }
    return ans;
}
合并操作:

合并操作通常用于树形 DP 或处理树上路径问题。这里实现一个将 v 树合并到 u 树的函数。        

// 递归合并操作
// 将 v 树的节点信息合并到 u 树的节点
ll mer(ll u, ll v) {
    if (!u) return v; // u 节点空,返回 v 节点
    if (!v) return u; // v 节点空,返回 u 节点

    // 合并当前节点的信息(例如,前缀计数)
    tr[u].cnt += tr[v].cnt; 

    // 递归合并所有孩子
    for (ll i = 0; i < 26; ++i) {
        tr[u].son[i] = mer(tr[u].son[i], tr[v].son[i]);
    }
    
    // **注意:** 在某些复杂场景中,如果节点 v 是动态创建的,可能需要在这里回收 v 节点。
    // 在这里我们不处理回收,仅完成信息合并。
    return u; // 返回合并后的 u 节点
}

 MC0489黛玉葬花

分析:很前缀字符串的思路,找到一个点就代表一个字母,根据题意就是这一位置的相同前缀的数量为x , 计算一下结果 ans  +=  x *(x - 1) , 对每一个节点都这样操作一次。

做法一:Trie

#include <bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 1000010;
int nodes;
const int mod = 998244353;

struct TrieNode{
    int child[26];
    int sz;
}tr[N];

void add(string s)
{
    int u = 0;
    tr[u].sz++;
    for (int i = 0; i < s.size(); i++)
    {
        int c_num = s[i] - 'a';
        if (!tr[u].child[c_num])
        {
            tr[u].child[c_num] = ++nodes;
        }
        u = tr[u].child[c_num];
        tr[u].sz++;
    }   
}

ll C(ll x)
{
    if (x == 1)
        return 0;
    return (x - 1) * x;
}

// 2 2 1 2 3 1 1 1 1
// 2 1 1

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int n;
    cin >> n;
    ll ans = 0;
    for (int i = 0; i < n; i++)
    {
        string s;
        cin >> s;
        ans += s.size();
        add(s);
    }
    for (int u = 1; u <= nodes; u++)
    {
        ans = (ans + C(tr[u].sz)) % mod;
    }
    cout << ans << endl;

    return 0;
}

做法二:字符串哈希(略显麻烦了)思路是大致一样的

#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef unsigned long long ull;
const int N = 1e6 + 10 , P = 131;
const int mod = 998244353;
ull h[N], p[N];
ull find( int l, int r){
    return h[r] - h[l-1] * p[r-l+1];
}
int cal(int x){
    return x * (x - 1) / 2;
}
void solve(){
    unordered_map<ull, int> mp;
    int n , ans = 0;
    cin >> n;
    for (int i = 0; i < n; i++){
        string s;
        cin >> s;
        ans += s.length();
        p[0] = 1;
        for (int i = 1; i <= s.size(); i++){
            h[i] = h[i-1] * P + s[i - 1];
            p[i] = p[i-1] * P;
        }
        for (int i = 0; i < s.size(); i++){
            mp[find(1, i + 1)] ++;
        }
    }
    for (auto p : mp){
        int cnt = p.second;
        if(cnt > 1){
            ans = (ans + cal(cnt) * 2) % mod;
        }
    }
    cout << ans << endl;
}
signed main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    solve();
    return 0;
}

D1. Max Sum OR (Easy Version)

题意:题意:给出了两组0到r个数字(一个 2(r +1)个数字 , l == 0 )求两个个组中的数相互异或的和的最大值,想到前缀树,各位取反的思路,比如10101就需要去找01010或者00000才能没有损失。

做法一:Trie

求两个数异或的最大值要想到前缀树,本身就是各位取反的思维。需要注意倒序遍历,因为我们要先满足大的数字,再去满足小的数字,可以认为大的数字更加重要,先满足高位一定会产生最优解具体证明略。

#include <bits/stdc++.h>
using namespace std;
#define endl "\n"
#define ll long long
const int N = 200010, B = 19, M = N * B;
int vis[N], p[N]; // 节点的个数 每个点的每一位置代表一个节点
struct TrieNode
{
    int child[2];
    int sz;
} tr[M];

int nodes;

void init()
{
    tr[0].child[0] = tr[0].child[1] = 0;
    tr[0].sz = 0;
    nodes = 0;
}

void add(int x)
{
    int u = 0;
    tr[u].sz++; // 根节点开始
    for (int i = B - 1; i >= 0; i--)
    {
        int b_t = (x >> i) & 1;
        if (!tr[u].child[b_t])
        {
            tr[u].child[b_t] = ++nodes;
            tr[nodes].child[0] = tr[nodes].child[1] = 0;
            tr[nodes].sz = 0;
        }
        u = tr[u].child[b_t];
        tr[u].sz++;
    }
}

void del(int x)
{
    int u = 0;
    tr[u].sz;
    for (int i = B - 1; i >= 0; i--)
    {
        int b_t = (x >> i) & 1;
        u = tr[u].child[b_t];
        tr[u].sz--;
    }
}

int find(int x) // del负责修改
{
    int u = 0;
    int res = 0;
    for (int i = B - 1; i >= 0; i--)
    {
        int b_t = (x >> i) & 1;
        int b_p = 1 - b_t;
        if (tr[u].child[b_p] && tr[tr[u].child[b_p]].sz > 0)
        {
            u = tr[u].child[b_p];
            res = res | (b_p << i);
        }
        else
        {
            u = tr[u].child[b_t];
            res = res | (b_t << i);
        }
    }
    return res;
}

void solve()
{
    init();
    int l, r;
    cin >> l >> r;
    for (int i = l; i <= r; i++)
    {
        add(i);
        vis[i] = 0, p[i] = 0;
    }
    for (int i = r; i >= l; i--)
    {
        if (!vis[i])
        {
            del(i);
            int j = find(i);
            del(j);
            p[i] = j;
            p[j] = i;
            vis[i] = vis[j] = 1;
        }
    }
    cout << (ll)(r + 1) * r << endl;
    for (int i = l; i <= r; i++)
    {
        cout << p[i] << " ";
    }
    cout << endl;
}

int main()
{
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    int t;
    cin >> t;
    while (t--)
        solve();
    return 0;
}

做法二:找规律。

对于任何一个配对 (i, j),我们有 j = K ^ i。根据异或的性质 A ^ B = C 等价于 A ^ C = B,我们可以推导出: i ^ j = i ^ (K ^ i) = K 这意味着,通过这种方法找到的每一对数字 (i, j),它们的异或和都等于那个掩码 K

  • i[8, 15] 之间时,它的二进制长度是4位,K 就是 15 (1111_2)。所有这个范围内的数字都会和另一个数字配对,使得异或和为15。

    • 15 ^ 0 = 15

    • 14 ^ 1 = 15

    • 13 ^ 2 = 15

    • ...

    • 8 ^ 7 = 15

  • i[4, 7] 之间时,它的二进制长度是3位,K 就是 7 (111_2)。

    • 7 ^ 0 = 7

    • 6 ^ 1 = 7

    • ...

  • 以此类推。


#include <bits/stdc++.h>
using namespace std;
#define ll long long
int masks[N] , ans[N];

int main(){
    int t;
    cin >> t;
    while(t --){
        int l, r;
        cin >> l >> r;
        cout << (ll)r * (r + 1) << endl;
        for (int i = l; i <= r; i ++)
            masks[i] = 0;
        for (int i = r; i > 0; i--)
        {
            if (!masks[i])
            {
                int j = ((1 << (32 - __builtin_clz(i))) - 1) ^ i;
                // 统计出来0的个数
                ans[i] = j;
                ans[j] = i;
                masks[i] = 1;
                masks[j] = 1;
            }
        }
        if(!masks[0]) // 如果前面没有数字配对 就代表自己和自己配对
            ans[0] = 0;
        for (int i = l; i <= r; i++){
            cout << ans[i] << " ";
        }
        cout << endl;
    }

    return 0;
}

G. Yelkrab

这题比较有难度一点,是2024香港区域赛银牌题,(AI)解析如下(太菜乐~):

解题思路
第1步:理解问题核心 f(i, j)

首先,我们需要解决最核心的子问题:对于给定的 i 个字符串和分组大小 j,如何求得 f(i, j),也就是最大的总评分?

评分规则:一个小组的评分是组内所有字符串的“最长公共前缀”(LCP) 的长度 。总评分是所有小组评分之和 。

贪心策略:一个直观且正确的想法是采用贪心策略。我们应该不断地从当前还未分组的字符串中,找出能组成最长 LCP 的 j 个字符串,把它们分为一组,然后对剩下的字符串重复此过程 。

第2步:用前缀树(Trie)优化 f(i, j) 的计算

贪心策略虽然正确,但实现起来很慢。处理 LCP 问题,最自然的工具就是前缀树

  • 前缀树的性质:我们将所有字符串插入前缀树。树上的每个节点代表一个前缀,节点的深度就等于前缀的长度。所有经过同一个节点的字符串,都共享这个节点代表的前缀。

  • 计算总评分:假设前缀树上有一个节点 x,有 sz_x 个字符串经过它(即最终都落在了 x 的子树中) 。这意味着这

    sz_x 个字符串都拥有 x 所代表的那个前缀。对于分组大小为 j 的情况,我们可以在 x 这个节点层面形成 floor(sz_x / j) 个分组。

  • 一个重要的转换:一个 LCP 长度为 L 的分组,其评分为 L。这可以看作是在其 LCP 路径上的 L 个节点(深度从1到L)处,每个节点都贡献了 1 的评分。因此,f(i, j) 的总评分可以等价于所有节点 xfloor(sz_x / j) 的总和

第3步:处理动态增加的字符串

题目要求我们为每个 i (从1到n) 都输出一个结果,这是一个动态过程。我们需要在加入第 i 个字符串后,快速更新所有 f(i, j) 的值。

增量更新:当我们加入一个新的字符串 s_i 时,它会在前缀树中走过一条路径。路径上每个节点的 sz 计数都会加 1 。

  • 关键发现:当一个节点的 szk-1 变为 k 时,floor(sz / j) 的值并不会对所有的 j 都改变。只有当 jk 的因子(除数)时,floor((k-1) / j) 才会比 floor(k / j) 小 1。

  • 更新策略:因此,在 s_i 路径上的每个节点 x,当它的 sz_x 更新后,我们只需要遍历新 sz_x 的所有因子 d,并将对应的 F[d](即 f(i,d))的值加 1 即可 。
第4步:高效计算异或和

我们需要计算的最终结果是

(f(i,1)×1) ⊕ (f(i,2)×2) ⊕ ... ⊕ (f(i,i)×i),其中 是异或操作 。

  • 问题:在加入第 i 个字符串后,很多 f(i, j) 的值都可能发生变化。如果每次都从头计算这个异或和,总复杂度会过高。

  • 解决方案:我们需要一个能够支持“单点更新”和“前缀查询”的数据结构。树状数组(Fenwick Tree) 是完美的选择。

    • 单点更新:当 F[d] 的值从 old 变为 new 时,我们需要更新的项是 F[d] * d。我们向树状数组的第 d 个位置异或上 (old * d) ^ (new * d),就可以消除旧值的影响,并加入新值。

    • 前缀查询:使用树状数组,我们可以在 O(log n) 的时间内查询到 term[1] ⊕ term[2] ⊕ ... ⊕ term[i] 的结果。

知识点:前缀异或,树状数组,Trie

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const ll N = 500010;   
ll n;
struct TrieNode{
    ll child[26];
    ll sz;
} tr[N * 2];
vector<int> divi[N];
ll F[N] , bit[N] , nodes;

void cal(){   //预处理因子
    for (ll i = 1; i < N; i++)
    {
        for (ll j = i; j < N; j += i)
        {
            divi[j].push_back(i);
        }
    }
}

void init(int n)  // 针对多组测试数据初始化
{
    memset(tr[0].child, 0, sizeof(tr[0].child));
    memset(bit , 0,  sizeof(ll) *(n + 1));
    memset(F, 0, sizeof(ll) * (n + 1));
    nodes = 0 ; 
}

void update(ll i  ,ll val){      // 双庄数组模板
    while(i <= n){
        bit[i] ^= val;
        i += i & -i;
    }
}

ll query(ll i){
    ll sum = 0 ; 
    while(i > 0){
        sum ^= bit[i];
        i -= i & -i;
    }
    return sum;
}

void add(string s){
    ll u = 0;
    tr[u].sz++;
    for(auto c : s){ 
        ll c_num = c - 'a';
        if(!tr[u].child[c_num]){
            tr[u].child[c_num] = ++nodes;
            memset(tr[nodes].child , 0 , sizeof(tr[nodes].child));
            tr[nodes].sz = 0;
        }
        u = tr[u].child[c_num];
        tr[u].sz++;
        // 上面是tire模板  下面是更新新加入的字符串对结果的影响

        for(auto d : divi[tr[u].sz]){
            ll old = F[d];
            ll neww = ++F[d];
            update(d, (old * d) ^ (neww * d));
        }
    }
}

void solve(){
    
    cin >> n;
    init(n);
    for (ll i = 1; i <= n; i ++){
        string s;
        cin >> s;
        add(s);
        // 找到所有的字符串进行操作 树状数组快速求和是怎么求的呢
        cout << query(i) << (i == n ? "\n" : " ");
    }
    return;
}

int main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    cal();
    ll t;
    cin >> t;
    while( t -- ){
        solve();
    }
    return 0;
}

K. Master of Both

2025.10.22 补充一道22杭州icpc区域赛的好题吧,中期题里面偏简单的。

题目:

       给定 n 个字符串让你求出其中的字符串的逆序对的数量,字符串的大小是按字典序比较的,相同前缀不同长度的时候,短的字符串更小,让我们求出其中的逆序的字符串的个数,并且会给你q个映射,可能26个字符的相对大小不是按照字母表来的,是按照给出字母的顺序来的。

分析:

        我们可以发现任意的两个字符串的相对大小都是由第一个不一样的字符决定的,因此我们去思考怎样得出第一个不相同的字母的数量呢,我们就想到了 trie 树去统计出现在自己前面的某个字符串的下一位是i,而当前字符串的当前位是 i ,f[i][idx] + 前面统计到的前缀字符串就代表相对顺序是先 i 后 idx 的,注意是+cnt,然后根据输入的映射进行输出即可,需要注意的是对于ab,abc这种我们没办法在上述思路去分辨出大小的,需要引入一个比 a 更小的字符在末尾,就是 char('a' - 1),这是这个题比较难想的地方之一。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
// 最大节点数 (N * 2 比较大)
const ll N = 1500010; 
// Trie 节点结构体,变量名不超过四个字符
struct Nod {
    ll son[27]; // 26个子节点索引
    ll cnt;     // 记录以当前节点为前缀的单词数
    // bool end; // 如果需要标记单词结尾,可以添加
} tr[N];

ll f[27][27];

ll tot; // 当前已使用的节点总数 (代替 nodes)

void insert(string s) {
    ll u = 0;
    for (char ch : s) { 
        ll idx = ch - 'a' + 1;
        if (!tr[u].son[idx]) {
            tr[u].son[idx] = ++tot;
        }
		for(ll i = 0 ; i < 27 ; i ++){
			if(idx == i) continue;	
			if(tr[u].son[i])
				f[i][idx] += tr[tr[u].son[i]].cnt; //前缀树就自动已经默认了前面的值都是相等哒
		}
        u = tr[u].son[idx];
        tr[u].cnt++; // 前缀计数增加
	}
}

void solve(){
	ll n  ,q ;
	cin >> n >> q;
	for(ll i = 0 ; i <n ; i ++){
		string s;
		cin >> s;
		s = s + char('a' - 1);
		insert(s);
	}
	while( q -- ){
		string s;
		cin >> s;
		s = char('a' - 1) + s;
		ll ans= 0 ; 
		for(ll i = 1; i < 27 ; i ++ ){
			for(ll j = 0 ; j < i ;  j ++){
				ans += f[s[i] - 'a' + 1][s[j] - 'a' + 1] ; 
			}
		}
		cout << ans <<endl;
	}
}

signed main(){
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
	cout.tie(nullptr);
	solve();
    return 0;
}

评论
成就一亿技术人!
拼手气红包6.0元
还能输入1000个字符
 
红包 添加红包
表情包 插入表情
 条评论被折叠 查看
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值