-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataloader.py
47 lines (37 loc) · 1.6 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, datasets, models
import warnings
class PMnet_data_usc(Dataset):
def __init__(self,
dir_dataset="",
transform= transforms.ToTensor()):
self.dir_dataset = dir_dataset
self.transform = transform
self.png_list = os.listdir(dir_dataset + "data/cropped/power_map/")
def __len__(self):
return len(self.png_list)
def __getitem__(self, idx):
#Load city map
self.dir_buildings = self.dir_dataset+ "data/cropped/city_map/"
img_name_buildings = os.path.join(self.dir_buildings, f"{self.png_list[idx]}")
image_buildings = np.asarray(io.imread(img_name_buildings))
#Load Tx (transmitter):
self.dir_Tx = self.dir_dataset+ "data/cropped/tx_map/"
img_name_Tx = os.path.join(self.dir_Tx, f"{self.png_list[idx]}")
image_Tx = np.asarray(io.imread(img_name_Tx))
#Load Power:
self.dir_power = self.dir_dataset+ "data/cropped/power_map/"
img_name_power = os.path.join(self.dir_power, f"{self.png_list[idx]}")
image_power = np.asarray(io.imread(img_name_power))
inputs=np.stack([image_buildings, image_Tx], axis=2)
if self.transform:
inputs = self.transform(inputs).type(torch.float32)
power = self.transform(image_power).type(torch.float32)
return [inputs , power]