Skip to content

Commit

Permalink
bugfix: fix JIT compilation of prefill kernels (#536)
Browse files Browse the repository at this point in the history
Some bugs were introduced in #534, this PR fix these issues.
  • Loading branch information
yzh119 authored Oct 18, 2024
1 parent de25d76 commit 425040c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
3 changes: 2 additions & 1 deletion python/flashinfer/jit/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,9 @@ def gen_batch_prefill_cu(*args) -> Tuple[str, pathlib.Path]:
os.makedirs(gen_directory)
uri = get_batch_prefill_uri(*args)
file_name = f"{uri}.cu"
path = gen_directory / file_name,
path = gen_directory / file_name
write_if_different(
path,
get_batch_prefill_cu_str(*args),
)
return uri, path
Expand Down
4 changes: 2 additions & 2 deletions python/flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def compile_single_prefill_module(
*args,
verbose: bool = False,
):
uri, path = get_single_prefill_uri(*args)
uri, path = gen_single_prefill_cu(*args)
return load_cuda_ops(
uri, [path],
verbose=verbose,
Expand All @@ -66,7 +66,7 @@ def compile_batch_prefill_module(
*args,
verbose: bool = False,
):
uri, path = get_batch_prefill_uri(*args)
uri, path = gen_batch_prefill_cu(*args)
return load_cuda_ops(
uri, [path],
verbose=verbose,
Expand Down

0 comments on commit 425040c

Please # to comment.