Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Handle large prompts by splitting them into smaller batches so they d… #139

Merged
merged 1 commit into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
public abstract class AbstractModel implements Generator {
private static final Logger logger = LoggerFactory.getLogger(AbstractModel.class);

private static final Integer MAX_BATCH_SIZE = Integer.getInteger("jlama.max_batch_size", 256);

public enum InferenceType {
// Used for distributed inference
INPUT_TO_EMBEDDING(true, false, false, false, false),
Expand Down Expand Up @@ -285,11 +287,12 @@ public AbstractTensor batchForward(
) {
AbstractTensor embedding = null;

//Batch prompt into groups of 1024
for (int i = 0; i < token_ids.length; i += 1024) {
int[] batch = Arrays.copyOfRange(token_ids, i, Math.min(token_ids.length, i + 1024));
//Batch prompt into groups of MAX_BATCH_SIZE
for (int i = 0; i < token_ids.length; i += MAX_BATCH_SIZE) {
int[] batch = Arrays.copyOfRange(token_ids, i, Math.min(token_ids.length, i + MAX_BATCH_SIZE));
embedding = embedInput.batchInputsToEmbeddings(batch, startPos + i);
embedding = forward(embedding, startPos + i, kvbuf, tensorReducer);
logger.debug("Batched forward pass for tokens {} to {}", i, i + batch.length);
}

return embedding;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,28 @@ public void GPT2Run() throws IOException {

@Test
public void Qwen2Run() throws IOException {
String modelPrefix = "../models/Qwen_Qwen2.5-0.5B-Instruct-JQ4";
String modelPrefix = "../models/Qwen_Qwen2.5-0.5B-Instruct-JQ4";
Assume.assumeTrue(Files.exists(Paths.get(modelPrefix)));

AbstractModel qwen2 = ModelSupport.loadModel(new File(modelPrefix), DType.F32, DType.I8);
PromptContext prompt = qwen2.promptSupport().get().builder().addUserMessage("What is the capital of France?").build();

Generator.Response r = qwen2.generate(UUID.randomUUID(), prompt, 0.9f, 1024, makeOutHandler());
int ntools = 200;
Tool[] tools = new Tool[ntools];
for (int i = 0; i < ntools; i++) {
Tool tool = Tool.from(Function.builder()
.description("some tool "+i)
.addParameter("input", "string", "an input", true)
.name("some-tool-"+i)
.build());
tools[i] = tool;
}

PromptContext prompt = qwen2.promptSupport().get().builder()
.addSystemMessage("You are a helpful chatbot who writes short responses.")
.addUserMessage("What is the capital of France?")
.build(tools);

Generator.Response r = qwen2.generate(UUID.randomUUID(), prompt, 0.9f, 32 * 1024, makeOutHandler());
logger.info("Response: {}", r);
}

Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
<!-- Build property abstractions: versions, etc -->
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<revision>0.8.3</revision>
<revision>0.8.4</revision>

<slf4j-api.version>2.0.7</slf4j-api.version>
<logback.version>1.5.6</logback.version>
Expand Down
Loading