Source code for poly_lithic.src.model_utils.MlflowModelGetter
import mlflow
from mlflow import MlflowClient
from mlflow.models.model import get_model_info
from poly_lithic.src.logging_utils import get_logger
from poly_lithic.src.model_utils import ModelGetterBase
logger = get_logger()
try:
from lume_model.models import TorchModel, TorchModule
LUME_MODEL_AVAILABLE = True
except ImportError:
logger.warning("lume_model is not installed. TorchModel and TorchModule functionality will not be available.")
LUME_MODEL_AVAILABLE = False
[docs]
class MLflowModelGetter(ModelGetterBase):
def __init__(self, config):
model_name = config['model_name']
model_version = config['model_version']
logger.debug(f'MLflowModelGetter: {model_name}, {model_version}')
self.model_name = model_name
self.model_version = model_version
self.client = MlflowClient()
self.model_type = None
self.tags = None
[docs]
def get_config(self):
self.get_tags()
version = self.client.get_model_version(self.model_name, self.model_version)
if 'artifact_location' in self.tags.keys():
artifact_location = self.tags['artifact_location']
logger.debug(f'Artifact location: {artifact_location}')
else:
artifact_location = version.name
logger.debug(f'Artifact location: {artifact_location}')
self.client.download_artifacts(
version.run_id, f'{artifact_location}/pv_mapping.yaml', '.'
)
# return yaml.load(
# open(f"{artifact_location}/pv_mapping.yaml", "r"), Loader=yaml.FullLoader
# )
return f'{artifact_location}/pv_mapping.yaml'
[docs]
def get_requirements(self):
# Get dependencies
if int(self.model_version) >= 0:
version = self.client.get_model_version(self.model_name, self.model_version)
elif self.model_version == 'champion': # this is stupid I need to change it
version_no = self.client.get_model_version_by_alias(
self.model_name, self.model_version
)
version = self.client.get_model_version(self.model_name, version_no.version)
deps = mlflow.artifacts.download_artifacts(f'{version.source}/requirements.txt')
return deps
[docs]
def get_model(self):
self.get_tags()
version = self.client.get_model_version(self.model_name, self.model_version)
# flavor
flavor = get_model_info(model_uri=version.source).flavors
loader_module = flavor['python_function']['loader_module']
logger.debug(f'Loader module: {loader_module}')
if loader_module == 'mlflow.pyfunc.model':
logger.debug('Loading pyfunc model')
model_pyfunc = mlflow.pyfunc.load_model(model_uri=version.source)
# check if model has.get_lume_model() method
if not hasattr(model_pyfunc.unwrap_python_model(), 'get_lume_model'):
# check if it has get__model() method
if not hasattr(model_pyfunc.unwrap_python_model(), 'get_model'):
raise Exception(
'Model does not have get_lume_model() or get_model() method'
)
else:
logger.debug('Model has get_model() method')
logger.warning(
'get_model() suggests a non-LUME model, please check if model has an evaluate method'
)
model = model_pyfunc.unwrap_python_model().get_model()
else:
logger.debug('Model has get_lume_model() method')
model = model_pyfunc.unwrap_python_model().get_lume_model()
logger.debug(f'Model: {model}, Model type: {type(model)}')
self.model_type = 'pyfunc'
return model
elif loader_module == 'mlflow.pytorch':
print('Loading torch model')
model_torch_module = mlflow.pytorch.load_model(model_uri=version.source)
assert isinstance(model_torch_module, TorchModule)
model = model_torch_module.model
assert isinstance(model, TorchModel)
logger.debug(f'Model: {model}, Model type: {type(model)}')
self.model_type = 'torch'
return model
else:
raise Exception(f'Flavor {flavor} not supported')