How to re use trained weights from a serialized state dict that was not built with your torch.nn.Module - SteampunkIslande/CNNAutoencoder GitHub Wiki

Have you ever build a Pytorch model, found this very similar architecture around in github, and realized that the pretrained serialized model didn't match your architecture ?

Well, fear not, for I came across this nice solution while playing around with this AutoEncoder.

With this tutorial, I will take you through all the steps from creating your own module to running it using pre-trained weights from another compatible serialized module.

Prerequisite: Build your module

If you already know pytorch, you can skip to next step

Otherwise, you can use the UNet module definition found on this repository. For the rest of this tutorial, I will assume that you already know what a module is.

Problem: Trying to load_state_dict on your module from serialized module with similar architecture

First, it is important to understand what serialization really is. Pytorch provides pickle-based data serialization, which makes saving and loading your module's state really easily.

Basically, serialization is the operation of writing any instance that supports serialization into a binary file that you can load later on into an instance with the exact same properties. This is called deserialiation.

So, back to your module, let's try to load your model's state_dict using a file that contains weights and biases from another module.

Below is some sample code (with UNet and a file you can find here):

ae = UNet()
ae.load_state_dict(torch.load("some/path/to/a/saved/model.pth",map_location=lambda storage, loc: storage))
ae.eval()

Of course, you can replace UNet with any torch.nn.module you wish to load the trained weights to.

And this is where it could fail. If your model's layers have the exact same name as the pre-trained model ones, it will work just fine. However, if for some reason you had slightly different architecture, the function will fail.

Here is an example with my code: load_state_dict : keys don't match

As you can see, every key of my model have a corresponding key in the serialized model I'm trying to load. Though, python has no way to tell, if two keys have similar names, that they are actually referring to the same layers. Pretty frustrating, huh ?

Proposed solution: the workaround

Note that this solution is very experimental, and I am not responsible in any case of the harm it may or may not cause to your project. That being said, there is no real danger other than just not having it to work...

Anyway, the solution is pretty basic.

Step one: load the serialized dict containing the weights you wish to transfer

model_state_dict = torch.load("path/to/pretrained/weights.pth", map_location=lambda storage,loc : storage)
print(model_state_dict.keys())

If the serialized model only contains one ordered dict with the weights and biases, then you're off to step two. Otherwise, if the serialized data was a dict containing more info, you may find the saved weights at a given key in the latter dict. If so, go to step two once you have the ordered dict containing the weights.

Step two: get your model's layer names

model = UNet()
print(model.state_dict().keys())

Now you can see where your module matches the serialized one. If you're lucky enough, it's a perfect match !

Let's put all this together into these few following gists:

Step three: manually match your weights with the serialized ones

import torch

from itertools import zip_longest
def saveOrderedDicts(serialized_dict, destination_dict, path):
    f=open(path,"wt")
    for k1,k2 in zip_longest(serialized_dict.keys(),destination_dict.keys()):
        f.write(f"{k1},{k2}\n")
    f.close()

If you already ran the previous few lines of code inside a python console, then serialized_dict and destination_dict are already defined as model_state_dict and model.state_dict() respectively. If you run this small function with a valid csv file name, it will output the serialized layer's names in the first column, and the names of your module's layers on the second column.

You can now open the csv file in the software of your choice (whether Libreoffice or Excel). Now make sure that layer names with same purpose are on the same line. You may have to drag some layer names from line to line, just make sure that you keep things consistent. That is, first column only contains names from the serialized state_dict keys, and second column only contains names from your module's layer names. Also, you shouldn't leave any line empty.

Once this is done, you can go to next step.

Step four: Translation!

If you made it to this step, congratulations! In this section, you will create a powerful tool to apply transfer learning even though you don't have the same network architecture as the one the pretrained weights come from.

Now that there is a perfect match between the serialized layer names and yours, it is time to create a translation dict (literally). Here is a simple gist for that (I know you could have figured it out yourself, but this wouldn't be a tutorial anymore, right ?)

def csvToTranslationDict(path):
    f=open(path)
    result={}
    for line in f:
        k,v = line.strip().split(",")
        result[k]=v
    return result

Now all this does is it creates a dictionnary with key being the layer name of your serialized file, and value is the layer name in your module.

Since this dict perfectly maps every serialized layer name to every layer name in your module, the translation is, as you can see, pretty straitforward:

def translateStateDicts(translation_dict,origin_dict):
    return {translation_dict[k] : v for k,v in origin_dict.items()}

To use it, just call

my_model_dict = translateStateDicts(translation,model_state_dict)
model.load_state_dict(my_model_dict)

If you followed the previous steps as mentionned above, then you should have this very satisfying message from the interpreter:

Final thoughts

Now you can use someone else's pretrained weights, you have the workflow here to do it yourself. The only challenging part (and it can be really tedious or tricky) is to match every single layer of your module with every layer of the serialized one. I tested it only with UNet (as my module) and these pretrained weights, so this might not work for other combinations of models/pretrained weights.

Nevertheless, I really hope this tutorial was helpful to you, and if so don't hesitate to give any feedback and share this page, I would really appreciate !

P.S.: The gists used in this tutorial are of course available as such in the form of gists, as well as a file in this repository.