use cfg_traits::liquidity_pools::{
InboundMessageHandler, LpMessageHash, LpMessageProof, MessageHash, MessageQueue, RouterProvider,
};
use cfg_types::domain_address::{Domain, DomainAddress};
use frame_support::{
dispatch::DispatchResult,
ensure,
pallet_prelude::{Decode, Encode, TypeInfo},
};
use parity_scale_codec::MaxEncodedLen;
use sp_arithmetic::traits::{EnsureAddAssign, EnsureSub, SaturatedConversion};
use sp_runtime::DispatchError;
use sp_std::vec::Vec;
use crate::{
message::GatewayMessage, Config, Error, Event, Pallet, PendingInboundEntries, Routers,
SessionIdStore,
};
#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)]
#[scale_info(skip_type_params(T))]
pub struct MessageEntry<T: Config> {
pub session_id: T::SessionId,
pub domain_address: DomainAddress,
pub message: T::Message,
pub expected_proof_count: u32,
}
#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)]
#[scale_info(skip_type_params(T))]
pub struct ProofEntry<T: Config> {
pub session_id: T::SessionId,
pub current_count: u32,
}
impl<T: Config> ProofEntry<T> {
pub fn has_valid_vote_for_session(&self, session_id: T::SessionId) -> bool {
self.session_id == session_id && self.current_count > 0
}
}
#[derive(Debug, Encode, Decode, Clone, Eq, MaxEncodedLen, PartialEq, TypeInfo)]
#[scale_info(skip_type_params(T))]
pub enum InboundEntry<T: Config> {
Message(MessageEntry<T>),
Proof(ProofEntry<T>),
}
impl<T: Config> From<MessageEntry<T>> for InboundEntry<T> {
fn from(message_entry: MessageEntry<T>) -> Self {
Self::Message(message_entry)
}
}
impl<T: Config> From<ProofEntry<T>> for InboundEntry<T> {
fn from(proof_entry: ProofEntry<T>) -> Self {
Self::Proof(proof_entry)
}
}
impl<T: Config> InboundEntry<T> {
pub fn create(
message: T::Message,
session_id: T::SessionId,
domain_address: DomainAddress,
expected_proof_count: u32,
) -> Self {
if message.is_proof_message() {
InboundEntry::Proof(ProofEntry {
session_id,
current_count: 1,
})
} else {
InboundEntry::Message(MessageEntry {
session_id,
domain_address,
message,
expected_proof_count,
})
}
}
pub fn create_post_voting_entry(
inbound_entry: &InboundEntry<T>,
expected_proof_count: u32,
) -> Result<Option<Self>, DispatchError> {
match inbound_entry {
InboundEntry::Message(message_entry) => {
let updated_count = message_entry
.expected_proof_count
.ensure_sub(expected_proof_count)?;
if updated_count == 0 {
return Ok(None);
}
Ok(Some(
MessageEntry {
expected_proof_count: updated_count,
..message_entry.clone()
}
.into(),
))
}
InboundEntry::Proof(proof_entry) => {
let updated_count = proof_entry.current_count.ensure_sub(1)?;
if updated_count == 0 {
return Ok(None);
}
Ok(Some(
ProofEntry {
current_count: updated_count,
..proof_entry.clone()
}
.into(),
))
}
}
}
pub fn validate(&self, router_ids: &[T::RouterId], router_id: &T::RouterId) -> DispatchResult {
ensure!(
router_ids.iter().any(|x| x == router_id),
Error::<T>::UnknownRouter
);
match self {
InboundEntry::Message { .. } => {
ensure!(
router_ids.first() == Some(router_id),
Error::<T>::MessageExpectedFromFirstRouter
);
Ok(())
}
InboundEntry::Proof { .. } => {
ensure!(
router_ids.first() != Some(router_id),
Error::<T>::ProofNotExpectedFromFirstRouter
);
Ok(())
}
}
}
pub fn increment_proof_count(&mut self, session_id: T::SessionId) -> DispatchResult {
match self {
InboundEntry::Proof(proof_entry) => {
if proof_entry.session_id != session_id {
proof_entry.session_id = session_id;
proof_entry.current_count = 1;
} else {
proof_entry.current_count.ensure_add_assign(1)?;
}
Ok::<(), DispatchError>(())
}
InboundEntry::Message(_) => Err(Error::<T>::ExpectedMessageProofType.into()),
}
}
pub fn pre_dispatch_update(&mut self, other: Self) -> DispatchResult {
match (&mut *self, &other) {
(
InboundEntry::Message(self_message_entry),
InboundEntry::Message(other_message_entry),
) => {
if self_message_entry.session_id != other_message_entry.session_id {
*self = other;
return Ok(());
}
self_message_entry
.expected_proof_count
.ensure_add_assign(other_message_entry.expected_proof_count)?;
Ok(())
}
(InboundEntry::Proof(self_proof_entry), InboundEntry::Proof(other_proof_entry)) => {
if self_proof_entry.session_id != other_proof_entry.session_id {
*self = other;
return Ok(());
}
self_proof_entry
.current_count
.ensure_add_assign(other_proof_entry.current_count)?;
Ok(())
}
(InboundEntry::Message(_), InboundEntry::Proof(_)) => {
Err(Error::<T>::ExpectedMessageType.into())
}
(InboundEntry::Proof(_), InboundEntry::Message(_)) => {
Err(Error::<T>::ExpectedMessageProofType.into())
}
}
}
}
impl<T: Config> Pallet<T> {
pub(crate) fn get_router_ids_for_domain(
domain: Domain,
) -> Result<Vec<T::RouterId>, DispatchError> {
let stored_routers = Routers::<T>::get();
let all_routers_for_domain = T::RouterProvider::routers_for_domain(domain);
let res = stored_routers
.iter()
.filter(|stored_router| {
all_routers_for_domain
.iter()
.any(|available_router| *stored_router == available_router)
})
.cloned()
.collect::<Vec<_>>();
if res.is_empty() {
return Err(Error::<T>::NotEnoughRoutersForDomain.into());
}
Ok(res)
}
pub(crate) fn get_expected_proof_count(
router_ids: &[T::RouterId],
) -> Result<u32, DispatchError> {
let expected_proof_count = router_ids
.len()
.ensure_sub(1)
.map_err(|_| Error::<T>::NotEnoughRoutersForDomain)?;
Ok(expected_proof_count.saturated_into())
}
pub(crate) fn upsert_pending_entry(
message_hash: MessageHash,
router_id: &T::RouterId,
new_inbound_entry: InboundEntry<T>,
) -> DispatchResult {
PendingInboundEntries::<T>::try_mutate(message_hash, router_id, |storage_entry| {
match storage_entry {
None => {
*storage_entry = Some(new_inbound_entry);
Ok::<(), DispatchError>(())
}
Some(stored_inbound_entry) => {
stored_inbound_entry.pre_dispatch_update(new_inbound_entry)
}
}
})
}
pub(crate) fn execute_if_requirements_are_met(
message_hash: MessageHash,
router_ids: &[T::RouterId],
session_id: T::SessionId,
expected_proof_count: u32,
domain_address: DomainAddress,
) -> DispatchResult {
let mut message = None;
let mut votes = 0;
for router_id in router_ids {
match PendingInboundEntries::<T>::get(message_hash, router_id) {
None => return Ok(()),
Some(stored_inbound_entry) => match stored_inbound_entry {
InboundEntry::Message(message_entry)
if message_entry.session_id == session_id =>
{
message = Some(message_entry.message)
}
InboundEntry::Proof(proof_entry)
if proof_entry.has_valid_vote_for_session(session_id) =>
{
votes.ensure_add_assign(1)?;
}
_ => {}
},
};
}
if votes < expected_proof_count {
return Ok(());
}
if let Some(msg) = message {
T::InboundMessageHandler::handle(domain_address.clone(), msg)?;
Self::execute_post_voting_dispatch(message_hash, router_ids, expected_proof_count)?;
Self::deposit_event(Event::<T>::InboundMessageExecuted {
domain_address,
message_hash,
})
}
Ok(())
}
pub(crate) fn execute_post_voting_dispatch(
message_hash: MessageHash,
router_ids: &[T::RouterId],
expected_proof_count: u32,
) -> DispatchResult {
for router_id in router_ids {
PendingInboundEntries::<T>::try_mutate(message_hash, router_id, |storage_entry| {
match storage_entry {
None => {
Err::<(), DispatchError>(Error::<T>::PendingInboundEntryNotFound.into())
}
Some(stored_inbound_entry) => {
let post_dispatch_entry = InboundEntry::create_post_voting_entry(
stored_inbound_entry,
expected_proof_count,
)?;
*storage_entry = post_dispatch_entry;
Ok(())
}
}
})?;
}
Ok(())
}
pub(crate) fn process_inbound_message(
domain_address: DomainAddress,
message: T::Message,
router_id: T::RouterId,
) -> DispatchResult {
let router_ids = Self::get_router_ids_for_domain(domain_address.domain())?;
let session_id = SessionIdStore::<T>::get();
let expected_proof_count = Self::get_expected_proof_count(&router_ids)?;
let message_hash = message.get_message_hash();
let inbound_entry: InboundEntry<T> = InboundEntry::create(
message.clone(),
session_id,
domain_address.clone(),
expected_proof_count,
);
inbound_entry.validate(&router_ids, &router_id.clone())?;
Self::upsert_pending_entry(message_hash, &router_id, inbound_entry)?;
Self::deposit_processing_event(
domain_address.clone(),
message,
message_hash,
router_id.clone(),
);
Self::execute_if_requirements_are_met(
message_hash,
&router_ids,
session_id,
expected_proof_count,
domain_address.clone(),
)?;
Ok(())
}
fn deposit_processing_event(
domain_address: DomainAddress,
message: T::Message,
message_hash: MessageHash,
router_id: T::RouterId,
) {
if message.is_proof_message() {
Self::deposit_event(Event::<T>::InboundProofProcessed {
domain_address,
message_hash,
router_id,
})
} else {
Self::deposit_event(Event::<T>::InboundMessageProcessed {
domain_address,
message_hash,
router_id,
})
}
}
pub(crate) fn queue_outbound_message(
destination: Domain,
message: T::Message,
) -> DispatchResult {
let router_ids = Self::get_router_ids_for_domain(destination)?;
let proof_message = message.to_proof_message();
let mut message_opt = Some(message);
for router_id in router_ids {
let router_msg = match message_opt.take() {
Some(m) => m,
None => proof_message.clone(),
};
let gateway_message = GatewayMessage::<T::Message, T::RouterId>::Outbound {
message: router_msg,
router_id,
};
T::MessageQueue::queue(gateway_message)?;
}
Ok(())
}
}