diff --git a/.gitignore b/.gitignore index 6a38901..b33914d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /target .env -/test \ No newline at end of file +/test +configuration.yaml \ No newline at end of file diff --git a/configuration.example.yaml b/configuration.example.yaml new file mode 100644 index 0000000..7448501 --- /dev/null +++ b/configuration.example.yaml @@ -0,0 +1,19 @@ +server: + host: 0.0.0.0 + port: 7070 + +db: + user: "oggy" + password: "very_secure_password" + host: "127.0.0.1" + port: "5432" + dbname: "okiba" + pool: + max_size: 16 + +app: + redis_uri: "redis://127.0.0.1" + rate_limit: + time_seconds: 60 + request_count: 10 + paste_id_length: 6 diff --git a/sql/schema.sql b/sql/schema.sql index 071ce3a..71d29d5 100644 --- a/sql/schema.sql +++ b/sql/schema.sql @@ -3,6 +3,6 @@ CREATE SCHEMA bin; CREATE TABLE bin.pastes ( id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY, - paste_id VARCHAR(5) NOT NULL, + paste_id VARCHAR(6) NOT NULL, content TEXT NOT NULL ) \ No newline at end of file diff --git a/src/cfg.rs b/src/cfg.rs index 7a5db9e..f68175b 100644 --- a/src/cfg.rs +++ b/src/cfg.rs @@ -1,16 +1,37 @@ pub use config::ConfigError; +use config::{File, FileFormat}; use serde::Deserialize; -#[derive(Debug, Default, Deserialize)] -pub struct AppConfig { - pub server_addr: String, - pub pg: deadpool_postgres::Config, +#[derive(Debug, Default, Deserialize, Clone)] +pub struct Config { + pub db: deadpool_postgres::Config, + pub server: ServerConfig, + pub app: ApplicationConfig, } -impl AppConfig { +#[derive(Debug, Default, Deserialize, Clone)] +pub struct ServerConfig { + pub port: u16, + pub host: String, +} + +#[derive(Debug, Default, Deserialize, Clone)] +pub struct ApplicationConfig { + pub redis_uri: String, + pub rate_limit: RateLimiterConfig, + pub paste_id_length: u8, +} + +#[derive(Debug, Default, Deserialize, Clone)] +pub struct RateLimiterConfig { + pub time_seconds: u64, + pub request_count: u64, +} + +impl Config { pub fn from_env() -> Result { let config = config::Config::builder() - .add_source(::config::Environment::default()) + .add_source(File::new("configuration", FileFormat::Yaml)) .build()?; config.try_deserialize() } diff --git a/src/db.rs b/src/db.rs index 464f57b..2cd16f5 100644 --- a/src/db.rs +++ b/src/db.rs @@ -24,7 +24,6 @@ pub mod errors { #[derive(Display, From, Debug)] pub enum MyError { - NotFound, PGError(PGError), PGMError(PGMError), PoolError(PoolError), @@ -33,16 +32,19 @@ pub mod errors { impl ResponseError for MyError { fn error_response(&self) -> HttpResponse { - match *self { - MyError::NotFound => HttpResponse::NotFound().json(json!({ - "message": "record not found" - })), - MyError::PoolError(ref err) => HttpResponse::InternalServerError().json(json!({ - "message": err.to_string() - })), - _ => HttpResponse::InternalServerError().json(json!({ - "message": "internal server error" - })), + match self { + MyError::PoolError(ref err) => { + log::error!("{}", err.to_string()); + HttpResponse::InternalServerError().json(json!({ + "message": err.to_string() + })) + } + error => { + log::error!("{}", error.to_string()); + HttpResponse::InternalServerError().json(json!({ + "message": "internal server error" + })) + } } } } diff --git a/src/main.rs b/src/main.rs index c3e2f14..21f4ee7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,16 +10,14 @@ use actix_web::{ web::{self}, App, Error, HttpResponse, HttpServer, }; -use cfg::AppConfig; +use cfg::ApplicationConfig; +use cfg::Config; use deadpool_postgres::{Client, Pool}; use dotenv::dotenv; use rand::{distributions::Alphanumeric, Rng}; use redis::aio::ConnectionManager; use serde_json::json; -use std::{ - str::{self}, - time::Duration, -}; +use std::time::Duration; use tokio_postgres::NoTls; fn generate_endpoint(length: u8) -> String { @@ -47,14 +45,18 @@ async fn fetch_paste( } #[post("/paste")] -async fn new_paste(db_pool: web::Data, body: String) -> Result { +async fn new_paste( + db_pool: web::Data, + app_config: web::Data, + body: String, +) -> Result { let client: Client = db_pool.get().await.map_err(MyError::PoolError)?; - let mut endpoint = generate_endpoint(5); + let mut endpoint = generate_endpoint(app_config.paste_id_length); // check if endpoint already exists loop { if db::paste_id_exists(&client, &endpoint).await? { - endpoint = generate_endpoint(5) + endpoint = generate_endpoint(app_config.paste_id_length) } else { break; } @@ -70,37 +72,43 @@ async fn new_paste(db_pool: web::Data, body: String) -> Result std::io::Result<()> { dotenv().ok(); - const ADDR: (&str, u16) = ("127.0.0.1", 8080); env_logger::init_from_env(env_logger::Env::new().default_filter_or("info")); - let config = AppConfig::from_env().unwrap(); + let config = Config::from_env().unwrap(); log::info!("Connecting database"); - let pool = config.pg.create_pool(None, NoTls).unwrap(); + let pool = config.db.create_pool(None, NoTls).unwrap(); - let redis_client = - redis::Client::open("redis://127.0.0.1:6379").expect("Couldn't connect to redis database"); + let redis_client = redis::Client::open(config.app.redis_uri.clone()) + .expect("Couldn't connect to redis database"); let redis_cm = ConnectionManager::new(redis_client).await.unwrap(); let redis_backend = RedisBackend::builder(redis_cm).build(); let server = HttpServer::new(move || { - // 5 requests per 60 seconds - let input = SimpleInputFunctionBuilder::new(Duration::from_secs(60), 5) - .real_ip_key() - .build(); + let input = SimpleInputFunctionBuilder::new( + Duration::from_secs(config.app.rate_limit.time_seconds), + config.app.rate_limit.request_count, + ) + .real_ip_key() + .build(); let middleware = RateLimiter::builder(redis_backend.clone(), input) .add_headers() .build(); App::new() .app_data(web::Data::new(pool.clone())) + .app_data(web::Data::new(config.app.clone())) .wrap(middleware) .service(new_paste) .service(fetch_paste) }) - .bind(ADDR)? + .bind((&*config.server.host, config.server.port))? .run(); - log::info!("Starting the server at http://{}:{}", ADDR.0, ADDR.1); + log::info!( + "Starting the server at http://{}:{}", + config.server.host, + config.server.port + ); server.await }