From 5be8abf64543d7d24f95cddc13333ca5468a07e0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 26 Jun 2024 17:04:37 -0500 Subject: [PATCH] C expression casting logic: refactor, add some types --- loopy/target/c/codegen/expression.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/loopy/target/c/codegen/expression.py b/loopy/target/c/codegen/expression.py index 04fad17c8..d6793de4a 100644 --- a/loopy/target/c/codegen/expression.py +++ b/loopy/target/c/codegen/expression.py @@ -21,6 +21,7 @@ """ +from typing import Optional import numpy as np from pymbolic.mapper import RecursiveMapper, IdentityMapper @@ -109,7 +110,7 @@ def find_array(self, expr): return ary - def wrap_in_typecast(self, actual_type, needed_type, s): + def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s): if actual_type != needed_type: registry = self.codegen_state.ast_builder.target.get_dtype_registry() cast = var("(%s) " % registry.dtype_to_ctype(needed_type)) @@ -117,13 +118,15 @@ def wrap_in_typecast(self, actual_type, needed_type, s): return s - def rec(self, expr, type_context=None, needed_type=None): - if needed_type is None: - return RecursiveMapper.rec(self, expr, type_context) + def rec(self, expr, type_context=None, needed_type: Optional[LoopyType] = None): + result = RecursiveMapper.rec(self, expr, type_context) - return self.wrap_in_typecast( - self.infer_type(expr), needed_type, - RecursiveMapper.rec(self, expr, type_context)) + if needed_type is None: + return result + else: + return self.wrap_in_typecast( + self.infer_type(expr), needed_type, + result) def __call__(self, expr, prec=None, type_context=None, needed_dtype=None): if prec is None: