Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion common-headers.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@ typ
jku
kid
x5u
x5t
x5t
url
253 changes: 72 additions & 181 deletions jwt_tool.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import base64
import json
import random
from urllib.parse import urljoin, urlparse
import tempfile
import argparse
from datetime import datetime
import configparser
from http.cookies import SimpleCookie
from collections import OrderedDict
from urllib.parse import urljoin, urlparse
try:
from Cryptodome.Signature import PKCS1_v1_5, DSS, pss
from Cryptodome.Hash import SHA256, SHA384, SHA512
Expand Down Expand Up @@ -56,6 +57,10 @@ def cprintc(textval, colval):
if not args.bare:
cprint(textval, colval)

def b64pad(buf):
""" Restore stripped B64 padding """
return buf + '=' * (4 - len(buf) % 4 if len(buf) % 4 in (2, 3) else 0)

def createConfig():
privKeyName = path+"/jwttool_custom_private_RSA.pem"
pubkeyName = path+"/jwttool_custom_public_RSA.pem"
Expand Down Expand Up @@ -867,34 +872,13 @@ def verifyTokenRSA(headDict, paylDict, sig, pubKey):
key = RSA.importKey(open(pubKey).read())
newContents = genContents(headDict, paylDict)
newContents = newContents.encode('UTF-8')
if "-" in sig:
try:
sig = base64.urlsafe_b64decode(sig)
except:
pass
try:
sig = base64.urlsafe_b64decode(sig+"=")
except:
pass
try:
sig = base64.urlsafe_b64decode(sig+"==")
except:
pass
elif "+" in sig:
try:
sig = base64.b64decode(sig)
except:
pass
try:
sig = base64.b64decode(sig+"=")
except:
pass
try:
sig = base64.urlsafe_b64decode(b64pad(sig))
except ValueError:
try:
sig = base64.b64decode(sig+"==")
except:
pass
else:
cprintc("Signature not Base64 encoded HEX", "red")
sig = base64.b64decode(b64pad(sig))
except ValueError:
cprintc("Signature not Base64 encoded HEX", "red")
if headDict['alg'] == "RS256":
h = SHA256.new(newContents)
elif headDict['alg'] == "RS384":
Expand All @@ -919,50 +903,33 @@ def verifyTokenRSA(headDict, paylDict, sig, pubKey):
def verifyTokenEC(headDict, paylDict, sig, pubKey):
newContents = genContents(headDict, paylDict)
message = newContents.encode('UTF-8')
if "-" in str(sig):
try:
signature = base64.urlsafe_b64decode(sig)
except:
pass
try:
signature = base64.urlsafe_b64decode(sig+"=")
except:
pass
try:
signature = base64.urlsafe_b64decode(sig+"==")
except:
pass
elif "+" in str(sig):
try:
signature = base64.b64decode(sig)
except:
pass
try:
signature = base64.b64decode(sig+"=")
except:
pass
try:
sig = base64.urlsafe_b64decode(b64pad(sig))
except ValueError:
try:
signature = base64.b64decode(sig+"==")
except:
pass
else:
cprintc("Signature not Base64 encoded HEX", "red")
sig = base64.b64decode(b64pad(sig))
except ValueError:
cprintc("Signature not Base64 encoded HEX", "red")

if headDict['alg'] == "ES256":
h = SHA256.new(message)
h, curvename = SHA256.new(message), 'P-256'
elif headDict['alg'] == "ES384":
h = SHA384.new(message)
h, curvename = SHA384.new(message), 'P-384'
elif headDict['alg'] == "ES512":
h = SHA512.new(message)
h, curvename = SHA512.new(message), 'P-521'
else:
cprintc("Invalid ECDSA algorithm", "red")
pubkey = open(pubKey, "r")
pub_key = ECC.import_key(pubkey.read())
cprintc("[ ] loaded ECC pubkey on the curve {}".format(pub_key.curve), "cyan")
assert pub_key.curve == 'NIST ' + curvename, "Key on unexpected curve loaded"

verifier = DSS.new(pub_key, 'fips-186-3')
try:
verifier.verify(h, signature)
verifier.verify(h, sig)
cprintc("ECC Signature is VALID", "green")
valid = True
except:
except ValueError:
cprintc("ECC Signature is INVALID", "red")
valid = False
return valid
Expand All @@ -971,34 +938,13 @@ def verifyTokenPSS(headDict, paylDict, sig, pubKey):
key = RSA.importKey(open(pubKey).read())
newContents = genContents(headDict, paylDict)
newContents = newContents.encode('UTF-8')
if "-" in sig:
try:
sig = base64.urlsafe_b64decode(sig)
except:
pass
try:
sig = base64.urlsafe_b64decode(sig+"=")
except:
pass
try:
sig = base64.urlsafe_b64decode(sig+"==")
except:
pass
elif "+" in sig:
try:
sig = base64.b64decode(sig)
except:
pass
try:
sig = base64.b64decode(sig+"=")
except:
pass
try:
sig = base64.urlsafe_b64decode(b64pad(sig))
except ValueError:
try:
sig = base64.b64decode(sig+"==")
except:
pass
else:
cprintc("Signature not Base64 encoded HEX", "red")
sig = base64.b64decode(b64pad(sig))
except ValueError:
cprintc("Signature not Base64 encoded HEX", "red")
if headDict['alg'] == "PS256":
h = SHA256.new(newContents)
elif headDict['alg'] == "PS384":
Expand Down Expand Up @@ -1027,131 +973,66 @@ def exportJWKS(jku):
return newContents, newSig

def parseJWKS(jwksfile):
jwks = open(jwksfile, "r").read()
jwksDict = json.loads(jwks, object_pairs_hook=OrderedDict)
jwksDict = json.load(open(jwksfile, 'r'), object_pairs_hook=OrderedDict)
nowtime = int(datetime.now().timestamp())
cprintc("JWKS Contents:", "cyan")
try:
keyLen = len(jwksDict["keys"])
cprintc("Number of keys: "+str(keyLen), "cyan")
i = -1
for jkey in range(0,keyLen):
i += 1
kids_seen = set()
new_kid = lambda: 1 + max([x for x in kids_seen if isinstance(x, int)], default=0)
any1valid = False
for d in jwksDict["keys"]:
cprintc("\n--------", "white")
try:
cprintc("Key "+str(i+1), "cyan")
kid = str(jwksDict["keys"][i]["kid"])
cprintc("kid: "+kid, "cyan")
except:
kid = i
cprintc("Key "+str(i+1), "cyan")
for keyVal in jwksDict["keys"][i].items():
keyVal = keyVal[0]
cprintc("[+] "+keyVal+" = "+str(jwksDict["keys"][i][keyVal]), "green")
try:
x = str(jwksDict["keys"][i]["x"])
y = str(jwksDict["keys"][i]["y"])
cprintc("\nFound ECC key factors, generating a public key", "cyan")
pubkeyName = genECPubFromJWKS(x, y, kid, nowtime)
cprintc("[+] "+pubkeyName, "green")
cprintc("\nAttempting to verify token using "+pubkeyName, "cyan")
valid = verifyTokenEC(headDict, paylDict, sig, pubkeyName)
except:
pass
try:
n = str(jwksDict["keys"][i]["n"])
e = str(jwksDict["keys"][i]["e"])
cprintc("\nFound RSA key factors, generating a public key", "cyan")
pubkeyName = genRSAPubFromJWKS(n, e, kid, nowtime)
cprintc("[+] "+pubkeyName, "green")
cprintc("\nAttempting to verify token using "+pubkeyName, "cyan")
valid = verifyTokenRSA(headDict, paylDict, sig, pubkeyName)
except:
pass
except:
kid = d['kid'] if 'kid' in d else new_kid()
kids_seen.add(kid)
cprintc(f"Key kid {kid}", "cyan")
for k, v in d.items():
cprintc(f"[+] {k} = {v}", "green")
if parseSingleJWK(d, nowtime, kid=kid):
any1valid = True
return any1valid
except ValueError:
cprintc("Single key file", "white")
for jkey in jwksDict:
cprintc("[+] "+jkey+" = "+str(jwksDict[jkey]), "green")
return parseSingleJWK(jwksDict, nowtime)

def parseSingleJWK(jwksDict, nowtime, kid=1):
try:
kid = 1
x = str(jwksDict["x"])
y = str(jwksDict["y"])
cprintc("\nFound ECC key factors, generating a public key", "cyan")
pubkeyName = genECPubFromJWKS(x, y, kid, nowtime)
pubkeyName = genECPubFromJWKS(x, y, kid, nowtime, curve=jwksDict.get('crv'))
cprintc("[+] "+pubkeyName, "green")
cprintc("\nAttempting to verify token using "+pubkeyName, "cyan")
valid = verifyTokenEC(headDict, paylDict, sig, pubkeyName)
except:
return verifyTokenEC(headDict, paylDict, sig, pubkeyName)
except KeyError:
pass
try:
kid = 1
n = str(jwksDict["n"])
e = str(jwksDict["e"])
cprintc("\nFound RSA key factors, generating a public key", "cyan")
pubkeyName = genRSAPubFromJWKS(n, e, kid, nowtime)
cprintc("[+] "+pubkeyName, "green")
cprintc("\nAttempting to verify token using "+pubkeyName, "cyan")
valid = verifyTokenRSA(headDict, paylDict, sig, pubkeyName)
return verifyTokenRSA(headDict, paylDict, sig, pubkeyName)
except:
pass

def genECPubFromJWKS(x, y, kid, nowtime):
try:
x = int.from_bytes(base64.urlsafe_b64decode(x), byteorder='big')
except:
pass
try:
x = int.from_bytes(base64.urlsafe_b64decode(x+"="), byteorder='big')
except:
pass
try:
x = int.from_bytes(base64.urlsafe_b64decode(x+"=="), byteorder='big')
except:
pass
try:
y = int.from_bytes(base64.urlsafe_b64decode(y), byteorder='big')
except:
pass
try:
y = int.from_bytes(base64.urlsafe_b64decode(y+"="), byteorder='big')
except:
pass
try:
y = int.from_bytes(base64.urlsafe_b64decode(y+"=="), byteorder='big')
except:
pass
new_key = ECC.construct(curve='P-256', point_x=x, point_y=y)
def genECPubFromJWKS(x, y, kid, nowtime, curve=None):
x = int.from_bytes(base64.urlsafe_b64decode(b64pad(x)), byteorder='big')
y = int.from_bytes(base64.urlsafe_b64decode(b64pad(y)), byteorder='big')
new_key = ECC.construct(curve=curve or 'P-256', point_x=x, point_y=y)
pubKey = new_key.public_key().export_key(format="PEM")+"\n"
pubkeyName = "kid_"+str(kid)+"_"+str(nowtime)+".pem"
with open(pubkeyName, 'w') as test_pub_out:
test_pub_out.write(pubKey)
return pubkeyName

def genRSAPubFromJWKS(n, e, kid, nowtime):
try:
n = int.from_bytes(base64.urlsafe_b64decode(n), byteorder='big')
except:
pass
try:
n = int.from_bytes(base64.urlsafe_b64decode(n+"="), byteorder='big')
except:
pass
try:
n = int.from_bytes(base64.urlsafe_b64decode(n+"=="), byteorder='big')
except:
pass
try:
e = int.from_bytes(base64.urlsafe_b64decode(e), byteorder='big')
except:
pass
try:
e = int.from_bytes(base64.urlsafe_b64decode(e+"="), byteorder='big')
except:
pass
try:
e = int.from_bytes(base64.urlsafe_b64decode(e+"=="), byteorder='big')
except:
pass
n = int.from_bytes(base64.urlsafe_b64decode(b64pad(n)), byteorder='big')
e = int.from_bytes(base64.urlsafe_b64decode(b64pad(e)), byteorder='big')
new_key = RSA.construct((n, e))
pubKey = new_key.publickey().exportKey(format="PEM")
pubkeyName = "kid_"+str(kid)+"_"+str(nowtime)+".pem"
Expand Down Expand Up @@ -1786,8 +1667,20 @@ def runActions():
else:
cprintc("Algorithm not supported for verification", "red")
exit(1)

elif args.jwksfile:
parseJWKS(config['crypto']['jwks'])

elif args.jwksurl:
resp = requests.get(args.jwksurl)
assert resp.ok

with tempfile.NamedTemporaryFile() as tmp:
tmp.write(resp.content)
tmp.flush()
tmp.seek(0)
valid = parseJWKS(tmp.name)
exit(0 if valid else 1)
else:
cprintc("No Public Key or JWKS file provided (-pk/-jw)\n", "red")
parser.print_usage()
Expand Down Expand Up @@ -1908,8 +1801,6 @@ def printLogo():
os.rename(configFileName, path+"/old_("+config['services']['jwt_tool_version']+")_jwtconf.ini")
createConfig()
exit(1)
with open(path+"/null.txt", 'w') as nullfile:
pass
findJWT = ""

if args.request:
Expand Down