-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathgenerate.py
38 lines (32 loc) · 1.27 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
This module generates the data for the sorting task.
"""
import os
import numpy as np
from config.generate import get_config
from utils import RANDOM_SEED
def generate(name, size, max_val, min_length, max_length, ewc):
"""Generates <size> samples for the dataset,
"""
np.random.seed(RANDOM_SEED)
if not os.path.isdir("data"):
os.mkdir("data")
with open("data/" + name + ".txt", mode='w') as file:
file.write("|".join([str(size), str(max_val), str(max_length)]) + "\n")
if ewc:
for length in range(2, max_length + 1):
for i in range(size):
if i % 10000 == 0:
print(length, i)
lst = list(np.random.randint(2, max_val, length))
srt = sorted(lst)
file.write(str(lst) + "|" + str(srt) + "\n")
for i in range(size):
if i % 10000 == 0:
print(i)
lst = list(np.random.randint(2, max_val, np.random.randint(min_length, max_length)))
srt = sorted(lst)
file.write(str(lst) + "|" + str(srt) + "\n")
if __name__ == "__main__":
args, unparsed = get_config()
generate(args.name, args.size, args.max_val, args.min_length, args.max_length, args.ewc)