From 07595e1eb278795ca2d7ae70b200be9487cf41e8 Mon Sep 17 00:00:00 2001
From: "Mads M. Pedersen" <mmpe@dtu.dk>
Date: Thu, 2 Nov 2017 09:38:49 +0100
Subject: [PATCH] fixed caching issue

---
 wetb/utils/caching.py                 | 29 +++++++++--------------
 wetb/utils/tests/test_caching.py      | 34 +++++++++++++++------------
 wetb/utils/tests/test_files/test2.csv |  5 ----
 3 files changed, 30 insertions(+), 38 deletions(-)
 delete mode 100644 wetb/utils/tests/test_files/test2.csv

diff --git a/wetb/utils/caching.py b/wetb/utils/caching.py
index 008f184f..df0c369f 100644
--- a/wetb/utils/caching.py
+++ b/wetb/utils/caching.py
@@ -128,36 +128,29 @@ def cache_npsave(f):
             return loadsave()
     return wrap
 
-def cache_npsavez(f):
+def _get_npsavez_wrap(f, compress):
     def wrap(filename,*args,**kwargs):
-        np_filename = os.path.splitext(filename)[0] + ".npy.npz"
+        np_filename = os.path.splitext(filename)[0] + ".npy.npz"+("","c")[compress]
         def loadsave():
             res = f(filename,*args,**kwargs)
-            np.savez(np_filename,*res)
+            if compress:
+                np.savez_compressed(np_filename,*res)
+            else:
+                np.savez(np_filename,*res)
             return res
         if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)):
             try:
                 npzfile = np.load(np_filename)
-                return [npzfile['arr_%d'%i] for i in range(len(f.files()))]
+                return [npzfile['arr_%d'%i] for i in range(len(npzfile.files))]
             except:
                 return loadsave()
         else:
             return loadsave()
     return wrap
 
+def cache_npsavez(f):
+    return _get_npsavez_wrap(f,False)
+
 
 def cache_npsavez_compressed(f):
-    def wrap(filename,*args,**kwargs):
-        np_filename = os.path.splitext(filename)[0] + ".npy.npz"
-        def loadsave():
-            res = f(filename,*args,**kwargs)
-            np.savez_compressed(np_filename,*res)
-            return res
-        if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)):
-            try:
-                return [f['arr_%d'%i] for i in range(len(f.files()))]
-            except:
-                return loadsave()
-        else:
-            return loadsave()
-    return wrap
\ No newline at end of file
+    return _get_npsavez_wrap(f, True)
\ No newline at end of file
diff --git a/wetb/utils/tests/test_caching.py b/wetb/utils/tests/test_caching.py
index bf308e76..816d3ec1 100644
--- a/wetb/utils/tests/test_caching.py
+++ b/wetb/utils/tests/test_caching.py
@@ -136,28 +136,32 @@ class TestCacheProperty(unittest.TestCase):
         os.remove(tfp+'test.npy')
 
     def test_cache_savez(self):
-        if os.path.isfile(tfp+"test.npy.npy"):
-            os.remove(tfp+'test.npy.npy')
-        A = open_csv2(tfp + "test.csv")
-        self.assertTrue(os.path.isfile(tfp+"test.npy.npz"))
+        npfilename = tfp+"test.npy.npz"
+        func = open_csv2
+        if os.path.isfile(npfilename):
+            os.remove(npfilename)
+        A = func(tfp + "test.csv")
+        self.assertTrue(os.path.isfile(npfilename))
         np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv"))
         A[0][0]=-1
-        np.save(tfp+"test.npy",A)
-        B = open_csv(tfp + "test.csv")
+        np.savez(npfilename,A[0],A[1])
+        B = func(tfp + "test.csv")
         np.testing.assert_array_equal(A,B)
-        os.remove(tfp+'test.npy')
+        os.remove(npfilename)
         
     def test_cache_savez_compressed(self):
-        if os.path.isfile(tfp+"test2.npy.npy"):
-            os.remove(tfp+'test2.npy.npy')
-        A = open_csv2(tfp + "test2.csv")
-        self.assertTrue(os.path.isfile(tfp+"test2.npy.npz"))
-        np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test2.csv"))
+        npfilename = tfp+"test.npy.npzc"
+        func = open_csv3
+        if os.path.isfile(npfilename):
+            os.remove(npfilename)
+        A = func(tfp + "test.csv")
+        self.assertTrue(os.path.isfile(npfilename))
+        np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv"))
         A[0][0]=-1
-        np.save(tfp+"test2.npy",A)
-        B = open_csv(tfp + "test2.csv")
+        np.savez(npfilename,A[0],A[1])
+        B = func(tfp + "test.csv")
         np.testing.assert_array_equal(A,B)
-        os.remove(tfp+'test2.npy')
+        os.remove(npfilename)
                 
 if __name__ == "__main__":
     #import sys;sys.argv = ['', 'Test.testName']
diff --git a/wetb/utils/tests/test_files/test2.csv b/wetb/utils/tests/test_files/test2.csv
deleted file mode 100644
index 158e1a69..00000000
--- a/wetb/utils/tests/test_files/test2.csv
+++ /dev/null
@@ -1,5 +0,0 @@
-0.000000000000000000e+00 5.000000000000000000e+00
-1.000000000000000000e+00 6.000000000000000000e+00
-2.000000000000000000e+00 7.000000000000000000e+00
-3.000000000000000000e+00 8.000000000000000000e+00
-4.000000000000000000e+00 9.000000000000000000e+00
-- 
GitLab