Skip to content

Allowed to overwrite resource id in serializer #1127

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 13, 2023
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Jeppe Fihl-Pearson <jeppe@tenzer.dk>
Jerel Unruh <mail@unruhdesigns.com>
Jonas Kiefer <https://github.com/jokiefer>
Jonas Metzener <jonas.metzener@adfinis.com>
Jonathan Hiles <jonathan@hil.es>
Jonathan Senecal <contact@jonathansenecal.com>
Joseba Mendivil <git@jma.email>
Kal <kal+oss@tedspot.com>
Expand Down
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,23 @@ any parts of the framework not mentioned in the documentation should generally b
* Replaced `OrderedDict` with `dict` which is also ordered since Python 3.7.
* Compound document "include" parameter is only included in the OpenAPI schema if serializer
implements `included_serializers`.
* Allowed overwriting of resource id by defining an `id` field on the serializer.

Example:
```python
class CustomIdSerializer(serializers.Serializer):
id = serializers.CharField(source='name')
body = serializers.CharField()
```

* Allowed overwriting resource id on resource related fields by creating custom `ResourceRelatedField`.

Example:
```python
class CustomResourceRelatedField(relations.ResourceRelatedField):
def get_resource_id(self, value):
return value.name
```

### Fixed

Expand Down
38 changes: 38 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,44 @@ class MyModelSerializer(serializers.ModelSerializer):
# ...
```

### Overwriting the resource object's id

Per default the primary key property `pk` on the instance is used as the resource identifier.

It is possible to overwrite the resource id by defining an `id` field on the serializer like:

```python
class UserSerializer(serializers.ModelSerializer):
id = serializers.CharField(source='email')
name = serializers.CharField()

class Meta:
model = User
```

This also works on generic serializers.

In case you also use a model as a resource related field make sure to overwrite `get_resource_id` by creating a custom `ResourceRelatedField` class:

```python
class UserResourceRelatedField(ResourceRelatedField):
def get_resource_id(self, value):
return value.email

class GroupSerializer(serializers.ModelSerializer):
user = UserResourceRelatedField(queryset=User.objects)
name = serializers.CharField()

class Meta:
model = Group
```

<div class="warning">
<strong>Note:</strong>
When using different id than primary key, make sure that your view
manages it properly by overwriting `get_object`.
</div>

### Setting resource identifier object type

You may manually set resource identifier object type by using `resource_name` property on views, serializers, or
Expand Down
14 changes: 9 additions & 5 deletions rest_framework_json_api/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,21 @@ def to_internal_value(self, data):
return super().to_internal_value(data["id"])

def to_representation(self, value):
if getattr(self, "pk_field", None) is not None:
pk = self.pk_field.to_representation(value.pk)
else:
pk = value.pk

pk = self.get_resource_id(value)
resource_type = self.get_resource_type_from_included_serializer()
if resource_type is None or not self._skip_polymorphic_optimization:
resource_type = get_resource_type_from_instance(value)

return {"type": resource_type, "id": str(pk)}

def get_resource_id(self, value):
"""
Get resource id of related field.

Per default pk of value is returned.
"""
return super().to_representation(value)

def get_resource_type_from_included_serializer(self):
"""
Check to see it this resource has a different resource_name when
Expand Down
3 changes: 1 addition & 2 deletions rest_framework_json_api/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,9 @@ def build_json_resource_obj(
# Determine type from the instance if the underlying model is polymorphic
if force_type_resolution:
resource_name = utils.get_resource_type_from_instance(resource_instance)
resource_id = force_str(resource_instance.pk) if resource_instance else None
resource_data = {
"type": resource_name,
"id": resource_id,
"id": utils.get_resource_id(resource_instance, resource),
"attributes": cls.extract_attributes(fields, resource),
}
relationships = cls.extract_relationships(fields, resource, resource_instance)
Expand Down
13 changes: 13 additions & 0 deletions rest_framework_json_api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,19 @@ def get_resource_type_from_serializer(serializer):
)


def get_resource_id(resource_instance, resource):
"""Returns the resource identifier for a given instance (`id` takes priority over `pk`)."""
if resource and "id" in resource:
return resource["id"] and encoding.force_str(resource["id"]) or None
if resource_instance:
return (
hasattr(resource_instance, "pk")
and encoding.force_str(resource_instance.pk)
or None
)
return None


def get_included_resources(request, serializer=None):
"""Build a list of included resources."""
include_resources_param = request.query_params.get("include") if request else None
Expand Down
25 changes: 24 additions & 1 deletion tests/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
HyperlinkedRelatedField,
SerializerMethodHyperlinkedRelatedField,
)
from rest_framework_json_api.serializers import ModelSerializer, ResourceRelatedField
from rest_framework_json_api.utils import format_link_segment
from rest_framework_json_api.views import RelationshipView
from tests.models import BasicModel
from tests.models import BasicModel, ForeignKeySource, ForeignKeyTarget
from tests.serializers import (
ForeignKeySourceSerializer,
ManyToManySourceReadOnlySerializer,
Expand Down Expand Up @@ -46,6 +47,28 @@ def test_serialize(

assert serializer.data["target"] == expected

def test_get_resource_id(self, foreign_key_target):
class CustomResourceRelatedField(ResourceRelatedField):
def get_resource_id(self, value):
return value.name

class CustomPkFieldSerializer(ModelSerializer):
target = CustomResourceRelatedField(
queryset=ForeignKeyTarget.objects, pk_field="name"
)

class Meta:
model = ForeignKeySource
fields = ("target",)

serializer = CustomPkFieldSerializer(instance={"target": foreign_key_target})
expected = {
"type": "ForeignKeyTarget",
"id": "Target",
}

assert serializer.data["target"] == expected

@pytest.mark.parametrize(
"format_type,pluralize_type,resource_type",
[
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
format_resource_type,
format_value,
get_related_resource_type,
get_resource_id,
get_resource_name,
get_resource_type_from_serializer,
undo_format_field_name,
Expand Down Expand Up @@ -392,6 +393,22 @@ class SerializerWithoutResourceName(serializers.Serializer):
)


@pytest.mark.parametrize(
"resource_instance, resource, expected",
[
(None, None, None),
(object(), {}, None),
(BasicModel(id=5), None, "5"),
(BasicModel(id=9), {}, "9"),
(None, {"id": 11}, "11"),
(object(), {"pk": 11}, None),
(BasicModel(id=6), {"id": 11}, "11"),
],
)
def test_get_resource_id(resource_instance, resource, expected):
assert get_resource_id(resource_instance, resource) == expected


@pytest.mark.parametrize(
"message,pointer,response,result",
[
Expand Down
67 changes: 67 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,50 @@ def test_patch(self, client):
}
}

@pytest.mark.urls(__name__)
def test_post_with_missing_id(self, client):
data = {
"data": {
"id": None,
"type": "custom",
"attributes": {"body": "hello"},
}
}

url = reverse("custom")

response = client.post(url, data=data)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {
"data": {
"type": "custom",
"id": None,
"attributes": {"body": "hello"},
}
}

@pytest.mark.urls(__name__)
def test_patch_with_custom_id(self, client):
data = {
"data": {
"id": 2_193_102,
"type": "custom",
"attributes": {"body": "hello"},
}
}

url = reverse("custom-id")

response = client.patch(url, data=data)
assert response.status_code == status.HTTP_200_OK
assert response.json() == {
"data": {
"type": "custom",
"id": "2176ce", # get_id() -> hex
"attributes": {"body": "hello"},
}
}


# Routing setup

Expand All @@ -202,6 +246,14 @@ class CustomModelSerializer(serializers.Serializer):
id = serializers.IntegerField()


class CustomIdModelSerializer(serializers.Serializer):
id = serializers.SerializerMethodField()
body = serializers.CharField()

def get_id(self, obj):
return hex(obj.id)[2:]


class CustomAPIView(APIView):
parser_classes = [JSONParser]
renderer_classes = [JSONRenderer]
Expand All @@ -211,11 +263,26 @@ def patch(self, request, *args, **kwargs):
serializer = CustomModelSerializer(CustomModel(request.data))
return Response(status=status.HTTP_200_OK, data=serializer.data)

def post(self, request, *args, **kwargs):
serializer = CustomModelSerializer(request.data)
return Response(status=status.HTTP_200_OK, data=serializer.data)


class CustomIdAPIView(APIView):
parser_classes = [JSONParser]
renderer_classes = [JSONRenderer]
resource_name = "custom"

def patch(self, request, *args, **kwargs):
serializer = CustomIdModelSerializer(CustomModel(request.data))
return Response(status=status.HTTP_200_OK, data=serializer.data)


router = SimpleRouter()
router.register(r"basic_models", BasicModelViewSet, basename="basic-model")

urlpatterns = [
path("custom", CustomAPIView.as_view(), name="custom"),
path("custom-id", CustomIdAPIView.as_view(), name="custom-id"),
]
urlpatterns += router.urls