Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Using pydrive with user credentials for authenticated download #3

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions download_ffhq.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
"""Download Flickr-Faces-HQ (FFHQ) dataset to current working directory."""

import os
import re
import sys
import requests
import html
@@ -27,6 +28,8 @@
import itertools
import shutil
from collections import OrderedDict, defaultdict
from pydrive2.auth import GoogleAuth
from pydrive2.drive import GoogleDrive

PIL.ImageFile.LOAD_TRUNCATED_IMAGES = True # avoid "Decompressed Data Too Large" error

@@ -130,6 +133,50 @@ def download_file(session, file_spec, stats, chunk_size=128, num_attempts=10):
except:
pass

def pydrive_create_drive_manager(cmd_auth):
gAuth = GoogleAuth()

if cmd_auth:
gAuth.CommandLineAuth()
else:
gAuth.LocalWebserverAuth()

gAuth.Authorize()
print("authorized access to google drive API!")

drive: GoogleDrive = GoogleDrive(gAuth)
return drive


def pydrive_extract_files_id(drive, link):
try:
fileID = re.search(r"(?<=/d/|id=|rs/).+?(?=/|$)", link)[0] # extract the fileID
return fileID
except Exception as error:
print("error : " + str(error))
print("Link is probably invalid")
print(link)


def pydrive_download_file(drive, spec, stats, chunk_size=128, num_attempts=10):
link = spec['file_url']
save_path = spec['file_path']
id = pydrive_extract_files_id(drive, link)
file_dir = os.path.dirname(save_path)
if file_dir:
os.makedirs(file_dir, exist_ok=True)

pydrive_file = drive.CreateFile({'id': id})
for attempts_left in reversed(range(num_attempts)):
try:
pydrive_file.GetContentFile(save_path)
break
except:
if not attempts_left:
raise
stats['files_done'] += 1
stats['bytes_done'] += os.stat(save_path).st_size

#----------------------------------------------------------------------------

def choose_bytes_unit(num_bytes):
@@ -152,7 +199,7 @@ def format_time(seconds):

#----------------------------------------------------------------------------

def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):
def download_files(file_specs, drive=None, num_threads=32, status_delay=0.2, timing_window=50, **download_kwargs):

# Determine which files to download.
done_specs = {spec['file_path']: spec for spec in file_specs if os.path.isfile(spec['file_path'])}
@@ -169,7 +216,7 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
exception_queue = queue.Queue()
for spec in missing_specs:
spec_queue.put(spec)
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, download_kwargs=download_kwargs)
thread_kwargs = dict(spec_queue=spec_queue, exception_queue=exception_queue, stats=stats, drive=drive, download_kwargs=download_kwargs)
for _thread_idx in range(min(num_threads, len(missing_specs))):
threading.Thread(target=_download_thread, kwargs=thread_kwargs, daemon=True).start()

@@ -206,12 +253,15 @@ def download_files(file_specs, num_threads=32, status_delay=0.2, timing_window=5
except queue.Empty:
pass

def _download_thread(spec_queue, exception_queue, stats, download_kwargs):
def _download_thread(spec_queue, exception_queue, stats, drive, download_kwargs):
with requests.Session() as session:
while not spec_queue.empty():
spec = spec_queue.get()
try:
download_file(session, spec, stats, **download_kwargs)
if drive is not None:
pydrive_download_file(drive, spec, stats, **download_kwargs)
else:
download_file(session, spec, stats, **download_kwargs)
except:
exception_queue.put(sys.exc_info())

@@ -350,10 +400,15 @@ def recreate_aligned_images(json_data, dst_dir='realign1024x1024', output_size=1

#----------------------------------------------------------------------------

def run(tasks, **download_kwargs):
def run(tasks, pydrive, cmd_auth, **download_kwargs):
if pydrive:
drive = pydrive_create_drive_manager(cmd_auth)
else:
drive = None

if not os.path.isfile(json_spec['file_path']) or not os.path.isfile('LICENSE.txt'):
print('Downloading JSON metadata...')
download_files([json_spec, license_specs['json']], **download_kwargs)
download_files([json_spec, license_specs['json']], drive=drive, **download_kwargs)

print('Parsing JSON metadata...')
with open(json_spec['file_path'], 'rb') as f:
@@ -375,7 +430,7 @@ def run(tasks, **download_kwargs):
if len(specs):
print('Downloading %d files...' % len(specs))
np.random.shuffle(specs) # to make the workload more homogeneous
download_files(specs, **download_kwargs)
download_files(specs, drive=drive, **download_kwargs)

if 'align' in tasks:
recreate_aligned_images(json_data)
@@ -390,6 +445,8 @@ def run_cmdline(argv):
parser.add_argument('-t', '--thumbs', help='download 128x128 thumbnails as PNG (1.95 GB)', dest='tasks', action='append_const', const='thumbs')
parser.add_argument('-w', '--wilds', help='download in-the-wild images as PNG (955 GB)', dest='tasks', action='append_const', const='wilds')
parser.add_argument('-r', '--tfrecords', help='download multi-resolution TFRecords (273 GB)', dest='tasks', action='append_const', const='tfrecords')
parser.add_argument('--pydrive', help='use pydrive interface to download files. it overrides google drive quota limitation this requires google credentials (default: False)', action='store_true')
parser.add_argument('--cmd_auth', help='use command line google authentication when using pydrive interface (default: False)', action='store_true')
parser.add_argument('-a', '--align', help='recreate 1024x1024 images from in-the-wild images', dest='tasks', action='append_const', const='align')
parser.add_argument('--num_threads', help='number of concurrent download threads (default: 32)', type=int, default=32, metavar='NUM')
parser.add_argument('--status_delay', help='time between download status prints (default: 0.2)', type=float, default=0.2, metavar='SEC')