use tor_hscrypto::{pk::HsBlindId, RevisionCounter, Subcredential};
use tor_llcrypto::cipher::aes::Aes256Ctr as Cipher;
use tor_llcrypto::d::Sha3_256 as Hash;
use tor_llcrypto::d::Shake256 as KDF;
use cipher::{KeyIvInit, StreamCipher};
use digest::{ExtendableOutput, FixedOutput, Update, XofReader};
#[cfg(any(test, feature = "hs-service"))]
use rand::{CryptoRng, Rng};
use tor_llcrypto::pk::curve25519::PublicKey;
use tor_llcrypto::pk::curve25519::StaticSecret;
use tor_llcrypto::util::ct::CtByteArray;
use zeroize::Zeroizing as Z;
pub(super) struct HsDescEncryption<'a> {
pub(super) blinded_id: &'a HsBlindId,
pub(super) desc_enc_nonce: Option<&'a HsDescEncNonce>,
pub(super) subcredential: &'a Subcredential,
pub(super) revision: RevisionCounter,
pub(super) string_const: &'a [u8],
}
pub(crate) const HS_DESC_CLIENT_ID_LEN: usize = 8;
pub(crate) const HS_DESC_IV_LEN: usize = 16;
pub(crate) const HS_DESC_ENC_NONCE_LEN: usize = 16;
#[derive(derive_more::AsRef, derive_more::From)]
pub(super) struct HsDescEncNonce([u8; HS_DESC_ENC_NONCE_LEN]);
const SALT_LEN: usize = 16;
const MAC_LEN: usize = 32;
impl<'a> HsDescEncryption<'a> {
const MAC_KEY_LEN: usize = 32;
const CIPHER_KEY_LEN: usize = 32;
const IV_LEN: usize = 16;
#[cfg(any(test, feature = "hs-service"))]
pub(super) fn encrypt<R: Rng + CryptoRng>(&self, rng: &mut R, data: &[u8]) -> Vec<u8> {
let output_len = data.len() + SALT_LEN + MAC_LEN;
let mut output = Vec::with_capacity(output_len);
let salt: [u8; SALT_LEN] = rng.gen();
let (mut cipher, mut mac) = self.init(&salt);
output.extend_from_slice(&salt[..]);
output.extend_from_slice(data);
cipher.apply_keystream(&mut output[SALT_LEN..]);
mac.update(&output[SALT_LEN..]);
let mut mac_val = Default::default();
mac.finalize_into(&mut mac_val);
output.extend_from_slice(&mac_val);
debug_assert_eq!(output.len(), output_len);
output
}
pub(super) fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, DecryptionError> {
if data.len() < SALT_LEN + MAC_LEN {
return Err(DecryptionError::default());
}
let msg_len = data.len() - SALT_LEN - MAC_LEN;
let salt = data[0..SALT_LEN]
.try_into()
.expect("Failed try_into for 16-byte array.");
let ciphertext = &data[SALT_LEN..(SALT_LEN + msg_len)];
let expected_mac = CtByteArray::from(
<[u8; MAC_LEN]>::try_from(&data[SALT_LEN + msg_len..SALT_LEN + msg_len + MAC_LEN])
.expect("Failed try_into for 32-byte array."),
);
let (mut cipher, mut mac) = self.init(&salt);
mac.update(ciphertext);
let mut received_mac = CtByteArray::from([0_u8; MAC_LEN]);
mac.finalize_into(received_mac.as_mut().into());
if received_mac != expected_mac {
return Err(DecryptionError::default());
}
let mut decrypted = ciphertext.to_vec();
cipher.apply_keystream(&mut decrypted[..]);
Ok(decrypted)
}
fn init(&self, salt: &[u8; 16]) -> (Cipher, Hash) {
let mut key_stream = self.get_kdf(salt).finalize_xof();
let mut key = Z::new([0_u8; Self::CIPHER_KEY_LEN]);
let mut iv = Z::new([0_u8; Self::IV_LEN]);
let mut mac_key = Z::new([0_u8; Self::MAC_KEY_LEN]);
key_stream.read(&mut key[..]);
key_stream.read(&mut iv[..]);
key_stream.read(&mut mac_key[..]);
let cipher = Cipher::new(key.as_ref().into(), iv.as_ref().into());
let mut mac = Hash::default();
mac.update(&(Self::MAC_KEY_LEN as u64).to_be_bytes());
mac.update(&mac_key[..]);
mac.update(&(salt.len() as u64).to_be_bytes());
mac.update(&salt[..]);
(cipher, mac)
}
fn get_kdf(&self, salt: &[u8; 16]) -> KDF {
let mut kdf = KDF::default();
kdf.update(self.blinded_id.as_ref());
if let Some(cookie) = self.desc_enc_nonce {
kdf.update(cookie.as_ref());
}
kdf.update(self.subcredential.as_ref());
kdf.update(&u64::from(self.revision).to_be_bytes());
kdf.update(salt);
kdf.update(self.string_const);
kdf
}
}
#[non_exhaustive]
#[derive(Clone, Debug, Default, thiserror::Error)]
#[error("Unable to decrypt onion service descriptor.")]
pub struct DecryptionError {}
pub(crate) fn build_descriptor_cookie_key(
our_secret_key: &StaticSecret,
their_public_key: &PublicKey,
subcredential: &Subcredential,
) -> (CtByteArray<8>, [u8; 32]) {
let secret_seed = our_secret_key.diffie_hellman(their_public_key);
let mut kdf = KDF::default();
kdf.update(subcredential.as_ref());
kdf.update(secret_seed.as_bytes());
let mut keys = kdf.finalize_xof();
let mut client_id = CtByteArray::from([0_u8; 8]);
let mut cookie_key = [0_u8; 32];
keys.read(client_id.as_mut());
keys.read(&mut cookie_key);
(client_id, cookie_key)
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_duration_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
use super::*;
use tor_basic_utils::test_rng::testing_rng;
#[test]
fn roundtrip_basics() {
let blinded_id = [7; 32].into();
let subcredential = [11; 32].into();
let revision = 13.into();
let string_const = "greetings puny humans";
let params = HsDescEncryption {
blinded_id: &blinded_id,
desc_enc_nonce: None,
subcredential: &subcredential,
revision,
string_const: string_const.as_bytes(),
};
let mut rng = testing_rng();
let bigmsg: Vec<u8> = (1..123).cycle().take(1021).collect();
for message in [&b""[..], &b"hello world"[..], &bigmsg[..]] {
let mut encrypted = params.encrypt(&mut rng, message);
assert_eq!(encrypted.len(), message.len() + 48);
let decrypted = params.decrypt(&encrypted[..]).unwrap();
assert_eq!(message, &decrypted);
let decryption_err = params.decrypt(&encrypted[..encrypted.len() - 1]);
assert!(decryption_err.is_err());
encrypted[7] ^= 3;
let decryption_err = params.decrypt(&encrypted[..]);
assert!(decryption_err.is_err());
}
}
#[test]
fn too_short() {
let blinded_id = [7; 32].into();
let subcredential = [11; 32].into();
let revision = 13.into();
let string_const = "greetings puny humans";
let params = HsDescEncryption {
blinded_id: &blinded_id,
desc_enc_nonce: None,
subcredential: &subcredential,
revision,
string_const: string_const.as_bytes(),
};
assert!(params.decrypt(b"").is_err());
assert!(params.decrypt(&[0_u8; 47]).is_err());
}
}