Integrate Azure AD SSO with SAML for your FastAPI Application using PySAML2

Posted on Jan 15, 2024

Recently I had to work on a Azure AD SAML integration for a FastAPI application. I looked around the internet for a how to and stitched together a solution from multiple sources. I am writing this post as a how to on integrating Azure AD SAML SSO into a FastAPI application so that there is a single source for others to use. I am using PySAML2 and I am sure there are better ways to do this but this is how I did it. This integration is written for a FastAPI application but since PySAML2 is a library seperate from the Web Framework being used, you could use this as a guide to integrate PySAML2 with Flask etc.

You will need to get a metadata file from your Azure AD Administrator. It will be an XML file built for your application.

Install PySAML2 as decsribed in the docs.

Also install xmlsec1 as stated in the prerequisites. If you are using homebrew you can simply install by running brew install libxmlsec1.

Setup the class

All the SSO work can be done using one class.


import base64

from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT
from saml2.client import Saml2Client
from saml2.config import Config
from saml2.saml import name_id_from_string

@dataclass
class SAMLService:
    saml_config = load_saml_config()
    saml_client = Saml2Client(config=saml_config)

    def prepare_for_authenticate(self, redirect_url=None):
        relay_state = ""
        if redirect_url:
            relay_state = base64.b64encode(redirect_url.encode("utf-8"))
        request_id, info = self.saml_client.prepare_for_authenticate(
            relay_state=relay_state
        )
        for key, value in info["headers"]:
            if key == "Location":
                sso_redirect_url = value
                return sso_redirect_url
        return None

    def process_saml_response(self, saml_response_data):
        # Parse and process the SAML response
        authn_response = self.saml_client.parse_authn_request_response(
            saml_response_data,
            BINDING_HTTP_POST,
        )
        authn_response.get_identity()
        authn_response.get_subject()
        return authn_response

    def get_user_info(self, saml_response):
        user_info = {
            "email": saml_response.ava["Email"][0],
            "first_name": saml_response.ava["Firstname"][0],
            "last_name": saml_response.ava["Lastname"][0],
        }
        return user_info

    def is_logged_in(self, sso_name_id):
        return self.saml_client.is_logged_in(name_id=name_id_from_string(sso_name_id))

    def logout(self, sso_name_id):
        # local logout. does not logout from azure
        return self.saml_client.local_logout(name_id=name_id_from_string(sso_name_id))

    def authenticate_user(self, user_attributes):
        # placeholder to verify user roles etc
        return True

class RequiresLoginException(Exception):
    def __init__(self, redirect_url: str = None):
        self.redirect_url = redirect_url
        super().__init__()

def get_saml_service():
    return SAMLService()

def is_user_logged_in(request: Request):
    saml_service = SAMLService()
    if not saml_service.is_logged_in(request.session["saml_name_id"]):
        redirect_url = str(request.url)
        raise RequiresLoginException(redirect_url=redirect_url)
    else:
        return request.session["user_info"]

And the method to load the saml_config:


def load_saml_config():
    settings = {
        "metadata": {"local": [SSO_METADATA_FILEPATH]},
        "service": {
            "sp": {
                "endpoints": {
                    "assertion_consumer_service": [
                        (SSO_ASSERTION_CONSUMER_SERVICE_URL, BINDING_HTTP_REDIRECT),
                        (SSO_ASSERTION_CONSUMER_SERVICE_URL, BINDING_HTTP_POST),
                    ],
                },
                "allow_unsolicited": True,
                "authn_requests_signed": False,
                "logout_requests_signed": True,
                "want_assertions_signed": True,
                "want_response_signed": False,
            },
        },
    }
    saml_config = Config()
    saml_config.load(settings)
    saml_config.allow_unknown_attributes = True
    saml_config.entityid = SSO_ENTITY_ID
    return saml_config

The load_saml_config method returns a saml config class instance with the settings dict loaded. You may find the Assertion Service Url (ACS) in the metadata file. The SSO_ENTITY_ID will be the entity id created during the setup on the Azure AD Admin account.

The prepare_for_authenticate method uses the prepare_for_authenticate method on the saml client class to get the request_id and info dictionary back. The header key in the info dict has a Location key which has the redirect url with the SAML request. This redirect url is where you will redirect the user to login. If you want to preserve the state of the url where the authentication exception originated in the application, you can encode it and pass it as a relay_state string in the prepare_for_authenticate method on the saml client class. This will be added to the redirect url to azure as a RelayState named param in the redirect url which is simply returned back as RelayState paramter in the SAML Response from Azure AD.

The process_saml_response processes the SAML response recieved from Azure AD after successful login and returns a dictionary with the SAML attributes that are returned from Azure. This can include user name, email address, etc.

Other methods are described as defined above. The is_user_logged_in method is used as a check to see wether the user is present in the SAML cache and if its still valid. If the user is not valid in the cache, a RequiresLoginException is raised which gets bubbled up to the FastAPI app level and then gets handled as described after the FastAPI route section below.

FastAPI Routes

On the FastAPI side, you can define routes and import the SAMLService class as a dependency using the FastAPI Dependency Injection feature. The starlette.status class has the http status code literals used below so import those if you want to use them.

The callback route is where the SAMLResponse comes in after a successful login in Azure AD.

We also rely on a signed session cookie from the client to store saml attributes and user info. Since FastAPI is written on top of Starlette, you can use Starlette’s Session Middleware feature for this.


@router.get("/login")
def saml_login(
    redirect_url: str = None, saml_service: SAMLService = Depends(get_saml_service)
):
    authn_request_url = saml_service.prepare_for_authenticate(redirect_url=redirect_url)
    if authn_request_url:
        headers = {"Cache-Control": "no-cache, no-store", "Pragma": "no-cache"}
        return RedirectResponse(url=authn_request_url, headers=headers)
    else:
        raise HTTPException(status_code=HTTP_400_BAD_REQUEST, detail="Unable to login")

@router.post("/login/callback")
async def saml_callback(
    request: Request, saml_service: SAMLService = Depends(get_saml_service)
):
    form_data = await request.form()
    saml_response_data = form_data.get("SAMLResponse")
    if not saml_response_data:
        raise HTTPException(
            status_code=HTTP_400_BAD_REQUEST, detail="SAMLResponse not found"
        )
    saml_response = saml_service.process_saml_response(saml_response_data)
    request.session["saml_attributes"] = saml_response.ava
    request.session["saml_name_id"] = str(saml_response.name_id)
    request.session["user_info"] = saml_service.get_user_info(saml_response)
    if form_data.get("RelayState"):
        return_path = base64.b64decode(form_data.get("RelayState")).decode("utf-8")
        return RedirectResponse(return_path, status_code=HTTP_303_SEE_OTHER)
    return RedirectResponse("/api/login/user-info", status_code=HTTP_303_SEE_OTHER)


@router.get("/login/user-info")
async def check_user(request: Request):
    return is_user_logged_in(request)


@router.get("/login/logout")
async def logout_user(
    request: Request, saml_service: SAMLService = Depends(get_saml_service)
):
    saml_name_id = request.session["saml_name_id"]
    if saml_service.logout(saml_name_id):
        request.session.pop("saml_attributes")
        request.session.pop("saml_name_id")
        request.session.pop("user_info")
        return Response(status_code=HTTP_205_RESET_CONTENT)

Handling RequiresLoginException:

# in main.py or where ever you define the
# fastapi app
@app.exception_handler(RequiresLoginException)
async def exception_handler(request: Request, exc: RequiresLoginException) -> Response:
    login_path = "/api/login"
    if exc.redirect_url:
        login_path = (
            f"{login_path}?redirect_url={urllib.parse.quote_plus(exc.redirect_url)}"
        )
    return RedirectResponse(url=login_path)

SAML authenticaiton does have a drawback of the inability to authenticate with the Identity Provider (Azure AD in this case) without user redirection. If the user has to be redirected for re-authentication the request state has to be reserved and can be set using the RelayState parameter.