Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For discussion: numba_scipy.stats #42

Open
luk-f-a opened this issue May 17, 2020 · 12 comments
Open

For discussion: numba_scipy.stats #42

luk-f-a opened this issue May 17, 2020 · 12 comments

Comments

@luk-f-a
Copy link

luk-f-a commented May 17, 2020

hi everyone!

this is meant as a way to gather feedback on the current status of numba_scipy.stats. I'm pinging people that have expressed interest in numba_scipy.stats and/or are involved in numba and scipy. I'd like to share what I've learned so far, and hopefully you'll share your perspective on this.

Since last year I've been looking into scipy.stats with a view of getting numba-scipy.stats started. I created a prototype in #16. It's viable, but the experience has led me to question the cost/benefit tradeoff of following that path.

The main technical complication with scipy.stats is that is not function based, but object based. It relies on several of Python's OO features like inheritance and operator overloading. Numba has great support for function-based libraries (or method based, when the number of objects is limited) like Numpy. However, the support for classes (via jitclasses) is more limited and experimental. Outside of jitclasses, the only other option is to use the extending module, with the added effort that it implies.

The consequence of the above is that it will not be possible to fully imitate the behaviour of scipy.stats. At least not in the medium term, and not without a lot of work.

Even if jitclasses worked exactly as python classes, scipy.stats has more than a hundred distributions, each of them with more than 10 methods. If we followed the way of how numba supports numpy, we are talking about 1000+ methods to re-write. In some cases there will be performance improvements, but in some cases there won't.

Look at the following example:

from scipy.stats import norm
from numba import njit

def foo():
    k = 20000
    x = np.zeros(k)
    for m in range(100):
        x += norm.rvs(m, 1, size= 20000)
    return x

foo_jit = njit(foo)

@njit
def bar():
    k = 20000
    x = np.zeros(k)
    for m in range(100):
        with nb.objmode(y='float64[:]'):
            y = norm.rvs(m, 1, size= 20000)
        x += y
    return x

%timeit foo() #66 ms ± 277 µs

foo_jit()
%timeit foo_jit() #73.7 ms ± 214 µs

bar()
%timeit bar() #65.8 ms ± 208 µs

There's no performance improvement at all, because most of the work is already done in C. This will be the case in many scipy.stats functions.

To summarize, I see a few ways forward, each with pros and cons:

  • jitclass based solution

    • pros: easy for people to contribute (not much more than being competent with python and having used numba before)
    • cons: won't replicate scipy's behaviour, will regularly find jitclass' limitations and will have to find workarounds, will require 1000s of man-hours to build. All that effort does not build anything new, just a copy of existing scipy features.
  • low-level numba extension (http://numba.pydata.org/numba-doc/latest/extending/low-level.html)

    • pros: should be able to reproduce all or most behaviour
    • cons: harder to work with: would increase the effort required and limit the number of contributors. All that effort does not build anything new, just a copy of existing scipy features.
  • objmode approach = no jitted solution

I personally lean towards option 3 at the moment. I might write some custom code that calls special functions if I really need performance. But I'm not feeling very attracted to the idea of re-implementing such a large module as scipy.stats.

It would be great to hear your perspective on this.

cc: @gioxc88 @francoislauger @stnatter @remidebette @rileymcdowell @person142 @stuartarchibald

@LordGav
Copy link

LordGav commented Jan 1, 2021

I really wanted to use stats.skew() and stats.kurtosis(), is there any way to do that?

@stnatter
Copy link

stnatter commented Jan 4, 2021

None of these alternative routes seem to be particularly appealing. The upside seems very limited indeed. Thanks for your perspective @luk-f-a

Happy New Year!

@HDembinski
Copy link

To apply the method of maximum likelihood, fast implementations of the pdfs and cdfs are needed, option 3 would not do. There are speed gains of factor 100 currently if for example norm.cdf is replaced by a custom implementation based on the erf in scipy.special.cython_function.

@HDembinski
Copy link

HDembinski commented Feb 2, 2021

I started a repository with fast implementations here https://github.com/HDembinski/numba-stats that work for me. It would be great to merge this into numba-scipy, but it is not straight-forward, since I did not implement the scipy API, just added some fast versions of norm.pdf, etc.

For now, numba-stats wraps the special functions from scipy.special.cython_special independently of numba-scipy, but eventually once numba-scipy is stable, I would prefer to depend on numba-scipy.

@HDembinski
Copy link

Adding to that, the speed gains are dramatic as mentioned before, I see up to a factor 100 in some cases, less for large arrays. There seems to be a very large call overhead in scipy. I added some benchmarks with pytest-benchmark to my repo, just run pytest and see what you get.

@HDembinski
Copy link

In my field (high energy physics) having fast stats translates directly into fast turn-around when developing non-linear fits, which is the default for us. The speed-up in the stats functions translates very nicely into equivalent speed-ups of the fits. Which means we can build more complex fits and bootstrap our fit results.

@luk-f-a
Copy link
Author

luk-f-a commented Feb 3, 2021

if you need fast code for stats, and don't need to follow the scipy API, then rvlib is a good library. sadly unmaintained, but the code is there if you want to use it.

@HDembinski
Copy link

HDembinski commented Feb 10, 2021

if you need fast code for stats, and don't need to follow the scipy API, then rvlib is a good library.

Thank you for pointing this out. rvlib claims to have a better API than Scipy, I could not see that from a quick look. I really want numba-scipy to offer this functionality.

In the meantime, I realize that wrapping scipy is not that hard, a lot of scipy's implementations just call some C function. I am puzzled why it is so slow if the actual work is done in C anyway.

@luk-f-a
Copy link
Author

luk-f-a commented Feb 11, 2021

I really want numba-scipy to offer this functionality.

What functionality? Fast pdf and cdf under a scipy API?

I am puzzled why it is so slow if the actual work is done in C anyway.

There might be some work being done that you are not considering. You say that it's slow, are you comparing the speed to a pure C implementation?

@dlee992
Copy link

dlee992 commented Aug 26, 2022

Hi, @luk-f-a. Using the current master code in this project, could I implement scipy.stats.truncnorm.rvs easily? I also found JAX has supported for this API (using jax.random.truncated_normal and related stuff to rewrite), and achieved 10~1000x speedup (seems using multithreading and vectorized? I don't know it well).

Based on ur example on norm.rvs(), numba didn't achieve any speedup, I guess I have to give up the idea that using numba to speedup truncnorm.rvs() as well.

@luk-f-a
Copy link
Author

luk-f-a commented Aug 28, 2022

@dlee992 I don't think you can take what I did for norm.rvs() as a reference for what could happen with truncnorm.rvs(). My argument was that:

  • it is impossible to fully replicate the scipy API, in particular under the constraint of "jit transparency" which it's what numba has for numpy: functions "just work" when inside a jitted function, without the user making any adjustment.
  • even getting close to the scipy API would required too much work that it's not justified by the performance gain.

In your case, it does not sound like either of these points applies to you, so you shouldn't read too much into my conclusions, because the starting point is not the same.

On the first point, it seems that you don't need identical API or jit transparency. Please note that JAX supports the truncated_normal distribution, but this is different from saying that it supports the scipy API. It does not look to me that it does the latter, because the functions have different names. If you are willing to depart from scipy's API, there's a lot that numba can do.

Also, each function in the stats module has a different degree of performance in scipy. Functions which are already fast, as norm.rvs() are unlikely to be sped up by Numba or JAX. I don't know how truncnorm.rvs is implemented, but if JAX can improve 10-1000X, then it's not efficiently implemented. This means that my second point also does not apply to your case. There are performance gains to be achieved, even if it means re-implemented the method entirely. Since you were willing to do that work in JAX, then the performance gain is worth enough to you, and there's nothing that stops Numba from achieving that performance.

Going back to your question:

could I implement scipy.stats.truncnorm.rvs easily?

If you want to implement the functionality, ie build a function that produces the same results, the effort will be similar to building it in JAX. If you want to implement the functionality and replicate the API when called from normal python code, ie a non-jit function that calls my_truncnorm.rvs(), which is itself jitted, then it won't be hard, similar to building it in JAX.
If you want to perfectly replicate scipy API and behaviour even when calling methods from inside another jit function, then yes, it will be hard. But the problem won't be the code of truncnorm.rvs, the problem is simulating the scipy API under jit transparency.

@dlee992
Copy link

dlee992 commented Aug 29, 2022

@luk-f-a , many thanks for this thoughtful and detailed explaination! Now, I know my situation very well! I will figure out pro-and-cons about implmemting it using numba. In fact, JAX uses multithreading, and sometimes creates too many threads, even more than the max limit on linux... And some JAX issues (e.g., jax-ml/jax#11168) confirm this point, and JAX guys seem not want to "fix" this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants