import torch
from torch import nn
import matplotlib
import matplotlib.pyplot as plt
import numpy
Weight = 0.07
Bias = 0.03
Start = 0
End = 1
Step = 0.02
X = torch.arange(Start, End, Step).unsqueeze(dim=1)
y = Weight * X + Bias
train_split = int((0.8 * len(X)))
X_train, y_train = X[:train_split], y[:train_split]
X_test, y_test = X[train_split:], y[train_split:]
def plot_predictions(train_data=X_train,
train_labels=y_train,
test_data=X_test,
test_labels=y_test,
predictions=None):
plt.figure(figsize=(10, 7))
plt.scatter(train_data, train_labels, c="b", s=4, label="Training Data")
plt.scatter(test_data, test_labels, c="g", s=4, label="Test Data")
if predictions is not None:
plt.scatter(test_data, predictions, c="r", s=4, label="predictions")
plt.legend(prop={"size": 14})
plot_predictions()
I am not getting an error, but im following a tutorial to learn pytorch and on this one its supposed to generate a graph but it just finishes and exits