Rust 怎么限定参数类型是整型

314 阅读2分钟

问题描述

网上看到个问题, 怎么限定一个函数的参数类型是整型. 然后提问的人给出了他的代码, 大概长这样

fn add_one<T>(lhs: T) -> T {
  lhs + 1
}

后来, 他又把这段代码改成这种形式

use std::convert::TryFrom;

fn add_one<T>(lhs: T) -> T
  where T: Copy + std::ops::Add<Output = T> + TryFrom<i32>
{
  lhs + T::try_from(1i32).unwrap_or_else(|_| v)
}

不过这样子只能指定 i32 的类型, 跟预期的有点差别

解决方案

按照 Cpp 的模板方式解决, 其实很容易

template<typename T, typename = typename std::enable_if_t<std::is_integral_v<T>>>
T addOne(T lhs) {
  return lhs + 1;
}

template<typename T, typename = typename std::enable_if_t<std::is_integral_v<T>>>
T add(T lhs, T rhs) {
  return lhs + rhs;
}

按照这种思路, 应该要实现一个 Integertrait

pub trait Integer: Sized + PartialOrd + Ord + PartialEq + Eq {}

macro_rules! empty_trait_impl {
  ($name:ident for $($t:ty)*) => ($(
    impl $name for $t {}
  )*)
}

empty_trait_impl!(Integer for usize u8 u16 u32 u64);
empty_trait_impl!(Integer for u128);

empty_trait_impl!(Integer for isize i8 i16 i32 i64);
empty_trait_impl!(Integer for i128);

然后把函数改成这样

fn add_one<T: Integer>(lhs: T) -> T {
  add(lhs, 1)
}

fn add<T: Integer>(lhs: T, rhs: T) -> T {
  lhs + rhs
}

然后编译器就跟我说不行, 因为问题依旧, Integer 没有实现 Add 的运算, 那就把实现了四则运算的抽象出来, 同时把 Integer 约束加上 NumOps

pub trait NumOps<Rhs = Self, Output = Self>:
  Add<Rhs, Output = Output>
  + Sub<Rhs, Output = Output>
  + Mul<Rhs, Output = Output>
  + Div<Rhs, Output = Output>
  + Rem<Rhs, Output = Output>
{
}

impl<T, Rhs, Output> NumOps<Rhs, Output> for T where
  T: Add<Rhs, Output = Output>
    + Sub<Rhs, Output = Output>
    + Mul<Rhs, Output = Output>
    + Div<Rhs, Output = Output>
    + Rem<Rhs, Output = Output>
{
}

pub trait Integer: Sized + PartialOrd + Ord + PartialEq + Eq + NumOps {}

然而编译器还是不想理睬我, 因为基础的 integer 不是泛型 T 指定的类型. 那就再实现个类型转换.

#[inline]
pub fn cast<T: NumCast, U: NumCast>(n: T) -> Option<U> {
  NumCast::from(n)
}

pub trait NumCast: Sized + ToPrimitive {
  fn from<T: ToPrimitive>(n: T) -> Option<Self>;
}

pub trait ToPrimitive {
  #[inline]
  fn to_isize(&self) -> Option<isize> {
    self.to_i64().as_ref().and_then(ToPrimitive::to_isize)
  }

  #[inline]
  fn to_i8(&self) -> Option<i8> {
    self.to_i64().as_ref().and_then(ToPrimitive::to_i8)
  }

  #[inline]
  fn to_i16(&self) -> Option<i16> {
    self.to_i64().as_ref().and_then(ToPrimitive::to_i16)
  }

  #[inline]
  fn to_i32(&self) -> Option<i32> {
    self.to_i64().as_ref().and_then(ToPrimitive::to_i32)
  }

  fn to_i64(&self) -> Option<i64>;

  #[inline]
  fn to_i128(&self) -> Option<i128> {
    self.to_i64().map(From::from)
  }

  #[inline]
  fn to_usize(&self) -> Option<usize> {
    self.to_u64().as_ref().and_then(ToPrimitive::to_usize)
  }

  #[inline]
  fn to_u8(&self) -> Option<u8> {
    self.to_u64().as_ref().and_then(ToPrimitive::to_u8)
  }

  #[inline]
  fn to_u16(&self) -> Option<u16> {
    self.to_u64().as_ref().and_then(ToPrimitive::to_u16)
  }

  #[inline]
  fn to_u32(&self) -> Option<u32> {
    self.to_u64().as_ref().and_then(ToPrimitive::to_u32)
  }

  fn to_u64(&self) -> Option<u64>;

  #[inline]
  fn to_u128(&self) -> Option<u128> {
    self.to_u64().map(From::from)
  }
}

因为要实现大量重复代码, 所以用宏来处理

use std::mem::size_of;

macro_rules! impl_to_primitive_int_to_int {
  ($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
    #[inline]
    fn $method(&self) -> Option<$DstT> {
      let min = $DstT::MIN as $SrcT;
      let max = $DstT::MAX as $SrcT;
      if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
        Some(*self as $DstT)
      } else {
        None
      }
    }
  )*}
}

macro_rules! impl_to_primitive_int_to_uint {
  ($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
    #[inline]
    fn $method(&self) -> Option<$DstT> {
      let max = $DstT::MAX as $SrcT;
      if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
        Some(*self as $DstT)
      } else {
        None
      }
    }
  )*}
}

macro_rules! impl_to_primitive_int {
  ($T:ident) => {
    impl ToPrimitive for $T {
      impl_to_primitive_int_to_int! { $T:
        fn to_isize -> isize;
        fn to_i8 -> i8;
        fn to_i16 -> i16;
        fn to_i32 -> i32;
        fn to_i64 -> i64;
        fn to_i128 -> i128;
      }

      impl_to_primitive_int_to_uint! { $T:
        fn to_usize -> usize;
        fn to_u8 -> u8;
        fn to_u16 -> u16;
        fn to_u32 -> u32;
        fn to_u64 -> u64;
        fn to_u128 -> u128;
      }
    }
  };
}

impl_to_primitive_int!(isize);
impl_to_primitive_int!(i8);
impl_to_primitive_int!(i16);
impl_to_primitive_int!(i32);
impl_to_primitive_int!(i64);
impl_to_primitive_int!(i128);

macro_rules! impl_to_primitive_uint_to_int {
  ($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
    #[inline]
    fn $method(&self) -> Option<$DstT> {
      let max = $DstT::MAX as $SrcT;
      if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
        Some(*self as $DstT)
      } else {
        None
      }
    }
  )*}
}

macro_rules! impl_to_primitive_uint_to_uint {
  ($SrcT:ident : $( fn $method:ident -> $DstT:ident; )*) => {$(
    #[inline]
    fn $method(&self) -> Option<$DstT> {
      let max = $DstT::MAX as $SrcT;
      if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
        Some(*self as $DstT)
      } else {
        None
      }
    }
  )*}
}

macro_rules! impl_to_primitive_uint {
  ($T:ident) => {
    impl ToPrimitive for $T {
      impl_to_primitive_uint_to_int! { $T:
        fn to_isize -> isize;
        fn to_i8 -> i8;
        fn to_i16 -> i16;
        fn to_i32 -> i32;
        fn to_i64 -> i64;
        fn to_i128 -> i128;
      }

      impl_to_primitive_uint_to_uint! { $T:
        fn to_usize -> usize;
        fn to_u8 -> u8;
        fn to_u16 -> u16;
        fn to_u32 -> u32;
        fn to_u64 -> u64;
        fn to_u128 -> u128;
      }
    }
  };
}

impl_to_primitive_uint!(usize);
impl_to_primitive_uint!(u8);
impl_to_primitive_uint!(u16);
impl_to_primitive_uint!(u32);
impl_to_primitive_uint!(u64);
impl_to_primitive_uint!(u128);

然后给 add_one 函数加上约束

fn add_one<T: Integer + NumCast>(lhs: T) -> T {
  lhs + cast(2).unwrap()
}

#[test]
fn test_add_one() {
  assert_eq!(add_one(10), 11);
}

全部思路来自 num-traits github.com/rust-num/nu…

总结, Rust 真好玩