Conversation
| rollin_samples = random.sample(sample_ids, nr_samples) | ||
| with torch.no_grad(): | ||
| # restore | ||
| for id_, sample in sample_stack: |
There was a problem hiding this comment.
@peter-makarov Here my idea was to sample an increasing size of samples each epoch while sampling a different subset each epoch (and setting the previous subset to the previous status (teacher forcing))
trans/transducer.py
Outdated
| return Output(action_history, self.decode_encoded_output(input_, action_history), | ||
| log_p, None) | ||
|
|
||
| def roll_in(self, sample: utils.Sample, rollin: int) -> None: |
There was a problem hiding this comment.
The implementation is for a single sample at the moment, the decoder steps could probably be batched (and sampling / exploration as loop). This would likely increase speed.
There was a problem hiding this comment.
rollin should probably be a float representing probability?
| # expert prediction | ||
| expert_actions = self.expert_rollout(sample.input, sample.target, | ||
| current_alignment.item(), output) | ||
| optimal_actions.append(expert_actions) |
There was a problem hiding this comment.
The targets are always generated by expert.
| optimal_actions.append(expert_actions) | ||
|
|
||
| # update states | ||
| if np.random.rand() <= rollin: |
There was a problem hiding this comment.
The next state is either sampled from the expert or the model itself. Is this the idea? Alternatively, one could always execute the model (and only rely on the expert for the targets).
There was a problem hiding this comment.
Yes, that's the idea!
|
@peter-makarov So I've drafted an implementation. I've tested it a bit and it seems to work somehow, however, I think I am still missing something... |
| transducer_.roll_in(training_data.samples[id_], rollin) | ||
|
|
||
| j = 0 | ||
| for j, batch in enumerate(training_data_loader): |
There was a problem hiding this comment.
With the current setup, rollin is performed before each epoch. I think a problem with this approach could be that the model performs update during the epoch and the rolled in target sequences are not representative anymore for the model. So maybe it would make more sense to rollin after every training step with some probability.
There was a problem hiding this comment.
Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).
| transducer_.roll_in(training_data.samples[id_], rollin) | ||
|
|
||
| j = 0 | ||
| for j, batch in enumerate(training_data_loader): |
There was a problem hiding this comment.
Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).
trans/transducer.py
Outdated
| return Output(action_history, self.decode_encoded_output(input_, action_history), | ||
| log_p, None) | ||
|
|
||
| def roll_in(self, sample: utils.Sample, rollin: int) -> None: |
There was a problem hiding this comment.
rollin should probably be a float representing probability?
| optimal_actions.append(expert_actions) | ||
|
|
||
| # update states | ||
| if np.random.rand() <= rollin: |
There was a problem hiding this comment.
Yes, that's the idea!
| action = sample_action | ||
| else: | ||
| action = expert_actions[ | ||
| int(np.argmax([log_probs_np[a] for a in expert_actions])) |
There was a problem hiding this comment.
So this does not over-corrects the model if it already predicts an optimal action, in case multiple actions are optimal.
| if char != "": | ||
| output.append(char) | ||
|
|
||
| alignment_history = torch.cat( |
There was a problem hiding this comment.
Doesn't seem like this has to be done via concatenation in a loop. Can this not be re-written using list append?
| ] | ||
|
|
||
| action_history = torch.cat( | ||
| (action_history, torch.tensor([[[action]]], device=self.device)), |
There was a problem hiding this comment.
Same comment regarding concatenation in the loop.
Is there no improvement? Test it on some little data (e.g. morphological inflection 100 samples) and batch size 1. Try with roll-in and without roll-in. You should be seeing consistent improvement. |
No description provided.