-
Notifications
You must be signed in to change notification settings - Fork 98
/
dataloader.py
52 lines (42 loc) · 2.04 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from __future__ import division
import numpy as np
import torch
import os
import logging
from torch.utils.data import DataLoader, Dataset, Sampler
logger = logging.getLogger('DeepAR.Data')
class TrainDataset(Dataset):
def __init__(self, data_path, data_name, num_class):
self.data = np.load(os.path.join(data_path, f'train_data_{data_name}.npy'))
self.label = np.load(os.path.join(data_path, f'train_label_{data_name}.npy'))
self.train_len = self.data.shape[0]
logger.info(f'train_len: {self.train_len}')
logger.info(f'building datasets from {data_path}...')
def __len__(self):
return self.train_len
def __getitem__(self, index):
return (self.data[index,:,:-1],int(self.data[index,0,-1]), self.label[index])
class TestDataset(Dataset):
def __init__(self, data_path, data_name, num_class):
self.data = np.load(os.path.join(data_path, f'test_data_{data_name}.npy'))
self.v = np.load(os.path.join(data_path, f'test_v_{data_name}.npy'))
self.label = np.load(os.path.join(data_path, f'test_label_{data_name}.npy'))
self.test_len = self.data.shape[0]
logger.info(f'test_len: {self.test_len}')
logger.info(f'building datasets from {data_path}...')
def __len__(self):
return self.test_len
def __getitem__(self, index):
return (self.data[index,:,:-1],int(self.data[index,0,-1]),self.v[index],self.label[index])
class WeightedSampler(Sampler):
def __init__(self, data_path, data_name, replacement=True):
v = np.load(os.path.join(data_path, f'train_v_{data_name}.npy'))
self.weights = torch.as_tensor(np.abs(v[:,0])/np.sum(np.abs(v[:,0])), dtype=torch.double)
logger.info(f'weights: {self.weights}')
self.num_samples = self.weights.shape[0]
logger.info(f'num samples: {self.num_samples}')
self.replacement = replacement
def __iter__(self):
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
def __len__(self):
return self.num_samples