diff --git a/wetb/utils/caching.py b/wetb/utils/caching.py index 008f184f3943efc19945237f1fc0335c956e903f..df0c369f8e6cffe70b4783d49ffe4c408ec97538 100644 --- a/wetb/utils/caching.py +++ b/wetb/utils/caching.py @@ -128,36 +128,29 @@ def cache_npsave(f): return loadsave() return wrap -def cache_npsavez(f): +def _get_npsavez_wrap(f, compress): def wrap(filename,*args,**kwargs): - np_filename = os.path.splitext(filename)[0] + ".npy.npz" + np_filename = os.path.splitext(filename)[0] + ".npy.npz"+("","c")[compress] def loadsave(): res = f(filename,*args,**kwargs) - np.savez(np_filename,*res) + if compress: + np.savez_compressed(np_filename,*res) + else: + np.savez(np_filename,*res) return res if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)): try: npzfile = np.load(np_filename) - return [npzfile['arr_%d'%i] for i in range(len(f.files()))] + return [npzfile['arr_%d'%i] for i in range(len(npzfile.files))] except: return loadsave() else: return loadsave() return wrap +def cache_npsavez(f): + return _get_npsavez_wrap(f,False) + def cache_npsavez_compressed(f): - def wrap(filename,*args,**kwargs): - np_filename = os.path.splitext(filename)[0] + ".npy.npz" - def loadsave(): - res = f(filename,*args,**kwargs) - np.savez_compressed(np_filename,*res) - return res - if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)): - try: - return [f['arr_%d'%i] for i in range(len(f.files()))] - except: - return loadsave() - else: - return loadsave() - return wrap \ No newline at end of file + return _get_npsavez_wrap(f, True) \ No newline at end of file diff --git a/wetb/utils/tests/test_caching.py b/wetb/utils/tests/test_caching.py index bf308e765fd55ec1c0bf825d3b30bfb5d4db9f00..816d3ec1ba2ddd20e2c192df4d9d2a9b336e5f7e 100644 --- a/wetb/utils/tests/test_caching.py +++ b/wetb/utils/tests/test_caching.py @@ -136,28 +136,32 @@ class TestCacheProperty(unittest.TestCase): os.remove(tfp+'test.npy') def test_cache_savez(self): - if os.path.isfile(tfp+"test.npy.npy"): - os.remove(tfp+'test.npy.npy') - A = open_csv2(tfp + "test.csv") - self.assertTrue(os.path.isfile(tfp+"test.npy.npz")) + npfilename = tfp+"test.npy.npz" + func = open_csv2 + if os.path.isfile(npfilename): + os.remove(npfilename) + A = func(tfp + "test.csv") + self.assertTrue(os.path.isfile(npfilename)) np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv")) A[0][0]=-1 - np.save(tfp+"test.npy",A) - B = open_csv(tfp + "test.csv") + np.savez(npfilename,A[0],A[1]) + B = func(tfp + "test.csv") np.testing.assert_array_equal(A,B) - os.remove(tfp+'test.npy') + os.remove(npfilename) def test_cache_savez_compressed(self): - if os.path.isfile(tfp+"test2.npy.npy"): - os.remove(tfp+'test2.npy.npy') - A = open_csv2(tfp + "test2.csv") - self.assertTrue(os.path.isfile(tfp+"test2.npy.npz")) - np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test2.csv")) + npfilename = tfp+"test.npy.npzc" + func = open_csv3 + if os.path.isfile(npfilename): + os.remove(npfilename) + A = func(tfp + "test.csv") + self.assertTrue(os.path.isfile(npfilename)) + np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv")) A[0][0]=-1 - np.save(tfp+"test2.npy",A) - B = open_csv(tfp + "test2.csv") + np.savez(npfilename,A[0],A[1]) + B = func(tfp + "test.csv") np.testing.assert_array_equal(A,B) - os.remove(tfp+'test2.npy') + os.remove(npfilename) if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.testName'] diff --git a/wetb/utils/tests/test_files/test2.csv b/wetb/utils/tests/test_files/test2.csv deleted file mode 100644 index 158e1a69bddc0c5c4fd4e79de42aedb45125f2e5..0000000000000000000000000000000000000000 --- a/wetb/utils/tests/test_files/test2.csv +++ /dev/null @@ -1,5 +0,0 @@ -0.000000000000000000e+00 5.000000000000000000e+00 -1.000000000000000000e+00 6.000000000000000000e+00 -2.000000000000000000e+00 7.000000000000000000e+00 -3.000000000000000000e+00 8.000000000000000000e+00 -4.000000000000000000e+00 9.000000000000000000e+00