본문 바로가기

프로그래밍

ALL_ABOUT_TORCH - 1 / Custom Data Training

ALL_ABOUT_TORCH

사용자 입장에서 파이토치 라이브러리를 사용하면서 필요한 부분에 대하여 포스팅 한 글 입니다. 

Custom Data Training

이번 시간에는 Custom Data에 대해서 훈련 시키는 코드를 작성해보겠습니다. 

데이터의 형식은 npy로 csv, tsv등의 데이터 형식에 대해서는 먼저 numpy ndarray로 변환시키고 진행하시면 됩니다. 


1. 먼저 필요한 라이브러리를 import 해줍니다. 

import torch
import torch.nn as nn
import numpy as np 
import torch.nn.functional as F 

 

2. 사용자가 원하는 모델 또한 customize하였습니다. 

class fxnnxc(nn.Module):
    def __init__(self):
        super(fxnnxc, self).__init__()
        self.linear1 = nn.Linear(28*28, 256)
        self.linear2 = nn.Linear(256, 64)
        self.linear3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = self.linear3(x)
        return x
        
model = fxnnxc()

 

3. 데이터 

  1) 먼저 npy파일을 읽습니다. (넘파이 객체를 가지고 있다면 패스)

  2) 다음으로 넘파이를 torch.tensor 로 변환해줍니다. 

  3) X,y 쌍을  TensorDataSet에 보관합니다. 

  4) Dataset에서 batch로 데이터를 전달해줄 DataLoader를 사용합니다.  

#--------------------------------------------- file read
x_train = np.load("mnist_train.npy")
x_test = np.load("mnist_test.npy")
y_train = np.load("mnist_train_target.npy")
y_test = np.load("mnist_test_target.npy")
#--------------------------------------------- numpy to tensor
 # X를 long으로 하면 loss 계산할 때 에러
 # Y를 float으로 하면 loss 계산할 때 에러  
x_train  = torch.from_numpy(x_train).float()      
x_test   = torch.from_numpy(x_test).float()
y_train  = torch.from_numpy(y_train).long()        
y_test   = torch.from_numpy(y_test).long()

#--------------------------------------------- data to dataset
train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
test_dataset  = torch.utils.data.TensorDataset(x_test,  y_test)

#--------------------------------------------- dataset to dataloader 
train_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=2)
                                    
test_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=4,
                                          shuffle=True,
                                          num_workers=2)                                 
                                    

numpy를 tensor로 변환하는 부분에서 타입을 float, long으로 정확하게 입력해줘야 합니다. 

타입을 다른 것으로 할 경우, critetion에서 loss를 계산할 때, 에러를 반환할 수 있습니다. 

4. Loss, Optimizer 

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)

훈련을 진행하면서 loss를 계산하는 방법과 optimize 방법을 정하는 부분입니다. 

 

5. 훈련

EPOCH = 1

for t in range(EPOCH): # EPOCH
    for i, (sample, target) in enumerate(train_loader): #BATCH
    	#----- 이부분은 필요하지 않습니다. 
        # 데이터에 샘플의 dimension을 바꿔줬습니다. 
        # X 데이터가 28,28 --> 768 
        sample = sample.view(sample.size()[0], -1)   
        #-----------------------------------------------------
        
        y = model(sample)
        loss = criterion(y, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 1000 == 99:
            print(t, loss.item())        

이중 for 문으로 구성되어 있는데, 

첫 번째 for 문은 epoch을 나타내며 그 수만큼 데이터를 반복합니다. 

두 번쨰 for 문은 데이터를 minibatch로 반복하는 부분입니다. 이 부분에서 모든 데이터로 훈련을 진행합니다. 

6. 테스트

# Test

correct = 0
total = 0
with torch.no_grad(): 
    for data in test_loader:
        images, labels = data
        images = images.view(images.size()[0], -1)   
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
print(f'Accuracy of the test images: {100 * correct / total}')

test데이터에 대해서 마찬가지로 test_loader를 사용하여 batch로 예측을 진행합니다. 

 

7. 결론

파이토치 버전에 대해서 해당 코드가 제대로 작동하지 않을 수 있습니다. 

APPENDIX에서 해당 코드가 돌아가는 버전을 적어놨습니다. 

전체 코드는 GITHUB에 올려뒀습니다. 

 

github.com/fxnnxc/all_about_torch

 

fxnnxc/all_about_torch

Contribute to fxnnxc/all_about_torch development by creating an account on GitHub.

github.com


APPENDIX

계발환경

python 3.6.8
pytorch 1.7.0
numpy 1.19.2