diff --git a/wetb/utils/caching.py b/wetb/utils/caching.py index 1ca38d4bf4d77251898cfdd9921fbdb3c5828702..a4333a2b44aba2fd03bf7fdaff708b7754ab882e 100644 --- a/wetb/utils/caching.py +++ b/wetb/utils/caching.py @@ -9,6 +9,7 @@ from __future__ import division from __future__ import absolute_import from future import standard_library import sys +from collections import OrderedDict standard_library.install_aliases() import inspect @@ -91,3 +92,18 @@ def cache_function(f): raise AttributeError("Functions decorated with cache_function are not allowed to take a parameter called 'reload'") return wrap +class cache_method(): + def __init__(self, N): + self.N = N + self.cache_dict = OrderedDict() + + def __call__(self, f): + def wrapped(*args): + name = "_" + f.__name__ + arg_id = ";".join([str(a) for a in args]) + if arg_id not in self.cache_dict: + self.cache_dict[arg_id] = f(*args) + if len(self.cache_dict)>self.N: + self.cache_dict.popitem(last=False) + return self.cache_dict[arg_id] + return wrapped \ No newline at end of file diff --git a/wetb/utils/tests/test_caching.py b/wetb/utils/tests/test_caching.py index ac529321356b3b26072364b2c9e64cc9576dfc5e..417087a1424374ee21c6770a5c56f0f9a5efeab1 100644 --- a/wetb/utils/tests/test_caching.py +++ b/wetb/utils/tests/test_caching.py @@ -8,6 +8,7 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import from future import standard_library +from collections import OrderedDict standard_library.install_aliases() import multiprocessing import time @@ -15,10 +16,12 @@ import unittest from wetb.utils.timing import get_time -from wetb.utils.caching import cache_function, set_cache_property +from wetb.utils.caching import cache_function, set_cache_property, cache_method import pdb + + class Example(object): def __init__(self, *args, **kwargs): object.__init__(self, *args, **kwargs) @@ -38,6 +41,16 @@ class Example(object): def prop(self, prop): return getattr(self, prop) + + @cache_method(2) + def test_cache_method1(self, x): + time.sleep(1) + return x + + @cache_method(2) + def test_cache_method2(self, x): + time.sleep(1) + return x*2 @@ -52,14 +65,14 @@ class TestCacheProperty(unittest.TestCase): e = Example() self.assertAlmostEqual(e.prop("test")[1], 1, 1) self.assertAlmostEqual(e.prop("test")[1], 0, 2) - + def testcache_property_pool(self): e = Example() e.prop("pool") #load pool self.assertAlmostEqual(e.prop("pool")[1], 0, places=4) #print (get_time(e.pool.map)(f, range(10))) - - + + def test_cache_function(self): e = Example() self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1) @@ -68,11 +81,24 @@ class TestCacheProperty(unittest.TestCase): self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 0, places=1) e.clear_cache() self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1) + - - - - + def test_cache_function_with_arguments(self): + e = Example() + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1) + self.assertEqual(e.test_cache_method1(3), 3) + self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 1, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 1, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1) + self.assertEqual(e.test_cache_method2(3), 6) + self.assertEqual(e.test_cache_method1(3), 3) + self.assertEqual(e.test_cache_method1(5), 5) + self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1) if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.testName'] unittest.main()