问题描述
网上看到个问题, 怎么限定一个函数的参数类型是整型. 然后提问的人给出了他的代码, 大概长这样
fn add_one<T>(lhs: T) -> T {
lhs + 1
}
复制代码
后来, 他又把这段代码改成这种形式
use std::convert::TryFrom;
fn add_one<T>(lhs: T) -> T
where T: Copy + std::ops::Add<Output = T> + TryFrom<i32>
{
lhs + T::try_from(1i32).unwrap_or_else(|_| v)
}
复制代码
不过这样子只能指定 i32
的类型, 跟预期的有点差别
解决方案
按照 Cpp 的模板方式解决, 其实很容易
template<typename T, typename = typename std::enable_if_t<std::is_integral_v<T>>>
T addOne(T lhs) {
return lhs + 1;
}
template<typename T, typename = typename std::enable_if_t<std::is_integral_v<T>>>
T add(T lhs, T rhs) {
return lhs + rhs;
}
复制代码
按照这种思路, 应该要实现一个 Integer
的 trait
pub trait Integer: Sized + PartialOrd + Ord + PartialEq + Eq {}
macro_rules! empty_trait_impl {
($name:ident for $($t:ty)*) => ($(
impl $name for $t {}
)*)
}
empty_trait_impl!(Integer for usize u8 u16 u32 u64);
empty_trait_impl!(Integer for u128);
empty_trait_impl!(Integer for isize i8 i16 i32 i64);
empty_trait_impl!(Integer for i128);
复制代码
然后把函数改成这样
fn add_one<T: Integer>(lhs: T) -> T {
add(lhs, 1)
}
fn add<T: Integer>(lhs: T, rhs: T) -> T {
lhs + rhs
}
复制代码
然后编译器就跟我说不行, 因为问题依旧, Integer
没有实现 Add
的运算, 那就把实现了四则运算的抽象出来, 同时把 Integer
约束加上 NumOps
pub trait NumOps<Rhs = Self, Output = Self>:
Add<Rhs, Output = Output>
+ Sub<Rhs, Output = Output>
+ Mul<Rhs, Output = Output>
+ Div<Rhs, Output = Output>
+ Rem<Rhs, Output = Output>
{
}
impl<T, Rhs, Output> NumOps<Rhs, Output> for T where
T: Add<Rhs, Output = Output>
+ Sub<Rhs, Output = Output>
+ Mul<Rhs, Output = Output>
+ Div<Rhs, Output = Output>
+ Rem<Rhs, Output = Output>
{
}
pub trait Integer: Sized + PartialOrd + Ord + PartialEq + Eq + NumOps {}
复制代码
然而编译器还是不想理睬我, 因为基础的 integer
不是泛型 T
指定的类型. 那就再实现个类型转换.
#[inline]
pub fn cast<T: NumCast, U: NumCast>(n: T) -> Option<U> {
NumCast::from(n)
}
pub trait NumCast: Sized + ToPrimitive {
fn from<T: ToPrimitive>(n: T) -> Option<Self>;
}
pub trait ToPrimitive {
#[inline]
fn to_isize(&self) -> Option<isize> {
self.to_i64().as_ref().and_then(ToPrimitive::to_isize)
}
#[inline]
fn to_i8(&self) -> Option<i8> {
self.to_i64().as_ref().and_then(ToPrimitive::to_i8)
}
#[inline]
fn to_i16(&self) -> Option<i16> {
self.to_i64().as_ref().and_then(ToPrimitive::to_i16)
}
#[inline]
fn to_i32(&self) -> Option<i32> {
self.to_i64().as_ref().and_then(ToPrimitive::to_i32)
}
fn to_i64(&self) -> Option<i64>;
#[inline]
fn to_i128(&self) -> Option<i128> {
self.to_i64().map(From::from)
}
#[inline]
fn to_usize(&self) -> Option<usize> {
self.to_u64().as_ref().and_then(ToPrimitive::to_usize)
}
#[inline]
fn to_u8(&self) -> Option<u8> {
self.to_u64().as_ref().and_then(ToPrimitive::to_u8)
}
#[inline]
fn to_u16(&self) -> Option<u16> {
self.to_u64().as_ref().and_then(ToPrimitive::to_u16)
}
#[inline]
fn to_u32(&self) -> Option<u32> {
self.to_u64().as_ref().and_then(ToPrimitive::to_u32)
}
fn to_u64(&self) -> Option<u64>;
#[inline]
fn to_u128(&self) -> Option<u128> {
self.to_u64().map(From::from)
}
}
复制代码
因为要实现大量重复代码, 所以用宏来处理
use std::mem::size_of;
macro_rules! impl_to_primitive_int_to_int {
($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
#[inline]
fn $method(&self) -> Option<$DstT> {
let min = $DstT::MIN as $SrcT;
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
Some(*self as $DstT)
} else {
None
}
}
)*}
}
macro_rules! impl_to_primitive_int_to_uint {
($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
#[inline]
fn $method(&self) -> Option<$DstT> {
let max = $DstT::MAX as $SrcT;
if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
Some(*self as $DstT)
} else {
None
}
}
)*}
}
macro_rules! impl_to_primitive_int {
($T:ident) => {
impl ToPrimitive for $T {
impl_to_primitive_int_to_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_primitive_int_to_uint! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
}
};
}
impl_to_primitive_int!(isize);
impl_to_primitive_int!(i8);
impl_to_primitive_int!(i16);
impl_to_primitive_int!(i32);
impl_to_primitive_int!(i64);
impl_to_primitive_int!(i128);
macro_rules! impl_to_primitive_uint_to_int {
($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
#[inline]
fn $method(&self) -> Option<$DstT> {
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
Some(*self as $DstT)
} else {
None
}
}
)*}
}
macro_rules! impl_to_primitive_uint_to_uint {
($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
#[inline]
fn $method(&self) -> Option<$DstT> {
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
Some(*self as $DstT)
} else {
None
}
}
)*}
}
macro_rules! impl_to_primitive_uint {
($T:ident) => {
impl ToPrimitive for $T {
impl_to_primitive_uint_to_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_primitive_uint_to_uint! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
}
};
}
impl_to_primitive_uint!(usize);
impl_to_primitive_uint!(u8);
impl_to_primitive_uint!(u16);
impl_to_primitive_uint!(u32);
impl_to_primitive_uint!(u64);
impl_to_primitive_uint!(u128);
复制代码
然后给 add_one 函数加上约束
fn add_one<T: Integer + NumCast>(lhs: T) -> T {
lhs + cast(2).unwrap()
}
#[test]
fn test_add_one() {
assert_eq!(add_one(10), 11);
}
复制代码
全部思路来自 num-traits github.com/rust-num/nu…
总结, Rust 真好玩