import importlib import os from model.base import BaseModel from . import base # Automatically import any python files this directory package_dir = os.path.dirname(__file__) package = os.path.basename(package_dir) for file in os.listdir(package_dir): path = os.path.join(package_dir, file) if file.startswith('_') or file.startswith('.'): continue if file.endswith('.py') or os.path.isdir(path): model_name = file[:-3] if file.endswith('.py') else file importlib.import_module(f'{package}.{model_name}') def get_class(class_name: str) -> type: return base.train_classes[class_name] def get_trainer(model: BaseModel, **kwargs) -> base.Train: train_class = get_class(model.trainer) return train_class(model, **kwargs)