use cfg_primitives::{AccountId, Balance};
use cfg_traits::{PreConditions, TransferAllowance};
use cfg_types::{
domain_address::DomainAddress,
locations::RestrictedTransferLocation,
tokens::{CurrencyId, FilterCurrency},
};
use frame_support::{traits::IsSubType, RuntimeDebugNoBound};
use pallet_restricted_tokens::TransferDetails;
use pallet_restricted_xtokens::TransferEffects;
use parity_scale_codec::{Decode, Encode};
use scale_info::TypeInfo;
use sp_runtime::{
traits::{Convert, DispatchInfoOf, SignedExtension, StaticLookup},
transaction_validity::{InvalidTransaction, TransactionValidityError},
DispatchError, DispatchResult, TokenError,
};
use sp_std::{boxed::Box, vec::Vec};
use staging_xcm::{
v4::{Asset, Location},
VersionedLocation,
};
pub struct PreXcmTransfer<T, C>(sp_std::marker::PhantomData<(T, C)>);
impl<
T: TransferAllowance<
AccountId,
CurrencyId = FilterCurrency,
Location = RestrictedTransferLocation,
>,
C: Convert<Location, Option<CurrencyId>>,
> PreConditions<TransferEffects<AccountId, CurrencyId, Balance>> for PreXcmTransfer<T, C>
{
type Result = DispatchResult;
fn check(t: TransferEffects<AccountId, CurrencyId, Balance>) -> Self::Result {
let currency_based_check = |sender: AccountId, destination: VersionedLocation, currency| {
amalgamate_allowance(
T::allowance(
sender.clone(),
RestrictedTransferLocation::Xcm(Box::new(destination.clone())),
FilterCurrency::Specific(currency),
),
T::allowance(
sender,
RestrictedTransferLocation::Xcm(Box::new(destination)),
FilterCurrency::All,
),
)
};
let asset_based_check = |sender, destination, asset: Asset| {
let currency =
C::convert(asset.id.0).ok_or(DispatchError::Token(TokenError::UnknownAsset))?;
currency_based_check(sender, destination, currency)
};
match t {
TransferEffects::Transfer {
sender,
destination,
currency_id,
..
} => currency_based_check(sender, destination, currency_id),
TransferEffects::TransferMultiAsset {
sender,
destination,
asset,
} => asset_based_check(sender, destination, asset),
TransferEffects::TransferWithFee {
sender,
destination,
currency_id,
..
} => currency_based_check(sender, destination, currency_id),
TransferEffects::TransferMultiAssetWithFee {
sender,
destination,
asset,
fee_asset,
} => {
asset_based_check(sender.clone(), destination.clone(), asset)?;
asset_based_check(sender, destination, fee_asset)
}
TransferEffects::TransferMultiCurrencies {
sender,
destination,
currencies,
fee,
} => {
for (currency, ..) in currencies {
currency_based_check(sender.clone(), destination.clone(), currency)?;
}
currency_based_check(sender, destination, fee.0)
}
TransferEffects::TransferMultiAssets {
sender,
destination,
assets,
fee_asset,
} => {
for asset in assets.into_inner() {
asset_based_check(sender.clone(), destination.clone().clone(), asset)?;
}
asset_based_check(sender, destination, fee_asset)
}
}
}
}
pub struct PreNativeTransfer<T>(sp_std::marker::PhantomData<T>);
impl<
T: TransferAllowance<
AccountId,
CurrencyId = FilterCurrency,
Location = RestrictedTransferLocation,
>,
> PreConditions<TransferDetails<AccountId, CurrencyId, Balance>> for PreNativeTransfer<T>
{
type Result = bool;
fn check(t: TransferDetails<AccountId, CurrencyId, Balance>) -> Self::Result {
amalgamate_allowance(
T::allowance(
t.send.clone(),
RestrictedTransferLocation::Local(t.recv.clone()),
FilterCurrency::Specific(t.id),
),
T::allowance(
t.send.clone(),
RestrictedTransferLocation::Local(t.recv.clone()),
FilterCurrency::All,
),
)
.is_ok()
}
}
pub struct PreLpTransfer<T>(sp_std::marker::PhantomData<T>);
impl<
T: TransferAllowance<
AccountId,
CurrencyId = FilterCurrency,
Location = RestrictedTransferLocation,
>,
> PreConditions<(AccountId, DomainAddress, CurrencyId)> for PreLpTransfer<T>
{
type Result = DispatchResult;
fn check(t: (AccountId, DomainAddress, CurrencyId)) -> Self::Result {
let (sender, receiver, currency) = t;
amalgamate_allowance(
T::allowance(
sender.clone(),
RestrictedTransferLocation::Address(receiver.clone()),
FilterCurrency::Specific(currency),
),
T::allowance(
sender,
RestrictedTransferLocation::Address(receiver),
FilterCurrency::All,
),
)
}
}
#[derive(
Clone, Copy, PartialOrd, Ord, PartialEq, Eq, RuntimeDebugNoBound, Encode, Decode, TypeInfo,
)]
#[scale_info(skip_type_params(T))]
pub struct PreBalanceTransferExtension<T: frame_system::Config>(sp_std::marker::PhantomData<T>);
#[allow(clippy::new_without_default)]
impl<T> PreBalanceTransferExtension<T>
where
T: frame_system::Config<AccountId = AccountId>
+ pallet_balances::Config
+ pallet_utility::Config<RuntimeCall = <T as frame_system::Config>::RuntimeCall>
+ pallet_proxy::Config<RuntimeCall = <T as frame_system::Config>::RuntimeCall>
+ pallet_remarks::Config<RuntimeCall = <T as frame_system::Config>::RuntimeCall>
+ Sync
+ Send,
<T as frame_system::Config>::RuntimeCall: IsSubType<pallet_balances::Call<T>>
+ IsSubType<pallet_utility::Call<T>>
+ IsSubType<pallet_proxy::Call<T>>
+ IsSubType<pallet_remarks::Call<T>>,
{
pub fn new() -> Self {
Self(sp_std::marker::PhantomData)
}
#[allow(clippy::type_complexity)]
fn retrieve(
caller: &T::AccountId,
call: &<T as frame_system::Config>::RuntimeCall,
) -> Result<Vec<(T::AccountId, T::AccountId)>, TransactionValidityError> {
Self::recursive_search(caller.clone(), call, |who, balance_call, checks| {
match balance_call {
pallet_balances::Call::transfer_all { dest, .. }
| pallet_balances::Call::transfer_allow_death { dest, .. }
| pallet_balances::Call::transfer_keep_alive { dest, .. } => {
let recv: T::AccountId = <T as frame_system::Config>::Lookup::lookup(
dest.clone(),
)
.map_err(|_| TransactionValidityError::Invalid(InvalidTransaction::Call))?;
checks.push((who, recv));
Ok(())
}
_ => Ok(()),
}
})
}
#[allow(clippy::type_complexity)]
#[allow(clippy::single_match)]
#[allow(clippy::collapsible_match)]
fn recursive_search<F>(
caller: T::AccountId,
call: &<T as frame_system::Config>::RuntimeCall,
check: F,
) -> Result<Vec<(T::AccountId, T::AccountId)>, TransactionValidityError>
where
F: Fn(
T::AccountId,
pallet_balances::Call<T>,
&mut Vec<(T::AccountId, T::AccountId)>,
) -> Result<(), TransactionValidityError>
+ Clone,
{
let mut checks = Vec::new();
if let Some(balance_call) = IsSubType::<pallet_balances::Call<T>>::is_sub_type(call) {
check(caller, balance_call.clone(), &mut checks)?;
} else if let Some(call) = IsSubType::<pallet_proxy::Call<T>>::is_sub_type(call) {
match call {
pallet_proxy::Call::<T>::proxy { real, call, .. }
| pallet_proxy::Call::<T>::proxy_announced { real, call, .. } => {
let caller = T::Lookup::lookup(real.clone())
.map_err(|_| TransactionValidityError::Invalid(InvalidTransaction::Call))?;
checks.extend(Self::recursive_search(caller, call, check)?);
}
_ => {}
}
} else if let Some(utility_call) = IsSubType::<pallet_utility::Call<T>>::is_sub_type(call) {
match utility_call {
pallet_utility::Call::<T>::batch { calls: batch_calls }
| pallet_utility::Call::<T>::batch_all { calls: batch_calls } => {
for batch_call in batch_calls {
checks.extend(Self::recursive_search(
caller.clone(),
batch_call,
check.clone(),
)?);
}
}
_ => {}
}
} else if let Some(remarks_call) = IsSubType::<pallet_remarks::Call<T>>::is_sub_type(call) {
match remarks_call {
pallet_remarks::Call::<T>::remark {
call: remark_call, ..
} => checks.extend(Self::recursive_search(caller, remark_call, check)?),
_ => {}
}
}
Ok(checks)
}
}
impl<T> SignedExtension for PreBalanceTransferExtension<T>
where
T: frame_system::Config<AccountId = AccountId>
+ pallet_balances::Config
+ pallet_utility::Config<RuntimeCall = <T as frame_system::Config>::RuntimeCall>
+ pallet_proxy::Config<RuntimeCall = <T as frame_system::Config>::RuntimeCall>
+ pallet_remarks::Config<RuntimeCall = <T as frame_system::Config>::RuntimeCall>
+ pallet_transfer_allowlist::Config<
CurrencyId = FilterCurrency,
Location = RestrictedTransferLocation,
> + Sync
+ Send,
<T as frame_system::Config>::RuntimeCall: IsSubType<pallet_balances::Call<T>>
+ IsSubType<pallet_utility::Call<T>>
+ IsSubType<pallet_proxy::Call<T>>
+ IsSubType<pallet_remarks::Call<T>>,
{
type AccountId = T::AccountId;
type AdditionalSigned = ();
type Call = <T as frame_system::Config>::RuntimeCall;
type Pre = ();
const IDENTIFIER: &'static str = "PreBalanceTransferExtension";
fn additional_signed(&self) -> Result<Self::AdditionalSigned, TransactionValidityError> {
Ok(())
}
fn pre_dispatch(
self,
who: &Self::AccountId,
call: &Self::Call,
_: &DispatchInfoOf<Self::Call>,
_: usize,
) -> Result<Self::Pre, TransactionValidityError> {
Self::retrieve(who, call)?
.iter()
.try_for_each(|(who, recv)| {
amalgamate_allowance(
pallet_transfer_allowlist::pallet::Pallet::<T>::allowance(
who.clone(),
RestrictedTransferLocation::Local(recv.clone()),
FilterCurrency::All,
),
pallet_transfer_allowlist::pallet::Pallet::<T>::allowance(
who.clone(),
RestrictedTransferLocation::Local(recv.clone()),
FilterCurrency::Specific(CurrencyId::Native),
),
)
.map_err(|_| TransactionValidityError::Invalid(InvalidTransaction::Custom(255)))
})
}
}
fn amalgamate_allowance(
first: Result<Option<RestrictedTransferLocation>, DispatchError>,
second: Result<Option<RestrictedTransferLocation>, DispatchError>,
) -> DispatchResult {
match (first, second) {
(Err(e), Err(_)) => Err(e),
(Err(_), Ok(Some(_))) => Ok(()),
(Ok(Some(_)), Err(_)) => Ok(()),
(Ok(None), Err(e)) => Err(e),
(Err(e), Ok(None)) => Err(e),
(Ok(Some(_)), Ok(Some(_))) => Ok(()),
(Ok(None), Ok(None)) => Ok(()),
(Ok(Some(_)), Ok(None)) => Ok(()),
(Ok(None), Ok(Some(_))) => Ok(()),
}
}