JAX running in CPU only mode only uses a single core
See original GitHub issueWhen running JAX installed from pip on a CPU only host while monitoring core usage I only ever see a single core go to 100% utilization. All other cores are idle.
I have observed the same behavior on multiple separate machines with different python versions.
From threads https://github.com/google/jax/issues/743 and https://github.com/google/jax/issues/1539 I have attempted to use XLA_FLAGS="--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=16" but this makes no difference.
Observing the comment https://github.com/google/jax/issues/1539#issuecomment-578496962 about thread affinity I have also tried running my script prepended with taskset -c 0-15 with and without the above XLA_FLAGS directive but again this makes no difference.
There must be something I am missing here. All documentation and support threads here on Github imply that the default behavior is to detect and use all cores as a single local device using intra op parallelism, and yet I can only observe single threaded behavior in practice.
Any help would be much appreciated 😃
Additionally does anyone have a link to documentation for all of the available XLA_FLAG options?
Issue Analytics
- State:
- Created 3 years ago
- Reactions:31
- Comments:15 (1 by maintainers)
Top Related StackOverflow Question
Same issue with jax-0.3.14 on Intel i7-10875H, 8 physical cores, Linux 5.18.9
Using @nestordemeure’s nice test script from #5506 shows a single core used:
Similarly
numpyroonly heats a single core. Tested both cuda/CPU and CPU only jaxlib wheels, with various combinations ofXLA_FLAGS="--xla_force_host_platform_device_count=8"taskset -c 0-7numpyro.set_host_device_count(8)etc, with no effect.I am also having the same trouble. I can only get 1 cpu core utilized. Any help here would be much appreciated!