From f3a49064efe7310224030f77ca7f3c8136120238 Mon Sep 17 00:00:00 2001
From: "Mads M. Pedersen" <mmpe@dtu.dk>
Date: Mon, 22 May 2017 14:01:47 +0200
Subject: [PATCH] cachemethod added in caching.py

---
 wetb/utils/caching.py            | 16 ++++++++++++
 wetb/utils/tests/test_caching.py | 42 ++++++++++++++++++++++++++------
 2 files changed, 50 insertions(+), 8 deletions(-)

diff --git a/wetb/utils/caching.py b/wetb/utils/caching.py
index 1ca38d4b..a4333a2b 100644
--- a/wetb/utils/caching.py
+++ b/wetb/utils/caching.py
@@ -9,6 +9,7 @@ from __future__ import division
 from __future__ import absolute_import
 from future import standard_library
 import sys
+from collections import OrderedDict
 standard_library.install_aliases()
 import inspect
 
@@ -91,3 +92,18 @@ def cache_function(f):
         raise AttributeError("Functions decorated with cache_function are not allowed to take a parameter called 'reload'")
     return wrap
 
+class cache_method():
+    def __init__(self, N):
+        self.N = N
+        self.cache_dict = OrderedDict()
+        
+    def __call__(self, f):
+        def wrapped(*args):
+            name = "_" + f.__name__
+            arg_id = ";".join([str(a) for a in args])
+            if arg_id not in self.cache_dict: 
+                self.cache_dict[arg_id] = f(*args)
+                if len(self.cache_dict)>self.N:
+                    self.cache_dict.popitem(last=False)
+            return self.cache_dict[arg_id]
+        return wrapped
\ No newline at end of file
diff --git a/wetb/utils/tests/test_caching.py b/wetb/utils/tests/test_caching.py
index ac529321..417087a1 100644
--- a/wetb/utils/tests/test_caching.py
+++ b/wetb/utils/tests/test_caching.py
@@ -8,6 +8,7 @@ from __future__ import print_function
 from __future__ import division
 from __future__ import absolute_import
 from future import standard_library
+from collections import OrderedDict
 standard_library.install_aliases()
 import multiprocessing
 import time
@@ -15,10 +16,12 @@ import unittest
 
 
 from wetb.utils.timing import get_time
-from wetb.utils.caching import cache_function, set_cache_property
+from wetb.utils.caching import cache_function, set_cache_property, cache_method
 import pdb
 
 
+
+
 class Example(object):
     def __init__(self, *args, **kwargs):
         object.__init__(self, *args, **kwargs)
@@ -38,6 +41,16 @@ class Example(object):
     def prop(self, prop):
         return getattr(self, prop)
 
+    
+    @cache_method(2)
+    def test_cache_method1(self, x):
+        time.sleep(1)
+        return x
+
+    @cache_method(2)
+    def test_cache_method2(self, x):
+        time.sleep(1)
+        return x*2
 
 
 
@@ -52,14 +65,14 @@ class TestCacheProperty(unittest.TestCase):
         e = Example()
         self.assertAlmostEqual(e.prop("test")[1], 1, 1)
         self.assertAlmostEqual(e.prop("test")[1], 0, 2)
-
+ 
     def testcache_property_pool(self):
         e = Example()
         e.prop("pool")  #load pool
         self.assertAlmostEqual(e.prop("pool")[1], 0, places=4)
         #print (get_time(e.pool.map)(f, range(10)))
-
-
+ 
+ 
     def test_cache_function(self):
         e = Example()
         self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1)
@@ -68,11 +81,24 @@ class TestCacheProperty(unittest.TestCase):
         self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 0, places=1)
         e.clear_cache()
         self.assertAlmostEqual(get_time(e.test_cache_function)()[1], 1, places=1)
+ 
 
-
-
-
-
+    def test_cache_function_with_arguments(self):
+        e = Example()
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1)
+        self.assertEqual(e.test_cache_method1(3), 3)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 1, places=1)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 0, places=1)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 1, places=1)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(4)[1], 0, places=1)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(3)[1], 1, places=1)
+        self.assertEqual(e.test_cache_method2(3), 6)
+        self.assertEqual(e.test_cache_method1(3), 3)
+        self.assertEqual(e.test_cache_method1(5), 5)
+        self.assertAlmostEqual(get_time(e.test_cache_method1)(5)[1], 0, places=1)
 if __name__ == "__main__":
     #import sys;sys.argv = ['', 'Test.testName']
     unittest.main()
-- 
GitLab