From f38568ea7d694d62e6ed40fa333807122d9bd152 Mon Sep 17 00:00:00 2001 From: Ali Momen Sani Date: Mon, 24 Mar 2025 17:58:44 +0100 Subject: [PATCH] add draft messages --- stream_chat/async_chat/channel.py | 22 ++- stream_chat/async_chat/client.py | 23 ++++ stream_chat/base/channel.py | 41 +++++- stream_chat/base/client.py | 11 ++ stream_chat/channel.py | 25 +++- stream_chat/client.py | 19 +++ stream_chat/tests/async_chat/test_draft.py | 148 +++++++++++++++++++++ stream_chat/tests/test_draft.py | 148 +++++++++++++++++++++ stream_chat/types/draft.py | 14 ++ 9 files changed, 448 insertions(+), 3 deletions(-) create mode 100644 stream_chat/tests/async_chat/test_draft.py create mode 100644 stream_chat/tests/test_draft.py create mode 100644 stream_chat/types/draft.py diff --git a/stream_chat/async_chat/channel.py b/stream_chat/async_chat/channel.py index 6c54c00b..a66fb7f9 100644 --- a/stream_chat/async_chat/channel.py +++ b/stream_chat/async_chat/channel.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Iterable, List, Union +from typing import Any, Dict, Iterable, List, Optional, Union from stream_chat.base.channel import ChannelInterface, add_user_id from stream_chat.base.exceptions import StreamChannelException @@ -247,3 +247,23 @@ async def update_member_partial( payload = {"set": to_set or {}, "unset": to_unset or []} return await self.client.patch(f"{self.url}/member/{user_id}", data=payload) + + async def create_draft(self, message: Dict, user_id: str) -> StreamResponse: + payload = {"message": add_user_id(message, user_id)} + return await self.client.post(f"{self.url}/draft", data=payload) + + async def delete_draft( + self, user_id: str, parent_id: Optional[str] = None + ) -> StreamResponse: + params = {"user_id": user_id} + if parent_id: + params["parent_id"] = parent_id + return await self.client.delete(f"{self.url}/draft", params=params) + + async def get_draft( + self, user_id: str, parent_id: Optional[str] = None + ) -> StreamResponse: + params = {"user_id": user_id} + if parent_id: + params["parent_id"] = parent_id + return await self.client.get(f"{self.url}/draft", params=params) diff --git a/stream_chat/async_chat/client.py b/stream_chat/async_chat/client.py index 44b3763d..f76987f7 100644 --- a/stream_chat/async_chat/client.py +++ b/stream_chat/async_chat/client.py @@ -21,6 +21,7 @@ from stream_chat.async_chat.segment import Segment from stream_chat.types.base import SortParam from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions +from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions from stream_chat.types.segment import ( QuerySegmentsOptions, QuerySegmentTargetsOptions, @@ -825,6 +826,28 @@ async def unread_counts(self, user_id: str) -> StreamResponse: async def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse: return await self.post("unread_batch", data={"user_ids": user_ids}) + async def query_drafts( + self, + user_id: str, + filter: Optional[QueryDraftsFilter] = None, + sort: Optional[List[SortParam]] = None, + options: Optional[QueryDraftsOptions] = None, + ) -> StreamResponse: + data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = { + "user_id": user_id + } + + if filter is not None: + data["filter"] = cast(dict, filter) + + if sort is not None: + data["sort"] = cast(dict, sort) + + if options is not None: + data.update(cast(dict, options)) + + return await self.post("drafts/query", data=data) + async def close(self) -> None: await self.session.close() diff --git a/stream_chat/base/channel.py b/stream_chat/base/channel.py index 8bde162e..8e300a57 100644 --- a/stream_chat/base/channel.py +++ b/stream_chat/base/channel.py @@ -1,5 +1,5 @@ import abc -from typing import Any, Awaitable, Dict, Iterable, List, Union +from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union from stream_chat.base.client import StreamChatInterface from stream_chat.base.exceptions import StreamChannelException @@ -488,6 +488,45 @@ def update_member_partial( """ pass + @abc.abstractmethod + def create_draft( + self, message: Dict, user_id: str + ) -> Union[StreamResponse, Awaitable[StreamResponse]]: + """ + Creates or updates a draft message in a channel. + + :param message: The message object + :param user_id: The ID of the user creating the draft + :return: The Server Response + """ + pass + + @abc.abstractmethod + def delete_draft( + self, user_id: str, parent_id: Optional[str] = None + ) -> Union[StreamResponse, Awaitable[StreamResponse]]: + """ + Deletes a draft message from a channel. + + :param user_id: The ID of the user who owns the draft + :param parent_id: Optional ID of the parent message if this is a thread draft + :return: The Server Response + """ + pass + + @abc.abstractmethod + def get_draft( + self, user_id: str, parent_id: Optional[str] = None + ) -> Union[StreamResponse, Awaitable[StreamResponse]]: + """ + Retrieves a draft message from a channel. + + :param user_id: The ID of the user who owns the draft + :param parent_id: Optional ID of the parent message if this is a thread draft + :return: The Server Response + """ + pass + def add_user_id(payload: Dict, user_id: str) -> Dict: return {**payload, "user": {"id": user_id}} diff --git a/stream_chat/base/client.py b/stream_chat/base/client.py index c2e0557f..1246eb95 100644 --- a/stream_chat/base/client.py +++ b/stream_chat/base/client.py @@ -9,6 +9,7 @@ from stream_chat.types.base import SortParam from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions +from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions from stream_chat.types.segment import ( QuerySegmentsOptions, QuerySegmentTargetsOptions, @@ -1384,6 +1385,16 @@ def unread_counts_batch( """ pass + @abc.abstractmethod + def query_drafts( + self, + user_id: str, + filter: Optional[QueryDraftsFilter] = None, + sort: Optional[List[SortParam]] = None, + options: Optional[QueryDraftsOptions] = None, + ) -> Union[StreamResponse, Awaitable[StreamResponse]]: + pass + ##################### # Private methods # ##################### diff --git a/stream_chat/channel.py b/stream_chat/channel.py index a9a3df6b..23738112 100644 --- a/stream_chat/channel.py +++ b/stream_chat/channel.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Iterable, List, Union +from typing import Any, Dict, Iterable, List, Optional, Union from stream_chat.base.channel import ChannelInterface, add_user_id from stream_chat.base.exceptions import StreamChannelException @@ -248,3 +248,26 @@ def update_member_partial( payload = {"set": to_set or {}, "unset": to_unset or []} return self.client.patch(f"{self.url}/member/{user_id}", data=payload) + + def create_draft(self, message: Dict, user_id: str) -> StreamResponse: + message["user_id"] = user_id + payload = {"message": message} + return self.client.post(f"{self.url}/draft", data=payload) + + def delete_draft( + self, user_id: str, parent_id: Optional[str] = None + ) -> StreamResponse: + params = {"user_id": user_id} + if parent_id: + params["parent_id"] = parent_id + + return self.client.delete(f"{self.url}/draft", params=params) + + def get_draft( + self, user_id: str, parent_id: Optional[str] = None + ) -> StreamResponse: + params = {"user_id": user_id} + if parent_id: + params["parent_id"] = parent_id + + return self.client.get(f"{self.url}/draft", params=params) diff --git a/stream_chat/client.py b/stream_chat/client.py index 920f4715..48f073fa 100644 --- a/stream_chat/client.py +++ b/stream_chat/client.py @@ -10,6 +10,7 @@ from stream_chat.segment import Segment from stream_chat.types.base import SortParam from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions +from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions from stream_chat.types.segment import ( QuerySegmentsOptions, QuerySegmentTargetsOptions, @@ -782,3 +783,21 @@ def unread_counts(self, user_id: str) -> StreamResponse: def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse: return self.post("unread_batch", data={"user_ids": user_ids}) + + def query_drafts( + self, + user_id: str, + filter: Optional[QueryDraftsFilter] = None, + sort: Optional[List[SortParam]] = None, + options: Optional[QueryDraftsOptions] = None, + ) -> StreamResponse: + data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = { + "user_id": user_id + } + if filter is not None: + data["filter"] = cast(dict, filter) + if sort is not None: + data["sort"] = cast(dict, sort) + if options is not None: + data.update(cast(dict, options)) + return self.post("drafts/query", data=data) diff --git a/stream_chat/tests/async_chat/test_draft.py b/stream_chat/tests/async_chat/test_draft.py new file mode 100644 index 00000000..cd8a0112 --- /dev/null +++ b/stream_chat/tests/async_chat/test_draft.py @@ -0,0 +1,148 @@ +import uuid +from typing import Dict + +import pytest + +from stream_chat.async_chat.channel import Channel +from stream_chat.async_chat.client import StreamChatAsync +from stream_chat.types.base import SortOrder + + +@pytest.mark.incremental +class TestDraft: + async def test_create_draft(self, channel: Channel, random_user: Dict): + draft_message = {"text": "This is a draft message"} + response = await channel.create_draft(draft_message, random_user["id"]) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft message" + assert response["draft"]["channel_cid"] == channel.cid + + async def test_get_draft(self, channel: Channel, random_user: Dict): + # First create a draft + draft_message = {"text": "This is a draft to retrieve"} + await channel.create_draft(draft_message, random_user["id"]) + + # Then get the draft + response = await channel.get_draft(random_user["id"]) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft to retrieve" + assert response["draft"]["channel_cid"] == channel.cid + + async def test_delete_draft(self, channel: Channel, random_user: Dict): + # First create a draft + draft_message = {"text": "This is a draft to delete"} + await channel.create_draft(draft_message, random_user["id"]) + + # Then delete the draft + await channel.delete_draft(random_user["id"]) + + # Verify it's deleted by trying to get it + try: + await channel.get_draft(random_user["id"]) + raise AssertionError("Draft should be deleted") + except Exception: + # Expected behavior, draft should not be found + pass + + async def test_thread_draft(self, channel: Channel, random_user: Dict): + # First create a parent message + msg = await channel.send_message({"text": "Parent message"}, random_user["id"]) + parent_id = msg["message"]["id"] + + # Create a draft reply + draft_reply = {"text": "This is a draft reply", "parent_id": parent_id} + response = await channel.create_draft(draft_reply, random_user["id"]) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft reply" + assert response["draft"]["parent_id"] == parent_id + + # Get the draft reply + response = await channel.get_draft(random_user["id"], parent_id=parent_id) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft reply" + assert response["draft"]["parent_id"] == parent_id + + # Delete the draft reply + await channel.delete_draft(random_user["id"], parent_id=parent_id) + + # Verify it's deleted + try: + await channel.get_draft(random_user["id"], parent_id=parent_id) + raise AssertionError("Thread draft should be deleted") + except Exception: + # Expected behavior + pass + + async def test_query_drafts( + self, client: StreamChatAsync, channel: Channel, random_user: Dict + ): + # Create multiple drafts in different channels + draft1 = {"text": "Draft in channel 1"} + await channel.create_draft(draft1, random_user["id"]) + + # Create another channel with a draft + channel2 = client.channel("messaging", str(uuid.uuid4())) + await channel2.create(random_user["id"]) + + draft2 = {"text": "Draft in channel 2"} + await channel2.create_draft(draft2, random_user["id"]) + + # Query all drafts for the user + response = await client.query_drafts(random_user["id"]) + + assert "drafts" in response + assert len(response["drafts"]) == 2 + + # Query drafts for a specific channel + response = await client.query_drafts( + random_user["id"], filter={"channel_cid": channel2.cid} + ) + + assert "drafts" in response + assert len(response["drafts"]) == 1 + draft = response["drafts"][0] + assert draft["channel_cid"] == channel2.cid + assert draft["message"]["text"] == "Draft in channel 2" + + # Query drafts with sort + response = await client.query_drafts( + random_user["id"], + sort=[{"field": "created_at", "direction": SortOrder.ASC}], + ) + + assert "drafts" in response + assert len(response["drafts"]) == 2 + assert response["drafts"][0]["channel_cid"] == channel.cid + assert response["drafts"][1]["channel_cid"] == channel2.cid + + # Query drafts with pagination + response = await client.query_drafts( + random_user["id"], + options={"limit": 1}, + ) + + assert "drafts" in response + assert len(response["drafts"]) == 1 + assert response["drafts"][0]["channel_cid"] == channel2.cid + + assert response["next"] is not None + + # Query drafts with pagination + response = await client.query_drafts( + random_user["id"], + options={"limit": 1, "next": response["next"]}, + ) + + assert "drafts" in response + assert len(response["drafts"]) == 1 + assert response["drafts"][0]["channel_cid"] == channel.cid + + # Cleanup + try: + await channel2.delete() + except Exception: + pass diff --git a/stream_chat/tests/test_draft.py b/stream_chat/tests/test_draft.py new file mode 100644 index 00000000..5d0d9888 --- /dev/null +++ b/stream_chat/tests/test_draft.py @@ -0,0 +1,148 @@ +import uuid +from typing import Dict + +import pytest + +from stream_chat import StreamChat +from stream_chat.channel import Channel +from stream_chat.types.base import SortOrder + + +@pytest.mark.incremental +class TestDraft: + def test_create_draft(self, channel: Channel, random_user: Dict): + draft_message = {"text": "This is a draft message"} + response = channel.create_draft(draft_message, random_user["id"]) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft message" + assert response["draft"]["channel_cid"] == channel.cid + + def test_get_draft(self, channel: Channel, random_user: Dict): + # First create a draft + draft_message = {"text": "This is a draft to retrieve"} + channel.create_draft(draft_message, random_user["id"]) + + # Then get the draft + response = channel.get_draft(random_user["id"]) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft to retrieve" + assert response["draft"]["channel_cid"] == channel.cid + + def test_delete_draft(self, channel: Channel, random_user: Dict): + # First create a draft + draft_message = {"text": "This is a draft to delete"} + channel.create_draft(draft_message, random_user["id"]) + + # Then delete the draft + channel.delete_draft(random_user["id"]) + + # Verify it's deleted by trying to get it + try: + channel.get_draft(random_user["id"]) + raise AssertionError("Draft should be deleted") + except Exception: + # Expected behavior, draft should not be found + pass + + def test_thread_draft(self, channel: Channel, random_user: Dict): + # First create a parent message + msg = channel.send_message({"text": "Parent message"}, random_user["id"]) + parent_id = msg["message"]["id"] + + # Create a draft reply + draft_reply = {"text": "This is a draft reply", "parent_id": parent_id} + response = channel.create_draft(draft_reply, random_user["id"]) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft reply" + assert response["draft"]["parent_id"] == parent_id + + # Get the draft reply + response = channel.get_draft(random_user["id"], parent_id=parent_id) + + assert "draft" in response + assert response["draft"]["message"]["text"] == "This is a draft reply" + assert response["draft"]["parent_id"] == parent_id + + # Delete the draft reply + channel.delete_draft(random_user["id"], parent_id=parent_id) + + # Verify it's deleted + try: + channel.get_draft(random_user["id"], parent_id=parent_id) + raise AssertionError("Thread draft should be deleted") + except Exception: + # Expected behavior + pass + + def test_query_drafts( + self, client: StreamChat, channel: Channel, random_user: Dict + ): + # Create multiple drafts in different channels + draft1 = {"text": "Draft in channel 1"} + channel.create_draft(draft1, random_user["id"]) + + # Create another channel with a draft + channel2 = client.channel("messaging", str(uuid.uuid4())) + channel2.create(random_user["id"]) + + draft2 = {"text": "Draft in channel 2"} + channel2.create_draft(draft2, random_user["id"]) + + # Query all drafts for the user + response = client.query_drafts(random_user["id"]) + + assert "drafts" in response + assert len(response["drafts"]) == 2 + + # Query drafts for a specific channel + response = client.query_drafts( + random_user["id"], filter={"channel_cid": channel2.cid} + ) + + assert "drafts" in response + assert len(response["drafts"]) == 1 + draft = response["drafts"][0] + assert draft["channel_cid"] == channel2.cid + assert draft["message"]["text"] == "Draft in channel 2" + + # Query drafts with sort + response = client.query_drafts( + random_user["id"], + sort=[{"field": "created_at", "direction": SortOrder.ASC}], + ) + + assert "drafts" in response + assert len(response["drafts"]) == 2 + assert response["drafts"][0]["channel_cid"] == channel.cid + assert response["drafts"][1]["channel_cid"] == channel2.cid + + # Query drafts with pagination + response = client.query_drafts( + random_user["id"], + options={"limit": 1}, + ) + + assert "drafts" in response + assert len(response["drafts"]) == 1 + assert response["drafts"][0]["channel_cid"] == channel2.cid + + assert response["next"] is not None + + # Query drafts with pagination + response = client.query_drafts( + random_user["id"], + options={"limit": 1, "next": response["next"]}, + ) + + assert "drafts" in response + assert len(response["drafts"]) == 1 + assert response["drafts"][0]["channel_cid"] == channel.cid + + # cleanup + try: + channel2.delete() + except Exception: + pass diff --git a/stream_chat/types/draft.py b/stream_chat/types/draft.py new file mode 100644 index 00000000..94c863e9 --- /dev/null +++ b/stream_chat/types/draft.py @@ -0,0 +1,14 @@ +from datetime import datetime +from typing import Optional, TypedDict + +from stream_chat.types.base import Pager + + +class QueryDraftsFilter(TypedDict): + channel_cid: Optional[str] + parent_id: Optional[str] + created_at: Optional[datetime] + + +class QueryDraftsOptions(Pager): + pass