Skip to content
Snippets Groups Projects
Commit 07595e1e authored by Mads M. Pedersen's avatar Mads M. Pedersen
Browse files

fixed caching issue

parent fda48e00
No related branches found
No related tags found
1 merge request!44Mmpe
Pipeline #
......@@ -128,36 +128,29 @@ def cache_npsave(f):
return loadsave()
return wrap
def cache_npsavez(f):
def _get_npsavez_wrap(f, compress):
def wrap(filename,*args,**kwargs):
np_filename = os.path.splitext(filename)[0] + ".npy.npz"
np_filename = os.path.splitext(filename)[0] + ".npy.npz"+("","c")[compress]
def loadsave():
res = f(filename,*args,**kwargs)
np.savez(np_filename,*res)
if compress:
np.savez_compressed(np_filename,*res)
else:
np.savez(np_filename,*res)
return res
if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)):
try:
npzfile = np.load(np_filename)
return [npzfile['arr_%d'%i] for i in range(len(f.files()))]
return [npzfile['arr_%d'%i] for i in range(len(npzfile.files))]
except:
return loadsave()
else:
return loadsave()
return wrap
def cache_npsavez(f):
return _get_npsavez_wrap(f,False)
def cache_npsavez_compressed(f):
def wrap(filename,*args,**kwargs):
np_filename = os.path.splitext(filename)[0] + ".npy.npz"
def loadsave():
res = f(filename,*args,**kwargs)
np.savez_compressed(np_filename,*res)
return res
if os.path.isfile(np_filename) and (not os.path.isfile(filename) or os.path.getmtime(np_filename) > os.path.getmtime(filename)):
try:
return [f['arr_%d'%i] for i in range(len(f.files()))]
except:
return loadsave()
else:
return loadsave()
return wrap
\ No newline at end of file
return _get_npsavez_wrap(f, True)
\ No newline at end of file
......@@ -136,28 +136,32 @@ class TestCacheProperty(unittest.TestCase):
os.remove(tfp+'test.npy')
def test_cache_savez(self):
if os.path.isfile(tfp+"test.npy.npy"):
os.remove(tfp+'test.npy.npy')
A = open_csv2(tfp + "test.csv")
self.assertTrue(os.path.isfile(tfp+"test.npy.npz"))
npfilename = tfp+"test.npy.npz"
func = open_csv2
if os.path.isfile(npfilename):
os.remove(npfilename)
A = func(tfp + "test.csv")
self.assertTrue(os.path.isfile(npfilename))
np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv"))
A[0][0]=-1
np.save(tfp+"test.npy",A)
B = open_csv(tfp + "test.csv")
np.savez(npfilename,A[0],A[1])
B = func(tfp + "test.csv")
np.testing.assert_array_equal(A,B)
os.remove(tfp+'test.npy')
os.remove(npfilename)
def test_cache_savez_compressed(self):
if os.path.isfile(tfp+"test2.npy.npy"):
os.remove(tfp+'test2.npy.npy')
A = open_csv2(tfp + "test2.csv")
self.assertTrue(os.path.isfile(tfp+"test2.npy.npz"))
np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test2.csv"))
npfilename = tfp+"test.npy.npzc"
func = open_csv3
if os.path.isfile(npfilename):
os.remove(npfilename)
A = func(tfp + "test.csv")
self.assertTrue(os.path.isfile(npfilename))
np.testing.assert_array_equal(A[0],np.loadtxt(tfp + "test.csv"))
A[0][0]=-1
np.save(tfp+"test2.npy",A)
B = open_csv(tfp + "test2.csv")
np.savez(npfilename,A[0],A[1])
B = func(tfp + "test.csv")
np.testing.assert_array_equal(A,B)
os.remove(tfp+'test2.npy')
os.remove(npfilename)
if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.testName']
......
0.000000000000000000e+00 5.000000000000000000e+00
1.000000000000000000e+00 6.000000000000000000e+00
2.000000000000000000e+00 7.000000000000000000e+00
3.000000000000000000e+00 8.000000000000000000e+00
4.000000000000000000e+00 9.000000000000000000e+00
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment