diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
index 3145fb511068..98f44d6926d3 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads.py
@@ -664,19 +664,14 @@ def start_bundle(self):
     self.pending_jobs = []
     self.schema_cache = {}
 
-  def process(
-      self,
-      element,
-      load_job_name_prefix,
-      pane_info=beam.DoFn.PaneInfoParam,
-      *schema_side_inputs):
+  def process(self, element, load_job_name_prefix, *schema_side_inputs):
     # Each load job is assumed to have files respecting these constraints:
     # 1. Total size of all files < 15 TB (Max size for load jobs)
     # 2. Total no. of files in a single load job < 10,000
     # This assumption means that there will always be a single load job
     # triggered for each partition of files.
     destination = element[0]
-    partition_key, files = element[1]
+    partition_key, files, pane_index = element[1]
 
     if callable(self.schema):
       schema = self.schema(destination, *schema_side_inputs)
@@ -705,7 +700,7 @@ def process(
             table_reference.datasetId,
             table_reference.tableId))
     job_name = '%s_%s_pane%s_partition%s' % (
-        load_job_name_prefix, destination_hash, pane_info.index, partition_key)
+        load_job_name_prefix, destination_hash, pane_index, partition_key)
     _LOGGER.info('Load job has %s files. Job name is %s.', len(files), job_name)
 
     create_disposition = self.create_disposition
@@ -1104,6 +1099,8 @@ def _load_data(
     # Load data using temp tables
     trigger_loads_outputs = (
         partitions_using_temp_tables
+        | "KeyByPaneIndexWithTempTables" >> beam.ParDo(KeyByPaneIndex())
+        | "ReshuffleBeforeLoadWithTempTables" >> beam.Reshuffle()
         | "TriggerLoadJobsWithTempTables" >> beam.ParDo(
             TriggerLoadJobs(
                 schema=self.schema,
@@ -1186,6 +1183,8 @@ def _load_data(
     # Load data directly to destination table
     destination_load_job_ids_pc = (
         partitions_direct_to_destination
+        | "KeyByPaneIndexWithoutTempTables" >> beam.ParDo(KeyByPaneIndex())
+        | "ReshuffleBeforeLoadWithoutTempTables" >> beam.Reshuffle()
         | "TriggerLoadJobsWithoutTempTables" >> beam.ParDo(
             TriggerLoadJobs(
                 schema=self.schema,
@@ -1313,3 +1312,9 @@ def expand(self, pcoll):
         self.DESTINATION_FILE_PAIRS: all_destination_file_pairs_pc,
         self.DESTINATION_COPY_JOBID_PAIRS: destination_copy_job_ids_pc,
     }
+
+
+class KeyByPaneIndex(beam.DoFn):
+  def process(self, element, pane_info=beam.DoFn.PaneInfoParam):
+    destination, (partition_key, files) = element
+    return [(destination, (partition_key, files, pane_info.index))]
diff --git a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
index 10453d9c8baf..35b66a0fc48c 100644
--- a/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
+++ b/sdks/python/apache_beam/io/gcp/bigquery_file_loads_test.py
@@ -593,8 +593,8 @@ def test_wait_for_load_job_completion(self, sleep_mock):
     bq_client.jobs.Get.side_effect = [
         job_1_waiting, job_2_done, job_1_done, job_2_done
     ]
-    partition_1 = ('project:dataset.table0', (0, ['file0']))
-    partition_2 = ('project:dataset.table1', (1, ['file1']))
+    partition_1 = ('project:dataset.table0', (0, ['file0'], 0))
+    partition_2 = ('project:dataset.table1', (1, ['file1'], 0))
     bq_client.jobs.Insert.side_effect = [job_1, job_2]
     test_job_prefix = "test_job"
 
@@ -636,8 +636,8 @@ def test_one_load_job_failed_after_waiting(self, sleep_mock):
     bq_client.jobs.Get.side_effect = [
         job_1_waiting, job_2_done, job_1_error, job_2_done
     ]
-    partition_1 = ('project:dataset.table0', (0, ['file0']))
-    partition_2 = ('project:dataset.table1', (1, ['file1']))
+    partition_1 = ('project:dataset.table0', (0, ['file0'], 0))
+    partition_2 = ('project:dataset.table1', (1, ['file1'], 0))
     bq_client.jobs.Insert.side_effect = [job_1, job_2]
     test_job_prefix = "test_job"