Added a new cached property decorator - threaded_cached_property, for
authorTin Tvrtkovic <tinchester@gmail.com>
Mon, 19 May 2014 20:30:11 +0000 (22:30 +0200)
committerTin Tvrtkovic <tinchester@gmail.com>
Mon, 19 May 2014 20:30:11 +0000 (22:30 +0200)
use in multithreaded environments.
Added tests.
Added Python 3.4 to the list of test environments.

cached_property.py
tests/test_cached_property.py
tests/test_threaded_cached_property.py [new file with mode: 0755]
tox.ini

index 9807777654f8a7b85bdebf8584baee23cad5e0b1..46bcef7ffa145d915a16bc529191f9d3ea690301 100644 (file)
@@ -6,6 +6,7 @@ __version__ = '0.1.4'
 __license__ = 'BSD'
 
 import time
+import threading
 
 
 class cached_property(object):
@@ -25,3 +26,23 @@ class cached_property(object):
             return self
         value = obj.__dict__[self.func.__name__] = self.func(obj)
         return value
+
+
+class threaded_cached_property(cached_property):
+    """ A cached_property version for use in environments where multiple
+        threads might concurrently try to access the property.
+        """
+    def __init__(self, func):
+        super(threaded_cached_property, self).__init__(func)
+        self.lock = threading.RLock()
+
+    def __get__(self, obj, cls):
+        with self.lock:
+            # Double check if the value was computed before the lock was
+            # acquired.
+            prop_name = self.func.__name__
+            if prop_name in obj.__dict__:
+                return obj.__dict__[prop_name]
+
+            # If not, do the calculation and release the lock.
+            return super(threaded_cached_property, self).__get__(obj, cls)
\ No newline at end of file
index ef89730fd958bc74554e035b6400824e5643c511..98357381f34317ae3643076845a211aa42d0af6f 100755 (executable)
@@ -8,7 +8,7 @@ Tests for `cached-property` module.
 """
 
 from time import sleep
-from threading import Thread
+from threading import Lock, Thread
 import unittest
 
 from cached_property import cached_property
@@ -93,16 +93,20 @@ class TestThreadingIssues(unittest.TestCase):
 
             def __init__(self):
                 self.total = 0
+                self.lock = Lock()
 
             @cached_property
             def add_cached(self):
                 sleep(1)
-                self.total += 1
+                # Need to guard this since += isn't atomic.
+                with self.lock:
+                    self.total += 1
                 return self.total
 
         c = Check()
         threads = []
-        for x in range(10):
+        num_threads = 10
+        for x in range(num_threads):
             thread = Thread(target=lambda: c.add_cached)
             thread.start()
             threads.append(thread)
@@ -116,4 +120,10 @@ class TestThreadingIssues(unittest.TestCase):
 
         # TODO: This assertion should be failing.
         # See https://github.com/pydanny/cached-property/issues/6
-        self.assertEqual(c.add_cached, 10)
+        # This assertion hinges on the fact the system executing the test can
+        # spawn and start running num_threads threads within the sleep period
+        # (defined in the Check class as 1 second). If num_threads were to be
+        # massively increased (try 10000), the actual value returned would be
+        # between 1 and num_threads, depending on thread scheduling and
+        # preemption.
+        self.assertEqual(c.add_cached, num_threads)
diff --git a/tests/test_threaded_cached_property.py b/tests/test_threaded_cached_property.py
new file mode 100755 (executable)
index 0000000..f73e45e
--- /dev/null
@@ -0,0 +1,116 @@
+# -*- coding: utf-8 -*-
+
+"""
+test_threaded_cache_property.py
+----------------------------------
+
+Tests for `cached-property` module, threaded_cache_property.
+"""
+
+from time import sleep
+from threading import Thread, Lock
+import unittest
+
+from cached_property import threaded_cached_property
+
+
+class TestCachedProperty(unittest.TestCase):
+
+    def test_cached_property(self):
+
+        class Check(object):
+
+            def __init__(self):
+                self.total1 = 0
+                self.total2 = 0
+
+            @property
+            def add_control(self):
+                self.total1 += 1
+                return self.total1
+
+            @threaded_cached_property
+            def add_cached(self):
+                self.total2 += 1
+                return self.total2
+
+        c = Check()
+
+        # The control shows that we can continue to add 1.
+        self.assertEqual(c.add_control, 1)
+        self.assertEqual(c.add_control, 2)
+
+        # The cached version demonstrates how nothing new is added
+        self.assertEqual(c.add_cached, 1)
+        self.assertEqual(c.add_cached, 1)
+
+    def test_reset_cached_property(self):
+
+        class Check(object):
+
+            def __init__(self):
+                self.total = 0
+
+            @threaded_cached_property
+            def add_cached(self):
+                self.total += 1
+                return self.total
+
+        c = Check()
+
+        # Run standard cache assertion
+        self.assertEqual(c.add_cached, 1)
+        self.assertEqual(c.add_cached, 1)
+
+        # Reset the cache.
+        del c.add_cached
+        self.assertEqual(c.add_cached, 2)
+        self.assertEqual(c.add_cached, 2)
+
+    def test_none_cached_property(self):
+
+        class Check(object):
+
+            def __init__(self):
+                self.total = None
+
+            @threaded_cached_property
+            def add_cached(self):
+                return self.total
+
+        c = Check()
+
+        # Run standard cache assertion
+        self.assertEqual(c.add_cached, None)
+
+
+class TestThreadingIssues(unittest.TestCase):
+
+    def test_threads(self):
+        """ How well does this implementation work with threads?"""
+
+        class Check(object):
+
+            def __init__(self):
+                self.total = 0
+                self.lock = Lock()
+
+            @threaded_cached_property
+            def add_cached(self):
+                sleep(1)
+                # Need to guard this since += isn't atomic.
+                with self.lock:
+                    self.total += 1
+                return self.total
+
+        c = Check()
+        threads = []
+        for x in range(10):
+            thread = Thread(target=lambda: c.add_cached)
+            thread.start()
+            threads.append(thread)
+
+        for thread in threads:
+            thread.join()
+
+        self.assertEqual(c.add_cached, 1)
diff --git a/tox.ini b/tox.ini
index 4bb2d214dbe74c08d29aec192b22c6d9dc4558ee..e385e61fdfa462fae9cc30625ce87bd32c7d1e09 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
 [tox]
-envlist = py26, py27, py33
+envlist = py26, py27, py33, py34
 
 [testenv]
 setenv =