import hashlib; import os; #Compute candidate square root of x modulo p, with p = 3 (mod 4). def sqrt4k3(x,p): return pow(x,(p + 1)//4,p) #Compute candidate square root of x modulo p, with p = 5 (mod 8). def sqrt8k5(x,p): y = pow(x,(p+3)//8,p) #If the square root exists, it is either y, or y*2^(p-1)/4. if (y * y) % p == x % p: return y else: z = pow(2,(p - 1)//4,p) return (y * z) % p #Decode a hexadecimal string representation of integer. def hexi(s): return int.from_bytes(bytes.fromhex(s),byteorder="big") #Rotate a word x by b places to the left. def rol(x,b): return ((x << b) | (x >> (64 - b))) & (2**64-1) #From little-endian. def from_le(s): return int.from_bytes(s, byteorder="little") #Do the SHA-3 state transform on state s. def sha3_transform(s): ROTATIONS = [0,1,62,28,27,36,44,6,55,20,3,10,43,25,39,41,45,15,\ 21,8,18,2,61,56,14] PERMUTATION = [1,6,9,22,14,20,2,12,13,19,23,15,4,24,21,8,16,5,3,\ 18,17,11,7,10] RC = [0x0000000000000001,0x0000000000008082,0x800000000000808a,\ 0x8000000080008000,0x000000000000808b,0x0000000080000001,\ 0x8000000080008081,0x8000000000008009,0x000000000000008a,\ 0x0000000000000088,0x0000000080008009,0x000000008000000a,\ 0x000000008000808b,0x800000000000008b,0x8000000000008089,\ 0x8000000000008003,0x8000000000008002,0x8000000000000080,\ 0x000000000000800a,0x800000008000000a,0x8000000080008081,\ 0x8000000000008080,0x0000000080000001,0x8000000080008008] for rnd in range(0,24): #AddColumnParity (Theta) c = [0]*5; d = [0]*5; for i in range(0,25): c[i%5]^=s[i] for i in range(0,5): d[i]=c[(i+4)%5]^rol(c[(i+1)%5],1) for i in range(0,25): s[i]^=d[i%5] #RotateWords (Rho). for i in range(0,25): s[i]=rol(s[i],ROTATIONS[i]) #PermuteWords (Pi) t = s[PERMUTATION[0]] for i in range(0,len(PERMUTATION)-1): s[PERMUTATION[i]]=s[PERMUTATION[i+1]] s[PERMUTATION[-1]]=t; #NonlinearMixRows (Chi) for i in range(0,25,5): t=[s[i],s[i+1],s[i+2],s[i+3],s[i+4],s[i],s[i+1]] for j in range(0,5): s[i+j]=t[j]^((~t[j+1])&(t[j+2])) #AddRoundConstant (Iota) s[0]^=RC[rnd] #Reinterpret octet array b to word array and XOR it to state s. def reinterpret_to_words_and_xor(s,b): for j in range(0,len(b)//8): s[j]^=from_le(b[8*j:][:8]) #Reinterpret word array w to octet array and return it. def reinterpret_to_octets(w): mp=bytearray() for j in range(0,len(w)): mp+=w[j].to_bytes(8,byteorder="little") return mp #(semi-)generic SHA-3 implementation def sha3_raw(msg,r_w,o_p,e_b): r_b=8*r_w s=[0]*25 #Handle whole blocks. idx=0 blocks=len(msg)//r_b for i in range(0,blocks): reinterpret_to_words_and_xor(s,msg[idx:][:r_b]) idx+=r_b sha3_transform(s) #Handle last block padding. m=bytearray(msg[idx:]) m.append(o_p) while len(m) < r_b: m.append(0) m[len(m)-1]|=128 #Handle padded last block. reinterpret_to_words_and_xor(s,m) sha3_transform(s) #Output. out = bytearray() while len(out)>((b-1)&7) #Decode y. If this fails, fail. y = self.base_field.frombytes(s,b) if y is None: return (None,None) #Try to recover x. If it does not exist, or is zero and xs is #wrong, fail. x=self.solve_x2(y).sqrt() if x is None or (x.iszero() and xs!=x.sign()): return (None,None) #If sign of x isn't correct, flip it. if x.sign()!=xs: x=-x # Return the constructed point. return (x,y) def encode_base(self,b): xp,yp=self.x/self.z,self.y/self.z #Encode y. s=bytearray(yp.tobytes(b)) #Add sign bit of x to encoding. if xp.sign()!=0: s[(b-1)//8]|=1<<(b-1)%8 return s def __mul__(self,x): r=self.zero_elem() s=self while x > 0: if (x%2)>0: r=r+s s=s.double() x=x//2 return r #Check two points are equal. def __eq__(self,y): #Need to check x1/z1 == x2/z2 and similarly for y, so cross- #multiply to eliminate divisions. xn1=self.x*y.z xn2=y.x*self.z yn1=self.y*y.z yn2=y.y*self.z return xn1==xn2 and yn1==yn2 #Check two points are not equal. def __ne__(self,y): return not (self==y) #A point on Edwards25519 class Edwards25519Point(EdwardsPoint): #Create a new point on curve. base_field=Field(1,2**255-19) d=-base_field.make(121665)/base_field.make(121666) f0=base_field.make(0) f1=base_field.make(1) xb=base_field.make(hexi("216936D3CD6E53FEC0A4E231FDD6DC5C692CC76"+\ "09525A7B2C9562D608F25D51A")) yb=base_field.make(hexi("666666666666666666666666666666666666666"+\ "6666666666666666666666658")) #The standard base point. @staticmethod def stdbase(): return Edwards25519Point(Edwards25519Point.xb,\ Edwards25519Point.yb) def __init__(self,x,y): #Check the point is actually on the curve. if y*y-x*x!=self.f1+self.d*x*x*y*y: raise ValueError("Invalid point") self.initpoint(x, y) self.t=x*y #Decode a point representation. def decode(self,s): x,y=self.decode_base(s,256); return Edwards25519Point(x, y) if x is not None else None #Encode a point representation def encode(self): return self.encode_base(256) #Construct neutral point on this curve. def zero_elem(self): return Edwards25519Point(self.f0,self.f1) #Solve for x^2. def solve_x2(self,y): return ((y*y-self.f1)/(self.d*y*y+self.f1)) #Point addition. def __add__(self,y): #The formulas are from EFD. tmp=self.zero_elem() zcp=self.z*y.z A=(self.y-self.x)*(y.y-y.x) B=(self.y+self.x)*(y.y+y.x) C=(self.d+self.d)*self.t*y.t D=zcp+zcp E,H=B-A,B+A F,G=D-C,D+C tmp.x,tmp.y,tmp.z,tmp.t=E*F,G*H,F*G,E*H return tmp #Point doubling. def double(self): #The formulas are from EFD (with assumption a=-1 propagated). tmp=self.zero_elem() A=self.x*self.x B=self.y*self.y Ch=self.z*self.z C=Ch+Ch H=A+B xys=self.x+self.y E=H-xys*xys G=A-B F=C+G tmp.x,tmp.y,tmp.z,tmp.t=E*F,G*H,F*G,E*H return tmp #Order of basepoint. def l(self): return hexi("1000000000000000000000000000000014def9dea2f79cd"+\ "65812631a5cf5d3ed") #The logarithm of cofactor. def c(self): return 3 #The highest set bit def n(self): return 254 #The coding length def b(self): return 256 #Validity check (for debugging) def is_valid_point(self): x,y,z,t=self.x,self.y,self.z,self.t x2=x*x y2=y*y z2=z*z lhs=(y2-x2)*z2 rhs=z2*z2+self.d*x2*y2 assert(lhs == rhs) assert(t*z == x*y) #A point on Edward448 class Edwards448Point(EdwardsPoint): #Create a new point on curve. base_field=Field(1,2**448-2**224-1) d=base_field.make(-39081) f0=base_field.make(0) f1=base_field.make(1) xb=base_field.make(hexi("4F1970C66BED0DED221D15A622BF36DA9E14657"+\ "0470F1767EA6DE324A3D3A46412AE1AF72AB66511433B80E18B00938E26"+\ "26A82BC70CC05E")) yb=base_field.make(hexi("693F46716EB6BC248876203756C9C7624BEA737"+\ "36CA3984087789C1E05A0C2D73AD3FF1CE67C39C4FDBD132C4ED7C8AD98"+\ "08795BF230FA14")) #The standard base point. @staticmethod def stdbase(): return Edwards448Point(Edwards448Point.xb,Edwards448Point.yb) def __init__(self,x,y): #Check the point is actually on the curve. if y*y+x*x!=self.f1+self.d*x*x*y*y: raise ValueError("Invalid point") self.initpoint(x, y) #Decode a point representation. def decode(self,s): x,y=self.decode_base(s,456); return Edwards448Point(x, y) if x is not None else None #Encode a point representation def encode(self): return self.encode_base(456) #Construct neutral point on this curve. def zero_elem(self): return Edwards448Point(self.f0,self.f1) #Solve for x^2. def solve_x2(self,y): return ((y*y-self.f1)/(self.d*y*y-self.f1)) #Point addition. def __add__(self,y): #The formulas are from EFD. tmp=self.zero_elem() xcp,ycp,zcp=self.x*y.x,self.y*y.y,self.z*y.z B=zcp*zcp E=self.d*xcp*ycp F,G=B-E,B+E tmp.x=zcp*F*((self.x+self.y)*(y.x+y.y)-xcp-ycp) tmp.y,tmp.z=zcp*G*(ycp-xcp),F*G return tmp #Point doubling. def double(self): #The formulas are from EFD. tmp=self.zero_elem() x1s,y1s,z1s=self.x*self.x,self.y*self.y,self.z*self.z xys=self.x+self.y F=x1s+y1s J=F-(z1s+z1s) tmp.x,tmp.y,tmp.z=(xys*xys-x1s-y1s)*J,F*(x1s-y1s),F*J return tmp #Order of basepoint. def l(self): return hexi("3ffffffffffffffffffffffffffffffffffffffffffffff"+\ "fffffffff7cca23e9c44edb49aed63690216cc2728dc58f552378c2"+\ "92ab5844f3") #The logarithm of cofactor. def c(self): return 2 #The highest set bit def n(self): return 447 #The coding length def b(self): return 456 #Validity check (for debugging) def is_valid_point(self): x,y,z=self.x,self.y,self.z x2=x*x y2=y*y z2=z*z lhs=(x2+y2)*z2 rhs=z2*z2+self.d*x2*y2 assert(lhs == rhs) #Simple self-check. def curve_self_check(point): p=point q=point.zero_elem() z=q l=p.l()+1 p.is_valid_point() q.is_valid_point() for i in range(0,point.b()): if (l>>i)&1 != 0: q=q+p q.is_valid_point() p=p.double() p.is_valid_point() assert q.encode() == point.encode() assert q.encode() != p.encode() assert q.encode() != z.encode() #Simple self-check. def self_check_curves(): curve_self_check(Edwards25519Point.stdbase()) curve_self_check(Edwards448Point.stdbase()) #PureEdDSA scheme. #Limitation: Only b mod 8 = 0 is handled. class PureEdDSA: #Create a new object. def __init__(self,properties): self.B=properties["B"] self.H=properties["H"] self.l=self.B.l() self.n=self.B.n() self.b=self.B.b() self.c=self.B.c() #Clamp a private scalar. def __clamp(self,a): _a = bytearray(a) for i in range(0,self.c): _a[i//8]&=~(1<<(i%8)) _a[self.n//8]|=1<<(self.n%8) for i in range(self.n+1,self.b): _a[i//8]&=~(1<<(i%8)) return _a #Generate a key. If privkey is None, random one is generated. #In any case, privkey, pubkey pair is returned. def keygen(self,privkey): #If no private key data given, generate random. if privkey is None: privkey=os.urandom(self.b//8) #Expand key. khash=self.H(privkey,None,None) a=from_le(self.__clamp(khash[:self.b//8])) #Return the keypair (public key is A=Enc(aB). return privkey,(self.B*a).encode() #Sign with keypair. def sign(self,privkey,pubkey,msg,ctx,hflag): #Expand key. khash=self.H(privkey,None,None) a=from_le(self.__clamp(khash[:self.b//8])) seed=khash[self.b//8:] #Calculate r and R (R only used in encoded form) r=from_le(self.H(seed+msg,ctx,hflag))%self.l R=(self.B*r).encode() #Calculate h. h=from_le(self.H(R+pubkey+msg,ctx,hflag))%self.l #Calculate s. S=((r+h*a)%self.l).to_bytes(self.b//8,byteorder="little") #The final signature is concatenation of R and S. return R+S #Verify signature with public key. def verify(self,pubkey,msg,sig,ctx,hflag): #Sanity-check sizes. if len(sig)!=self.b//4: return False if len(pubkey)!=self.b//8: return False #Split signature into R and S, and parse. Rraw,Sraw=sig[:self.b//8],sig[self.b//8:] R,S=self.B.decode(Rraw),from_le(Sraw) #Parse public key. A=self.B.decode(pubkey) #Check parse results. if (R is None) or (A is None) or S>=self.l: return False #Calculate h. h=from_le(self.H(Rraw+pubkey+msg,ctx,hflag))%self.l #Calculate left and right sides of check eq. rhs=R+(A*h) lhs=self.B*S for i in range(0, self.c): lhs = lhs.double() rhs = rhs.double() #Check eq. holds? return lhs==rhs def Ed25519_inthash(data,ctx,hflag): if (ctx is not None and len(ctx) > 0) or hflag: raise ValueError("Contexts/hashes not supported") return hashlib.sha512(data).digest() #The base PureEdDSA schemes. pEd25519=PureEdDSA({\ "B":Edwards25519Point.stdbase(),\ "H":Ed25519_inthash\ }) def Ed25519ctx_inthash(data,ctx,hflag): dompfx = b"" PREFIX=b"SigEd25519 no Ed25519 collisions" if ctx is not None: if len(ctx) > 255: raise ValueError("Context too big") dompfx=PREFIX+bytes([1 if hflag else 0,len(ctx)])+ctx return hashlib.sha512(dompfx+data).digest() pEd25519ctx=PureEdDSA({\ "B":Edwards25519Point.stdbase(),\ "H":Ed25519ctx_inthash\ }) def Ed448_inthash(data,ctx,hflag): dompfx = b"" if ctx is not None: if len(ctx) > 255: raise ValueError("Context too big") dompfx=b"SigEd448"+bytes([1 if hflag else 0,len(ctx)])+ctx return shake256(dompfx+data,114) pEd448 = PureEdDSA({\ "B":Edwards448Point.stdbase(),\ "H":Ed448_inthash\ }) #EdDSA scheme. class EdDSA: #Create a new scheme object, with specified PureEdDSA base scheme #and specified prehash. def __init__(self,pure_scheme,prehash): self.__pflag = True self.__pure=pure_scheme self.__prehash=prehash if self.__prehash is None: self.__prehash = lambda x,y:x self.__pflag = False # Generate a key. If privkey is none, generates a random # privkey key, otherwise uses specified private key. # Returns pair (privkey, pubkey). def keygen(self,privkey): return self.__pure.keygen(privkey) # Sign message msg using specified keypair. def sign(self,privkey,pubkey,msg,ctx=None): if ctx is None: ctx=b""; return self.__pure.sign(privkey,pubkey,self.__prehash(msg,ctx),\ ctx,self.__pflag) # Verify signature sig on message msg using public key pubkey. def verify(self,pubkey,msg,sig,ctx=None): if ctx is None: ctx=b""; return self.__pure.verify(pubkey,self.__prehash(msg,ctx),sig,\ ctx,self.__pflag) def Ed448ph_prehash(data,ctx): return shake256(data,64) #Our signature schemes. Ed25519 = EdDSA(pEd25519,None) Ed25519ctx = EdDSA(pEd25519ctx,None) Ed25519ph = EdDSA(pEd25519ctx,lambda x,y:hashlib.sha512(x).digest()) Ed448 = EdDSA(pEd448,None) Ed448ph = EdDSA(pEd448,Ed448ph_prehash) def eddsa_obj(name): if name == "Ed25519": return Ed25519 if name == "Ed25519ctx": return Ed25519ctx if name == "Ed25519ph": return Ed25519ph if name == "Ed448": return Ed448 if name == "Ed448ph": return Ed448ph raise NotImplementedError("Algorithm not implemented")