Skip to content

Conversation

@dfm
Copy link
Owner

@dfm dfm commented Apr 3, 2024

The quasisep solver is fast on CPU, but the performance is very bad on GPU (and probably TPU) because of the extensive use of lax.scan. It's possible to rewrite at least some of these operations using lax.associative_scan which (at least in principle) are more accelerator friendly. This approach is similar is spirit to the algorithms derived in https://arxiv.org/abs/1905.13002

This PR is a WIP to add some of these operations. So far, I've just implemented a parallel matrix multiplication. There are still some precision issues to work out, but the initial performance looks good:

Screenshot 2024-04-03 at 6 16 26 PM

On CPU, the scan and associative_scan matmuls take 1.65 ms and 3.59 ms respectively, for a J = 3 lower triangular matrix with N = 50,000 data points. On the GPU, these computations cost 685 ms and 1.32 ms respectively. Therefore, the scan version is ~600x slower on GPU, whereas the associative_scan version isn't. These GPU results are not impressive, but it might be worth investigating further in case someone wants to use this solver as part of a larger model that benefits from hardware acceleration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants