Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

I am using the following function to flatten the network:

#############################################################################
# Flattening the NET
#############################################################################
def flattenNetwork(net):
    flatNet = []
    shapes = []
    for param in net.parameters():
        #if its WEIGHTS
        curr_shape = param.cpu().data.numpy().shape
        shapes.append(curr_shape)
        if len(curr_shape) == 2:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1])
            flatNet.append(param)
        elif len(curr_shape) == 4:
            param = param.cpu().data.numpy().reshape(curr_shape[0]*curr_shape[1]*curr_shape[2]*curr_shape[3])
            flatNet.append(param)
        else:
            param = param.cpu().data.numpy().reshape(curr_shape[0])
            flatNet.append(param)
    finalNet = []
    for obj in flatNet:
        for x in obj:
            finalNet.append(x)
    finalNet = np.array(finalNet)
    return finalNet,shapes

The above function returns all the weights as a numpy column vector finalNet and shapes (list) of the network. I want to see the effect of weight modifications on the prediction accuracy. So, I change the weights. How can I copy this modified weight vector back to the original network? Please help. Thank you.

question from:https://stackoverflow.com/questions/65941834/pytorch-how-to-unflatten-get-back-the-network-from-flattened-network

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
190 views
Welcome To Ask or Share your Answers For Others

1 Answer

There is a difference between model definition (its forward function), and the parameter configuration (what's called model state, and is easily accessible as a dictionary using state_dict).

You can get a model's state, as you did with your implementation flattenNetwork. However reverting this operation (i.e. if you only have the weights and layer shapes), for pretty much all models, is not possible.

Now, assuming you do - still - have access to net. My advice is that work with net.state_dict() directly, modify it, then load the dictionary of weights back with load_state_dict. This way, you will avoid having to deal with serializing the model's parameters yourself.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share

548k questions

547k answers

4 comments

86.3k users

...