JAX running in CPU only mode only uses a single core

See original GitHub issue

When running JAX installed from pip on a CPU only host while monitoring core usage I only ever see a single core go to 100% utilization. All other cores are idle.

I have observed the same behavior on multiple separate machines with different python versions.

From threads https://github.com/google/jax/issues/743 and https://github.com/google/jax/issues/1539 I have attempted to use XLA_FLAGS="--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads=16" but this makes no difference.

Observing the comment https://github.com/google/jax/issues/1539#issuecomment-578496962 about thread affinity I have also tried running my script prepended with taskset -c 0-15 with and without the above XLA_FLAGS directive but again this makes no difference.

There must be something I am missing here. All documentation and support threads here on Github imply that the default behavior is to detect and use all cores as a single local device using intra op parallelism, and yet I can only observe single threaded behavior in practice.

Any help would be much appreciated 😃

Additionally does anyone have a link to documentation for all of the available XLA_FLAG options?

Issue Analytics

  • State:open
  • Created 3 years ago
  • Reactions:31
  • Comments:15 (1 by maintainers)

github_iconTop GitHub Comments

5reactions
mattjacommented, Jul 8, 2022

Same issue with jax-0.3.14 on Intel i7-10875H, 8 physical cores, Linux 5.18.9

Using @nestordemeure’s nice test script from #5506 shows a single core used:

numpy: cpu usage 1.0/16 wall_time:0.5s
vmap: cpu usage 1.0/16 wall_time:3.2s
xmap: cpu usage 1.0/16 wall_time:3.3s
dot: cpu usage 1.0/16 wall_time:20.4s

Similarly numpyro only heats a single core. Tested both cuda/CPU and CPU only jaxlib wheels, with various combinations of XLA_FLAGS="--xla_force_host_platform_device_count=8" taskset -c 0-7 numpyro.set_host_device_count(8) etc, with no effect.

5reactions
mrtupekcommented, Feb 22, 2021

I am also having the same trouble. I can only get 1 cpu core utilized. Any help here would be much appreciated!

Read more comments on GitHub >

github_iconTop Results From Across the Web

JAX Frequently Asked Questions (FAQ)
JAX by default only uses 32-bit dtypes. You may want to either explicitly use 32-bit dtypes in NumPy or enable 64-bit dtypes in...
Read more >
jax · PyPI
It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to ...
Read more >
[D] Should We Be Using JAX in 2022? : r/MachineLearning
it only require extra computation equal to one evaluation of your function instead of 2. PyTorch has always had a functional.jvp which does...
Read more >
Single-core - Wikipedia
... as it only runs on one thread. A computer using a single core CPU is generally slower than a multi-core system. Single...
Read more >
I feel like the #1 downside of Python for the last few years is ...
We use Python for data analysis as well, and for 95% of operations we are ... especially now that we can easily get...
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found