From c59a834c370881a744d4445f4f5d55ceb0455700 Mon Sep 17 00:00:00 2001 From: Michel Van den Bergh Date: Sat, 16 Mar 2024 18:38:43 +0000 Subject: [PATCH] Move all schemas to a single file schemas.py This reduces clutter and avoids duplicate definitions. Also: The run schema has been tightened. Various tweaks in other schemas. Requires upgrading the vtjson package. --- server/fishtest/actiondb.py | 175 +----------- server/fishtest/api.py | 91 +----- server/fishtest/rundb.py | 274 ++---------------- server/fishtest/schemas.py | 533 ++++++++++++++++++++++++++++++++++++ server/fishtest/userdb.py | 18 +- server/fishtest/views.py | 14 +- server/fishtest/workerdb.py | 13 +- server/tests/test_api.py | 2 +- server/tests/test_nn.py | 77 ++++++ server/tests/test_rundb.py | 4 +- 10 files changed, 668 insertions(+), 533 deletions(-) create mode 100644 server/fishtest/schemas.py create mode 100644 server/tests/test_nn.py diff --git a/server/fishtest/actiondb.py b/server/fishtest/actiondb.py index b2775b1e4..df0d5b01e 100644 --- a/server/fishtest/actiondb.py +++ b/server/fishtest/actiondb.py @@ -1,170 +1,10 @@ from datetime import datetime, timezone from bson.objectid import ObjectId +from fishtest.schemas import action_schema from fishtest.util import hex_print, worker_name from pymongo import DESCENDING -from vtjson import regex, union, validate - -run_id = regex(r"[a-f0-9]{24}", name="run_id") -run_name = regex(r".*-[a-f0-9]{7}", name="run_name") -short_worker_name = regex(r".*-[\d]+cores-[a-zA-Z0-9]{2,8}", name="short_worker_name") -long_worker_name = regex( - r".*-[\d]+cores-[a-zA-Z0-9]{2,8}-[a-f0-9]{4}\*?", name="long_worker_name" -) - -schema = union( - { - "_id?": ObjectId, - "time": float, - "action": "failed_task", - "username": str, - "worker": long_worker_name, - "run_id": run_id, - "run": run_name, - "task_id": int, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "crash_or_time", - "username": str, - "worker": long_worker_name, - "run_id": run_id, - "run": run_name, - "task_id": int, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "dead_task", - "username": str, - "worker": long_worker_name, - "run_id": run_id, - "run": run_name, - "task_id": int, - }, - { - "_id?": ObjectId, - "time": float, - "action": "system_event", - "username": "fishtest.system", - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "new_run", - "username": str, - "run_id": run_id, - "run": run_name, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "upload_nn", - "username": str, - "nn": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "modify_run", - "username": str, - "run_id": run_id, - "run": run_name, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "delete_run", - "username": str, - "run_id": run_id, - "run": run_name, - }, - { - "_id?": ObjectId, - "time": float, - "action": "stop_run", - "username": str, - "worker": long_worker_name, - "run_id": run_id, - "run": run_name, - "task_id": int, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "stop_run", - "username": str, - "run_id": run_id, - "run": run_name, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "finished_run", - "username": str, - "run_id": run_id, - "run": run_name, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "approve_run", - "username": str, - "run_id": run_id, - "run": run_name, - }, - { - "_id?": ObjectId, - "time": float, - "action": "purge_run", - "username": str, - "run_id": run_id, - "run": run_name, - "message": str, - }, - { - "_id?": ObjectId, - "time": float, - "action": "block_user", - "username": str, - "user": str, - "message": union("blocked", "unblocked"), - }, - { - "_id?": ObjectId, - "time": float, - "action": "accept_user", - "username": str, - "user": str, - "message": "accepted", - }, - { - "_id?": ObjectId, - "time": float, - "action": "block_worker", - "username": str, - "worker": short_worker_name, - "message": union("blocked", "unblocked"), - }, - { - "_id?": ObjectId, - "time": float, - "action": "log_message", - "username": str, - "message": str, - }, -) - -del long_worker_name, run_id, run_name, short_worker_name +from vtjson import ValidationError, validate def run_name(run): @@ -376,5 +216,14 @@ def insert_action(self, **action): if "run_id" in action: action["run_id"] = str(action["run_id"]) action["time"] = datetime.now(timezone.utc).timestamp() - validate(schema, action, "action") # may raise exception + try: + validate(action_schema, action, "action") + except ValidationError: + message = f"Internal Error. Request {str(action)} does not validate" + print(message, flush=True) + self.log_message( + username="fishtest.system", + message=message, + ) + return self.actions.insert_one(action) diff --git a/server/fishtest/api.py b/server/fishtest/api.py index 9b2fb9920..43179523d 100644 --- a/server/fishtest/api.py +++ b/server/fishtest/api.py @@ -4,6 +4,7 @@ import re from datetime import datetime, timezone +from fishtest.schemas import api_access_schema, api_schema from fishtest.stats.stat_util import SPRT_elo from fishtest.util import worker_name from pyramid.httpexceptions import ( @@ -14,7 +15,7 @@ ) from pyramid.response import Response from pyramid.view import exception_view_config, view_config, view_defaults -from vtjson import ValidationError, compile, intersect, interval, lax, regex, validate +from vtjson import ValidationError, validate """ Important note @@ -33,83 +34,6 @@ WORKER_VERSION = 232 -""" -begin api_schema -""" - -run_id = regex(r"[a-f0-9]{24}", name="run_id") -uuid = regex(r"[0-9a-zA-Z]{2,8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}", name="uuid") - -uint = intersect(int, interval(0, ...)) -suint = intersect(int, interval(1, ...)) -ufloat = intersect(float, interval(0.0, ...)) - - -def valid_results(R): - l, d, w = R["losses"], R["draws"], R["wins"] - R = R["pentanomial"] - return ( - l + d + w == 2 * sum(R) - and w - l == 2 * R[4] + R[3] - R[1] - 2 * R[0] - and R[3] + 2 * R[2] + R[1] >= d >= R[3] + R[1] - ) - - -def valid_spsa_results(R): - return R["wins"] + R["losses"] + R["draws"] == R["num_games"] - - -api_schema = { - "password": str, - "run_id?": run_id, - "task_id?": uint, - "pgn?": str, - "message?": str, - "worker_info": { - "uname": str, - "architecture": [str, str], - "concurrency": suint, - "max_memory": suint, - "min_threads": suint, - "username": str, - "version": uint, - "python_version": [uint, uint, uint], - "gcc_version": [uint, uint, uint], - "compiler": {"g++", "clang++"}, - "unique_key": uuid, - "modified": bool, - "near_github_api_limit": bool, - "ARCH": str, - "nps": ufloat, - }, - "spsa?": intersect( - { - "wins": uint, - "losses": uint, - "draws": uint, - "num_games": uint, - }, - valid_spsa_results, - ), - "stats?": intersect( - { - "wins": uint, - "losses": uint, - "draws": uint, - "crashes": uint, - "time_losses": uint, - "pentanomial": [uint, uint, uint, uint, uint], - }, - valid_results, - ), -} - -api_schema = compile(api_schema) - -""" -end api_schema -""" - def validate_request(request): validate(api_schema, request, "request") @@ -180,9 +104,8 @@ def validate_username_password(self, api): self.handle_error("request is not json encoded") # Is the request syntactically correct? - schema = lax({"password": str, "worker_info": {"username": str}}) try: - validate(schema, self.request_body, "request") + validate(api_access_schema, self.request_body, "request") except ValidationError as e: self.handle_error(str(e)) @@ -218,13 +141,11 @@ def validate_request(self, api): self.handle_error("Invalid run_id: {}".format(run_id)) self.__run = run - # if a task_id is present then there should be a run_id, and - # the unique_key should correspond to the unique_key of the - # task + # if a task_id is present then the unique_key should correspond + # to the unique_key of the task + if "task_id" in self.request_body: task_id = self.request_body["task_id"] - if "run_id" not in self.request_body: - self.handle_error("The request has a task_id but no run_id") if task_id < 0 or task_id >= len(run["tasks"]): self.handle_error( diff --git a/server/fishtest/rundb.py b/server/fishtest/rundb.py index 90890faf0..de0ddf623 100644 --- a/server/fishtest/rundb.py +++ b/server/fishtest/rundb.py @@ -17,6 +17,7 @@ from bson.binary import Binary from bson.objectid import ObjectId from fishtest.actiondb import ActionDb +from fishtest.schemas import nn_schema, runs_schema from fishtest.stats.stat_util import SPRT_elo from fishtest.userdb import UserDb from fishtest.util import ( @@ -35,7 +36,7 @@ ) from fishtest.workerdb import WorkerDb from pymongo import DESCENDING, MongoClient -from vtjson import ValidationError, ip_address, number, regex, union, url, validate +from vtjson import ValidationError, validate DEBUG = False @@ -43,198 +44,6 @@ last_rundb = None -# This schema only matches new runs. The old runs are not -# compatible with it. For documentation purposes it would -# also be useful to have a "universal schema" that matches -# all the runs in the db. -# To make this practical we will eventually put all schemas -# in a separate module "schemas.py". - -net_name = regex("nn-[a-f0-9]{12}.nnue", name="net_name") -tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc") -str_int = regex(r"[1-9]\d*", name="str_int") -sha = regex(r"[a-f0-9]{40}", name="sha") -country_code = regex(r"[A-Z][A-Z]", name="country_code") -run_id = regex(r"[a-f0-9]{24}", name="run_id") -uuid = regex(r"[0-9a-zA-Z]{2,8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}", name="uuid") - -worker_info_schema = { - "uname": str, - "architecture": [str, str], - "concurrency": int, - "max_memory": int, - "min_threads": int, - "username": str, - "version": int, - "python_version": [int, int, int], - "gcc_version": [int, int, int], - "compiler": union("clang++", "g++"), - "unique_key": uuid, - "modified": bool, - "ARCH": str, - "nps": number, - "near_github_api_limit": bool, - "remote_addr": ip_address, - "country_code": union(country_code, "?"), -} - -results_schema = { - "wins": int, - "losses": int, - "draws": int, - "crashes": int, - "time_losses": int, - "pentanomial": [int, int, int, int, int], -} - -schema = { - "_id?": ObjectId, - "start_time": datetime, - "last_updated": datetime, - "tc_base": number, - "base_same_as_master": bool, - "rescheduled_from?": run_id, - "approved": bool, - "approver": str, - "finished": bool, - "deleted": bool, - "failed": bool, - "is_green": bool, - "is_yellow": bool, - "workers": int, - "cores": int, - "results": results_schema, - "results_info?": { - "style": str, - "info": [str, ...], - }, - "args": { - "base_tag": str, - "new_tag": str, - "base_nets": [net_name, ...], - "new_nets": [net_name, ...], - "num_games": int, - "tc": tc, - "new_tc": tc, - "book": str, - "book_depth": str_int, - "threads": int, - "resolved_base": sha, - "resolved_new": sha, - "msg_base": str, - "msg_new": str, - "base_options": str, - "new_options": str, - "info": str, - "base_signature": str_int, - "new_signature": str_int, - "username": str, - "tests_repo": url, - "auto_purge": bool, - "throughput": number, - "itp": number, - "priority": number, - "adjudication": bool, - "sprt?": { - "alpha": 0.05, - "beta": 0.05, - "elo0": number, - "elo1": number, - "elo_model": "normalized", - "state": union("", "accepted", "rejected"), - "llr": number, - "batch_size": int, - "lower_bound": -math.log(19), - "upper_bound": math.log(19), - "lost_samples?": int, - "illegal_update?": int, - "overshoot?": { - "last_update": int, - "skipped_updates": int, - "ref0": number, - "m0": number, - "sq0": number, - "ref1": number, - "m1": number, - "sq1": number, - }, - }, - "spsa?": { - "A": number, - "alpha": number, - "gamma": number, - "raw_params": str, - "iter": int, - "num_iter": int, - "params": [ - { - "name": str, - "start": number, - "min": number, - "max": number, - "c_end": number, - "r_end": number, - "c": number, - "a_end": number, - "a": number, - "theta": number, - }, - ..., - ], - "param_history?": [ - [{"theta": number, "R": number, "c": number}, ...], - ..., - ], - }, - }, - "tasks": [ - { - "num_games": int, - "active": bool, - "last_updated": datetime, - "start": int, - "residual?": number, - "residual_color?": str, - "bad?": True, - "stats": results_schema, - "worker_info": worker_info_schema, - }, - ..., - ], - "bad_tasks?": [ - { - "num_games": int, - "active": False, - "last_updated": datetime, - "start": int, - "residual": number, - "residual_color": str, - "bad": True, - "task_id": int, - "stats": results_schema, - "worker_info": worker_info_schema, - }, - ..., - ], -} - -# Avoid leaking too many things into the global scope -del ( - country_code, - ip_address, - number, - regex, - results_schema, - run_id, - sha, - str_int, - tc, - union, - url, - uuid, - worker_info_schema, -) - def get_port(): params = {} @@ -435,7 +244,7 @@ def new_run( new_run["rescheduled_from"] = rescheduled_from try: - validate(schema, new_run, "run") + validate(runs_schema, new_run, "run") except ValidationError as e: message = f"The new run object does not validate: {str(e)}" print(message, flush=True) @@ -469,19 +278,27 @@ def get_run_pgns(self, run_id): return pgns_tar return None - def upload_nn(self, userid, name, nn): - self.nndb.insert_one({"user": userid, "name": name, "downloads": 0}) - return {} - - def update_nn(self, net): - net.pop("downloads", None) - self.nndb.update_one({"name": net["name"]}, {"$set": net}) + def write_nn(self, net): + validate(nn_schema, net, "net") + self.nndb.replace_one({"name": net["name"]}, net, upsert=True) def get_nn(self, name): return self.nndb.find_one({"name": name}, {"nn": 0}) + def upload_nn(self, userid, name): + self.write_nn({"user": userid, "name": name, "downloads": 0}) + + def update_nn(self, net): + net = copy.copy(net) # avoid side effects + net.pop("downloads", None) + old_net = self.get_nn(net["name"]) + old_net.update(net) + self.write_nn(old_net) + def increment_nn_downloads(self, name): - self.nndb.update_one({"name": name}, {"$inc": {"downloads": 1}}) + net = self.get_nn(name) + net["downloads"] += 1 + self.write_nn(net) def get_nns(self, user="", network_name="", master_only=False, limit=0, skip=0): q = {} @@ -1460,8 +1277,6 @@ def count_games(d): # Return. if run_finished: - self.check_results(run, run_id, task_id) - self.stop_run(run_id) # stop run may not actually stop a run because of autopurging! if run["finished"]: @@ -1479,55 +1294,6 @@ def count_games(d): return ret - def check_results(self, run, run_id, task_id): - old = run["results"] - - # Recalculate results from all tasks in run["tasks"]. - new = self.compute_results(run) - - # Log any discrepancies between incremented and recalculated results - for s in ["wins", "losses", "draws", "crashes", "time_losses"]: - if old.get(s, -1) != new.get(s, -1): - info = "Check_results: task {}/{} {} results mismatch: {}/{}".format( - run_id, task_id, s, old.get(s, -1), new.get(s, -1) - ) - self.actiondb.log_message( - username="fishtest.system", - message=info, - ) - print(info, flush=True) - - if ( - "pentanomial" not in old - or "pentanomial" not in new - or len(old["pentanomial"]) < 5 - or len(new["pentanomial"]) < 5 - ): - info = "Check_results: task {}/{} pentanomial length results mismatch: {}/{}".format( - run_id, - task_id, - len(old.get("pentanomial", [])), - len(new.get("pentanomial", [])), - ) - self.actiondb.log_message( - username="fishtest.system", - message=info, - ) - print(info, flush=True) - else: - for i, (old_value, new_value) in enumerate( - zip(old["pentanomial"], new["pentanomial"]) - ): - if old_value != new_value: - info = "Check_results: task {}/{} pentanomial value {} results mismatch: {}/{}".format( - run_id, task_id, i, old_value, new_value - ) - self.actiondb.log_message( - username="fishtest.system", - message=info, - ) - print(info, flush=True) - def failed_task(self, run_id, task_id, message="Unknown reason"): run = self.get_run(run_id) task = run["tasks"][task_id] @@ -1577,7 +1343,7 @@ def stop_run(self, run_id): run["workers"] = 0 run["finished"] = True try: - validate(schema, run, "run") + validate(runs_schema, run, "run") except ValidationError as e: message = f"The run object {run_id} does not validate: {str(e)}" print(message, flush=True) diff --git a/server/fishtest/schemas.py b/server/fishtest/schemas.py new file mode 100644 index 000000000..7d43b5953 --- /dev/null +++ b/server/fishtest/schemas.py @@ -0,0 +1,533 @@ +# This file describes some of the data structures used by Fishtest so that they +# can be statically validated before they are processed further or written +# to the database. +# +# See https://github.com/vdbergh/vtjson for a description of the schema format. + +import copy +import math +from datetime import datetime, timezone + +from bson.objectid import ObjectId +from vtjson import ( + at_least_one_of, + at_most_one_of, + div, + email, + glob, + ifthen, + intersect, + interval, + ip_address, + keys, + lax, + number, + one_of, + quote, + regex, + set_name, + size, + union, + url, +) + +run_id = regex(r"[a-f0-9]{24}", name="run_id") +run_name = intersect(regex(r".*-[a-f0-9]{7}", name="run_name"), size(0, 23 + 1 + 7)) +action_message = intersect(str, size(0, 1024)) +worker_message = intersect(str, size(0, 500)) +short_worker_name = regex(r".*-[\d]+cores-[a-zA-Z0-9]{2,8}", name="short_worker_name") +long_worker_name = regex( + r".*-[\d]+cores-[a-zA-Z0-9]{2,8}-[a-f0-9]{4}\*?", name="long_worker_name" +) +net_name = regex("nn-[a-f0-9]{12}.nnue", name="net_name") +tc = regex(r"([1-9]\d*/)?\d+(\.\d+)?(\+\d+(\.\d+)?)?", name="tc") +str_int = regex(r"[1-9]\d*", name="str_int") +sha = regex(r"[a-f0-9]{40}", name="sha") +uuid = regex(r"[0-9a-zA-Z]{2,8}(-[a-f0-9]{4}){3}-[a-f0-9]{12}", name="uuid") +country_code = regex(r"[A-Z][A-Z]", name="country_code") +epd_file = glob("*.epd", name="epd_file") +pgn_file = glob("*.pgn", name="pgn_file") +even = div(2, name="even") + + +uint = intersect(int, interval(0, ...)) +suint = intersect(int, interval(1, ...)) +ufloat = intersect(float, interval(0.0, ...)) +unumber = intersect(number, interval(0, ...)) + +user_schema = { + "_id?": ObjectId, + "username": str, + "password": str, + "registration_time": datetime, + "pending": bool, + "blocked": bool, + "email": email, + "groups": [str, ...], + "tests_repo": union("", url), + "machine_limit": int, +} + + +worker_schema = { + "_id?": ObjectId, + "worker_name": short_worker_name, + "blocked": bool, + "message": worker_message, + "last_updated": datetime, +} + + +def first_test_before_last(x): + # Pymongo is not timezone aware. Assume dates are UTC. + f = x["first_test"]["date"].replace(tzinfo=timezone.utc) + l = x["last_test"]["date"].replace(tzinfo=timezone.utc) + if f <= l: + return True + else: + raise Exception( + f"The first test at {str(f)} is later than the last test at {str(l)}" + ) + + +nn_schema = intersect( + { + "_id?": ObjectId, + "downloads": uint, + "first_test?": {"date": datetime, "id": run_id}, + "is_master?": True, + "last_test?": {"date": datetime, "id": run_id}, + "name": net_name, + "user": str, + }, + ifthen( + at_least_one_of("first_test", "last_test"), + intersect( + keys("first_test", "last_test"), + first_test_before_last, + ), + ), + ifthen(keys("is_master"), keys("first_test")), +) + +# not yet used, not tested +contributors_schema = { + "_id": ObjectId, + "cpu_hours": unumber, + "diff": unumber, + "games": uint, + "games_per_hour": unumber, + "last_updated": datetime, + "str_last_updated": str, + "tests": uint, + "tests_repo": union(url, ""), + "username": str, +} + + +action_schema = union( + { + "_id?": ObjectId, + "time": float, + "action": "failed_task", + "username": str, + "worker": long_worker_name, + "run_id": run_id, + "run": run_name, + "task_id": int, + "message": action_message, + }, + { + "_id?": ObjectId, + "time": float, + "action": "crash_or_time", + "username": str, + "worker": long_worker_name, + "run_id": run_id, + "run": run_name, + "task_id": int, + "message": action_message, + }, + { + "_id?": ObjectId, + "time": float, + "action": "dead_task", + "username": str, + "worker": long_worker_name, + "run_id": run_id, + "run": run_name, + "task_id": int, + }, + { + "_id?": ObjectId, + "time": float, + "action": "system_event", + "username": "fishtest.system", + "message": action_message, + }, + { + "_id?": ObjectId, + "time": float, + "action": "new_run", + "username": str, + "run_id": run_id, + "run": run_name, + "message": action_message, + }, + { + "_id?": ObjectId, + "time": float, + "action": "upload_nn", + "username": str, + "nn": str, + }, + { + "_id?": ObjectId, + "time": float, + "action": "modify_run", + "username": str, + "run_id": run_id, + "run": run_name, + "message": action_message, + }, + { + "_id?": ObjectId, + "time": float, + "action": "delete_run", + "username": str, + "run_id": run_id, + "run": run_name, + }, + intersect( + { + "_id?": ObjectId, + "time": float, + "action": "stop_run", + "username": str, + "run_id": run_id, + "run": run_name, + "message": action_message, + "worker?": long_worker_name, + "task_id?": int, + }, + ifthen(at_least_one_of("worker", "task_id"), keys("worker", "task_id")), + ), + { + "_id?": ObjectId, + "time": float, + "action": "finished_run", + "username": str, + "run_id": run_id, + "run": run_name, + "message": action_message, + }, + { + "_id?": ObjectId, + "time": float, + "action": "approve_run", + "username": str, + "run_id": run_id, + "run": run_name, + }, + { + "_id?": ObjectId, + "time": float, + "action": "purge_run", + "username": str, + "run_id": run_id, + "run": run_name, + "message": action_message, + }, + { + "_id?": ObjectId, + "time": float, + "action": "block_user", + "username": str, + "user": str, + "message": union("blocked", "unblocked"), + }, + { + "_id?": ObjectId, + "time": float, + "action": "accept_user", + "username": str, + "user": str, + "message": "accepted", + }, + { + "_id?": ObjectId, + "time": float, + "action": "block_worker", + "username": str, + "worker": short_worker_name, + "message": union("blocked", "unblocked"), + }, + { + "_id?": ObjectId, + "time": float, + "action": "log_message", + "username": str, + "message": action_message, + }, +) + + +worker_info_schema_api = { + "uname": str, + "architecture": [str, str], + "concurrency": suint, + "max_memory": uint, + "min_threads": suint, + "username": str, + "version": uint, + "python_version": [uint, uint, uint], + "gcc_version": [uint, uint, uint], + "compiler": union("clang++", "g++"), + "unique_key": uuid, + "modified": bool, + "ARCH": str, + "nps": unumber, + "near_github_api_limit": bool, +} + +worker_info_schema_runs = copy.deepcopy(worker_info_schema_api) +worker_info_schema_runs.update( + {"remote_addr": ip_address, "country_code": union(country_code, "?")} +) + + +def valid_results(R): + l, d, w = R["losses"], R["draws"], R["wins"] + R = R["pentanomial"] + return ( + l + d + w == 2 * sum(R) + and w - l == 2 * R[4] + R[3] - R[1] - 2 * R[0] + and R[3] + 2 * R[2] + R[1] >= d >= R[3] + R[1] + ) + + +results_schema = intersect( + { + "wins": uint, + "losses": uint, + "draws": uint, + "crashes": uint, + "time_losses": uint, + "pentanomial": [uint, uint, uint, uint, uint], + }, + valid_results, +) + + +def valid_spsa_results(R): + return R["wins"] + R["losses"] + R["draws"] == R["num_games"] + + +api_access_schema = lax({"password": str, "worker_info": {"username": str}}) + +api_schema = intersect( + { + "password": str, + "run_id?": run_id, + "task_id?": uint, + "pgn?": str, + "message?": str, + "worker_info": worker_info_schema_api, + "spsa?": intersect( + { + "wins": uint, + "losses": uint, + "draws": uint, + "num_games": intersect(uint, even), + }, + valid_spsa_results, + ), + "stats?": results_schema, + }, + ifthen(keys("task_id"), keys("run_id")), +) + + +zero_results = { + "wins": 0, + "draws": 0, + "losses": 0, + "crashes": 0, + "time_losses": 0, + "pentanomial": 5 * [0], +} + + +if_bad_then_zero_stats_and_not_active = ifthen( + keys("bad"), lax({"active": False, "stats": quote(zero_results)}) +) + + +def final_results_must_match(run): + rr = copy.deepcopy(zero_results) + for t in run["tasks"]: + r = t["stats"] + for k in r: + if k != "pentanomial": + rr[k] += r[k] + else: + for i, p in enumerate(r["pentanomial"]): + rr[k][i] += p + if rr != run["results"]: + raise Exception( + f"The final results {run['results']} do not match the computed results {rr}" + ) + else: + return True + + +# The following schema only matches new runs. The old runs +# are not compatible with it. For documentation purposes +# it would also be useful to have a "universal schema" +# that matches all the runs in the db. + +runs_schema = intersect( + { + "_id?": ObjectId, + "start_time": datetime, + "last_updated": datetime, + "tc_base": unumber, + "base_same_as_master": bool, + "rescheduled_from?": run_id, + "approved": bool, + "approver": str, + "finished": bool, + "deleted": bool, + "failed": bool, + "is_green": bool, + "is_yellow": bool, + "workers": uint, + "cores": uint, + "results": results_schema, + "results_info?": { + "style": str, + "info": [str, ...], + }, + "args": intersect( + { + "base_tag": str, + "new_tag": str, + "base_nets": [net_name, ...], + "new_nets": [net_name, ...], + "num_games": intersect(uint, even), + "tc": tc, + "new_tc": tc, + "book": union(epd_file, pgn_file), + "book_depth": str_int, + "threads": suint, + "resolved_base": sha, + "resolved_new": sha, + "msg_base": str, + "msg_new": str, + "base_options": str, + "new_options": str, + "info": str, + "base_signature": str_int, + "new_signature": str_int, + "username": str, + "tests_repo": url, + "auto_purge": bool, + "throughput": unumber, + "itp": unumber, + "priority": number, + "adjudication": bool, + "sprt?": intersect( + { + "alpha": 0.05, + "beta": 0.05, + "elo0": number, + "elo1": number, + "elo_model": "normalized", + "state": union("", "accepted", "rejected"), + "llr": number, + "batch_size": suint, + "lower_bound": -math.log(19), + "upper_bound": math.log(19), + "lost_samples?": uint, + "illegal_update?": uint, + "overshoot?": { + "last_update": uint, + "skipped_updates": uint, + "ref0": number, + "m0": number, + "sq0": unumber, + "ref1": number, + "m1": number, + "sq1": unumber, + }, + }, + one_of("overshoot", "lost_samples"), + ), + "spsa?": { + "A": unumber, + "alpha": unumber, + "gamma": unumber, + "raw_params": str, + "iter": uint, + "num_iter": uint, + "params": [ + { + "name": str, + "start": number, + "min": number, + "max": number, + "c_end": unumber, + "r_end": unumber, + "c": unumber, + "a_end": unumber, + "a": unumber, + "theta": number, + }, + ..., + ], + "param_history?": [ + [ + {"theta": number, "R": unumber, "c": unumber}, + ..., + ], + ..., + ], + }, + }, + at_most_one_of("sprt", "spsa"), + ), + "tasks": [ + intersect( + { + "num_games": intersect(uint, even), + "active": bool, + "last_updated": datetime, + "start": uint, + "residual?": number, + "residual_color?": str, + "bad?": True, + "stats": results_schema, + "worker_info": worker_info_schema_runs, + }, + if_bad_then_zero_stats_and_not_active, + ), + ..., + ], + "bad_tasks?": [ + { + "num_games": intersect(uint, even), + "active": False, + "last_updated": datetime, + "start": uint, + "residual": number, + "residual_color": str, + "bad": True, + "task_id": uint, + "stats": results_schema, + "worker_info": worker_info_schema_runs, + }, + ..., + ], + }, + final_results_must_match, +) diff --git a/server/fishtest/userdb.py b/server/fishtest/userdb.py index 962893c4c..688bf536d 100644 --- a/server/fishtest/userdb.py +++ b/server/fishtest/userdb.py @@ -4,28 +4,16 @@ from datetime import datetime, timezone from bson.objectid import ObjectId +from fishtest.schemas import user_schema from pymongo import ASCENDING -from vtjson import ValidationError, email, union, url, validate - -schema = { - "_id?": ObjectId, - "username": str, - "password": str, - "registration_time": datetime, - "pending": bool, - "blocked": bool, - "email": email, - "groups": [str, ...], - "tests_repo": union("", url), - "machine_limit": int, -} +from vtjson import ValidationError, validate DEFAULT_MACHINE_LIMIT = 16 def validate_user(user): try: - validate(schema, user, "user") + validate(user_schema, user, "user") except ValidationError as e: print(valid, flush=True) raise diff --git a/server/fishtest/views.py b/server/fishtest/views.py index 9c4d70326..e5f3150a2 100644 --- a/server/fishtest/views.py +++ b/server/fishtest/views.py @@ -11,6 +11,7 @@ import fishtest.stats.stat_util import requests from bson.objectid import ObjectId +from fishtest.schemas import short_worker_name from fishtest.util import ( email_valid, format_bounds, @@ -25,6 +26,7 @@ from pyramid.security import forget, remember from pyramid.view import forbidden_view_config, view_config from requests.exceptions import ConnectionError, HTTPError +from vtjson import ValidationError, union, validate HTTP_TIMEOUT = 15.0 @@ -188,7 +190,15 @@ def workers(request): w["subject"] = f"Issue(s) with worker {w['worker_name']}" worker_name = request.matchdict.get("worker_name") - # TODO. Do more validation of worker names + try: + validate(union(short_worker_name, "show"), worker_name, name="worker_name") + except ValidationError as e: + request.session.flash(str(e), "error") + return { + "show_admin": False, + "show_email": is_approver, + "blocked_workers": blocked_workers, + } if len(worker_name.split("-")) != 3: return { "show_admin": False, @@ -318,7 +328,7 @@ def upload(request): request.session.flash("Network already exists", "error") return {} - request.rundb.upload_nn(request.authenticated_userid, filename, network) + request.rundb.upload_nn(request.authenticated_userid, filename) request.actiondb.upload_nn( username=request.authenticated_userid, diff --git a/server/fishtest/workerdb.py b/server/fishtest/workerdb.py index a9f838e60..59c576a71 100644 --- a/server/fishtest/workerdb.py +++ b/server/fishtest/workerdb.py @@ -1,18 +1,9 @@ from datetime import datetime, timezone from bson.objectid import ObjectId +from fishtest.schemas import worker_schema from vtjson import regex, validate -short_worker_name = regex(r".*-[\d]+cores-[a-zA-Z0-9]{2,8}", name="short_worker_name") - -schema = { - "_id?": ObjectId, - "worker_name": short_worker_name, - "blocked": bool, - "message": str, - "last_updated": datetime, -} - class WorkerDb: def __init__(self, db): @@ -44,7 +35,7 @@ def update_worker(self, worker_name, blocked=None, message=None): "message": message, "last_updated": datetime.now(timezone.utc), } - validate(schema, r, "worker") # may throw exception + validate(worker_schema, r, "worker") # may throw exception self.workers.replace_one({"worker_name": worker_name}, r, upsert=True) def get_blocked_workers(self): diff --git a/server/tests/test_api.py b/server/tests/test_api.py index 3ae264dfe..2a0a80149 100644 --- a/server/tests/test_api.py +++ b/server/tests/test_api.py @@ -27,7 +27,7 @@ def new_run(self, add_tasks=0): num_games, "10+0.01", "10+0.01", - "book", + "book.pgn", "10", 1, "", diff --git a/server/tests/test_nn.py b/server/tests/test_nn.py new file mode 100644 index 000000000..b89776e9b --- /dev/null +++ b/server/tests/test_nn.py @@ -0,0 +1,77 @@ +import unittest +from datetime import datetime, timezone + +from util import get_rundb +from vtjson import ValidationError + + +def show(mc): + exception = mc.exception + print(f"{exception.__class__.__name__}: {str(mc.exception)}") + + +class TestNN(unittest.TestCase): + def setUp(self): + self.rundb = get_rundb() + self.name = "nn-0000000000a0.nnue" + self.user = "user00" + self.first_test = datetime(2024, 1, 1) + self.last_test = datetime(2024, 3, 24) + self.last_test_old = datetime(2023, 3, 24) + self.run_id = "64e74776a170cb1f26fa3930" + + def tearDown(self): + self.rundb.nndb.delete_many({}) + + def test_nn(self): + self.rundb.upload_nn(self.user, self.name) + net = self.rundb.get_nn(self.name) + del net["_id"] + self.assertEqual(net, {"user": self.user, "name": self.name, "downloads": 0}) + self.rundb.increment_nn_downloads(self.name) + net = self.rundb.get_nn(self.name) + del net["_id"] + self.assertEqual(net, {"user": self.user, "name": self.name, "downloads": 1}) + with self.assertRaises(ValidationError) as mc: + new_net = { + "user": self.user, + "name": self.name, + "downloads": 0, + "first_test": {"date": self.first_test, "id": self.run_id}, + "is_master": True, + } + self.rundb.update_nn(new_net) + show(mc) + with self.assertRaises(ValidationError) as mc: + new_net = { + "user": self.user, + "name": self.name, + "downloads": 0, + "is_master": True, + } + self.rundb.update_nn(new_net) + show(mc) + with self.assertRaises(ValidationError) as mc: + new_net = { + "user": self.user, + "name": self.name, + "downloads": 0, + "first_test": {"date": self.first_test, "id": self.run_id}, + "is_master": True, + "last_test": {"date": self.last_test_old, "id": self.run_id}, + } + self.rundb.update_nn(new_net) + show(mc) + new_net = { + "user": self.user, + "name": self.name, + "downloads": 0, + "first_test": {"date": self.first_test, "id": self.run_id}, + "is_master": True, + "last_test": {"date": self.last_test, "id": self.run_id}, + } + self.rundb.update_nn(new_net) + net = self.rundb.get_nn(self.name) + del net["_id"] + new_net["downloads"] = 1 + self.assertEqual(net, new_net) diff --git a/server/tests/test_rundb.py b/server/tests/test_rundb.py index 01fd676f0..6fd0d0e93 100644 --- a/server/tests/test_rundb.py +++ b/server/tests/test_rundb.py @@ -60,7 +60,7 @@ def test_10_create_run(self): num_games, "10+0.01", "10+0.01", - "book", + "book.pgn", "10", 1, "", @@ -101,7 +101,7 @@ def test_10_create_run(self): num_games, "10+0.01", "10+0.01", - "book", + "book.pgn", "10", 1, "",