# this code is public domain
#
# written by Bram Cohen 2001-02-09

from time import time, sleep
from sha import sha
import xdrlib
from cStringIO import StringIO

def make_random_seed():
    return sha(str(time())).digest()[:14]

def test_make_random():
    a = make_random_seed()
    sleep(1)
    b = make_random_seed()
    assert len(a) == 14
    assert len(b) == 14
    assert a != b

def make_key(seed):
    assert len(seed) == 14
    private_key = sha(seed).digest()
    public_key = sha(private_key).digest()
    key_identifier = sha(public_key).digest()
    return (private_key, public_key, key_identifier)

def test_make_key_repeatable():
    seed = make_random_seed()
    assert make_key(seed) == make_key(seed)

def xor(string1, string2):
    assert len(string1) == len(string2)
    s = StringIO()
    for i in xrange(len(string1)):
        s.write(chr(ord(string1[i]) ^ ord(string2[i])))
    return s.getvalue()

def test_xor():
    s1 = 'satnoheurgf.ypzaboxksaqjzvowe.gf,'
    s2 = 'saoepcr.,gydoawbxkm;q/j,r.cgyaod '
    assert xor(s1, s2) == xor(s2, s1)
    assert xor(s1, s1) == '\000' * len(s1)
    assert xor(xor(s1, s2), s1) == s2

def encrypt(plaintext, public_keys):
    symmetric_key = sha(str(time())).digest()
    p = xdrlib.Packer()
    key_identifiers = []
    for i in public_keys:
        key_identifiers.append(sha(i).digest())
    p.pack_array(key_identifiers, p.pack_string)
    key_encodings = []
    for i in public_keys:
        key_encodings.append(xor(i, symmetric_key))
    p.pack_array(key_encodings, p.pack_string)
    reps = (len(plaintext) / len(symmetric_key)) + 1
    main_ciphertext = xor((symmetric_key * reps)[:len(plaintext)], plaintext)
    p.pack_string(main_ciphertext)
    return p.get_buffer()

def decrypt(ciphertext, private_key):
    try:
        u = xdrlib.Unpacker(ciphertext)
        key_identifiers = u.unpack_array(u.unpack_string)
        key_encodings = u.unpack_array(u.unpack_string)
        main_ciphertext = u.unpack_string()
        u.done()
    except xdrlib.Error, e:
        return (None, str(e))
    if len(key_identifiers) != len(key_encodings):
        return (None, 'number of key ids (' + str(len(key_identifiers)) + ') and number of key encodings (' + str(len(key_encodings)) + ") don't match up")
    public_key = sha(private_key).digest()
    kid = sha(public_key).digest()
    try:
        i = key_identifiers.index(kid)
    except ValueError, e:
        return (None, 'not encrypted with given private key')
    symmetric_key = xor(public_key, key_encodings[i])
    reps = (len(main_ciphertext) / len(symmetric_key)) + 1
    plaintext = xor((symmetric_key * reps)[:len(main_ciphertext)], main_ciphertext)
    return (plaintext, None)

def get_key_identifiers(ciphertext):
    try:
        u = xdrlib.Unpacker(ciphertext)
        return (tuple(u.unpack_array(u.unpack_string)), None)
    except xdrlib.Error, e:
        return (None, str(e))

def test_encrypt_decrypt():
    seed1 = make_random_seed()
    sleep(1)
    seed2 = make_random_seed()
    sleep(1)
    seed3 = make_random_seed()
    assert seed1 != seed2
    (private_key1, public_key1, key_identifier1) = make_key(seed1)
    (private_key2, public_key2, key_identifier2) = make_key(seed2)
    (private_key3, public_key3, key_identifier3) = make_key(seed3)
    plaintext = 'I will not make my computer do busywork for me. ' * 100
    ciphertext = encrypt(plaintext, [public_key1, public_key2])
    assert get_key_identifiers(ciphertext) == ((key_identifier1, key_identifier2), None)
    assert decrypt(ciphertext, private_key1) == (plaintext, None)
    assert decrypt(ciphertext, private_key2) == (plaintext, None)
    assert decrypt(ciphertext, private_key3)[0] == None
    assert decrypt(ciphertext + 'a', private_key1)[0] == None

mojo_test_flag = 1

