From 46002e31fc34c746c01bcaa7ade999187068ad3c Mon Sep 17 00:00:00 2001 From: Jin Ma Date: Wed, 28 Sep 2016 09:46:05 +0800 Subject: [PATCH] add infer stage for lstm_ocr. Signed-off-by: Jin Ma --- example/warpctc/infer_ocr.py | 100 ++++++++++++++++++++++++++++++++++ example/warpctc/lstm.py | 39 +++++++++++++ example/warpctc/lstm_model.py | 54 ++++++++++++++++++ 3 files changed, 193 insertions(+) create mode 100644 example/warpctc/infer_ocr.py create mode 100644 example/warpctc/lstm_model.py diff --git a/example/warpctc/infer_ocr.py b/example/warpctc/infer_ocr.py new file mode 100644 index 000000000000..8451acc4a0d5 --- /dev/null +++ b/example/warpctc/infer_ocr.py @@ -0,0 +1,100 @@ +# coding=utf-8 +# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme +# pylint: disable=superfluous-parens, no-member, invalid-name +import sys + +sys.path.insert(0, "../../python") +import numpy as np +import mxnet as mx + +from lstm_model import LSTMInferenceModel + +import cv2, random +from captcha.image import ImageCaptcha + +BATCH_SIZE = 32 +SEQ_LENGTH = 80 + + +def ctc_label(p): + ret = [] + p1 = [0] + p + for i in range(len(p)): + c1 = p1[i] + c2 = p1[i + 1] + if c2 == 0 or c2 == c1: + continue + ret.append(c2) + return ret + + +def remove_blank(l): + ret = [] + for i in range(len(l)): + if l[i] == 0: + break + ret.append(l[i]) + return ret + + +def gen_rand(): + buf = "" + max_len = random.randint(3,4) + for i in range(max_len): + buf += str(random.randint(0,9)) + return buf + +if __name__ == '__main__': + num_hidden = 100 + num_lstm_layer = 2 + + num_epoch = 10 + learning_rate = 0.001 + momentum = 0.9 + num_label = 4 + + n_channel = 1 + contexts = [mx.context.gpu(0)] + _, arg_params, __ = mx.model.load_checkpoint('ocr', num_epoch) + + num = gen_rand() + print 'Generated number: ' + num + # change the fonts accordingly + captcha = ImageCaptcha(fonts=['./data/OpenSans-Regular.ttf']) + img = captcha.generate(num) + img = np.fromstring(img.getvalue(), dtype='uint8') + img = cv2.imdecode(img, cv2.IMREAD_GRAYSCALE) + img = cv2.resize(img, (80, 30)) + + img = img.transpose(1, 0) + + img = img.reshape((1, 80 * 30)) + img = np.multiply(img, 1 / 255.0) + + data_shape = [('data', (1, n_channel * 80 * 30))] + input_shapes = dict(data_shape) + + model = LSTMInferenceModel(num_lstm_layer, + SEQ_LENGTH, + num_hidden=num_hidden, + num_label=num_label, + arg_params=arg_params, + data_size = n_channel * 30 * 80, + ctx=contexts[0]) + + prob = model.forward(mx.nd.array(img)) + + p = [] + for k in range(SEQ_LENGTH): + p.append(np.argmax(prob[k])) + + p = ctc_label(p) + print 'Predicted label: ' + str(p) + + pred = '' + for c in p: + pred += str((int(c) - 1)) + + print 'Predicted number: ' + pred + + diff --git a/example/warpctc/lstm.py b/example/warpctc/lstm.py index 32ba2455e11d..4be4a0d914f1 100644 --- a/example/warpctc/lstm.py +++ b/example/warpctc/lstm.py @@ -77,3 +77,42 @@ def lstm_unroll(num_lstm_layer, seq_len, sm = mx.sym.WarpCTC(data=pred, label=label, label_length = num_label, input_length = seq_len) return sm + +def lstm_inference_symbol(num_lstm_layer, seq_len, num_hidden, num_label): + param_cells = [] + last_states = [] + for i in range(num_lstm_layer): + param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable("l%d_i2h_weight" % i), + i2h_bias=mx.sym.Variable("l%d_i2h_bias" % i), + h2h_weight=mx.sym.Variable("l%d_h2h_weight" % i), + h2h_bias=mx.sym.Variable("l%d_h2h_bias" % i))) + state = LSTMState(c=mx.sym.Variable("l%d_init_c" % i), + h=mx.sym.Variable("l%d_init_h" % i)) + last_states.append(state) + assert (len(last_states) == num_lstm_layer) + + # embeding layer + data = mx.sym.Variable('data') + wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) + + hidden_all = [] + for seqidx in range(seq_len): + hidden = wordvec[seqidx] + for i in range(num_lstm_layer): + next_state = lstm(num_hidden, indata=hidden, + prev_state=last_states[i], + param=param_cells[i], + seqidx=seqidx, layeridx=i) + hidden = next_state.h + last_states[i] = next_state + hidden_all.append(hidden) + + hidden_concat = mx.sym.Concat(*hidden_all, dim=0) + fc = mx.sym.FullyConnected(data=hidden_concat, num_hidden=11) + sm = mx.sym.SoftmaxOutput(data=fc, name='softmax') + + output = [sm] + for state in last_states: + output.append(state.c) + output.append(state.h) + return mx.sym.Group(output) diff --git a/example/warpctc/lstm_model.py b/example/warpctc/lstm_model.py new file mode 100644 index 000000000000..e9c8aa74365f --- /dev/null +++ b/example/warpctc/lstm_model.py @@ -0,0 +1,54 @@ + +# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme +# pylint: disable=superfluous-parens, no-member, invalid-name +import sys +sys.path.insert(0, "../../python") +import numpy as np +import mxnet as mx + +from lstm import LSTMState, LSTMParam, lstm, lstm_inference_symbol + + +class LSTMInferenceModel(object): + def __init__(self, + num_lstm_layer, + seq_len, + num_hidden, + num_label, + arg_params, + data_size, + ctx=mx.cpu()): + self.sym = lstm_inference_symbol(num_lstm_layer, + seq_len, + num_hidden, + num_label) + + batch_size = 1 + init_c = [('l%d_init_c'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] + init_h = [('l%d_init_h'%l, (batch_size, num_hidden)) for l in range(num_lstm_layer)] + data_shape = [("data", (batch_size, data_size))] + input_shapes = dict(init_c + init_h + data_shape) + self.executor = self.sym.simple_bind(ctx=ctx, **input_shapes) + + for key in self.executor.arg_dict.keys(): + if key in arg_params: + arg_params[key].copyto(self.executor.arg_dict[key]) + + state_name = [] + for i in range(num_lstm_layer): + state_name.append("l%d_init_c" % i) + state_name.append("l%d_init_h" % i) + + self.states_dict = dict(zip(state_name, self.executor.outputs[1:])) + self.input_arr = mx.nd.zeros(data_shape[0][1]) + + def forward(self, input_data, new_seq=False): + if new_seq == True: + for key in self.states_dict.keys(): + self.executor.arg_dict[key][:] = 0. + input_data.copyto(self.executor.arg_dict["data"]) + self.executor.forward() + for key in self.states_dict.keys(): + self.states_dict[key].copyto(self.executor.arg_dict[key]) + prob = self.executor.outputs[0].asnumpy() + return prob \ No newline at end of file