Skip to content
Snippets Groups Projects
Commit 07595e1e authored by Mads M. Pedersen's avatar Mads M. Pedersen
Browse files

fixed caching issue

parent fda48e00
No related branches found
No related tags found
No related merge requests found
...@@ -128,36 +128,29 @@ def cache_npsave(f): ...@@ -128,36 +128,29 @@ def cache_npsave(f):
return loadsave() return loadsave()
return wrap return wrap
def cache_npsavez(f): def _get_npsavez_wrap(f, compress):
def wrap(filename,*args,**kwargs): def wrap(filename,*args,**kwargs):
np_filename = os.path.splitext(filename)[0] + ".npy.npz" np_filename = os.path.splitext(filename)[0] + ".npy.npz"+("","c")[compress]
def loadsave(): def loadsave():
res = f(filename,*args,**kwargs) res = f(filename,*args,**kwargs)
np.savez(np_filename,*res) if compress:
np.savez_compressed(np_filename,*res)
else:
np.savez(np_filename,*res)
return res return res
if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)): if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)):
try: try:
npzfile = np.load(np_filename) npzfile = np.load(np_filename)
return [npzfile['arr_%d'%i] for i in range(len(f.files()))] return [npzfile['arr_%d'%i] for i in range(len(npzfile.files))]
except: except:
return loadsave() return loadsave()
else: else:
return loadsave() return loadsave()
return wrap return wrap
def cache_npsavez(f):
return _get_npsavez_wrap(f,False)
def cache_npsavez_compressed(f): def cache_npsavez_compressed(f):
def wrap(filename,*args,**kwargs): return _get_npsavez_wrap(f, True)
np_filename = os.path.splitext(filename)[0] + ".npy.npz" \ No newline at end of file
def loadsave():
res = f(filename,*args,**kwargs)
np.savez_compressed(np_filename,*res)
return res
if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)):
try:
return [f['arr_%d'%i] for i in range(len(f.files()))]
except:
return loadsave()
else:
return loadsave()
return wrap
\ No newline at end of file
...@@ -136,28 +136,32 @@ class TestCacheProperty(unittest.TestCase): ...@@ -136,28 +136,32 @@ class TestCacheProperty(unittest.TestCase):
os.remove(tfp+'test.npy') os.remove(tfp+'test.npy')
def test_cache_savez(self): def test_cache_savez(self):
if os.path.isfile(tfp+"test.npy.npy"): npfilename = tfp+"test.npy.npz"
os.remove(tfp+'test.npy.npy') func = open_csv2
A = open_csv2(tfp + "test.csv") if os.path.isfile(npfilename):
self.assertTrue(os.path.isfile(tfp+"test.npy.npz")) os.remove(npfilename)
A = func(tfp + "test.csv")
self.assertTrue(os.path.isfile(npfilename))
np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv")) np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv"))
A[0][0]=-1 A[0][0]=-1
np.save(tfp+"test.npy",A) np.savez(npfilename,A[0],A[1])
B = open_csv(tfp + "test.csv") B = func(tfp + "test.csv")
np.testing.assert_array_equal(A,B) np.testing.assert_array_equal(A,B)
os.remove(tfp+'test.npy') os.remove(npfilename)
def test_cache_savez_compressed(self): def test_cache_savez_compressed(self):
if os.path.isfile(tfp+"test2.npy.npy"): npfilename = tfp+"test.npy.npzc"
os.remove(tfp+'test2.npy.npy') func = open_csv3
A = open_csv2(tfp + "test2.csv") if os.path.isfile(npfilename):
self.assertTrue(os.path.isfile(tfp+"test2.npy.npz")) os.remove(npfilename)
np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test2.csv")) A = func(tfp + "test.csv")
self.assertTrue(os.path.isfile(npfilename))
np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv"))
A[0][0]=-1 A[0][0]=-1
np.save(tfp+"test2.npy",A) np.savez(npfilename,A[0],A[1])
B = open_csv(tfp + "test2.csv") B = func(tfp + "test.csv")
np.testing.assert_array_equal(A,B) np.testing.assert_array_equal(A,B)
os.remove(tfp+'test2.npy') os.remove(npfilename)
if __name__ == "__main__": if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.testName'] #import sys;sys.argv = ['', 'Test.testName']
......
0.000000000000000000e+00 5.000000000000000000e+00
1.000000000000000000e+00 6.000000000000000000e+00
2.000000000000000000e+00 7.000000000000000000e+00
3.000000000000000000e+00 8.000000000000000000e+00
4.000000000000000000e+00 9.000000000000000000e+00
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment