-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathEncDecModel.py
More file actions
51 lines (36 loc) · 1.57 KB
/
EncDecModel.py
File metadata and controls
51 lines (36 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch.nn as nn
from Utils import *
from modules.Generations import *
class EncDecModel(nn.Module):
def __init__(self, vocab2id, max_dec_len=120, beam_width=1, eps=1e-10):
super(EncDecModel, self).__init__()
self.eps = eps
self.beam_width = beam_width
self.max_dec_len = max_dec_len
self.vocab2id = vocab2id
def encode(self, data):
raise NotImplementedError
def init_decoder_states(self, data, encode_output):
return None
def init_feedback_states(self, encode_outputs, init_decoder_states):
return None
def decode(self, data, previous_word, encode_outputs, previous_deocde_outputs, feedback_outputs):
raise NotImplementedError
def generate(self, data, encode_outputs, decode_outputs, softmax=False):
raise NotImplementedError
def to_word(self, data, gen_outputs, k=5, sampling=False):
raise NotImplementedError
def generation_to_decoder_input(self, data, indices):
return indices
def decoder_to_encoder(self, data, encoder_outputs, decoder_outputs):
return NotImplementedError
def loss(self, data, encode_output, decode_outputs, gen_outputs, reduction='mean'):
raise NotImplementedError
def to_sentence(self, data, batch_indice):
raise NotImplementedError
def sample(self, data):
raise NotImplementedError
def greedy(self, data):
return greedy(self, data, self.vocab2id, self.max_dec_len)
def beam(self, data):
return beam(self, data, self.vocab2id, self.max_dec_len, self.beam_width)