add database/redis connection check

This commit is contained in:
naskya 2024-06-25 13:09:30 +09:00
parent e1d6b77079
commit 3493eeb29e
Signed by: naskya
GPG key ID: 712D413B3A9FED5C
5 changed files with 147 additions and 20 deletions

66
Cargo.lock generated
View file

@ -109,6 +109,17 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "async-trait"
version = "0.1.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
]
[[package]] [[package]]
name = "atoi" name = "atoi"
version = "2.0.0" version = "2.0.0"
@ -292,6 +303,20 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422" checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]] [[package]]
name = "const-oid" name = "const-oid"
version = "0.9.6" version = "0.9.6"
@ -502,6 +527,7 @@ dependencies = [
"clap", "clap",
"color-print", "color-print",
"enum-iterator", "enum-iterator",
"redis",
"serde", "serde",
"serde_repr", "serde_repr",
"sqlx", "sqlx",
@ -1192,6 +1218,27 @@ dependencies = [
"getrandom", "getrandom",
] ]
[[package]]
name = "redis"
version = "0.25.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0d7a6955c7511f60f3ba9e86c6d02b3c3f144f8c24b288d1f4e18074ab8bbec"
dependencies = [
"async-trait",
"bytes",
"combine",
"futures-util",
"itoa",
"percent-encoding",
"pin-project-lite",
"ryu",
"sha1_smol",
"socket2",
"tokio",
"tokio-util",
"url",
]
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.4.1" version = "0.4.1"
@ -1352,6 +1399,12 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]] [[package]]
name = "sha2" name = "sha2"
version = "0.10.8" version = "0.10.8"
@ -1788,6 +1841,19 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-util"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]] [[package]]
name = "toml" name = "toml"
version = "0.8.14" version = "0.8.14"

View file

@ -14,10 +14,11 @@ chrono = "0.4"
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
color-print = "0.3" color-print = "0.3"
enum-iterator = "2.1" enum-iterator = "2.1"
thiserror = "1.0" redis = { version = "0.25", features = ["tokio-comp"] }
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_repr = "0.1" serde_repr = "0.1"
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres"] } sqlx = { version = "0.7", features = ["runtime-tokio", "postgres"] }
thiserror = "1.0"
tokio = { version = "1.38", features = ["full"] } tokio = { version = "1.38", features = ["full"] }
toml = "0.8" toml = "0.8"
url = "2.5" url = "2.5"

View file

@ -36,7 +36,7 @@ pub(crate) enum ConfigError {
pub(super) async fn run(command: Commands) -> Result<(), ConfigError> { pub(super) async fn run(command: Commands) -> Result<(), ConfigError> {
match command { match command {
Commands::Update { revision } => update::run(revision).await?, Commands::Update { revision } => update::run(revision).await?,
Commands::Validate { offline } => validate::run(offline)?, Commands::Validate { offline } => validate::run(offline).await?,
} }
Ok(()) Ok(())
} }

View file

@ -4,6 +4,7 @@ use super::*;
use crate::config::{client, server, CLIENT_CONFIG_PATH, SERVER_CONFIG_PATH}; use crate::config::{client, server, CLIENT_CONFIG_PATH, SERVER_CONFIG_PATH};
use color_print::cprintln; use color_print::cprintln;
use enum_iterator::Sequence; use enum_iterator::Sequence;
use sqlx::{postgres::PgConnectOptions, query, ConnectOptions};
use validator::Validate; use validator::Validate;
/// Errors that can happen in `config validate` subcommand /// Errors that can happen in `config validate` subcommand
@ -17,9 +18,15 @@ pub(crate) enum ValidationError {
OutOfDate, OutOfDate,
#[error("invalid config file")] #[error("invalid config file")]
InvalidConfig, InvalidConfig,
#[error("failed to connect to database")]
Db(#[from] sqlx::Error),
#[error("failed to connect to cache server")]
CacheServerConn(#[from] redis::RedisError),
#[error("unexpected cache server response")]
CacheServer,
} }
pub(super) fn run(bypass_connection_checks: bool) -> Result<(), ValidationError> { pub(super) async fn run(bypass_connection_checks: bool) -> Result<(), ValidationError> {
if current_revision()?.next().is_some() { if current_revision()?.next().is_some() {
cprintln!("Please first run `<bold>fishctl config update</>` to update your config files."); cprintln!("Please first run `<bold>fishctl config update</>` to update your config files.");
return Err(ValidationError::OutOfDate); return Err(ValidationError::OutOfDate);
@ -66,9 +73,62 @@ pub(super) fn run(bypass_connection_checks: bool) -> Result<(), ValidationError>
cprintln!("<bold>Note:</> This command only checks the format of the config files, and its result does not guarantee the correctness of the value."); cprintln!("<bold>Note:</> This command only checks the format of the config files, and its result does not guarantee the correctness of the value.");
} }
if !bypass_connection_checks { if let Err(err) = server_validation_result.and(client_validation_result) {
todo!() return Err(err);
} }
server_validation_result.and(client_validation_result) let server_config = read_server_config_as::<server::Config>()
.expect("server config should be formally valid at this point");
match bypass_connection_checks {
true => check_database_connection(server_config.database)
.await
.and(check_cache_server_connection(server_config.cache_server).await),
false => Ok(()),
}
}
async fn check_database_connection(db: server::Database) -> Result<(), ValidationError> {
let mut conn = PgConnectOptions::new()
.host(&db.host)
.port(db.port)
.username(&db.user)
.password(&db.password)
.database(&db.name)
.connect()
.await?;
query("SELECT version()").execute(&mut conn).await?;
Ok(())
}
async fn check_cache_server_connection(
cache_server: server::CacheServer,
) -> Result<(), ValidationError> {
let url = {
let mut params = vec!["redis://".to_owned()];
if let Some(user) = cache_server.user.as_ref() {
params.push(user.to_owned());
}
if let Some(password) = cache_server.password.as_ref() {
params.push(format!(":{}@", password));
}
params.push(cache_server.host);
params.push(format!(":{}", cache_server.port));
params.push(format!("/{}", cache_server.index.unwrap_or(0)));
params.concat()
};
let mut conn = redis::Client::open(url)?
.get_multiplexed_async_connection()
.await?;
let pong: String = redis::cmd("PING").query_async(&mut conn).await?;
match pong.as_str() {
"PONG" => Ok(()),
_ => Err(ValidationError::CacheServer),
}
} }

View file

@ -9,7 +9,7 @@ use crate::config::{ensure_latest_revision, Revision};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use validator::Validate; use validator::Validate;
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Config { pub struct Config {
#[validate(custom(function = "ensure_latest_revision"))] #[validate(custom(function = "ensure_latest_revision"))]
pub config_revision: Revision, pub config_revision: Revision,
@ -31,7 +31,7 @@ pub struct Config {
pub security: Option<Security>, pub security: Option<Security>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Network { pub struct Network {
pub protocol: Option<HttpProtocol>, pub protocol: Option<HttpProtocol>,
pub domain: String, pub domain: String,
@ -48,13 +48,13 @@ pub struct Network {
pub smtp_proxy: Option<Proxy>, pub smtp_proxy: Option<Proxy>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Listen { pub struct Listen {
pub host: Option<String>, pub host: Option<String>,
pub port: u16, pub port: u16,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Proxy { pub struct Proxy {
pub enabled: bool, pub enabled: bool,
#[validate(url)] #[validate(url)]
@ -62,7 +62,7 @@ pub struct Proxy {
pub allowlist: Option<Vec<String>>, pub allowlist: Option<Vec<String>>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Info { pub struct Info {
/// Server name /// Server name
pub name: Option<String>, pub name: Option<String>,
@ -78,7 +78,7 @@ pub struct Info {
pub repository_url: Option<String>, pub repository_url: Option<String>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Timeline { pub struct Timeline {
/// Whether to enable the local timeline /// Whether to enable the local timeline
pub local: bool, pub local: bool,
@ -90,14 +90,14 @@ pub struct Timeline {
pub guest: bool, pub guest: bool,
} }
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
pub enum HttpProtocol { pub enum HttpProtocol {
Https, Https,
Http, Http,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Database { pub struct Database {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
@ -106,7 +106,7 @@ pub struct Database {
pub name: String, pub name: String,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct CacheServer { pub struct CacheServer {
pub host: String, pub host: String,
pub port: u16, pub port: u16,
@ -117,14 +117,14 @@ pub struct CacheServer {
pub prefix: Option<String>, pub prefix: Option<String>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Id { pub struct Id {
#[validate(range(min = 16, max = 24))] #[validate(range(min = 16, max = 24))]
pub length: Option<u8>, pub length: Option<u8>,
pub fingerprint: Option<String>, pub fingerprint: Option<String>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct File { pub struct File {
/// Maximum file size in megabytes /// Maximum file size in megabytes
pub max_size: Option<u64>, pub max_size: Option<u64>,
@ -132,7 +132,7 @@ pub struct File {
pub cache_remote_file: Option<bool>, pub cache_remote_file: Option<bool>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Security { pub struct Security {
pub require_authorized_fetch: Option<bool>, pub require_authorized_fetch: Option<bool>,
pub private_mode: Option<bool>, pub private_mode: Option<bool>,
@ -142,7 +142,7 @@ pub struct Security {
pub log_ip_address: Option<bool>, pub log_ip_address: Option<bool>,
} }
#[derive(Deserialize, Serialize, Validate, Debug)] #[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct CaptchaConfig { pub struct CaptchaConfig {
pub enabled: bool, pub enabled: bool,
pub kind: Captcha, pub kind: Captcha,
@ -150,7 +150,7 @@ pub struct CaptchaConfig {
pub secret_key: String, pub secret_key: String,
} }
#[derive(Deserialize, Serialize, Debug)] #[derive(Deserialize, Serialize, Debug, Clone)]
pub enum Captcha { pub enum Captcha {
#[serde(rename = "hCaptcha")] #[serde(rename = "hCaptcha")]
HCaptcha, HCaptcha,