【AI】Pytorchモデルから重み、バイアスを取得する方法

こんにちは、ヒガシです。

 

このページでは、Pytorchを使って構築したモデルから、重みとバイアスを取り出す方法をご紹介していきます。

 

それではさっそくやっていきましょう!

 

スポンサーリンク

Pytorchでモデルを構築する

まずはモデルがないと始まりませんので、MNIST画像の分類問題を解くモデルを適当に作成してみます。

 

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# ニューラルネットワークの定義
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 512) # 入力画像サイズ: 28x28, 出力サイズ: 512
        self.fc2 = nn.Linear(512, 512)     # 入力サイズ: 512, 出力サイズ: 512
        self.fc3 = nn.Linear(512, 10)      # 入力サイズ: 512, 出力サイズ: 10 (0~9の数字分類)

    def forward(self, x):
        x = x.view(-1, 28 * 28)   # 画像を1次元に平坦化
        x = F.relu(self.fc1(x))   # ReLU活性化関数を適用
        x = F.relu(self.fc2(x))   # ReLU活性化関数を適用
        x = self.fc3(x)           # 出力層への入力
        return x

net = Net().to(device)
summary(net,(1,28,28))

 

こいつを実行すると以下の結果が得られました。

torchsummaryでモデルを確認した結果

問題なくモデル定義できていますね。

 

スポンサーリンク

Pytrochモデルから重み、バイアスを取り出す方法

それでは本題であるPytorchのモデルから重みとバイアスを取り出してみましょう。

 

以下がそのコードです。

#パラメーターの取得
key_list = []
param_list =[]
for key, param in net.state_dict().items():
    key_list.append(key)
    param_list.append(net.state_dict()[key].cpu().numpy())
    
#重み、バイアスの名前を表示する
print(key_list)
for i in range(len(key_list)):
    #パラメーターのサイズを表示する
    print(key_list[i], param_list[i].shape)
#パラメータを表示してみる
print(param_list[0])

 

こいつを実行すると以下が出力されました。

Pytorchモデルから重みとバイアスを取得した結果

 

こちらも問題なく出力できていそうですね。

 

スポンサーリンク

パラメータの数を数えてみる

モデルを作成した際にtorchsummaryで確認したところ、モデル内パラメータは全部で669706個あるとされました。

torchsummaryでモデルを確認した結果

先ほど出力されたモデルのサイズから本当にこうなっているか確認してみましょう。

 

以下で確認してみます。

単純にパラメータの縦、横をかけて足していっているだけです。

#パラメータの総数を数える
num_param = 0
for j in range(0,len(key_list),2):
    num_param += param_list[j].shape[0] * param_list[j].shape[1]
    num_param += param_list[j + 1].shape[0]
print(num_param)

 

問題なく669706が出力されました。

 

スポンサーリンク

おわりに

ということで今回はPytorchモデルから重みとバイアスを取り出す方法をご紹介しました。

ライブラリが使えない環境下にモデルを置く際などにぜひご活用ください。

 

このブログでは、このようなAIスキルを多数紹介しています。

ぜひ他のページもご覧ください。

過去記事一覧

コメント

タイトルとURLをコピーしました