Skip to content

Commit 959f7c2

Browse files
authored
Merge pull request #873 from rohitgr7/package/pytorch_lightning
Add pytorch-lightning to Dockerfile
2 parents 5b01c83 + b46156b commit 959f7c2

File tree

2 files changed

+74
-0
lines changed

2 files changed

+74
-0
lines changed

Dockerfile

+1
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ RUN pip install flashtext && \
443443
pip install tensorflow-datasets && \
444444
pip install pydub && \
445445
pip install pydegensac && \
446+
pip install pytorch-lightning && \
446447
/tmp/clean-layer.sh
447448

448449
# Tesseract and some associated utility packages

tests/test_pytorch_lightning.py

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import unittest
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch.utils.data import DataLoader, TensorDataset
6+
7+
import pytorch_lightning as pl
8+
from pytorch_lightning.metrics.functional import to_onehot
9+
10+
11+
class LitDataModule(pl.LightningDataModule):
12+
13+
def __init__(self, batch_size=16):
14+
super().__init__()
15+
16+
self.batch_size = batch_size
17+
18+
def setup(self, stage=None):
19+
X_train = torch.rand(100, 1, 28, 28).float()
20+
y_train = to_onehot(torch.randint(0, 10, size=(100,)), num_classes=10).float()
21+
X_valid = torch.rand(20, 1, 28, 28)
22+
y_valid = to_onehot(torch.randint(0, 10, size=(20,)), num_classes=10).float()
23+
24+
self.train_ds = TensorDataset(X_train, y_train)
25+
self.valid_ds = TensorDataset(X_valid, y_valid)
26+
27+
def train_dataloader(self):
28+
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
29+
30+
def val_dataloader(self):
31+
return DataLoader(self.valid_ds, batch_size=self.batch_size, shuffle=False)
32+
33+
34+
class LitClassifier(pl.LightningModule):
35+
36+
def __init__(self):
37+
super().__init__()
38+
self.l1 = torch.nn.Linear(28 * 28, 10)
39+
40+
def forward(self, x):
41+
return torch.relu(self.l1(x.view(x.size(0), -1)))
42+
43+
def training_step(self, batch, batch_idx):
44+
x, y = batch
45+
y_hat = self(x)
46+
loss = F.binary_cross_entropy_with_logits(y_hat, y)
47+
result = pl.TrainResult(loss)
48+
result.log('train_loss', loss, on_epoch=True)
49+
return result
50+
51+
def validation_step(self, batch, batch_idx):
52+
x, y = batch
53+
y_hat = self(x)
54+
loss = F.binary_cross_entropy_with_logits(y_hat, y)
55+
result = pl.EvalResult(checkpoint_on=loss)
56+
result.log('val_loss', loss)
57+
return result
58+
59+
def configure_optimizers(self):
60+
return torch.optim.Adam(self.parameters(), lr=0.02)
61+
62+
63+
class TestPytorchLightning(unittest.TestCase):
64+
65+
def test_version(self):
66+
self.assertIsNotNone(pl.__version__)
67+
68+
def test_mnist(self):
69+
dm = LitDataModule()
70+
model = LitClassifier()
71+
trainer = pl.Trainer(gpus=None, max_epochs=1)
72+
result = trainer.fit(model, datamodule=dm)
73+
self.assertTrue(result)

0 commit comments

Comments
 (0)