from litestar.openapi import ResponseSpec
from webauthn_handlers import try_authenticate_user
from dataclasses import dataclass
from typing import Annotated, Any, cast
import typing
import re
import pyotp
from argon2.exceptions import VerifyMismatchError
from litestar import Response, post, patch, get
from litestar.connection import ASGIConnection, Request
from litestar.exceptions import (
ClientException,
NotAuthorizedException,
NotFoundException,
PermissionDeniedException,
)
from litestar.middleware import AbstractAuthenticationMiddleware, AuthenticationResult
import jwt
from argon2 import PasswordHasher
from datetime import datetime, timezone, timedelta
from litestar.params import Body, Parameter
from advanced_alchemy.extensions.litestar import SQLAlchemyPlugin
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from models.user import User
from responses import SuccessResponse
from config import config
from dto.read_dto import UserReadDTO
_BASE32_RE = re.compile(r"^[A-Z2-7]+=*$")
JWT_ALGORITHM = "HS256"
AUTH_COOKIE_NAME = "token"
TOTP_PENDING_PREFIX = "pending:"
# TODO: make configurable!
JWT_SECRET = "secretfortesting"
# how long tokens are valid. If the time has passed, users are automatically getting logged out
JWT_VALIDITY_DURATION_HOURS = 7 * 24 # 1 week
def _normalize_totp_code(code: str | None) -> str | None:
"""
Normalize user-entered TOTP codes.
Users often type codes with spaces or hyphens (e.g. "123 456" or "123-456").
We strip whitespace and hyphens so verification works with these inputs.
"""
if not code:
return None
return re.sub(r"[\s-]", "", code) or None
def _parse_totp_secret(raw: str | None) -> tuple[str | None, bool]:
"""
Parse a stored TOTP secret value.
Returns (secret, pending).
- secret is normalized to upper-case base32 (or None if invalid / missing)
- pending indicates whether the secret is still in "pending:" state
"""
if not raw:
return None, False
pending = raw.startswith(TOTP_PENDING_PREFIX)
secret = raw[len(TOTP_PENDING_PREFIX) :] if pending else raw
secret = secret.replace(" ", "").upper()
if not _BASE32_RE.fullmatch(secret):
return None, pending
return secret, pending
[docs]
def verify_totp(secret: str, code: str) -> bool:
return pyotp.TOTP(secret).verify(
code, valid_window=config.twofa_config.totp_valid_window
)
[docs]
def generate_totp_secret() -> str:
return pyotp.random_base32()
[docs]
def totp_provisioning_uri(secret: str, user_email: str) -> str:
return pyotp.TOTP(secret).provisioning_uri(
name=user_email, issuer_name=config.twofa_config.app_name
)
[docs]
@dataclass
class JwtUser:
id: str
name: str
email: str
[docs]
class AuthenticationMiddleware(AbstractAuthenticationMiddleware):
"""
Middleware that checks if the user has provided a valid jwt auth key as the 'Authentication' HTTP header.
"""
[docs]
async def authenticate_request(
self, connection: ASGIConnection
) -> AuthenticationResult:
token = connection.cookies.get(AUTH_COOKIE_NAME)
# fallback to headers for compatibility with OpenAPI docs at `/docs`
if not token:
token = connection.headers.get(AUTH_COOKIE_NAME)
if not token:
raise NotAuthorizedException()
jwt_user = verify_jwt(token)
if not jwt_user:
raise NotAuthorizedException()
user_id = str(jwt_user.id)
sqlalchemy_plugin = connection.app.plugins.get(SQLAlchemyPlugin)
db_config = sqlalchemy_plugin.config[0]
session_maker = db_config.create_session_maker()
async with session_maker() as db_session:
user_query = await db_session.execute(
select(User).where(User.id == user_id)
)
user = user_query.scalar_one_or_none()
if not user:
raise NotAuthorizedException()
return AuthenticationResult(user=user, auth=token)
[docs]
@dataclass
class LoginRequest:
"""
Parameters sent by the user in order to login.
"""
name: str
password: str
two_fa_code: str | None = None
# TODO: 2fa is not yet implemented, so the code is ignored
webauthn_response: dict[str, Any] | None = None
[docs]
@dataclass
class ChangePasswordRequest:
old_password: str
new_password: str
[docs]
@dataclass
class ResetPasswordRequest:
new_password: str
[docs]
@dataclass
class TwoFaRequiredResponse:
user_id: str
totp_supported: bool
webauthn_supported: bool
[docs]
async def get_user_by_name_or_mail(db_session: AsyncSession, query: str) -> User | None:
"""
Get a user by their name (case-insensitive).
This method tries to find a user with the given username first.
If there's none, it falls back to searching a user whose email equals the query.
param query: the username or email to search for
return: the user for the given query, or `None` if no such user exists
"""
user_query = await db_session.execute(
select(User).where(func.lower(User.name) == func.lower(query))
)
user = user_query.scalar_one_or_none()
if user:
return user
# fallback to email
user_query = await db_session.execute(
select(User).where(func.lower(User.email) == func.lower(query))
)
return user_query.scalar_one_or_none()
def _verify_2fa_request(user: User, request: LoginRequest) -> bool:
"""
Check if the user provided either
- a valid 2FA TOTP password
- a valid Webauthn credential
If the user doesn't have 2FA set up, this method returns ``True``.
If none of the provided 2FA options is correct (or provided), this returns ``False``.
:param user: the user that wants to authenticate
:param request: the data sent by the user to confirm its identity
"""
# user doesn't have 2FA set up, so we don't have to check if it's valid
if not user.webauthn and not user.two_fa_secret:
return True
# TOTP 2FA: only required if user has a configured secret (and not pending)
secret, pending = _parse_totp_secret(user.two_fa_secret)
if secret and not pending:
code = _normalize_totp_code(request.two_fa_code)
if code and verify_totp(secret, code):
return True
# user has Webauthn set up -> needs to be validated as well
if user.webauthn and request.webauthn_response:
try:
try_authenticate_user(
str(user.id), user.webauthn, request.webauthn_response
)
return True
except ClientException:
pass
return False
@post(
"/login",
return_dto=UserReadDTO,
responses={
403: ResponseSpec(
data_container=TwoFaRequiredResponse,
description="Invalid or no 2FA options provided, although required",
)
},
)
async def login_handler(
data: Annotated[LoginRequest, Body(title="Login Request")], db_session: AsyncSession
) -> Response[User]:
"""
Login to the application.
If 2FA is required but not provided, a HTTP 403 "Not Authorized" error will be returned.
param data: the login data the user entered
return: a JSON object containing the generated jwt token
"""
user = await get_user_by_name_or_mail(db_session, data.name)
if user is None or not verify_password(user.password_hash, data.password):
raise ClientException("invalid username or password")
# check if the provided 2FA options match
if not _verify_2fa_request(user, data):
totp_secret, totp_pending = _parse_totp_secret(user.two_fa_secret)
raise PermissionDeniedException(
"missing or wrong 2fa options",
extra=TwoFaRequiredResponse(
user_id=str(user.id),
totp_supported=totp_secret is not None and not totp_pending,
webauthn_supported=user.webauthn is not None,
).__dict__,
)
jwt_token = create_jwt(user, config.jwt_config.validity_duration_hours)
response = Response(content=user)
is_local = config.twofa_config.app_url in ("localhost", "127.0.0.1")
response.set_cookie(
key=AUTH_COOKIE_NAME,
value=jwt_token,
max_age=config.jwt_config.validity_duration_hours * 60 * 60,
samesite="lax",
secure=not is_local, # https needed, except for localhost
httponly=True, # True, so that cookies cant be read by javascript
)
return response
[docs]
def create_jwt(user: User, validity_hours: int) -> str:
jwt_user = JwtUser(id=str(user.id), name=user.name, email=user.email)
jwt_payload = jwt_user.__dict__
# set expiration time - automatically gets handled when `jwt.decode` is called
jwt_payload["exp"] = datetime.now(tz=timezone.utc) + timedelta(hours=validity_hours)
return jwt.encode(
payload=jwt_payload, key=config.jwt_config.secret, algorithm=JWT_ALGORITHM
)
[docs]
def verify_jwt(jwt_token: str) -> JwtUser | None:
try:
user_info: dict[str, Any] = jwt.decode(
jwt_token, config.jwt_config.secret, algorithms=[JWT_ALGORITHM]
)
except jwt.DecodeError:
return None
except jwt.ExpiredSignatureError:
# possible TODO: inform user that session has expired
return None
# the expiration time is not part of the user info
# hence it must be removed in order to create a JwtUser
del user_info["exp"]
jwt_user: JwtUser = JwtUser(**user_info)
return jwt_user
[docs]
def hash_password(password: str) -> str:
"""
Hashes the password using Argon2
"""
hasher = PasswordHasher()
return hasher.hash(password)
[docs]
def verify_password(password_hash: str, password: str) -> bool:
"""
Verify the password with the given Argon2 Hash.
param password: the password to check against
param password_hash: the argon2 hash of the password
return: whether the password is correct
"""
pw_hasher = PasswordHasher()
try:
# guaranteed to crash if the password is invalid, so we don't need to handle the return type (always true)
_ = pw_hasher.verify(password_hash, password)
return True
except VerifyMismatchError:
return False
@patch("/users/password/change")
async def change_password(
request: Request,
db_session: AsyncSession,
data: ChangePasswordRequest = Body(title="Change Password Request"),
) -> SuccessResponse:
"""
Change a user's password.
param data: old and new password
"""
user = cast(User, request.user)
# load user
user_query = await db_session.execute(select(User).where(User.id == user.id))
user = user_query.scalar_one_or_none()
if not user:
raise NotFoundException("User not found")
# verify old password
if not verify_password(user.password_hash, data.old_password):
raise NotAuthorizedException("Old password is incorrect")
# hash and store new password
user.password_hash = hash_password(data.new_password)
user.must_change_password = False
await db_session.commit()
return SuccessResponse("password successfully changed")
@patch("/users/{user_id:str}/password/reset")
async def reset_password(
request: Request,
db_session: AsyncSession,
user_id: str = Parameter(),
data: ResetPasswordRequest = Body(title="Reset Password Request"),
) -> SuccessResponse:
"""
Reset a user's authentication by changing the password to the given new password and removing 2FA.
param user_id: ID of the user whose password should be changed
param data: new password
"""
actor = cast(User, request.user)
if not actor.is_admin:
raise PermissionDeniedException("Not allowed")
# load user
user_query = await db_session.execute(select(User).where(User.id == user_id))
user = user_query.scalar_one_or_none()
if not user:
raise NotFoundException("User not found")
if user.is_admin and actor.id != user.id:
raise PermissionDeniedException(
"admins may not reset the password of other admins"
)
actor = request.user
# Only the admin or the user can change the password
if not actor.is_admin and str(actor.id) != user_id:
raise PermissionDeniedException("Not allowed")
# hash and store new password
user.password_hash = hash_password(data.new_password)
user.must_change_password = True
user.two_fa_secret = None
if user.webauthn:
await db_session.delete(user.webauthn)
await db_session.commit()
return SuccessResponse("password successfully changed")
def _ensure_self_or_admin(request: Request, user_id: str) -> None:
u = request.user
if not u:
raise NotAuthorizedException()
u = typing.cast(User, u)
if str(u.id) != user_id and not u.is_admin:
raise NotAuthorizedException()
[docs]
@dataclass
class TotpSetupResponse:
secret: str
otpauth_uri: str
[docs]
@dataclass
class TotpCodeRequest:
code: str
@post("/users/{user_id:str}/2fa/totp/setup")
async def totp_setup(
request: Request,
db_session: AsyncSession,
user_id: str = Parameter(),
) -> Response:
_ensure_self_or_admin(request, user_id)
user_query = await db_session.execute(select(User).where(User.id == user_id))
user = user_query.scalar_one_or_none()
if not user:
raise NotFoundException("User not found")
secret = generate_totp_secret()
user.two_fa_secret = TOTP_PENDING_PREFIX + secret
await db_session.commit()
uri = totp_provisioning_uri(secret, user.email)
return Response(content=TotpSetupResponse(secret=secret, otpauth_uri=uri))
@post("/users/{user_id:str}/2fa/totp/confirm")
async def totp_confirm(
request: Request,
db_session: AsyncSession,
user_id: str = Parameter(),
data: TotpCodeRequest = Body(title="TOTP Confirm Request"),
) -> SuccessResponse:
_ensure_self_or_admin(request, user_id)
user_query = await db_session.execute(select(User).where(User.id == user_id))
user = user_query.scalar_one_or_none()
if not user:
raise NotFoundException("User not found")
secret, pending = _parse_totp_secret(user.two_fa_secret)
if not secret or not pending:
raise ClientException("no pending TOTP setup")
code = _normalize_totp_code(data.code)
if not code or not verify_totp(secret, code):
raise ClientException("invalid 2FA code")
user.two_fa_secret = secret
await db_session.commit()
return SuccessResponse("TOTP 2FA enabled")
@post("/users/{user_id:str}/2fa/totp/disable")
async def totp_disable(
request: Request,
db_session: AsyncSession,
user_id: str = Parameter(),
data: TotpCodeRequest = Body(title="TOTP Disable Request"),
) -> SuccessResponse:
_ensure_self_or_admin(request, user_id)
user_query = await db_session.execute(select(User).where(User.id == user_id))
user = user_query.scalar_one_or_none()
if not user:
raise NotFoundException("User not found")
secret, pending = _parse_totp_secret(user.two_fa_secret)
if not secret:
return SuccessResponse("TOTP 2FA already disabled")
if pending:
user.two_fa_secret = None
await db_session.commit()
return SuccessResponse("pending TOTP setup cleared")
code = _normalize_totp_code(data.code)
if not code or not verify_totp(secret, code):
raise ClientException("invalid 2FA code")
user.two_fa_secret = None
await db_session.commit()
return SuccessResponse("TOTP 2FA disabled")
@get("/users/{user_id:str}/2fa/totp/is_configured")
async def totp_is_configured(
request: Request,
db_session: AsyncSession,
user_id: str = Parameter(),
) -> TotpConfiguredResponse:
"""
Check whether the user has 2FA via TOTP enabled.
param user_id: the ID of the user to check the current TOTP state for
return: a JSON dict describing if TOTP is enabled
"""
_ensure_self_or_admin(request, user_id)
user = await db_session.get(User, user_id)
if not user:
raise ClientException("user does not exist")
if user.two_fa_secret and not user.two_fa_secret.startswith(TOTP_PENDING_PREFIX):
return TotpConfiguredResponse(is_configured=True)
return TotpConfiguredResponse(is_configured=False)
@post("/logout")
async def logout() -> Response[None]:
response = Response(None)
response.delete_cookie(
key=AUTH_COOKIE_NAME,
)
return response