diff --git a/code2seq/preprocessing/astminer_to_code2seq.py b/code2seq/preprocessing/astminer_to_code2seq.py index 790440e..4d478b1 100644 --- a/code2seq/preprocessing/astminer_to_code2seq.py +++ b/code2seq/preprocessing/astminer_to_code2seq.py @@ -10,7 +10,10 @@ def _get_id2value_from_csv(path_: str) -> Dict[str, str]: - return dict(numpy.genfromtxt(path_, delimiter=",", dtype=(str, str))[1:]) + with open(path_, "r") as f: + lines = f.read().strip().split("\n")[1:] + parsed_lines = [line.split(",", maxsplit=1) for line in lines] + return {k: v for k, v in parsed_lines} def preprocess_csv(data_folder: str, dataset_name: str, holdout_name: str, is_shuffled: bool): diff --git a/scripts/download_data.sh b/scripts/download_data.sh index 699a435..560886d 100755 --- a/scripts/download_data.sh +++ b/scripts/download_data.sh @@ -9,7 +9,7 @@ LOAD_SPLITTED=false DATA_DIR=./data POJ_DOWNLOAD_SCRIPT=./scripts/download_poj.sh CODEFORCES_DOWNLOAD_SCRIPT=./scripts/download_codeforces.sh -ASTMINER_PATH=../astminer/build/shadow/lib-0.*.jar +ASTMINER_PATH=../astminer/build/shadow/astminer.jar SPLIT_SCRIPT=./scripts/split_dataset.sh function is_int(){