Skip to content

Commit f6bf576

Browse files
Add support for virtual workflow threads (#2297)
Add support for virtual workflow threads
1 parent 37081cc commit f6bf576

File tree

27 files changed

+902
-72
lines changed

27 files changed

+902
-72
lines changed

.github/workflows/ci.yml

+10-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ jobs:
6060
- name: Set up Java
6161
uses: actions/setup-java@v4
6262
with:
63-
java-version: "11"
63+
java-version: |
64+
21
65+
11
6466
distribution: "temurin"
6567

6668
- name: Set up Gradle
@@ -79,6 +81,13 @@ jobs:
7981
USE_DOCKER_SERVICE: true
8082
run: ./gradlew --no-daemon test -x checkLicenseMain -x checkLicenses -x spotlessCheck -x spotlessApply -x spotlessJava
8183

84+
- name: Run virtual thread tests
85+
env:
86+
USER: unittest
87+
TEMPORAL_SERVICE_ADDRESS: localhost:7233
88+
USE_DOCKER_SERVICE: true
89+
run: ./gradlew --no-daemon :temporal-sdk:virtualThreadTests -x checkLicenseMain -x checkLicenses -x spotlessCheck -x spotlessApply -x spotlessJava
90+
8291
- name: Publish Test Report
8392
uses: mikepenz/action-junit-report@v4
8493
if: success() || failure() # always run even if the previous step fails

temporal-sdk/build.gradle

+105
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,36 @@ dependencies {
3131
testImplementation group: 'ch.qos.logback', name: 'logback-classic', version: "${logbackVersion}"
3232
}
3333

34+
// Temporal SDK supports Java 8 or later so to support virtual threads
35+
// we need to compile the code with Java 21 and package it in a multi-release jar.
36+
sourceSets {
37+
java21 {
38+
java {
39+
srcDirs = ['src/main/java21']
40+
}
41+
}
42+
}
43+
44+
dependencies {
45+
java21Implementation files(sourceSets.main.output.classesDirs) { builtBy compileJava }
46+
}
47+
48+
tasks.named('compileJava21Java') {
49+
javaCompiler = javaToolchains.compilerFor {
50+
languageVersion = JavaLanguageVersion.of(21)
51+
}
52+
options.release = 21
53+
}
54+
55+
jar {
56+
into('META-INF/versions/21') {
57+
from sourceSets.java21.output
58+
}
59+
manifest.attributes(
60+
'Multi-Release': 'true'
61+
)
62+
}
63+
3464
task registerNamespace(type: JavaExec) {
3565
getMainClass().set('io.temporal.internal.docker.RegisterTestNamespace')
3666
classpath = sourceSets.test.runtimeClasspath
@@ -49,4 +79,79 @@ task testResourceIndependent(type: Test) {
4979
includeCategories 'io.temporal.worker.IndependentResourceBasedTests'
5080
maxParallelForks = 1
5181
}
82+
}
83+
84+
// To test the virtual thread support we need to run a separate test suite with Java 21
85+
testing {
86+
suites {
87+
// Common setup for all test suites
88+
configureEach {
89+
useJUnit(junitVersion)
90+
dependencies {
91+
implementation project()
92+
implementation "ch.qos.logback:logback-classic:${logbackVersion}"
93+
implementation project(':temporal-testing')
94+
95+
implementation "junit:junit:${junitVersion}"
96+
implementation "org.mockito:mockito-core:${mockitoVersion}"
97+
implementation 'pl.pragmatists:JUnitParams:1.1.1'
98+
implementation("com.jayway.jsonpath:json-path:$jsonPathVersion"){
99+
exclude group: 'org.slf4j', module: 'slf4j-api'
100+
}
101+
}
102+
targets {
103+
all {
104+
testTask.configure {
105+
testLogging {
106+
events 'passed', 'skipped', 'failed'
107+
exceptionFormat 'full'
108+
// Uncomment the following line if you want to see test logs in gradlew run.
109+
showStandardStreams true
110+
}
111+
}
112+
}
113+
}
114+
}
115+
116+
virtualThreadTests(JvmTestSuite) {
117+
targets {
118+
all {
119+
testTask.configure {
120+
javaLauncher = javaToolchains.launcherFor {
121+
languageVersion = JavaLanguageVersion.of(21)
122+
}
123+
shouldRunAfter(test)
124+
}
125+
}
126+
}
127+
}
128+
129+
// Run the same test as the normal test task with virtual threads
130+
testsWithVirtualThreads(JvmTestSuite) {
131+
// Use the same source and resources as the main test set
132+
sources {
133+
java {
134+
srcDirs = ['src/test/java']
135+
}
136+
resources {
137+
srcDirs = ["src/test/resources"]
138+
}
139+
}
140+
141+
targets {
142+
all {
143+
testTask.configure {
144+
javaLauncher = javaToolchains.launcherFor {
145+
languageVersion = JavaLanguageVersion.of(21)
146+
}
147+
environment("USE_VIRTUAL_THREADS", "false")
148+
}
149+
}
150+
}
151+
}
152+
}
153+
}
154+
155+
tasks.named('check') {
156+
dependsOn(testing.suites.virtualThreadTests)
52157
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved.
3+
*
4+
* Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
5+
*
6+
* Modifications copyright (C) 2017 Uber Technologies, Inc.
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this material except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
21+
package io.temporal.internal.task;
22+
23+
/**
24+
* Function interface for {@link VirtualThreadDelegate#newVirtualThreadExecutor(ThreadConfigurator)}
25+
* called for every thread created.
26+
*/
27+
@FunctionalInterface
28+
public interface ThreadConfigurator {
29+
/** Invoked for every thread created by {@link VirtualThreadDelegate#newVirtualThreadExecutor}. */
30+
void configure(Thread t);
31+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright (C) 2022 Temporal Technologies, Inc. All Rights Reserved.
3+
*
4+
* Copyright (C) 2012-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved.
5+
*
6+
* Modifications copyright (C) 2017 Uber Technologies, Inc.
7+
*
8+
* Licensed under the Apache License, Version 2.0 (the "License");
9+
* you may not use this material except in compliance with the License.
10+
* You may obtain a copy of the License at
11+
*
12+
* http://www.apache.org/licenses/LICENSE-2.0
13+
*
14+
* Unless required by applicable law or agreed to in writing, software
15+
* distributed under the License is distributed on an "AS IS" BASIS,
16+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
* See the License for the specific language governing permissions and
18+
* limitations under the License.
19+
*/
20+
21+
package io.temporal.internal.task;
22+
23+
import java.util.concurrent.ExecutorService;
24+
25+
/**
26+
* Internal delegate for virtual thread handling on JDK 21. This is a dummy version for reachability
27+
* on JDK <21.
28+
*/
29+
public final class VirtualThreadDelegate {
30+
public static ExecutorService newVirtualThreadExecutor(ThreadConfigurator configurator) {
31+
throw new UnsupportedOperationException("Virtual threads not supported on JDK <21");
32+
}
33+
34+
private VirtualThreadDelegate() {}
35+
}

temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,8 @@ public boolean start() {
104104
new TaskHandlerImpl(handler),
105105
pollerOptions,
106106
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
107-
true);
107+
true,
108+
options.isUsingVirtualThreads());
108109
poller =
109110
new Poller<>(
110111
options.getIdentity(),

temporal-sdk/src/main/java/io/temporal/internal/worker/LocalActivityWorker.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,8 @@ public boolean start() {
688688
new AttemptTaskHandlerImpl(handler),
689689
pollerOptions,
690690
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
691-
false);
691+
false,
692+
options.isUsingVirtualThreads());
692693

693694
this.workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1);
694695
this.slotQueue.start();

temporal-sdk/src/main/java/io/temporal/internal/worker/NexusWorker.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ public boolean start() {
102102
new TaskHandlerImpl(handler),
103103
pollerOptions,
104104
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
105-
true);
105+
true,
106+
options.isUsingVirtualThreads());
106107
poller =
107108
new Poller<>(
108109
options.getIdentity(),

temporal-sdk/src/main/java/io/temporal/internal/worker/PollTaskExecutor.java

+35-22
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222

2323
import com.google.common.base.Preconditions;
2424
import io.temporal.internal.logging.LoggerTag;
25+
import io.temporal.internal.task.VirtualThreadDelegate;
2526
import java.util.Objects;
2627
import java.util.concurrent.*;
28+
import java.util.concurrent.atomic.AtomicInteger;
2729
import javax.annotation.Nonnull;
2830
import org.slf4j.MDC;
2931

@@ -41,7 +43,7 @@ public interface TaskHandler<TT> {
4143
private final TaskHandler<T> handler;
4244
private final PollerOptions pollerOptions;
4345

44-
private final ThreadPoolExecutor taskExecutor;
46+
private final ExecutorService taskExecutor;
4547
private final String pollThreadNamePrefix;
4648

4749
PollTaskExecutor(
@@ -51,35 +53,46 @@ public interface TaskHandler<TT> {
5153
@Nonnull TaskHandler<T> handler,
5254
@Nonnull PollerOptions pollerOptions,
5355
int workerTaskSlots,
54-
boolean synchronousQueue) {
56+
boolean synchronousQueue,
57+
boolean useVirtualThreads) {
5558
this.namespace = Objects.requireNonNull(namespace);
5659
this.taskQueue = Objects.requireNonNull(taskQueue);
5760
this.identity = Objects.requireNonNull(identity);
5861
this.handler = Objects.requireNonNull(handler);
5962
this.pollerOptions = Objects.requireNonNull(pollerOptions);
6063

61-
this.taskExecutor =
62-
new ThreadPoolExecutor(
63-
// for SynchronousQueue we can afford to set it to 0, because the queue is always full
64-
// or empty
65-
// for LinkedBlockingQueue we have to set slots to workerTaskSlots to avoid situation
66-
// when the queue grows, but the amount of threads is not, because the queue is not (and
67-
// never) full
68-
synchronousQueue ? 0 : workerTaskSlots,
69-
workerTaskSlots,
70-
10,
71-
TimeUnit.SECONDS,
72-
synchronousQueue ? new SynchronousQueue<>() : new LinkedBlockingQueue<>());
73-
this.taskExecutor.allowCoreThreadTimeOut(true);
74-
7564
this.pollThreadNamePrefix =
7665
pollerOptions.getPollThreadNamePrefix().replaceFirst("Poller", "Executor");
77-
78-
this.taskExecutor.setThreadFactory(
79-
new ExecutorThreadFactory(
80-
pollerOptions.getPollThreadNamePrefix().replaceFirst("Poller", "Executor"),
81-
pollerOptions.getUncaughtExceptionHandler()));
82-
this.taskExecutor.setRejectedExecutionHandler(new BlockCallerPolicy());
66+
// If virtual threads are enabled, we use a virtual thread executor.
67+
if (useVirtualThreads) {
68+
AtomicInteger threadIndex = new AtomicInteger();
69+
this.taskExecutor =
70+
VirtualThreadDelegate.newVirtualThreadExecutor(
71+
(t) -> {
72+
t.setName(this.pollThreadNamePrefix + ": " + threadIndex.incrementAndGet());
73+
t.setUncaughtExceptionHandler(pollerOptions.getUncaughtExceptionHandler());
74+
});
75+
} else {
76+
ThreadPoolExecutor threadPoolTaskExecutor =
77+
new ThreadPoolExecutor(
78+
// for SynchronousQueue we can afford to set it to 0, because the queue is always full
79+
// or empty
80+
// for LinkedBlockingQueue we have to set slots to workerTaskSlots to avoid situation
81+
// when the queue grows, but the amount of threads is not, because the queue is not
82+
// (and
83+
// never) full
84+
synchronousQueue ? 0 : workerTaskSlots,
85+
workerTaskSlots,
86+
10,
87+
TimeUnit.SECONDS,
88+
synchronousQueue ? new SynchronousQueue<>() : new LinkedBlockingQueue<>());
89+
threadPoolTaskExecutor.allowCoreThreadTimeOut(true);
90+
threadPoolTaskExecutor.setThreadFactory(
91+
new ExecutorThreadFactory(
92+
this.pollThreadNamePrefix, pollerOptions.getUncaughtExceptionHandler()));
93+
threadPoolTaskExecutor.setRejectedExecutionHandler(new BlockCallerPolicy());
94+
this.taskExecutor = threadPoolTaskExecutor;
95+
}
8396
}
8497

8598
@Override

temporal-sdk/src/main/java/io/temporal/internal/worker/Poller.java

+31-15
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,12 @@
2525
import io.grpc.StatusRuntimeException;
2626
import io.temporal.internal.BackoffThrottler;
2727
import io.temporal.internal.common.GrpcUtils;
28+
import io.temporal.internal.task.VirtualThreadDelegate;
2829
import io.temporal.worker.MetricsType;
2930
import java.time.Duration;
3031
import java.util.Objects;
3132
import java.util.concurrent.*;
33+
import java.util.concurrent.atomic.AtomicInteger;
3234
import java.util.concurrent.atomic.AtomicReference;
3335
import org.slf4j.Logger;
3436
import org.slf4j.LoggerFactory;
@@ -57,7 +59,7 @@ interface ThrowingRunnable {
5759
private final PollTask<T> pollTask;
5860
private final PollerOptions pollerOptions;
5961
private static final Logger log = LoggerFactory.getLogger(Poller.class);
60-
private ThreadPoolExecutor pollExecutor;
62+
private ExecutorService pollExecutor;
6163
private final Scope workerMetricsScope;
6264

6365
private final AtomicReference<CountDownLatch> suspendLatch = new AtomicReference<>();
@@ -97,20 +99,34 @@ public boolean start() {
9799
pollerOptions.getMaximumPollRatePerSecond(),
98100
pollerOptions.getMaximumPollRateIntervalMilliseconds());
99101
}
100-
101-
// It is important to pass blocking queue of at least options.getPollThreadCount() capacity. As
102-
// task enqueues next task the buffering is needed to queue task until the previous one releases
103-
// a thread.
104-
pollExecutor =
105-
new ThreadPoolExecutor(
106-
pollerOptions.getPollThreadCount(),
107-
pollerOptions.getPollThreadCount(),
108-
1,
109-
TimeUnit.SECONDS,
110-
new ArrayBlockingQueue<>(pollerOptions.getPollThreadCount()));
111-
pollExecutor.setThreadFactory(
112-
new ExecutorThreadFactory(
113-
pollerOptions.getPollThreadNamePrefix(), pollerOptions.getUncaughtExceptionHandler()));
102+
// If virtual threads are enabled, we use a virtual thread executor.
103+
if (pollerOptions.isUsingVirtualThreads()) {
104+
AtomicInteger threadIndex = new AtomicInteger();
105+
pollExecutor =
106+
VirtualThreadDelegate.newVirtualThreadExecutor(
107+
(t) -> {
108+
// TODO: Consider using a more descriptive name for the thread.
109+
t.setName(
110+
pollerOptions.getPollThreadNamePrefix() + ": " + threadIndex.incrementAndGet());
111+
t.setUncaughtExceptionHandler(uncaughtExceptionHandler);
112+
});
113+
} else {
114+
// It is important to pass blocking queue of at least options.getPollThreadCount() capacity.
115+
// As task enqueues next task the buffering is needed to queue task until the previous one
116+
// releases a thread.
117+
ThreadPoolExecutor threadPoolPoller =
118+
new ThreadPoolExecutor(
119+
pollerOptions.getPollThreadCount(),
120+
pollerOptions.getPollThreadCount(),
121+
1,
122+
TimeUnit.SECONDS,
123+
new ArrayBlockingQueue<>(pollerOptions.getPollThreadCount()));
124+
threadPoolPoller.setThreadFactory(
125+
new ExecutorThreadFactory(
126+
pollerOptions.getPollThreadNamePrefix(),
127+
pollerOptions.getUncaughtExceptionHandler()));
128+
pollExecutor = threadPoolPoller;
129+
}
114130

115131
for (int i = 0; i < pollerOptions.getPollThreadCount(); i++) {
116132
pollExecutor.execute(new PollLoopTask(new PollExecutionTask()));

0 commit comments

Comments
 (0)