-
Notifications
You must be signed in to change notification settings - Fork 30
/
mnnl.py
executable file
·78 lines (65 loc) · 2.58 KB
/
mnnl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
https://github.com/bplank/bilstm-aux/master/src/lib/mnnl.py
"""
import dynet
import numpy as np
import sys
## NN classes
class SequencePredictor:
def __init__(self):
pass
def predict_sequence(self, inputs):
raise NotImplementedError("SequencePredictor predict_sequence: Not Implemented")
class FFSequencePredictor(SequencePredictor):
def __init__(self, network_builder):
self.network_builder = network_builder
def predict_sequence(self, inputs):
return [self.network_builder(x) for x in inputs]
class RNNSequencePredictor(SequencePredictor):
def __init__(self, rnn_builder):
"""
rnn_builder: a LSTMBuilder/SimpleRNNBuilder or GRU builder object
"""
self.builder = rnn_builder
def predict_sequence(self, inputs):
s_init = self.builder.initial_state()
return s_init.transduce(inputs)
class BiRNNSequencePredictor(SequencePredictor):
""" a bidirectional RNN (LSTM/GRU) """
def __init__(self, f_builder, b_builder):
self.f_builder = f_builder
self.b_builder = b_builder
def predict_sequence(self, f_inputs, b_inputs):
f_init = self.f_builder.initial_state()
b_init = self.b_builder.initial_state()
forward_sequence = f_init.transduce(f_inputs)
backward_sequence = b_init.transduce(reversed(b_inputs))
return forward_sequence, backward_sequence