Unimplemented: DNN library is not found.

See original GitHub issue

Working on local GPU RTX 2060 super, Cuda 11.1, and got this error.

jax has been installed successfully with the following

pip install --upgrade jax jaxlib==0.1.57+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html

and symlink

sudo ln -s /path/to/cuda /usr/local/cuda-11.1

jax outputs the gpu with

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

and do math stuff like

rng_key = random.PRNGKey(0)

however still can’t train the model

evaluate(model, test_ds)

FilteredStackTrace: RuntimeError: Unimplemented: DNN library is not found.

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-8-0f8618edbb7d> in <module>()
     13   return compute_metrics(logits, eval_ds['label'])
     14 
---> 15 evaluate(model, test_ds)

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/traceback_util.py in reraise_with_filtered_traceback(*args, **kwargs)
    131   def reraise_with_filtered_traceback(*args, **kwargs):
    132     try:
--> 133       return fun(*args, **kwargs)
    134     except Exception as e:
    135       if not is_under_reraiser(e):

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/api.py in f_jitted(*args, **kwargs)
    221         backend=backend,
    222         name=flat_fun.__name__,
--> 223         donated_invars=donated_invars)
    224     return tree_unflatten(out_tree(), out)
    225 

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in bind(self, fun, *args, **params)
   1175 
   1176   def bind(self, fun, *args, **params):
-> 1177     return call_bind(self, fun, *args, **params)
   1178 
   1179   def process(self, trace, fun, tracers, params):

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in call_bind(primitive, fun, *args, **params)
   1166   tracers = map(top_trace.full_raise, args)
   1167   with maybe_new_sublevel(top_trace):
-> 1168     outs = primitive.process(top_trace, fun, tracers, params)
   1169   return map(full_lower, apply_todos(env_trace_todo(), outs))
   1170 

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in process(self, trace, fun, tracers, params)
   1178 
   1179   def process(self, trace, fun, tracers, params):
-> 1180     return trace.process_call(self, fun, tracers, params)
   1181 
   1182   def post_process(self, trace, out_tracers, params):

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/core.py in process_call(self, primitive, f, tracers, params)
    577 
    578   def process_call(self, primitive, f, tracers, params):
--> 579     return primitive.impl(f, *tracers, **params)
    580   process_map = process_call
    581 

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/interpreters/xla.py in _xla_call_impl(fun, device, backend, name, donated_invars, *args)
    557                                *unsafe_map(arg_spec, args))
    558   try:
--> 559     return compiled_fun(*args)
    560   except FloatingPointError:
    561     assert FLAGS.jax_debug_nans  # compiled_fun can only raise in this case

/home/xxx/anaconda3/envs/flax/lib/python3.7/site-packages/jax/interpreters/xla.py in _execute_compiled(compiled, avals, handlers, *args)
    805   device, = compiled.local_devices()
    806   input_bufs = list(it.chain.from_iterable(device_put(x, device) for x in args if x is not token))
--> 807   out_bufs = compiled.execute(input_bufs)
    808   if FLAGS.jax_debug_nans: check_nans(xla_call_p, out_bufs)
    809   return [handler(*bs) for handler, bs in zip(handlers, _partition_outputs(avals, out_bufs))]

RuntimeError: Unimplemented: DNN library is not found.

Issue Analytics

  • State:open
  • Created 3 years ago
  • Comments:23 (4 by maintainers)

github_iconTop GitHub Comments

9reactions
xidulucommented, Jun 21, 2022

I was able to solve this problem by adding these 4 lines of code at the head of the file:

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] ='false'
os.environ['XLA_PYTHON_CLIENT_ALLOCATOR']='platform'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
6reactions
hawkinspcommented, Nov 17, 2020

About 5 months ago (https://github.com/google/jax/commit/a141cc6e8d36ff10e28180683588bedf5432df1a) we switched how we link GPU libraries to be the same as TensorFlow, namely, we use dlopen() to find libraries like CuDNN rather than linking against them directly. dlopen() looks for libraries using LD_LIBRARY_PATH, so that’s ultimately the cause of this error: we can’t find the libraries.

I suspect you would see the exact same behavior with tensorflow with GPU support: as far as I am aware, it uses the same code to find the GPU libraries. It might be interesting to verify that hypothesis: install a GPU version of TF and try running a convolution. You should see the same error as JAX (if you haven’t set LD_LIBRARY_PATH).

I also suspect if you set TF_CPP_MIN_LOG_LEVEL=0 then you may see some better logging that more clearly indicates what the real problem is.

I agree the error message isn’t very helpful; we should probably fix that.

Read more comments on GitHub >

github_iconTop Results From Across the Web

Colab: (0) UNIMPLEMENTED: DNN library is not found
Now when I try to run model I have this message: Graph execution error: 2 root error(s) found. (0) UNIMPLEMENTED: DNN library is...
Read more >
Can't train network: "DNN library is not found" - Image.sc Forum
Latest tf-2 friendly DLC. I can label frames and create training datasets. However, starting training throws this error (2 different attempts):.
Read more >
Onsets and Frames Colab Notebook - Run inference error
This is being caused by a mismatch between the installed cudnn library and the compiled TensorFlow code. I'll try to figure out why...
Read more >
CUDNN version conflict JETSON AGX ORIN
CuDNN library needs to have matching major version and equal or higher minor version… UNIMPLEMENTED: DNN library is not found.
Read more >
model_main_tf2.py doesn't train : r/tensorflow - Reddit
... W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at conv_ops.cc:1130 : UNIMPLEMENTED: DNN library is not found.
Read more >

github_iconTop Related Medium Post

No results found

github_iconTop Related StackOverflow Question

No results found

github_iconTroubleshoot Live Code

Lightrun enables developers to add logs, metrics and snapshots to live code - no restarts or redeploys required.
Start Free

github_iconTop Related Reddit Thread

No results found

github_iconTop Related Hackernoon Post

No results found

github_iconTop Related Tweet

No results found

github_iconTop Related Dev.to Post

No results found

github_iconTop Related Hashnode Post

No results found