[DRF]PrimaryKeyRelatedFieldを高速化するスニペット


はじめに

シリアライザでPrimaryKeyRelatedFieldmany=Trueにすると、pkの配列をリクエストパラメータで渡せますが、pkが大量にあるとめちゃくちゃ遅くなります
そのため原因の調査と高速化方法を考えました

原因

結論これです

ManyRelatedField.py
def to_internal_value(self, data):
    if isinstance(data, str) or not hasattr(data, '__iter__'):
        self.fail('not_a_list', input_type=type(data).__name__)
    if not self.allow_empty and len(data) == 0:
        self.fail('empty')

    return [
        self.child_relation.to_internal_value(item)
        for item in data
    ]

配列分self.child_relation.to_internal_value(item)が呼ばれてるため(to_internal_valueの中でpkをgetしている)遅くなっていました

高速化するスニペット

from rest_framework import serializers
from rest_framework.relations import MANY_RELATION_KWARGS, ManyRelatedField


class PrimaryKeyRelatedFieldEx(serializers.PrimaryKeyRelatedField):
    def __init__(self, **kwargs):
        self.queryset_response = kwargs.pop('queryset_response', False)
        super().__init__(**kwargs)

    class _ManyRelatedFieldEx(ManyRelatedField):
        def to_internal_value(self, data):
            if isinstance(data, str) or not hasattr(data, '__iter__'):
                self.fail('not_a_list', input_type=type(data).__name__)
            if not self.allow_empty and len(data) == 0:
                self.fail('empty')
            return self.child_relation.to_internal_value(data)

    @classmethod
    def many_init(cls, *args, **kwargs):
        list_kwargs = {'child_relation': cls(*args, **kwargs)}
        for key in kwargs:
            if key in MANY_RELATION_KWARGS:
                list_kwargs[key] = kwargs[key]
        return cls._ManyRelatedFieldEx(**list_kwargs)

    def to_internal_value(self, data):
        if isinstance(data, list):
            if self.pk_field is not None:
                data = self.pk_field.to_internal_value(data)
            results = self.get_queryset().filter(pk__in=data)
            # 全てのデータがあるかチェックする
            pk_list = results.values_list('pk', flat=True)
            pk_list = [str(n) for n in pk_list]
            data_list = [str(n) for n in data]
            diff = list(set(data_list) - set(list(pk_list)))
            if len(diff) > 0:
                pk_value = ', '.join(map(str, diff))
                self.fail('does_not_exist', pk_value=pk_value)
            if self.queryset_response:
                return results
            else:
                return list(results)
        else:
            return super().to_internal_value(data)

解説

  • 一つずつgetしていたのをfilterのinで取得するように変更
  • 存在しないpkが含まれていた場合のエラーメッセージは一つしか表示されなかったのを、カンマ区切りで全て表示するように変更
  • queryset_responseのパラメータを追加しています。queryset_response=Trueにするとレスポンスがquerysetになります(今まではquerysetの配列。個人的にはquerysetになってた方が使いやすいんじゃないかなと思う)