Creating new Trainer class - jniedzie/SVJanalysis_wiki GitHub Wiki
In order to add your architecture to the ML framework, apart from creating a new config file, you will need to implement a specialized Trainer class.
This class will be dynamically loaded within the framework (based on the path specified in the config) and some methods from this class will be called during the training. Because of that, you have to write this class following certain specification, such that the framework can find required methods and call them. Methods that are obligatory are marked with "@mandatory" in the doc string.
-
The easiest way to do that would be to copy one of the existing architectures (go to module/architectures), e.g. TrainerBdt.py as a new Trainer class - let's call it for now TrainerXYZ.py.
-
Inside of the newly created file, make sure to change the name of the class to match the name of the file, so 'class TrainerXYZ`.
-
The
__init__
method should takedata_processor
anddata_loader
arguments, followed by arbitrary architecture-specific arguments that will be passed from the config file. Probably it would be a good idea to load and normalize the data and prepare the model here. -
The class has to contain
train
method, which shouldn't take any arguments. This method should fit the model. -
There should be a property called
model
, which returns the model. -
Finally, it should have
get_summary
method, which returns a dict with some additional information to be stored in the summary file. The dict can be empty if you don't want to add any additional information, but the method should still be implemented.
And that's it - once you implement this class and prepare the config file, you should be able to run the training with:
python train.py -c config/my_new_arch_config.py