erpc_analysis/algorithms/
classification.rs
use log::info;
use std::sync::Arc;
use crate::db_trait::{AnalysisDatabase, AnalysisError};
use crate::models::partitions::{
ConnectedComponent, PartitionClassificationResult,
};
pub struct PartitionClassifier {
db_client: Arc<dyn AnalysisDatabase>,
}
impl PartitionClassifier {
pub fn new(db_client: Arc<dyn AnalysisDatabase>) -> Self {
Self { db_client }
}
pub async fn classify_by_geography(
&self,
components: &[ConnectedComponent],
) -> Result<PartitionClassificationResult, AnalysisError> {
info!("=== Starting Geographic Classification Analysis ===");
let result = self
.db_client
.classify_components_by_geography(components)
.await?;
info!("=== Geographic Classification Complete ===");
info!(
"Found {} geographic groups with {:.1}% coverage",
result.metrics.total_groups,
result.metrics.classification_coverage
);
Ok(result)
}
pub async fn classify_by_asn(
&self,
components: &[ConnectedComponent],
) -> Result<PartitionClassificationResult, AnalysisError> {
info!("=== Starting ASN Classification Analysis ===");
let result = self
.db_client
.classify_components_by_asn(components)
.await?;
info!("=== ASN Classification Complete ===");
info!(
"Found {} ASN groups with {:.1}% coverage",
result.metrics.total_groups,
result.metrics.classification_coverage
);
Ok(result)
}
pub async fn classify_by_family(
&self,
components: &[ConnectedComponent],
) -> Result<PartitionClassificationResult, AnalysisError> {
info!("=== Starting Family Classification Analysis ===");
let result = self
.db_client
.classify_components_by_family(components)
.await?;
info!("=== Family Classification Complete ===");
info!(
"Found {} family groups with {:.1}% coverage",
result.metrics.total_groups,
result.metrics.classification_coverage
);
Ok(result)
}
pub fn display_geographic_classification(
&self,
result: &PartitionClassificationResult,
config: &crate::config::AnalysisSettings,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Geographic Classification Analysis:");
info!("Total Countries: {}", result.metrics.total_groups);
info!(
"Countries with Partitions: {}",
result.metrics.groups_with_partitions
);
info!(
"Classification Coverage: {:.1}%",
result.metrics.classification_coverage
);
info!(
"Largest Country Group: {} relays",
result.metrics.largest_group_size
);
info!(
"Average Country Group Size: {:.1} relays",
result.metrics.average_group_size
);
info!("Diversity Score: {:.3}", result.metrics.diversity_score);
info!(
"Partition Correlation: {:.2}",
result.metrics.partition_correlation
);
let mut sorted_groups = result.groups.clone();
sorted_groups.sort_by(|a, b| {
b.relay_fingerprints.len().cmp(&a.relay_fingerprints.len())
});
info!("Top Countries by Relay Count:");
for (i, group) in sorted_groups
.iter()
.take(config.max_display_components)
.enumerate()
{
let partition_info = if group.component_mapping.len() > 1 {
format!(
" (across {} partitions)",
group.component_mapping.len()
)
} else {
" (single partition)".to_string()
};
info!(
"{}. {}: {} relays, isolation: {:.1}%{}",
i + 1,
group.identifier,
group.relay_fingerprints.len(),
group.isolation_score,
partition_info
);
}
if !result.unclassified_relays.is_empty() {
let total_classified = result
.groups
.iter()
.map(|g| g.relay_fingerprints.len())
.sum::<usize>();
let total_relays =
total_classified + result.unclassified_relays.len();
let unclassified_percent = if total_relays > 0 {
(result.unclassified_relays.len() as f64 / total_relays as f64)
* 100.0
} else {
0.0
};
info!(
"Unclassified relays: {} ({:.3}% of total)",
result.unclassified_relays.len(),
unclassified_percent
);
}
Ok(())
}
pub fn display_asn_classification(
&self,
result: &PartitionClassificationResult,
config: &crate::config::AnalysisSettings,
) -> Result<(), Box<dyn std::error::Error>> {
info!("ASN Classification Analysis:");
info!("Total ASNs: {}", result.metrics.total_groups);
info!(
"ASNs with Partitions: {}",
result.metrics.groups_with_partitions
);
info!(
"Classification Coverage: {:.1}%",
result.metrics.classification_coverage
);
info!(
"Largest ASN Group: {} relays",
result.metrics.largest_group_size
);
info!(
"Average ASN Group Size: {:.1} relays",
result.metrics.average_group_size
);
info!("Diversity Score: {:.3}", result.metrics.diversity_score);
info!(
"Partition Correlation: {:.2}",
result.metrics.partition_correlation
);
let mut sorted_groups = result.groups.clone();
sorted_groups.sort_by(|a, b| {
b.relay_fingerprints.len().cmp(&a.relay_fingerprints.len())
});
info!("Top ASNs by Relay Count:");
for (i, group) in sorted_groups
.iter()
.take(config.max_display_components)
.enumerate()
{
let partition_info = if group.component_mapping.len() > 1 {
format!(
" (across {} partitions)",
group.component_mapping.len()
)
} else {
" (single partition)".to_string()
};
info!(
"{}. AS{}: {} relays, isolation: {:.1}%{}",
i + 1,
group.identifier,
group.relay_fingerprints.len(),
group.isolation_score,
partition_info
);
}
if !result.unclassified_relays.is_empty() {
let total_classified = result
.groups
.iter()
.map(|g| g.relay_fingerprints.len())
.sum::<usize>();
let total_relays =
total_classified + result.unclassified_relays.len();
let unclassified_percent = if total_relays > 0 {
(result.unclassified_relays.len() as f64 / total_relays as f64)
* 100.0
} else {
0.0
};
info!(
"Unclassified relays: {} ({:.3}% of total)",
result.unclassified_relays.len(),
unclassified_percent
);
}
Ok(())
}
pub fn display_family_classification(
&self,
result: &PartitionClassificationResult,
config: &crate::config::AnalysisSettings,
) -> Result<(), Box<dyn std::error::Error>> {
info!("Family Classification Analysis:");
info!("Total Families: {}", result.metrics.total_groups);
info!(
"Families with Partitions: {}",
result.metrics.groups_with_partitions
);
info!(
"Classification Coverage: {:.1}%",
result.metrics.classification_coverage
);
info!(
"Largest Family: {} relays",
result.metrics.largest_group_size
);
info!(
"Average Family Size: {:.1} relays",
result.metrics.average_group_size
);
info!("Diversity Score: {:.3}", result.metrics.diversity_score);
info!(
"Partition Correlation: {:.2}",
result.metrics.partition_correlation
);
let mut sorted_groups = result.groups.clone();
sorted_groups.sort_by(|a, b| {
b.relay_fingerprints.len().cmp(&a.relay_fingerprints.len())
});
info!("Top Families by Relay Count:");
for (i, group) in sorted_groups
.iter()
.take(config.max_display_components)
.enumerate()
{
let partition_info = if group.component_mapping.len() > 1 {
format!(
" (across {} partitions)",
group.component_mapping.len()
)
} else {
" (single partition)".to_string()
};
info!(
"{}. {}: {} relays, isolation: {:.1}%{}",
i + 1,
if group.identifier.len() > 12 {
format!("{}...", &group.identifier[0..12])
} else {
group.identifier.clone()
},
group.relay_fingerprints.len(),
group.isolation_score,
partition_info
);
}
if !result.unclassified_relays.is_empty() {
let total_classified = result
.groups
.iter()
.map(|g| g.relay_fingerprints.len())
.sum::<usize>();
let total_relays =
total_classified + result.unclassified_relays.len();
let unclassified_percent = if total_relays > 0 {
(result.unclassified_relays.len() as f64 / total_relays as f64)
* 100.0
} else {
0.0
};
info!(
"Non-family relays: {} ({:.3}% of total)",
result.unclassified_relays.len(),
unclassified_percent
);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::db_trait::mock::MockDatabase;
use crate::models::partitions::ConnectedComponent;
#[tokio::test]
async fn test_geographic_classification() {
let db = Arc::new(MockDatabase::new());
let classifier = PartitionClassifier::new(db);
let components = vec![ConnectedComponent {
component_id: 1,
relay_fingerprints: vec![
"RELAY_A".to_string(), "RELAY_B".to_string(), "RELAY_C".to_string(), "RELAY_D".to_string(), ],
size: 4,
}];
let result =
classifier.classify_by_geography(&components).await.unwrap();
assert!(!result.groups.is_empty());
assert_eq!(result.metrics.classification_coverage, 100.0);
for group in &result.groups {
assert!(group.component_mapping.contains_key(&1));
}
}
#[tokio::test]
async fn test_asn_classification() {
let db = Arc::new(MockDatabase::new());
let classifier = PartitionClassifier::new(db);
let components = vec![ConnectedComponent {
component_id: 42,
relay_fingerprints: vec![
"RELAY_X".to_string(), "RELAY_Y".to_string(), "RELAY_Z".to_string(), ],
size: 3,
}];
let result = classifier.classify_by_asn(&components).await.unwrap();
assert!(!result.groups.is_empty());
assert_eq!(result.metrics.classification_coverage, 100.0);
for group in &result.groups {
assert!(group.component_mapping.contains_key(&42));
}
}
#[tokio::test]
async fn test_isolation_score() {
let db = Arc::new(MockDatabase::new());
let classifier = PartitionClassifier::new(db);
let result = classifier.classify_by_geography(&[]).await.unwrap();
assert_eq!(result.groups.len(), 0);
assert_eq!(result.metrics.classification_coverage, 0.0);
let components = vec![
ConnectedComponent {
component_id: 10,
relay_fingerprints: vec!["AAAA".to_string()], size: 1,
},
ConnectedComponent {
component_id: 20,
relay_fingerprints: vec![
"BBBB".to_string(),
"CCCC".to_string(),
],
size: 2,
},
];
let result =
classifier.classify_by_geography(&components).await.unwrap();
assert!(!result.groups.is_empty());
for group in &result.groups {
assert!(group.isolation_score >= 0.0);
assert!(group.isolation_score <= 100.0);
let component_ids: Vec<i64> =
group.component_mapping.keys().copied().collect();
assert!(component_ids.iter().any(|&id| id == 10 || id == 20));
}
}
#[tokio::test]
async fn test_family_classification() {
let db = Arc::new(MockDatabase::new());
let classifier = PartitionClassifier::new(db);
let components = vec![ConnectedComponent {
component_id: 5,
relay_fingerprints: vec![
"RELAY_FAM1".to_string(), "RELAY_FAM2".to_string(),
"RELAY_FAM3".to_string(),
"RELAY_LONE".to_string(),
],
size: 4,
}];
let result = classifier.classify_by_family(&components).await.unwrap();
let total_classified: usize = result
.groups
.iter()
.map(|g| g.relay_fingerprints.len())
.sum();
let total_unclassified = result.unclassified_relays.len();
assert_eq!(total_classified + total_unclassified, 4);
for group in &result.groups {
assert!(group.component_mapping.contains_key(&5));
}
let expected_coverage = (total_classified as f64 / 4.0) * 100.0;
assert!(
(result.metrics.classification_coverage - expected_coverage).abs()
< 0.1
);
}
#[tokio::test]
async fn test_classification_error_handling() {
let components = vec![ConnectedComponent {
component_id: 0,
relay_fingerprints: vec!["FAIL_RELAY".to_string()],
size: 1,
}];
let db = Arc::new(
MockDatabase::new().fail_on("classify_components_by_geography"),
);
let classifier = PartitionClassifier::new(db);
let result = classifier.classify_by_geography(&components).await;
assert!(result.is_err(), "Should fail when database operation fails");
}
#[tokio::test]
async fn test_display_methods() {
let components = vec![ConnectedComponent {
component_id: 0,
relay_fingerprints: vec!["TEST_RELAY".to_string()],
size: 1,
}];
let db = Arc::new(MockDatabase::new());
let classifier = PartitionClassifier::new(db);
let geo_result =
classifier.classify_by_geography(&components).await.unwrap();
let asn_result =
classifier.classify_by_asn(&components).await.unwrap();
let family_result =
classifier.classify_by_family(&components).await.unwrap();
let config = crate::config::AnalysisSettings::default();
classifier
.display_geographic_classification(&geo_result, &config)
.unwrap();
classifier
.display_asn_classification(&asn_result, &config)
.unwrap();
classifier
.display_family_classification(&family_result, &config)
.unwrap();
}
}