-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdataset.ts
71 lines (60 loc) · 2.38 KB
/
dataset.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
/**
* Nothing specific to GPT here, only a generic character-level
* dataset wrapper.
*
* A helper wrapper on top of any txt-file-based character-level
* dataset. It loads arbitrary txt file,treats each letter as a token,
* splits the characters into training and testing batches,
* encodes/decodes letters to indices and vice versa.
*/
import * as tf from '@tensorflow/tfjs'
import { Dataset, DatasetGetBatchParams, DatasetParams } from './types'
// Creates a character-level dataset, where each letter is a token.
export async function CharDataset(args: DatasetParams): Promise<Dataset> {
const { textSourceURL, textSource = '', maskZero = true } = args
// Whether to use 0-based or 1-based (0 is for masking) index.
const indexShift = maskZero ? 1 : 0
const text: string = textSourceURL ? await (await fetch(textSourceURL)).text() : textSource
const textSize: number = text.length
const chars: string[] = Array.from(new Set(text)).sort()
const vocabSize: number = chars.length
// Data encoders/decoders
const stoi = Object.fromEntries(chars.map((ch, i) => [ch, i + indexShift]))
const itos = Object.fromEntries(chars.map((ch, i) => [i + indexShift, ch]))
const encode = (s: string) => s.split('').map((c) => stoi[c])
const decode = (a: number[]) => a.map((i) => itos[i]).join('')
// Train and test splits
const data = tf.tensor(encode(text), [textSize], 'int32')
const dataSize: number = data.shape[0]
const n = Math.floor(0.9 * textSize)
const trainData: tf.Tensor = data.slice(0, n)
const valData: tf.Tensor = data.slice(n)
const getBatch = (args: DatasetGetBatchParams) =>
tf.tidy(() => {
const { split, blockSize, batchSize } = args
const data = split === 'train' ? trainData : valData
const maxval = data!.shape[0] - blockSize
const ix = tf.randomUniform([batchSize], 0, maxval, 'int32') // (B)
const ranges = tf.range(0, blockSize, 1, 'int32').expandDims(0) // (1,T)
const indices = ix.expandDims(1).add(ranges) // (B,T)
const x = tf.gather(data!, indices) // (B,T)
const y = tf.gather(data!, indices.add(tf.scalar(1, 'int32'))) // (B,T)
return { x, y }
})
const dispose = () => {
data.dispose()
trainData.dispose()
valData.dispose()
}
return {
textSourceURL,
vocabSize,
vocabulary: chars,
dataSize,
text,
getBatch,
encode,
decode,
dispose,
}
}