Skip to content

Commit

Permalink
start mlp on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ChenS676 committed Jul 5, 2024
1 parent 87f0340 commit 15c8e3c
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 104 deletions.
Binary file removed core/model_finetuning/.create_dataset.py.swp
Binary file not shown.
31 changes: 3 additions & 28 deletions core/model_finetuning/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,6 @@ def __len__(self):
def __getitem__(self, idx):
return self.embeddings[idx], self.labels[idx]

class MLP(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(MLP, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.2)
self.bn1 = nn.BatchNorm1d(hidden_size)
self.bn2 = nn.BatchNorm1d(hidden_size)

def forward(self, x):
out = self.fc1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.dropout(out)
out = self.fc2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.dropout(out)
out = self.fc3(out)
return out

def init_weights(m):
if isinstance(m, nn.Linear):
Expand All @@ -86,7 +64,7 @@ def parse_args() -> argparse.Namespace:
r"""Parses the command line arguments."""
parser = argparse.ArgumentParser(description='GraphGym')
parser.add_argument('--data', dest='data', type=str, required=True,
default='pubmed',
default='cora',
help='data name')
parser.add_argument('--device', dest='device', required=False,
help='device id')
Expand All @@ -101,9 +79,7 @@ def parse_args() -> argparse.Namespace:
help='word embedding method')
parser.add_argument('--score', dest='score', type=str, required=False, default='mlp_score',
help='decoder name')
parser.add_argument('--decoder', dest='decoder', type=str, required=False, default='MLP',
help='decoder name')
parser.add_argument('--repeat', type=int, default=3,
parser.add_argument('--repeat', type=int, default=5,
help='The number of repeated jobs.')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
help='See graphgym/config.py for remaining options.')
Expand Down Expand Up @@ -132,7 +108,6 @@ def project_main():
custom_set_out_dir(cfg, args.cfg_file, cfg.wandb.name_tag)
# torch.set_num_threads(20)
loggers = create_logger(args.repeat)

for run_id, seed, split_index in zip(*run_loop_settings(cfg, args)):
print(f'run id : {run_id}')
# Set configurations for each run TODO clean code here
Expand All @@ -152,7 +127,7 @@ def project_main():
clf = RidgeClassifier(tol=1e-2, max_iter=10000, solver="sparse_cg")
clf.fit(train_dataset, train_labels)
elif args.decoder == 'MLP':
clf = MLPClassifier(random_state=run_id, max_iter=10000).fit(train_dataset, train_labels)
clf = MLPClassifier(random_state=run_id, max_iter=100).fit(train_dataset, train_labels)

test_pred = clf.predict(test_dataset)
test_acc = sum(np.asarray(test_labels) == test_pred ) / len(test_labels)
Expand Down
37 changes: 28 additions & 9 deletions core/model_finetuning/mlp_tfidf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def parse_args() -> argparse.Namespace:
help='word embedding method')
parser.add_argument('--score', dest='score', type=str, required=False, default='mlp_score',
help='decoder name')
parser.add_argument('--repeat', type=int, default=3,
parser.add_argument('--repeat', type=int, default=5,
help='The number of repeated jobs.')
parser.add_argument('opts', default=None, nargs=argparse.REMAINDER,
help='See graphgym/config.py for remaining options.')
Expand All @@ -77,7 +77,23 @@ def __len__(self):
def __getitem__(self, idx):
return self.embeddings[idx], self.labels[idx]




class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(MLP, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim)
)

def forward(self, x):
return self.model(x)


def project_main(): # sourcery skip: avoid-builtin-shadow, hoist-statement-from-loop

# process params
Expand Down Expand Up @@ -130,13 +146,16 @@ def project_main(): # sourcery skip: avoid-builtin-shadow, hoist-statement-from
)

dump_cfg(cfg)
in_channels = train_dataset.shape[1]

model = mlp_model(in_channels,
cfg.decoder.hidden_channels,
cfg.decoder.out_channels,
cfg.decoder.num_layers,
cfg.decoder.dropout).to(cfg.device)
# model = mlp_model(in_channels,
# cfg.decoder.hidden_channels,
# cfg.decoder.out_channels,
# cfg.decoder.num_layers,
# cfg.decoder.dropout).to(cfg.device)
input_dim = train_dataset.shape[1]
hidden_dim = 128
output_dim = len(np.unique(train_labels))
model = MLP(input_dim, hidden_dim, output_dim).to(cfg.device)

model = model.to(cfg.device)

print_logger.info(f"{model} on {next(model.parameters()).device}" )
Expand Down
54 changes: 0 additions & 54 deletions core/model_finetuning/results/gae-gae-cora-origin/gae.yaml

This file was deleted.

2 changes: 1 addition & 1 deletion core/model_finetuning/scripts/create_data.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/bin/sh
#SBATCH --time=3-00:00:00
#SBATCH --partition=cpuonly
#SBATCH --job-name=gnn_wb
#SBATCH --job-name=tfidf



Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
#!/bin/sh
#SBATCH --time=02:00:00
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --partition=accelerated
#SBATCH --job-name=cross_encoder
#SBATCH --mem=50160mb
#BATCH --cpu-per-gpu=38
#SBATCH --time=3-00:00:00
#SBATCH --partition=cpuonly
#SBATCH --job-name=gnn_wb



#SBATCH --output=log/TAG_Benchmark_%j.output
#SBATCH --error=error/TAG_Benchmark_%j.error
#SBATCH --gres=gpu:4
#SBATCH --account=hk-project-test-p0022257 # specify the project group


#SBATCH --chdir=/hkfs/work/workspace/scratch/cc7738-benchmark_tag/TAPE_chen/batch
Expand All @@ -18,10 +15,10 @@
#SBATCH --mail-type=ALL
#SBATCH --mail-user=cc7738@kit.edu

source /hkfs/home/project/hk-project-test-p0021478/cc7738/anaconda3/etc/profile.d/conda.sh
source /hkfs/home/haicore/aifb/cc7738/anaconda3/etc/profile.d/conda.sh

conda activate base
conda activate TAG_LP
conda activate ss
# <<< conda initialize <<<
module purge
module load devel/cmake/3.18
Expand All @@ -32,4 +29,9 @@ module load compiler/gnu/12
cd /hkfs/work/workspace/scratch/cc7738-benchmark_tag/TAPE_chen/core/model_finetuning


python cross_encoder.py
for data in cora arxiv_2023 pubmed ogbn-arxiv ogbn-products; do
python mlp.py --data $data --decoder MLP
done
for data in cora arxiv_2023 pubmed ogbn-arxiv ogbn-products; do
python mlp.py --data $data --decoder Ridge
done

0 comments on commit 15c8e3c

Please # to comment.