Synchronized in Python

14/10/2011

Some more fun with python decorators. This time about threading.

Java gives us some nice threading primitives built into the language including the abstraction of the synchronized keyword. What this keyword ensures is that no two threads will enter the same block marked as synchronized. We can actually build into python almost identical functionality using decorators and some basic meta-programming.

First we'll tackle the synchronized block syntax in Java. It looks like this:

 

/* Enter critical section */
synchronized(mLock) {
	
    /* Do critical work */
		
}
/* Exit critical section */

 

This can appear pretty much anywhere in Java code and any non-native Object can be used as a lock. For the same in python we don't have to add anything - it actually has this already built in - but instead of the synchronizedwith keyword.

 

# Enter critical section
with self.lock:
    # Do critical work
	
# Exit critical section

 

Pretty straight forward - and the lovely thing about the pythonic way is that the functionality of the "with" keyword can be defined for any class using the __enter__ and __exit__ methods. You might have already seen it being used for files and a host of other things.

There is also a second advantage to using the with keyword over simply wrapping the critical section in self.lock.acquire() and self.lock.release(). Using with ensures that the lock is always released, even if an exception occurs within the critical section. It will ensure that __exit__ is always called, much like the finally statement in a try block.

The second way in which synchronized appears in Java is in a method declaration. It looks something like this.

 

public class Counter {
		
    int mTotal;
	
    public Counter() {
 	mTotal = 0;
    }

    public synchronized void addOne() {
        int val = mTotal;
        val++;
        mTotal = val;
    }
}

 

This ensures that no two threads can both be inside a class instance' method at the same time. For this example it ensures that the total always get updated correctly.

For Python let us start with something a little simpler than a method. We can build this syntax for plain old functions using our own hand made decorators. The decorator looks something like this:

 

import threading
	
def synchronized(func):
	
    func.__lock__ = threading.Lock()
		
    def synced_func(*args, **kws):
        with func.__lock__:
            return func(*args, **kws)

    return synced_func

 

As you can see, it takes a function, attaches a lock to that function, and wraps the function within that lock. As in Java this ensures that no two threads can be inside the function at the same time. Here is an example of it at work.

 

import time
	
total = 0
		
@synchronized
def count():
    global total
    curr = total + 1
    time.sleep(0.1)
    total = curr
		
def counter():
    for i in range(0,10): count()
		
thread1 = threading.Thread(target = counter)
thread2 = threading.Thread(target = counter)

thread1.start()
thread2.start()
    
thread1.join()
thread2.join()

print total

 

With the function count synchronized this script should correctly print out 20 for the value of total. If the decorator is removed then technically the value of total is some unknown between 10 and 20, but in this script it will tend to be 10 as the timing almost always ensures both threads enter the critical section at the same time and the total is updated incorrectly.

But unfortunately the decorator above wont work for methods. The reason for this is the way in which classes are instanced. If you were to apply the above to a class method it would create a lock across all instances of that method (the actual method function only exists in one place). If we want a lock on a per-instance basis we have to do something else.

 

def synchronized_method(method):
    
    outer_lock = threading.Lock()
    lock_name = "__"+method.__name__+"_lock"+"__"
    
    def sync_method(self, *args, **kws):
        with outer_lock:
            if not hasattr(self, lock_name): setattr(self, lock_name, threading.Lock())
            lock = getattr(self, lock_name)
            with lock:
                return method(self, *args, **kws)  

    return sync_method

class Counter:
		
    def __init__(self):
        self.total = 0
			
    @synchronized_method
    def add_one(self):
        val = self.total
        val += 1
        self.total = val

 

This is a bit more messy but we can solve it by taking advantage of the fact that all method functions take a self value as their first parameter representing the class instance. This new decorator basically checks if the class instance already has a lock for the method to use. If a lock exists it uses it, otherwise it creates a new one and attaches it to the instance. Worth nothing that as this class instance lock is created at run-time we need an outer lock, created before anything in the instance, to ensure there is no race condition with who gets to create the first class instance lock (imagine two threads calling this method for the first time, at the same time). For this we create an outer lock at the time of class definition (when the decorator is applied) and use this to ensure only one instance lock is created.

In Java method synchronization all synchronization methods refer to the same object instance lock. In our python example above it is slightly different. Each of our methods uses a different lock.

We can create something more similar though. What we can do is to lock whenever we wish to manipulate a certain member of a class instance.

 

import threading
import time

def synchronized_with_attr(lock_name):
    
    def decorator(method):
			
        def synced_method(self, *args, **kws):
            lock = getattr(self, lock_name)
            with lock:
                return method(self, *args, **kws)
                
        return synced_method
		
    return decorator
    

class Counter:
    
    def __init__(self):
        self.lock = threading.RLock()
        self.total = 0
		
    @synchronized_with_attr("lock")
    def add_one(self):
        val = self.total
        val += 1
        time.sleep(0.1)
        self.total = val
	
    @synchronized_with_attr("lock")
    def add_two(self):
        val = self.total
        val += 2
        time.sleep(0.1)
        self.total = val


The most obvious difference is that we are now passing in the decorator a string with the name of the lock attribute. The reason for this is that decorators are applied at the time the class is defined, not at the point at which it is instanced. Before it is instanced python has no idea that self.lock actually exists. Because of this we have to give the name as a string and rely on the fact that it will accurately find the correct member once the class is instanced.

The second thing to notice is that I'm now using an RLock instead of a normal lock. This is a re-entrant lock and allows a thread to call acquire on a lock more than once if it already holds the lock. This ensures that if a method calls another method internally it will not cause deadlock as a thread tries to acquire a lock it already owns. In fact we might have wanted to use an RLock for our previous method decorator - otherwise recursion would cause a deadlock.

Finally if we take a look at our decorator synchronized_with_attr we realize this isn't a decorator in itself, but in fact a function which returns a decorator. This is a key aspect of how we argument decorators and begins to show their full power.

Sometimes we wish to lock over all of the methods for a class. Creating what is essentially a thread safe data structure. In Java this is easy - we just add the synchronized keyword to the class definition.

 

public synchronized class Counter {
	
    int mTotal;
	
    public MyClass() {
        mTotal = 0;
    }

    public void addOne() {
        int val = mTotal;
        val += 1;
        mTotal = val;
    }
	
    public void addTwo() {
        int val = mTotal;
        val += 2;
        mTotal = val;
    }
}

 

We can do this in python too, though the solution becomes a bit more complicated. Let us try and create an appropriate decorator for a class.

 

import threading
import types

def synchronized_with(lock):
		
    def decorator(func):
			
        def synced_func(*args, **kws):
            with lock:
	        return func(*args, **kws)
        return synced_func
		
    return decorator


def synchronized_class(sync_class):
	
    lock = threading.RLock()
	
    orig_init = sync_class.__init__
    def __init__(self, *args, **kws):
        self.__lock__ = lock
        orig_init(self, *args, **kws)
    sync_class.__init__ = __init__
	
    for key in sync_class.__dict__:
        val = sync_class.__dict__[key]
        if type(val) is types.FunctionType:
            decorator = synchronized_with(lock)
            sync_class.__dict__[key] = decorator(val)
    
    return sync_class

 

Let me explain what is happening here.

The first thing that happens is we create a new lock for the class to use. We then override the __init__ method of the class so that it first assigns this new lock as the class member __lock__ before calling the old __init__ function. We then loop over all of the items in the class dictionary. We check for which ones are functions, and if they are we apply our synchronized_with decorator to them with the lock we created at the beginning. We then return the modified class. And that's it! We have a synchronized class.

There is one more nice tweak we can do, which is to combine all of these new decorators into one function that decides which decorator is appropriate to apply. This is fairly straight forward - we simply look at the argument to the function. If the argument is a lock then we know we must return a new decorator using that lock. If the argument is a string we try to apply the attribute synchronization. If we get anything else (such as a function or a class) then we apply the decorator as usual.

 

import thread
import threading
import types

def synchronized_with_attr(lock_name):
    
    def decorator(method):
			
        def synced_method(self, *args, **kws):
            lock = getattr(self, lock_name)
            with lock:
                return method(self, *args, **kws)
                
        return synced_method
		
    return decorator

    
def syncronized_with(lock):
	
    def synchronized_obj(obj):
		
        if type(obj) is types.FunctionType:
            
            obj.__lock__ = lock
			
            def func(*args, **kws):
                with lock:
                    obj(*args, **kws)
            return func
			
        elif type(obj) is types.ClassType:
            
            orig_init = obj.__init__
            def __init__(self, *args, **kws):
                self.__lock__ = lock
                orig_init(self, *args, **kws)
            obj.__init__ = __init__
            
            for key in obj.__dict__:
                val = obj.__dict__[key]
                if type(val) is types.FunctionType:
                    decorator = syncronized_with(lock)
                    obj.__dict__[key] = decorator(val)
            
            return obj
	
    return synchronized_obj
	
	
def synchronized(item):
	
    if type(item) is types.StringType:
        decorator = synchronized_with_attr(item)
        return decorator(item)
    
    if type(item) is thread.LockType:
        decorator = syncronized_with(item)
        return decorator(item)
		
    else:
        new_lock = threading.Lock()
        decorator = syncronized_with(new_lock)
        return decorator(item)

 

And with about 50 lines of code we've added the synchronization primitives to python, beautiful!

 

@synchronized
class Counter:
	
    def __init__(self):
        self.counter = 0
		
    def add_one(self):
        val = self.counter
        val += 1
        time.sleep(0.1)
        self.counter = val
		
    def add_two(self):
        val = self.counter
        val += 2
        time.sleep(0.1)
        self.counter = val
		
		
my_counter = Counter()

def class_counter1():
    global my_counter
    for i in range(0,10): my_counter.add_one()
	
def class_counter2():
    global my_counter
    for i in range(0,10): my_counter.add_two()

thread1 = threading.Thread(target = class_counter1)
thread2 = threading.Thread(target = class_counter2)

thread1.start()
thread2.start()

thread1.join()
thread2.join()

print my_counter.counter