
// I retain copyright in this code but I encourage its free use provided
// that I don't carry any responsibility for the results. I am especially 
// happy to see it used in free and open source software. If you do use 
// it I would appreciate an acknowledgement of its origin in the code or
// the product that results and I would also appreciate knowing a liitle
// about the use to which it is being put.
//
// Dr B. R. Gladman <brg@gladman.uk.net> 20th December 2000.

// Modifications for StegFS by Andrew McDonald, January 2001.

// This is an implementation of the AES encryption algorithm (Rijndael)
// designed by Joan Daemen and Vincent Rijmen. This version is designed
// for a fixed block length of 128 bits (Ncol = 4) and can run with either
// big or little endian internal byte order (see aes.h). It implements
// the three key lengths of 128, 192 and 256 bits, the length being set
// by Nk in the key set interface where Nk = 4, 6 or 8 resepectively.
// Note that the validity of input parameters is NOT checked in this
// implementation.
//
// To compile AES (Rijndael) in C
//   a. Exclude (remove or comment out) the AES_IN_CPP define in aes.h
//   b. Exclude the AES_DLL define in aes.h
//   c. Rename aes.cpp to aes.c (if necessary) and compile as a C file
//
// To compile AES (Rijndael) in C++
//   a. Include the AES_IN_CPP define in aes.h
//   b. Exclude the AES_DLL define in aes.h
//   c. Rename aes.c to aes.cpp (if necessary) and compile as a C++ file
//
// To compile AES (Rijndael) in C as a Dynamic Link Library
//   a. Exclude the AES_IN_CPP define in aes.h
//   b. Include the AES_DLL define in aes.h
//   c. Rename aes.cpp to aes.c (if necessary) and compile as a C file


#include <linux/slab.h>


// define COMPACT for a compact but slow version of the cipher (the 
// other options below do not apply to the compact version)

//#define COMPACT

// define UNROLL to loop unrolling in encrypt/decrypt

#define UNROLL

// define FIXED_TABLES for static tables built within the executable
// image - otherwise tables are built dynamically on first use

#define FIXED_TABLES

// define FOUR_TABLES for high speed but increased memory use (uses four
// tables in place of one to reduce inner loop instruction counts)

#define FOUR_TABLES

// define LAST_ROUND_TABLES for ultimate speed but still more memory use
// (uses tables for the last round)

#define LAST_ROUND_TABLES

// define SMALL_PTABLE for slower finite field multiply (smaller power table)

#define SMALL_PTABLE

#include "aes.h"

#if defined(FIXED_TABLES) || defined(COMPACT)
#include "aes_tab.h"
#endif

// finite field multiply of a and b

#ifdef  SMALL_PTABLE
#define FFmul(a, b) (a && b ? \
        FFpow[(u = FFlog[a], v = FFlog[b] + u, v + (v < u ? 1 : 0))] : 0)
#else
#define FFmul(a, b) (a && b ? FFpow[FFlog[a] + FFlog[b]] : 0)
#endif

// multiply four bytes in GF(2^8) by 'x' {02} in parallel

#define m1  0x80808080
#define m2  0x7f7f7f7f
#define m3  0x1b1b1b1b
#define FFmulX(x) (u = (x) & m1, ((x) & m2) << 1) ^ ((u - (u >> 7)) & m3)

// perform column mix operation on four bytes in parallel

#define mix_col(x) (f2 = FFmulX(x), f2 ^ rot3(x ^ f2) ^ rot2(x) ^ rot1(x))

// perform inverse column mix operation on four bytes in parallel

#define inv_mix_col(x)                                                      \
    (f9 = (x),f2 = FFmulX(f9), f4 = FFmulX(f2), f8 = FFmulX(f4), f9 ^= f8,  \
    f2 ^ f4 ^ f8 ^ rot3(f2 ^ f9) ^ rot2(f4 ^ f9) ^ rot1(f9))

#ifndef COMPACT

#ifndef FIXED_TABLES

#ifdef  SMALL_PTABLE
static byte  FFpow[256];        // powers of generator (0x03) in GF(2^8)
#else
static byte  FFpow[512];
#endif
static byte  FFlog[256];        // log: map element to power of generator

static byte  s_box[256];        // the S box
static byte  inv_s_box[256];    // the inverse S box
static word  rcon_tab[28];      // table of round constants (can be reduced
                                // to a length of 10 for 128-bit blocks)
#ifdef FOUR_TABLES
static word  ft_tab[4][256];
static word  it_tab[4][256];
#else
static word  ft_tab[256];
static word  it_tab[256];
#endif

#ifdef LAST_ROUND_TABLES
#ifdef FOUR_TABLES
    static word  fl_tab[4][256];
    static word  il_tab[4][256];
#else
    static word  fl_tab[256];
    static word  il_tab[256];
#endif
#endif

static byte tab_gen = 0;    // non-zero if tables have been generated

static void gen_tabs(void)
{   word  i, t;
    byte  p, q;
#ifdef  SMALL_PTABLE
    byte  u, v;
#endif

    // log and power tables for GF(2**8) finite field with
    // 0x011b as modular polynomial - the simplest prmitive
    // root is 0x03, used here to generate the tables

    for(i = 0,p = 1; i < 256; ++i)
    {
        FFpow[i] = p;
#ifndef SMALL_PTABLE
        FFpow[i + 255] = p;
#endif
        FFlog[p] = (byte)(i);
        p ^=  (p << 1) ^ (p & 0x80 ? 0x01b : 0);
    }

    FFlog[1] = 0;

    for(i = 0,p = 1; i < 28; ++i)
    {
        rcon_tab[i] = bytes2word(p, 0, 0, 0);
        p = (p << 1) ^ (p & 0x80 ? 0x01b : 0);
    }

    for(i = 0; i < 256; ++i)
    {
        p = (i ? FFpow[255 - FFlog[i]] : 0);
        q  = ((p >> 7) | (p << 1)) ^ ((p >> 6) | (p << 2));
        p ^= 0x63 ^ q ^ ((q >> 6) | (q << 2));
        s_box[i] = p;
        inv_s_box[p] = (byte)(i);
    }

    for(i = 0; i < 256; ++i)
    {
        p = s_box[i];
#ifdef LAST_ROUND_TABLES
        t = bytes2word(p, 0, 0, 0);
#ifdef FOUR_TABLES
        fl_tab[0][i] = t;
        fl_tab[1][i] = rot1(t);
        fl_tab[2][i] = rot2(t);
        fl_tab[3][i] = rot3(t);
#else
        fl_tab[i] = t;
#endif
#endif
        t = bytes2word(FFmul(0x02, p), p, p, FFmul(0x03, p));
#ifdef FOUR_TABLES
        ft_tab[0][i] = t;
        ft_tab[1][i] = rot1(t);
        ft_tab[2][i] = rot2(t);
        ft_tab[3][i] = rot3(t);
#else
        ft_tab[i] = t;
#endif
        p = inv_s_box[i];
#ifdef LAST_ROUND_TABLES
        t = bytes2word(p, 0, 0, 0);
#ifdef FOUR_TABLES
        il_tab[0][i] = t;
        il_tab[1][i] = rot1(t);
        il_tab[2][i] = rot2(t);
        il_tab[3][i] = rot3(t);
#else
        il_tab[i] = t;
#endif
#endif
        t = bytes2word(FFmul(0x0e, p), FFmul(0x09, p),
                            FFmul(0x0d, p), FFmul(0x0b, p));
#ifdef FOUR_TABLES
        it_tab[0][i] = t;
        it_tab[1][i] = rot1(t);
        it_tab[2][i] = rot2(t);
        it_tab[3][i] = rot3(t);
#else
        it_tab[i] = t;
#endif
    }

    tab_gen = 1;
}

#endif

#ifdef LAST_ROUND_TABLES
#ifdef FOUR_TABLES

#define ls_box(x)       \
 (  fl_tab[0][byte0(x)] \
  ^ fl_tab[1][byte1(x)] \
  ^ fl_tab[2][byte2(x)] \
  ^ fl_tab[3][byte3(x)] )

#define lf_rnd(x, n)                    \
  ( fl_tab[0][byte0(x[n])]              \
  ^ fl_tab[1][byte1(x[(n + 1) % Ncol])] \
  ^ fl_tab[2][byte2(x[(n + 2) % Ncol])] \
  ^ fl_tab[3][byte3(x[(n + 3) % Ncol])] )

#define li_rnd(x, n)                    \
  ( il_tab[0][byte0(x[n])]              \
  ^ il_tab[1][byte1(x[(n + 3) % Ncol])] \
  ^ il_tab[2][byte2(x[(n + 2) % Ncol])] \
  ^ il_tab[3][byte3(x[(n + 1) % Ncol])] )

#else

#define ls_box(x)           \
 (  fl_tab[byte0(x)]        \
  ^ rot1(fl_tab[byte1(x)])  \
  ^ rot2(fl_tab[byte2(x)])  \
  ^ rot3(fl_tab[byte3(x)]) )

#define lf_rnd(x, n)                        \
  ( fl_tab[byte0(x[n])]                     \
  ^ rot1(fl_tab[byte1(x[(n + 1) % Ncol])])  \
  ^ rot2(fl_tab[byte2(x[(n + 2) % Ncol])])  \
  ^ rot3(fl_tab[byte3(x[(n + 3) % Ncol])]) )

#define li_rnd(x, n)                        \
  ( il_tab[byte0(x[n])]                     \
  ^ rot1(il_tab[byte1(x[(n + 3) % Ncol])])  \
  ^ rot2(il_tab[byte2(x[(n + 2) % Ncol])])  \
  ^ rot3(il_tab[byte3(x[(n + 1) % Ncol])]) )

#endif
#else

#define ls_box(x) bytes2word(           \
    s_box[byte0(x)], s_box[byte1(x)],   \
    s_box[byte2(x)], s_box[byte3(x)])

#define lf_rnd(x, n)    bytes2word(     \
    s_box[byte0(x[n])],                 \
    s_box[byte1(x[(n + 1) % Ncol])],    \
    s_box[byte2(x[(n + 2) % Ncol])],    \
    s_box[byte3(x[(n + 3) % Ncol])])

#define li_rnd(x, n)    bytes2word(     \
    inv_s_box[byte0(x[n])],             \
    inv_s_box[byte1(x[(n + 3) % Ncol])],\
    inv_s_box[byte2(x[(n + 2) % Ncol])],\
    inv_s_box[byte3(x[(n + 1) % Ncol])])

#endif

// initialise the key schedule from the user supplied key, where Nk
// is the key length (bits) divided by 32 with a value of 4, 6 or 8

static void set_key(const byte in_key[], const word Nk, const enum aes_key f, aes *cx)
{   word    i, t, u, f2, f4, f8, f9, *k1, *rcp;

#ifndef FIXED_TABLES

    if(!tab_gen)

        gen_tabs();
#endif

    f_dat(cx,mode) = f;               // encryption mode = enc, dec or both
    f_dat(cx,Nkey) = Nk;              // only 4, 6 or 8 valid (not checked)

    f_dat(cx,e_key)[0] = word_in(in_key     );
    f_dat(cx,e_key)[1] = word_in(in_key +  4);
    f_dat(cx,e_key)[2] = word_in(in_key +  8);
    f_dat(cx,e_key)[3] = word_in(in_key + 12);

    k1 = f_dat(cx,e_key); rcp = rcon_tab;

    switch(f_dat(cx,Nkey))
    {
    case 4: while(k1 < f_dat(cx,e_key) + 40)
            {   t = rot3(k1[3]);
                k1[4] = k1[0] ^ ls_box(t) ^ *rcp++;
                k1[5] = k1[1] ^ k1[4];
                k1[6] = k1[2] ^ k1[5];
                k1[7] = k1[3] ^ k1[6];
                k1 += 4;
            }
            break;

    case 6: f_dat(cx,e_key)[4] = word_in(in_key + 16);
            f_dat(cx,e_key)[5] = word_in(in_key + 20);
            while(k1 < f_dat(cx,e_key) + 48)
            {   t = rot3(k1[5]);
                k1[ 6] = k1[0] ^ ls_box(t) ^ *rcp++;
                k1[ 7] = k1[1] ^ k1[ 6];
                k1[ 8] = k1[2] ^ k1[ 7];
                k1[ 9] = k1[3] ^ k1[ 8];
                k1[10] = k1[4] ^ k1[ 9];
                k1[11] = k1[5] ^ k1[10];
                k1 += 6;
            }
            break;

    case 8: f_dat(cx,e_key)[4] = word_in(in_key + 16);
            f_dat(cx,e_key)[5] = word_in(in_key + 20);
            f_dat(cx,e_key)[6] = word_in(in_key + 24);
            f_dat(cx,e_key)[7] = word_in(in_key + 28);
            while(k1 < f_dat(cx,e_key) + 56)
            {   t = rot3(k1[7]);
                k1[ 8] = k1[0] ^ ls_box(t) ^ *rcp++;
                k1[ 9] = k1[1] ^ k1[ 8];
                k1[10] = k1[2] ^ k1[ 9];
                k1[11] = k1[3] ^ k1[10];
                k1[12] = k1[4] ^ ls_box(k1[11]);
                k1[13] = k1[5] ^ k1[12];
                k1[14] = k1[6] ^ k1[13];
                k1[15] = k1[7] ^ k1[14];
                k1 += 8;
            }
            break;
    }

    if(f_dat(cx,mode) != enc)
    {
        f_dat(cx,d_key)[0] = f_dat(cx,e_key)[0];
        f_dat(cx,d_key)[1] = f_dat(cx,e_key)[1];
        f_dat(cx,d_key)[2] = f_dat(cx,e_key)[2];
        f_dat(cx,d_key)[3] = f_dat(cx,e_key)[3];

        for(i = 4; i < 4 * f_dat(cx,Nkey) + 24; ++i)

            f_dat(cx,d_key)[i] = inv_mix_col(f_dat(cx,e_key)[i]);

        f_dat(cx,d_key)[4 * f_dat(cx,Nkey) + 24] = f_dat(cx,e_key)[4 * f_dat(cx,Nkey) + 24];
        f_dat(cx,d_key)[4 * f_dat(cx,Nkey) + 25] = f_dat(cx,e_key)[4 * f_dat(cx,Nkey) + 25];
        f_dat(cx,d_key)[4 * f_dat(cx,Nkey) + 26] = f_dat(cx,e_key)[4 * f_dat(cx,Nkey) + 26];
        f_dat(cx,d_key)[4 * f_dat(cx,Nkey) + 27] = f_dat(cx,e_key)[4 * f_dat(cx,Nkey) + 27];
    }

    return;
}

// encrypt a block of text

#ifdef FOUR_TABLES

#define f_rnd(x, n)                     \
  ( ft_tab[0][byte0(x[n])]              \
  ^ ft_tab[1][byte1(x[(n + 1) % Ncol])] \
  ^ ft_tab[2][byte2(x[(n + 2) % Ncol])] \
  ^ ft_tab[3][byte3(x[(n + 3) % Ncol])] )

#else

#define f_rnd(x, n)                         \
  ( ft_tab[byte0(x[n])]                     \
  ^ rot1(ft_tab[byte1(x[(n + 1) % Ncol])])  \
  ^ rot2(ft_tab[byte2(x[(n + 2) % Ncol])])  \
  ^ rot3(ft_tab[byte3(x[(n + 3) % Ncol])]) )

#endif

#define f_round(bo, bi, k)            \
    bo[0] = f_rnd(bi, 0) ^ (k)[0];    \
    bo[1] = f_rnd(bi, 1) ^ (k)[1];    \
    bo[2] = f_rnd(bi, 2) ^ (k)[2];    \
    bo[3] = f_rnd(bi, 3) ^ (k)[3]

#ifdef UNROLL

static void encrypt(const byte in_blk[16], byte out_blk[16], const aes *cx)
{   word        b0[4], b1[4];
    const word  *kp = f_dat(cx,e_key);

    b0[0] = word_in(in_blk     ) ^ kp[0];
    b0[1] = word_in(in_blk +  4) ^ kp[1];
    b0[2] = word_in(in_blk +  8) ^ kp[2];
    b0[3] = word_in(in_blk + 12) ^ kp[3]; kp += 4;

    if(f_dat(cx,Nkey) > 6)
    {
        f_round(b1, b0, kp);
        f_round(b0, b1, kp + 4); kp += 8;
    }

    if(f_dat(cx,Nkey) > 4)
    {
        f_round(b1, b0, kp);
        f_round(b0, b1, kp + 4); kp += 8;
    }

    f_round(b1, b0, kp);      f_round(b0, b1, kp +  4);
    f_round(b1, b0, kp +  8); f_round(b0, b1, kp + 12);
    f_round(b1, b0, kp + 16); f_round(b0, b1, kp + 20);
    f_round(b1, b0, kp + 24); f_round(b0, b1, kp + 28);
    f_round(b1, b0, kp + 32); kp += 36;

    word_out(out_blk,      lf_rnd(b1, 0) ^ kp[0]);
    word_out(out_blk +  4, lf_rnd(b1, 1) ^ kp[1]);
    word_out(out_blk +  8, lf_rnd(b1, 2) ^ kp[2]);
    word_out(out_blk + 12, lf_rnd(b1, 3) ^ kp[3]);
}

#else

static void encrypt(const byte in_blk[16], byte out_blk[16], const aes *cx)
{   word        i, b0[4], b1[4];
    const word  *kp = f_dat(cx,e_key);

    b0[0] = word_in(in_blk     ) ^ kp[0];
    b0[1] = word_in(in_blk +  4) ^ kp[1];
    b0[2] = word_in(in_blk +  8) ^ kp[2];
    b0[3] = word_in(in_blk + 12) ^ kp[3]; kp += 4;

    for(i = 0; i < 2 + (f_dat(cx,Nkey) >> 1); ++i)
    {
        f_round(b1, b0, kp);
        f_round(b0, b1, kp + 4); kp += 8;
    }

    f_round(b1, b0, kp); kp += 4;

    word_out(out_blk,      lf_rnd(b1, 0) ^ kp[0]);
    word_out(out_blk +  4, lf_rnd(b1, 1) ^ kp[1]);
    word_out(out_blk +  8, lf_rnd(b1, 2) ^ kp[2]);
    word_out(out_blk + 12, lf_rnd(b1, 3) ^ kp[3]);
}

#endif

// decrypt a block of text

#ifdef FOUR_TABLES

#define i_rnd(x, n)                     \
  ( it_tab[0][byte0(x[n])]              \
  ^ it_tab[1][byte1(x[(n + 3) % Ncol])] \
  ^ it_tab[2][byte2(x[(n + 2) % Ncol])] \
  ^ it_tab[3][byte3(x[(n + 1) % Ncol])] )

#else

#define i_rnd(x, n)                         \
  ( it_tab[byte0(x[n])]                     \
  ^ rot1(it_tab[byte1(x[(n + 3) % Ncol])])  \
  ^ rot2(it_tab[byte2(x[(n + 2) % Ncol])])  \
  ^ rot3(it_tab[byte3(x[(n + 1) % Ncol])]) )

#endif

#define i_round(bo, bi, k)            \
    bo[3] = i_rnd(bi, 3) ^ (k)[3];    \
    bo[2] = i_rnd(bi, 2) ^ (k)[2];    \
    bo[1] = i_rnd(bi, 1) ^ (k)[1];    \
    bo[0] = i_rnd(bi, 0) ^ (k)[0]

#ifdef  UNROLL

static void decrypt(const byte in_blk[16], byte out_blk[16], const aes *cx)
{   word        b0[4], b1[4];
    const word  *kp = f_dat(cx,d_key) + 4 * (f_dat(cx,Nkey) + 6);

    b0[3] = word_in(in_blk + 12) ^ kp[3];
    b0[2] = word_in(in_blk +  8) ^ kp[2];
    b0[1] = word_in(in_blk +  4) ^ kp[1];
    b0[0] = word_in(in_blk     ) ^ kp[0]; kp -= 4;

    if(f_dat(cx,Nkey) > 6)
    {
        i_round(b1, b0, kp);
        i_round(b0, b1, kp - 4); kp -= 8;
    }

    if(f_dat(cx,Nkey) > 4)
    {
        i_round(b1, b0, kp);
        i_round(b0, b1, kp - 4); kp -= 8;
    }

    i_round(b1, b0, kp);      i_round(b0, b1, kp -  4);
    i_round(b1, b0, kp -  8); i_round(b0, b1, kp - 12);
    i_round(b1, b0, kp - 16); i_round(b0, b1, kp - 20);
    i_round(b1, b0, kp - 24); i_round(b0, b1, kp - 28);
    i_round(b1, b0, kp - 32); kp -= 36;

    word_out(out_blk + 12, li_rnd(b1, 3) ^ kp[3]);
    word_out(out_blk +  8, li_rnd(b1, 2) ^ kp[2]);
    word_out(out_blk +  4, li_rnd(b1, 1) ^ kp[1]);
    word_out(out_blk,      li_rnd(b1, 0) ^ kp[0]);
}

#else

static void decrypt(const byte in_blk[16], byte out_blk[16], const aes *cx)
{   word        i, b0[4], b1[4];
    const word  *kp = f_dat(cx,d_key) + 4 * (f_dat(cx,Nkey) + 6);

    b0[3] = word_in(in_blk + 12) ^ kp[3];
    b0[2] = word_in(in_blk +  8) ^ kp[2];
    b0[1] = word_in(in_blk +  4) ^ kp[1];
    b0[0] = word_in(in_blk     ) ^ kp[0]; kp -= 4;

    for(i = 0; i < 2 + (f_dat(cx,Nkey) >> 1); ++i)
    {
        i_round(b1, b0, kp);
        i_round(b0, b1, kp - 4); kp -= 8;
    }

    i_round(b1, b0, kp); kp -= 4;

    word_out(out_blk + 12, li_rnd(b1, 3) ^ kp[3]);
    word_out(out_blk +  8, li_rnd(b1, 2) ^ kp[2]);
    word_out(out_blk +  4, li_rnd(b1, 1) ^ kp[1]);
    word_out(out_blk,      li_rnd(b1, 0) ^ kp[0]);
}

#endif

#else

#define ls_box(x) bytes2word(           \
    s_box[byte0(x)], s_box[byte1(x)],   \
    s_box[byte2(x)], s_box[byte3(x)])

#define sbx_row(i) bytes2word(          \
    s_box[byte0(b0[i])],                \
    s_box[byte1(b0[(i + 1) % Ncol])],   \
    s_box[byte2(b0[(i + 2) % Ncol])],   \
    s_box[byte3(b0[(i + 3) % Ncol])])

#define inv_sbx_row(i)  bytes2word(         \
    inv_s_box[byte0(b0[i])],                \
    inv_s_box[byte1(b0[(i + 3) % Ncol])],   \
    inv_s_box[byte2(b0[(i + 2) % Ncol])],   \
    inv_s_box[byte3(b0[(i + 1) % Ncol])])

// initialise the key schedule from the user supplied key, where Nk
// is the key length (bits) divided by 32 with a value of 4, 6 or 8

static void set_key(const byte in_key[], const word Nk, const enum aes_key f, aes *cx)
{   word    *k1, *rcp, t;

    f_dat(cx,mode) = f;               // encryption mode = enc, dec or both
    f_dat(cx,Nkey) = Nk;              // only 4, 6 or 8 valid (not checked)

    f_dat(cx,e_key)[0] = word_in(in_key     );
    f_dat(cx,e_key)[1] = word_in(in_key +  4);
    f_dat(cx,e_key)[2] = word_in(in_key +  8);
    f_dat(cx,e_key)[3] = word_in(in_key + 12);

    k1 = f_dat(cx,e_key); rcp = rcon_tab;

    switch(f_dat(cx,Nkey))
    {
    case 4: while(k1 < f_dat(cx,e_key) + 40)
            {   t = rot3(k1[3]);
                k1[4] = k1[0] ^ ls_box(t) ^ *rcp++;
                k1[5] = k1[1] ^ k1[4];
                k1[6] = k1[2] ^ k1[5];
                k1[7] = k1[3] ^ k1[6];
                k1 += 4;
            }
            break;

    case 6: f_dat(cx,e_key)[4] = word_in(in_key + 16);
            f_dat(cx,e_key)[5] = word_in(in_key + 20);
            while(k1 < f_dat(cx,e_key) + 48)
            {   t = rot3(k1[5]);
                k1[ 6] = k1[0] ^ ls_box(t) ^ *rcp++;
                k1[ 7] = k1[1] ^ k1[ 6];
                k1[ 8] = k1[2] ^ k1[ 7];
                k1[ 9] = k1[3] ^ k1[ 8];
                k1[10] = k1[4] ^ k1[ 9];
                k1[11] = k1[5] ^ k1[10];
                k1 += 6;
            }
            break;

    case 8: f_dat(cx,e_key)[4] = word_in(in_key + 16);
            f_dat(cx,e_key)[5] = word_in(in_key + 20);
            f_dat(cx,e_key)[6] = word_in(in_key + 24);
            f_dat(cx,e_key)[7] = word_in(in_key + 28);
            while(k1 < f_dat(cx,e_key) + 56)
            {   t = rot3(k1[7]);
                k1[ 8] = k1[0] ^ ls_box(t) ^ *rcp++;
                k1[ 9] = k1[1] ^ k1[ 8];
                k1[10] = k1[2] ^ k1[ 9];
                k1[11] = k1[3] ^ k1[10];
                k1[12] = k1[4] ^ ls_box(k1[11]);
                k1[13] = k1[5] ^ k1[12];
                k1[14] = k1[6] ^ k1[13];
                k1[15] = k1[7] ^ k1[14];
                k1 += 8;
            }
            break;
    }

    return;
}

static void encrypt(const byte in_blk[16], byte out_blk[16], const aes *cx)
{   word        r, u, f2, b0[4], b1[4]; 
    const word  *kp = f_dat(cx,e_key);

    b0[0] = word_in(in_blk     ) ^ *kp++;
    b0[1] = word_in(in_blk +  4) ^ *kp++;
    b0[2] = word_in(in_blk +  8) ^ *kp++;
    b0[3] = word_in(in_blk + 12) ^ *kp++;

    for(r = 0; r < f_dat(cx,Nkey) + 5; ++r)
    {
        b1[0] = sbx_row(0);
        b1[1] = sbx_row(1);
        b1[2] = sbx_row(2);
        b1[3] = sbx_row(3);

        b0[0] = mix_col(b1[0]) ^ *kp++;
        b0[1] = mix_col(b1[1]) ^ *kp++;
        b0[2] = mix_col(b1[2]) ^ *kp++;
        b0[3] = mix_col(b1[3]) ^ *kp++;
    }

    word_out(out_blk,      sbx_row(0) ^ *kp++);
    word_out(out_blk +  4, sbx_row(1) ^ *kp++);
    word_out(out_blk +  8, sbx_row(2) ^ *kp++);
    word_out(out_blk + 12, sbx_row(3) ^ *kp++);
}

// decrypt a block of text

static void decrypt(const byte in_blk[16], byte out_blk[16], const aes *cx)
{   word        r, u, f2, f4, f8, f9, b0[4], b1[4]; 
    const word  *kp = f_dat(cx,e_key) + 4 * (f_dat(cx,Nkey) + 7);

    b0[3] = word_in(in_blk + 12) ^ *--kp;
    b0[2] = word_in(in_blk +  8) ^ *--kp;
    b0[1] = word_in(in_blk +  4) ^ *--kp;
    b0[0] = word_in(in_blk     ) ^ *--kp;

    for(r = 0; r < f_dat(cx,Nkey) + 5; ++r)
    {
        b1[3] = inv_sbx_row(3) ^ *--kp;
        b1[2] = inv_sbx_row(2) ^ *--kp;
        b1[1] = inv_sbx_row(1) ^ *--kp;
        b1[0] = inv_sbx_row(0) ^ *--kp;

        b0[3] = inv_mix_col(b1[3]);
        b0[2] = inv_mix_col(b1[2]);
        b0[1] = inv_mix_col(b1[1]);
        b0[0] = inv_mix_col(b1[0]);
    }

    word_out(out_blk + 12, inv_sbx_row(3) ^ *--kp);
    word_out(out_blk +  8, inv_sbx_row(2) ^ *--kp);
    word_out(out_blk +  4, inv_sbx_row(1) ^ *--kp);
    word_out(out_blk,      inv_sbx_row(0) ^ *--kp);
}

#endif


/* Stegfs AES stubs */
/* Interfaces to Brian Gladman's AES implementation to avoid
   rewriting the existing AES finalists included in StegFS */

/* Copyright (c)2001 Andrew McDonald (andrew@mcdonald.org.uk) */
/* $Id$ */

char *stegfs_aes_set_key(const char *cin_key, const unsigned int key_len)
{
	unsigned int key_len2;
	aes *keyinfo;

	key_len2 = key_len / 32;
	keyinfo = (aes *)kmalloc(sizeof(aes), GFP_KERNEL);

	set_key(cin_key, key_len2, both, keyinfo);

	return (char *)keyinfo;
}


void stegfs_aes_encrypt(const char *cl_key, const char *cin_blk, char *cout_blk)
{

	encrypt(cin_blk, cout_blk, (aes *)cl_key);

}


void stegfs_aes_decrypt(const char *cl_key, const char *cin_blk, char *cout_blk)
{

	decrypt(cin_blk, cout_blk, (aes *)cl_key);

}
