-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.toml
32 lines (28 loc) · 929 Bytes
/
predict.toml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
[PREP]
data_dir = "</path/to/the/WAV-files/to/run/the/model/on>"
output_dir = "</path/to/where/the/organizing/CSV-file/should/be/stored>"
audio_format = "wav"
spect_output_dir = "</path/to/where/the/spectrograms/should/be/stored>"
[SPECT_PARAMS]
fft_size = 1024
step_size = 480
freq_cutoffs = [ 750, 18000,]
transform_type = "log_spect"
spect_key = "s"
freqbins_key = "f"
timebins_key = "t"
audio_path_key = "audio_path"
[DATALOADER]
window_size = 800
[PREDICT]
device = "cuda"
num_workers = 2
models = [ "TweetyNetModel",]
batch_size = 128
checkpoint_path = "</path/to/the/model/TweetyNet/checkpoints/max-val-acc-checkpoint.pt>"
labelmap_path = "</path/to/the/model/>labelmap.json"
spect_scaler_path = "</path/to/the/model/>StandardizeSpect"
output_dir = "</dir/where/the/resulting/CSV-file/should/be/stored>"
annot_csv_filename = "<name of the CSV-file with the predictions>.csv"
[TweetyNetModel.network]
hidden_size = 512