Axum validated json extractor

The Rust Axum web-framework—is great. It lacks one thing, however. A proper validated JSON request.

By using the validator crate, we can create an extractor for Axum that will validate the JSON request. Below you can find the example for it.

use async_trait::async_trait;
use axum::{
    extract::{rejection::JsonRejection, FromRequest},
    http::Request,
    Json,
};
use serde::de::DeserializeOwned;
use validator::Validate;

#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedJson<T>(pub T);

#[async_trait]
impl<T, S, B> FromRequest<S, B> for ValidatedJson<T>
where
    B: Send + 'static,
    S: Send + Sync,
    T: DeserializeOwned + Validate,
    Json<T>: FromRequest<S, B, Rejection = JsonRejection>,
{
    type Rejection = super::error::ServerError;

    async fn from_request(req: Request<B>, state: &S) -> Result<Self, Self::Rejection> {
        let Json(data) = Json::<T>::from_request(req, state).await?;
        data.validate()?;
        Ok(ValidatedJson(data))
    }
}

As for the error, you can return anything you want, so for example (StatusCode, String) will be fine. But I wanted something more structured, so I create a dedicated ServerError.

Note: Keep in mind that I do some JSON error processing with serde in order to get more meaningful error message. You don’t have to do the same.

#[derive(Debug, Error)]
pub enum ServerError {
    #[error(transparent)]
    ValidationError(#[from] validator::ValidationErrors),

    #[error(transparent)]
    AxumJsonRejection(#[from] JsonRejection),
}

#[derive(Serialize)]
pub struct ErrorResponse {
    message: String,
}

impl IntoResponse for ServerError {
    fn into_response(self) -> Response {
        let (code, msg) = match self {
            ServerError::ValidationError(_) => {
                let message = format!("Input validation error: [{}]", self).replace('\n', ", ");
                (StatusCode::BAD_REQUEST, message)
            }
            ServerError::AxumJsonRejection(err) => match err {
                JsonRejection::JsonDataError(err) => serde_json_error_response(err),
                JsonRejection::JsonSyntaxError(err) => serde_json_error_response(err),
                // handle other rejections from the `Json` extractor
                JsonRejection::MissingJsonContentType(_) => (
                    StatusCode::BAD_REQUEST,
                    "Missing `Content-Type: application/json` header".to_string(),
                ),
                JsonRejection::BytesRejection(_) => (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    "Failed to buffer request body".to_string(),
                ),
                // we must provide a catch-all case since `JsonRejection` is marked
                // `#[non_exhaustive]`
                _ => (
                    StatusCode::INTERNAL_SERVER_ERROR,
                    "Unknown error".to_string(),
                ),
            }
        };

        (code, Json(ErrorResponse { message: msg })).into_response()
    }
}

// attempt to extract the inner `serde_path_to_error::Error<serde_json::Error>`,
// if that succeeds we can provide a more specific error.
//
// `Json` uses `serde_path_to_error` so the error will be wrapped in `serde_path_to_error::Error`.
fn serde_json_error_response<E>(err: E) -> (StatusCode, String)
where
    E: std::error::Error + 'static,
{
    if let Some(err) = find_error_source::<serde_path_to_error::Error<serde_json::Error>>(&err) {
        let serde_json_err = err.inner();
        (
            StatusCode::BAD_REQUEST,
            format!("Invalid JSON: {}", serde_json_err.to_string()),
        )
    } else {
        (StatusCode::BAD_REQUEST, "Unknown error".to_string())
    }
}

// attempt to downcast `err` into a `T` and if that fails recursively try and
// downcast `err`'s source
fn find_error_source<'a, T>(err: &'a (dyn std::error::Error + 'static)) -> Option<&'a T>
where
    T: std::error::Error + 'static,
{
    if let Some(err) = err.downcast_ref::<T>() {
        Some(err)
    } else if let Some(source) = err.source() {
        find_error_source(source)
    } else {
        None
    }
}