forked from jina-ai/clip-as-service
-
Notifications
You must be signed in to change notification settings - Fork 0
/
example7.py
122 lines (105 loc) · 4.11 KB
/
example7.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Han Xiao <[email protected]> <https://hanxiao.github.io>
# NOTE: First install bert-as-service via
# $
# $ pip install bert-serving-server
# $ pip install bert-serving-client
# $
# visualizing a 12-layer BERT
import time
from collections import namedtuple
import numpy as np
import pandas as pd
# from MulticoreTSNE import MulticoreTSNE as TSNE
from bert_serving.client import BertClient
from bert_serving.server import BertServer
from bert_serving.server.graph import PoolingStrategy
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.decomposition import PCA
data = pd.read_csv('/data/cips/data/lab/data/dataset/uci-news-aggregator.csv', usecols=['TITLE', 'CATEGORY'])
# just copy paste from some Kaggle kernel ->
num_of_categories = 5000
shuffled = data.reindex(np.random.permutation(data.index))
e = shuffled[shuffled['CATEGORY'] == 'e'][:num_of_categories]
b = shuffled[shuffled['CATEGORY'] == 'b'][:num_of_categories]
t = shuffled[shuffled['CATEGORY'] == 't'][:num_of_categories]
m = shuffled[shuffled['CATEGORY'] == 'm'][:num_of_categories]
concated = pd.concat([e, b, t, m], ignore_index=True)
# Shuffle the dataset
concated = concated.reindex(np.random.permutation(concated.index))
concated['LABEL'] = 0
# One-hot encode the lab
concated.loc[concated['CATEGORY'] == 'e', 'LABEL'] = 0
concated.loc[concated['CATEGORY'] == 'b', 'LABEL'] = 1
concated.loc[concated['CATEGORY'] == 't', 'LABEL'] = 2
concated.loc[concated['CATEGORY'] == 'm', 'LABEL'] = 3
subset_text = list(concated['TITLE'].values)
subset_label = list(concated['LABEL'].values)
num_label = len(set(subset_label))
# <- just copy paste from some Kaggle kernel
print('min_seq_len: %d' % min(len(v.split()) for v in subset_text))
print('max_seq_len: %d' % max(len(v.split()) for v in subset_text))
print('unique label: %d' % num_label)
pool_layer = 1
subset_vec_all_layers = []
common = {
'model_dir': '//data/cips/data/lab/data/model/uncased_L-12_H-768_A-12',
'num_worker': 2,
'num_repeat': 5,
'port': 6006,
'port_out': 6007,
'max_seq_len': 20,
'client_batch_size': 2048,
'max_batch_size': 256,
'num_client': 1,
'pooling_strategy': PoolingStrategy.REDUCE_MEAN,
'pooling_layer': [-2],
'gpu_memory_fraction': 0.5,
'xla': False,
'cpu': False,
'verbose': False,
'device_map': []
}
args = namedtuple('args_namedtuple', ','.join(common.keys()))
for k, v in common.items():
setattr(args, k, v)
for pool_layer in range(1, 13):
setattr(args, 'pooling_layer', [-pool_layer])
server = BertServer(args)
server.start()
print('wait until server is ready...')
time.sleep(15)
print('encoding...')
bc = BertClient(port=common['port'], port_out=common['port_out'], show_server_config=True)
subset_vec_all_layers.append(bc.encode(subset_text))
bc.close()
server.close()
print('done at layer -%d' % pool_layer)
def vis(embed, vis_alg='PCA', pool_alg='REDUCE_MEAN'):
plt.close()
fig = plt.figure()
plt.rcParams['figure.figsize'] = [21, 7]
for idx, ebd in enumerate(embed):
ax = plt.subplot(2, 6, idx + 1)
vis_x = ebd[:, 0]
vis_y = ebd[:, 1]
plt.scatter(vis_x, vis_y, c=subset_label, cmap=ListedColormap(["blue", "green", "yellow", "red"]), marker='.',
alpha=0.7, s=2)
ax.set_title('pool_layer=-%d' % (idx + 1))
plt.tight_layout()
plt.subplots_adjust(bottom=0.1, right=0.95, top=0.9)
cax = plt.axes([0.96, 0.1, 0.01, 0.3])
cbar = plt.colorbar(cax=cax, ticks=range(num_label))
cbar.ax.get_yaxis().set_ticks([])
for j, lab in enumerate(['ent.', 'bus.', 'sci.', 'heal.']):
cbar.ax.text(.5, (2 * j + 1) / 8.0, lab, ha='center', va='center', rotation=270)
fig.suptitle('%s visualization of BERT layers using "bert-as-service" (-pool_strategy=%s)' % (vis_alg, pool_alg),
fontsize=14)
plt.show()
pca_embed = [PCA(n_components=2).fit_transform(v) for v in subset_vec_all_layers]
vis(pca_embed)
if False:
tsne_embed = [TSNE(n_jobs=8).fit_transform(v) for v in subset_vec_all_layers]
vis(tsne_embed, 't-SNE')