字典树模板

本文介绍了ACWing和LeetCode中的字典树模板,展示了如何使用字典树结构实现字符串插入和查询操作,以及在LeetCode 720题中应用字典树解决最长单词问题。

介绍

字典树,也叫trie树,是用来高效存储字符串和查询字符串的数据结构!

在这里插入图片描述

1 acwing字典树模板

C++代码如下,

const int N = 100010;

int son[N][26], cnt[N], idx;
char str[N];


void insert(char str[]){
    int p = 0;
    for(int i = 0; str[i]; i++){
        int u = str[i] - 'a';
        if(!son[p][u]) son[p][u] = ++idx;
        p = son[p][u];
    }
    cnt[p]++;//以p结尾的单词数目,p表示结点编号
    return;
}


int query(char str[]){
    int p = 0;
    for(int i = 0; str[i]; i++){
        int u = str[i] - 'a';
        if(!son[p][u]) return 0;
        p = son[p][u];
    }
    return cnt[p];
}

2 leetcode字典树模板

题目1720. 词典中最长的单词

解题思路:字典树

C++代码如下,

class Trie {
public:
    Trie() {
        this->children = vector<Trie *>(26, nullptr);
        this->isEnd = false;
    }

    bool insert(const string &word) {
        Trie* node = this;
        for (const auto &ch : word) {
            int index = ch - 'a';
            if (node->children[index] == nullptr) {
                node->children[index] = new Trie();
            }
            node = node->children[index];
        }
        node->isEnd = true;
        return true;
    }

    bool search(const string& word) {
        Trie* node = this;
        for (const auto& ch : word) {
            int index = ch - 'a';
            if (node->children[index] == nullptr || !node->children[index]->isEnd) {
                return false;
            }
            node = node->children[index];
        }
        return node != nullptr && node->isEnd;
    }

private:
    vector<Trie*> children;
    bool isEnd;
};

class Solution {
public:
    string longestWord(vector<string>& words) {
        Trie trie;
        for (const auto& word : words) {
            trie.insert(word);
        }
        string longest = "";
        for (const auto& word : words) {
            if (trie.search(word)) {
                if (word.size() > longest.size() || (word.size() == longest.size() && word < longest)) {
                    longest = word;
                }
            }
        }
        return longest;
    }
};

python3代码如下,

class Trie:
    def __init__(self):
        self.children = [None] * 26
        self.isEnd = False

    def insert(self, word: str) -> None:
        node = self
        for ch in word:
            ch = ord(ch) - ord('a')
            if not node.children[ch]:
                node.children[ch] = Trie()
            node = node.children[ch]
        node.isEnd = True
    
    def search(self, word: str) -> bool:
        node = self
        for ch in word:
            ch = ord(ch) - ord('a')
            if node.children[ch] is None or not node.children[ch].isEnd:
                return False
            node = node.children[ch]
        return True

class Solution:
    def longestWord(self, words: List[str]) -> str:
        t = Trie()
        for word in words:
            t.insert(word)
        longest = ""
        for word in words:
            if t.search(word) and (len(word) > len(longest) or len(word) == len(longest) and word < longest):
                longest = word
        return longest

题目21803. 统计异或值在范围内的数对有多少

解题思路:前缀字典树。

C++代码如下,

class TrieNode {
public:
    TrieNode* children[2] = {nullptr, nullptr};
    int sum = 0;
    TrieNode() : sum(0) {}
};

class Trie {
private:
    TrieNode* root = new TrieNode();
    static constexpr int HIGH_BIT = 14;
public:
    void add(int x) {
        //将数x插入到前缀字典树中
        TrieNode* cur = root;
        for (int i = HIGH_BIT; i >= 0; --i) {
            int bit = (x >> i) & 1;
            if (cur->children[bit] == nullptr) {
                cur->children[bit] = new TrieNode();
            }
            cur = cur->children[bit];
            cur->sum += 1;
        }
        return;
    }

    int get(int x, int limit) {
        //从前缀字典树中找到y,使得y^x <= limit,求有多少个这样的y
        int res = 0;
        TrieNode* cur = root;
        for (int i = HIGH_BIT; i >= 0; --i) {
            int bit = (x >> i) & 1;
            if ((limit >> i) & 1) {
                //limit的第i为1,从高位14开始数
                if (cur->children[bit] != nullptr) {
                    res += cur->children[bit]->sum;
                }
                if (cur->children[bit^1] == nullptr) {
                    return res;
                }
                cur = cur->children[bit^1];
            } else {
                //limit的第i位为0,从高位14开始数
                if (cur->children[bit] == nullptr) {
                    return res;
                }
                cur = cur->children[bit];
            }
        }
        res += cur->sum;
        return res;
    }
};

class Solution {
public:
    int countPairs(vector<int>& nums, int low, int high) {
        function<int(int)> f =[&] (int limit) -> int {
            int res = 0;
            int n = nums.size();
            Trie trie = Trie();
            for (int i = 0; i < n-1; ++i) {
                trie.add(nums[i]);
                res += trie.get(nums[i+1], limit);
            }
            return res;
        };   
        return f(high) - f(low-1);
    }
};

python3代码如下,

HIGH_BIT = 14

class TrieNode:
    def __init__(self):
        self.children = [None, None]
        self.sum = 0

class Trie:
    def __init__(self):
        self.root = TrieNode()
    
    def add(self, x: int) -> None:
        #往前缀trie树中插入数x
        cur = self.root
        for i in range(HIGH_BIT, -1, -1):
            bit = (x >> i) & 1
            if not cur.children[bit]:
                cur.children[bit] = TrieNode()
            cur = cur.children[bit]
            cur.sum += 1
        return 
    
    def get(self, x: int, limit: int) -> int:
        #从前缀trie树中找到y,使得y^x <= limit,返回这样的y的个数
        res = 0
        cur = self.root 
        for i in range(HIGH_BIT, -1, -1):
            bit = (x >> i) & 1
            if (limit >> i) & 1:
                if cur.children[bit]:
                    #字典树中y^x的值,与limit进行比较
                    #记HIGH_BIT-i为k
                    #前k-1位相同,但第k位不同,y^x的第k为0,而limit的第k位为1
                    res += cur.children[bit].sum 
                if not cur.children[bit^1]:
                    #字典树中y^x的值,与limit进行比较
                    #记HIGH_BIT-i为k
                    #前k位都相同的数,它不存在
                    return res 
                cur = cur.children[bit^1]
            else:
                #记HIGH_BIT-i为k
                #limit的第k位为0
                #y^X的第k位的数只能为0
                if not cur.children[bit]:
                    #字典树中y^x的值,与limit进行比较
                    #记HIGH_BIT-i为k
                    #前k位都相同的数,它不存在
                    return res 
                cur = cur.children[bit]
        res += cur.sum #加上末尾结点的值 
        return res 
    
class Solution:
    def countPairs(self, nums: List[int], low: int, high: int) -> int:
        def f(limit: int) -> int:
            #nums中nums[i]^nums[j]<=limit,且i<j的(i,j)的个数
            res = 0
            trie = Trie()
            n = len(nums)
            for i in range(n-1):
                trie.add(nums[i])
                res += trie.get(nums[i+1], limit)
            return res 
        return f(high) - f(low-1)

题目3421. 数组中两个数的最大异或值

C++字典树模板,只需要定义一个TrieNode结构体即可。

struct TrieNode {
    TrieNode* children[2] = {nullptr, nullptr};
    TrieNode() {
        children[0] = nullptr;
        children[1] = nullptr;
    }
};

class Solution {
public:
    int findMaximumXOR(vector<int>& nums) {
        TrieNode* trie = new TrieNode();
        int res = 0;
        for (auto num : nums) {
            //在trie中查找num
            TrieNode* cur = trie;
            int ans = 0;
            for (int i = 31; i >= 0; --i) {
                if (cur == nullptr) break;
                int x = (num >> i) & 1;
                if (cur->children[1-x] != nullptr) {
                    ans += 1 << i;
                    cur = cur->children[1-x];
                } else {
                    cur = cur->children[x];
                }
            }
            res = max(res, ans);

            //把num写入到trie中
            cur = trie; //重新将cur置为起始结点
            for (int i = 31; i >= 0; --i) {
                int x = (num >> i) & 1;
                if (cur->children[x] == nullptr) {
                    cur->children[x] = new TrieNode();
                }
                cur = cur->children[x];
            }
        }
        return res;
    }
};

python3简洁版字典树(纯字典实现),

class Solution:
    def findMaximumXOR(self, nums: List[int]) -> int:
        trie = {}
        res = 0
        for num in nums:
            #从trie树中查找元素
            ans = 0
            cur = trie 
            for i in range(31,-1,-1):
                x = (num >> i) & 1
                if (1-x) in cur:
                    ans += (1 << i)
                    cur = cur[1-x]
                else:
                    if x in cur:
                        cur = cur[x]
                    else:
                        break 
            res = max(res, ans) #更新res
            #将num插入trie树中
            cur = trie
            for i in range(31,-1,-1):
                x = (num >> i) & 1
                if x in cur:
                    pass 
                else:
                    cur[x] = {}
                cur = cur[x]
        return res 
                

参考

acwing算法基础之数据结构–trie算法或字典树算法

oi-wiki字典树

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

YMWM_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值