# https://towardsdatascience.com/linear-regression-with-pytorch-eb6dedead817

import numpy as np
import torch
import matplotlib.pyplot as plt

class linearRegression(torch.nn.Module):
    def __init__(self, inputSize, outputSize):
        super(linearRegression,self).__init__()
        self.linear = torch.nn.Linear(inputSize,outputSize)

    def forward(self, x):
        out = self.linear(x)
        return out
    
    
x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32)
x_train = x_train.reshape(-1,1)
print(x_train)

y_values = [2*i+1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1,1)
print(y_train)

inputDim = 1
outputDim = 1
learningRate = 0.01
epochs = 100

model = linearRegression(inputDim, outputDim)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learningRate)

for epoch in range(epochs):
    inputs = torch.from_numpy(x_train)
    labels = torch.from_numpy(y_train)
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print('epoch {}, loss {}'.format(epoch, loss.item()))

with torch.no_grad():
    predicted = model(torch.from_numpy(x_train)).data.numpy()
    print(predicted)

    plt.clf()
    plt.plot(x_train, y_train, 'go', label='True Data', alpha=0.5)
    plt.plot(x_train, predicted, '--', label='Predictions', alpha=0.5)
    plt.legend(loc='best')
    plt.show()
    
               
                      
                      
