-
Notifications
You must be signed in to change notification settings - Fork 444
Add a script to verify safetensor checkpoint #2903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a73364e to
0638a9a
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
🤖 Hi @RissyRan, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
📋 Review Summary
This pull request introduces a valuable verification script to ensure the correctness of converted safetensor checkpoints. The script is well-structured and provides clear logging. The refactoring of helper functions into a shared utils.py file is a good improvement for code organization.
🔍 General Feedback
- The new verification script is a great addition for improving the reliability of the checkpoint conversion process.
- The parallel loading of safetensor files using
ThreadPoolExecutoris a good choice for performance. - The detailed logging in the verification script is helpful for debugging potential mismatches.
One minor logging issue was found and commented on. Overall, this is a solid contribution.
shuningjin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for adding this utility! Looks good overall.
20f43c9 to
ad787ee
Compare
ad787ee to
e951861
Compare
hengtaoguo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I also wonder if this is equivalent to the forward logits check, which also requires all layers identical to have matching results. It's great to have a double insurance though.
Thanks Hengtao! No, it's not the same. Forward logit test is testing orbax checkpoint loading into maxtext against with HF. However this test is used specifically for safetensor checkpoint (converted from to_huggingface) against with reference HF checkpoint. |
Description
Verify the converted safetensor checkpoint (GCS or local) matches the remote HuggingFace checkpoint reference.
utils.pyWe may enable lazy mode as follow up to save memory.
Tests
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.