Fix the previous commit and add more test assertions to show why it was wrong.
authorGeorge Sakkis <george.sakkis@gmail.com>
Mon, 20 Apr 2015 00:46:09 +0000 (03:46 +0300)
committerGeorge Sakkis <george.sakkis@gmail.com>
Mon, 20 Apr 2015 00:56:44 +0000 (03:56 +0300)
Although the previous commit correctly cached and returned only the first computed
value (since dict.setdefault() is atomic), the actual computation could be performed
more than once in multithreaded environment, with all but the first computed values
being discarded.

cached_property.py
tests/test_cached_property.py

index 8e07fde3c708fd148c509d99a85bf1430d006e84..50b40ac09f114f374e9fec73d8c98c3fc6469786 100644 (file)
@@ -23,11 +23,34 @@ class cached_property(object):
     def __get__(self, obj, cls):
         if obj is None:
             return self
-        return obj.__dict__.setdefault(self.func.__name__, self.func(obj))
+        value = obj.__dict__[self.func.__name__] = self.func(obj)
+        return value
 
 
-# Leave for backwards compatibility
-threaded_cached_property = cached_property
+class threaded_cached_property(object):
+    """
+    A cached_property version for use in environments where multiple threads
+    might concurrently try to access the property.
+    """
+
+    def __init__(self, func):
+        self.__doc__ = getattr(func, '__doc__')
+        self.func = func
+        self.lock = threading.RLock()
+
+    def __get__(self, obj, cls):
+        if obj is None:
+            return self
+
+        obj_dict = obj.__dict__
+        name = self.func.__name__
+        with self.lock:
+            try:
+                # check if the value was computed before the lock was acquired
+                return obj_dict[name]
+            except KeyError:
+                # if not, do the calculation and release the lock
+                return obj_dict.setdefault(name, self.func(obj))
 
 
 class cached_property_with_ttl(object):
index 56678cf2c122c9999cb8d179c72fbe4c98af776e..cd046633de1ca257929d43775df94a72be1ab0de 100644 (file)
@@ -1,15 +1,18 @@
 # -*- coding: utf-8 -*-
 
-"""Tests for cached_property"""
+"""Tests for cached_property and threaded_cached_property"""
 
 from time import sleep
 from threading import Lock, Thread
 import unittest
 
-from cached_property import cached_property
+from cached_property import cached_property, threaded_cached_property
 
 
 class TestCachedProperty(unittest.TestCase):
+    """Tests for cached_property"""
+
+    cached_property_factory = cached_property
 
     def test_cached_property(self):
 
@@ -24,7 +27,7 @@ class TestCachedProperty(unittest.TestCase):
                 self.total1 += 1
                 return self.total1
 
-            @cached_property
+            @self.cached_property_factory
             def add_cached(self):
                 self.total2 += 1
                 return self.total2
@@ -38,10 +41,11 @@ class TestCachedProperty(unittest.TestCase):
         # The cached version demonstrates how nothing new is added
         self.assertEqual(c.add_cached, 1)
         self.assertEqual(c.add_cached, 1)
+        self.assertEqual(c.total2, 1)
 
         # It's customary for descriptors to return themselves if accessed
         # though the class, rather than through an instance.
-        self.assertTrue(isinstance(Check.add_cached, cached_property))
+        self.assertTrue(isinstance(Check.add_cached, self.cached_property_factory))
 
     def test_reset_cached_property(self):
 
@@ -50,7 +54,7 @@ class TestCachedProperty(unittest.TestCase):
             def __init__(self):
                 self.total = 0
 
-            @cached_property
+            @self.cached_property_factory
             def add_cached(self):
                 self.total += 1
                 return self.total
@@ -60,11 +64,13 @@ class TestCachedProperty(unittest.TestCase):
         # Run standard cache assertion
         self.assertEqual(c.add_cached, 1)
         self.assertEqual(c.add_cached, 1)
+        self.assertEqual(c.total, 1)
 
         # Reset the cache.
         del c.add_cached
         self.assertEqual(c.add_cached, 2)
         self.assertEqual(c.add_cached, 2)
+        self.assertEqual(c.total, 2)
 
     def test_none_cached_property(self):
 
@@ -73,7 +79,7 @@ class TestCachedProperty(unittest.TestCase):
             def __init__(self):
                 self.total = None
 
-            @cached_property
+            @self.cached_property_factory
             def add_cached(self):
                 return self.total
 
@@ -81,17 +87,33 @@ class TestCachedProperty(unittest.TestCase):
 
         # Run standard cache assertion
         self.assertEqual(c.add_cached, None)
+        self.assertEqual(c.total, None)
 
     def test_threads(self):
-        """How well does this implementation work with threads?"""
-
+        """
+        How well does the standard cached_property implementation work with
+        threads? It doesn't, use threaded_cached_property instead!
+        """
+        num_threads = 10
+        check = self._run_threads(num_threads)
+        # Threads means that caching is bypassed.
+        # This assertion hinges on the fact the system executing the test can
+        # spawn and start running num_threads threads within the sleep period
+        # (defined in the Check class as 1 second). If num_threads were to be
+        # massively increased (try 10000), the actual value returned would be
+        # between 1 and num_threads, depending on thread scheduling and
+        # preemption.
+        self.assertEqual(check.add_cached, num_threads)
+        self.assertEqual(check.total, num_threads)
+
+    def _run_threads(self, num_threads):
         class Check(object):
 
             def __init__(self):
                 self.total = 0
                 self.lock = Lock()
 
-            @cached_property
+            @self.cached_property_factory
             def add_cached(self):
                 sleep(1)
                 # Need to guard this since += isn't atomic.
@@ -100,13 +122,26 @@ class TestCachedProperty(unittest.TestCase):
                 return self.total
 
         c = Check()
+
         threads = []
-        for x in range(10):
+        for _ in range(num_threads):
             thread = Thread(target=lambda: c.add_cached)
             thread.start()
             threads.append(thread)
-
         for thread in threads:
             thread.join()
 
-        self.assertEqual(c.add_cached, 1)
+        return c
+
+
+class TestThreadedCachedProperty(TestCachedProperty):
+    """Tests for threaded_cached_property"""
+
+    cached_property_factory = threaded_cached_property
+
+    def test_threads(self):
+        """How well does this implementation work with threads?"""
+        num_threads = 10
+        check = self._run_threads(num_threads)
+        self.assertEqual(check.add_cached, 1)
+        self.assertEqual(check.total, 1)