-
Notifications
You must be signed in to change notification settings - Fork 738
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
EBM Family / Loss Functions #380
Comments
Hi @mtl-tony -- Are you interested specifically/only in adding new loss functions, or more broadly in helping us improve the package? I ask because adding new loss functions is probably the most difficult new feature in our pipeline and it isn't something that I'd recommend someone new to the codebase start with. I'll give you an overview so you can decide if it's something you'd like to tackle. If you're looking more generally for an area that you could contribute in, we'd be happy to suggest some other areas that would be both impactful and easier starting off points. Firstly, the loss function computations are inside the most critical of loop for EBMs, so they need to be fast, and I'm working hard on making them even faster. In the future this critical loop will be both SIMD vectorized, and also supported on GPUs. Here is where we currently calculate the gradient and hessian for binary classification: interpret/shared/ebm_native/ApplyUpdate.cpp Line 308 in c9b43de
And for regression the equivalent is here: interpret/shared/ebm_native/ApplyUpdate.cpp Line 434 in c9b43de
Unfortunately, it isn't very easy currently to add new loss functions in part because both binary classification and regression are special. For MSE regression we can avoid storing the current per-sample predictions. Instead, because of the special nature of MSE regression, we can keep just the per-sample error, which also happens to be the gradient. Also, since the hessian in MSE is always 1, we do not need a floating point number, so we keep an integerized count in each histogram bin. We use C++ templates to specialize the data structures to either regression (without a hessian) or classification (with a hessian):
In order to support other loss functions, these things would need to change, but we don't want to lose the performance benefits that we currently enjoy for MSE regression, so it would need to be handled in a separate function for non-MSE loss functions, and the Bin structure would need to be specialized throughout the codebase to understand non-MSE regression. In the farther future, we'd like to compute the different loss functions inside vectorized loops. My plan there is to use specially constructed loss function classes which will then allow the loss function calculations to be injected into these highly optimized loops via C++ templating. This work is ongoing, but incomplete. This code below is part of the future Loss function architecture: https://github.com/interpretml/interpret/blob/develop/shared/ebm_native/compute/Loss.hpp Each loss function will be written as a class that looks something like this: And then be "registered" here: |
Hello, Thanks for your reply and for me it's of more interest to be able to try the model architecture with different loss function / family of distributions but I could potentially tackle an easier problem to become more familiar with the code base first. I appreciate your help with this and I was wondering what the procedure was typically if I did want to contribute as I find this is a very promising package. |
Hi @mtl-tony -- We're happy to receive PRs if you see something during your own experience that you'd like to improve. If you're looking for a dialog that would match something that we think is important to something within your interests, I'd suggest emailing us at interpret@microsoft.com and we'd be happy to have that longer discussion. |
Really interesting topic. An example of what I implemented with CNN and that I'd like to replicate with EBM: with target = the true value and pred = the prediction. An other example I'd like to achieve with EBM would be with binary classification (with BCE loss) Is there a project to allow such custom loss in a near future? |
Hi @JDE65 -- Custom losses are probably our most requested feature currently and we have a couple of issues open on them. It's fairly difficult to implement them and requires deep changes in the C++ layer, so we do not have an ETA on them. I can say though they are not in the cards for the next few months, but we will eventually complete this feature given its importance. |
Closing this to consolidate issues. We'll track progress on custom losses in issue #281. |
Hello,
I'm looking to add more functionality to this package as I wanted to incorporate other distributions and loss functions. I was wondering if you could advise / point to where I would need to add to incorporate this.
Thanks
The text was updated successfully, but these errors were encountered: