From 3493eeb29e3bae68254928ed88d375d2ca75005e Mon Sep 17 00:00:00 2001 From: naskya Date: Tue, 25 Jun 2024 13:09:30 +0900 Subject: [PATCH] add database/redis connection check --- Cargo.lock | 66 +++++++++++++++++++++++++++++++++ Cargo.toml | 3 +- src/command/config.rs | 2 +- src/command/config/validate.rs | 68 ++++++++++++++++++++++++++++++++-- src/config/server.rs | 28 +++++++------- 5 files changed, 147 insertions(+), 20 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 110ccbd..e9e0fbe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -109,6 +109,17 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "async-trait" +version = "0.1.80" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "atoi" version = "2.0.0" @@ -292,6 +303,20 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" +[[package]] +name = "combine" +version = "4.6.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" +dependencies = [ + "bytes", + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "tokio-util", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -502,6 +527,7 @@ dependencies = [ "clap", "color-print", "enum-iterator", + "redis", "serde", "serde_repr", "sqlx", @@ -1192,6 +1218,27 @@ dependencies = [ "getrandom", ] +[[package]] +name = "redis" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0d7a6955c7511f60f3ba9e86c6d02b3c3f144f8c24b288d1f4e18074ab8bbec" +dependencies = [ + "async-trait", + "bytes", + "combine", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "ryu", + "sha1_smol", + "socket2", + "tokio", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -1352,6 +1399,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.8" @@ -1788,6 +1841,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-util" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "toml" version = "0.8.14" diff --git a/Cargo.toml b/Cargo.toml index c0e64df..b951e72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,10 +14,11 @@ chrono = "0.4" clap = { version = "4.5", features = ["derive"] } color-print = "0.3" enum-iterator = "2.1" -thiserror = "1.0" +redis = { version = "0.25", features = ["tokio-comp"] } serde = { version = "1.0", features = ["derive"] } serde_repr = "0.1" sqlx = { version = "0.7", features = ["runtime-tokio", "postgres"] } +thiserror = "1.0" tokio = { version = "1.38", features = ["full"] } toml = "0.8" url = "2.5" diff --git a/src/command/config.rs b/src/command/config.rs index 92d5c8e..3f9f6a2 100644 --- a/src/command/config.rs +++ b/src/command/config.rs @@ -36,7 +36,7 @@ pub(crate) enum ConfigError { pub(super) async fn run(command: Commands) -> Result<(), ConfigError> { match command { Commands::Update { revision } => update::run(revision).await?, - Commands::Validate { offline } => validate::run(offline)?, + Commands::Validate { offline } => validate::run(offline).await?, } Ok(()) } diff --git a/src/command/config/validate.rs b/src/command/config/validate.rs index 998c27c..b76e149 100644 --- a/src/command/config/validate.rs +++ b/src/command/config/validate.rs @@ -4,6 +4,7 @@ use super::*; use crate::config::{client, server, CLIENT_CONFIG_PATH, SERVER_CONFIG_PATH}; use color_print::cprintln; use enum_iterator::Sequence; +use sqlx::{postgres::PgConnectOptions, query, ConnectOptions}; use validator::Validate; /// Errors that can happen in `config validate` subcommand @@ -17,9 +18,15 @@ pub(crate) enum ValidationError { OutOfDate, #[error("invalid config file")] InvalidConfig, + #[error("failed to connect to database")] + Db(#[from] sqlx::Error), + #[error("failed to connect to cache server")] + CacheServerConn(#[from] redis::RedisError), + #[error("unexpected cache server response")] + CacheServer, } -pub(super) fn run(bypass_connection_checks: bool) -> Result<(), ValidationError> { +pub(super) async fn run(bypass_connection_checks: bool) -> Result<(), ValidationError> { if current_revision()?.next().is_some() { cprintln!("Please first run `fishctl config update` to update your config files."); return Err(ValidationError::OutOfDate); @@ -66,9 +73,62 @@ pub(super) fn run(bypass_connection_checks: bool) -> Result<(), ValidationError> cprintln!("Note: This command only checks the format of the config files, and its result does not guarantee the correctness of the value."); } - if !bypass_connection_checks { - todo!() + if let Err(err) = server_validation_result.and(client_validation_result) { + return Err(err); } - server_validation_result.and(client_validation_result) + let server_config = read_server_config_as::() + .expect("server config should be formally valid at this point"); + + match bypass_connection_checks { + true => check_database_connection(server_config.database) + .await + .and(check_cache_server_connection(server_config.cache_server).await), + false => Ok(()), + } +} + +async fn check_database_connection(db: server::Database) -> Result<(), ValidationError> { + let mut conn = PgConnectOptions::new() + .host(&db.host) + .port(db.port) + .username(&db.user) + .password(&db.password) + .database(&db.name) + .connect() + .await?; + + query("SELECT version()").execute(&mut conn).await?; + + Ok(()) +} + +async fn check_cache_server_connection( + cache_server: server::CacheServer, +) -> Result<(), ValidationError> { + let url = { + let mut params = vec!["redis://".to_owned()]; + + if let Some(user) = cache_server.user.as_ref() { + params.push(user.to_owned()); + } + if let Some(password) = cache_server.password.as_ref() { + params.push(format!(":{}@", password)); + } + params.push(cache_server.host); + params.push(format!(":{}", cache_server.port)); + params.push(format!("/{}", cache_server.index.unwrap_or(0))); + + params.concat() + }; + + let mut conn = redis::Client::open(url)? + .get_multiplexed_async_connection() + .await?; + let pong: String = redis::cmd("PING").query_async(&mut conn).await?; + + match pong.as_str() { + "PONG" => Ok(()), + _ => Err(ValidationError::CacheServer), + } } diff --git a/src/config/server.rs b/src/config/server.rs index d543702..de1ac16 100644 --- a/src/config/server.rs +++ b/src/config/server.rs @@ -9,7 +9,7 @@ use crate::config::{ensure_latest_revision, Revision}; use serde::{Deserialize, Serialize}; use validator::Validate; -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Config { #[validate(custom(function = "ensure_latest_revision"))] pub config_revision: Revision, @@ -31,7 +31,7 @@ pub struct Config { pub security: Option, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Network { pub protocol: Option, pub domain: String, @@ -48,13 +48,13 @@ pub struct Network { pub smtp_proxy: Option, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Listen { pub host: Option, pub port: u16, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Proxy { pub enabled: bool, #[validate(url)] @@ -62,7 +62,7 @@ pub struct Proxy { pub allowlist: Option>, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Info { /// Server name pub name: Option, @@ -78,7 +78,7 @@ pub struct Info { pub repository_url: Option, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Timeline { /// Whether to enable the local timeline pub local: bool, @@ -90,14 +90,14 @@ pub struct Timeline { pub guest: bool, } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug, Clone)] #[serde(rename_all = "lowercase")] pub enum HttpProtocol { Https, Http, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Database { pub host: String, pub port: u16, @@ -106,7 +106,7 @@ pub struct Database { pub name: String, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct CacheServer { pub host: String, pub port: u16, @@ -117,14 +117,14 @@ pub struct CacheServer { pub prefix: Option, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Id { #[validate(range(min = 16, max = 24))] pub length: Option, pub fingerprint: Option, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct File { /// Maximum file size in megabytes pub max_size: Option, @@ -132,7 +132,7 @@ pub struct File { pub cache_remote_file: Option, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct Security { pub require_authorized_fetch: Option, pub private_mode: Option, @@ -142,7 +142,7 @@ pub struct Security { pub log_ip_address: Option, } -#[derive(Deserialize, Serialize, Validate, Debug)] +#[derive(Deserialize, Serialize, Validate, Debug, Clone)] pub struct CaptchaConfig { pub enabled: bool, pub kind: Captcha, @@ -150,7 +150,7 @@ pub struct CaptchaConfig { pub secret_key: String, } -#[derive(Deserialize, Serialize, Debug)] +#[derive(Deserialize, Serialize, Debug, Clone)] pub enum Captcha { #[serde(rename = "hCaptcha")] HCaptcha,