-
Notifications
You must be signed in to change notification settings - Fork 503
/
model.py
37 lines (25 loc) · 1.12 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import better_exceptions
from keras.applications import ResNet50, InceptionResNetV2
from keras.layers import Dense
from keras.models import Model
from keras import backend as K
def age_mae(y_true, y_pred):
true_age = K.sum(y_true * K.arange(0, 101, dtype="float32"), axis=-1)
pred_age = K.sum(y_pred * K.arange(0, 101, dtype="float32"), axis=-1)
mae = K.mean(K.abs(true_age - pred_age))
return mae
def get_model(model_name="ResNet50"):
base_model = None
if model_name == "ResNet50":
base_model = ResNet50(include_top=False, weights='imagenet', input_shape=(224, 224, 3), pooling="avg")
elif model_name == "InceptionResNetV2":
base_model = InceptionResNetV2(include_top=False, weights='imagenet', input_shape=(299, 299, 3), pooling="avg")
prediction = Dense(units=101, kernel_initializer="he_normal", use_bias=False, activation="softmax",
name="pred_age")(base_model.output)
model = Model(inputs=base_model.input, outputs=prediction)
return model
def main():
model = get_model("InceptionResNetV2")
model.summary()
if __name__ == '__main__':
main()