diff --git a/src/ops/constant.rs b/src/ops/constant.rs index 308378e3a..bd8970b1e 100644 --- a/src/ops/constant.rs +++ b/src/ops/constant.rs @@ -76,6 +76,11 @@ impl Const { .unzip(); Self::new(Value::tuple(values), Type::new_tuple(types)).unwrap() } + + /// For a Const holding a CustomConst, extract the CustomConst by downcasting. + pub fn get_custom_value(&self) -> Option<&T> { + self.value().get_custom_value() + } } impl OpName for Const { @@ -123,7 +128,7 @@ mod test { prelude::{ConstUsize, USIZE_T}, ExtensionId, ExtensionSet, }, - std_extensions::arithmetic::float_types::FLOAT64_TYPE, + std_extensions::arithmetic::float_types::{ConstF64, FLOAT64_TYPE}, type_row, types::test::test_registry, types::type_param::TypeArg, @@ -197,11 +202,23 @@ mod test { tuple_ty.check_type(&tuple_val2), Err(ConstTypeError::ValueCheckFail(ty, tv2)) => ty == tuple_ty && tv2 == tuple_val2 ); - let tuple_val3 = Value::tuple([int_value, serialized_float(3.3), serialized_float(2.0)]); + let tuple_val3 = Value::tuple([ + int_value.clone(), + serialized_float(3.3), + serialized_float(2.0), + ]); assert_eq!( tuple_ty.check_type(&tuple_val3), Err(ConstTypeError::TupleWrongLength) ); + + let op = Const::new(int_value, USIZE_T).unwrap(); + + assert_eq!(op.get_custom_value(), Some(&ConstUsize::new(257))); + let try_float: Option<&ConstF64> = op.get_custom_value(); + assert!(try_float.is_none()); + let try_usize: Option<&ConstUsize> = tuple_val.get_custom_value(); + assert!(try_usize.is_none()); } #[test] diff --git a/src/values.rs b/src/values.rs index 987e8a5be..808cf883e 100644 --- a/src/values.rs +++ b/src/values.rs @@ -127,6 +127,18 @@ impl Value { val: PrimValue::Extension { c: (Box::new(c),) }, } } + + /// For a Const holding a CustomConst, extract the CustomConst by downcasting. + pub fn get_custom_value(&self) -> Option<&T> { + if let Value::Prim { + val: PrimValue::Extension { c: (custom,) }, + } = self + { + custom.downcast_ref() + } else { + None + } + } } impl From for Value {