first commit

This commit is contained in:
Shautvast 2025-05-15 18:27:51 +02:00
commit 7fc534ba8d
25 changed files with 3711 additions and 0 deletions

3
.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
/target
.env
.idea

View file

@ -0,0 +1,40 @@
{
"db_name": "PostgreSQL",
"query": "SELECT id, user_id, title, body FROM posts WHERE id = $1",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "user_id",
"type_info": "Int4"
},
{
"ordinal": 2,
"name": "title",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "body",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Int4"
]
},
"nullable": [
false,
true,
false,
false
]
},
"hash": "4643a886e01f7111e63d0136dc63882c7186de86d3a30ca3d00b280c5cdc0ed2"
}

View file

@ -0,0 +1,38 @@
{
"db_name": "PostgreSQL",
"query": "SELECT id, user_id, title, body FROM posts",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int4"
},
{
"ordinal": 1,
"name": "user_id",
"type_info": "Int4"
},
{
"ordinal": 2,
"name": "title",
"type_info": "Text"
},
{
"ordinal": 3,
"name": "body",
"type_info": "Text"
}
],
"parameters": {
"Left": []
},
"nullable": [
false,
true,
false,
false
]
},
"hash": "a65937cc250d94bc40f49bf1d8dbd426f9c4c233aa1e8ff416cd78998265126f"
}

2727
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

27
Cargo.toml Normal file
View file

@ -0,0 +1,27 @@
[package]
name = "rustrest"
version = "0.1.0"
edition = "2024"
[dependencies]
axum = { version = "0.8", features = ["macros"] }
axum-extra = { version = "0.10", features = ["typed-header"] }
dotenvy = "0.15"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0.140"
sqlx = { version = "0.8", features = ["runtime-tokio", "tls-native-tls", "postgres", "uuid", "chrono"] }
tokio = { version = "1.45", features = ["full"] }
tracing = "0.1"
tracing-subscriber = "0.3"
anyhow = "1.0"
tower-http = { version = "0.5", features = ["cors", "trace"] }
jsonwebtoken = "9.0"
chrono = { version = "0.4", features = ["serde"] }
http = "1.0"
thiserror = "1.0"
uuid = { version = "1.5", features = ["serde", "v4"] }
# New dependencies for security
argon2 = { version = "0.5", features = ["password-hash"] }

17
README.md Normal file
View file

@ -0,0 +1,17 @@
* REST Api based on Axum 0.8, serves as a more complete example than most blogs will provide
* Postgres database using Sqlx, including migrations
* simple datamodel and api for reading posts for a blog
* Has users and roles (roles not fully implemented)
* logging
* externalized config
* /register stores the user (passwords hashed with argon2)
* /login returns a JWT token
* /posts returns all posts
* /posts/{id} returns a post
| .env |
| DATABASE_URL=postgres://postgres:...@localhost:5432/rust-axum-rest-api |
| MAX_DB_CONNECTIONS=5 |
| BIND_HOST=0.0.0.0:5001 |
| JWT_SECRET=... |
| ALLOWED_ORIGINS= |

View file

@ -0,0 +1,9 @@
-- Add migration script here
CREATE TABLE users
(
id SERIAL PRIMARY KEY,
username TEXT NOT NULL UNIQUE,
email TEXT NOT NULL UNIQUE,
password_hash TEXT,
created_at TIMESTAMP DEFAULT NOW()
)

View file

@ -0,0 +1,8 @@
-- Add migration script here
CREATE TABLE posts(
id SERIAL PRIMARY KEY,
user_id INTEGER REFERENCES users(id) ON DELETE CASCADE,
title TEXT NOT NULL,
body TEXT NOT NULL,
created_at TIMESTAMP DEFAULT NOW()
)

112
src/auth/jwt.rs Normal file
View file

@ -0,0 +1,112 @@
use axum::{
http::request::Parts,
RequestPartsExt,
extract::FromRequestParts,
};
use axum_extra::headers::Authorization;
use axum_extra::headers::authorization::Bearer;
use axum_extra::TypedHeader;
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing;
use uuid::Uuid;
use crate::services::error::AppError;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String, // User ID
pub exp: i64, // Expiration time
pub iat: i64, // Issued at
pub jti: String, // JWT ID (unique identifier)
pub roles: Vec<String>, // User roles
}
impl Claims {
pub fn new(user_id: String, roles: Vec<String>, expiration: Duration) -> Self {
let now = Utc::now();
Self {
sub: user_id,
iat: now.timestamp(),
exp: (now + expiration).timestamp(),
jti: Uuid::new_v4().to_string(),
roles,
}
}
}
#[derive(Clone)]
pub struct JwtAuth {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
}
impl JwtAuth {
pub fn new(secret: &[u8]) -> Self {
Self {
encoding_key: EncodingKey::from_secret(secret),
decoding_key: DecodingKey::from_secret(secret),
}
}
pub fn create_token(&self, claims: &Claims) -> Result<String, AppError> {
encode(&Header::default(), claims, &self.encoding_key)
.map_err(|_| AppError::TokenCreation)
}
pub fn verify_token(&self, token: &str) -> Result<Claims, AppError> {
// Create a validation object with default settings
let mut validation = Validation::default();
validation.validate_exp = true; // Verify expiration time
validation.leeway = 0; // No leeway for exp verification (default)
// Decode and verify the token
match decode::<Claims>(token, &self.decoding_key, &validation) {
Ok(token_data) => {
// Token is valid, return claims
Ok(token_data.claims)
}
Err(e) => {
// Log the error for debugging
tracing::error!("Token validation error: {:?}", e);
// Map jsonwebtoken errors to AppError
match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => Err(AppError::TokenExpired),
jsonwebtoken::errors::ErrorKind::InvalidToken => Err(AppError::InvalidToken),
jsonwebtoken::errors::ErrorKind::InvalidSignature => Err(AppError::InvalidToken),
_ => Err(AppError::InvalidToken),
}
}
}
}
}
// Extractor for protected routes
impl<S> FromRequestParts<S> for Claims
where
S: Send + Sync,
{
type Rejection = AppError;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
// Extract the token from the Authorization header
let TypedHeader(Authorization(bearer)) = parts
.extract::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|_| AppError::MissingToken)?;
// Get our auth service from extensions
let jwt_auth = parts
.extensions
.get::<Arc<JwtAuth>>()
.ok_or(AppError::MissingAuthService)?;
// Verify the token
let claims = jwt_auth.verify_token(bearer.token())?;
Ok(claims)
}
}

82
src/auth/login.rs Normal file
View file

@ -0,0 +1,82 @@
use std::sync::Arc;
use axum::{extract::{State, Json, Extension}, http::StatusCode};
use chrono::Duration;
use serde::{Deserialize, Serialize};
use sqlx::Pool;
use sqlx::Postgres;
use crate::services::error::AppError;
use crate::auth::jwt::{Claims, JwtAuth};
use crate::models::user::User;
#[derive(Deserialize)]
pub struct LoginRequest {
username: String,
password: String,
}
#[derive(Serialize)]
pub struct LoginResponse {
access_token: String,
token_type: String,
expires_in: i64,
user_id: i32,
username: String,
}
pub async fn login(
State(jwt_auth): State<Arc<JwtAuth>>,
Extension(pool): Extension<Pool<Postgres>>,
Json(payload): Json<LoginRequest>,
) -> Result<Json<LoginResponse>, AppError> {
// Authentication against database
let user = User::find_by_credentials(
&payload.username,
payload.password,
&pool,
)
.await?;
// Create token with appropriate roles
let expiration = Duration::minutes(15);
let claims = Claims::new(
user.id.to_string(),
vec!["user".to_string()], // Default role - in real app, fetch from database
expiration,
);
let token = jwt_auth.create_token(&claims)?;
Ok(Json(LoginResponse {
access_token: token,
token_type: "Bearer".to_string(),
expires_in: expiration.num_seconds(),
user_id: user.id,
username: user.username,
}))
}
// Registration endpoint
#[derive(Deserialize)]
pub struct RegisterRequest {
username: String,
email: String,
password: String,
}
pub async fn register(
Extension(pool): Extension<Pool<Postgres>>,
Json(payload): Json<RegisterRequest>,
) -> Result<(StatusCode, Json<User>), AppError> {
// Create the new user
let new_user = crate::models::user::NewUser {
username: payload.username,
email: payload.email,
password: String::new(), // Placeholder
};
let user = User::create(new_user, payload.password, &pool).await?;
// Return the created user (without password)
Ok((StatusCode::CREATED, Json(user)))
}

8
src/auth/mod.rs Normal file
View file

@ -0,0 +1,8 @@
pub mod rbac;
pub mod login;
pub mod jwt;
pub mod password;
pub use login::{login, register};
pub use password::{hash_password, verify_password};
pub use jwt::JwtAuth;

24
src/auth/password.rs Normal file
View file

@ -0,0 +1,24 @@
use crate::services::error::AppError;
use argon2::password_hash::SaltString;
use argon2::password_hash::rand_core::OsRng;
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
pub fn hash_password(password: String) -> String {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
argon2
.hash_password(password.as_bytes(), &salt)
.unwrap()
.to_string()
}
pub fn verify_password(
stored_hash: &PasswordHash<'_>,
password: String,
) -> anyhow::Result<bool, AppError> {
let argon2 = Argon2::default();
argon2
.verify_password(password.as_bytes(), stored_hash)
.map(|_| true)
.map_err(|_| AppError::AuthenticationFailed)
}

63
src/auth/rbac.rs Normal file
View file

@ -0,0 +1,63 @@
use axum::{
extract::Request,
middleware::Next,
response::{IntoResponse, Response},
};
use std::fmt;
use crate::services::error::AppError;
use crate::auth::jwt::Claims;
#[derive(Debug, Clone, PartialEq)]
pub enum Role {
User,
Editor,
Admin,
}
impl fmt::Display for Role {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Role::User => write!(f, "user"),
Role::Editor => write!(f, "editor"),
Role::Admin => write!(f, "admin"),
}
}
}
impl From<&str> for Role {
fn from(role: &str) -> Self {
match role.to_lowercase().as_str() {
"admin" => Role::Admin,
"editor" => Role::Editor,
_ => Role::User,
}
}
}
// Simple function to check if a user has a required role
pub fn has_role(claims: &Claims, required_role: &Role) -> bool {
claims.roles
.iter()
.any(|r| Role::from(r.as_str()) == *required_role || Role::from(r.as_str()) == Role::Admin)
}
// Simple function to check if a user has any of the required roles
pub fn has_any_role(claims: &Claims, required_roles: &[Role]) -> bool {
required_roles
.iter()
.any(|required| has_role(claims, required))
}
// Middleware for role-based authorization
pub async fn require_role(required_role: Role, request: Request, next: Next) -> impl IntoResponse {
if let Some(claims) = request.extensions().get::<Claims>() {
if has_role(claims, &required_role) {
next.run(request).await
} else {
AppError::Forbidden(format!("Requires {} role", required_role)).into_response()
}
} else {
AppError::Unauthorized("Not authenticated".to_string()).into_response()
}
}

8
src/lib.rs Normal file
View file

@ -0,0 +1,8 @@
pub mod models;
pub mod services;
pub mod auth;
pub mod middleware;
// Ensure models are accessible
pub use models::post::Post;
pub use models::user::User;

68
src/main.rs Normal file
View file

@ -0,0 +1,68 @@
use axum::routing::{get, post};
use axum::{middleware, Extension, Router};
use dotenvy::dotenv;
use rustrest::services::posts::{get_post, get_posts};
use sqlx::postgres::PgPoolOptions;
use std::env;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
use tower_http::trace::TraceLayer;
use tracing::Level;
use rustrest::auth;
use rustrest::auth::jwt::{JwtAuth};
use rustrest::middleware::{audit_log, auth_middleware, security_headers};
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_max_level(Level::INFO)
.init();
dotenv().ok();
let url = env::var("DATABASE_URL").expect("DATABASE_URL must be set");
let max_connections = env::var("MAX_DB_CONNECTIONS")
.expect("MAX_DB_CONNECTIONS must be set")
.parse()?;
let pool = PgPoolOptions::new()
.max_connections(max_connections)
.connect(&url)
.await?;
// JWT Authentication
let jwt_secret = env::var("JWT_SECRET")
.expect("JWT_SECRET must be set")
.into_bytes();
let jwt_auth = Arc::new(JwtAuth::new(&jwt_secret));
// API routes
let api_routes = Router::new()
// Public routes
.route("/login", post(auth::login))
.route("/register", post(auth::register))
// Protected routes
.route("/posts", get(get_posts))
.route("/posts/{id}", get(get_post))
// Apply authentication middleware to all routes
.layer(middleware::from_fn_with_state(Arc::clone(&jwt_auth), auth_middleware))
.layer(middleware::from_fn(audit_log))
.layer(middleware::from_fn(security_headers)) // Security headers
.layer(TraceLayer::new_for_http()) // Request tracing
.layer(Extension(Arc::clone(&jwt_auth))) // JWT auth
.layer(Extension(pool))
.with_state(jwt_auth).into_make_service_with_connect_info::<SocketAddr>();
let bind_host = env::var("BIND_HOST").expect("BIND_HOST must be set");
let addr: SocketAddr = bind_host.parse()?;
let listener = TcpListener::bind(addr).await?;
println!("Server is running on {}", bind_host);
axum::serve(listener, api_routes).await?;
Ok(())
}

80
src/middleware/audit.rs Normal file
View file

@ -0,0 +1,80 @@
use axum::{
extract::{ConnectInfo, Request},
middleware::Next,
response::Response,
};
use std::time::Instant;
use tracing::{info, warn};
use uuid::Uuid;
// Audit logging middleware for security-relevant events
pub async fn audit_log(
ConnectInfo(addr): ConnectInfo<std::net::SocketAddr>,
request: Request,
next: Next,
) -> Response {
let start = Instant::now();
let method = request.method().clone();
let uri = request.uri().clone();
let request_id = Uuid::new_v4();
// Extract user information if available
let user_id = request
.extensions()
.get::<crate::auth::jwt::Claims>()
.map(|claims| claims.sub.clone())
.unwrap_or_else(|| "anonymous".to_string());
// Extract sensitive paths to log with more detail
let is_sensitive_path = uri.path().contains("/login") ||
uri.path().contains("/register") ||
uri.path().contains("/password");
if is_sensitive_path {
info!(
target: "AUDIT",
request_id = %request_id,
remote_addr = %addr,
method = %method,
uri = %uri,
user_id = %user_id,
"Sensitive operation initiated"
);
}
// Process the request
let response = next.run(request).await;
// Get response status for the log
let status = response.status();
let duration = start.elapsed();
// Log authentication failures and other suspicious activity
if status.is_client_error() || status.is_server_error() {
warn!(
target: "AUDIT",
request_id = %request_id,
remote_addr = %addr,
method = %method,
uri = %uri,
user_id = %user_id,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Request failed"
);
} else if is_sensitive_path {
info!(
target: "AUDIT",
request_id = %request_id,
remote_addr = %addr,
method = %method,
uri = %uri,
user_id = %user_id,
status = %status.as_u16(),
duration_ms = %duration.as_millis(),
"Sensitive operation completed"
);
}
response
}

View file

@ -0,0 +1,50 @@
use std::sync::Arc;
use axum::extract::{Request, State};
use axum::middleware::Next;
use axum::response::Response;
use http::{header, StatusCode};
use tracing::{error, info};
use crate::auth::JwtAuth;
pub async fn auth_middleware(
State(jwt_auth): State<Arc<JwtAuth>>,
mut request: Request,
next: Next
) -> Response {
if request.uri().path() == "/login" || request.uri().path() == "/register" {
return next.run(request).await;
}
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok());
match auth_header {
Some(auth_value) if auth_value.starts_with("Bearer ") => {
// Extract the token (remove "Bearer " prefix)
let token = &auth_value[7..];
match jwt_auth.verify_token(token) {
Ok(claims) => {
info!("Authentication successful for user: {}", claims.sub);
request.extensions_mut().insert(claims);
next.run(request).await
}
Err(e) => {
error!("Token verification failed: {:?}", e);
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body("Invalid token".into())
.unwrap()
}
}
}
_ => {
Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body("Missing or invalid Authorization header".into())
.unwrap()
}
}
}

7
src/middleware/mod.rs Normal file
View file

@ -0,0 +1,7 @@
pub mod security_headers;
mod audit;
mod auth_middleware;
pub use security_headers::security_headers;
pub use audit::audit_log;
pub use auth_middleware::auth_middleware;

View file

@ -0,0 +1,57 @@
use axum::{
extract::Request,
middleware::Next,
response::IntoResponse,
};
use http::{HeaderMap, HeaderName, HeaderValue};
// Enhanced security headers
pub async fn security_headers(request: Request, next: Next) -> impl IntoResponse {
let mut response = next.run(request).await;
let headers = response.headers_mut();
// Prevent MIME type sniffing
headers.insert(
HeaderName::from_static("x-content-type-options"),
HeaderValue::from_static("nosniff")
);
// Prevent clickjacking
headers.insert(
HeaderName::from_static("x-frame-options"),
HeaderValue::from_static("DENY")
);
// Enable XSS protections
headers.insert(
HeaderName::from_static("x-xss-protection"),
HeaderValue::from_static("1; mode=block")
);
// Force HTTPS connections
headers.insert(
HeaderName::from_static("strict-transport-security"),
HeaderValue::from_static("max-age=31536000; includeSubDomains; preload")
);
// Control referrer information
headers.insert(
HeaderName::from_static("referrer-policy"),
HeaderValue::from_static("strict-origin-when-cross-origin")
);
// Content Security Policy
headers.insert(
HeaderName::from_static("content-security-policy"),
HeaderValue::from_static("default-src 'self'; script-src 'self'; img-src 'self'; style-src 'self'; font-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self'")
);
// Permissions Policy
headers.insert(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("camera=(), microphone=(), geolocation=(), interest-cohort=()")
);
response
}

2
src/models/mod.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod post;
pub mod user;

9
src/models/post.rs Normal file
View file

@ -0,0 +1,9 @@
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Serialize)]
pub struct Post {
pub(crate) id: i32,
pub(crate) user_id: Option<i32>,
pub(crate) title: String,
pub(crate) body: String,
}

152
src/models/user.rs Normal file
View file

@ -0,0 +1,152 @@
use argon2::PasswordHash;
use crate::auth::{hash_password, verify_password};
use crate::services::error::AppError;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use sqlx::{Pool, Postgres, Row, postgres::PgRow};
#[derive(Debug, Serialize)]
pub struct User {
pub id: i32,
pub username: String,
pub email: String,
pub created_at: Option<DateTime<Utc>>,
}
impl<'c> sqlx::FromRow<'c, PgRow> for User {
fn from_row(row: &'c PgRow) -> Result<Self, sqlx::Error> {
let id: i32 = row.try_get("id")?;
let username: String = row.try_get("username")?;
let email: String = row.try_get("email")?;
// Handle created_at which might be missing or in a different format
let created_at: Option<DateTime<Utc>> = match row.try_get("created_at") {
Ok(dt) => Some(dt),
Err(_) => None,
};
Ok(User {
id,
username,
email,
created_at,
})
}
}
#[derive(Debug, Deserialize)]
pub struct NewUser {
pub username: String,
pub email: String,
#[serde(skip)]
pub password: String, // Plain string password - used only for signing up
}
// Validate email format
fn is_valid_email(email: &str) -> bool {
// Simple validation - should use a proper email validation library in production
email.contains('@') && email.contains('.')
}
fn is_strong_password(password: &str) -> bool {
password.len() >= 12
}
// Database functions
impl User {
// Create a user in the database, handling both old schema (without password_hash) and new schema
pub async fn create(
mut new_user: NewUser,
password: String,
pool: &Pool<Postgres>,
) -> Result<Self, AppError> {
// Validate input
if new_user.username.len() < 3 {
return Err(AppError::ValidationError("Username too short".to_string()));
}
if !is_valid_email(&new_user.email) {
return Err(AppError::ValidationError(
"Invalid email format".to_string(),
));
}
if !is_strong_password(&password) {
return Err(AppError::ValidationError(
"Password must be at least 12 characters".to_string()
));
}
new_user.password = password;
let user =
// Insert with password hash
sqlx::query_as::<_, User>(
"INSERT INTO users (username, email, password_hash) VALUES ($1, $2, $3) RETURNING id, username, email, created_at"
)
.bind(&new_user.username)
.bind(&new_user.email)
.bind(hash_password(new_user.password))
.fetch_one(pool)
.await;
match user {
Ok(user) => Ok(user),
Err(e) => {
// Handle constraint violations specifically
if let sqlx::Error::Database(ref dbe) = e {
if let Some(constraint) = dbe.constraint() {
if constraint.contains("username") || constraint.contains("email") {
return Err(AppError::ValidationError(
"Username or email already exists".to_string(),
));
}
}
}
// Other database errors become internal errors
Err(AppError::InternalServerError)
}
}
}
pub async fn find_by_credentials(
username: &str,
password: String,
pool: &Pool<Postgres>,
) -> anyhow::Result<Self, AppError> {
// Get user by username
let user = sqlx::query_as::<_, User>(
"SELECT id, username, email, created_at FROM users WHERE username = $1",
)
.bind(username)
.fetch_optional(pool)
.await
.map_err(|_| AppError::InternalServerError)?
.ok_or(AppError::AuthenticationFailed)?;
// If password hash exists, verify it
let password_hash = sqlx::query_scalar::<_, Option<String>>(
"SELECT password_hash FROM users WHERE username = $1",
)
.bind(username)
.fetch_one(pool)
.await
.map_err(|_| AppError::InternalServerError)?
.unwrap_or_default();
verify_password( &PasswordHash::new(password_hash.as_str()).expect(""),password).map(|_|user)
}
pub async fn find_by_id(id: i32, pool: &Pool<Postgres>) -> Result<Self, AppError> {
let user = sqlx::query_as::<_, User>(
"SELECT id, username, email, created_at FROM users WHERE id = $1",
)
.bind(id)
.fetch_optional(pool)
.await
.map_err(|_| AppError::InternalServerError)?
.ok_or(AppError::NotFound("User not found".to_string()))?;
Ok(user)
}
}

92
src/services/error.rs Normal file
View file

@ -0,0 +1,92 @@
// src/error.rs
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use thiserror::Error;
use tracing::error;
#[derive(Error, Debug)]
pub enum AppError {
#[error("Authentication failed")]
AuthenticationFailed,
#[error("Token creation error")]
TokenCreation,
#[error("Invalid token")]
InvalidToken,
#[error("Token expired")]
TokenExpired,
#[error("Missing authentication token")]
MissingToken,
#[error("Missing auth service")]
MissingAuthService,
#[error("Internal server error")]
InternalServerError,
#[error("Not found: {0}")]
NotFound(String),
#[error("Validation error: {0}")]
ValidationError(String),
#[error("Unauthorized: {0}")]
Unauthorized(String),
#[error("Forbidden: {0}")]
Forbidden(String),
#[error("Database error: {0}")]
DatabaseError(String),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_message) = match self {
AppError::AuthenticationFailed => (StatusCode::UNAUTHORIZED, self.to_string()),
AppError::TokenCreation => {
error!("Failed to create token: {}", self);
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
},
AppError::InvalidToken => (StatusCode::UNAUTHORIZED, self.to_string()),
AppError::TokenExpired => (StatusCode::UNAUTHORIZED, self.to_string()),
AppError::MissingToken => (StatusCode::UNAUTHORIZED, self.to_string()),
AppError::MissingAuthService => {
error!("Auth service missing: {}", self);
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
},
AppError::InternalServerError => {
error!("Internal server error: {}", self);
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
},
AppError::NotFound(msg) => (StatusCode::NOT_FOUND, msg),
AppError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg),
AppError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, msg),
AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg),
AppError::DatabaseError(msg) => {
error!("Database error: {}", msg);
(StatusCode::INTERNAL_SERVER_ERROR, "Internal server error".to_string())
},
};
// Hide internal details from response for security
let public_message = if status == StatusCode::INTERNAL_SERVER_ERROR {
"An internal error occurred. Please try again later.".to_string()
} else {
error_message
};
let body = Json(json!({
"error": public_message,
}));
(status, body).into_response()
}
}

2
src/services/mod.rs Normal file
View file

@ -0,0 +1,2 @@
pub mod posts;
pub mod error;

26
src/services/posts.rs Normal file
View file

@ -0,0 +1,26 @@
use crate::models::post::Post;
use axum::http::StatusCode;
use axum::{Extension, Json};
use axum::extract::Path;
use sqlx::{query_as, Pool, Postgres};
pub async fn get_posts(
Extension(pool): Extension<Pool<Postgres>>,
) -> Result<Json<Vec<Post>>, StatusCode> {
let posts = query_as!(Post, "SELECT id, user_id, title, body FROM posts")
.fetch_all(&pool)
.await
.map_err(|_| StatusCode::NOT_FOUND)?;
Ok(Json(posts))
}
pub async fn get_post(
Extension(pool): Extension<Pool<Postgres>>,
Path(id): Path<i32>
) -> Result<Json<Post>, StatusCode> {
let post = query_as!(Post, "SELECT id, user_id, title, body FROM posts WHERE id = $1", id)
.fetch_one(&pool)
.await
.map_err(|_| StatusCode::NOT_FOUND)?;
Ok(Json(post))
}