first commit
This commit is contained in:
commit
7fc534ba8d
25 changed files with 3711 additions and 0 deletions
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
/target
|
||||
.env
|
||||
.idea
|
||||
40
.sqlx/query-4643a886e01f7111e63d0136dc63882c7186de86d3a30ca3d00b280c5cdc0ed2.json
generated
Normal file
40
.sqlx/query-4643a886e01f7111e63d0136dc63882c7186de86d3a30ca3d00b280c5cdc0ed2.json
generated
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT id, user_id, title, body FROM posts WHERE id = $1",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "title",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "body",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int4"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "4643a886e01f7111e63d0136dc63882c7186de86d3a30ca3d00b280c5cdc0ed2"
|
||||
}
|
||||
38
.sqlx/query-a65937cc250d94bc40f49bf1d8dbd426f9c4c233aa1e8ff416cd78998265126f.json
generated
Normal file
38
.sqlx/query-a65937cc250d94bc40f49bf1d8dbd426f9c4c233aa1e8ff416cd78998265126f.json
generated
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "SELECT id, user_id, title, body FROM posts",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_id",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "title",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "body",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": []
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "a65937cc250d94bc40f49bf1d8dbd426f9c4c233aa1e8ff416cd78998265126f"
|
||||
}
|
||||
2727
Cargo.lock
generated
Normal file
2727
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
27
Cargo.toml
Normal file
27
Cargo.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "rustrest"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
axum = { version = "0.8", features = ["macros"] }
|
||||
axum-extra = { version = "0.10", features = ["typed-header"] }
|
||||
dotenvy = "0.15"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0.140"
|
||||
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-native-tls", "postgres", "uuid", "chrono"] }
|
||||
tokio = { version = "1.45", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = "0.3"
|
||||
anyhow = "1.0"
|
||||
|
||||
tower-http = { version = "0.5", features = ["cors", "trace"] }
|
||||
jsonwebtoken = "9.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
http = "1.0"
|
||||
thiserror = "1.0"
|
||||
uuid = { version = "1.5", features = ["serde", "v4"] }
|
||||
|
||||
# New dependencies for security
|
||||
argon2 = { version = "0.5", features = ["password-hash"] }
|
||||
|
||||
17
README.md
Normal file
17
README.md
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
* REST Api based on Axum 0.8, serves as a more complete example than most blogs will provide
|
||||
* Postgres database using Sqlx, including migrations
|
||||
* simple datamodel and api for reading posts for a blog
|
||||
* Has users and roles (roles not fully implemented)
|
||||
* logging
|
||||
* externalized config
|
||||
* /register stores the user (passwords hashed with argon2)
|
||||
* /login returns a JWT token
|
||||
* /posts returns all posts
|
||||
* /posts/{id} returns a post
|
||||
|
||||
| .env |
|
||||
| DATABASE_URL=postgres://postgres:...@localhost:5432/rust-axum-rest-api |
|
||||
| MAX_DB_CONNECTIONS=5 |
|
||||
| BIND_HOST=0.0.0.0:5001 |
|
||||
| JWT_SECRET=... |
|
||||
| ALLOWED_ORIGINS= |
|
||||
9
migrations/20250514121601_create_users_table.sql
Normal file
9
migrations/20250514121601_create_users_table.sql
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
-- Add migration script here
|
||||
CREATE TABLE users
|
||||
(
|
||||
id SERIAL PRIMARY KEY,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
email TEXT NOT NULL UNIQUE,
|
||||
password_hash TEXT,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
)
|
||||
8
migrations/20250514121730_create_posts_table.sql
Normal file
8
migrations/20250514121730_create_posts_table.sql
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
-- Add migration script here
|
||||
CREATE TABLE posts(
|
||||
id SERIAL PRIMARY KEY,
|
||||
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
|
||||
title TEXT NOT NULL,
|
||||
body TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT NOW()
|
||||
)
|
||||
112
src/auth/jwt.rs
Normal file
112
src/auth/jwt.rs
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
use axum::{
|
||||
http::request::Parts,
|
||||
RequestPartsExt,
|
||||
extract::FromRequestParts,
|
||||
};
|
||||
use axum_extra::headers::Authorization;
|
||||
use axum_extra::headers::authorization::Bearer;
|
||||
use axum_extra::TypedHeader;
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use tracing;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::services::error::AppError;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Claims {
|
||||
pub sub: String, // User ID
|
||||
pub exp: i64, // Expiration time
|
||||
pub iat: i64, // Issued at
|
||||
pub jti: String, // JWT ID (unique identifier)
|
||||
pub roles: Vec<String>, // User roles
|
||||
}
|
||||
|
||||
impl Claims {
|
||||
pub fn new(user_id: String, roles: Vec<String>, expiration: Duration) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
sub: user_id,
|
||||
iat: now.timestamp(),
|
||||
exp: (now + expiration).timestamp(),
|
||||
jti: Uuid::new_v4().to_string(),
|
||||
roles,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JwtAuth {
|
||||
encoding_key: EncodingKey,
|
||||
decoding_key: DecodingKey,
|
||||
}
|
||||
|
||||
impl JwtAuth {
|
||||
pub fn new(secret: &[u8]) -> Self {
|
||||
Self {
|
||||
encoding_key: EncodingKey::from_secret(secret),
|
||||
decoding_key: DecodingKey::from_secret(secret),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_token(&self, claims: &Claims) -> Result<String, AppError> {
|
||||
encode(&Header::default(), claims, &self.encoding_key)
|
||||
.map_err(|_| AppError::TokenCreation)
|
||||
}
|
||||
|
||||
pub fn verify_token(&self, token: &str) -> Result<Claims, AppError> {
|
||||
// Create a validation object with default settings
|
||||
let mut validation = Validation::default();
|
||||
validation.validate_exp = true; // Verify expiration time
|
||||
validation.leeway = 0; // No leeway for exp verification (default)
|
||||
|
||||
// Decode and verify the token
|
||||
match decode::<Claims>(token, &self.decoding_key, &validation) {
|
||||
Ok(token_data) => {
|
||||
// Token is valid, return claims
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
Err(e) => {
|
||||
// Log the error for debugging
|
||||
tracing::error!("Token validation error: {:?}", e);
|
||||
|
||||
// Map jsonwebtoken errors to AppError
|
||||
match e.kind() {
|
||||
jsonwebtoken::errors::ErrorKind::ExpiredSignature => Err(AppError::TokenExpired),
|
||||
jsonwebtoken::errors::ErrorKind::InvalidToken => Err(AppError::InvalidToken),
|
||||
jsonwebtoken::errors::ErrorKind::InvalidSignature => Err(AppError::InvalidToken),
|
||||
_ => Err(AppError::InvalidToken),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extractor for protected routes
|
||||
impl<S> FromRequestParts<S> for Claims
|
||||
where
|
||||
S: Send + Sync,
|
||||
{
|
||||
type Rejection = AppError;
|
||||
|
||||
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||
// Extract the token from the Authorization header
|
||||
let TypedHeader(Authorization(bearer)) = parts
|
||||
.extract::<TypedHeader<Authorization<Bearer>>>()
|
||||
.await
|
||||
.map_err(|_| AppError::MissingToken)?;
|
||||
|
||||
// Get our auth service from extensions
|
||||
let jwt_auth = parts
|
||||
.extensions
|
||||
.get::<Arc<JwtAuth>>()
|
||||
.ok_or(AppError::MissingAuthService)?;
|
||||
|
||||
// Verify the token
|
||||
let claims = jwt_auth.verify_token(bearer.token())?;
|
||||
|
||||
Ok(claims)
|
||||
}
|
||||
}
|
||||
82
src/auth/login.rs
Normal file
82
src/auth/login.rs
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
use std::sync::Arc;
|
||||
use axum::{extract::{State, Json, Extension}, http::StatusCode};
|
||||
use chrono::Duration;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::Pool;
|
||||
use sqlx::Postgres;
|
||||
|
||||
use crate::services::error::AppError;
|
||||
use crate::auth::jwt::{Claims, JwtAuth};
|
||||
use crate::models::user::User;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
username: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct LoginResponse {
|
||||
access_token: String,
|
||||
token_type: String,
|
||||
expires_in: i64,
|
||||
user_id: i32,
|
||||
username: String,
|
||||
}
|
||||
|
||||
pub async fn login(
|
||||
State(jwt_auth): State<Arc<JwtAuth>>,
|
||||
Extension(pool): Extension<Pool<Postgres>>,
|
||||
Json(payload): Json<LoginRequest>,
|
||||
) -> Result<Json<LoginResponse>, AppError> {
|
||||
// Authentication against database
|
||||
let user = User::find_by_credentials(
|
||||
&payload.username,
|
||||
payload.password,
|
||||
&pool,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create token with appropriate roles
|
||||
let expiration = Duration::minutes(15);
|
||||
let claims = Claims::new(
|
||||
user.id.to_string(),
|
||||
vec!["user".to_string()], // Default role - in real app, fetch from database
|
||||
expiration,
|
||||
);
|
||||
|
||||
let token = jwt_auth.create_token(&claims)?;
|
||||
|
||||
Ok(Json(LoginResponse {
|
||||
access_token: token,
|
||||
token_type: "Bearer".to_string(),
|
||||
expires_in: expiration.num_seconds(),
|
||||
user_id: user.id,
|
||||
username: user.username,
|
||||
}))
|
||||
}
|
||||
|
||||
// Registration endpoint
|
||||
#[derive(Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
username: String,
|
||||
email: String,
|
||||
password: String,
|
||||
}
|
||||
|
||||
pub async fn register(
|
||||
Extension(pool): Extension<Pool<Postgres>>,
|
||||
Json(payload): Json<RegisterRequest>,
|
||||
) -> Result<(StatusCode, Json<User>), AppError> {
|
||||
// Create the new user
|
||||
let new_user = crate::models::user::NewUser {
|
||||
username: payload.username,
|
||||
email: payload.email,
|
||||
password: String::new(), // Placeholder
|
||||
};
|
||||
|
||||
let user = User::create(new_user, payload.password, &pool).await?;
|
||||
|
||||
// Return the created user (without password)
|
||||
Ok((StatusCode::CREATED, Json(user)))
|
||||
}
|
||||
8
src/auth/mod.rs
Normal file
8
src/auth/mod.rs
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
pub mod rbac;
|
||||
pub mod login;
|
||||
pub mod jwt;
|
||||
pub mod password;
|
||||
|
||||
pub use login::{login, register};
|
||||
pub use password::{hash_password, verify_password};
|
||||
pub use jwt::JwtAuth;
|
||||
24
src/auth/password.rs
Normal file
24
src/auth/password.rs
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
use crate::services::error::AppError;
|
||||
use argon2::password_hash::SaltString;
|
||||
use argon2::password_hash::rand_core::OsRng;
|
||||
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
|
||||
|
||||
pub fn hash_password(password: String) -> String {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.unwrap()
|
||||
.to_string()
|
||||
}
|
||||
|
||||
pub fn verify_password(
|
||||
stored_hash: &PasswordHash<'_>,
|
||||
password: String,
|
||||
) -> anyhow::Result<bool, AppError> {
|
||||
let argon2 = Argon2::default();
|
||||
argon2
|
||||
.verify_password(password.as_bytes(), stored_hash)
|
||||
.map(|_| true)
|
||||
.map_err(|_| AppError::AuthenticationFailed)
|
||||
}
|
||||
63
src/auth/rbac.rs
Normal file
63
src/auth/rbac.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
use axum::{
|
||||
extract::Request,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use std::fmt;
|
||||
|
||||
use crate::services::error::AppError;
|
||||
use crate::auth::jwt::Claims;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Role {
|
||||
User,
|
||||
Editor,
|
||||
Admin,
|
||||
}
|
||||
|
||||
impl fmt::Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Role::User => write!(f, "user"),
|
||||
Role::Editor => write!(f, "editor"),
|
||||
Role::Admin => write!(f, "admin"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Role {
|
||||
fn from(role: &str) -> Self {
|
||||
match role.to_lowercase().as_str() {
|
||||
"admin" => Role::Admin,
|
||||
"editor" => Role::Editor,
|
||||
_ => Role::User,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Simple function to check if a user has a required role
|
||||
pub fn has_role(claims: &Claims, required_role: &Role) -> bool {
|
||||
claims.roles
|
||||
.iter()
|
||||
.any(|r| Role::from(r.as_str()) == *required_role || Role::from(r.as_str()) == Role::Admin)
|
||||
}
|
||||
|
||||
// Simple function to check if a user has any of the required roles
|
||||
pub fn has_any_role(claims: &Claims, required_roles: &[Role]) -> bool {
|
||||
required_roles
|
||||
.iter()
|
||||
.any(|required| has_role(claims, required))
|
||||
}
|
||||
|
||||
// Middleware for role-based authorization
|
||||
pub async fn require_role(required_role: Role, request: Request, next: Next) -> impl IntoResponse {
|
||||
if let Some(claims) = request.extensions().get::<Claims>() {
|
||||
if has_role(claims, &required_role) {
|
||||
next.run(request).await
|
||||
} else {
|
||||
AppError::Forbidden(format!("Requires {} role", required_role)).into_response()
|
||||
}
|
||||
} else {
|
||||
AppError::Unauthorized("Not authenticated".to_string()).into_response()
|
||||
}
|
||||
}
|
||||
8
src/lib.rs
Normal file
8
src/lib.rs
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
pub mod models;
|
||||
pub mod services;
|
||||
pub mod auth;
|
||||
pub mod middleware;
|
||||
|
||||
// Ensure models are accessible
|
||||
pub use models::post::Post;
|
||||
pub use models::user::User;
|
||||
68
src/main.rs
Normal file
68
src/main.rs
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
use axum::routing::{get, post};
|
||||
use axum::{middleware, Extension, Router};
|
||||
use dotenvy::dotenv;
|
||||
use rustrest::services::posts::{get_post, get_posts};
|
||||
use sqlx::postgres::PgPoolOptions;
|
||||
use std::env;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use tokio::net::TcpListener;
|
||||
use tower_http::trace::TraceLayer;
|
||||
use tracing::Level;
|
||||
|
||||
use rustrest::auth;
|
||||
use rustrest::auth::jwt::{JwtAuth};
|
||||
use rustrest::middleware::{audit_log, auth_middleware, security_headers};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.with_max_level(Level::INFO)
|
||||
.init();
|
||||
|
||||
dotenv().ok();
|
||||
|
||||
let url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
|
||||
let max_connections = env::var("MAX_DB_CONNECTIONS")
|
||||
.expect("MAX_DB_CONNECTIONS must be set")
|
||||
.parse()?;
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(max_connections)
|
||||
.connect(&url)
|
||||
.await?;
|
||||
|
||||
// JWT Authentication
|
||||
let jwt_secret = env::var("JWT_SECRET")
|
||||
.expect("JWT_SECRET must be set")
|
||||
.into_bytes();
|
||||
let jwt_auth = Arc::new(JwtAuth::new(&jwt_secret));
|
||||
|
||||
// API routes
|
||||
let api_routes = Router::new()
|
||||
// Public routes
|
||||
.route("/login", post(auth::login))
|
||||
.route("/register", post(auth::register))
|
||||
|
||||
// Protected routes
|
||||
.route("/posts", get(get_posts))
|
||||
.route("/posts/{id}", get(get_post))
|
||||
|
||||
// Apply authentication middleware to all routes
|
||||
.layer(middleware::from_fn_with_state(Arc::clone(&jwt_auth), auth_middleware))
|
||||
.layer(middleware::from_fn(audit_log))
|
||||
.layer(middleware::from_fn(security_headers)) // Security headers
|
||||
.layer(TraceLayer::new_for_http()) // Request tracing
|
||||
.layer(Extension(Arc::clone(&jwt_auth))) // JWT auth
|
||||
.layer(Extension(pool))
|
||||
.with_state(jwt_auth).into_make_service_with_connect_info::<SocketAddr>();
|
||||
|
||||
let bind_host = env::var("BIND_HOST").expect("BIND_HOST must be set");
|
||||
let addr: SocketAddr = bind_host.parse()?;
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
println!("Server is running on {}", bind_host);
|
||||
|
||||
axum::serve(listener, api_routes).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
80
src/middleware/audit.rs
Normal file
80
src/middleware/audit.rs
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
use axum::{
|
||||
extract::{ConnectInfo, Request},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use std::time::Instant;
|
||||
use tracing::{info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
// Audit logging middleware for security-relevant events
|
||||
pub async fn audit_log(
|
||||
ConnectInfo(addr): ConnectInfo<std::net::SocketAddr>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
let method = request.method().clone();
|
||||
let uri = request.uri().clone();
|
||||
let request_id = Uuid::new_v4();
|
||||
|
||||
// Extract user information if available
|
||||
let user_id = request
|
||||
.extensions()
|
||||
.get::<crate::auth::jwt::Claims>()
|
||||
.map(|claims| claims.sub.clone())
|
||||
.unwrap_or_else(|| "anonymous".to_string());
|
||||
|
||||
// Extract sensitive paths to log with more detail
|
||||
let is_sensitive_path = uri.path().contains("/login") ||
|
||||
uri.path().contains("/register") ||
|
||||
uri.path().contains("/password");
|
||||
|
||||
if is_sensitive_path {
|
||||
info!(
|
||||
target: "AUDIT",
|
||||
request_id = %request_id,
|
||||
remote_addr = %addr,
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
user_id = %user_id,
|
||||
"Sensitive operation initiated"
|
||||
);
|
||||
}
|
||||
|
||||
// Process the request
|
||||
let response = next.run(request).await;
|
||||
|
||||
// Get response status for the log
|
||||
let status = response.status();
|
||||
let duration = start.elapsed();
|
||||
|
||||
// Log authentication failures and other suspicious activity
|
||||
if status.is_client_error() || status.is_server_error() {
|
||||
warn!(
|
||||
target: "AUDIT",
|
||||
request_id = %request_id,
|
||||
remote_addr = %addr,
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
user_id = %user_id,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"Request failed"
|
||||
);
|
||||
} else if is_sensitive_path {
|
||||
info!(
|
||||
target: "AUDIT",
|
||||
request_id = %request_id,
|
||||
remote_addr = %addr,
|
||||
method = %method,
|
||||
uri = %uri,
|
||||
user_id = %user_id,
|
||||
status = %status.as_u16(),
|
||||
duration_ms = %duration.as_millis(),
|
||||
"Sensitive operation completed"
|
||||
);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
50
src/middleware/auth_middleware.rs
Normal file
50
src/middleware/auth_middleware.rs
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
use std::sync::Arc;
|
||||
use axum::extract::{Request, State};
|
||||
use axum::middleware::Next;
|
||||
use axum::response::Response;
|
||||
use http::{header, StatusCode};
|
||||
use tracing::{error, info};
|
||||
use crate::auth::JwtAuth;
|
||||
|
||||
pub async fn auth_middleware(
|
||||
State(jwt_auth): State<Arc<JwtAuth>>,
|
||||
mut request: Request,
|
||||
next: Next
|
||||
) -> Response {
|
||||
if request.uri().path() == "/login" || request.uri().path() == "/register" {
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let auth_header = request
|
||||
.headers()
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|header| header.to_str().ok());
|
||||
|
||||
match auth_header {
|
||||
Some(auth_value) if auth_value.starts_with("Bearer ") => {
|
||||
// Extract the token (remove "Bearer " prefix)
|
||||
let token = &auth_value[7..];
|
||||
|
||||
match jwt_auth.verify_token(token) {
|
||||
Ok(claims) => {
|
||||
info!("Authentication successful for user: {}", claims.sub);
|
||||
request.extensions_mut().insert(claims);
|
||||
next.run(request).await
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Token verification failed: {:?}", e);
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body("Invalid token".into())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body("Missing or invalid Authorization header".into())
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
7
src/middleware/mod.rs
Normal file
7
src/middleware/mod.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
pub mod security_headers;
|
||||
mod audit;
|
||||
mod auth_middleware;
|
||||
|
||||
pub use security_headers::security_headers;
|
||||
pub use audit::audit_log;
|
||||
pub use auth_middleware::auth_middleware;
|
||||
57
src/middleware/security_headers.rs
Normal file
57
src/middleware/security_headers.rs
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
use axum::{
|
||||
extract::Request,
|
||||
middleware::Next,
|
||||
response::IntoResponse,
|
||||
};
|
||||
use http::{HeaderMap, HeaderName, HeaderValue};
|
||||
|
||||
// Enhanced security headers
|
||||
pub async fn security_headers(request: Request, next: Next) -> impl IntoResponse {
|
||||
let mut response = next.run(request).await;
|
||||
|
||||
let headers = response.headers_mut();
|
||||
|
||||
// Prevent MIME type sniffing
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-content-type-options"),
|
||||
HeaderValue::from_static("nosniff")
|
||||
);
|
||||
|
||||
// Prevent clickjacking
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-frame-options"),
|
||||
HeaderValue::from_static("DENY")
|
||||
);
|
||||
|
||||
// Enable XSS protections
|
||||
headers.insert(
|
||||
HeaderName::from_static("x-xss-protection"),
|
||||
HeaderValue::from_static("1; mode=block")
|
||||
);
|
||||
|
||||
// Force HTTPS connections
|
||||
headers.insert(
|
||||
HeaderName::from_static("strict-transport-security"),
|
||||
HeaderValue::from_static("max-age=31536000; includeSubDomains; preload")
|
||||
);
|
||||
|
||||
// Control referrer information
|
||||
headers.insert(
|
||||
HeaderName::from_static("referrer-policy"),
|
||||
HeaderValue::from_static("strict-origin-when-cross-origin")
|
||||
);
|
||||
|
||||
// Content Security Policy
|
||||
headers.insert(
|
||||
HeaderName::from_static("content-security-policy"),
|
||||
HeaderValue::from_static("default-src 'self'; script-src 'self'; img-src 'self'; style-src 'self'; font-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'")
|
||||
);
|
||||
|
||||
// Permissions Policy
|
||||
headers.insert(
|
||||
HeaderName::from_static("permissions-policy"),
|
||||
HeaderValue::from_static("camera=(), microphone=(), geolocation=(), interest-cohort=()")
|
||||
);
|
||||
|
||||
response
|
||||
}
|
||||
2
src/models/mod.rs
Normal file
2
src/models/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod post;
|
||||
pub mod user;
|
||||
9
src/models/post.rs
Normal file
9
src/models/post.rs
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
pub struct Post {
|
||||
pub(crate) id: i32,
|
||||
pub(crate) user_id: Option<i32>,
|
||||
pub(crate) title: String,
|
||||
pub(crate) body: String,
|
||||
}
|
||||
152
src/models/user.rs
Normal file
152
src/models/user.rs
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
use argon2::PasswordHash;
|
||||
use crate::auth::{hash_password, verify_password};
|
||||
use crate::services::error::AppError;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::{Pool, Postgres, Row, postgres::PgRow};
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct User {
|
||||
pub id: i32,
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl<'c> sqlx::FromRow<'c, PgRow> for User {
|
||||
fn from_row(row: &'c PgRow) -> Result<Self, sqlx::Error> {
|
||||
let id: i32 = row.try_get("id")?;
|
||||
let username: String = row.try_get("username")?;
|
||||
let email: String = row.try_get("email")?;
|
||||
|
||||
// Handle created_at which might be missing or in a different format
|
||||
let created_at: Option<DateTime<Utc>> = match row.try_get("created_at") {
|
||||
Ok(dt) => Some(dt),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
Ok(User {
|
||||
id,
|
||||
username,
|
||||
email,
|
||||
created_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct NewUser {
|
||||
pub username: String,
|
||||
pub email: String,
|
||||
#[serde(skip)]
|
||||
pub password: String, // Plain string password - used only for signing up
|
||||
}
|
||||
|
||||
// Validate email format
|
||||
fn is_valid_email(email: &str) -> bool {
|
||||
// Simple validation - should use a proper email validation library in production
|
||||
email.contains('@') && email.contains('.')
|
||||
}
|
||||
|
||||
fn is_strong_password(password: &str) -> bool {
|
||||
password.len() >= 12
|
||||
}
|
||||
|
||||
// Database functions
|
||||
impl User {
|
||||
// Create a user in the database, handling both old schema (without password_hash) and new schema
|
||||
pub async fn create(
|
||||
mut new_user: NewUser,
|
||||
password: String,
|
||||
pool: &Pool<Postgres>,
|
||||
) -> Result<Self, AppError> {
|
||||
// Validate input
|
||||
if new_user.username.len() < 3 {
|
||||
return Err(AppError::ValidationError("Username too short".to_string()));
|
||||
}
|
||||
|
||||
if !is_valid_email(&new_user.email) {
|
||||
return Err(AppError::ValidationError(
|
||||
"Invalid email format".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if !is_strong_password(&password) {
|
||||
return Err(AppError::ValidationError(
|
||||
"Password must be at least 12 characters".to_string()
|
||||
));
|
||||
}
|
||||
|
||||
new_user.password = password;
|
||||
|
||||
let user =
|
||||
// Insert with password hash
|
||||
sqlx::query_as::<_, User>(
|
||||
"INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3) RETURNING id, username, email, created_at"
|
||||
)
|
||||
.bind(&new_user.username)
|
||||
.bind(&new_user.email)
|
||||
.bind(hash_password(new_user.password))
|
||||
.fetch_one(pool)
|
||||
.await;
|
||||
|
||||
match user {
|
||||
Ok(user) => Ok(user),
|
||||
Err(e) => {
|
||||
// Handle constraint violations specifically
|
||||
if let sqlx::Error::Database(ref dbe) = e {
|
||||
if let Some(constraint) = dbe.constraint() {
|
||||
if constraint.contains("username") || constraint.contains("email") {
|
||||
return Err(AppError::ValidationError(
|
||||
"Username or email already exists".to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Other database errors become internal errors
|
||||
Err(AppError::InternalServerError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn find_by_credentials(
|
||||
username: &str,
|
||||
password: String,
|
||||
pool: &Pool<Postgres>,
|
||||
) -> anyhow::Result<Self, AppError> {
|
||||
// Get user by username
|
||||
let user = sqlx::query_as::<_, User>(
|
||||
"SELECT id, username, email, created_at FROM users WHERE username = $1",
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|_| AppError::InternalServerError)?
|
||||
.ok_or(AppError::AuthenticationFailed)?;
|
||||
|
||||
// If password hash exists, verify it
|
||||
let password_hash = sqlx::query_scalar::<_, Option<String>>(
|
||||
"SELECT password_hash FROM users WHERE username = $1",
|
||||
)
|
||||
.bind(username)
|
||||
.fetch_one(pool)
|
||||
.await
|
||||
.map_err(|_| AppError::InternalServerError)?
|
||||
.unwrap_or_default();
|
||||
verify_password( &PasswordHash::new(password_hash.as_str()).expect(""),password).map(|_|user)
|
||||
}
|
||||
|
||||
pub async fn find_by_id(id: i32, pool: &Pool<Postgres>) -> Result<Self, AppError> {
|
||||
let user = sqlx::query_as::<_, User>(
|
||||
"SELECT id, username, email, created_at FROM users WHERE id = $1",
|
||||
)
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
.map_err(|_| AppError::InternalServerError)?
|
||||
.ok_or(AppError::NotFound("User not found".to_string()))?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
}
|
||||
92
src/services/error.rs
Normal file
92
src/services/error.rs
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
// src/error.rs
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
use thiserror::Error;
|
||||
use tracing::error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum AppError {
|
||||
#[error("Authentication failed")]
|
||||
AuthenticationFailed,
|
||||
|
||||
#[error("Token creation error")]
|
||||
TokenCreation,
|
||||
|
||||
#[error("Invalid token")]
|
||||
InvalidToken,
|
||||
|
||||
#[error("Token expired")]
|
||||
TokenExpired,
|
||||
|
||||
#[error("Missing authentication token")]
|
||||
MissingToken,
|
||||
|
||||
#[error("Missing auth service")]
|
||||
MissingAuthService,
|
||||
|
||||
#[error("Internal server error")]
|
||||
InternalServerError,
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Validation error: {0}")]
|
||||
ValidationError(String),
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
|
||||
#[error("Forbidden: {0}")]
|
||||
Forbidden(String),
|
||||
|
||||
#[error("Database error: {0}")]
|
||||
DatabaseError(String),
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, error_message) = match self {
|
||||
AppError::AuthenticationFailed => (StatusCode::UNAUTHORIZED, self.to_string()),
|
||||
AppError::TokenCreation => {
|
||||
error!("Failed to create token: {}", self);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
|
||||
},
|
||||
AppError::InvalidToken => (StatusCode::UNAUTHORIZED, self.to_string()),
|
||||
AppError::TokenExpired => (StatusCode::UNAUTHORIZED, self.to_string()),
|
||||
AppError::MissingToken => (StatusCode::UNAUTHORIZED, self.to_string()),
|
||||
AppError::MissingAuthService => {
|
||||
error!("Auth service missing: {}", self);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
|
||||
},
|
||||
AppError::InternalServerError => {
|
||||
error!("Internal server error: {}", self);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
|
||||
},
|
||||
AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
|
||||
AppError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg),
|
||||
AppError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, msg),
|
||||
AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg),
|
||||
AppError::DatabaseError(msg) => {
|
||||
error!("Database error: {}", msg);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
|
||||
},
|
||||
};
|
||||
|
||||
// Hide internal details from response for security
|
||||
let public_message = if status == StatusCode::INTERNAL_SERVER_ERROR {
|
||||
"An internal error occurred. Please try again later.".to_string()
|
||||
} else {
|
||||
error_message
|
||||
};
|
||||
|
||||
let body = Json(json!({
|
||||
"error": public_message,
|
||||
}));
|
||||
|
||||
(status, body).into_response()
|
||||
}
|
||||
}
|
||||
2
src/services/mod.rs
Normal file
2
src/services/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod posts;
|
||||
pub mod error;
|
||||
26
src/services/posts.rs
Normal file
26
src/services/posts.rs
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
use crate::models::post::Post;
|
||||
use axum::http::StatusCode;
|
||||
use axum::{Extension, Json};
|
||||
use axum::extract::Path;
|
||||
use sqlx::{query_as, Pool, Postgres};
|
||||
|
||||
pub async fn get_posts(
|
||||
Extension(pool): Extension<Pool<Postgres>>,
|
||||
) -> Result<Json<Vec<Post>>, StatusCode> {
|
||||
let posts = query_as!(Post, "SELECT id, user_id, title, body FROM posts")
|
||||
.fetch_all(&pool)
|
||||
.await
|
||||
.map_err(|_| StatusCode::NOT_FOUND)?;
|
||||
Ok(Json(posts))
|
||||
}
|
||||
|
||||
pub async fn get_post(
|
||||
Extension(pool): Extension<Pool<Postgres>>,
|
||||
Path(id): Path<i32>
|
||||
) -> Result<Json<Post>, StatusCode> {
|
||||
let post = query_as!(Post, "SELECT id, user_id, title, body FROM posts WHERE id = $1", id)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(|_| StatusCode::NOT_FOUND)?;
|
||||
Ok(Json(post))
|
||||
}
|
||||
Loading…
Add table
Reference in a new issue