diff --git a/src/main.rs b/src/main.rs index 9de2694..d868c00 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ mod board; mod game; -use axum::Router; +use std::env; + +use axum::{http::Method, Router}; use board::Board; use dotenv::dotenv; use futures_util::stream::StreamExt; @@ -16,7 +18,7 @@ use socketioxide::{ }; use sqlx::PgPool; use tokio::net::TcpListener; -use tower_http::cors::CorsLayer; +use tower_http::cors::{AllowOrigin, CorsLayer}; use tracing_subscriber::FmtSubscriber; #[tokio::main] @@ -34,9 +36,27 @@ async fn main() -> Result<(), Box> { let (layer, io) = SocketIo::builder().with_state(pool).build_layer(); io.ns("/", on_connect); - let app = Router::new() - .layer(layer) - .layer(CorsLayer::very_permissive()); + // Get the allowed origins from the .env file + let allowed_origins = env::var("ALLOWED_ORIGINS").expect("ALLOWED_ORIGINS must be set"); + + // Split the origins by comma and collect them into a vector + let origins: Vec = allowed_origins + .split(',') + .map(|s| s.trim().to_string()) + .collect(); + + // Convert the vector of strings into `AllowOrigin` + let allow_origin = AllowOrigin::list(origins.iter().map(|origin| origin.parse().unwrap())); + + // Create a CORS layer + let cors = CorsLayer::new().allow_origin(allow_origin).allow_methods([ + Method::GET, + Method::POST, + Method::PUT, + Method::DELETE, + ]); + + let app = Router::new().layer(layer).layer(cors); let listener = TcpListener::bind("0.0.0.0:3000").await?; println!("listening on {}", listener.local_addr()?);