part2: input validation

This commit is contained in:
Sander Hautvast 2022-05-30 18:31:12 +02:00
parent d395c39200
commit 9a7789a85a
4 changed files with 163 additions and 5 deletions

87
Cargo.lock generated
View file

@ -13,6 +13,15 @@ dependencies = [
"version_check", "version_check",
] ]
[[package]]
name = "aho-corasick"
version = "0.7.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e37cfd5e7657ada45f742d6e99ca5788580b5c529dc78faf11ece6dc702656f"
dependencies = [
"memchr",
]
[[package]] [[package]]
name = "ansi_term" name = "ansi_term"
version = "0.12.1" version = "0.12.1"
@ -524,6 +533,12 @@ dependencies = [
"unicode-normalization", "unicode-normalization",
] ]
[[package]]
name = "if_chain"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb56e1aa765b4b4f3aadfab769793b7087bb03a4ea4920644a6d238e2df5b9ed"
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.8.1" version = "1.8.1"
@ -870,6 +885,30 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.39" version = "1.0.39"
@ -944,6 +983,8 @@ version = "1.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d83f127d94bdbcda4c8cc2e50f6f84f4b611f69c902699ca385a39c3a75f9ff1" checksum = "d83f127d94bdbcda4c8cc2e50f6f84f4b611f69c902699ca385a39c3a75f9ff1"
dependencies = [ dependencies = [
"aho-corasick",
"memchr",
"regex-syntax", "regex-syntax",
] ]
@ -975,13 +1016,17 @@ dependencies = [
name = "rust_for_life" name = "rust_for_life"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"async-trait",
"axum", "axum",
"chrono", "chrono",
"http-body",
"serde", "serde",
"sqlx", "sqlx",
"thiserror",
"tokio", "tokio",
"tracing", "tracing",
"tracing-subscriber", "tracing-subscriber",
"validator",
] ]
[[package]] [[package]]
@ -1565,6 +1610,48 @@ dependencies = [
"percent-encoding", "percent-encoding",
] ]
[[package]]
name = "validator"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f07b0a1390e01c0fc35ebb26b28ced33c9a3808f7f9fbe94d3cc01e233bfeed5"
dependencies = [
"idna",
"lazy_static",
"regex",
"serde",
"serde_derive",
"serde_json",
"url",
"validator_derive",
]
[[package]]
name = "validator_derive"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea7ed5e8cf2b6bdd64a6c4ce851da25388a89327b17b88424ceced6bd5017923"
dependencies = [
"if_chain",
"lazy_static",
"proc-macro-error",
"proc-macro2",
"quote",
"regex",
"syn",
"validator_types",
]
[[package]]
name = "validator_types"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2ddf34293296847abfc1493b15c6e2f5d3cd19f57ad7d22673bf4c6278da329"
dependencies = [
"proc-macro2",
"syn",
]
[[package]] [[package]]
name = "valuable" name = "valuable"
version = "0.1.0" version = "0.1.0"

View file

@ -11,3 +11,7 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
serde = "1.0" serde = "1.0"
sqlx = { version = "0.5.13", features = ["postgres", "runtime-tokio-native-tls", "chrono"] } sqlx = { version = "0.5.13", features = ["postgres", "runtime-tokio-native-tls", "chrono"] }
chrono = {version = "0.4", features = ["serde"]} chrono = {version = "0.4", features = ["serde"]}
validator = { version = "0.15", features = ["derive"] }
thiserror = "1.0.29"
http-body = "0.4.3"
async-trait = "0.1"

1
curl.txt Normal file
View file

@ -0,0 +1 @@
curl http://localhost:3000/entries -X POST -d '{"created":"2022-05-30T17:09:00.000000Z", "title":"", "author":"", "text": ""}' -v -H "Content-Type:application/json"

View file

@ -15,12 +15,17 @@
use std::{net::SocketAddr, time::Duration}; use std::{net::SocketAddr, time::Duration};
use axum::{extract::Extension, http::StatusCode, Json, Router, routing::get}; use axum::{http::StatusCode, Json, response::{IntoResponse, Response}, Router, routing::get, BoxError};
use axum::extract::{Extension, FromRequest, RequestParts, Json as ExtractJson};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde::de::DeserializeOwned;
use sqlx::postgres::{PgPool, PgPoolOptions}; use sqlx::postgres::{PgPool, PgPoolOptions};
use tracing::{debug, Level}; use tracing::{debug, Level};
use tracing_subscriber::FmtSubscriber; use tracing_subscriber::FmtSubscriber;
use thiserror::Error;
use validator::Validate;
use async_trait::async_trait;
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
@ -47,7 +52,7 @@ async fn main() {
} }
let app = Router::new() let app = Router::new()
.route("/entries", get(get_blogs)) .route("/entries", get(get_blogs).post(add_blog))
.layer(Extension(pool)); .layer(Extension(pool));
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
@ -70,6 +75,21 @@ async fn get_blogs(Extension(pool): Extension<PgPool>) -> Result<Json<Vec<BlogEn
.map_err(internal_error) .map_err(internal_error)
} }
async fn add_blog(Extension(pool): Extension<PgPool>, ValidatedJson(blog): ValidatedJson<BlogEntry>) -> Result<Json<String>, (StatusCode, String)> {
debug!("handling BlogEntries request");
sqlx::query("insert into blog_entry (created, title, author, text) values ($1, $2, $3, $4)")
.bind(blog.created)
.bind(blog.title)
.bind(blog.author)
.bind(blog.text)
.execute(&pool)
.await
.map_err(internal_error)?;
Ok(Json("created".to_owned()))
}
/// Utility function for mapping any error into a `500 Internal Server Error` response. /// Utility function for mapping any error into a `500 Internal Server Error` response.
fn internal_error<E>(err: E) -> (StatusCode, String) fn internal_error<E>(err: E) -> (StatusCode, String)
where where
@ -78,10 +98,56 @@ fn internal_error<E>(err: E) -> (StatusCode, String)
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) (StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
} }
#[derive(Serialize, Deserialize, Clone, Debug, sqlx::FromRow)] #[derive(Serialize, Deserialize, Clone, Debug, sqlx::FromRow, Validate)]
struct BlogEntry { struct BlogEntry {
created: DateTime<Utc>, created: DateTime<Utc>,
#[validate(length(min = 10, max = 100, message = "Title length must be between 10 and 100"))]
title: String, title: String,
#[validate(email(message = "author must be a valid email address"))]
author: String, author: String,
#[validate(length(min = 10, message = "text length must be at least 10"))]
text: String, text: String,
} }
#[derive(Debug, Clone, Copy, Default)]
pub struct ValidatedJson<T>(pub T);
#[async_trait]
impl<T, B> FromRequest<B> for ValidatedJson<T>
where
T: DeserializeOwned + Validate,
B: http_body::Body + Send,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Rejection = ServerError;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let ExtractJson(value) = ExtractJson::<T>::from_request(req).await?;
value.validate()?;
Ok(ValidatedJson(value))
}
}
#[derive(Debug, Error)]
pub enum ServerError {
#[error(transparent)]
ValidationError(#[from] validator::ValidationErrors),
#[error(transparent)]
AxumFormRejection(#[from] axum::extract::rejection::JsonRejection),
}
impl IntoResponse for ServerError {
fn into_response(self) -> Response {
match self {
ServerError::ValidationError(_) => {
let message = format!("Input validation error: [{:?}]", self).replace('\n', ", ");
(StatusCode::BAD_REQUEST, message)
}
ServerError::AxumFormRejection(_) => (StatusCode::BAD_REQUEST, self.to_string()),
}
.into_response()
}
}