diff --git a/persephone/keras_model.py b/persephone/keras_model.py new file mode 100644 index 0000000..5c9402c --- /dev/null +++ b/persephone/keras_model.py @@ -0,0 +1,23 @@ +import keras + +class RNN_CTC_model: + """ An acoustic model with a LSTM/CTC architecture. + + Uses Keras to define the model""" + + def __init__(self, exp_dir: str, corpus_reader, num_layers: int = 3, + hidden_size: int=250, beam_width: int = 100, + decoding_merge_repeated: bool = True) -> None: + """Initialize a new model + + Arguments: + exp_dir: Path that the experiment directory is located at + corpus_reader: `CorpusReader` object that provides access to the corpus + this model is being trained on. + num_layers: number of layers in the network + hidden_size: the size, in nodes, of the hidden layers + beam_width: size of the beam width (used for the decoding) + decoding_merge_repeated: A flag to toggle behavior of repeating characters + if true "a b b b b c" becomes "a b c" + """ + raise NotImplementedError \ No newline at end of file diff --git a/setup.py b/setup.py index f95da2b..bb5b41b 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,7 @@ 'nltk==3.2.5', 'numpy==1.15.0', 'python-speech-features==0.6', + 'keras==2.2.2', 'scipy==1.1.0', 'tensorflow==1.10.0', 'scikit-learn==0.19.1',