You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The stopping criteria is updated in the latest transformers(V4.39.3 now). The return result is modified to a tensor (torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) ) instead of a bool value, which causes the bug in #847
@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now.. @gpucce ?
@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now.. @gpucce ?
Your concerns are right in the cases when users using StopStringCriteria and EosTokenCriteria, which I ignored before. I only noticed the default StoppingCriteria method MaxLengthCriteria before, which returns a boolTensor filled with one single bool value is_done. Thus, I think use any() brings bigger operating efficiency than all().
To adapt to the situation of StopStringCriteria and EosTokenCriteria at the same time, I think we have two choices:
change to use all() here
checking if there is StopStringCriteria and EosTokenCriteria in stopping_criteria, if no, use any(), otherwise, use all(). This may run faster but bring more changes than 1
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
fix #847
The stopping criteria is updated in the latest
transformers(V4.39.3 now). The return result is modified to a tensor (torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)) instead of a bool value, which causes the bug in #847the related code in
transformershttps://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/generation/stopping_criteria.py#L76