diff --git a/examples/class_based_views/flask_bb.py b/examples/class_based_views/flask_bb.py new file mode 100644 index 00000000..47549653 --- /dev/null +++ b/examples/class_based_views/flask_bb.py @@ -0,0 +1,47 @@ +class Login(MethodView): + decorators = [anonymous_required] + + def __init__(self, authentication_manager_factory): + self.authentication_manager_factory = authentication_manager_factory + + def form(self): + if enforce_recaptcha(limiter): + return LoginRecaptchaForm() + return LoginForm() + + def get(self): + return render_template("auth/login.html", form=self.form()) + + def post(self): + foo = request.args.get('next') + # return redirect(foo) + return redirect(request.args.get('next')) + # form = self.form() + # if form.validate_on_submit(): + # auth_manager = self.authentication_manager_factory() + # try: + # user = auth_manager.authenticate( + # identifier=form.login.data, secret=form.password.data + # ) + # login_user(user, remember=form.remember_me.data) + # return redirect_or_next(url_for("forum.index")) + # except StopAuthentication as e: + # flash(e.reason, "danger") + # except Exception: + # flash(_("Unrecoverable error while handling login")) + + # return render_template("auth/login.html", form=form) + + +def redirect_or_next(endpoint, **kwargs): + """Redirects the user back to the page they were viewing or to a specified + endpoint. Wraps Flasks :func:`Flask.redirect` function. + :param endpoint: The fallback endpoint. + """ + return redirect(request.args.get('next')) + # return redirect( + # request.args.get('next') or endpoint, **kwargs + # ) + + + diff --git a/examples/class_based_views/foo_bar.py b/examples/class_based_views/foo_bar.py new file mode 100644 index 00000000..b99ffe1b --- /dev/null +++ b/examples/class_based_views/foo_bar.py @@ -0,0 +1,23 @@ +# ipdb> class_definition.name +# 'Foo' +class Foo(): + def bar(evil_param): + command = 'echo ' + evil_param + ' >> ' + 'menu.txt' + + subprocess.call(command, shell=True) + + with open('menu.txt','r') as f: + menu = f.read() + + return render_template('command_injection.html', menu=menu) + + +def menu(param): + command = 'echo ' + param + ' >> ' + 'menu.txt' + + subprocess.call(command, shell=True) + + with open('menu.txt','r') as f: + menu = f.read() + + return render_template('command_injection.html', menu=menu) diff --git a/pyt/web_frameworks/framework_adaptor.py b/pyt/web_frameworks/framework_adaptor.py index 2bc4d7ee..d8a4c469 100644 --- a/pyt/web_frameworks/framework_adaptor.py +++ b/pyt/web_frameworks/framework_adaptor.py @@ -76,9 +76,28 @@ def find_route_functions_taint_args(self): Yields: CFG of each route function, with args marked as tainted. """ - for definition in _get_func_nodes(): - if self.is_route_function(definition.node): - yield self.get_func_cfg_with_tainted_args(definition) + for class_definition in _get_class_nodes(): + # import ipdb + # ipdb.set_trace() + # print(f'class_definition is {class_definition}') + # print(f'class_definition.name is {class_definition.name}') + + # ipdb> class_definition.module_definitions.definitions[0].name + # 'Foo' + # ipdb> class_definition.module_definitions.definitions[1].name + # 'Foo.bar' + # ipdb> class_definition.module_definitions.definitions[2].name + # 'menu' + + for definition in class_definition.module_definitions.definitions: + # print(f'definition inside class is {definition}') + # print(f'definition.name inside class is {definition.name}') + if definition.name.endswith('.get') or definition.name.endswith('.post'): + print(f'adding {definition.name}') + # print() + yield self.get_func_cfg_with_tainted_args(definition) + # if self.is_route_function(class_definition.node): + # yield self.get_func_cfg_with_tainted_args(class_definition) def run(self): """Run find_route_functions_taint_args on each CFG.""" @@ -92,3 +111,9 @@ def _get_func_nodes(): """Get all function nodes.""" return [definition for definition in project_definitions.values() if isinstance(definition.node, ast.FunctionDef)] + +def _get_class_nodes(): + """Get all function nodes.""" + return [definition for definition in project_definitions.values() + if isinstance(definition.node, ast.ClassDef)] +