Skip to content

Commit

Permalink
Merge pull request #172 from kensho/query_keyword_sanitization
Browse files Browse the repository at this point in the history
Sanitizing property names in queries and creation commands.
  • Loading branch information
mogui committed Feb 9, 2016
2 parents 266b638 + f1840d6 commit b7535bd
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 19 deletions.
23 changes: 14 additions & 9 deletions pyorient/ogm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def create_class(self, cls):
self.client.command(
'ALTER PROPERTY {0} DEFAULT {1}'
.format(class_prop,
PropertyEncoder.encode(prop_value.default)))
PropertyEncoder.encode_value(prop_value.default)))

self.client.command(
'ALTER PROPERTY {0} NOTNULL {1}'
Expand Down Expand Up @@ -366,7 +366,6 @@ def drop_all(self, registry):

def create_vertex(self, vertex_cls, **kwargs):
result = self.client.command(
# to_unicode(str())
to_unicode(self.create_vertex_command(vertex_cls, **kwargs)))[0]

props = result.oRecordData
Expand All @@ -379,8 +378,9 @@ def create_vertex_command(self, vertex_cls, **kwargs):
if kwargs:
db_props = Graph.props_to_db(vertex_cls, kwargs, self.strict)
set_clause = u' SET {}'.format(
u','.join(u'{}={}'.format(k,PropertyEncoder.encode(v))
for k,v in db_props.items()))
u','.join(u'{}={}'.format(
PropertyEncoder.encode_name(k), PropertyEncoder.encode_value(v))
for k, v in db_props.items()))
else:
set_clause = u''

Expand All @@ -402,16 +402,16 @@ def create_edge_command(self, edge_cls, from_vertex, to_vertex, **kwargs):
if kwargs:
db_props = Graph.props_to_db(edge_cls, kwargs, self.strict)
set_clause = u' SET {}'.format(
u','.join(u'{}={}'.format(k,PropertyEncoder.encode(v))
for k,v in db_props.items()))
u','.join(u'{}={}'.format(
PropertyEncoder.encode_name(k), PropertyEncoder.encode_value(v))
for k, v in db_props.items()))
else:
set_clause = ''

return CreateEdgeCommand(
u'CREATE EDGE {} FROM {} TO {}{}'.format(
class_name, from_vertex._id, to_vertex._id, set_clause))


def get_vertex(self, vertex_id):
record = self.client.command('SELECT FROM {}'.format(vertex_id))
return self.vertex_from_record(record[0]) if record else None
Expand All @@ -436,8 +436,9 @@ def save_element(self, element_class, props, elem_id):
if props:
db_props = Graph.props_to_db(element_class, props, self.strict)
set_clause = u' SET {}'.format(
u','.join(u'{}={}'.format(k,PropertyEncoder.encode(v))
for k,v in db_props.items()))
u','.join(u'{}={}'.format(
PropertyEncoder.encode_name(k), PropertyEncoder.encode_value(v))
for k, v in db_props.items()))
else:
set_clause = ''

Expand Down Expand Up @@ -658,6 +659,10 @@ def create_props_mapping(db_to_element):
def props_to_db(element_class, props, strict):
db_props = {}
for k, v in props.items():
# sanitize the property name -- this line
# will raise an error if the name is invalid
PropertyEncoder.encode_name(k)

if hasattr(element_class, k):
prop = getattr(element_class, k)
db_props[prop.name or k] = v
Expand Down
20 changes: 15 additions & 5 deletions pyorient/ogm/property.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .operators import Operand, ArithmeticMixin

import sys
import decimal
import datetime
import decimal
import string
import sys


class Property(Operand):
Expand Down Expand Up @@ -76,8 +77,17 @@ def __str__(self):
return 'UUID()'

class PropertyEncoder:
PROHIBITED_NAME_CHARS = set(''.join([string.whitespace, '"\'']))

@staticmethod
def encode_name(name):
for c in name:
if c in PropertyEncoder.PROHIBITED_NAME_CHARS:
raise ValueError('Prohibited character in property name: {}'.format(name))
return name

@staticmethod
def encode(value):
def encode_value(value):
if isinstance(value, decimal.Decimal):
return repr(str(value))
elif isinstance(value, datetime.datetime) or isinstance(value, datetime.date):
Expand All @@ -89,10 +99,10 @@ def encode(value):
elif value is None:
return 'null'
elif isinstance(value, list) or isinstance(value, set):
return u'[{}]'.format(u','.join([PropertyEncoder.encode(v) for v in value]))
return u'[{}]'.format(u','.join([PropertyEncoder.encode_value(v) for v in value]))
elif isinstance(value, dict):
contents = u','.join([
'{}: {}'.format(PropertyEncoder.encode(k), PropertyEncoder.encode(v))
'{}: {}'.format(PropertyEncoder.encode_value(k), PropertyEncoder.encode_value(v))
for k, v in value.items()
])
return u'{{ {} }}'.format(contents)
Expand Down
6 changes: 3 additions & 3 deletions pyorient/ogm/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def filter_string(self, expression_root):
left_str, ArgConverter.convert_to(ArgConverter.Value
, right, self))
elif op is Operator.Between:
far_right = PropertyEncoder.encode(expression_root.operands[2])
far_right = PropertyEncoder.encode_value(expression_root.operands[2])
return u'{0} BETWEEN {1} and {2}'.format(
left_str, right, far_right)
elif op is Operator.Contains:
Expand All @@ -325,7 +325,7 @@ def filter_string(self, expression_root):
left_str, self.filter_string(right))
else:
return u'{} in {}'.format(
PropertyEncoder.encode(right), left_str)
PropertyEncoder.encode_value(right), left_str)
elif op is Operator.EndsWith:
return u'{0} like \'%{1}\''.format(left_str, right)
elif op is Operator.Is:
Expand Down Expand Up @@ -412,7 +412,7 @@ def build_props(self, params, prop_names=None, for_iterator=False):
def build_wheres(self, params):
kw_filters = params.get('kw_filters')
kw_where = [u' and '.join(u'{0}={1}'
.format(k, PropertyEncoder.encode(v))
.format(PropertyEncoder.encode_name(k), PropertyEncoder.encode_value(v))
for k,v in kw_filters.items())] if kw_filters else []

filter_exp = params.get('filter')
Expand Down
4 changes: 2 additions & 2 deletions pyorient/ogm/query_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class ArgConverter(object):
@staticmethod
def convert_to(conversion, arg, for_query):
if conversion is ArgConverter.Label:
return '{}'.format(PropertyEncoder.encode(arg))
return '{}'.format(PropertyEncoder.encode_value(arg))
elif conversion is ArgConverter.Expression:
if isinstance(arg, LogicalConnective):
return '\'{}\''.format(for_query.filter_string(arg))
Expand Down Expand Up @@ -46,7 +46,7 @@ def convert_to(conversion, arg, for_query):
elif isinstance(arg, What):
return for_query.build_what(arg)
else:
return PropertyEncoder.encode(arg)
return PropertyEncoder.encode_value(arg)
elif conversion is ArgConverter.Boolean:
if isinstance(arg, What):
return for_query.build_what(arg)
Expand Down
8 changes: 8 additions & 0 deletions tests/test_ogm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ def testGraph(self):

assert rat == queried_rat

invalid_query_args = {'name': 'rat', 'name="rat" OR 1': 1}
try:
g.animals.query(**invalid_query_args).all()
except:
pass
else:
assert False and 'Invalid params did not raise an exception!'

queried_mouse = g.query(mouse).one()
assert mouse == queried_mouse
assert mouse == g.get_vertex(mouse._id)
Expand Down

0 comments on commit b7535bd

Please # to comment.