"""
Public key signatures based on secure hashes

Because of how it's constructed, the security of this algorithm 
is very easy to analyze. It has at least 80 bits of security, and 
probably quite a bit more in practice depending on how it's used.

see the documentation of the functions sign() and verify() for 
more information
"""

# written March 2001 by Bram Cohen
# this file is public domain

from cStringIO import StringIO
from sha import sha

def sign(data, private_key, iterations, steps, numbers):
    """
    data is the thing to be signed, it must be of length 20
    
    the easiest way to find the public key for a private key is to 
    sign something with it

    it's safe to use the same private key with different 
    iterations and steps, but not the exact same numbers array.
    using different iterations or steps will produce a different 
    public key though

    time to compute signatures is (2 ** iterations) * steps
    
    signature size is mostly proportional to steps
    
    numbers is an array of length steps containing nonnegative 
    integers less than 2 ** iterations
    
    signing two different values using the same iterations, steps, and 
    numbers could make it possible for an attacker to forge signatures. 
    to avoid this, either keep careful track of what numbers arrays you've 
    used in the past, or make it large enough that picking them randomly 
    will be unlikely to cause a collision
    
    returns (signature, public key)
    """
    assert 0 <= iterations and iterations < 256
    assert 0 < steps and steps < 256
    assert len(numbers) == steps
    assert len(data) == 20
    sig = StringIO()
    sig.write(chr(iterations))
    sig.write(chr(steps))
    # chr(steps) is to make it safe to use the same key with different numbers of steps
    keys = [sh(chr(steps) + 'f' + private_key)]
    for i in xrange(steps - 1):
        key = keys[-1]
        num = numbers[i] + 1
        while num > 1:
            if num % 2 == 0:
                key = sh('g' + key)
            else:
                key = sh('h' + key)
            num = num / 2
        keys.append(key)
    keys.reverse()
    next = data
    for i in xrange(steps):
        signature, next = sign_single(next, keys[i], iterations, numbers[steps - 1 - i])
        sig.write(signature)
    return sig.getvalue(), next

def test_sign():
    x = sign('a' * 20, 'abc', 3, 2, [5, 2])
    y = sign('a' * 20, 'abc', 3, 2, [5, 4])
    assert x[1] == y[1]
    assert x[0] != y[0]
    l = 20 + 165 * 20 + 82 * 20 + 21 * 3
    assert x[0][-l:] == y[0][-l:]
    z = sign('b' * 20, 'abc', 3, 2, [5, 2])
    assert z[1] == x[1]
    assert z[0] != x[0]
    w = sign('a' * 20, 'ab', 3, 2, [5, 2])
    assert w[0] != x[0]
    assert w[1] != x[1]

def verify(sig):
    """
    returns (signed value, public key, iterations, numbers)

    mostly does the inverse of what sign() does
    
    raises a ValueError if the signature doesn't check out
    
    You probably aren't interested in iterations or numbers, but 
    make sure to double-check the signed value and public key!
    """
    if len(sig < 2):
        raise ValueError('signature lengths must be at least two')
    iterations = ord(sig[0])
    steps = ord(sig[1])
    if steps == 0:
        raise ValueError('must have at least one step')
    l = 20 + 165 * 20 + 82 * 20 + 21 * iterations
    if len(sig) % l != 2:
        raise ValueError('bad signature length')
    data, last, number = verify_single(sig[2:2 + l], iterations)
    numbers = [number]
    for i in xrange(2 + l, len(sig), l):
        signed, next, number = verify_single(sig[i:i + l], iterations)
        numbers.append(number)
        if signed != last:
            raise VerifyError('wrong value signed')
        last = next
    numbers.reverse()
    return data, last, iterations, numbers

def test_verify():
    x = sign('a' * 20, 'abc', 3, 2, [5, 2])
    y = verify(x[0])
    assert y[0] == 'a' * 20
    assert y[1] == x[1]
    assert y[2] == 3
    assert y[3] == [5, 2]

def sign_single(data, private_key, iterations, number):
    """
    returns (signature, public key)
    """
    assert int(iterations) == iterations
    size = 2 ** iterations
    assert 0 <= number and number < size
    assert int(number) == number
    assert len(data) == 20
    sig = StringIO()
    sig.write(data)
    hashes = [None] * (2 * size)
    # this is to make it safe to use the same private key with different numbers of iterations
    keys = [sh(chr(iterations) + 'b' + private_key)]
    keys.append(sh('c' + keys[0]))
    # this is so signing can be in parallel
    for i in xrange(1, size/2):
        keys.append(sh('d' + keys[i]))
        keys.append(sh('e' + keys[i]))
    # generate the public key and signature together
    for i in xrange(size):
        key = keys[i]
        stretch = StringIO()
        if i == number:
            # generate the signature
            pre = StringIO()
            # the obvious way of doing this would be to have 160 pairs of 
            # two hashes and reveal one of each pair for a signature - 
            # that would result in 320 hashes, this technique 
            # reveals a subset of 82 of 165, which works because 
            # 165 choose 82 is greated than 2 ** 160, and almost 
            # halves the size of signatures
            sbits = splitbits(make_number(data), 165, 83)
            for j in xrange(165):
                if sbits[j]:
                    pre.write(sh(key))
                stretch.write(sh(sh(key)))
                key = sh('a' + key)
            sig.write(stretch.getvalue())
            sig.write(pre.getvalue())
        else:
            for j in xrange(165):
                stretch.write(sh(sh(key)))
                key = sh('a' + key)
        hashes[i + size] = sh(stretch.getvalue())
    # make the chain leading to the public key
    for i in xrange(size - 1, 0, -1):
        hashes[i] = sh(hashes[2 * i] + hashes[2 * i + 1])
    c = int(number) + size
    while c > 1:
        if c % 2 == 0:
            sig.write(chr(0))
            sig.write(hashes[c + 1])
        else:
            sig.write(chr(1))
            sig.write(hashes[c - 1])
        c = c / 2
    return sig.getvalue(), hashes[1]

def verify_single(sig, iterations):
    """
    returns (signed hash, public key, number)
    """
    if len(sig) != 20 + 165 * 20 + 82 * 20 + 21 * iterations:
        raise ValueError('wrong length for signature')
    data = sig[:20]
    # check that the pre-hashes are correct
    b = splitbits(make_number(data), 165, 83)
    stretch = sig[20:20 + 165 * 20]
    pre = sig[20 + 165 * 20:20 + 165 * 20 + 82 * 20]
    last = 0
    for i in xrange(165):
        if b[i]:
            r = pre[last * 20:(last + 1) * 20]
            p = sh(r)
            q = stretch[i * 20:(i + 1) * 20]
            if p != q:
                raise ValueError('invalid prehash, signature check failed')
            last = last + 1
    # check which root the chain leads to
    chain = sig[-21 * iterations:]
    key = sh(stretch)
    bs = []
    for i in xrange(iterations):
        bit = ord(chain[21 * i])
        bs.append(bit)
        if bit == 0:
            key = sh(key + chain[i * 21 + 1:i * 21 + 21])
        elif bit == 1:
            key = sh(chain[i * 21 + 1:i * 21 + 21] + key)
        else:
            raise ValueError('illegal character in signature')
    # calculate which number position was used
    bs.reverse()
    num = 0
    for i in bs:
        num = (num << 1) | i
    return (data, key, num)

def test_single_signature():
    key1 = ';slkjfds;ldf'
    key2 = 'lksjdf;kljfds'
    d1 = sh('auhe')
    d2 = sh('boug')
    s1 = sign_single(d1, key1, 6, 0)
    s2 = sign_single(d1, key1, 6, 63)
    s3 = sign_single(d1, key2, 6, 0)
    s4 = sign_single(d2, key2, 6, 0)
    s5 = sign_single(d1, key1, 4, 0)
    assert s1[1] == s2[1]
    assert s3[1] == s4[1]
    assert s1[1] != s3[1]
    assert s1[1] != s5[1]

    assert s1[0] != s2[0]
    assert s1[0] != s3[0]
    assert s1[0] != s4[0]
    assert s3[0] != s4[0]

def test_verify():
    data = sh(',.prcgvwvqjk')
    key = sh('oeug.,,.p')
    x = sign_single(data, key, 4, 3)
    y = verify_single(x[0], 4)
    assert y[0] == data
    assert y[1] == x[1]
    assert y[2] == 3

    try:
        verify_single(splat(x[0], 0), 4)
        raise 'failed'
    except ValueError:
        pass
    try:
        verify_single(splat(x[0], 170 * 20), 4)
        raise 'failed'
    except ValueError:
        pass
    try:
        verify_single(splat(x[0], -21), 4)
        raise 'failed'
    except ValueError:
        pass
    try:
        verify_single(x[0] + '\000', 4)
        raise 'failed'
    except ValueError:
        pass
    b = splitbits(make_number(data), 165, 83)
    try:
        verify_single(splat(x[0], (b.index(1) + 1) * 20), 4)
        raise 'failed'
    except ValueError:
        pass
    y = verify_single(splat(x[0], (b.index(0) + 1) * 20), 4)
    assert y[0] == data
    assert y[1] != x[1]
    assert y[2] == 3
    y = verify_single(splat(x[0], len(x[0]) - 1), 4)
    assert y[0] == data
    assert y[1] != x[1]
    assert y[2] == 3
    y = verify_single(splat(x[0], -21, 1), 4)
    assert y[0] == data
    assert y[1] != x[1]
    assert y[2] != 3

def splat(s, pos, x = 0xFF):
    r = s[:pos] + chr(ord(s[pos]) ^ x) + s[pos + 1:]
    assert r != s
    return r

def sh(s):
    return sha(s).digest()

def make_number(s):
    r = 0l
    for c in s:
        r = (r << 8) | ord(c)
    return r

def test_make_number():
    assert make_number('\001\002') == 258
    assert make_number('\200') == 128

def choose(a, b):
    r = 1l
    for i in xrange(b + 1, a + 1):
        r = r * i
    for i in xrange(1, a - b + 1):
        r = r / i
    return r

def test_choose():
    assert choose(50, 1) == 50
    assert choose(50, 49) == 50
    assert choose(50, 0) == 1
    assert choose(50, 50) == 1
    assert choose(50, 10) == choose(50, 40)

def splitbits(n, bound, size):
    assert 0 <= n
    limit = choose(bound, size)
    assert n < limit
    uppertotal = limit * (bound - size) / bound
    lowertotal = limit * size / bound
    n0 = size
    n1 = bound - size
    result = []
    while n0 > 0 and n1 > 0:
        if n < lowertotal:
            result.append(0)
            uppertotal = (lowertotal * n1) / (n0 + n1 - 1)
            lowertotal = (lowertotal * (n0 - 1)) / (n0 + n1 - 1)
            n0 = n0 - 1
        else:
            result.append(1)
            n = n - lowertotal
            lowertotal = (uppertotal * n0) / (n0 + n1 - 1)
            uppertotal = (uppertotal * (n1 - 1)) / (n0 + n1 - 1)
            n1 = n1 - 1
    if n1 == 0:
        while len(result) < bound:
            result.append(0)
    else:
        while len(result) < bound:
            result.append(1)
    return result

def test_splitbits(b = 7, s = 3):
    c = choose(b, s)
    ss = []
    for i in xrange(c):
        ss.append(splitbits(i, b, s))
    assert ss[0] == [0] * s + [1] * (b - s)
    assert ss[-1] == [1] * (b - s) + [0] * s
    for x in ss:
        assert x.count(1) == b - s
        assert x.count(0) == s
        assert len(x) == b
    for i in xrange(c):
        for j in xrange(c):
            if i < j:
                assert ss[i] != ss[j]
