Source code for poly_lithic.src.transformers.BaseTransformers

import time
import numpy as np
import sympy as sp
from poly_lithic.src.logging_utils.make_logger import get_logger
from poly_lithic.src.transformers.BaseTransformer import BaseTransformer

logger = get_logger()


[docs] class SimpleTransformer(BaseTransformer): def __init__(self, config): """ config: dict dictionary containing the following keys: - variables: dict dictionary containing the following keys: - formula: str formula to be used for transformation - symbols: list list of symbols to be used in the formula """ pv_mapping = config['variables'] self.input_list = config['symbols'] logger.debug('Initializing SimpleTransformer') logger.debug(f'PV Mapping: {pv_mapping}') logger.debug(f'Symbol List: {self.input_list}') self.pv_mapping = pv_mapping for key, value in self.pv_mapping.items(): self.__validate_formulas(value['formula']) self.latest_input = {symbol: None for symbol in self.input_list} self.latest_transformed = {key: 0 for key in self.pv_mapping.keys()} self.updated = False self.handler_time = None self.formulas = {} self.lambdified_formulas = {} for key, value in self.pv_mapping.items(): self.formulas[key] = sp.sympify(value['formula'].replace(':', '_')) input_list_renamed = [ symbol.replace(':', '_') for symbol in self.input_list ] self.lambdified_formulas[key] = sp.lambdify( input_list_renamed, self.formulas[key], modules='numpy' ) self.handler_time = [] def __validate_formulas(self, formula: str): try: sp.sympify(formula.replace(':', '_')) except Exception as e: raise Exception(f'Invalid formula: {formula}: {e}')
[docs] def handler(self, pv_name, value): # logger.debug(f"SimpleTransformer handler for {pv_name} with value {value}") # chek if pv_name is in sel.input_list if pv_name in self.input_list: # assert value is float try: if isinstance(value['value'], (float, int, np.float32)): value = float(value['value']) elif isinstance(value['value'], (np.ndarray, list)): value = np.array(value['value']).astype(float) else: raise Exception( f'Invalid type for value: {value}, type: {type(value["value"])}' ) except Exception as e: logger.error(f'Error converting value to float: {e}') raise e self.latest_input[pv_name] = value try: if all([value is not None for value in self.latest_input.values()]): time_start = time.time() self.transform() self.handler_time = time.time() - time_start # logger.info(f'Handler time for {pv_name} is {self.handler_time}') # if self.handler_time > 0.5: # logger.warning(f'Handler time for {pv_name} is {self.handler_time}') # print(f'self.latest_input: {self.latest_input}') # print(f'self.latest_transformed: {self.latest_transformed}') except Exception as e: logger.error(f'Error transforming: {e}') raise e else: logger.debug(f'PV name {pv_name} not in input list')
# def transform(self): # # logger.debug("Transforming") # transformed = {} # pvs_renamed = { # key.replace(':', '_'): value for key, value in self.latest_input.items() # } # pv_shapes = {} # # convert to sympy symbols # for key, value in pvs_renamed.items(): # if isinstance(value, (np.ndarray, list)): # pv_shapes[key] = value.shape # pvs_renamed[key] = sp.Matrix(value) # elif isinstance(value, (float, int)): # pvs_renamed[key] = value # else: # raise Exception(f'Invalid type for value: {value}') # for key, value in self.pv_mapping.items(): # try: # # formula = value['formula'].replace(':', '_') # # formula = sp.sympify(formula) # formula = self.formulas[key] # transformed[key] = formula.subs(pvs_renamed) # # print(transformed[key]) # # converted to float # if isinstance(transformed[key], sp.Matrix | sp.ImmutableDenseMatrix): # # bit hacky but casuse sympy is meant to be symbolic only and not numerical # s = sp.symbols('s') # numpy_value = sp.lambdify(s, transformed[key], modules='numpy') # numpy_value = numpy_value(0) # transformed[key] = numpy_value # # drop last dim if it is 1 # if transformed[key].shape[-1] == 1: # transformed[key] = transformed[key].squeeze() # else: # transformed[key] = float(transformed[key]) # except Exception as e: # logger.error(f'Error transforming: {e}') # raise e # for key, value in transformed.items(): # self.latest_transformed[key] = value # self.updated = True
[docs] def transform(self): transformed = {} pvs_renamed = { key.replace(':', '_'): value for key, value in self.latest_input.items() } for key, value in self.pv_mapping.items(): try: lambdified_formula = self.lambdified_formulas[key] transformed[key] = lambdified_formula(*[ pvs_renamed[symbol.replace(':', '_')] for symbol in self.input_list ]) if isinstance(transformed[key], np.ndarray): if transformed[key].shape[-1] == 1: transformed[key] = transformed[key].squeeze() else: transformed[key] = float(transformed[key]) except Exception as e: logger.error(f'Error transforming: {e}') raise e for key, value in transformed.items(): self.latest_transformed[key] = value self.updated = True
[docs] class CAImageTransfomer(BaseTransformer): """Input only image transformation""" def __init__(self, config) -> None: self.img = config['variables'] self.img_list = list(self.img.keys()) self.variables = {} self.input_list = [] for key, value in self.img.items(): self.variables[key] = value['img_ch'] self.variables[key + '_x'] = value['img_x_ch'] self.variables[key + '_y'] = value['img_y_ch'] if 'unfold' in value.keys(): self.variables[key + '_unfolding'] = value['unfold'] else: self.variables[key + '_unfolding'] = 'row_major' self.input_list.append(value['img_ch']) self.input_list.append(value['img_x_ch']) self.input_list.append(value['img_y_ch']) self.latest_input = {symbol: None for symbol in self.input_list} self.latest_transformed = {key: 0 for key in self.variables.keys()} self.handler_time = None self.updated = False
[docs] def handler(self, variable_name: str, value: dict): logger.debug(f'CAImageTransfomer handler for {variable_name}') try: self.latest_input[variable_name] = value['value'] if all([value is not None for value in self.latest_input.values()]): time_start = time.time() self.transform() self.handler_time = time.time() - time_start else: logger.debug('Not all values are present') except Exception as e: logger.error(f'Error transforming: {e}') raise e
[docs] def transform(self): logger.debug('Transforming') transformed = {} for key in self.img_list: value = self.latest_input[self.variables[key]] # print x and y try: transformed[key] = np.array(value).reshape( ( int( self.latest_input[self.variables[key + '_y']] ), # note the order, we are going from x,y to y,x (rows, columns) in numpy int(self.latest_input[self.variables[key + '_x']]), ), order='F' if self.variables[key + '_unfolding'] == 'column_major' else 'C', ) if self.variables[key + '_unfolding'] == 'column_major': transformed[key] = transformed[key].T except Exception as e: logger.error(f'Error transforming: {e}') for key, value in transformed.items(): self.latest_transformed[key] = value self.updated = True
[docs] class PassThroughTransformer(BaseTransformer): def __init__(self, config): # config is a dictionary of output:intput pairs pv_mapping = config['variables'] self.latest_input = {} self.latest_transformed = {} self.updated = False self.input_list = list(pv_mapping.values()) for key, value in pv_mapping.items(): self.latest_input[value] = None self.latest_transformed[key] = None self.pv_mapping = pv_mapping self.handler_time = 0
[docs] def handler(self, pv_name, value): time_start = time.time() logger.debug(f'PassThroughTransformer handler for {pv_name}') self.latest_input[pv_name] = value['value'] if all([value is not None for value in self.latest_input.values()]): self.transform() self.updated = True time_end = time.time() self.handler_time = time_end - time_start
[docs] def transform(self): logger.debug('Transforming') for key, value in self.pv_mapping.items(): self.latest_transformed[key] = self.latest_input[value] if isinstance(self.latest_input[value], np.ndarray): if self.latest_input[value].shape != self.latest_transformed[key].shape: logger.error(f'Shape mismatch between input and output for {key}') self.updated = True
# for key, value in self.latest_input.items(): # logger.debug(f"{key}: {value.shape}") # for key, value in self.latest_transformed.items(): # logger.debug(f"{key}: {value.shape}")