-
Notifications
You must be signed in to change notification settings - Fork 1.4k
QQ linear #2931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
QQ linear #2931
Conversation
This reverts commit 867e0fc.
python/mlx/nn/layers/quantized.py
Outdated
| mode=self.mode, | ||
| ) | ||
|
|
||
| def train(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you make that API consistent with Module::train?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think if you do that we can get rid of the eval override above and just use the base class eval.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course, this is a very good point. Now the weights will be quantized on the first qq_linear.__call__() after calling Module::eval(), and likewise dequantized on the first qq_linear.__call__()after calling Module::train(). This seems to be the only way to keep it consistent with the current API without changing Module::train()..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see why you had to do it this way but I'm not crazy about how it works... I'm wondering if there is a better way to do it.
Basically the behavior that would be good to have is if we do:
qq_module.eval()
qq_module.parameters() # should give me the quantized params
qq_module.load_weights(quantized_weights) # should be able to load the quantized params
I think that should work and right now it won't.
I think we have some other options:
- Break away from the
train/evalAPI and have something likeQQLinear.quantize/QQLinear.dequantizewhich either quantizes/dequantize the module in-place (or maybe returns a copy that is quantized/dequantized) - Change the base class
Moduleto make it easier to overridetrain(e..g call the submodulestrainas well as setting the local module's_training.
Wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fully agree with the desired behavior you described. Between the options, I prefer (2). I would expect model.train() / model.eval() to recursively propagate mode changes through the tree and adding a separate quantize() / dequantize() API would create a parallel mode system that would likely need to be aligned with train / eval as well..
awni
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks very nice! Just a few more cosmetic comments then we should get it merged!
awni
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks!!
QQLinearlayer