diff --git a/server/mergin/sync/public_api_v2_controller.py b/server/mergin/sync/public_api_v2_controller.py index 3e28aa40..47ce4b3a 100644 --- a/server/mergin/sync/public_api_v2_controller.py +++ b/server/mergin/sync/public_api_v2_controller.py @@ -49,7 +49,7 @@ from .storages.disk import move_to_tmp, save_to_file from .utils import get_device_id, get_ip, get_user_agent, get_chunk_location from .workspace import WorkspaceRole -from ..utils import parse_order_params +from ..utils import parse_order_params, get_schema_fields_map @auth_required @@ -437,11 +437,15 @@ def list_workspace_projects(workspace_id, page, per_page, order_params=None, q=N projects = projects.filter(Project.name.ilike(f"%{q}%")) if order_params: - order_by_params = parse_order_params(Project, order_params) + schema_map = get_schema_fields_map(ProjectSchemaV2) + order_by_params = parse_order_params( + Project, order_params, field_map=schema_map + ) projects = projects.order_by(*order_by_params) - result = projects.paginate(page, per_page).items - total = projects.paginate(page, per_page).total + pagination = projects.paginate(page=page, per_page=per_page) + result = pagination.items + total = pagination.total data = ProjectSchemaV2(many=True).dump(result) return jsonify(projects=data, count=total, page=page, per_page=per_page), 200 diff --git a/server/mergin/tests/test_public_api_v2.py b/server/mergin/tests/test_public_api_v2.py index e058c589..49b62046 100644 --- a/server/mergin/tests/test_public_api_v2.py +++ b/server/mergin/tests/test_public_api_v2.py @@ -643,6 +643,17 @@ def test_list_workspace_projects(client): url + f"?page={page}&per_page={per_page}&q=1&order_params=created DESC" ) assert response.json["projects"][0]["name"] == "project_10" + # using field name instead column names for sorting + p4 = Project.query.filter(Project.name == project_name).first() + p4.disk_usage = 1234567 + db.session.commit() + response = client.get(url + f"?page=1&per_page=10&order_params=size DESC") + resp_data = json.loads(response.data) + assert resp_data["projects"][0]["name"] == project_name + + # invalid order param + response = client.get(url + f"?page=1&per_page=10&order_params=invalid DESC") + assert response.status_code == 200 # no permissions to workspace user2 = add_user("user", "password") diff --git a/server/mergin/tests/test_utils.py b/server/mergin/tests/test_utils.py index bf5f4666..00b3e1c6 100644 --- a/server/mergin/tests/test_utils.py +++ b/server/mergin/tests/test_utils.py @@ -7,6 +7,7 @@ import json import pytest from flask import url_for, current_app +from marshmallow import Schema, fields from sqlalchemy import desc import os from unittest.mock import patch @@ -14,7 +15,7 @@ from pygeodiff import GeoDiff from pathlib import PureWindowsPath -from ..utils import save_diagnostic_log_file +from ..utils import save_diagnostic_log_file, get_schema_fields_map from ..sync.utils import ( is_reserved_word, @@ -297,3 +298,27 @@ def test_save_diagnostic_log_file(client, app): with open(saved_file_path, "r") as f: content = f.read() assert content == body.decode("utf-8") + + +def test_get_schema_fields_map(): + """Test that schema map correctly resolves DB attributes, keeps all fields, and ignores virtual fields.""" + + # dummy schema for testing + class TestSchema(Schema): + # standard field -> map 'name': 'name' + name = fields.String() + # aliased field -> map 'size': 'disk_usage + size = fields.Integer(attribute="disk_usage") + # virtual fields -> skip + version = fields.Function(lambda obj: "v1") + role = fields.Method("get_role") + # excluded field - set to None in schema inheritance -> skip + hidden_field = None + + schema_map = get_schema_fields_map(TestSchema) + + expected_map = { + "name": "name", + "size": "disk_usage", + } + assert schema_map == expected_map diff --git a/server/mergin/utils.py b/server/mergin/utils.py index 9acc6124..7b062770 100644 --- a/server/mergin/utils.py +++ b/server/mergin/utils.py @@ -1,6 +1,8 @@ # Copyright (C) Lutra Consulting Limited # # SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-MerginMaps-Commercial +import logging + import math from collections import namedtuple from datetime import datetime, timedelta, timezone @@ -8,11 +10,11 @@ import os from flask import current_app from flask_sqlalchemy import Model +from marshmallow import Schema, fields from pathvalidate import sanitize_filename from sqlalchemy import Column, JSON from sqlalchemy.sql.elements import UnaryExpression -from typing import Optional - +from typing import Optional, Type OrderParam = namedtuple("OrderParam", "name direction") @@ -33,7 +35,7 @@ def split_order_param(order_param: str) -> Optional[OrderParam]: def get_order_param( - cls: Model, order_param: OrderParam, json_sort: dict = None + cls: Model, order_param: OrderParam, json_sort: dict = None, field_map: dict = None ) -> Optional[UnaryExpression]: """Return order by clause parameter for SQL query @@ -43,15 +45,22 @@ def get_order_param( :type order_param: OrderParam :param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None :type json_sort: dict + :param field_map: mapping for translating public field names to internal DB columns, e.g. '{"size": "disk_usage"}' + :type field_map: dict """ + # translate field name to column name + db_column_name = order_param.name + if field_map and order_param.name in field_map: + db_column_name = field_map[order_param.name] # find candidate for nested json sort - if "." in order_param.name: - col, attr = order_param.name.split(".") + if "." in db_column_name: + col, attr = db_column_name.split(".") else: - col = order_param.name + col = db_column_name attr = None order_attr = cls.__table__.c.get(col, None) if not isinstance(order_attr, Column): + logging.warning("Ignoring invalid order parameter.") return # sort by key in JSON field if attr: @@ -80,7 +89,9 @@ def get_order_param( return order_attr.desc() -def parse_order_params(cls: Model, order_params: str, json_sort: dict = None): +def parse_order_params( + cls: Model, order_params: str, json_sort: dict = None, field_map: dict = None +) -> list[UnaryExpression]: """Convert order parameters in query string to list of order by clauses. :param cls: Db model class @@ -89,6 +100,8 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None): :type order_params: str :param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None :type json_sort: dict + :param field_map: mapping response fields to database column names, e.g. '{"size": "disk_usage"}' + :type field_map: dict :rtype: List[Column] """ @@ -97,7 +110,7 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None): order_param = split_order_param(p) if not order_param: continue - order_attr = get_order_param(cls, order_param, json_sort) + order_attr = get_order_param(cls, order_param, json_sort, field_map) if order_attr is not None: order_by_params.append(order_attr) return order_by_params @@ -135,3 +148,27 @@ def save_diagnostic_log_file(app: str, username: str, body: bytes) -> str: f.write(content) return file_name + + +def get_schema_fields_map(schema: Type[Schema]) -> dict: + """ + Creates a mapping of schema field names to corresponding DB columns. + This allows sorting by the API field name (e.g. 'size') while + actually sorting by the database column (e.g. 'disk_usage'). + """ + mapping = {} + for name, field in schema._declared_fields.items(): + # some fields could have been overridden with None to be excluded + if not field: + continue + # skip virtual fields as DB cannot sort by them + if isinstance( + field, (fields.Function, fields.Method, fields.Nested, fields.List) + ): + continue + if field.attribute: + mapping[name] = field.attribute + # keep the map complete + else: + mapping[name] = name + return mapping