from abc import ABC, abstractmethod
from typing import Dict, Any, Optional, Union
from pydantic import (
BaseModel,
Field,
field_validator,
computed_field,
)
import time
from poly_lithic.src.logging_utils import get_logger
from poly_lithic.src.transformers import BaseTransformer
from poly_lithic.src.interfaces import BaseInterface
from poly_lithic.src.model_utils import registered_model_getters
import os
# from deepdiff import DeepDiff
import hashlib
logger = get_logger()
import cProfile
[docs]
def profileit(func):
def wrapper(*args, **kwargs):
start_time = time.time()
datafn = func.__name__ + '.profile' # Name the data file sensibly
prof = cProfile.Profile()
retval = prof.runcall(func, *args, **kwargs)
end = time.time()
if end - start_time > 0.3:
prof.dump_stats(datafn)
return retval
return wrapper
[docs]
class Message(BaseModel):
topic: Union[str, list[str]]
source: str
## key: str made a mess of this by including a key, no need to include a key
value: dict = Field(default_factory=dict)
timestamp: float = Field(default_factory=time.time)
# optional
allow_unsafe: Optional[bool] = False
[docs]
@field_validator('topic')
@classmethod
def check_topic(cls, topic):
if not isinstance(topic, (str, list)):
raise ValueError('topic must be a string or list of strings')
elif isinstance(topic, list):
t_len = len(topic)
if t_len == 0 or t_len > 1:
raise ValueError('topic list must contain one element')
else:
return topic[0]
else:
return topic
[docs]
@field_validator('value')
@classmethod
def check_value(cls, value):
if not isinstance(value, dict):
if cls.allow_unsafe:
logger.warning(f'allowing unsafe value {value}')
return {'value': value}
else:
raise ValueError('value must be a dictionary')
# structs must be
# {name : {"value": value, "timestamp": timestamp, "metadata": metadata}} value is mandatory, timestamp is optional, metadata is optional
# can have multiple structs in a dictionary {name1: struct1, name2: struct2}
for key, struct in value.items():
if not isinstance(struct, dict):
raise ValueError('struct must be a dictionary')
if 'value' not in struct:
raise ValueError('struct must contain a value')
if 'timestamp' in struct:
if not isinstance(struct['timestamp'], (int, float)):
raise ValueError('timestamp must be an int or float')
if 'metadata' in struct:
if not isinstance(struct['metadata'], dict):
raise ValueError('metadata must be a dictionary')
return value
@computed_field
def keys(self) -> list[str]:
return list(self.value.keys())
@computed_field
def values(self) -> list[Any]:
return list(self.value.values())
@computed_field
def uid(self) -> str:
"""return a unique id for the message"""
items = []
for key, value in self.value.items():
value_items = frozenset((k, str(v)) for k, v in value.items())
items.append((key, value_items))
return hashlib.md5(str(frozenset(items)).encode()).hexdigest()
def __str__(self):
return f'Message(topic={self.topic}, source={self.source}, value={self.value}, timestamp={self.timestamp})'
def __repr__(self):
return f'Message(topic={self.topic}, source={self.source}, value={self.value}, timestamp={self.timestamp})'
def __eq__(self, value):
# value timestamp source and topic must be the same
if (
self.topic == value.topic
and self.source == value.source
and self.timestamp == value.timestamp
and self.value == value.value
):
return True
else:
return False
[docs]
class Observer(ABC):
[docs]
@abstractmethod
def update(self, message: Message) -> Message:
# all updates should return a message
pass
[docs]
class MessageBroker:
def __init__(self):
"""initialize the message broker"""
self._observers: Dict[str, list[Observer]] = {}
self._stats = {}
self._stats_cnt = {}
self.queue = []
self.last_update = time.time()
[docs]
def attach(self, observer: Observer, topic: str | list[str]) -> None:
"""add observer to topic"""
logger.debug(f'attaching {observer} to {topic}')
if isinstance(topic, list):
for t in topic:
if t not in self._observers:
self._observers[t] = []
self._observers[t].append(observer)
else:
if topic not in self._observers:
self._observers[topic] = []
self._observers[topic].append(observer)
[docs]
def detach(self, observer: Observer, topic: str | list[str]) -> None:
"""remove observer from topic, we will probably never use this"""
if isinstance(topic, list):
for t in topic:
if t in self._observers:
self._observers[t].remove(observer)
else:
self._observers[topic].remove(observer)
# @profileit
[docs]
def notify(self, message: Message) -> None:
"""notify all observers of a message"""
if message.topic in self._observers:
# logger.debug(f"notifying observers of {message}")
for observer in self._observers[message.topic]:
logger.debug(f'notifying {observer}')
start = time.time()
result = observer.update(message)
end = time.time()
if str(observer) not in self._stats:
self._stats[str(observer)] = 0
self._stats_cnt[str(observer)] = 0
self._stats[str(observer)] += (end - start) * 1000
self._stats_cnt[str(observer)] += 1
if result is not None:
# if list of messages
if isinstance(result, list):
for r in result:
self.queue.append(r)
else:
self.queue.append(result)
if time.time() - self.last_update > 1:
self.last_update = time.time()
fmt_stats = {k: v / self._stats_cnt[k] for k, v in self._stats.items()}
'\n\t\n' + '\t\n'.join([
f'{k}: {v:.2f}ms' for k, v in fmt_stats.items()
])
# sum all _stats
sum_time = sum([v for v in self._stats.values()])
cnt = sum([v for v in self._stats_cnt.values()])
logger.info(
f'real time factor: {sum_time / 1000:.2f} must be less than 1, time spent updating this cycle : {sum_time:.2f}ms'
)
# print(self._stats)
# print(self._stats_cnt)
self._stats = {}
self._stats_cnt = {}
else:
logger.error(f'no observers for {message.topic}')
[docs]
def get_stats(self):
return self._stats
[docs]
def get_all(self) -> None:
refresh_msg = Message(
topic='get_all', source='clock', value={'dummy': {'value': 1}}
)
self.notify(refresh_msg)
return None
[docs]
def parse_queue(self):
queue_snapshot = self.queue.copy()
for message in queue_snapshot:
self.notify(message)
self.queue.remove(message)
logger.debug(f'queue length: {len(self.queue)}')
# logger.debug(f"queue: {self.queue}")
[docs]
class InterfaceObserver(Observer):
def __init__(self, interface: BaseInterface, topic: str, sanitise: bool = True):
"""wraps around the interface.put_many method"""
self.interface: BaseInterface = interface
self.topic: str = topic
self.sanitise = sanitise
self.last_get_all = None
[docs]
def update(self, message: Message) -> Message | list[Message]:
if message.topic == 'get_all':
messages = self.get_all()
# compare to last_get_all if not None
if self.last_get_all is not None:
# compare uid for each message
diff = False
for m in messages:
if m.uid not in [msg.uid for msg in self.last_get_all]:
diff = True
break
# print(self.last_get_all, messages)
if diff:
self.last_get_all = messages
return messages
else:
logger.debug('no diff')
return None
else:
self.last_get_all = messages
return messages
return messages
else:
logger.debug(f'updating {self}')
if os.environ['PUBLISH'] == 'True':
self.interface.put_many(message.value)
else:
logger.warning(
'PUBLISH is set to False, this will not publish to the interface'
)
[docs]
def get(self, message: Message) -> list[Message]:
"""get a single variable from the interface"""
messages = []
for key in message.keys:
key, value = self.interface.get(key)
messages.append(
Message(topic=self.topic, source=str(self), value={key: value})
)
return messages
[docs]
def get_all(self) -> list[Message]:
"""get all variables from the interface based on internal variable list"""
messages = []
output_dict = {}
self.interface.get_many(self.interface.variable_list)
# print(f"values: {values}")
for key in self.interface.variable_list:
key, value = self.interface.get(key)
if value is not None:
output_dict[key] = value
messages.append(Message(topic=self.topic, source=str(self), value=output_dict))
return messages
# if self.last_get_all is not None:
# diff = DeepDiff(self.last_get_all, output_dict)
# self.last_get_all = output_dict
# if diff:
# messages.append(
# Message(topic=self.topic, source=str(self), value=output_dict)
# )
# else:
# logger.debug("no diff")
# else:
# self.last_get_all = output_dict
# messages.append(
# Message(topic=self.topic, source=str(self), value=output_dict)
# )
# return messages
[docs]
def get_many(self, message: Message) -> list[Message]:
"""get many variables from the interface"""
keys, values = self.interface.get_many(message.value)
messages = []
for key, value in values.items():
messages.append(
Message(topic=self.topic, source=str(self), value={key: value})
)
return messages
[docs]
def put(self, message: Message) -> None:
"""put a single variable into the interface"""
if not isinstance(message.value, dict):
raise ValueError('message value must be a dictionary')
for key, value in zip(message.keys, message.values):
self.interface.put(key, value)
[docs]
def put_many(self, message: Message) -> None:
"""put many variables into the interface"""
if not isinstance(message.value, dict):
raise ValueError('message value must be a dictionary')
self.interface.put_many(message.value)
[docs]
class MockModel:
def __init__(self):
"""placeholder for model"""
[docs]
def evaluate(self, value):
"""placeholder for model prediction"""
return {'not_initialized': {'value': -99999999999}}
[docs]
class ModelObserver(Observer):
def __init__(
self,
model=None,
config=None,
topic: str = 'model',
unpack_input: bool = True,
pack_output: bool = True,
):
"""wraps around the model.predict method"""
self.model = model
self.topic = topic
self.config = config
self.unpack_input = unpack_input
self.pack_output = pack_output
if self.model is None and self.config is not None:
self.model = self.__get_model()
if not hasattr(self.model, 'evaluate'):
raise ValueError('model must have a .evaluate() method')
elif self.model is not None:
self.model = model
else:
raise ValueError('model must be provided or a config to load a model')
def __get_model(self):
"""load the model from the config"""
if self.config['type'] == 'mock':
return MockModel()
if self.config['type'] == 'MlflowModelGetter':
model_getter = registered_model_getters['mlflow'](
self.config['args']
) # legacy name well make it consistent across the board in the future
model = model_getter.get_model()
# check model is not None
if model is None:
raise ValueError('model is None')
return model
else:
raise ValueError(f'model type not recognised: {self.config["type"]}')
[docs]
def update(self, message: Message) -> list[Message]:
messages = []
logger.debug(f'updating {self}')
if self.unpack_input:
# logger.debug(f"unpacking input: {message.value}")
value = {v: message.value[v]['value'] for v in message.value}
else:
# logger.debug(f"not unpacking input passign raw: {message.value}")
value = message.value
pred = self.model.evaluate(value)
output = {}
if self.pack_output:
# logger.debug(f"packing output: {pred}")
for key, value in pred.items():
output[key] = {'value': value}
else:
# logger.debug(f"not packing output passign raw: {pred}")
output = pred
messages.append(Message(topic=self.topic, source=str(self), value=output))
return messages
# class GenericObserver(Observer):
# def __init__(self, callback):
# """wraps around the callback method, a catch all observer"""
# self.callback = callback
# def update(self, message: Message) -> None:
# self.callback(message)