use anyhow::{Context, Result};
use log::info;
use serde::Deserialize;
use std::{env, fs, path::Path};
use crate::args::Args;
#[derive(Debug, Deserialize)]
pub struct PrimarySectionConfig {
pub neo4j_allowed: bool,
}
#[derive(Debug, Deserialize)]
pub struct TomlRootConfig {
pub primary: PrimarySectionConfig,
}
#[derive(Debug, Clone)]
pub struct Neo4jConfig {
pub uri: String,
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct LouvainConfig {
pub max_iterations: u32,
pub tolerance: f64,
pub include_intermediate_communities: bool,
}
impl Default for LouvainConfig {
fn default() -> Self {
Self {
max_iterations: 10,
tolerance: 0.0001,
include_intermediate_communities: false,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct LabelPropagationConfig {
pub max_iterations: u32,
}
impl Default for LabelPropagationConfig {
fn default() -> Self {
Self { max_iterations: 10 }
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct CommunityDetectionConfig {
pub consensus_runs: u32,
pub louvain: LouvainConfig,
pub label_propagation: LabelPropagationConfig,
}
impl Default for CommunityDetectionConfig {
fn default() -> Self {
Self {
consensus_runs: 5,
louvain: LouvainConfig::default(),
label_propagation: LabelPropagationConfig::default(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct CentralityConfig {
pub betweenness_sampling_size: Option<usize>,
pub betweenness_sampling_seed: Option<u64>,
pub use_wasserman_faust: Option<bool>,
}
impl Default for CentralityConfig {
fn default() -> Self {
Self {
betweenness_sampling_size: Some(1000),
betweenness_sampling_seed: Some(42),
use_wasserman_faust: Some(true),
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct PathAnalysisConfig {
pub max_paths_per_pair: Option<usize>,
pub max_path_length: Option<usize>,
pub representative_nodes_per_partition: usize,
pub num_top_communities: usize,
pub sample_size_communities: usize,
pub num_top_asn_groups: usize,
pub sample_size_asn_groups: usize,
pub internal_component_sample_size: usize,
}
impl Default for PathAnalysisConfig {
fn default() -> Self {
Self {
max_paths_per_pair: Some(10),
max_path_length: Some(10),
representative_nodes_per_partition: 5,
num_top_communities: 3,
sample_size_communities: 5,
num_top_asn_groups: 3,
sample_size_asn_groups: 3,
internal_component_sample_size: 15,
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct AnalysisSettings {
pub max_display_components: usize,
pub calculate_distribution: bool,
pub isolation_ratio_threshold: f64,
}
impl Default for AnalysisSettings {
fn default() -> Self {
Self {
max_display_components: 10,
calculate_distribution: true,
isolation_ratio_threshold: 50.0,
}
}
}
#[derive(Debug, Clone, Deserialize, Default)]
pub struct AnalysisParametersConfig {
pub community_detection: CommunityDetectionConfig,
pub centrality: CentralityConfig,
pub path_analysis: PathAnalysisConfig,
pub analysis: AnalysisSettings,
}
#[derive(Debug, Clone)]
pub struct AnalysisConfig {
pub neo4j: Option<Neo4jConfig>,
pub analysis_params: AnalysisParametersConfig,
}
impl AnalysisConfig {
pub fn load_from_toml_and_env(args: &Args) -> Result<Self> {
let config_path = Path::new(&args.config);
info!("Loading configuration from: {:?}", config_path);
let config_file_contents = fs::read_to_string(config_path)
.with_context(|| {
format!(
"Failed to read configuration file at: {:?}",
config_path
)
})?;
let toml_root: TomlRootConfig = toml::from_str(&config_file_contents)
.with_context(|| {
format!(
"Failed to parse TOML from configuration file at: {:?}. \
Ensure it matches the expected structure.",
config_path
)
})?;
let mut neo4j_details: Option<Neo4jConfig> = None;
if toml_root.primary.neo4j_allowed {
let env_path = "config/primary/.env";
match dotenvy::from_path(env_path) {
Ok(_) => {
println!(
"Successfully loaded environment variables from: {}",
env_path
);
}
Err(e) => {
println!(
"Warning: Could not load .env file from {}: {}. \
Neo4j credential loading will rely on globally set \
environment variables.",
env_path, e
);
}
}
info!(
"Attempting to load Neo4j credentials from environment \
variables..."
);
let raw_uri = env::var("NEO4J_DB_ADDR").context(
"NEO4J_DB_ADDR not found. Ensure it is set (e.g., in \
'config/primary/.env' or globally).",
)?;
let uri = format!("bolt://{}", raw_uri);
let username = env::var("NEO4J_DB_USERNAME").context(
"NEO4J_DB_USERNAME not found. Ensure it is set (e.g., in \
'config/primary/.env' or globally).",
)?;
let password = env::var("NEO4J_DB_PASSWORD").context(
"NEO4J_DB_PASSWORD not found. Ensure it is set (e.g., in \
'config/primary/.env' or globally).",
)?;
neo4j_details = Some(Neo4jConfig {
uri,
username,
password,
});
println!("Successfully loaded Neo4j credentials.");
} else {
println!(
"Neo4j is not allowed in the [primary] section of the \
configuration file. Skipping .env loading and credential \
retrieval."
);
}
let analysis_params =
Self::load_analysis_parameters(&args.analysis_config)?;
Ok(AnalysisConfig {
neo4j: neo4j_details,
analysis_params,
})
}
fn load_analysis_parameters(
config_path: &str,
) -> Result<AnalysisParametersConfig> {
let analysis_config_path = Path::new(config_path);
if !analysis_config_path.exists() {
info!(
"Analysis config file not found at {:?}, using defaults",
analysis_config_path
);
return Ok(AnalysisParametersConfig::default());
}
info!(
"Loading analysis parameters from: {:?}",
analysis_config_path
);
let config_contents = fs::read_to_string(analysis_config_path)
.with_context(|| {
format!(
"Failed to read analysis config file at: {:?}",
analysis_config_path
)
})?;
let analysis_params: AnalysisParametersConfig =
toml::from_str(&config_contents).with_context(|| {
format!(
"Failed to parse analysis config TOML from: {:?}. \
Ensure it matches the expected structure.",
analysis_config_path
)
})?;
info!("Successfully loaded analysis parameters");
Self::validate_analysis_parameters(&analysis_params)?;
Ok(analysis_params)
}
fn validate_analysis_parameters(
params: &AnalysisParametersConfig,
) -> Result<()> {
let louvain = ¶ms.community_detection.louvain;
if louvain.max_iterations == 0 || louvain.max_iterations > 100 {
return Err(anyhow::anyhow!(
"Louvain max_iterations must be between 1 and 100, got: {}",
louvain.max_iterations
));
}
if louvain.tolerance <= 0.0 || louvain.tolerance > 1.0 {
return Err(anyhow::anyhow!(
"Louvain tolerance must be between 0.0 and 1.0, got: {}",
louvain.tolerance
));
}
let lpa = ¶ms.community_detection.label_propagation;
if lpa.max_iterations == 0 || lpa.max_iterations > 100 {
return Err(anyhow::anyhow!(
"Label Propagation max_iterations must be between 1 and 100, \
got: {}",
lpa.max_iterations
));
}
let cd = ¶ms.community_detection;
if cd.consensus_runs == 0 || cd.consensus_runs > 20 {
return Err(anyhow::anyhow!(
"consensus_runs must be between 1 and 20, got: {}",
cd.consensus_runs
));
}
Ok(())
}
}