add database/redis connection check

This commit is contained in:
naskya 2024-06-25 13:09:30 +09:00
parent e1d6b77079
commit 3493eeb29e
Signed by: naskya
GPG key ID: 712D413B3A9FED5C
5 changed files with 147 additions and 20 deletions

66
Cargo.lock generated
View file

@ -109,6 +109,17 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "async-trait"
version = "0.1.80"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6fa2087f2753a7da8cc1c0dbfcf89579dd57458e36769de5ac750b4671737ca"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
]
[[package]]
name = "atoi"
version = "2.0.0"
@ -292,6 +303,20 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b6a852b24ab71dffc585bcb46eaf7959d175cb865a7152e35b348d1b2960422"
[[package]]
name = "combine"
version = "4.6.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd"
dependencies = [
"bytes",
"futures-core",
"memchr",
"pin-project-lite",
"tokio",
"tokio-util",
]
[[package]]
name = "const-oid"
version = "0.9.6"
@ -502,6 +527,7 @@ dependencies = [
"clap",
"color-print",
"enum-iterator",
"redis",
"serde",
"serde_repr",
"sqlx",
@ -1192,6 +1218,27 @@ dependencies = [
"getrandom",
]
[[package]]
name = "redis"
version = "0.25.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e0d7a6955c7511f60f3ba9e86c6d02b3c3f144f8c24b288d1f4e18074ab8bbec"
dependencies = [
"async-trait",
"bytes",
"combine",
"futures-util",
"itoa",
"percent-encoding",
"pin-project-lite",
"ryu",
"sha1_smol",
"socket2",
"tokio",
"tokio-util",
"url",
]
[[package]]
name = "redox_syscall"
version = "0.4.1"
@ -1352,6 +1399,12 @@ dependencies = [
"digest",
]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]]
name = "sha2"
version = "0.10.8"
@ -1788,6 +1841,19 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9cf6b47b3771c49ac75ad09a6162f53ad4b8088b76ac60e8ec1455b31a189fe1"
dependencies = [
"bytes",
"futures-core",
"futures-sink",
"pin-project-lite",
"tokio",
]
[[package]]
name = "toml"
version = "0.8.14"

View file

@ -14,10 +14,11 @@ chrono = "0.4"
clap = { version = "4.5", features = ["derive"] }
color-print = "0.3"
enum-iterator = "2.1"
thiserror = "1.0"
redis = { version = "0.25", features = ["tokio-comp"] }
serde = { version = "1.0", features = ["derive"] }
serde_repr = "0.1"
sqlx = { version = "0.7", features = ["runtime-tokio", "postgres"] }
thiserror = "1.0"
tokio = { version = "1.38", features = ["full"] }
toml = "0.8"
url = "2.5"

View file

@ -36,7 +36,7 @@ pub(crate) enum ConfigError {
pub(super) async fn run(command: Commands) -> Result<(), ConfigError> {
match command {
Commands::Update { revision } => update::run(revision).await?,
Commands::Validate { offline } => validate::run(offline)?,
Commands::Validate { offline } => validate::run(offline).await?,
}
Ok(())
}

View file

@ -4,6 +4,7 @@ use super::*;
use crate::config::{client, server, CLIENT_CONFIG_PATH, SERVER_CONFIG_PATH};
use color_print::cprintln;
use enum_iterator::Sequence;
use sqlx::{postgres::PgConnectOptions, query, ConnectOptions};
use validator::Validate;
/// Errors that can happen in `config validate` subcommand
@ -17,9 +18,15 @@ pub(crate) enum ValidationError {
OutOfDate,
#[error("invalid config file")]
InvalidConfig,
#[error("failed to connect to database")]
Db(#[from] sqlx::Error),
#[error("failed to connect to cache server")]
CacheServerConn(#[from] redis::RedisError),
#[error("unexpected cache server response")]
CacheServer,
}
pub(super) fn run(bypass_connection_checks: bool) -> Result<(), ValidationError> {
pub(super) async fn run(bypass_connection_checks: bool) -> Result<(), ValidationError> {
if current_revision()?.next().is_some() {
cprintln!("Please first run `<bold>fishctl config update</>` to update your config files.");
return Err(ValidationError::OutOfDate);
@ -66,9 +73,62 @@ pub(super) fn run(bypass_connection_checks: bool) -> Result<(), ValidationError>
cprintln!("<bold>Note:</> This command only checks the format of the config files, and its result does not guarantee the correctness of the value.");
}
if !bypass_connection_checks {
todo!()
if let Err(err) = server_validation_result.and(client_validation_result) {
return Err(err);
}
server_validation_result.and(client_validation_result)
let server_config = read_server_config_as::<server::Config>()
.expect("server config should be formally valid at this point");
match bypass_connection_checks {
true => check_database_connection(server_config.database)
.await
.and(check_cache_server_connection(server_config.cache_server).await),
false => Ok(()),
}
}
async fn check_database_connection(db: server::Database) -> Result<(), ValidationError> {
let mut conn = PgConnectOptions::new()
.host(&db.host)
.port(db.port)
.username(&db.user)
.password(&db.password)
.database(&db.name)
.connect()
.await?;
query("SELECT version()").execute(&mut conn).await?;
Ok(())
}
async fn check_cache_server_connection(
cache_server: server::CacheServer,
) -> Result<(), ValidationError> {
let url = {
let mut params = vec!["redis://".to_owned()];
if let Some(user) = cache_server.user.as_ref() {
params.push(user.to_owned());
}
if let Some(password) = cache_server.password.as_ref() {
params.push(format!(":{}@", password));
}
params.push(cache_server.host);
params.push(format!(":{}", cache_server.port));
params.push(format!("/{}", cache_server.index.unwrap_or(0)));
params.concat()
};
let mut conn = redis::Client::open(url)?
.get_multiplexed_async_connection()
.await?;
let pong: String = redis::cmd("PING").query_async(&mut conn).await?;
match pong.as_str() {
"PONG" => Ok(()),
_ => Err(ValidationError::CacheServer),
}
}

View file

@ -9,7 +9,7 @@ use crate::config::{ensure_latest_revision, Revision};
use serde::{Deserialize, Serialize};
use validator::Validate;
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Config {
#[validate(custom(function = "ensure_latest_revision"))]
pub config_revision: Revision,
@ -31,7 +31,7 @@ pub struct Config {
pub security: Option<Security>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Network {
pub protocol: Option<HttpProtocol>,
pub domain: String,
@ -48,13 +48,13 @@ pub struct Network {
pub smtp_proxy: Option<Proxy>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Listen {
pub host: Option<String>,
pub port: u16,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Proxy {
pub enabled: bool,
#[validate(url)]
@ -62,7 +62,7 @@ pub struct Proxy {
pub allowlist: Option<Vec<String>>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Info {
/// Server name
pub name: Option<String>,
@ -78,7 +78,7 @@ pub struct Info {
pub repository_url: Option<String>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Timeline {
/// Whether to enable the local timeline
pub local: bool,
@ -90,14 +90,14 @@ pub struct Timeline {
pub guest: bool,
}
#[derive(Deserialize, Serialize, Debug)]
#[derive(Deserialize, Serialize, Debug, Clone)]
#[serde(rename_all = "lowercase")]
pub enum HttpProtocol {
Https,
Http,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Database {
pub host: String,
pub port: u16,
@ -106,7 +106,7 @@ pub struct Database {
pub name: String,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct CacheServer {
pub host: String,
pub port: u16,
@ -117,14 +117,14 @@ pub struct CacheServer {
pub prefix: Option<String>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Id {
#[validate(range(min = 16, max = 24))]
pub length: Option<u8>,
pub fingerprint: Option<String>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct File {
/// Maximum file size in megabytes
pub max_size: Option<u64>,
@ -132,7 +132,7 @@ pub struct File {
pub cache_remote_file: Option<bool>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct Security {
pub require_authorized_fetch: Option<bool>,
pub private_mode: Option<bool>,
@ -142,7 +142,7 @@ pub struct Security {
pub log_ip_address: Option<bool>,
}
#[derive(Deserialize, Serialize, Validate, Debug)]
#[derive(Deserialize, Serialize, Validate, Debug, Clone)]
pub struct CaptchaConfig {
pub enabled: bool,
pub kind: Captcha,
@ -150,7 +150,7 @@ pub struct CaptchaConfig {
pub secret_key: String,
}
#[derive(Deserialize, Serialize, Debug)]
#[derive(Deserialize, Serialize, Debug, Clone)]
pub enum Captcha {
#[serde(rename = "hCaptcha")]
HCaptcha,