//----------------------------------------------------------------------------------------------------------------------------------------------
// to compile on command line:
//
// cc -O3 -o sha256mt sha256mt.c
//
//----------------------------------------------------------------------------------------------------------------------------------------------
// Nathan Mariels
// input :  ivec <8 words> block <16 words>
// output:  sha256 hash of input block
// note 1:  output is identical to SHA256, but computation is done by a mathematically different method
// note 2:  sha256 block is <N bits - data> <1 bit - end bit = 1> <(447 - N) bits zero padding> <64 bit - length in bits>
// note 3:  for a single block containing an ascii string of less than 4 bytes, the characters are stored in reverse order
//----------------------------------------------------------------------------------------------------------------------------------------------

//----------------------------------------------------------------------------------------------------------------------------------------------
// include

#include <stdio.h>

//----------------------------------------------------------------------------------------------------------------------------------------------
// typedef

typedef    unsigned int    uint32;

//----------------------------------------------------------------------------------------------------------------------------------------------
// macros

#define    loop(X,Y)           for(X=0;X<(Y);X++)
#define    c_char_to_hex(X)    ((((X)>='0')&&((X)<='9'))?((X)-'0'):(((((X)|32)>='a')&&(((X)|32)<='f'))?(((X)|32)-'a'+10):0))
#define    rightrotate(X,Y)    (((X)>>(Y))|((X)<<(32-(Y))))
#define    rightshift(X,Y)     (((X)>>(Y)))
#define    sha256_ws0(X)       (rightrotate((X), 7)^rightrotate((X),18)^rightshift((X),3))
#define    sha256_ws1(X)       (rightrotate((X),17)^rightrotate((X),19)^rightshift((X),10))
#define    sha256_s0(X)        (rightrotate((X), 2)^rightrotate((X),13)^rightrotate((X),22))
#define    sha256_s1(X)        (rightrotate((X), 6)^rightrotate((X),11)^rightrotate((X),25))

//----------------------------------------------------------------------------------------------------------------------------------------------
// const

#define         k_hash_array_size      96

const uint32    sha256_k[k_hash_array_size] = {
    0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5,
    0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174,
    0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da,
    0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967,
    0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85,
    0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070,
    0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3,
    0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2,
    0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };

//----------------------------------------------------------------------------------------------------------------------------------------------
// structs

typedef struct mt_sha256_state {
    uint32        a[k_hash_array_size];
    uint32        b[k_hash_array_size];
    uint32        out[8];
} mt_sha256_state;

//----------------------------------------------------------------------------------------------------------------------------------------------
// prototypes

uint32 mt_cmd_line_get_hex (const char *ptr);
void printh (int start, int end, char *str, uint32 *buf);
int main (int argc, const char * argv[]);

//----------------------------------------------------------------------------------------------------------------------------------------------
// code

int main (int argc, const char * argv[])
{
    mt_sha256_state  *hs, sha_state;
    int              i;
    uint32           x;
    
    hs = &sha_state;
    
    if (argc != 25) {
        printf ("\nsha256mt - computes the SHA256 hash of a single block of data in raw format\n\n");
        printf ("input : ivec <8 words> block <16 words>, data is 32 bit hex words, ");
        printf ("block : <N bits - data> <1 bit - end bit = 1> <(447 - N) bits zero padding> <64 bit - length in bits>\n\n");
        
        printf ("example 1: ./sha256mt 6A09E667 BB67AE85 3C6EF372 A54FF53A 510E527F 9B05688C 1F83D9AB 5BE0CD19 61626380 0 0 0 0 0 0 0 0 0 0 0 0 0 0 18\n");
        printf ("This will compute the SHA256 hash of the string \"abc\" using the default ivec.  It should be: BA7816BF .... F20015AD\n\n");
        
        printf ("example 2: ./sha256mt A51483D3 C96D8DDE 85FF79C1 025A5327 38005DA7 CEFAF9CE 60451A63 FA06237F 4D 6F 75 73 65 54 72 61 70 00 00 00 00 80 00 01B8\n");
        printf ("This will compute the SHA256 hash of a partial free start self collision.  It should be: A51483D3 .... F5E622FE\n\n");
        
        printf ("example 3: ./sha256mt 2EC557A2 0B6E2499 0CF13E72 2CDD2309 CD4AB124 B54D3298 9FBAAA26 595767F4   4E617468 616E4D61 7269656C 73800000 00 00 00 00 00 00 00 00 00 00 00 68\n");
        printf ("This will compute the SHA256 hash of a free start self collision.  It should be: 2EC557A2 .... 595767F4\n\n");
        
        return 0;
    }
    
    loop (i, 24) {
        x = mt_cmd_line_get_hex (argv[i+1]);
        
        if      (i < 4) hs->a[3-i] = x;
        else if (i < 8) hs->b[7-i] = x;
        else            hs->b[i-4] = x;
    }
    
    printh (3, 0,  "ivec:   ",   hs->a);
    printh (3, 0,  "",           hs->b);
    printh (4, 19, "\nblock:  ", hs->b);
    
    loop (i, 64) {
        if (i > 13) hs->b[i+6]  += sha256_ws1(hs->b[i+4]);
        if (i > 8)  hs->b[i+11] += hs->b[i+4];
        if (i > 0)  hs->b[i+19] += sha256_ws0(hs->b[i+4]);
        hs->b[i+20] = hs->b[i+4];
        hs->b[i+4]  = hs->b[i+4] + sha256_k[i] + hs->a[i] + hs->b[i] + sha256_s1 (hs->b[i+3]) + (hs->b[i+3] & hs->b[i+2]) + (~hs->b[i+3] & hs->b[i+1]);
        hs->a[i+4]  = hs->b[i+4] - hs->a[i] + sha256_s0 (hs->a[i+3]) + ((((hs->a[i+3] ^ hs->a[i+1]) & hs->a[i+2]) ^ (hs->a[i+1] & hs->a[i+3])));
    }
    
    loop (i, 4) {
        hs->out[i]   = hs->a [3-i] + hs->a[67-i];
        hs->out[i+4] = hs->b [3-i] + hs->b[67-i];
    }
    
    printh (0, 7, "\nhash:   ", hs->out);
    printf ("\n\n");
    
    return (0);
}

/*
 note:
 
 hs->a[i] = hs->b[i+4]                   - hs->a[i+4]                   + sha256_s0 (hs->a[i+3])      + ((((hs->a[i+3] ^ hs->a[i+1]) & hs->a[i+2]) ^ (hs->a[i+1] & hs->a[i+3])));
 hs->b[i] = sha256_s1 (hs->b[i+19])      - sha256_s1 (hs->b[i+12])      - sha256_s1 (hs->b[i+3])
          + sha256_k[i+16]               - sha256_k[i+9]                - sha256_k[i]
          - hs->b[i+20]                  + hs->b[i+16]                  + hs->b[i+13]                 - hs->b[i+9]              + hs->b[i+4]
          + hs->a[i+16]                  - hs->a[i+9]                   - hs->a[i]
          + ( hs->b[i+19] & hs->b[i+18]) - ( hs->b[i+12] & hs->b[i+11]) - ( hs->b[i+3] & hs->b[i+2])
          + (~hs->b[i+19] & hs->b[i+17]) - (~hs->b[i+12] & hs->b[i+10]) - (~hs->b[i+3] & hs->b[i+1])
          + sha256_ws0 (- hs->b[i+1]     - sha256_k[i+1]                - hs->a[i+1]                  - sha256_s1 (hs->b[i+4])  - (hs->b[i+4]  & hs->b[i+3])  - (~hs->b[i+4]  & hs->b[i+2])  + hs->b[i+5])
          + sha256_ws1 (- hs->b[i+14]    - sha256_k[i+14]               - hs->a[i+14]                 - sha256_s1 (hs->b[i+17]) - (hs->b[i+17] & hs->b[i+16]) - (~hs->b[i+17] & hs->b[i+15]) + hs->b[i+18]);
          
 hs->w[i] = hs->b[i+4] - hs->b[i] - sha256_k[i] - hs->a[i] - sha256_s1 (hs->b[i+3]) - (hs->b[i+3] & hs->b[i+2]) - (~hs->b[i+3] & hs->b[i+1]);
*/

//----------------------------------------------------------------------------------------------------------------------------------------------

uint32 mt_cmd_line_get_hex (const char *ptr)
{
    int       x = 0;
    uint32    r = 0;
    
    while ((*ptr) && (x < 8)) {
        r = r << 4;
        r = r + (uint32) c_char_to_hex (*ptr);
        ptr++;
        x++;
    }
    return (r);
}

//----------------------------------------------------------------------------------------------------------------------------------------------

void printh (int start, int end, char *str, uint32 *buf)
{
    int        i;
    
    if (str && *str) printf ("%s", str);
    if (start < end) for (i=start; i<=end; i++) printf ("%08X ", buf[i]);
    else for (i=start; i>=end; i--) printf ("%08X ", buf[i]);
}

//----------------------------------------------------------------------------------------------------------------------------------------------
