Binary Exponentiation

June 5, 2009

(The code examples in this post are also available here.)

Although I have already written something on binary exponentiation, as it applied to modular exponentiation, there’s plenty more to be said about it in more general settings. Basically, today we are after computing the n-th power of a, but a need not be an integer, rational, real or even complex number, but can be any more general mathematical object that can be raised to a power, as a square matrix or a polynomial could be…

For an n-th power to exist, the object being exponentiated must fulfill certain conditions, which can basically be summarized in belonging to a more general set that has a well-behaved multiplication operation defined. Good behavior can be translated into closure and associativity, which in mathematical slang can be expressed as the object belonging to a semigroup. This will often mean that the object belongs to a monoid, and even to a ring or rng… But I’ll follow V.I. Arnold‘s advice, and not get too carried away with the abstract generalizations, so enough said: the thing must be multipliable, and the result must be of the same kind.

Defining Multiplication

Because we are dealing with general mathematical objects, we first need to define the multiplication operation. There are two ways of dealing with this in python. To work on an example, say we want to deal with 2×2 matrices. We could define a Matrix2x2 class, and write a __mul__ method for it, so that whenever we write a*b, what gets evaluated is a.__mul__(b). For the same price, I’ve thrown in an __add__ method…

class Matrix2x2(object) :
    def __init__(self,elements = None) :
        if elements :
            self._data_ = elements[:4]
        else :
            self._data_ = [0] * 4

    def __repr__(self) :
        as_string = [str(j) for j in self._data_]
        str_length = [len(j) for j in as_string]
        longest = max(str_length)
        for j in xrange(4) :
            as_string[j] =' '*(longest - str_length[j]) + as_string[j]
        ret = 'Matrix2x2 object:'
        for j in xrange(2) :
                ret += '\n[ %s %s ]' % (as_string[2*j], as_string[2*j+1])
        return ret
        

    def __mul__(self, b) :
        ret = [self._data_[0]*b._data_[0] + self._data_[1]*b._data_[2]]
        ret += [self._data_[0]*b._data_[1] + self._data_[1]*b._data_[3]]
        ret += [self._data_[2]*b._data_[0] + self._data_[3]*b._data_[2]]
        ret += [self._data_[2]*b._data_[1] + self._data_[3]*b._data_[3]]
        return Matrix2x2(ret)

    def __add__(self, b) :
        return Matrix2x2([self._data_[j] + b._data_[j] for j in xrange(4)])

Or if we don’t want to go through the hassle of dealing with objects, we can agree with ourselves in keeping the matrix’s data in a list in row-major order, and define a standalone multiplication function…

def mat_mul(a, b) :
    """
    Returns the product of two 2x2 square matrices.

    Computes the product of two 2x2 matrices, each stored in a four element
    list in row major order.
    """

    ret = [a[0]*b[0] + a[1]*b[2]]
    ret += [a[0]*b[1] + a[1]*b[3]]
    ret += [a[2]*b[0] + a[3]*b[2]]
    ret += [a[2]*b[1] + a[3]*b[3]]
    return ret

Direct Multiplication

To set a baseline for performance, the starting point has to be the most direct implementation of exponentiation: repeated multiplication. There really isn’t much mistery to it…

import nrp_base

@nrp_base.performance_timer
def direct_pow(a, n, **kwargs) :
    """
    Computes a**n by direct multiplication.

    Arguments
     a - The object to be exponentiated.
     n - The integral exponent.

    Keyword Argument
     mul - A function taking two arguments of the same type as a, and
           returning their product. If undefined, a's __mul__ method will
           be used.
           
    """
    mul = kwargs.pop('mul',None)
    ret = a
    if mul is None :
        mul = lambda x,y : x*y
    for j in xrange(n-1) :
        ret = mul(ret,a)
    return ret

The Two Flavors of Binary Exponentiation

OK, so now we are ready for the fun to begin… The basic idea behind speeding exponentiation up is that a4 can be computed with two, instead of three, multiplications, if rather than doing it as a4 = a·a·a·a, we do it as a2 = a·a, a4 = a2·a2, storing the intermediate result for a2

To expand this idea to exponents other than 4 we basically need to write the exponent in binary. Examples are great for figuring out this kind of things, so we’ll take computing a19 as ours. Knowing that 19 in binary is 10011, we now have a choice on how to proceed, which basically amount to start using the bits of the exponent from the least significant one, or from the most significant one…

Least Significant Bit First

That 19 in binary is 10011 can also be interpreted as 19 = 20 + 21 + 24, and so a19 = a20·a21·a24… The algorithm then is pretty straightforward:

  • start computing sequentially from j = 0 the values of a2j, where each new value is the square of the preceding one,
  • when the first non-zero bit is found, set the return value to the corresponding a2j value,
  • for subsequent non-zero bits, multiply the return value by the corresponding a2j value,
  • once all bits of the exponent have been searched, the return value holds the sought value.

If the exponent is n, then there will be log2 n squarings, to compute the a2j values, plus one multiplication less than there are 1’s in the binary expansion of the exponent.

Most Significant Bit First

It is also possible to write a19 = ((a23·a)2)·a. While it is easy to verify that is the case, it may be harder to see where does it come from. The algorithm is as follows:

  • Take the exponent bits from most to least significant,
  • since the first bit is always a 1, set the return value to a,
  • for every bit processed, first square the return value, and then, if the bit is a 1, multiply the return value by a,
  • again, once the last bit is used, the return value holds the sought value.

The analysis of the algorithm is similar to the previous one: there will again be log2 n squarings, and the same number of multiplications as well.

These second flavor has one clear disadvantage: there’s no simple way to generate the bits of a number from most to least significant. So this approach requires computing and storing the bits from least to most significant, and then retrieve them in inverse order. No that it is anything too complicated, but there are additional time costs involved.

So why would one even worry about this scheme then? Well, it also has a potential advantage: the multiplications performed always involves the original object being exponentiated. This can prove beneficial in two different ways:

  1. it lends itself better for optimization of particular cases, and
  2. if the original element involves small numbers, but the end result involves large numbers, the multiplications are usually faster than in the previous algorithm.

To see the differences with examples, it is better to first write the python code…

@nrp_base.performance_timer
def bin_pow(a, n, **kwargs) :
    """
    Computes a**n by binary exponentiation.

    Arguments
     a - The object to be exponentiated.
     n - The integral exponent.

    Keyword Argument
     mul - A function taking two arguments of the same type as a, and
           returning their product. If undefined, a's __mul__ method will
           be used.
     sqr - A function taking a single argument of the same type as a, and
           returning its square. If undefined, mul(a,a) will be used. To
           use a method of the object's class, use sqr=ClassName.function_name.
     lsb - Set to True to use least significant bits first. Default is False.
     msb - Set to True to use most significant bits first. Overrides lsb.
           Default is True.
           
    """
    mul = kwargs.pop('mul',None)
    sqr = kwargs.pop('sqr',None)
    lsb = kwargs.pop('lsb',None)
    msb = kwargs.pop('msb',None)
    if mul is None :
        mul = lambda x, y : x * y
    if sqr is None :
        sqr = lambda x : mul(x, x)
    if lsb is None and msb is None :
        msb = True
        lsb = False
    elif msb is None :
        msb = not lsb
    else :
        lsb = not msb

    if msb :
        bits = []
        while n > 1 : # The last bit is always a 1...
            if n & 1 :
                bits += [1]
            else :
                bits += [0]
            n >>= 1
        ret = a
        while bits :
            ret = sqr(ret)
            if bits.pop() :
                ret = mul(ret,a)
    else :
        val = a
        while not n & 1 :
            n >>= 1
            val = sqr(val)
        ret = val
        while n > 1:
            n >>= 1
            val = sqr(val)
            if n & 1 :
                ret = mul(ret,val)
    return ret

Taking bin_pow for a Ride…

To get a better grip on how the algorithm performs, we’ll take a very simple matrix, which we will find again in a future post, when we discuss Fibonacci numbers

>>> a = Matrix2x2([1,1,1,0])
>>> a
Matrix2x2 object:
[ 1 1 ]
[ 1 0 ]

So lets raise it first to a small power, and see how the two algorithms perform…

>>> val = bin_pow(a,10,lsb=True,verbose=True,timer_loops=100)
Call to bin_pow took:
2.84952e-05 sec. (min)
0.000205613 sec. (max)
3.363e-05 sec. (avg)
0.000178629 sec. (st_dev)
>>> val = bin_pow(a,10,msb=True,verbose=True,timer_loops=100)
Call to bin_pow took:
3.12889e-05 sec. (min)
5.16825e-05 sec. (max)
3.32891e-05 sec. (avg)
2.93951e-05 sec. (st_dev)

Not many surprises here, as the algorithm having to store and retrieve bits is a little slower than the other. But lets see what happens if we scale things…

>>> val = bin_pow(a,1000000,lsb=True,verbose=True)
Call to bin_pow took 2.599 sec.
>>> val = bin_pow(a,1000000,msb=True,verbose=True)
Call to bin_pow took 1.76194 sec.

This sort of shows that the multiplication thing can be very relevant. A last little experiment gives more insight into the workings of this…

>>> 2**20
1048576
>>> val = bin_pow(a,2**20,lsb=True,verbose=True)
Call to bin_pow took 1.95651 sec.
>>> val = bin_pow(a,2**20,msb=True,verbose=True)
Call to bin_pow took 1.90171 sec.

Funny, isn’t it? Why would it be that the time differences are so much smaller now, even though the exponents are roughly the same? This other times may help answer that…

>>> val = bin_pow(a,2**20-1,lsb=True,verbose=True)
Call to bin_pow took 2.61069 sec.
>>> val = bin_pow(a,2**20-1,msb=True,verbose=True)
Call to bin_pow took 1.89222 sec.

Got it now? 220 has a binary expression full of 0’s, 20 of them, preceded by a single 1. This means that neither algorithm is performing any multiplication, just squarings, so the performance is almost identical. On the other hand, 220-1 has a binary expression full of 1’s, 20 of them, with no 0 anywhere. This means that, although there will be one less squaring than for the previous case, there will be 19 multiplications more. And here’s where the msb approach excels, because the involved multiplications are simpler.

Advertisements

Modular Exponentiation

March 20, 2009

I have a post cooking on wheel factorization, yet another sieving algorithm for the computation of all primes below a given number n. It turns out that while implementing it, one comes across a number of not so obvious preliminary questions. It is a rule of this blog not to resort to any non standard library, and I feel it would be breaking the spirit of that rule to post code which hasn’t been thoroughly explained in advance. So in order not to make that soon-to-be post on prime numbers longer than anyone could bear, I find myself in need of covering a couple of preliminaries, and so here comes this thrilling post on modular exponentiation

What we are after is determining ab (mod m), or to put it in plain English, the remainder of a to the power of b, divided by m. Doesn’t seem much, does it? We could just go ahead and write:

def modExp1(a, b, m) :
    """
    Computes a to the power b, modulo m, directly
    """
    return a**b % m

The problems begin if a and b are both big. Python’s integers don’t overflow, and we all love it for it, but the memory juggling it has to do to accommodate a huge number in memory does come at a speed cost. Specially if m is small, it definitely seems a little overkill to keep a zillion digit number in memory, given that the answer will be smaller than m…

We can, and will, use modular arithmetic to our advantage. See, wikipedia says that “the congruence relation (introduced by modular arithmetic) on the integers is compatible with the multiplication operation of the ring of integers,” and if it’s on the internet it has to be true, right? Translated to plain English, the idea is that if we want to take the product of two numbers, a and b, modulo m, we can first multiply them, and then take the m-modulo of the result, or we can take the m-modulo of a and/or b, prior to multiplication, and the m-modulo of the result will still be the right answer. So the least a conscious programmer should do is to rewrite the previous function as:

def modExp2(a, b, m) :
    """
    Computes a to the power b, modulo m, a little less directly
    """
    return (a%m)**b % m

Of course, if b is large enough, you may still find yourself in the realm of zillion digit numbers, which is a place we don’t really want to go. It turns out then that it may be a good option to remember that exponentiation is just repeated multiplication. If we multiply a by itself b times, and take the m-modulo of the result after every multiplication, the larger number we will ever have in memory will be m2. It doesn’t really seem right to put a loop of b iterations into the code, since what we are worried about are large b’s, but lets do it anyway:

def modExp3(a, b, m) :
    """
    Computes a to the power b, modulo m, by iterative multiplication
    """
    ret = 1
    a %= m
    while b :
        ret *= a
        ret %= m
        b -= 1
    return ret

The only reason why this last function deserves a place here, is because it helps introduce the really cool algorithm that this whole post really is about, binary exponentiation. In John D. Cook’s blog there is a great explanation on how it could be implemented iteratively, based on the binary representation of the exponent, but I’ll go for a recursive version, which is essentially the same, but which I find easier for my brain’s wiring…

So what we are about to do is, when asked to compute ab, check if b is even, in which case we will compute it as the square of ab/2. If b is odd, we will compute it as a·ab-1. If b is 0, we will of course return 1. If we also throw in all the modulo m stuff I have been discussing previously, we finally arrive at THE function to do modular exponentiation:

def modExp(a, b, m) :
    """
    Computes a to the power b, modulo m, using binary exponentiation
    """
    a %= m
    ret = None
    if b == 0 :
        ret = 1
    elif b%2 :
        ret = a * modExp(a,b-1,m)
    else :
        ret = modExp(a,b//2,m)
        ret *= ret
    return ret%m

Now that we have all four versions of the same function written, we can do a little testing, I will leave writing the testing code as an exercise for the reader (I’ve been wanting to say this since my first year of university…), but when computing 1234 to the power 5678, modulo 90, a 1000 times with each function, these is the performance I get:


>>> testModExp(1234,5678,90,1000)
Got 46 using method 1 in 3.73399996758 sec.
Got 46 using method 2 in 0.546999931335 sec.
Got 46 using method 3 in 1.25 sec.
Got 46 using THE method in 0.0150001049042 sec.

Results which really don’t require much commenting…

EDIT: There is some very interesting information in the comments: python comes with a built-in pow function, which can take two or three arguments, the third being, precisely, the modulo. It probably implements an algorithm similar to what has been presented, but is more than an order of magnitude faster.