From dbae89cf5cf79e4845303dbe202531afae4e13a6 Mon Sep 17 00:00:00 2001
From: Stuart Archibald <redacted>
Date: Fri, 22 Feb 2019 10:25:01 +0000
Subject: [PATCH 07/21] use mklfft when available

---
 numpy/core/tests/test_overrides.py |  1 +
 numpy/fft/__init__.py              | 41 ++++++++++++++++++++++++++++++
 numpy/tests/test_public_api.py     |  2 +-
 3 files changed, 43 insertions(+), 1 deletion(-)

diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 63b0e4539..2c3ce5ce5 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -372,6 +372,7 @@ class TestNDArrayMethods(object):
 
 class TestNumPyFunctions(object):
 
+    @pytest.mark.xfail(reason="numpy.fft is patched in MKL-optimized NumPy")
     def test_set_module(self):
         assert_equal(np.sum.__module__, 'numpy')
         assert_equal(np.char.equal.__module__, 'numpy.char')
diff --git a/numpy/fft/__init__.py b/numpy/fft/__init__.py
index 64b35bc19..825180373 100644
--- a/numpy/fft/__init__.py
+++ b/numpy/fft/__init__.py
@@ -9,3 +9,44 @@ from .helper import *
 from numpy._pytesttester import PytestTester
 test = PytestTester(__name__)
 del PytestTester
+
+
+try:
+    import mkl_fft._numpy_fft as _nfft
+    patch_fft = True
+    __patched_functions__ = _nfft.__all__
+except ImportError:
+    patch_fft = False
+
+if patch_fft:
+    _restore_dict = {}
+    import sys
+
+    def register_func(name, func):
+        if name not in __patched_functions__:
+            raise ValueError("%s not an mkl_fft function." % name)
+        f = sys._getframe(0).f_globals
+        _restore_dict[name] = f[name]
+        f[name] = func
+
+    def restore_func(name):
+        if name not in __patched_functions__:
+            raise ValueError("%s not an mkl_fft function." % name)
+        try:
+            val = _restore_dict[name]
+        except KeyError:
+            print('failed to restore')
+            return
+        else:
+            print('found and restoring...')
+            sys._getframe(0).f_globals[name] = val
+
+    def restore_all():
+        for name in _restore_dict.keys():
+            restore_func(name)
+
+    for f in __patched_functions__:
+        register_func(f, getattr(_nfft, f))
+    del _nfft
+
+del patch_fft
diff --git a/numpy/tests/test_public_api.py b/numpy/tests/test_public_api.py
index 807c98652..488661f78 100644
--- a/numpy/tests/test_public_api.py
+++ b/numpy/tests/test_public_api.py
@@ -73,7 +73,7 @@ def test_numpy_linalg():
     bad_results = check_dir(np.linalg)
     assert bad_results == {}
 
-
+@pytest.mark.xfail(reason="numpy.fft is patched in MKL-optimized NumPy", run=False)
 def test_numpy_fft():
     bad_results = check_dir(np.fft)
     assert bad_results == {}
-- 
2.20.1

