Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[pre-commit.ci] pre-commit autoupdate #1

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,30 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.5.0
hooks:
- id: trailing-whitespace
args: [--markdown-linebreak-ext=md]
- id: end-of-file-fixer

- repo: https://github.com/asottile/pyupgrade
rev: v2.37.3
rev: v3.15.1
hooks:
- id: pyupgrade
args: [--py39-plus]

- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.13.2
hooks:
- id: isort
args: [--profile=black]

- repo: https://github.com/psf/black
rev: 22.8.0
rev: 24.2.0
hooks:
- id: black

- repo: https://github.com/PyCQA/flake8
rev: 5.0.4
rev: 7.0.0
hooks:
- id: flake8
additional_dependencies:
Expand Down
25 changes: 19 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ def main(
show_default=False,
),
dataset_name: str = Argument(
..., help="사용할 데이터셋의 huggingface 이름", rich_help_panel="데이터", show_default=False
...,
help="사용할 데이터셋의 huggingface 이름",
rich_help_panel="데이터",
show_default=False,
),
max_seq_length: Optional[int] = Option(
None, help=max_seq_length_desc, rich_help_panel="모델"
Expand Down Expand Up @@ -102,18 +105,28 @@ def main(
dataset_split: str = Option(
"train", help="데이터셋에서 사용할 split", rich_help_panel="데이터"
),
text_col: str = Option("text", help="데이터셋에서 text를 담은 열의 이름", rich_help_panel="데이터"),
text_col: str = Option(
"text", help="데이터셋에서 text를 담은 열의 이름", rich_help_panel="데이터"
),
use_auth_token: Optional[bool] = Option(
None, help="huggingface auth token", rich_help_panel="데이터"
),
num_workers: int = Option(
8, help="데이터 로더에서 사용할 프로세스 수, windows면 0으로 고정됨", rich_help_panel="훈련"
8,
help="데이터 로더에서 사용할 프로세스 수, windows면 0으로 고정됨",
rich_help_panel="훈련",
),
fast_dev_run: bool = Option(
False, help="훈련 테스트를 실행합니다.", rich_help_panel="훈련"
),
output_path: Optional[str] = Option(
None, help="모델을 저장할 경로", rich_help_panel="훈련"
),
fast_dev_run: bool = Option(False, help="훈련 테스트를 실행합니다.", rich_help_panel="훈련"),
output_path: Optional[str] = Option(None, help="모델을 저장할 경로", rich_help_panel="훈련"),
save_steps: int = Option(10_000, help="모델을 저장할 주기", rich_help_panel="훈련"),
wandb_name: Optional[str] = Option(None, help="wandb 이름", rich_help_panel="훈련"),
log_every_n_steps: int = Option(200, help="몇 스텝마다 로그를 남길지", rich_help_panel="훈련"),
log_every_n_steps: int = Option(
200, help="몇 스텝마다 로그를 남길지", rich_help_panel="훈련"
),
seed: int = Option(42, help="랜덤 시드", rich_help_panel="훈련"),
):
# 모델
Expand Down
6 changes: 3 additions & 3 deletions tsdae/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def delete(self, text: str, /) -> str:

keep_or_not = np.random.rand(n) > self.p
if sum(keep_or_not) == 0:
keep_or_not[
np.random.choice(n)
] = True # guarantee that at least one word remains
keep_or_not[np.random.choice(n)] = (
True # guarantee that at least one word remains
)
tokens_processed = [tokens[i] for i in range(n) if keep_or_not[i]]
words_processed = self.kiwi.join(tokens_processed, lm_search=False)
return words_processed