Source code for poly_lithic.src.model_utils.LocalModelGetter
# not implemented yet warning
import importlib.util
from poly_lithic.src.logging_utils import get_logger
from poly_lithic.src.model_utils import ModelGetterBase
logger = get_logger()
[docs]
class LocalModelGetter(ModelGetterBase):
def __init__(self, config):
self.model_module_path = config['model_path']
self.model_class_name = config['model_factory_class']
self.model_type = 'local'
self.requirements = config.get('requirements', None)
logger.debug(
f"LocalModelGetter initialized with model_module_path: {self.model_module_path}, "
f"model_class_name: {self.model_class_name}, requirements: {self.requirements}"
)
[docs]
def get_model(self):
# Import the model class from the specified module
spec = importlib.util.spec_from_file_location(
'model_module', self.model_module_path
)
model_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model_module)
model_factory_class = getattr(model_module, self.model_class_name)
# Create an instance of the model factory class
model_factory = model_factory_class()
model = model_factory.get_model()
return model
[docs]
def get_requirements(self):
return self.requirements