본문 바로가기
잡지식 저장고/Pytorch

torch.nn.Parameter 에 관해서

by Slate_Knowledge 2021. 4. 7.
728x90

Pytorch 에는 Parameter라는 모듈이 있는데, 얘는 레이어가 아니라 말 그대로 파라미터 값만을 가지고 있는 놈이다.

class Actor(nn.Module):
    def __init__(self, num_inputs, num_outputs, continuous=True, shared=False):
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(num_inputs, hp.hidden)
        self.fc2 = nn.Linear(hp.hidden, hp.hidden)
        self.fc3 = nn.Linear(hp.hidden, num_outputs)
        self.fc3.weight.data.mul_(0.1)
        self.fc3.bias.data.mul_(0.0)
        self.continuous = continuous
        if self.continuous:
            self.logstd = nn.Parameter(torch.zeros(1, num_outputs))

        self.shared = shared
        if self.shared:
            self.fc3_critic = nn.Linear(hp.hidden, 1)
            self.fc3_critic.weight.data.mul_(0.1)
            self.fc3_critic.bias.data.mul_(0.0)

    def forward(self, x, critic=False):
        x = F.gelu(self.fc1(x))
        x = F.gelu(self.fc2(x))
        if critic :
            value = self.fc3_critic(x)
            return value
        else:
            mu = self.fc3(x)
        #logstd = torch.zeros_like(mu)
            if self.continuous:
                logstd = self.logstd.expand_as(mu)
                std = torch.exp(logstd)
                return mu, std, logstd
            else:
                return F.softmax(mu, dim=-1), None, None

위 예시(RL에서의 Actor 모델)에서는 logstd 라는 파라미터를 가지는데 이는 학습가능한 파라미터로서의 action 로그 표준편차이다. 

요놈을 모델이 가지는 initial knowledge(업데이트 가능한) 같은 요소로도 사용 가능할 것 같다.

예제코드:

import torch
class dummy(torch.nn.Module):
  def __init__(self):
    super(dummy, self).__init__()
    self.m = torch.nn.Linear(10,1)
    self.par = torch.nn.Parameter(torch.ones(10))
   
  def forward(self, x):
    x = torch.mul(x, self.par)
    return self.m(x)

x = torch.randn(10)
'''
ex) tensor([ 0.3600, -0.1501, -1.3687,  0.3313, -0.8489,  0.2747,  1.3481, -0.2329,
        -0.9797, -1.2622])
'''
model = dummy()
optimizer = torch.optim.SGD(model.parameters(), 0.1)
out = model(x)
loss = torch.sum((out-target)**2)
print(model.par)
'''
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)
'''
loss.backward()
optimizer.step()
print(model.par)
'''
Parameter containing:
tensor([0.9717, 1.0124, 0.8922, 0.9711, 0.9394, 1.0182, 0.8684, 1.0105, 1.0106,
        1.0173], requires_grad=True)
'''
728x90
반응형

댓글