-
Notifications
You must be signed in to change notification settings - Fork 1
/
tree.py
238 lines (181 loc) · 10.1 KB
/
tree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import torch
import torch.nn as nn
import numpy as np
import collections
import copy
class Tree(nn.Module):
def __init__(self, args, manifold):
super(Tree, self).__init__()
self.manifold = manifold
self.parse_args(args)
self.generate_modules(args)
self.init_tree()
def parse_args(self, args):
self.device = args.device
self.max_levels = args.max_levels
self.max_children_per_parent = args.max_children_per_parent
self.add_threshold = args.add_threshold
self.remove_threshold = args.remove_threshold
self.emb_dim = args.emb_dim
self.init_num_children_per_parent = [3, 3]
self.init_num_levels = len(self.init_num_children_per_parent) + 1
def generate_modules(self, args):
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def init_tree(self):
self.par2child, self.level2topic = collections.defaultdict(list), collections.defaultdict(list)
self.level2topic[0] = [0]
for level in range(self.init_num_levels - 1):
for parent_id in self.level2topic[level]:
child_ids = [parent_id * self.max_children_per_parent + i for i in range(1, self.init_num_children_per_parent[level] + 1)]
self.par2child[parent_id] = child_ids
self.level2topic[level + 1].extend(child_ids)
self.construct_tree(self.par2child)
def construct_tree(self, par2child):
def get_topic2level(par2child, parent_id=0, topic2level=None, level=0):
if topic2level is None:
topic2level = {0: level}
child_ids = par2child[parent_id]
level += 1
for child_id in child_ids:
topic2level[child_id] = level
if child_id in par2child:
get_topic2level(par2child, child_id, topic2level, level)
return topic2level
def get_topic2descendant(parent_id, descendant_ids=None):
if descendant_ids is None:
descendant_ids = [parent_id]
if parent_id in self.par2child:
child_ids = self.par2child[parent_id]
descendant_ids += child_ids
for child_id in child_ids:
if child_id in self.par2child:
descendant_ids = get_topic2descendant(child_id, descendant_ids)
return descendant_ids
self.topic_ids = [0] + [child_id for parent_id, child_ids in par2child.items() for child_id in child_ids]
self.topic_ids.sort()
self.child2par = {child_id: parent_id for parent_id, child_ids in par2child.items() for child_id in child_ids}
self.topic2level = get_topic2level(par2child)
self.level2topic = collections.defaultdict(list)
for topic_id, level in self.topic2level.items():
self.level2topic[level].append(topic_id)
self.num_levels = len([level for level in self.level2topic.keys()])
self.leaf2ancestor = collections.defaultdict(list)
for topic_id in self.topic_ids:
if topic_id not in par2child:
x = topic_id
self.leaf2ancestor[topic_id].append(topic_id)
while x != 0:
x = self.child2par[x]
self.leaf2ancestor[topic_id].append(x)
self.leaf2ancestor[topic_id].sort()
self.topic2descendant = {topic_id: get_topic2descendant(topic_id) for topic_id in self.topic_ids}
def evaluate_level_dist(self, doc_emb, topic_emb, c=1.0, dist_type='gauss'):
eta_topic, sum_eta = {}, 0
d_root = torch.reshape(self.manifold.sqdist(doc_emb, topic_emb[0], c), [1, -1])
if dist_type == 'gauss':
eta_topic[0] = torch.exp(- 0.5 * d_root)
if dist_type == 'inv':
eta_topic[0] = 1.0 / (1.0 + d_root)
sum_eta += eta_topic[0]
levels = np.sort([level for level in self.level2topic.keys()])
for level in levels:
topic_ids = self.level2topic[level]
if level != 0:
distance = {}
for topic_id in topic_ids:
d = torch.reshape(self.manifold.sqdist(doc_emb, topic_emb[topic_id], c), [1, -1])
distance[topic_id] = d
min_level_distance = torch.min(torch.concat([value for value in distance.values()], dim=0), dim=0, keepdim=True)[0]
if dist_type == 'gauss':
eta_topic[level] = torch.exp(- 0.5 * min_level_distance)
if dist_type == 'inv':
eta_topic[level] = 1.0 / (1.0 + min_level_distance)
sum_eta += eta_topic[level]
level_dist = []
for level in levels:
level_dist.append(eta_topic[level] / (sum_eta + 1e-20))
level_dist = torch.concat(level_dist, dim=0).T
return level_dist
def evaluate_path_dist(self, doc_emb, topic_emb, c=1.0, dist_type='gauss'):
path_dist, gamma_topic = {}, {}
gamma_topic[0] = torch.ones([1, doc_emb.size(0)], dtype=torch.float32).to(self.device)
distance_c_d = collections.defaultdict(float)
for parent_id, child_ids in self.par2child.items():
sum_childs = collections.defaultdict(float)
sum_childs[parent_id] = 0
for child_id in child_ids:
topic_emb_one_topic = topic_emb[child_id]
d = torch.reshape(self.manifold.sqdist(doc_emb, topic_emb_one_topic, c), [1, -1])
if dist_type == 'gauss':
distance_temp = torch.exp(- 0.5 * d)
elif dist_type == 'inv':
distance_temp = 1.0 / (1.0 + d)
distance_c_d[child_id] = distance_temp
sum_childs[parent_id] += distance_temp
for child_id in child_ids:
gamma_topic[child_id] = distance_c_d[child_id] / (sum_childs[parent_id] + 1e-20)
for leaf_id, ancestor_ids in self.leaf2ancestor.items():
path_dist[leaf_id] = torch.prod(torch.concat([gamma_topic[ancestor_id] for ancestor_id in ancestor_ids], dim=0), dim=0)
return path_dist
def evaluate_doc_topic_dist(self, doc_emb, topic_emb, c=1.0, dist_type='inv'):
level_dist = self.evaluate_level_dist(doc_emb, topic_emb, c=c, dist_type=dist_type)
path_dist = self.evaluate_path_dist(doc_emb, topic_emb, c=c, dist_type=dist_type)
doc_topic_dist = collections.defaultdict(float)
for leaf_id, ancestor_ids in self.leaf2ancestor.items():
p = path_dist[leaf_id]
for i, ancestor_id in enumerate(ancestor_ids):
ancestor_prob = p * level_dist[:, i]
doc_topic_dist[ancestor_id] += torch.reshape(ancestor_prob, [-1, 1])
doc_topic_dist = torch.concat([doc_topic_dist[topic_id] for topic_id in self.topic_ids], dim=-1)
return doc_topic_dist
def evaluate_topic_word_dist(self, drnn, word_emb, depth_temperature=10.0):
topic_emb = drnn(self.par2child)
topic_word_dist = {}
for topic_id, level in self.topic2level.items():
topic_e = topic_emb[topic_id]
temperature = depth_temperature ** (1.0 / (level + 1.0))
logits = torch.matmul(topic_e, torch.transpose(word_emb, 0, 1))
topic_word_dist[topic_id] = self.softmax(logits / temperature)
topic_word_dist = torch.concat([topic_word_dist[topic_id] for topic_id in self.topic_ids], dim=0)
return topic_word_dist, topic_emb
def update_tree(self, doc_topic_dist, doc_length):
self.update_tree_flg = False
p = np.sum(np.multiply(np.expand_dims(doc_length, -1), doc_topic_dist), axis=0) / np.sum(doc_length)
p_dict = {topic_id: p[i] for i, topic_id in enumerate(self.topic_ids)}
recur_p_topic = {parent_id: np.sum([p_dict[child_id] for child_id in recur_child_ids]) for
parent_id, recur_child_ids in self.topic2descendant.items()}
def add_topic(topic_id, par2child):
if topic_id in par2child:
child_id = min([self.max_children_per_parent * topic_id + i for i in range(1, self.max_children_per_parent + 1) if
self.max_children_per_parent * topic_id + i not in par2child[topic_id]])
par2child[topic_id].append(child_id)
else:
child_id = self.max_children_per_parent * topic_id + 1
par2child[topic_id] = [self.max_children_per_parent * topic_id + 1]
return child_id, par2child
def remove_topic(parent_id, child_id, par2child):
if parent_id in par2child:
par2child[parent_id].remove(child_id)
if child_id in par2child:
par2child.pop(child_id)
return par2child
added_par2child = copy.deepcopy(self.par2child)
for parent_id, child_ids in self.par2child.items():
prob_topic = p_dict[parent_id]
if prob_topic > self.add_threshold:
self.update_tree_flg = True
parent_id, added_par2child = add_topic(parent_id, added_par2child)
removed_par2child = copy.deepcopy(added_par2child)
for parent_id, child_ids in self.par2child.items():
probs_child = np.array([recur_p_topic[child_id] for child_id in child_ids])
for prob_child, child_id in zip(probs_child, child_ids):
if prob_child < self.remove_threshold:
self.update_tree_flg = True
removed_par2child = remove_topic(parent_id, child_id, removed_par2child)
if parent_id in removed_par2child:
if len(removed_par2child[parent_id]) == 0:
ancestor_id = self.child2par[parent_id]
removed_par2child = remove_topic(ancestor_id, parent_id, removed_par2child)
self.par2child = removed_par2child
self.construct_tree(self.par2child)