-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
70 lines (56 loc) · 1.88 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pickle
import numpy as np
import keras.preprocessing.image as image_processing
def pickle_save(object, path):
try:
print('save data to {} successfully'.format(path))
with open(path, "wb") as f:
return pickle.dump(object, f)
except:
print('save data to {} failed'.format(path))
def pickle_load(path):
try:
print("Loading data from {}".format(path))
with open(path, "rb") as f:
data = pickle.load(f)
print('load data successfully'.format(path))
return data
except Exception as e:
print(str(e))
return None
def prune(x, y, labels, prune_classes):
"""
prune data by give classes
"""
for class_to_prune in range(len(prune_classes)):
remove_size = prune_classes[class_to_prune]
if remove_size <= 0:
continue
print(class_to_prune)
all_ids = list(np.arange(len(x)))
mask = [lc == class_to_prune for lc in labels]
all_ids_c = np.array(all_ids)[mask]
np.random.shuffle(all_ids_c)
to_delete = all_ids_c[:remove_size]
x = np.delete(x, to_delete, axis=0)
y = np.delete(y, to_delete, axis=0)
labels = np.delete(labels, to_delete, axis=0)
print('Remove {} items in class {}'.format(remove_size, class_to_prune))
return x, y, labels
def norm(imgs):
return (imgs - 127.5) / 127.5
def de_norm(imgs):
return imgs * 127.5 + 127.5
def transform(x, seed=0):
np.random.seed(seed)
img = image_processing.random_rotation(x, 0.2)
img = image_processing.random_shear(img, 30)
img = image_processing.random_zoom(img, (0.5, 1.1))
if np.random.rand() >= 0.5:
img = np.fliplr(img)
return img
def weighted_samples(labels, class_weight):
w = []
for i in range(len(labels)):
w.append(class_weight[labels[i]])
return np.array(w)