728x90
Pytorch 에는 Parameter라는 모듈이 있는데, 얘는 레이어가 아니라 말 그대로 파라미터 값만을 가지고 있는 놈이다.
class Actor(nn.Module):
def __init__(self, num_inputs, num_outputs, continuous=True, shared=False):
self.num_inputs = num_inputs
self.num_outputs = num_outputs
super(Actor, self).__init__()
self.fc1 = nn.Linear(num_inputs, hp.hidden)
self.fc2 = nn.Linear(hp.hidden, hp.hidden)
self.fc3 = nn.Linear(hp.hidden, num_outputs)
self.fc3.weight.data.mul_(0.1)
self.fc3.bias.data.mul_(0.0)
self.continuous = continuous
if self.continuous:
self.logstd = nn.Parameter(torch.zeros(1, num_outputs))
self.shared = shared
if self.shared:
self.fc3_critic = nn.Linear(hp.hidden, 1)
self.fc3_critic.weight.data.mul_(0.1)
self.fc3_critic.bias.data.mul_(0.0)
def forward(self, x, critic=False):
x = F.gelu(self.fc1(x))
x = F.gelu(self.fc2(x))
if critic :
value = self.fc3_critic(x)
return value
else:
mu = self.fc3(x)
#logstd = torch.zeros_like(mu)
if self.continuous:
logstd = self.logstd.expand_as(mu)
std = torch.exp(logstd)
return mu, std, logstd
else:
return F.softmax(mu, dim=-1), None, None
위 예시(RL에서의 Actor 모델)에서는 logstd 라는 파라미터를 가지는데 이는 학습가능한 파라미터로서의 action 로그 표준편차이다.
요놈을 모델이 가지는 initial knowledge(업데이트 가능한) 같은 요소로도 사용 가능할 것 같다.
예제코드:
import torch
class dummy(torch.nn.Module):
def __init__(self):
super(dummy, self).__init__()
self.m = torch.nn.Linear(10,1)
self.par = torch.nn.Parameter(torch.ones(10))
def forward(self, x):
x = torch.mul(x, self.par)
return self.m(x)
x = torch.randn(10)
'''
ex) tensor([ 0.3600, -0.1501, -1.3687, 0.3313, -0.8489, 0.2747, 1.3481, -0.2329,
-0.9797, -1.2622])
'''
model = dummy()
optimizer = torch.optim.SGD(model.parameters(), 0.1)
out = model(x)
loss = torch.sum((out-target)**2)
print(model.par)
'''
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True)
'''
loss.backward()
optimizer.step()
print(model.par)
'''
Parameter containing:
tensor([0.9717, 1.0124, 0.8922, 0.9711, 0.9394, 1.0182, 0.8684, 1.0105, 1.0106,
1.0173], requires_grad=True)
'''
728x90
반응형
'잡지식 저장고 > Pytorch' 카테고리의 다른 글
Pytorch-TensorRT Dataparallel Inference 직접 만들기 (1) | 2023.04.28 |
---|---|
detectron2 에서 Faster R-CNN RPN에 GradCAM 붙이기 (0) | 2021.03.30 |
댓글