Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

I would like to add a train test function to the KAN class #44

Open
riteshshergill opened this issue Jun 6, 2024 · 0 comments
Open

Comments

@riteshshergill
Copy link

can't seem to open a Branch for raising a pull request so adding code here:

def train_model(self, model, trainloader, valloader, optimizer, scheduler, criterion, device, epochs):
model.to(device)
for epoch in range(epochs):
# Train
model.train()
with tqdm(trainloader) as pbar:
for i, (images, labels) in enumerate(pbar):
images = images.view(-1, 28 * 28).to(device)
optimizer.zero_grad()
output = model(images)
loss = criterion(output, labels.to(device))
loss.backward()
optimizer.step()
accuracy = (output.argmax(dim=1) == labels.to(device)).float().mean()
pbar.set_postfix(loss=loss.item(), accuracy=accuracy.item(), lr=optimizer.param_groups[0]['lr'])

        # Validation
        model.eval()
        val_loss = 0
        val_accuracy = 0
        with torch.no_grad():
            for images, labels in valloader:
                images = images.view(-1, 28 * 28).to(device)
                output = model(images)
                val_loss += criterion(output, labels.to(device)).item()
                val_accuracy += (
                    (output.argmax(dim=1) == labels.to(device)).float().mean().item()
                )
        val_loss /= len(valloader)
        val_accuracy /= len(valloader)

        # Update learning rate
        scheduler.step()

        print(
            f"Epoch {epoch + 1}, Val Loss: {val_loss}, Val Accuracy: {val_accuracy}"
        )

def test_model(self, model, testloader, device, num_samples=10):
    model.to(device)
    model.eval()
    predictions = []
    ground_truths = []
    images_to_show = []

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for i, (images, labels) in enumerate(testloader):
            images = images.view(-1, 28 * 28).to(device)
            output = model(images)
            predictions.extend(output.argmax(dim=1).cpu().numpy())
            ground_truths.extend(labels.cpu().numpy())
            images_to_show.extend(images.view(-1, 28, 28).cpu().numpy())

            if len(predictions) >= num_samples:
                break

    # Print the predictions for the specified number of samples
    for i in range(num_samples):
        print(f"Ground Truth: {ground_truths[i]}, Prediction: {predictions[i]}")

    return predictions[:num_samples], ground_truths[:num_samples], images_to_show[:num_samples]
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant