Skip to content

Fix stopping_criteria result check in coca_model#860

Merged
rwightman merged 2 commits into
mlfoundations:mainfrom
MengqingCao:coca_fix
Jun 22, 2024
Merged

Fix stopping_criteria result check in coca_model#860
rwightman merged 2 commits into
mlfoundations:mainfrom
MengqingCao:coca_fix

Conversation

@MengqingCao

Copy link
Copy Markdown
Contributor

fix #847

The stopping criteria is updated in the latest transformers(V4.39.3 now). The return result is modified to a tensor (torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool) ) instead of a bool value, which causes the bug in #847

the related code in transformers
https://github.com/huggingface/transformers/blob/0bd58f1ce0573c0e3269de4215a17d318add49b9/src/transformers/generation/stopping_criteria.py#L76

@MengqingCao

Copy link
Copy Markdown
Contributor Author

@rwightman

rwightman commented May 9, 2024

Copy link
Copy Markdown
Collaborator

@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now..
@gpucce ?

@MengqingCao

Copy link
Copy Markdown
Contributor Author

@MengqingCao hmm, yeah this needs fixing. One q though as I'm not intimately familiar with the gen code. Does the implementation here support multiple sentences in a batch? If yes, using any() will stop too early as the bool tensor is done state for each row (sentence) no? If it's only supporting one sentence at a time this is fine for now.. @gpucce ?

Your concerns are right in the cases when users using StopStringCriteria and EosTokenCriteria, which I ignored before. I only noticed the default StoppingCriteria method MaxLengthCriteria before, which returns a boolTensor filled with one single bool value is_done. Thus, I think use any() brings bigger operating efficiency than all().

The related code in Transformers:
image

To adapt to the situation of StopStringCriteria and EosTokenCriteria at the same time, I think we have two choices:

  1. change to use all() here
  2. checking if there is StopStringCriteria and EosTokenCriteria in stopping_criteria, if no, use any(), otherwise, use all(). This may run faster but bring more changes than 1

@MengqingCao

Copy link
Copy Markdown
Contributor Author

@rwightman @gpucce , I have implemented option 2 and updated the code, give me some suggestions plz, thanks!

@rwightman rwightman merged commit 45b43c9 into mlfoundations:main Jun 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

2 participants