Skip to content

draft for rollin implementation#12

Draft
slvnwhrl wants to merge 2 commits intodevelopmentfrom
feature/roll-in
Draft

draft for rollin implementation#12
slvnwhrl wants to merge 2 commits intodevelopmentfrom
feature/roll-in

Conversation

@slvnwhrl
Copy link
Owner

No description provided.

@slvnwhrl slvnwhrl requested a review from peter-makarov June 16, 2022 08:39
@slvnwhrl slvnwhrl self-assigned this Jun 16, 2022
rollin_samples = random.sample(sample_ids, nr_samples)
with torch.no_grad():
# restore
for id_, sample in sample_stack:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@peter-makarov Here my idea was to sample an increasing size of samples each epoch while sampling a different subset each epoch (and setting the previous subset to the previous status (teacher forcing))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense.

return Output(action_history, self.decode_encoded_output(input_, action_history),
log_p, None)

def roll_in(self, sample: utils.Sample, rollin: int) -> None:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation is for a single sample at the moment, the decoder steps could probably be batched (and sampling / exploration as loop). This would likely increase speed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollin should probably be a float representing probability?

# expert prediction
expert_actions = self.expert_rollout(sample.input, sample.target,
current_alignment.item(), output)
optimal_actions.append(expert_actions)
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The targets are always generated by expert.

optimal_actions.append(expert_actions)

# update states
if np.random.rand() <= rollin:
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The next state is either sampled from the expert or the model itself. Is this the idea? Alternatively, one could always execute the model (and only rely on the expert for the targets).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea!

@slvnwhrl
Copy link
Owner Author

@peter-makarov So I've drafted an implementation. I've tested it a bit and it seems to work somehow, however, I think I am still missing something...

transducer_.roll_in(training_data.samples[id_], rollin)

j = 0
for j, batch in enumerate(training_data_loader):
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the current setup, rollin is performed before each epoch. I think a problem with this approach could be that the model performs update during the epoch and the rolled in target sequences are not representative anymore for the model. So maybe it would make more sense to rollin after every training step with some probability.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).

Copy link
Collaborator

@peter-makarov peter-makarov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

transducer_.roll_in(training_data.samples[id_], rollin)

j = 0
for j, batch in enumerate(training_data_loader):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it would be more sound and more useful to sample for each batch (this will address the errors of the current model checkpoint, not the errors that may have resolved themselves anyway already due to the recent parameter updates).

return Output(action_history, self.decode_encoded_output(input_, action_history),
log_p, None)

def roll_in(self, sample: utils.Sample, rollin: int) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rollin should probably be a float representing probability?

optimal_actions.append(expert_actions)

# update states
if np.random.rand() <= rollin:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the idea!

action = sample_action
else:
action = expert_actions[
int(np.argmax([log_probs_np[a] for a in expert_actions]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this does not over-corrects the model if it already predicts an optimal action, in case multiple actions are optimal.

if char != "":
output.append(char)

alignment_history = torch.cat(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like this has to be done via concatenation in a loop. Can this not be re-written using list append?

]

action_history = torch.cat(
(action_history, torch.tensor([[[action]]], device=self.device)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment regarding concatenation in the loop.

@peter-makarov
Copy link
Collaborator

@peter-makarov So I've drafted an implementation. I've tested it a bit and it seems to work somehow, however, I think I am still missing something...

Is there no improvement?

Test it on some little data (e.g. morphological inflection 100 samples) and batch size 1. Try with roll-in and without roll-in. You should be seeing consistent improvement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants