Building a Custom Authentication Layer in Axum
Ethan Miller
Product Engineer · Leapcell

Introduction
In the world of modern web services, securing your APIs is paramount. Whether you're building a microservice or a monolithic application, ensuring that only authorized users or services can access certain endpoints is a fundamental requirement. JSON Web Tokens (JWTs) and API Keys are two prevalent methods for achieving this, offering lightweight and stateless ways to verify identity. While Axum, Rust's increasingly popular web framework, provides excellent building blocks, implementing custom authentication often requires a deeper understanding of its middleware architecture. This article will guide you through the process of crafting a reusable authentication "Layer" in Axum from the ground up, capable of handling both JWT and API Key validation. We'll explore the design principles, walk through the implementation details, and demonstrate how to integrate it seamlessly into your Axum applications.
Understanding the Core Concepts
Before we dive into the code, let's briefly define some key terms that are central to our discussion:
- Axum Request Handling: Axum handles incoming HTTP requests through a chain of services. Each service can process the request, modify it, or pass it to the next service in the chain.
- Tower Service Trait: At the heart of Axum's middleware system is the
tower::Service
trait. It defines a generic asynchronous operation that takes a request and returns a response. - Tower Layer Trait: A
tower::Layer
is a factory fortower::Service
instances. It wraps an inner service, allowing you to add logic before or after the inner service is executed. This is precisely what we'll use to inject our authentication logic. - JSON Web Token (JWT): A self-contained, compact, and URL-safe means of representing claims to be transferred between two parties. JWTs are often used for authentication, where the server issues a token to a client after successful login, and the client then sends this token with subsequent requests to prove its identity.
- API Key: A unique identifier provided by an API provider to a consumer to access their API. API Keys are typically passed in request headers or as query parameters. While simpler, they offer less fine-grained control and are generally less secure than JWTs for user authentication without additional mechanisms.
- Authentication vs. Authorization: Authentication is about who you are (verifying identity), while authorization is about what you can do (determining access rights). Our Layer will primarily focus on authentication.
Building the Authentication Layer
Our authentication Layer will inspect incoming requests for either a JWT in the Authorization
header or an API Key in a custom header (e.g., X-API-Key
). If valid credentials are found, the request will proceed; otherwise, an unauthorized response will be returned.
The Authentication State
First, let's define a simple enum to represent the possible authentication outcomes:
#[derive(Debug, Clone, PartialEq)] pub enum AuthStatus { Authenticated(String), // For user ID or other identifier Unauthenticated, }
This AuthStatus
will be used to signal the result of our authentication attempt. When authenticated, we might store a user ID or other relevant information.
The Authentication Service
Next, we'll create our custom AuthService
that implements the tower::Service
trait. This service will wrap an inner service and perform the authentication logic.
use async_trait::async_trait; use axum::{ body::{Body, BoxBody}, extract::Request, http::{ header::{AUTHORIZATION, CONTENT_TYPE}, response::Response, StatusCode, }, response::IntoResponse, middleware::Next, }; use std::{ future::Future, pin::Pin, task::{Context, Poll}, }; use tower::{Layer, Service}; pub struct AuthService<S> { inner: S, jwt_secret: String, api_key_header_name: String, valid_api_keys: Vec<String>, } impl<S> AuthService<S> { pub fn new( inner: S, jwt_secret: String, api_key_header_name: String, valid_api_keys: Vec<String>, ) -> Self { Self { inner, jwt_secret, api_key_header_name, valid_api_keys, } } // Helper function to validate JWT fn validate_jwt(&self, token: &str) -> Option<String> { // In a real application, you'd parse and validate the JWT. // For demonstration, let's assume a dummy validation. if token.starts_with("Bearer my_valid_jwt_") { // Extract user ID from the token (e.g., by decoding claims) Some("user123".to_string()) } else { None } } // Helper function to validate API key fn validate_api_key(&self, api_key: &str) -> Option<String> { if self.valid_api_keys.contains(&api_key.to_string()) { // For API keys, we might return the key itself or an associated user/service ID Some(format!("api_user_{}", api_key)) } else { None } } } // Implement the tower::Service trait for AuthService impl<S> Service<Request> for AuthService<S> where S: Service<Request, Response = Response> + Send + 'static, S::Future: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { self.inner.poll_ready(cx) } fn call(&mut self, mut request: Request) -> Self::Future { let jwt_secret = self.jwt_secret.clone(); let api_key_header_name = self.api_key_header_name.clone(); let valid_api_keys = self.valid_api_keys.clone(); // Clone for async block // This is a tricky part: we need to clone `self.inner` because `Service::call` takes `&mut self`. // If `S` doesn't implement `Clone`, this approach requires `ServiceBuilder::service` which takes `&&mut S`. // A common pattern is to wrap the inner service in an `Arc<Mutex<S>>` if `S` is not `Clone`. // For simplicity here, we assume S is clonable or we carefully handle the mutable reference. // In Axum's `tower::util::ServiceFn` or `middleware::from_extractor_with_state`, cloning is handled. // For a raw Tower Service, we often need to be more explicit. let inner = self.inner.ready(); // Ensure inner service is ready Box::pin(async move { let mut auth_status = AuthStatus::Unauthenticated; // 1. Check for JWT if let Some(auth_header) = request.headers().get(AUTHORIZATION) { if let Ok(header_value) = auth_header.to_str() { if header_value.starts_with("Bearer ") { let token = &header_value[7..]; if let Some(user_id) = AuthService::<S>::new( // This `new` is only for calling `validate_jwt` helper, // not for creating the actual service that will be called. // Better to make `validate_jwt` a free function or pass context. S::default(), // Dummy inner service, not used for validation logic jwt_secret.clone(), String::new(), // Not used for JWT validation vec![], // Not used for JWT validation ).validate_jwt(token) { auth_status = AuthStatus::Authenticated(user_id); } } } } // 2. Check for API Key if not already authenticated if auth_status == AuthStatus::Unauthenticated { if let Some(api_key_header) = request.headers().get(&api_key_header_name) { if let Ok(key_value) = api_key_header.to_str() { if let Some(api_user) = AuthService::<S>::new( S::default(), // Dummy inner service String::new(), // Not used for API Key validation api_key_header_name.clone(), valid_api_keys.clone(), ).validate_api_key(key_value) { auth_status = AuthStatus::Authenticated(api_user); } } } } match auth_status { AuthStatus::Authenticated(user_id) => { // Store the authenticated user ID in request extensions request.extensions_mut().insert(user_id.clone()); inner.await?.call(request).await // Proceed to inner service } AuthStatus::Unauthenticated => { // Return 401 Unauthorized Ok(StatusCode::UNAUTHORIZED.into_response()) } } }) } }
Important Note on Service::call
and &mut self
: The tower::Service::call
method takes &mut self
. This means that if our AuthService
needs to perform asynchronous operations that depend on internal state (like jwt_secret
or valid_api_keys
), and then also call the inner
service (which also takes &mut S
), we need to be careful. The code above uses clone
for demonstration. In a production setting, you'd often wrap shared state in Arc
and Mutex
or RwLock
if it needs to be mutable and shared across asynchronous tasks, or pass copies of the configuration. Axum's tower::Layer::layer
helper often simplifies this, or by making Service
immutable after creation. For a Layer
, the Service
created typically lives for the lifetime of a connection or application, meaning state needs to be shareable.
A more idiomatic way to handle Service::call
with &mut self
when working with inner.await?.call(request).await
is often achieved through a tower::util::ServiceFn
or by separating the authentication logic into its own middleware
function that can be composed using axum::middleware::from_fn
. However, for explicitly demonstrating a raw tower::Layer
and Service
, we're following this pattern.
The Authentication Layer
Now, we define our AuthLayer
which creates instances of AuthService
:
pub struct AuthLayer { jwt_secret: String, api_key_header_name: String, valid_api_keys: Vec<String>, } impl AuthLayer { pub fn new(jwt_secret: String, api_key_header_name: String, valid_api_keys: Vec<String>) -> Self { Self { jwt_secret, api_key_header_name, valid_api_keys, } } } // Implement the tower::Layer trait for AuthLayer impl<S> Layer<S> for AuthLayer { type Service = AuthService<S>; fn layer(&self, inner: S) -> Self::Service { AuthService::new( inner, self.jwt_secret.clone(), self.api_key_header_name.clone(), self.valid_api_keys.clone(), ) } }
Extracting Authenticated User ID
To make the authenticated user ID available to our handlers, we can create an Axum extractor:
use axum::{ async_trait, extract::{FromRequestParts, State}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, Json, }; pub struct AuthenticatedUser(pub String); #[async_trait] impl<S> FromRequestParts<S> for AuthenticatedUser where S: Send + Sync, { type Rejection = Response; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> { if let Some(user_id) = parts.extensions.get::<String>() { Ok(AuthenticatedUser(user_id.clone())) } else { // This should ideally not happen if the layer is applied correctly, // but it acts as a safeguard. Err(StatusCode::UNAUTHORIZED.into_response()) } } }
Integrating with Axum
Finally, let's see how to apply this layer to an Axum router:
use axum::{routing::get, Router}; use tower::ServiceBuilder; // Handler that requires authentication async fn protected_handler(AuthenticatedUser(user_id): AuthenticatedUser) -> String { format!("Hello, authenticated user: {}!", user_id) } // Simple public handler async fn public_handler() -> &'static str { "This is a public endpoint." } #[tokio::main] async fn main() { let app = Router::new() .route("/public", get(public_handler)) .route("/protected", get(protected_handler)) .layer( ServiceBuilder::new() .layer(AuthLayer::new( "super_secret_jwt_key".to_string(), // In reality, load from config/env "X-Api-Key".to_string(), vec!["my_secret_api_key".to_string(), "another_key".to_string()], )) ); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await .unwrap(); println!("Listening on http://127.0.0.1:3000"); axum::serve(listener, app).await.unwrap(); }
How it Works
- Request
AuthLayer
: When an HTTP request comes in, it first hits ourAuthLayer
. AuthLayer::layer
: The layer creates anAuthService
instance, wrapping the inner service (which might be another layer or the final handler).AuthService::call
: TheAuthService
'scall
method is invoked.- It checks the
Authorization
header for a "Bearer" token and attempts JWT validation. - If JWT validation fails or isn't present, it checks the
X-Api-Key
header for a predefined API key. - If either is valid, it inserts the derived
user_id
(or similar identifier) into the request's extensions usingrequest.extensions_mut().insert()
. - It then calls the
inner
service, allowing the request to proceed to the handler. - If neither JWT nor API Key is valid, it returns a
401 Unauthorized
response immediately, short-circuiting the request chain.
- It checks the
AuthenticatedUser
Extractor: In ourprotected_handler
, we use theAuthenticatedUser
extractor. This extractor retrieves theString
(ouruser_id
) from the request extensions. If it's present, the handler receives it; otherwise, the extractor returns a401 Unauthorized
from its own rejection flow, acting as a double-check.
Conclusion
By leveraging Axum's tower::Layer
and tower::Service
traits, we've successfully implemented a custom authentication layer capable of handling both JWT and API Key validation. This approach centralizes authentication logic, keeps handlers clean, and promotes code reusability. This robust middleware pattern is fundamental for building secure and maintainable web applications with Axum. Remember, a well-structured application relies on clearly defined boundaries and modular components, and custom layers are a powerful tool in achieving this.