-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathpyconvert.jl
526 lines (467 loc) · 18 KB
/
pyconvert.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
@enum PyConvertPriority begin
PYCONVERT_PRIORITY_WRAP = 400
PYCONVERT_PRIORITY_ARRAY = 300
PYCONVERT_PRIORITY_CANONICAL = 200
PYCONVERT_PRIORITY_NORMAL = 0
PYCONVERT_PRIORITY_FALLBACK = -100
end
struct PyConvertRule
type::Type
func::Function
priority::PyConvertPriority
end
const PYCONVERT_RULES = Dict{String,Vector{PyConvertRule}}()
const PYCONVERT_EXTRATYPES = Py[]
"""
pyconvert_add_rule(tname::String, T::Type, func::Function, priority::PyConvertPriority=PYCONVERT_PRIORITY_NORMAL)
Add a new conversion rule for `pyconvert`.
### Arguments
- `tname` is a string of the form `"__module__:__qualname__"` identifying a Python type `t`,
such as `"builtins:dict"` or `"sympy.core.symbol:Symbol"`. This rule only applies to
Python objects of this type.
- `T` is a Julia type, such that this rule only applies when the target type intersects
with `T`.
- `func` is the function implementing the rule.
- `priority` determines whether to prioritise this rule above others.
When `pyconvert(R, x)` is called, all rules such that `typeintersect(T, R) != Union{}`
and `pyisinstance(x, t)` are considered. These rules are sorted first by priority,
then by the specificity of `t` (e.g. `bool` is more specific than `int` is more specific
than `object`) then by the order they were added. The rules are tried in turn until one
succeeds.
### Implementing `func`
`func` is called as `func(S, x::Py)` for some `S <: T`.
It must return one of:
- `pyconvert_return(ans)` where `ans` is the result of the conversion (and must be an `S`).
- `pyconvert_unconverted()` if the conversion was not possible (e.g. converting a `list` to
`Vector{Int}` might fail if some of the list items are not integers).
The target type `S` is never a union or the empty type, i.e. it is always a data type or
union-all.
### Priority
Most rules should have priority `PYCONVERT_PRIORITY_NORMAL` (the default) which is for any
reasonable conversion rule.
Use priority `PYCONVERT_PRIORITY_CANONICAL` for **canonical** conversion rules. Immutable
objects may be canonically converted to their corresponding Julia type, such as `int` to
`Integer`. Mutable objects **must** be converted to a wrapper type, such that the original
Python object can be retrieved. For example a `list` is canonically converted to `PyList`
and not to a `Vector`. There should not be more than one canonical conversion rule for a
given Python type.
Other priorities are reserved for internal use.
"""
function pyconvert_add_rule(
pytypename::String,
type::Type,
func::Function,
priority::PyConvertPriority = PYCONVERT_PRIORITY_NORMAL,
)
@nospecialize type func
push!(
get!(Vector{PyConvertRule}, PYCONVERT_RULES, pytypename),
PyConvertRule(type, func, priority),
)
empty!.(values(PYCONVERT_RULES_CACHE))
return
end
# Alternative ways to represent the result of conversion.
if true
# Returns either the result or Unconverted().
struct Unconverted end
@inline pyconvert_return(x) = x
@inline pyconvert_unconverted() = Unconverted()
@inline pyconvert_returntype(::Type{T}) where {T} = Union{T,Unconverted}
@inline pyconvert_isunconverted(r) = r === Unconverted()
@inline pyconvert_result(::Type{T}, r) where {T} = r::T
elseif false
# Stores the result in PYCONVERT_RESULT.
# This is global state, probably best avoided.
const PYCONVERT_RESULT = Ref{Any}(nothing)
@inline pyconvert_return(x) = (PYCONVERT_RESULT[] = x; true)
@inline pyconvert_unconverted() = false
@inline pyconvert_returntype(::Type{T}) where {T} = Bool
@inline pyconvert_isunconverted(r::Bool) = !r
@inline pyconvert_result(::Type{T}, r::Bool) where {T} =
(ans = PYCONVERT_RESULT[]::T; PYCONVERT_RESULT[] = nothing; ans)
else
# Same as the previous scheme, but with special handling for bits types.
# This is global state, probably best avoided.
const PYCONVERT_RESULT = Ref{Any}(nothing)
const PYCONVERT_RESULT_ISBITS = Ref{Bool}(false)
const PYCONVERT_RESULT_TYPE = Ref{Type}(Union{})
const PYCONVERT_RESULT_BITSLEN = 1024
const PYCONVERT_RESULT_BITS = fill(0x00, PYCONVERT_RESULT_BITSLEN)
function pyconvert_return(x::T) where {T}
if isbitstype(T) && sizeof(T) ≤ PYCONVERT_RESULT_BITSLEN
unsafe_store!(Ptr{T}(pointer(PYCONVERT_RESULT_BITS)), x)
PYCONVERT_RESULT_ISBITS[] = true
PYCONVERT_RESULT_TYPE[] = T
else
PYCONVERT_RESULT[] = x
PYCONVERT_RESULT_ISBITS[] = false
end
return true
end
@inline pyconvert_unconverted() = false
@inline pyconvert_returntype(::Type{T}) where {T} = Bool
@inline pyconvert_isunconverted(r::Bool) = !r
function pyconvert_result(::Type{T}, r::Bool) where {T}
if isbitstype(T)
if sizeof(T) ≤ PYCONVERT_RESULT_BITSLEN
@assert PYCONVERT_RESULT_ISBITS[]
@assert PYCONVERT_RESULT_TYPE[] == T
return unsafe_load(Ptr{T}(pointer(PYCONVERT_RESULT_BITS)))::T
end
elseif PYCONVERT_RESULT_ISBITS[]
t = PYCONVERT_RESULT_TYPE[]
@assert isbitstype(t)
@assert sizeof(t) ≤ PYCONVERT_RESULT_BITSLEN
@assert t <: T
@assert isconcretetype(t)
return unsafe_load(Ptr{t}(pointer(PYCONVERT_RESULT_BITS)))::T
end
# general case
ans = PYCONVERT_RESULT[]::T
PYCONVERT_RESULT[] = nothing
return ans::T
end
end
pyconvert_result(r) = pyconvert_result(Any, r)
pyconvert_tryconvert(::Type{T}, x::T) where {T} = pyconvert_return(x)
pyconvert_tryconvert(::Type{T}, x) where {T} =
try
pyconvert_return(convert(T, x)::T)
catch
pyconvert_unconverted()
end
function pyconvert_typename(t::Py)
m = pygetattr(t, "__module__", "<unknown>")
n = pygetattr(t, "__name__", "<name>")
return "$m:$n"
end
function _pyconvert_get_rules(pytype::Py)
pyisin(x, ys) = any(pyis(x, y) for y in ys)
# get the MROs of all base types we are considering
omro = collect(pytype.__mro__)
basetypes = Py[pytype]
basemros = Vector{Py}[omro]
for xtype in PYCONVERT_EXTRATYPES
# find the topmost supertype of
xbase = PyNULL
for base in omro
if pyissubclass(base, xtype)
xbase = base
end
end
if !pyisnull(xbase)
push!(basetypes, xtype)
xmro = collect(xtype.__mro__)
pyisin(xbase, xmro) || pushfirst!(xmro, xbase)
push!(basemros, xmro)
end
end
for xbase in basetypes[2:end]
push!(basemros, [xbase])
end
# merge the MROs
# this is a port of the merge() function at the bottom of:
# https://www.python.org/download/releases/2.3/mro/
mro = Py[]
while !isempty(basemros)
# find the first head not contained in any tail
ok = false
b = PyNULL
for bmro in basemros
b = bmro[1]
if all(bmro -> !pyisin(b, bmro[2:end]), basemros)
ok = true
break
end
end
ok || error(
"Fatal inheritance error: could not merge MROs (mro=$mro, basemros=$basemros)",
)
# add it to the list
push!(mro, b)
# remove it from consideration
for bmro in basemros
filter!(t -> !pyis(t, b), bmro)
end
# remove empty lists
filter!(x -> !isempty(x), basemros)
end
# check the original MRO is preserved
omro_ = filter(t -> pyisin(t, omro), mro)
@assert length(omro) == length(omro_)
@assert all(pyis(x, y) for (x, y) in zip(omro, omro_))
# get the names of the types in the MRO of pytype
xmro = [String[pyconvert_typename(t)] for t in mro]
# add special names corresponding to certain interfaces
# these get inserted just above the topmost type satisfying the interface
for (t, x) in reverse(collect(zip(mro, xmro)))
if pyhasattr(t, "__array_struct__")
push!(x, "<arraystruct>")
break
end
end
for (t, x) in reverse(collect(zip(mro, xmro)))
if pyhasattr(t, "__array_interface__")
push!(x, "<arrayinterface>")
break
end
end
for (t, x) in reverse(collect(zip(mro, xmro)))
if pyhasattr(t, "__array__")
push!(x, "<array>")
break
end
end
for (t, x) in reverse(collect(zip(mro, xmro)))
if C.PyType_CheckBuffer(getptr(t))
push!(x, "<buffer>")
break
end
end
# flatten to get the MRO as a list of strings
mro = String[x for xs in xmro for x in xs]
# get corresponding rules
rules = PyConvertRule[
rule for tname in mro for
rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES, tname)
]
# order the rules by priority, then by original order
order = sort(axes(rules, 1), by = i -> (rules[i].priority, -i), rev = true)
rules = rules[order]
@debug "pyconvert" pytype mro = join(mro, " ")
return rules
end
const PYCONVERT_PREFERRED_TYPE = Dict{Py,Type}()
pyconvert_preferred_type(pytype::Py) =
get!(PYCONVERT_PREFERRED_TYPE, pytype) do
if pyissubclass(pytype, pybuiltins.int)
Union{Int,BigInt}
else
_pyconvert_get_rules(pytype)[1].type
end
end
function pyconvert_get_rules(type::Type, pytype::Py)
@nospecialize type
# this could be cached
rules = _pyconvert_get_rules(pytype)
# intersect rules with type
rules = PyConvertRule[
PyConvertRule(typeintersect(rule.type, type), rule.func, rule.priority) for
rule in rules
]
# explode out unions
rules = [
PyConvertRule(type, rule.func, rule.priority) for rule in rules for
type in Utils.explode_union(rule.type)
]
# filter out empty rules
rules = [rule for rule in rules if rule.type != Union{}]
# filter out repeated rules
rules = [
rule for (i, rule) in enumerate(rules) if !any(
(rule.func === rules[j].func) && ((rule.type) <: (rules[j].type)) for
j = 1:(i-1)
)
]
@debug "pyconvert" type rules
return Function[pyconvert_fix(rule.type, rule.func) for rule in rules]
end
pyconvert_fix(::Type{T}, func) where {T} = x -> func(T, x)
const PYCONVERT_RULES_CACHE = Dict{Type,Dict{C.PyPtr,Vector{Function}}}()
@generated pyconvert_rules_cache(::Type{T}) where {T} =
get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE, T)
function pyconvert_rule_fast(::Type{T}, x::Py) where {T}
if T isa Union
a = pyconvert_rule_fast(T.a, x)::pyconvert_returntype(T.a)
pyconvert_isunconverted(a) || return a
b = pyconvert_rule_fast(T.b, x)::pyconvert_returntype(T.b)
pyconvert_isunconverted(b) || return b
elseif (T == Nothing) | (T == Missing)
pyisnone(x) && return pyconvert_return(T())
elseif (T == Bool)
pyisFalse(x) && return pyconvert_return(false)
pyisTrue(x) && return pyconvert_return(true)
elseif (T == Int) | (T == BigInt)
pyisint(x) && return pyconvert_rule_int(T, x)
elseif (T == Float64)
pyisfloat(x) && return pyconvert_return(T(pyfloat_asdouble(x)))
elseif (T == ComplexF64)
pyiscomplex(x) && return pyconvert_return(T(pycomplex_ascomplex(x)))
elseif (T == String) | (T == Char) | (T == Symbol)
pyisstr(x) && return pyconvert_rule_str(T, x)
elseif (T == Vector{UInt8}) | (T == Base.CodeUnits{UInt8,String})
pyisbytes(x) && return pyconvert_rule_bytes(T, x)
elseif (T <: StepRange) | (T <: UnitRange)
pyisrange(x) && return pyconvert_rule_range(T, x)
end
pyconvert_unconverted()
end
function pytryconvert(::Type{T}, x_) where {T}
# Convert the input to a Py
x = Py(x_)
# We can optimize the conversion for some types by overloading pytryconvert_fast.
# It MUST give the same results as via the slower route using rules.
ans1 = pyconvert_rule_fast(T, x)::pyconvert_returntype(T)
pyconvert_isunconverted(ans1) || return ans1
# get rules from the cache
# TODO: we should hold weak references and clear the cache if types get deleted
tptr = C.Py_Type(getptr(x))
trules = pyconvert_rules_cache(T)
rules = get!(trules, tptr) do
t = pynew(incref(tptr))
ans = pyconvert_get_rules(T, t)::Vector{Function}
pydel!(t)
ans
end
# apply the rules
for rule in rules
ans2 = rule(x)::pyconvert_returntype(T)
pyconvert_isunconverted(ans2) || return ans2
end
return pyconvert_unconverted()
end
"""
@pyconvert(T, x, [onfail])
Convert the Python object `x` to a `T`.
On failure, evaluates to `onfail`, which defaults to `return pyconvert_unconverted()` (mainly useful for writing conversion rules).
"""
macro pyconvert(T, x, onfail = :(return $pyconvert_unconverted()))
quote
T = $(esc(T))
x = $(esc(x))
ans = pytryconvert(T, x)
if pyconvert_isunconverted(ans)
$(esc(onfail))
else
pyconvert_result(T, ans)
end
end
end
export @pyconvert
"""
pyconvert(T, x, [d])
Convert the Python object `x` to a `T`.
If `d` is specified, it is returned on failure instead of throwing an error.
"""
pyconvert(::Type{T}, x) where {T} = @autopy x @pyconvert T x_ error(
"cannot convert this Python '$(pytype(x_).__name__)' to a Julia '$T'",
)
pyconvert(::Type{T}, x, d) where {T} = @autopy x @pyconvert T x_ d
export pyconvert
"""
pyconvertarg(T, x, name)
Convert the Python object `x` to a `T`.
On failure, throws a Python `TypeError` saying that the argument `name` could not be converted.
"""
pyconvertarg(::Type{T}, x, name) where {T} = @autopy x @pyconvert T x_ begin
errset(
pybuiltins.TypeError,
"Cannot convert argument '$name' to a Julia '$T', got a '$(pytype(x_).__name__)'",
)
pythrow()
end
function init_pyconvert()
push!(PYCONVERT_EXTRATYPES, pyimport("io" => "IOBase"))
push!(
PYCONVERT_EXTRATYPES,
pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))...,
)
push!(
PYCONVERT_EXTRATYPES,
pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))...,
)
priority = PYCONVERT_PRIORITY_CANONICAL
pyconvert_add_rule("builtins:NoneType", Nothing, pyconvert_rule_none, priority)
pyconvert_add_rule("builtins:bool", Bool, pyconvert_rule_bool, priority)
pyconvert_add_rule("builtins:float", Float64, pyconvert_rule_float, priority)
pyconvert_add_rule(
"builtins:complex",
Complex{Float64},
pyconvert_rule_complex,
priority,
)
pyconvert_add_rule("numbers:Integral", Integer, pyconvert_rule_int, priority)
pyconvert_add_rule("builtins:str", String, pyconvert_rule_str, priority)
pyconvert_add_rule(
"builtins:bytes",
Base.CodeUnits{UInt8,String},
pyconvert_rule_bytes,
priority,
)
pyconvert_add_rule(
"builtins:range",
StepRange{<:Integer,<:Integer},
pyconvert_rule_range,
priority,
)
pyconvert_add_rule(
"numbers:Rational",
Rational{<:Integer},
pyconvert_rule_fraction,
priority,
)
pyconvert_add_rule("builtins:tuple", NamedTuple, pyconvert_rule_iterable, priority)
pyconvert_add_rule("builtins:tuple", Tuple, pyconvert_rule_iterable, priority)
pyconvert_add_rule("datetime:datetime", DateTime, pyconvert_rule_datetime, priority)
pyconvert_add_rule("datetime:date", Date, pyconvert_rule_date, priority)
pyconvert_add_rule("datetime:time", Time, pyconvert_rule_time, priority)
pyconvert_add_rule(
"datetime:timedelta",
Microsecond,
pyconvert_rule_timedelta,
priority,
)
pyconvert_add_rule(
"builtins:BaseException",
PyException,
pyconvert_rule_exception,
priority,
)
priority = PYCONVERT_PRIORITY_NORMAL
pyconvert_add_rule("builtins:NoneType", Missing, pyconvert_rule_none, priority)
pyconvert_add_rule("builtins:bool", Number, pyconvert_rule_bool, priority)
pyconvert_add_rule("numbers:Real", Number, pyconvert_rule_float, priority)
pyconvert_add_rule("builtins:float", Nothing, pyconvert_rule_float, priority)
pyconvert_add_rule("builtins:float", Missing, pyconvert_rule_float, priority)
pyconvert_add_rule("numbers:Complex", Number, pyconvert_rule_complex, priority)
pyconvert_add_rule("numbers:Integral", Number, pyconvert_rule_int, priority)
pyconvert_add_rule("builtins:str", Symbol, pyconvert_rule_str, priority)
pyconvert_add_rule("builtins:str", Char, pyconvert_rule_str, priority)
pyconvert_add_rule("builtins:bytes", Vector{UInt8}, pyconvert_rule_bytes, priority)
pyconvert_add_rule(
"builtins:range",
UnitRange{<:Integer},
pyconvert_rule_range,
priority,
)
pyconvert_add_rule("numbers:Rational", Number, pyconvert_rule_fraction, priority)
pyconvert_add_rule(
"collections.abc:Iterable",
Vector,
pyconvert_rule_iterable,
priority,
)
pyconvert_add_rule("collections.abc:Iterable", Tuple, pyconvert_rule_iterable, priority)
pyconvert_add_rule("collections.abc:Iterable", Pair, pyconvert_rule_iterable, priority)
pyconvert_add_rule("collections.abc:Iterable", Set, pyconvert_rule_iterable, priority)
pyconvert_add_rule(
"collections.abc:Sequence",
Vector,
pyconvert_rule_iterable,
priority,
)
pyconvert_add_rule("collections.abc:Sequence", Tuple, pyconvert_rule_iterable, priority)
pyconvert_add_rule("collections.abc:Set", Set, pyconvert_rule_iterable, priority)
pyconvert_add_rule("collections.abc:Mapping", Dict, pyconvert_rule_mapping, priority)
pyconvert_add_rule(
"datetime:timedelta",
Millisecond,
pyconvert_rule_timedelta,
priority,
)
pyconvert_add_rule("datetime:timedelta", Second, pyconvert_rule_timedelta, priority)
pyconvert_add_rule("datetime:timedelta", Nanosecond, pyconvert_rule_timedelta, priority)
priority = PYCONVERT_PRIORITY_FALLBACK
pyconvert_add_rule("builtins:object", Py, pyconvert_rule_object, priority)
end