From f3a49064efe7310224030f77ca7f3c8136120238 Mon Sep 17 00:00:00 2001 From: "Mads M. Pedersen" <mmpe@dtu.dk> Date: Mon, 22 May 2017 14:01:47 +0200 Subject: [PATCH] cachemethod added in caching.py --- wetb/utils/caching.py | 16 ++++++++++++ wetb/utils/tests/test_caching.py | 42 ++++++++++++++++++++++++++------ 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/wetb/utils/caching.py b/wetb/utils/caching.py index 1ca38d4b..a4333a2b 100644 --- a/wetb/utils/caching.py +++ b/wetb/utils/caching.py @@ -9,6 +9,7 @@ from __future__ import division from __future__ import absolute_import from future import standard_library import sys +from collections import OrderedDict standard_library.install_aliases() import inspect @@ -91,3 +92,18 @@ def cache_function(f): raise AttributeError("Functions decorated with cache_function are not allowed to take a parameter called 'reload'") return wrap +class cache_method(): + def __init__(self, N): + self.N = N + self.cache_dict = OrderedDict() + + def __call__(self, f): + def wrapped(*args): + name = "_" + f.__name__ + arg_id = ";".join([str(a) for a in args]) + if arg_id not in self.cache_dict: + self.cache_dict[arg_id] = f(*args) + if len(self.cache_dict)>self.N: + self.cache_dict.popitem(last=False) + return self.cache_dict[arg_id] + return wrapped \ No newline at end of file diff --git a/wetb/utils/tests/test_caching.py b/wetb/utils/tests/test_caching.py index ac529321..417087a1 100644 --- a/wetb/utils/tests/test_caching.py +++ b/wetb/utils/tests/test_caching.py @@ -8,6 +8,7 @@ from __future__ import print_function from __future__ import division from __future__ import absolute_import from future import standard_library +from collections import OrderedDict standard_library.install_aliases() import multiprocessing import time @@ -15,10 +16,12 @@ import unittest from wetb.utils.timing import get_time -from wetb.utils.caching import cache_function, set_cache_property +from wetb.utils.caching import cache_function, set_cache_property, cache_method import pdb + + class Example(object): def __init__(self, *args, **kwargs): object.__init__(self, *args, **kwargs) @@ -38,6 +41,16 @@ class Example(object): def prop(self, prop): return getattr(self, prop) + + @cache_method(2) + def test_cache_method1(self, x): + time.sleep(1) + return x + + @cache_method(2) + def test_cache_method2(self, x): + time.sleep(1) + return x*2 @@ -52,14 +65,14 @@ class TestCacheProperty(unittest.TestCase): e = Example() self.assertAlmostEqual(e.prop("test")[1], 1, 1) self.assertAlmostEqual(e.prop("test")[1], 0, 2) - + def testcache_property_pool(self): e = Example() e.prop("pool") #load pool self.assertAlmostEqual(e.prop("pool")[1], 0, places=4) #print (get_time(e.pool.map)(f, range(10))) - - + + def test_cache_function(self): e = Example() self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1) @@ -68,11 +81,24 @@ class TestCacheProperty(unittest.TestCase): self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 0, places=1) e.clear_cache() self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1) + - - - - + def test_cache_function_with_arguments(self): + e = Example() + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1) + self.assertEqual(e.test_cache_method1(3), 3) + self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 1, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 1, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1) + self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1) + self.assertEqual(e.test_cache_method2(3), 6) + self.assertEqual(e.test_cache_method1(3), 3) + self.assertEqual(e.test_cache_method1(5), 5) + self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1) if __name__ == "__main__": #import sys;sys.argv = ['', 'Test.testName'] unittest.main() -- GitLab