Skip to content

Commit 36cf030

Browse files
authored
Update gretel checkpoints to use temp directory (#105)
1 parent 51e6cc2 commit 36cf030

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

sdgym/synthesizers/gretel.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import os
1+
import tempfile
22

33
import numpy as np
44
from gretel_synthetics.batch import DataFrameBatch
@@ -9,11 +9,9 @@
99
class Gretel(SingleTableBaseline):
1010
"""Class to represent Gretel's neural network model."""
1111

12-
DEFAULT_CHECKPOINT_DIR = os.path.join(os.getcwd(), 'checkpoints')
13-
1412
def __init__(self, max_lines=0, max_line_len=2048, epochs=None, vocab_size=20000,
1513
gen_lines=None, dp=False, field_delimiter=",", overwrite=True,
16-
checkpoint_dir=DEFAULT_CHECKPOINT_DIR):
14+
checkpoint_dir=None):
1715
self.max_lines = max_lines
1816
self.max_line_len = max_line_len
1917
self.epochs = epochs
@@ -22,7 +20,7 @@ def __init__(self, max_lines=0, max_line_len=2048, epochs=None, vocab_size=20000
2220
self.dp = dp
2321
self.field_delimiter = field_delimiter
2422
self.overwrite = overwrite
25-
self.checkpoint_dir = checkpoint_dir
23+
self.checkpoint_dir = checkpoint_dir or tempfile.TemporaryDirectory().name
2624

2725
def _fit_sample(self, data, metadata):
2826
config = {

0 commit comments

Comments
 (0)