from PIL import Image
import bitarray
from Crypto.PublicKey import RSA
from Crypto.Cipher import PKCS1_OAEP


def generate_sk():
    key = RSA.generate(4096)
    with open("cobra_private_key.pem", "wb") as f_out:
        f_out.write(key.exportKey(format='PEM'))
    return key


def import_key(file_name):
    with open(file_name, 'rb') as f_in:
        key = RSA.importKey(f_in.read())
    return key


def generate_pk(sk):
    pk = sk.publickey()
    with open("cobra_public_key.pem", "wb") as f_out:
        f_out.write(pk.exportKey(format='PEM'))
    return pk


def encrypt(pk_receiver, message):
    cipher = PKCS1_OAEP.new(pk_receiver)
    c = cipher.encrypt(message)
    return c


def decrypt(sk, c):
    cipher = PKCS1_OAEP.new(sk)
    m = cipher.decrypt(c)
    return m


def generate_new_keys():
    generate_sk()
    key = import_key('cobra_private_key.pem')
    generate_pk(key)
    print("\033[42m\033[34mNew Keys generated!\033[00m\n")
    main()


def file_to_bitarray(path, public_key):
    result = bitarray.bitarray()
    with open(path, 'r') as file:
        message = file.read()
    text = encrypt(import_key(public_key), message.encode('UTF-8'))
    result.frombytes(text)
    return result


def bitarray_to_file(path, bits, private_key):
    message = bits.tobytes()
    text = decrypt(import_key(private_key), message).decode('UTF-8')
    with open(path, "w") as file:
        file.write(str(text))


def set_last_bit(value, integer):
    bits = '{0:b}'.format(integer)
    return int(bits[0:len(bits) - 1] + str(value), 2)


def get_last_bit(integer):
    bits = '{0:b}'.format(integer)
    return int(bits[len(bits) - 1])


def hide(in_file, secret, out_file, public_key):
    im = Image.open(in_file)
    pic = im.load()
    width, height = im.size
    to_hide = file_to_bitarray(secret, public_key)
    k = 0
    l = len(to_hide)
    for y in range(height):
        for x in range(width):
            pixel = pic[x, y]
            r = set_last_bit(to_hide[k % l], pixel[0])
            k += 1
            g = set_last_bit(to_hide[k % l], pixel[1])
            k += 1
            b = set_last_bit(to_hide[k % l], pixel[2])
            k += 1
            pic[x, y] = (r, g, b)
    im.save(out_file)
    print("\033[42m\033[34mSuccess!\033[00m\n")
    main()


def seek(in_file, out_file, private_key):
    im = Image.open(in_file)
    pic = im.load()
    width, height = im.size
    bits = ''
    for y in range(height):
        for x in range(width):
            pixel = pic[x, y]
            bits += str(get_last_bit(pixel[0]))
            bits += str(get_last_bit(pixel[1]))
            bits += str(get_last_bit(pixel[2]))
    msg = ''
    i = 0
    for j in range(8, len(bits), 8):
        msg += bits[i:j]
        i += 8
        if bits[0:j+j] == msg+msg:
            break
    bitarray_to_file(out_file, bitarray.bitarray(msg), private_key)
    print("\033[42m\033[34mSuccess!\033[00m\n")
    main()


def hide_filler():
    in_file = input("\033[00mIn which png do you want to hide your txt? \033[32m")
    secret = input("\033[00mWhich txt do you want to hide? \033[32m")
    out_file = input("\033[00mHow do you want to name the output png? \033[32m")
    public_key = input("\033[00mPublic Key \033[32m")
    try:
        hide(in_file, secret, out_file, public_key)
    except:
        print("\033[41m\033[37mAn error occured\033[00m\n")


def seek_filler():
    in_file = input("\033[00mIn which png do you want to seek for information? \033[32m")
    out_file = input("\033[00mHow do you want to name the output txt? \033[32m")
    private_key = input("\033[00mPrivate Key \033[32m")
    try:
        seek(in_file, out_file, private_key)
    except:
        print("\033[41m\033[37mAn error occured\033[00m\n")


def generate_pk_filler():
    key = import_key(input("\033[00mPrivate Key \033[32m"))
    generate_pk(key)
    print("\033[42m\033[34mNew Public Key generated!\033[00m\n")
    main()


def generate_chooser():
    print("(" + "\033[33m{}\033[00m".format("1") + ") Generate new Public Key")
    print("(" + "\033[33m{}\033[00m".format("2") + ") Generate both Keys. WARNING: YOU CANNOT DECIPHER OLD PNGS WITH A NEW PRIVATE KEY!")
    print("(" + "\033[33m{}\033[00m".format("3") + ") Back\n")
    option = int(input("\033[91m{}\033[32m".format("1/2/3 >> ")))
    while (option != 1) and (option != 2) and (option != 3):
        option = int(input("\033[91m{}\033[32m".format("1/2/3 >> ")))
    if option == 1:
        generate_pk_filler()
    elif option == 2:
        generate_new_keys()
    elif option == 3:
        main()


def main():
    app_name = "COBRA v1.0 by Andre Mertes"
    cobra_image = "         ,,'6''-,.\n        <====,.;;--.\n        _`---===. \"\"\"==__\n      //\"\"@@-\===\@@@@ \"\"\\\\\n     |( @@@  |===|  @@@  ||\n      \\\\ @@   |===|  @@  //\n        \\\\ @@ |===|@@@ //\n         \\\\  |===|  //\n___________\\\\|===| //_____,----\"\"\"\"\"\"\"\"\"\"-----,_\n  \"\"\"\"---,__`\===`/ _________,---------,____    `,\n             |==||                           `\   \\\n            |==| |          pb                 )   |\n           |==| |       _____         ______,--'   '\n           |=|  `----\"\"\"     `\"\"\"\"\"\"\"\"         _,-'\n            `=\     __,---\"\"\"-------------\"\"\"''\n                \"\"\"\"\";"
    print("\033[36m{}\033[00m".format("\n\n--- "+app_name+" ---\n"))
    print("\033[32m{}\033[32m".format(cobra_image))
    print("\n("+"\033[33m{}\033[00m".format("1")+") Hide a txt in a png")
    print("("+"\033[33m{}\033[00m".format("2")+") Seek in a png for information")
    print("("+"\033[33m{}\033[00m".format("3")+") Generate Keys")
    print("(" + "\033[33m{}\033[00m".format("4") + ") Quit\n")
    option = int(input("\033[91m{}\033[32m".format("1/2/3/4 >> ")))
    while (option != 1) and (option != 2) and (option != 3) and (option != 4):
        option = int(input("\033[91m{}\033[32m".format("1/2/3/4 >> ")))
    if option == 1:
        hide_filler()
    elif option == 2:
        seek_filler()
    elif option == 3:
        generate_chooser()
    elif option == 4:
        print("\033[41m\033[37mQuit "+app_name+"\033[00m\n")
        return 0


if __name__ == '__main__':
    main()