// Ed25519 using 32-bit int // // R. Perry, June 2016 #include #include #include #include #include "impl.h" #ifndef DEBUG #define DEBUG 0 #endif // prevent timing attacks, see Point_mul() and mod_L() // #ifndef CONSTANT_TIME #define CONSTANT_TIME 1 #endif int debug = DEBUG; //------------------------------------------------------------------------------ // constants // m = 2^255 - 19 // E m2 = // m-2 { 0xeb, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f }; // d*2 // C d2 = { 0x159, 0xb2f, 0x426, 0x9b9, 0xbd6, 0x56e, 0x3b1, 0x828, 0x49a, 0xe01, 0x000, 0xd13, 0xef3, 0xf2e, 0xe80, 0x198, 0xce7, 0xdff, 0xc56, 0xd9d, 0x406, 0x2 }; // curve neutral element (0, 1, 1, 0) // Point Z = { { 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x0 }, { 0x001, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x0 }, { 0x001, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x0 }, { 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x0 } }; // curve base point // Point G = { { 0x51a, 0x25d, 0x08f, 0x2d6, 0x956, 0xb2c, 0x5a7, 0x952, 0x760, 0x2cc, 0xc69, 0xdc5, 0xdd6, 0x31f, 0x4e2, 0xc0a, 0x3fe, 0x6e5, 0x3cd, 0x36d, 0x169, 0x2 }, { 0x658, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x666, 0x6 }, { 0x001, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x0 }, { 0xda3, 0xb7d, 0x3a5, 0x8ab, 0xdde, 0xf56, 0x152, 0x775, 0xf80, 0xf09, 0xd20, 0xe37, 0x4ab, 0x8e6, 0xa4e, 0x66e, 0x665, 0x8b7, 0xfd7, 0x5f0, 0x787, 0x6 } }; // curve group order L = 2^252 + 27742317777372353535851937790883648493 // C L = { 0x3ed, 0xf5d, 0xa5c, 0x631, 0x812, 0xd65, 0x79c, 0xa2f, 0x9de, 0xdef, 0x014, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x000, 0x1 }; //------------------------------------------------------------------------------ // operations mod (2^255 - 19) // add: w = u+v with no reduction // void add( C w, C u, C v) { for( int i = 0; i < N; ++i) w[i] = u[i] + v[i]; } // sub: w = u-v with no reduction // // *** assumes that B is 12 and N is 22, i.e. 2^255 = 2^3 * (2^12)^21 *** // // does not assume that v is in the range 0 to m-1: // // -v = m-v mod m, but since v could be >= m, subtract it from 2m // // the extra m will be eliminated later when mod is performed // // m = 2^255 - 19 = 7 fff fff ... fff fed // 2m = 2^256 - 38 = f fff fff ... fff fda // void sub( C w, C u, C v) { // w = (2's complement of v) + 2m = (1's complement of v) + 1 + 2m // w[0] = (~v[0] & 0xfff) + 0xfda + 1; w[N-1] = ~v[N-1] + 0xf; // mask out overflow below for( int i = 1; i < N-1; ++i) w[i] = (~v[i] & 0xfff) + 0xfff; if( debug) print( "sub1", w, N); // reduce and mask out any overflow // reduce( w, N); w[N-1] &= 0xf; if( debug) print( "sub2", w, N); // now w = 2m-v, add u // add( w, u, w); } // mul: r = u*v with reduction and mod // // w = u[21]*x^21 + u[20]*x^20 + ... + u[1]*x + u[0] // * v[21]*x^21 + v[20]*x^20 + ... + v[1]*x + v[0] // // = u[21]*v[21]*x^42 + ... + (u[1]*v[0]+u[0]*v[1])*x + u[0]*v[0] // // = w[42]*x^42 + ... + w[1]*x + w[0] // // r = w mod (2^255 - 19) // void mul( C r, C u, C v) { D w; for( int k = 0; k < NN; ++k) w[k] = 0; for( int i = 0; i < N; ++i) for( int j = 0; j < N; ++j) w[i+j] += u[i]*v[j]; reduce( w, NN); mod( r, w); } // mul2: r = 2*u*v with reduction and mod // void mul2( C r, C u, C v) { D w; for( int k = 0; k < NN; ++k) w[k] = 0; for( int i = 0; i < N; ++i) { UI t = u[i] << 1; // 2*u for( int j = 0; j < N; ++j) w[i+j] += t*v[j]; } reduce( w, NN); mod( r, w); } // sqr: r = u*u with reduction and mod // void sqr( C r, C u) { D w; for( int k = 0; k < NN; ++k) w[k] = 0; for( int i = 0; i < N; ++i) w[i+i] += u[i]*u[i]; // uses about half the number of multiplies compared to u*v // since inner loop on j only goes up to i-1 instead of N-1 // for( int i = 0; i < N; ++i) for( int j = 0; j < i; ++j) w[i+j] += 2*u[i]*u[j]; reduce( w, NN); mod( r, w); } // reduce a C or D value by propagating the carries // void reduce( UI *w, int n) { for( int i = 0; i < n-1; ++i) { w[i+1] += w[i] >> B; w[i] &= MASK; } } // fast_reduce: reduce a C or D value by propagating the carries // starting at a particular offset and stoping once the carry becomes zero. // // This should only be used when a reduced w has been increased at only // the offset position (w[offset] += something), e.g. in mod() // void fast_reduce( UI *w, int n, int offset) { for( int i = offset; i < n-1; ++i) { int carry = w[i] >> B; if( carry == 0) break; w[i+1] += carry; w[i] &= MASK; } } // mod: u = w mod (2^255 - 19) // // *** assumes that B is 12 and N is 22, i.e. 2^255 = 2^3 * (2^12)^21 *** // // NOTE: u may not be fully reduced but will not exceed (2^255 - 1) // // see mod.pdf notes // void mod( C u, D w) { UI t; u[0] = w[0]; /*** for( int i = 1; i < N; ++i) { t = 19*w[i+N-1]; u[i] = w[i] + (t >> 3); u[i-1] += (t & 7) << (B-3); } reduce( u, N); ***/ // combined mod and reduce: // for( int i = 1; i < N; ++i) { t = 19*w[i+N-1]; u[i-1] += (t & 7) << (B-3); u[i] = w[i] + (u[i-1] >> B) + (t >> 3); // carry + remainder u[i-1] &= MASK; } t = u[N-1] >> 3; if( debug && t) printf( "mod: t = %x\n", t); while( t) { u[N-1] -= t << 3; u[0] += 19*t; // if( debug) print( "mod3", u, N); fast_reduce( u, N, 0); // if( debug) print( "mod4", u, N); t = u[N-1] >> 3; // in rare cases this will be non-zero if( debug && t) printf( "mod: new t = %x\n", t); } // defer this adjustment until the final value mod m is needed // adjust( u); // if u >= m then subtract m } // adjust: if u >= m then subtract m // // *** assumes that B is 12 and N is 22, i.e. 2^255 = 2^3 * (2^12)^21 *** // // also assumes that mod has already been performed // so that u is in the range 0 ... (2^255 - 1) // void adjust( C u) { // m = 2^255 - 19 = 7 fff fff ... fff fed // if( u[N-1] < 7) return; // u < m for( int i = N-2; i > 0; --i) if( u[i] < 0xfff) return; // u < m if( u[0] < 0xfed) return; // u < m if( debug) printf( "adjust: u >= m\n"); // u >= m, subtract m // u[N-1] -= 7; u[0] -= 0xfed; for( int i = 1; i < N-1; ++i) u[i] -= 0xfff; } // modPow: u = v^e mod (2^255 - 19) // void modPow( C u, C v, E e) { C t; UC b; // initialize: u = 1, t = v // u[0] = 1; t[0] = v[0]; for( int i = 1; i < N; ++i) { u[i] = 0; t[i] = v[i]; } // for each byte in the exponent // for( int i = 0; i < M; ++i) { b = e[i]; for( int j = 0; j < 8; ++j) // for each bit in the exponent byte { if( b & 1) mul( u, u, t); // u = u*t b >>= 1; if( i == M-1 && b == 0) break; // done with MSB sqr( t, t); // t = t*t } } } // modInverse: u = v^(m-2) mod m, m = 2^255 - 19 // void modInverse( C u, C v) { modPow( u, v, m2); } //------------------------------------------------------------------------------ // curve operations // Point_add: p3 = p1 + p2 // void Point_add( Point *p3, Point *p1, Point *p2) { C a, b, c, d, e, f, g, h, r, s; // A = (Y1-X1)*(Y2-X2) // sub( r, p1->y, p1->x); sub( s, p2->y, p2->x); mul( a, r, s); // B = (Y1+X1)*(Y2+X2) // add( r, p1->y, p1->x); add( s, p2->y, p2->x); mul( b, r, s); // C = T1*2*d*T2 // mul( r, p1->t, d2); mul( c, r, p2->t); // D = Z1*2*Z2 // mul2( d, p1->z, p2->z); // E = B-A, F = D-C, G = D+C, H = B+A // sub( e, b, a); sub( f, d, c); add( g, d, c); add( h, b, a); // X3 = E*F // mul( p3->x, e, f); // Y3 = G*H // mul( p3->y, g, h); // T3 = E*H // mul( p3->t, e, h); // Z3 = F*G // mul( p3->z, f, g); } // Point_mul: q = e*r // void Point_mul( Point *q, E e, Point *r) { Point t = *r, *p = &t; UC b; *q = Z; for( int i = 0; i < M; ++i) // for each byte in the exponent { b = e[i]; for( int j = 0; j < 8; ++j) // for each bit in the exponent byte { #if CONSTANT_TIME int alpha = b & 1, beta = !alpha; Point d; Point_add( &d, q, p); for( int k = 0; k < N; ++k) // q = d or q = q { q->x[k] = alpha*d.x[k] + beta*q->x[k]; q->y[k] = alpha*d.y[k] + beta*q->y[k]; q->z[k] = alpha*d.z[k] + beta*q->z[k]; q->t[k] = alpha*d.t[k] + beta*q->t[k]; } #else if( b & 1) Point_add( q, q, p); // q = q + p #endif b >>= 1; #if CONSTANT_TIME if( i == M-1 && j == 7) break; // skip final/useless Point_add #else if( i == M-1 && b == 0) break; // done with MSB #endif Point_add( p, p, p); // p = p + p } } adjust( q->x); adjust( q->y); } // encode point (x,y) values into E value // void Point_encode( E e, Point *p) { C x, y, zi; modInverse( zi, p->z); mul( x, p->x, zi); adjust( x); mul( y, p->y, zi); adjust( y); // if( debug) print( "encode: x", x, N); // if( debug) print( "encode: y", y, N); CtoE( e, y); e[M-1] |= (x[0] & 1) << 7; // lsb of x -> msb of e } //------------------------------------------------------------------------------ // operations mod L, the curve group order // // L = 2^252 + 27742317777372353535851937790883648493 // // = 1 000 000 000 000 000 000 000 000 000 000 014 def 9de a2f 79c d65 812 631 a5c f5d 3ed // // 21 20 19 18 17 16 15 14 13 12 11 10 9 8 7 6 5 4 3 2 1 0 // // note that L[20:11] are all zero // mod_L: e = u mod L // // *** assumes that B is 12 and N is 22 *** // // NOTE: this routine overwrites u // void mod_L( E e, D u) { // w here references input u as signed int // int s, *w = (int *) u; UI t, carry; int k, neg; // poly(w[42:0]) = w[42]*x^42 + ... + w[1]*x + w[0], x = 2^12 // // NN = 43, N = 22 // // i = 42, remainder = poly(w[41:0]) + w[42]*x^21*poly(L[10:0]) // i = 41, remainder = poly(w[40:0]) + w[41]*x^20*poly(L[10:0]) // ... // i = 21, remainder = poly(w[20:0]) + w[21]*x^0*poly(L[10:0]) for( int i = NN-1; i >= N-1; --i) // w index { s = w[i]; w[i] = 0; for( int j = 0; j < 13; ++j) // L index, L[20:11] are all zero { k = i+j-N+1; // remainder index // // i = 42, j = 0:20, k = 21:41 // i = 41, j = 0:20, k = 20:40 // ... // i = 21, j = 0:20, k = 0:20 // #if CONSTANT_TIME w[k] -= s*L[j]; t = abs( w[k]); carry = (t >> B); neg = 2*(w[k] >= 0) - 1; // -1 or +1 w[k+1] += neg*carry; t &= MASK; w[k] = neg*t; #else w[k] -= s*L[j]; t = abs( w[k]); carry = (t >> B); if( j >= 10 && carry == 0) break; // leave loop early neg = w[k] < 0; w[k+1] += neg ? -carry : carry; t &= MASK; w[k] = neg ? -t : t; #endif } } if( debug) print( "mod_L", u, NN); // adjust negative values // for( int k = 0; k < N-1; ++k) if( w[k] < 0) { w[k] += 0x1000; w[k+1] -= 1; } // now w[0]...w[N-2] are all non-negative; if w[N-1] is negative, add L // if( w[N-1] < 0) { if( debug) { printf( "neg\n"); print( "w", u, NN); } for( int k = 0; k < N; ++k) { w[k] += L[k]; w[k+1] += w[k] >> B; w[k] &= MASK; } } // copy lower half of w to e // CtoE( e, u); } // sign: s = (r + k*a) mod L // void sign( E s, E re, E ke, E ae) { D w; C r, k, a; EtoC( r, re); EtoC( k, ke); EtoC( a, ae); if( debug) { print_E( "sign: re", re, M); print( "sign: r", r, N); } // start with w = r // for( int i = 0; i < N; ++i) w[i] = r[i]; for( int i = N; i < NN; ++i) w[i] = 0; // add on k*a // for( int i = 0; i < N; ++i) for( int j = 0; j < N; ++j) w[i+j] += k[i]*a[j]; // reduce and mod // reduce( w, NN); mod_L( s, w); } //------------------------------------------------------------------------------ // I/O and conversion routines // print a C or D value, MSB first, *** assumes that B is 12 *** // void print( const char *msg, UI *u, int n) { if( msg) printf( "%s = ", msg); for( int i = n-1; i >= 0; --i) printf( " %03x", u[i]); // %03x for 12 bits putchar( '\n'); } // print an E or EE value, LSB first // void print_E( const char *msg, UC *e, int n) { if( msg) printf( "%s = ", msg); for( int i = 0; i < n; ++i) printf( "%02x", e[i]); putchar( '\n'); } static char digits[] = "0123456789ABCDEFabcdef"; // convert one hex digit character to an int // UI hex2int( int h) { if( !isxdigit( h) ) { fprintf( stderr, "hex2int: bad data\n"); exit(1); } return strchr( digits, toupper(h)) - digits; } // convert hex string (big-endian) to C value *** assumes that B is 12 *** // // 3 hex digits -> one 12-bit piece // void convert( C u, const char *s) { if( strspn(s,digits) != M*2) { fprintf( stderr, "convert: bad strlen\n"); exit(1); } u[N-1] = hex2int( s[0]); for( int i = N-2, j = 1; i >= 0; --i, j += 3) u[i] = (hex2int(s[j]) << 8) | (hex2int(s[j+1]) << 4) | hex2int(s[j+2]); } // convert hex string (big-endian) to D value *** assumes that B is 12 *** // // 3 hex digits -> one 12-bit piece // // 512 = 8 + 42*12 // void convert_D( D w, const char *s) { if( strspn(s,digits) != M*4) { fprintf( stderr, "convert_D: bad strlen\n"); exit(1); } w[NN-1] = (hex2int(s[0]) << 4) | hex2int(s[1]); for( int i = NN-2, j = 2; i >= 0; --i, j += 3) w[i] = (hex2int(s[j]) << 8) | (hex2int(s[j+1]) << 4) | hex2int(s[j+2]); } // convert hex string (little-endian) to E value // // 2 hex digits -> one byte // void convert_E( E e, const char *s) { if( strspn(s,digits) != M*2) { fprintf( stderr, "convert_E: bad strlen\n"); exit(1); } for( int i = 0, j = 0; i < M; ++i, j += 2) e[i] = (hex2int(s[j]) << 4) | hex2int(s[j+1]); } // CtoE: e = c, convert C value to E value *** assumes that B is 12 *** // // c: x xxx xxx xxx ... xxx N = 22 // 21 20 19 18 ... 0 // // e: xx xx xx xx ... xx M = 32 // 31 30 29 28 ... 0 // // c: 21 20 19 18 // x|x xx| xx x|x xx|... // e: 31 30 29 28 27 // void CtoE( E e, C c) { e[M-1] = (c[N-1] << 4) | (c[N-2] >> 8); e[M-2] = c[N-2] & 0xFF; for( int i = N-3, j = M-3; i > 0; i -= 2, j -= 3) { e[j] = c[i] >> 4; e[j-1] = ((c[i] & 0xF) << 4) | (c[i-1] >> 8); e[j-2] = c[i-1] & 0xFF; } } // EtoC: c = e, convert E value to C value *** assumes that B is 12 *** // // c: 21 20 19 18 // x|x xx| xx x|x xx|... // e: 31 30 29 28 27 // void EtoC( C c, E e) { c[N-1] = e[M-1] >> 4; c[N-2] = ((e[M-1] & 0xF) << 8) | e[M-2]; for( int i = N-3, j = M-3; i > 0; i -=2, j -=3) { c[i] = (e[j] << 4) | (e[j-1] >> 4); c[i-1] = ((e[j-1] & 0xF) << 8) | e[j-2]; } } // EEtoD: w = e, convert EE value to D value *** assumes that B is 12 *** // // w: xx xxx xxx xxx ... xxx NN = 43 // 42 41 40 39 ... 0 // // e: xx xx xx xx ... xx MM = 64 // 63 62 61 60 ... 0 // // w: 42 41 40 39 // xx|xx x|x xx|xx x|x.. // e: 63 62 61 60 59 58 // void EEtoD( D w, EE e) { w[NN-1] = e[MM-1]; for( int i = NN-2, j = MM-2; i > 0; i -= 2, j -= 3) { w[i] = (e[j] << 4) | (e[j-1] >> 4); w[i-1] = ((e[j-1] & 0xF) << 8) | e[j-2]; } }