diff --git a/common-headers.txt b/common-headers.txt index 4be1f15..009d1f9 100644 --- a/common-headers.txt +++ b/common-headers.txt @@ -2,4 +2,5 @@ typ jku kid x5u -x5t \ No newline at end of file +x5t +url diff --git a/jwt_tool.py b/jwt_tool.py old mode 100644 new mode 100755 index 99ca1ab..d15cbb7 --- a/jwt_tool.py +++ b/jwt_tool.py @@ -17,12 +17,13 @@ import base64 import json import random -from urllib.parse import urljoin, urlparse +import tempfile import argparse from datetime import datetime import configparser from http.cookies import SimpleCookie from collections import OrderedDict +from urllib.parse import urljoin, urlparse try: from Cryptodome.Signature import PKCS1_v1_5, DSS, pss from Cryptodome.Hash import SHA256, SHA384, SHA512 @@ -56,6 +57,10 @@ def cprintc(textval, colval): if not args.bare: cprint(textval, colval) +def b64pad(buf): + """ Restore stripped B64 padding """ + return buf + '=' * (4 - len(buf) % 4 if len(buf) % 4 in (2, 3) else 0) + def createConfig(): privKeyName = path+"/jwttool_custom_private_RSA.pem" pubkeyName = path+"/jwttool_custom_public_RSA.pem" @@ -867,34 +872,13 @@ def verifyTokenRSA(headDict, paylDict, sig, pubKey): key = RSA.importKey(open(pubKey).read()) newContents = genContents(headDict, paylDict) newContents = newContents.encode('UTF-8') - if "-" in sig: - try: - sig = base64.urlsafe_b64decode(sig) - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"=") - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"==") - except: - pass - elif "+" in sig: - try: - sig = base64.b64decode(sig) - except: - pass - try: - sig = base64.b64decode(sig+"=") - except: - pass + try: + sig = base64.urlsafe_b64decode(b64pad(sig)) + except ValueError: try: - sig = base64.b64decode(sig+"==") - except: - pass - else: - cprintc("Signature not Base64 encoded HEX", "red") + sig = base64.b64decode(b64pad(sig)) + except ValueError: + cprintc("Signature not Base64 encoded HEX", "red") if headDict['alg'] == "RS256": h = SHA256.new(newContents) elif headDict['alg'] == "RS384": @@ -919,50 +903,33 @@ def verifyTokenRSA(headDict, paylDict, sig, pubKey): def verifyTokenEC(headDict, paylDict, sig, pubKey): newContents = genContents(headDict, paylDict) message = newContents.encode('UTF-8') - if "-" in str(sig): - try: - signature = base64.urlsafe_b64decode(sig) - except: - pass - try: - signature = base64.urlsafe_b64decode(sig+"=") - except: - pass - try: - signature = base64.urlsafe_b64decode(sig+"==") - except: - pass - elif "+" in str(sig): - try: - signature = base64.b64decode(sig) - except: - pass - try: - signature = base64.b64decode(sig+"=") - except: - pass + try: + sig = base64.urlsafe_b64decode(b64pad(sig)) + except ValueError: try: - signature = base64.b64decode(sig+"==") - except: - pass - else: - cprintc("Signature not Base64 encoded HEX", "red") + sig = base64.b64decode(b64pad(sig)) + except ValueError: + cprintc("Signature not Base64 encoded HEX", "red") + if headDict['alg'] == "ES256": - h = SHA256.new(message) + h, curvename = SHA256.new(message), 'P-256' elif headDict['alg'] == "ES384": - h = SHA384.new(message) + h, curvename = SHA384.new(message), 'P-384' elif headDict['alg'] == "ES512": - h = SHA512.new(message) + h, curvename = SHA512.new(message), 'P-521' else: cprintc("Invalid ECDSA algorithm", "red") pubkey = open(pubKey, "r") pub_key = ECC.import_key(pubkey.read()) + cprintc("[ ] loaded ECC pubkey on the curve {}".format(pub_key.curve), "cyan") + assert pub_key.curve == 'NIST ' + curvename, "Key on unexpected curve loaded" + verifier = DSS.new(pub_key, 'fips-186-3') try: - verifier.verify(h, signature) + verifier.verify(h, sig) cprintc("ECC Signature is VALID", "green") valid = True - except: + except ValueError: cprintc("ECC Signature is INVALID", "red") valid = False return valid @@ -971,34 +938,13 @@ def verifyTokenPSS(headDict, paylDict, sig, pubKey): key = RSA.importKey(open(pubKey).read()) newContents = genContents(headDict, paylDict) newContents = newContents.encode('UTF-8') - if "-" in sig: - try: - sig = base64.urlsafe_b64decode(sig) - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"=") - except: - pass - try: - sig = base64.urlsafe_b64decode(sig+"==") - except: - pass - elif "+" in sig: - try: - sig = base64.b64decode(sig) - except: - pass - try: - sig = base64.b64decode(sig+"=") - except: - pass + try: + sig = base64.urlsafe_b64decode(b64pad(sig)) + except ValueError: try: - sig = base64.b64decode(sig+"==") - except: - pass - else: - cprintc("Signature not Base64 encoded HEX", "red") + sig = base64.b64decode(b64pad(sig)) + except ValueError: + cprintc("Signature not Base64 encoded HEX", "red") if headDict['alg'] == "PS256": h = SHA256.new(newContents) elif headDict['alg'] == "PS384": @@ -1027,100 +973,57 @@ def exportJWKS(jku): return newContents, newSig def parseJWKS(jwksfile): - jwks = open(jwksfile, "r").read() - jwksDict = json.loads(jwks, object_pairs_hook=OrderedDict) + jwksDict = json.load(open(jwksfile, 'r'), object_pairs_hook=OrderedDict) nowtime = int(datetime.now().timestamp()) cprintc("JWKS Contents:", "cyan") try: keyLen = len(jwksDict["keys"]) cprintc("Number of keys: "+str(keyLen), "cyan") - i = -1 - for jkey in range(0,keyLen): - i += 1 + kids_seen = set() + new_kid = lambda: 1 + max([x for x in kids_seen if isinstance(x, int)], default=0) + any1valid = False + for d in jwksDict["keys"]: cprintc("\n--------", "white") - try: - cprintc("Key "+str(i+1), "cyan") - kid = str(jwksDict["keys"][i]["kid"]) - cprintc("kid: "+kid, "cyan") - except: - kid = i - cprintc("Key "+str(i+1), "cyan") - for keyVal in jwksDict["keys"][i].items(): - keyVal = keyVal[0] - cprintc("[+] "+keyVal+" = "+str(jwksDict["keys"][i][keyVal]), "green") - try: - x = str(jwksDict["keys"][i]["x"]) - y = str(jwksDict["keys"][i]["y"]) - cprintc("\nFound ECC key factors, generating a public key", "cyan") - pubkeyName = genECPubFromJWKS(x, y, kid, nowtime) - cprintc("[+] "+pubkeyName, "green") - cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenEC(headDict, paylDict, sig, pubkeyName) - except: - pass - try: - n = str(jwksDict["keys"][i]["n"]) - e = str(jwksDict["keys"][i]["e"]) - cprintc("\nFound RSA key factors, generating a public key", "cyan") - pubkeyName = genRSAPubFromJWKS(n, e, kid, nowtime) - cprintc("[+] "+pubkeyName, "green") - cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenRSA(headDict, paylDict, sig, pubkeyName) - except: - pass - except: + kid = d['kid'] if 'kid' in d else new_kid() + kids_seen.add(kid) + cprintc(f"Key kid {kid}", "cyan") + for k, v in d.items(): + cprintc(f"[+] {k} = {v}", "green") + if parseSingleJWK(d, nowtime, kid=kid): + any1valid = True + return any1valid + except ValueError: cprintc("Single key file", "white") for jkey in jwksDict: cprintc("[+] "+jkey+" = "+str(jwksDict[jkey]), "green") + return parseSingleJWK(jwksDict, nowtime) + +def parseSingleJWK(jwksDict, nowtime, kid=1): try: - kid = 1 x = str(jwksDict["x"]) y = str(jwksDict["y"]) cprintc("\nFound ECC key factors, generating a public key", "cyan") - pubkeyName = genECPubFromJWKS(x, y, kid, nowtime) + pubkeyName = genECPubFromJWKS(x, y, kid, nowtime, curve=jwksDict.get('crv')) cprintc("[+] "+pubkeyName, "green") cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenEC(headDict, paylDict, sig, pubkeyName) - except: + return verifyTokenEC(headDict, paylDict, sig, pubkeyName) + except KeyError: pass try: - kid = 1 n = str(jwksDict["n"]) e = str(jwksDict["e"]) cprintc("\nFound RSA key factors, generating a public key", "cyan") pubkeyName = genRSAPubFromJWKS(n, e, kid, nowtime) cprintc("[+] "+pubkeyName, "green") cprintc("\nAttempting to verify token using "+pubkeyName, "cyan") - valid = verifyTokenRSA(headDict, paylDict, sig, pubkeyName) + return verifyTokenRSA(headDict, paylDict, sig, pubkeyName) except: pass -def genECPubFromJWKS(x, y, kid, nowtime): - try: - x = int.from_bytes(base64.urlsafe_b64decode(x), byteorder='big') - except: - pass - try: - x = int.from_bytes(base64.urlsafe_b64decode(x+"="), byteorder='big') - except: - pass - try: - x = int.from_bytes(base64.urlsafe_b64decode(x+"=="), byteorder='big') - except: - pass - try: - y = int.from_bytes(base64.urlsafe_b64decode(y), byteorder='big') - except: - pass - try: - y = int.from_bytes(base64.urlsafe_b64decode(y+"="), byteorder='big') - except: - pass - try: - y = int.from_bytes(base64.urlsafe_b64decode(y+"=="), byteorder='big') - except: - pass - new_key = ECC.construct(curve='P-256', point_x=x, point_y=y) +def genECPubFromJWKS(x, y, kid, nowtime, curve=None): + x = int.from_bytes(base64.urlsafe_b64decode(b64pad(x)), byteorder='big') + y = int.from_bytes(base64.urlsafe_b64decode(b64pad(y)), byteorder='big') + new_key = ECC.construct(curve=curve or 'P-256', point_x=x, point_y=y) pubKey = new_key.public_key().export_key(format="PEM")+"\n" pubkeyName = "kid_"+str(kid)+"_"+str(nowtime)+".pem" with open(pubkeyName, 'w') as test_pub_out: @@ -1128,30 +1031,8 @@ def genECPubFromJWKS(x, y, kid, nowtime): return pubkeyName def genRSAPubFromJWKS(n, e, kid, nowtime): - try: - n = int.from_bytes(base64.urlsafe_b64decode(n), byteorder='big') - except: - pass - try: - n = int.from_bytes(base64.urlsafe_b64decode(n+"="), byteorder='big') - except: - pass - try: - n = int.from_bytes(base64.urlsafe_b64decode(n+"=="), byteorder='big') - except: - pass - try: - e = int.from_bytes(base64.urlsafe_b64decode(e), byteorder='big') - except: - pass - try: - e = int.from_bytes(base64.urlsafe_b64decode(e+"="), byteorder='big') - except: - pass - try: - e = int.from_bytes(base64.urlsafe_b64decode(e+"=="), byteorder='big') - except: - pass + n = int.from_bytes(base64.urlsafe_b64decode(b64pad(n)), byteorder='big') + e = int.from_bytes(base64.urlsafe_b64decode(b64pad(e)), byteorder='big') new_key = RSA.construct((n, e)) pubKey = new_key.publickey().exportKey(format="PEM") pubkeyName = "kid_"+str(kid)+"_"+str(nowtime)+".pem" @@ -1786,8 +1667,20 @@ def runActions(): else: cprintc("Algorithm not supported for verification", "red") exit(1) + elif args.jwksfile: parseJWKS(config['crypto']['jwks']) + + elif args.jwksurl: + resp = requests.get(args.jwksurl) + assert resp.ok + + with tempfile.NamedTemporaryFile() as tmp: + tmp.write(resp.content) + tmp.flush() + tmp.seek(0) + valid = parseJWKS(tmp.name) + exit(0 if valid else 1) else: cprintc("No Public Key or JWKS file provided (-pk/-jw)\n", "red") parser.print_usage() @@ -1908,8 +1801,6 @@ def printLogo(): os.rename(configFileName, path+"/old_("+config['services']['jwt_tool_version']+")_jwtconf.ini") createConfig() exit(1) - with open(path+"/null.txt", 'w') as nullfile: - pass findJWT = "" if args.request: