-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #398 from knshnb/chainer-graph
Chainer graph
- Loading branch information
Showing
35 changed files
with
1,372 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
73 changes: 73 additions & 0 deletions
73
chainer_chemistry/dataset/graph_dataset/base_graph_data.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import numpy | ||
import chainer | ||
|
||
|
||
class BaseGraphData(object): | ||
"""Base class of graph data """ | ||
|
||
def __init__(self, *args, **kwargs): | ||
for k, v in kwargs.items(): | ||
setattr(self, k, v) | ||
|
||
def to_device(self, device): | ||
"""Send self to `device` | ||
Args: | ||
device (chainer.backend.Device): device | ||
Returns: | ||
self sent to `device` | ||
""" | ||
for k, v in self.__dict__.items(): | ||
if isinstance(v, (numpy.ndarray)): | ||
setattr(self, k, device.send(v)) | ||
elif isinstance(v, (chainer.utils.CooMatrix)): | ||
data = device.send(v.data.array) | ||
row = device.send(v.row) | ||
col = device.send(v.col) | ||
device_coo_matrix = chainer.utils.CooMatrix( | ||
data, row, col, v.shape, order=v.order) | ||
setattr(self, k, device_coo_matrix) | ||
return self | ||
|
||
|
||
class PaddingGraphData(BaseGraphData): | ||
"""Graph data class for padding pattern | ||
Args: | ||
x (numpy.ndarray): input node feature | ||
adj (numpy.ndarray): adjacency matrix | ||
y (int or numpy.ndarray): graph or node label | ||
""" | ||
|
||
def __init__(self, x=None, adj=None, super_node=None, pos=None, y=None, | ||
**kwargs): | ||
self.x = x | ||
self.adj = adj | ||
self.super_node = super_node | ||
self.pos = pos | ||
self.y = y | ||
self.n_nodes = x.shape[0] | ||
super(PaddingGraphData, self).__init__(**kwargs) | ||
|
||
|
||
class SparseGraphData(BaseGraphData): | ||
"""Graph data class for sparse pattern | ||
Args: | ||
x (numpy.ndarray): input node feature | ||
edge_index (numpy.ndarray): sources and destinations of edges | ||
edge_attr (numpy.ndarray): attribution of edges | ||
y (int or numpy.ndarray): graph or node label | ||
""" | ||
|
||
def __init__(self, x=None, edge_index=None, edge_attr=None, | ||
pos=None, super_node=None, y=None, **kwargs): | ||
self.x = x | ||
self.edge_index = edge_index | ||
self.edge_attr = edge_attr | ||
self.pos = pos | ||
self.super_node = super_node | ||
self.y = y | ||
self.n_nodes = x.shape[0] | ||
super(SparseGraphData, self).__init__(**kwargs) |
134 changes: 134 additions & 0 deletions
134
chainer_chemistry/dataset/graph_dataset/base_graph_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
import chainer | ||
import numpy | ||
from chainer._backend import Device | ||
from chainer_chemistry.dataset.graph_dataset.base_graph_data import \ | ||
BaseGraphData | ||
from chainer_chemistry.dataset.graph_dataset.feature_converters \ | ||
import batch_with_padding, batch_without_padding, concat, shift_concat, \ | ||
concat_with_padding, shift_concat_with_padding | ||
|
||
|
||
class BaseGraphDataset(object): | ||
"""Base class of graph dataset (list of graph data)""" | ||
_pattern = '' | ||
_feature_entries = [] | ||
_feature_batch_method = [] | ||
|
||
def __init__(self, data_list, *args, **kwargs): | ||
self.data_list = data_list | ||
|
||
def register_feature(self, key, batch_method, skip_if_none=True): | ||
"""Register feature with batch method | ||
Args: | ||
key (str): name of the feature | ||
batch_method (function): batch method | ||
skip_if_none (bool, optional): If true, skip if `batch_method` is | ||
None. Defaults to True. | ||
""" | ||
if skip_if_none and getattr(self.data_list[0], key, None) is None: | ||
return | ||
self._feature_entries.append(key) | ||
self._feature_batch_method.append(batch_method) | ||
|
||
def update_feature(self, key, batch_method): | ||
"""Update batch method of the feature | ||
Args: | ||
key (str): name of the feature | ||
batch_method (function): batch method | ||
""" | ||
|
||
index = self._feature_entries.index(key) | ||
self._feature_batch_method[index] = batch_method | ||
|
||
def __len__(self): | ||
return len(self.data_list) | ||
|
||
def __getitem__(self, item): | ||
return self.data_list[item] | ||
|
||
def converter(self, batch, device=None): | ||
"""Converter | ||
Args: | ||
batch (list[BaseGraphData]): list of graph data | ||
device (int, optional): specifier of device. Defaults to None. | ||
Returns: | ||
self sent to `device` | ||
""" | ||
if not isinstance(device, Device): | ||
device = chainer.get_device(device) | ||
batch = [method(name, batch, device=device) for name, method in | ||
zip(self._feature_entries, self._feature_batch_method)] | ||
data = BaseGraphData( | ||
**{key: value for key, value in zip(self._feature_entries, batch)}) | ||
return data | ||
|
||
|
||
class PaddingGraphDataset(BaseGraphDataset): | ||
"""Graph dataset class for padding pattern""" | ||
_pattern = 'padding' | ||
|
||
def __init__(self, data_list): | ||
super(PaddingGraphDataset, self).__init__(data_list) | ||
self.register_feature('x', batch_with_padding) | ||
self.register_feature('adj', batch_with_padding) | ||
self.register_feature('super_node', batch_with_padding) | ||
self.register_feature('pos', batch_with_padding) | ||
self.register_feature('y', batch_without_padding) | ||
self.register_feature('n_nodes', batch_without_padding) | ||
|
||
|
||
class SparseGraphDataset(BaseGraphDataset): | ||
"""Graph dataset class for sparse pattern""" | ||
_pattern = 'sparse' | ||
|
||
def __init__(self, data_list): | ||
super(SparseGraphDataset, self).__init__(data_list) | ||
self.register_feature('x', concat) | ||
self.register_feature('edge_index', shift_concat) | ||
self.register_feature('edge_attr', concat) | ||
self.register_feature('super_node', concat) | ||
self.register_feature('pos', concat) | ||
self.register_feature('y', batch_without_padding) | ||
self.register_feature('n_nodes', batch_without_padding) | ||
|
||
def converter(self, batch, device=None): | ||
"""Converter | ||
add `self.batch`, which represents the index of the graph each node | ||
belongs to. | ||
Args: | ||
batch (list[BaseGraphData]): list of graph data | ||
device (int, optional): specifier of device. Defaults to None. | ||
Returns: | ||
self sent to `device` | ||
""" | ||
data = super(SparseGraphDataset, self).converter(batch, device=device) | ||
if not isinstance(device, Device): | ||
device = chainer.get_device(device) | ||
data.batch = numpy.concatenate([ | ||
numpy.full((data.x.shape[0]), i, dtype=numpy.int) | ||
for i, data in enumerate(batch) | ||
]) | ||
data.batch = device.send(data.batch) | ||
return data | ||
|
||
# for experiment | ||
# use converter for the normal use | ||
def converter_with_padding(self, batch, device=None): | ||
self.update_feature('x', concat_with_padding) | ||
self.update_feature('edge_index', shift_concat_with_padding) | ||
data = super(SparseGraphDataset, self).converter(batch, device=device) | ||
if not isinstance(device, Device): | ||
device = chainer.get_device(device) | ||
max_n_nodes = max([data.x.shape[0] for data in batch]) | ||
data.batch = numpy.concatenate([ | ||
numpy.full((max_n_nodes), i, dtype=numpy.int) | ||
for i, data in enumerate(batch) | ||
]) | ||
data.batch = device.send(data.batch) | ||
return data |
115 changes: 115 additions & 0 deletions
115
chainer_chemistry/dataset/graph_dataset/feature_converters.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import numpy | ||
from chainer.dataset.convert import _concat_arrays | ||
|
||
|
||
def batch_with_padding(name, batch, device=None, pad=0): | ||
"""Batch with padding (increase ndim by 1) | ||
Args: | ||
name (str): propaty name of graph data | ||
batch (list[BaseGraphData]): list of base graph data | ||
device (chainer.backend.Device, optional): device. Defaults to None. | ||
pad (int, optional): padding value. Defaults to 0. | ||
Returns: | ||
BaseGraphDataset: graph dataset sent to `device` | ||
""" | ||
feat = _concat_arrays( | ||
[getattr(example, name) for example in batch], pad) | ||
return device.send(feat) | ||
|
||
|
||
def batch_without_padding(name, batch, device=None): | ||
"""Batch without padding (increase ndim by 1) | ||
Args: | ||
name (str): propaty name of graph data | ||
batch (list[BaseGraphData]): list of base graph data | ||
device (chainer.backend.Device, optional): device. Defaults to None. | ||
Returns: | ||
BaseGraphDataset: graph dataset sent to `device` | ||
""" | ||
feat = _concat_arrays( | ||
[getattr(example, name) for example in batch], None) | ||
return device.send(feat) | ||
|
||
|
||
def concat_with_padding(name, batch, device=None, pad=0): | ||
"""Concat without padding (ndim does not increase) | ||
Args: | ||
name (str): propaty name of graph data | ||
batch (list[BaseGraphData]): list of base graph data | ||
device (chainer.backend.Device, optional): device. Defaults to None. | ||
pad (int, optional): padding value. Defaults to 0. | ||
Returns: | ||
BaseGraphDataset: graph dataset sent to `device` | ||
""" | ||
feat = batch_with_padding(name, batch, device=device, pad=pad) | ||
a, b = feat.shape | ||
return feat.reshape((a * b)) | ||
|
||
|
||
def concat(name, batch, device=None, axis=0): | ||
"""Concat with padding (ndim does not increase) | ||
Args: | ||
name (str): propaty name of graph data | ||
batch (list[BaseGraphData]): list of base graph data | ||
device (chainer.backend.Device, optional): device. Defaults to None. | ||
pad (int, optional): padding value. Defaults to 0. | ||
Returns: | ||
BaseGraphDataset: graph dataset sent to `device` | ||
""" | ||
feat = numpy.concatenate([getattr(data, name) for data in batch], | ||
axis=axis) | ||
return device.send(feat) | ||
|
||
|
||
def shift_concat(name, batch, device=None, shift_attr='x', shift_axis=1): | ||
"""Concat with index shift (ndim does not increase) | ||
Concatenate graphs into a big one. | ||
Used for sparse pattern batching. | ||
Args: | ||
name (str): propaty name of graph data | ||
batch (list[BaseGraphData]): list of base graph data | ||
device (chainer.backend.Device, optional): device. Defaults to None. | ||
Returns: | ||
BaseGraphDataset: graph dataset sent to `device` | ||
""" | ||
shift_index_array = numpy.cumsum( | ||
numpy.array([0] + [getattr(data, shift_attr).shape[0] | ||
for data in batch])) | ||
feat = numpy.concatenate([ | ||
getattr(data, name) + shift_index_array[i] | ||
for i, data in enumerate(batch)], axis=shift_axis) | ||
return device.send(feat) | ||
|
||
|
||
def shift_concat_with_padding(name, batch, device=None, shift_attr='x', | ||
shift_axis=1): | ||
"""Concat with index shift and padding (ndim does not increase) | ||
Concatenate graphs into a big one. | ||
Used for sparse pattern batching. | ||
Args: | ||
name (str): propaty name of graph data | ||
batch (list[BaseGraphData]): list of base graph data | ||
device (chainer.backend.Device, optional): device. Defaults to None. | ||
Returns: | ||
BaseGraphDataset: graph dataset sent to `device` | ||
""" | ||
max_n_nodes = max([data.x.shape[0] for data in batch]) | ||
shift_index_array = numpy.arange(0, len(batch) * max_n_nodes, max_n_nodes) | ||
feat = numpy.concatenate([ | ||
getattr(data, name) + shift_index_array[i] | ||
for i, data in enumerate(batch)], axis=shift_axis) | ||
return device.send(feat) |
Oops, something went wrong.