class JwtAuthMiddleware(BaseMiddleware):
def __init__(self, inner):
self.inner = inner
async def __call__(self, scope, receive, send):
# Close old database connections to prevent usage of timed out connections
close_old_connections()
# Get the token
token = parse_qs(scope["query_string"].decode("utf8"))["token"][0]
# Try to authenticate the user
try:
# This will automatically validate the token and raise an error if token is invalid
UntypedToken(token)
except (InvalidToken, TokenError) as e:
# Token is invalid
print(e)
return None
else:
# Then token is valid, decode it
decoded_data = jwt_decode(token, settings.SECRET_KEY, algorithms=["HS256"])
# Get the user using ID
scope["user"] = await get_user(validated_token=decoded_data)
return await super().__call__(scope, receive, send)