Skip to content
Snippets Groups Projects
Commit f3a49064 authored by Mads M. Pedersen's avatar Mads M. Pedersen
Browse files

cachemethod added in caching.py

parent 00be8be8
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ from __future__ import division ...@@ -9,6 +9,7 @@ from __future__ import division
from __future__ import absolute_import from __future__ import absolute_import
from future import standard_library from future import standard_library
import sys import sys
from collections import OrderedDict
standard_library.install_aliases() standard_library.install_aliases()
import inspect import inspect
...@@ -91,3 +92,18 @@ def cache_function(f): ...@@ -91,3 +92,18 @@ def cache_function(f):
raise AttributeError("Functions decorated with cache_function are not allowed to take a parameter called 'reload'") raise AttributeError("Functions decorated with cache_function are not allowed to take a parameter called 'reload'")
return wrap return wrap
class cache_method():
def __init__(self, N):
self.N = N
self.cache_dict = OrderedDict()
def __call__(self, f):
def wrapped(*args):
name = "_" + f.__name__
arg_id = ";".join([str(a) for a in args])
if arg_id not in self.cache_dict:
self.cache_dict[arg_id] = f(*args)
if len(self.cache_dict)>self.N:
self.cache_dict.popitem(last=False)
return self.cache_dict[arg_id]
return wrapped
\ No newline at end of file
...@@ -8,6 +8,7 @@ from __future__ import print_function ...@@ -8,6 +8,7 @@ from __future__ import print_function
from __future__ import division from __future__ import division
from __future__ import absolute_import from __future__ import absolute_import
from future import standard_library from future import standard_library
from collections import OrderedDict
standard_library.install_aliases() standard_library.install_aliases()
import multiprocessing import multiprocessing
import time import time
...@@ -15,10 +16,12 @@ import unittest ...@@ -15,10 +16,12 @@ import unittest
from wetb.utils.timing import get_time from wetb.utils.timing import get_time
from wetb.utils.caching import cache_function, set_cache_property from wetb.utils.caching import cache_function, set_cache_property, cache_method
import pdb import pdb
class Example(object): class Example(object):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
object.__init__(self, *args, **kwargs) object.__init__(self, *args, **kwargs)
...@@ -38,6 +41,16 @@ class Example(object): ...@@ -38,6 +41,16 @@ class Example(object):
def prop(self, prop): def prop(self, prop):
return getattr(self, prop) return getattr(self, prop)
@cache_method(2)
def test_cache_method1(self, x):
time.sleep(1)
return x
@cache_method(2)
def test_cache_method2(self, x):
time.sleep(1)
return x*2
...@@ -52,14 +65,14 @@ class TestCacheProperty(unittest.TestCase): ...@@ -52,14 +65,14 @@ class TestCacheProperty(unittest.TestCase):
e = Example() e = Example()
self.assertAlmostEqual(e.prop("test")[1], 1, 1) self.assertAlmostEqual(e.prop("test")[1], 1, 1)
self.assertAlmostEqual(e.prop("test")[1], 0, 2) self.assertAlmostEqual(e.prop("test")[1], 0, 2)
def testcache_property_pool(self): def testcache_property_pool(self):
e = Example() e = Example()
e.prop("pool") #load pool e.prop("pool") #load pool
self.assertAlmostEqual(e.prop("pool")[1], 0, places=4) self.assertAlmostEqual(e.prop("pool")[1], 0, places=4)
#print (get_time(e.pool.map)(f, range(10))) #print (get_time(e.pool.map)(f, range(10)))
def test_cache_function(self): def test_cache_function(self):
e = Example() e = Example()
self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1) self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1)
...@@ -68,11 +81,24 @@ class TestCacheProperty(unittest.TestCase): ...@@ -68,11 +81,24 @@ class TestCacheProperty(unittest.TestCase):
self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 0, places=1) self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 0, places=1)
e.clear_cache() e.clear_cache()
self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1) self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1)
def test_cache_function_with_arguments(self):
e = Example()
self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1)
self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1)
self.assertEqual(e.test_cache_method1(3), 3)
self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 1, places=1)
self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1)
self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1)
self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 1, places=1)
self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1)
self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1)
self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1)
self.assertEqual(e.test_cache_method2(3), 6)
self.assertEqual(e.test_cache_method1(3), 3)
self.assertEqual(e.test_cache_method1(5), 5)
self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1)
if __name__ == "__main__": if __name__ == "__main__":
#import sys;sys.argv = ['', 'Test.testName'] #import sys;sys.argv = ['', 'Test.testName']
unittest.main() unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment