# -*- coding: utf-8 -*-
# Copyright (C) 2019 Anaconda, Inc
# SPDX-License-Identifier: BSD-3-Clause

def cuda_detect():
    '''Attempt to detect the version of CUDA present in the operating system.

    On Windows and Linux, the CUDA library is installed by the NVIDIA
    driver package, and is typically found in the standard library path,
    rather than with the CUDA SDK (which is optional for running CUDA apps).

    On macOS, the CUDA library is only installed with the CUDA SDK, and
    might not be in the library path.

    Returns: version string (Ex: '9.2') or None if CUDA not found.
    '''
    # platform specific libcuda location
    import platform
    system = platform.system()
    if system == 'Darwin':
        lib_filenames = [
            'libcuda.dylib',  # check library path first
            '/usr/local/cuda/lib/libcuda.dylib'
        ]
    elif system == 'Linux':
        lib_filenames = [
            'libcuda.so',  # check library path first
            '/usr/lib64/nvidia/libcuda.so',  # Redhat/CentOS/Fedora
            '/usr/lib/x86_64-linux-gnu/libcuda.so',  # Ubuntu
        ]
    elif system == 'Windows':
        lib_filenames = ['nvcuda.dll']
    else:
        return None  # CUDA not available for other operating systems

    # open library
    import ctypes
    if system == 'Windows':
        dll = ctypes.windll
    else:
        dll = ctypes.cdll
    libcuda = None
    for lib_filename in lib_filenames:
        try:
            libcuda = dll.LoadLibrary(lib_filename)
            break
        except:
            pass
    if libcuda is None:
        return None

    # Get CUDA version
    try:
        cuInit = libcuda.cuInit
        flags = ctypes.c_uint(0)
        ret = cuInit(flags)
        if ret != 0:
            return None

        cuDriverGetVersion = libcuda.cuDriverGetVersion
        version_int = ctypes.c_int(0)
        ret = cuDriverGetVersion(ctypes.byref(version_int))
        if ret != 0:
            return None

        # Convert version integer to version string
        value = version_int.value
        return '%d.%d' % (value // 1000, (value % 1000) // 10)
    except:
        return None
