0%

Trie

Prefix Tree选编

原理 (from labuladong)

https://mp.weixin.qq.com/s/hGrTUmM1zusPZZ0nA9aaNw

Trie 树又叫字典树、前缀树、单词查找树,是一种二叉树衍生出来的高级数据结构,主要应用场景是处理字符串前缀相关的操作。

Trie 树用「树枝」存储字符串(键),用「节点」存储字符串(键)对应的数据(值)

关于 MapSet,是两个抽象数据结构(接口),Map 存储一个键值对集合,其中键不重复,Set 存储一个不重复的元素集合。

常见的 MapSet 的底层实现原理有哈希表和二叉搜索树两种,比如 Java 的 HashMap/HashSet 和 C++ 的 unorderd_map/unordered_set 底层就是用哈希表实现,而 Java 的 TreeMap/TreeSet 和 C++ 的 map/set 底层使用红黑树这种自平衡 BST 实现的。

  • HashMap: 迭代顺序和插入顺序无关,迭代顺序不可预知

  • LinkedHashMap: 迭代顺序和插入顺序相同,迭代顺序可预知

  • TreeMap: 基于红黑树,映射基于key的自然顺序进行排序,或根据创建映射时提供的Comparator

而本文实现的 TrieSet/TrieMap 底层则用 Trie 树这种结构来实现。

了解数据结构的读者应该知道,本质上 Set 可以视为一种特殊的 MapSet 其实就是 Map 中的键。

HashMap<K, V> 的优势是能够在 O(1) 时间通过键查找对应的值,但要求键的类型 K 必须是「可哈希」的;而 TreeMap<K, V> 的特点是方便根据键的大小进行操作,但要求键的类型 K 必须是「可比较」的。

本文要实现的 TrieMap 也是类似的,由于 Trie 树原理,我们要求 TrieMap<V> 的键必须是字符串类型,值的类型 V 可以随意。

TrieMap 中的树节点 TrieNode 的代码实现是这样:

1
2
3
4
5
/* Trie 树节点实现 */
class TrieNode<V> {
V val = null;
TrieNode<V>[] children = new TrieNode[256];
}

这个 val 字段存储键对应的值,children 数组存储指向子节点的指针。

但是和之前的普通多叉树节点不同,TrieNodechildren 数组的索引是有意义的,代表键中的一个字符

比如说 children[97] 如果非空,说明这里存储了一个字符 'a',因为 'a' 的 ASCII 码为 97。

我们的模板只考虑处理 ASCII 字符,所以 children 数组的大小设置为 256。不过这个可以根据具体问题修改,比如改成更小的数组或者 HashMap<Character, TrieNode> 都是一样的效果。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
// 底层用 Trie 树实现的键值映射
// 键为 String 类型,值为类型 Vclass TrieMap<V> {

/***** 增/改 *****/

// 在 Map 中添加 key
public void put(String key, V val);

/***** 删 *****/

// 删除键 key 以及对应的值
public void remove(String key);

/***** 查 *****/

// 搜索 key 对应的值,不存在则返回 null
// get("the") -> 4
// get("tha") -> null
public V get(String key);

// 判断 key 是否存在在 Map 中
// containsKey("tea") -> false
// containsKey("team") -> true
public boolean containsKey(String key);

// 在 Map 的所有键中搜索 query 的最短前缀
// shortestPrefixOf("themxyz") -> "the"
public String shortestPrefixOf(String query);

// 在 Map 的所有键中搜索 query 的最长前缀
// longestPrefixOf("themxyz") -> "them"
public String longestPrefixOf(String query);

// 搜索所有前缀为 prefix 的键
// keysWithPrefix("th") -> ["that", "the", "them"]
public List<String> keysWithPrefix(String prefix);

// 判断是和否存在前缀为 prefix 的键
// hasKeyWithPrefix("tha") -> true
// hasKeyWithPrefix("apple") -> false
public boolean hasKeyWithPrefix(String prefix);

// 通配符 . 匹配任意字符,搜索所有匹配的键
// keysWithPattern("t.a.") -> ["team", "that"]
public List<String> keysWithPattern(String pattern);

// 通配符 . 匹配任意字符,判断是否存在匹配的键
// hasKeyWithPattern(".ip") -> true
// hasKeyWithPattern(".i") -> false
public boolean hasKeyWithPattern(String pattern);

// 返回 Map 中键值对的数量
public int size();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
class TrieMap<V> {
// ASCII 码个数
private static final int R = 256;
// 当前存在 Map 中的键值对个数
private int size = 0;
// Trie 树的根节点
private TrieNode<V> root = null;

private static class TrieNode<V> {
V val = null;
TrieNode<V>[] children = new TrieNode[R];
}

/***** 增/改 *****/

// 在 map 中添加或修改键值对
public void put(String key, V val) {
if (!containsKey(key)) {
// 新增键值对
size++;
}
// 需要一个额外的辅助函数,并接收其返回值
root = put(root, key, val, 0);
}

// 定义:向以 node 为根的 Trie 树中插入 key[i..],返回插入完成后的根节点
private TrieNode<V> put(TrieNode<V> node, String key, V val, int i) {
if (node == null) {
// 如果树枝不存在,新建
node = new TrieNode<>();
}
if (i == key.length()) {
// key 的路径已插入完成,将值 val 存入节点
node.val = val;
return node;
}
char c = key.charAt(i);
// 递归插入子节点,并接收返回值
node.children[c] = put(node.children[c], key, val, i + 1);
return node;
}

/***** 删 *****/

// 在 Map 中删除 key
public void remove(String key) {
if (!containsKey(key)) {
return;
}
// 递归修改数据结构要接收函数的返回值
root = remove(root, key, 0);
size--;
}

// 定义:在以 node 为根的 Trie 树中删除 key[i..],返回删除后的根节点
private TrieNode<V> remove(TrieNode<V> node, String key, int i) {
if (node == null) {
return null;
}
if (i == key.length()) {
// 找到了 key 对应的 TrieNode,删除 val
node.val = null;
} else {
char c = key.charAt(i);
// 递归去子树进行删除
node.children[c] = remove(node.children[c], key, i + 1);
}
// 后序位置,递归路径上的节点可能需要被清理
if (node.val != null) {
// 如果该 TireNode 存储着 val,不需要被清理
return node;
}
// 检查该 TrieNode 是否还有后缀
for (int c = 0; c < R; c++) {
if (node.children[c] != null) {
// 只要存在一个子节点(后缀树枝),就不需要被清理
return node;
}
}
// 既没有存储 val,也没有后缀树枝,则该节点需要被清理
return null;
}

/***** 查 *****/

// 搜索 key 对应的值,不存在则返回 null
public V get(String key) {
// 从 root 开始搜索 key
TrieNode<V> x = getNode(root, key);
if (x == null || x.val == null) {
// x 为空或 x 的 val 字段为空都说明 key 没有对应的值
return null;
}
return x.val;
}

// 判断 key 是否存在在 Map 中
public boolean containsKey(String key) {
return get(key) != null;
}

// 判断是和否存在前缀为 prefix 的键
public boolean hasKeyWithPrefix(String prefix) {
// 只要能找到一个节点,就是存在前缀
return getNode(root, prefix) != null;
}

// 在所有键中寻找 query 的最短前缀
public String shortestPrefixOf(String query) {
TrieNode<V> p = root;
// 从节点 node 开始搜索 key
for (int i = 0; i < query.length(); i++) {
if (p == null) {
// 无法向下搜索
return "";
}
if (p.val != null) {
// 找到一个键是 query 的前缀
return query.substring(0, i);
}
// 向下搜索
char c = query.charAt(i);
p = p.children[c];
}

if (p != null && p.val != null) {
// 如果 query 本身就是一个键
return query;
}
return "";
}

// 在所有键中寻找 query 的最长前缀
public String longestPrefixOf(String query) {
TrieNode<V> p = root;
// 记录前缀的最大长度
int max_len = 0;

// 从节点 node 开始搜索 key
for (int i = 0; i < query.length(); i++) {
if (p == null) {
// 无法向下搜索
break;
}
if (p.val != null) {
// 找到一个键是 query 的前缀,更新前缀的最大长度
max_len = i;
}
// 向下搜索
char c = query.charAt(i);
p = p.children[c];
}

if (p != null && p.val != null) {
// 如果 query 本身就是一个键
return query;
}
return query.substring(0, max_len);
}

// 搜索前缀为 prefix 的所有键
public List<String> keysWithPrefix(String prefix) {
List<String> res = new LinkedList<>();
// 找到匹配 prefix 在 Trie 树中的那个节点
TrieNode<V> x = getNode(root, prefix);
if (x == null) {
return res;
}
// DFS 遍历以 x 为根的这棵 Trie 树
traverse(x, new StringBuilder(prefix), res);
return res;
}

// 遍历以 node 节点为根的 Trie 树,找到所有键
private void traverse(TrieNode<V> node, StringBuilder path, List<String> res) {
if (node == null) {
// 到达 Trie 树底部叶子结点
return;
}

if (node.val != null) {
// 找到一个 key,添加到结果列表中
res.add(path.toString());
}

// 回溯算法遍历框架
for (char c = 0; c < R; c++) {
// 做选择
path.append(c);
traverse(node.children[c], path, res);
// 撤销选择
path.deleteCharAt(path.length() - 1);
}
}

// 通配符 . 匹配任意字符
public List<String> keysWithPattern(String pattern) {
List<String> res = new LinkedList<>();
traverse(root, new StringBuilder(), pattern, 0, res);
return res;
}

// 遍历函数,尝试在「以 node 为根的 Trie 树中」匹配 pattern[i..]
private void traverse(TrieNode<V> node, StringBuilder path, String pattern, int i, List<String> res) {
if (node == null) {
// 树枝不存在,即匹配失败
return;
}
if (i == pattern.length()) {
// pattern 匹配完成
if (node.val != null) {
// 如果这个节点存储着 val,则找到一个匹配的键
res.add(path.toString());
}
return;
}
char c = pattern.charAt(i);
if (c == '.') {
// pattern[i] 是通配符,可以变化成任意字符
// 多叉树(回溯算法)遍历框架
for (char j = 0; j < R; j++) {
path.append(j);
traverse(node.children[j], path, pattern, i + 1, res);
path.deleteCharAt(path.length() - 1);
}
} else {
// pattern[i] 是普通字符 c
path.append(c);
traverse(node.children[c], path, pattern, i + 1, res);
path.deleteCharAt(path.length() - 1);
}
}

// 判断是和否存在前缀为 prefix 的键
public boolean hasKeyWithPattern(String pattern) {
// 从 root 节点开始匹配 pattern[0..]
return hasKeyWithPattern(root, pattern, 0);
}

// 函数定义:从 node 节点开始匹配 pattern[i..],返回是否成功匹配
private boolean hasKeyWithPattern(TrieNode<V> node, String pattern, int i) {
if (node == null) {
// 树枝不存在,即匹配失败
return false;
}
if (i == pattern.length()) {
// 模式串走到头了,看看匹配到的是否是一个键
return node.val != null;
}
char c = pattern.charAt(i);
// 没有遇到通配符
if (c != '.') {
// 从 node.children[c] 节点开始匹配 pattern[i+1..]
return hasKeyWithPattern(node.children[c], pattern, i + 1);
}
// 遇到通配符
for (int j = 0; j < R; j++) {
// pattern[i] 可以变化成任意字符,尝试所有可能,只要遇到一个匹配成功就返回
if (hasKeyWithPattern(node.children[j], pattern, i + 1)) {
return true;
}
}
// 都没有匹配
return false;
}

// 从节点 node 开始搜索 key,如果存在返回对应节点,否则返回 null
private TrieNode<V> getNode(TrieNode<V> node, String key) {
TrieNode<V> p = node;
// 从节点 node 开始搜索 key
for (int i = 0; i < key.length(); i++) {
if (p == null) {
// 无法向下搜索
return null;
}
// 向下搜索
char c = key.charAt(i);
p = p.children[c];
}
return p;
}

public int size() {
return size;
}
}

208. Implement Trie (Prefix Tree)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
public class LeetCode208 {
class Trie {

TrieNode root;
public Trie() {
root=null;
}

public void insert(String word) {
root=put(root,word,0);
}

private TrieNode put(TrieNode node, String word, int i){
if(node==null){
node=new TrieNode();
}
if(i==word.length()){
node.val=1;
return node;
}
char c=word.charAt(i);
node.children[c-'a']=put(node.children[c-'a'],word,i+1 );
return node;
}

public boolean search(String word) {
TrieNode x=getNode(root,word);
if(x==null || x.val==-1){
return false;
}
return true;
}

public boolean startsWith(String prefix) {
return getNode(root,prefix)!=null;
}

private TrieNode getNode(TrieNode root, String word){
TrieNode p=root;
for (int i = 0; i < word.length(); i++) {
if(p==null){
return null;
}
char c=word.charAt(i);
p=p.children[c-'a'];
}
return p;
}
}

class TrieNode{
int val;
TrieNode[] children;

TrieNode(){
val=-1;
children=new TrieNode[26]; //'a'-'z'
}
}
}

1804. Implement Trie II (Prefix Tree)

优化labuladong:

  1. 初始root=new TrieNode() ,其值无意义
  2. put该递归为循环,与get相统一
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
public class LeetCode1804 {
class Trie {

TrieNode root;
public Trie() {
root=new TrieNode();
}

public void insert(String word) {
TrieNode p=root;
for (int i = 0; i < word.length(); i++) {
char c=word.charAt(i);
if(p.children[c-'a']==null){
p.children[c-'a']=new TrieNode();
}
p=p.children[c-'a'];
p.preCount++;
}
p.wordCount++;
}


private TrieNode getNode(TrieNode node, String word){
TrieNode p=root;
for (int i = 0; i < word.length(); i++) {
if(p==null){
break;
}
char c=word.charAt(i);
p=p.children[c-'a'];
}
return p;
}

public int countWordsEqualTo(String word) {
TrieNode p=getNode(root,word);
return p==null ? 0 : p.wordCount;
}

public int countWordsStartingWith(String prefix) {
TrieNode p=getNode(root,prefix);
return p==null ? 0 : p.preCount;
}

public void erase(String word) {
TrieNode p=root;
for (int i = 0; i < word.length(); i++) {
char c=word.charAt(i);
p=p.children[c-'a'];
p.preCount--;
}
p.wordCount--;
}
}

class TrieNode {
int wordCount;
int preCount;
TrieNode[] children;

TrieNode(){
wordCount=0;
preCount=0;
children=new TrieNode[26];
}
}
}

648. Replace Words

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
public class LeetCode648 {
public String replaceWords(List<String> dictionary, String sentence) {
StringBuilder sb=new StringBuilder();
Trie trie=new Trie();
for (String s : dictionary) {
trie.insert(s);
}
String[] strings = sentence.split(" ");
for (String string : strings) {
boolean found=false;
for (int i = 0; i < string.length(); i++) {
if(trie.containsWord(string.substring(0,i+1))){
sb.append(string.substring(0,i+1));
found=true;
break;
}
}
if(!found){
sb.append(string);
}
sb.append(" ");
}
sb.deleteCharAt(sb.length()-1);
return sb.toString();
}

class Trie{

Node root;

Trie(){
root=new Node();
}

void insert(String word){
Node p=root;
for (int i = 0; i < word.length(); i++) {
char c=word.charAt(i);
if(p.children[c-'a']==null){
p.children[c-'a']=new Node();
}
p=p.children[c-'a'];
}
p.word++;
}

boolean containsWord(String word){
Node p=getNode(root,word);
if(p==null){
return false;
}
return p.word==1 ? true : false;
}

private Node getNode(Node node, String word){
Node p=root;
for (int i = 0; i < word.length(); i++) {
if(p==null){
break;
}
char c=word.charAt(i);
p=p.children[c-'a'];
}
return p;
}
}

class Node{
int word;
Node[] children;

Node(){
word=0;
children=new Node[26];
}
}
}

211. Design Add and Search Words Data Structure

dfs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class WordDictionary {

Node root;

public WordDictionary() {
root=new Node();
}

public void addWord(String word) {
Node p=root;
for (int i = 0; i < word.length(); i++) {
char c=word.charAt(i);
if(p.children[c-'a']==null){
p.children[c-'a']=new Node();
}
p=p.children[c-'a'];
}
p.word=1;
}

public boolean search(String word) {
return dfs(root,word,0);
}

private boolean dfs(Node node, String word, int index){
if(index==word.length()){
return node.word==1;
}
char c=word.charAt(index);
if(c!='.'){//是字母就继续找
Node next=node.children[c-'a'];
if(next!=null && dfs(next,word,index+1)){
return true;
}
}else{//是通配符就逐一试,通一个就行
for (int i = 0; i < 26; i++) {
if(node.children[i]!=null && dfs(node.children[i],word,index+1 )){
return true;
}
}
}
return false; //全都走不通则false
}
}

class Node {
int word;
Node[] children;

Node() {
word=0;
children=new Node[26];
}
}

677. Map Sum Pairs

注意:sum的prefix可能不存在

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
class MapSum {

Node root;
public MapSum() {
root=new Node();
}

public void insert(String key, int val) {
Node p=root;
for (int i = 0; i < key.length(); i++) {
char c=key.charAt(i);
if(p.children[c-'a']==null){
p.children[c-'a']=new Node();
}
p=p.children[c-'a'];
}
p.val=val;
}

public int sum(String prefix) {
Node p=root;
for (int i = 0; i < prefix.length(); i++) {
char c=prefix.charAt(i);
p=p.children[c-'a'];
if(p==null){
return 0;
}
}
return dfs(p);
}

private int dfs(Node node){
if(node==null){
return 0;
}
int sum=node.val;
for (int i = 0; i < 26; i++) {
sum+=dfs(node.children[i]);
}
return sum;
}


}

class Node {
int val;
Node[] children;

Node(){
val=0;
children=new Node[26];
}
}

212. Word Search II

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
List<String> res=new ArrayList<>();
TrieNode root=new TrieNode();
int row;
int col;
char[][] board;
String[] words;
int[] dx=new int[]{0,0,1,-1};
int[] dy=new int[]{1,-1,0,0};
public List<String> findWords(char[][] board, String[] words) {
row= board.length;
col=board[0].length;
this.board=board;
this.words=words;
for (String word : words) {
TrieNode node=root;
for (int i = 0; i < word.length(); i++) {
char c=word.charAt(i);
if(!node.children.containsKey(c)){
node.children.put(c,new TrieNode());
}
node=node.children.get(c);
}
node.word=word;
}
for (int i = 0; i < row; i++) {
for (int j = 0; j < col; j++) {
if(root.children.containsKey(board[i][j])){
backtracking(root,i,j);
}
}
}
return res;
}

void backtracking(TrieNode parent, int i, int j){
char c=board[i][j];
TrieNode node=parent.children.get(c);
if(node.word!=null){
res.add(node.word);
node.word=null;
}

board[i][j]='#';
for (int k = 0; k < 4; k++) {
int x=i+dx[k];
int y=j+dy[k];
if(x<0 || x>=row || y<0 || y>=col || !node.children.containsKey(board[x][y])){
continue;
}
backtracking(node,x,y);
}
board[i][j]=c;

//important!
//Optimization: incrementally remove the leaf nodes
if(node.children.isEmpty()){
parent.children.remove(c);
}
}

class TrieNode{
Map<Character,TrieNode> children;
String word;
TrieNode(){
children=new HashMap<>();
}

}

425. Word Squares

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
List<List<String>> res;
int k;
Trie trie;
public List<List<String>> wordSquares(String[] words) {
trie=new Trie();
for (String word : words) {
trie.add(word);
}
k=words[0].length();
res=new ArrayList<>();
LinkedList<String> list=new LinkedList<>();
for (String word : words) {
list.add(word);
backtracking(list);
list.removeLast();
}
return res;
}

void backtracking(LinkedList<String> list){
if(list.size()==k){
res.add(new ArrayList<>(list));
return;
}
StringBuilder sb=new StringBuilder();
int n=list.size();
for (String s : list) {
sb.append(s.charAt(n));
}
String prefix=sb.toString();
for (String word : trie.getWords(prefix)) {
list.add(word);
backtracking(list);
list.removeLast();
}
}


class Trie{
Node root;
Trie(){
root=new Node();
}

void add(String word){
Node node=root;
for (int i = 0; i < word.length(); i++) {
char c=word.charAt(i);
if(!node.children.containsKey(c)){
node.children.put(c,new Node());
}
node=node.children.get(c);
node.words.add(word);
}
}

List<String> getWords(String prefix){
Node node=root;
for (int i = 0; i < prefix.length(); i++) {
char c=prefix.charAt(i);
if(!node.children.containsKey(c)){
return new ArrayList<>();
}
node=node.children.get(c);
}
return node.words;
}
}

class Node{
List<String> words;
Map<Character,Node> children;
Node(){
words=new ArrayList<>();
children=new HashMap<>();
}
}