Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion stream_chat/async_chat/channel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Iterable, List, Union
from typing import Any, Dict, Iterable, List, Optional, Union

from stream_chat.base.channel import ChannelInterface, add_user_id
from stream_chat.base.exceptions import StreamChannelException
Expand Down Expand Up @@ -247,3 +247,23 @@ async def update_member_partial(

payload = {"set": to_set or {}, "unset": to_unset or []}
return await self.client.patch(f"{self.url}/member/{user_id}", data=payload)

async def create_draft(self, message: Dict, user_id: str) -> StreamResponse:
payload = {"message": add_user_id(message, user_id)}
return await self.client.post(f"{self.url}/draft", data=payload)

async def delete_draft(
self, user_id: str, parent_id: Optional[str] = None
) -> StreamResponse:
params = {"user_id": user_id}
if parent_id:
params["parent_id"] = parent_id
return await self.client.delete(f"{self.url}/draft", params=params)

async def get_draft(
self, user_id: str, parent_id: Optional[str] = None
) -> StreamResponse:
params = {"user_id": user_id}
if parent_id:
params["parent_id"] = parent_id
return await self.client.get(f"{self.url}/draft", params=params)
23 changes: 23 additions & 0 deletions stream_chat/async_chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from stream_chat.async_chat.segment import Segment
from stream_chat.types.base import SortParam
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
from stream_chat.types.segment import (
QuerySegmentsOptions,
QuerySegmentTargetsOptions,
Expand Down Expand Up @@ -825,6 +826,28 @@ async def unread_counts(self, user_id: str) -> StreamResponse:
async def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse:
return await self.post("unread_batch", data={"user_ids": user_ids})

async def query_drafts(
self,
user_id: str,
filter: Optional[QueryDraftsFilter] = None,
sort: Optional[List[SortParam]] = None,
options: Optional[QueryDraftsOptions] = None,
) -> StreamResponse:
data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = {
"user_id": user_id
}

if filter is not None:
data["filter"] = cast(dict, filter)

if sort is not None:
data["sort"] = cast(dict, sort)

if options is not None:
data.update(cast(dict, options))

return await self.post("drafts/query", data=data)

async def close(self) -> None:
await self.session.close()

Expand Down
41 changes: 40 additions & 1 deletion stream_chat/base/channel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Any, Awaitable, Dict, Iterable, List, Union
from typing import Any, Awaitable, Dict, Iterable, List, Optional, Union

from stream_chat.base.client import StreamChatInterface
from stream_chat.base.exceptions import StreamChannelException
Expand Down Expand Up @@ -488,6 +488,45 @@ def update_member_partial(
"""
pass

@abc.abstractmethod
def create_draft(
self, message: Dict, user_id: str
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
"""
Creates or updates a draft message in a channel.

:param message: The message object
:param user_id: The ID of the user creating the draft
:return: The Server Response
"""
pass

@abc.abstractmethod
def delete_draft(
self, user_id: str, parent_id: Optional[str] = None
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
"""
Deletes a draft message from a channel.

:param user_id: The ID of the user who owns the draft
:param parent_id: Optional ID of the parent message if this is a thread draft
:return: The Server Response
"""
pass

@abc.abstractmethod
def get_draft(
self, user_id: str, parent_id: Optional[str] = None
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
"""
Retrieves a draft message from a channel.

:param user_id: The ID of the user who owns the draft
:param parent_id: Optional ID of the parent message if this is a thread draft
:return: The Server Response
"""
pass


def add_user_id(payload: Dict, user_id: str) -> Dict:
return {**payload, "user": {"id": user_id}}
11 changes: 11 additions & 0 deletions stream_chat/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from stream_chat.types.base import SortParam
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
from stream_chat.types.segment import (
QuerySegmentsOptions,
QuerySegmentTargetsOptions,
Expand Down Expand Up @@ -1384,6 +1385,16 @@ def unread_counts_batch(
"""
pass

@abc.abstractmethod
def query_drafts(
self,
user_id: str,
filter: Optional[QueryDraftsFilter] = None,
sort: Optional[List[SortParam]] = None,
options: Optional[QueryDraftsOptions] = None,
) -> Union[StreamResponse, Awaitable[StreamResponse]]:
pass

#####################
# Private methods #
#####################
Expand Down
25 changes: 24 additions & 1 deletion stream_chat/channel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Iterable, List, Union
from typing import Any, Dict, Iterable, List, Optional, Union

from stream_chat.base.channel import ChannelInterface, add_user_id
from stream_chat.base.exceptions import StreamChannelException
Expand Down Expand Up @@ -248,3 +248,26 @@ def update_member_partial(

payload = {"set": to_set or {}, "unset": to_unset or []}
return self.client.patch(f"{self.url}/member/{user_id}", data=payload)

def create_draft(self, message: Dict, user_id: str) -> StreamResponse:
message["user_id"] = user_id
payload = {"message": message}
return self.client.post(f"{self.url}/draft", data=payload)

def delete_draft(
self, user_id: str, parent_id: Optional[str] = None
) -> StreamResponse:
params = {"user_id": user_id}
if parent_id:
params["parent_id"] = parent_id

return self.client.delete(f"{self.url}/draft", params=params)

def get_draft(
self, user_id: str, parent_id: Optional[str] = None
) -> StreamResponse:
params = {"user_id": user_id}
if parent_id:
params["parent_id"] = parent_id

return self.client.get(f"{self.url}/draft", params=params)
19 changes: 19 additions & 0 deletions stream_chat/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from stream_chat.segment import Segment
from stream_chat.types.base import SortParam
from stream_chat.types.campaign import CampaignData, QueryCampaignsOptions
from stream_chat.types.draft import QueryDraftsFilter, QueryDraftsOptions
from stream_chat.types.segment import (
QuerySegmentsOptions,
QuerySegmentTargetsOptions,
Expand Down Expand Up @@ -782,3 +783,21 @@ def unread_counts(self, user_id: str) -> StreamResponse:

def unread_counts_batch(self, user_ids: List[str]) -> StreamResponse:
return self.post("unread_batch", data={"user_ids": user_ids})

def query_drafts(
self,
user_id: str,
filter: Optional[QueryDraftsFilter] = None,
sort: Optional[List[SortParam]] = None,
options: Optional[QueryDraftsOptions] = None,
) -> StreamResponse:
data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = {
"user_id": user_id
}
if filter is not None:
data["filter"] = cast(dict, filter)
if sort is not None:
data["sort"] = cast(dict, sort)
if options is not None:
data.update(cast(dict, options))
return self.post("drafts/query", data=data)
148 changes: 148 additions & 0 deletions stream_chat/tests/async_chat/test_draft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import uuid
from typing import Dict

import pytest

from stream_chat.async_chat.channel import Channel
from stream_chat.async_chat.client import StreamChatAsync
from stream_chat.types.base import SortOrder


@pytest.mark.incremental
class TestDraft:
async def test_create_draft(self, channel: Channel, random_user: Dict):
draft_message = {"text": "This is a draft message"}
response = await channel.create_draft(draft_message, random_user["id"])

assert "draft" in response
assert response["draft"]["message"]["text"] == "This is a draft message"
assert response["draft"]["channel_cid"] == channel.cid

async def test_get_draft(self, channel: Channel, random_user: Dict):
# First create a draft
draft_message = {"text": "This is a draft to retrieve"}
await channel.create_draft(draft_message, random_user["id"])

# Then get the draft
response = await channel.get_draft(random_user["id"])

assert "draft" in response
assert response["draft"]["message"]["text"] == "This is a draft to retrieve"
assert response["draft"]["channel_cid"] == channel.cid

async def test_delete_draft(self, channel: Channel, random_user: Dict):
# First create a draft
draft_message = {"text": "This is a draft to delete"}
await channel.create_draft(draft_message, random_user["id"])

# Then delete the draft
await channel.delete_draft(random_user["id"])

# Verify it's deleted by trying to get it
try:
await channel.get_draft(random_user["id"])
raise AssertionError("Draft should be deleted")
except Exception:
# Expected behavior, draft should not be found
pass

async def test_thread_draft(self, channel: Channel, random_user: Dict):
# First create a parent message
msg = await channel.send_message({"text": "Parent message"}, random_user["id"])
parent_id = msg["message"]["id"]

# Create a draft reply
draft_reply = {"text": "This is a draft reply", "parent_id": parent_id}
response = await channel.create_draft(draft_reply, random_user["id"])

assert "draft" in response
assert response["draft"]["message"]["text"] == "This is a draft reply"
assert response["draft"]["parent_id"] == parent_id

# Get the draft reply
response = await channel.get_draft(random_user["id"], parent_id=parent_id)

assert "draft" in response
assert response["draft"]["message"]["text"] == "This is a draft reply"
assert response["draft"]["parent_id"] == parent_id

# Delete the draft reply
await channel.delete_draft(random_user["id"], parent_id=parent_id)

# Verify it's deleted
try:
await channel.get_draft(random_user["id"], parent_id=parent_id)
raise AssertionError("Thread draft should be deleted")
except Exception:
# Expected behavior
pass

async def test_query_drafts(
self, client: StreamChatAsync, channel: Channel, random_user: Dict
):
# Create multiple drafts in different channels
draft1 = {"text": "Draft in channel 1"}
await channel.create_draft(draft1, random_user["id"])

# Create another channel with a draft
channel2 = client.channel("messaging", str(uuid.uuid4()))
await channel2.create(random_user["id"])

draft2 = {"text": "Draft in channel 2"}
await channel2.create_draft(draft2, random_user["id"])

# Query all drafts for the user
response = await client.query_drafts(random_user["id"])

assert "drafts" in response
assert len(response["drafts"]) == 2

# Query drafts for a specific channel
response = await client.query_drafts(
random_user["id"], filter={"channel_cid": channel2.cid}
)

assert "drafts" in response
assert len(response["drafts"]) == 1
draft = response["drafts"][0]
assert draft["channel_cid"] == channel2.cid
assert draft["message"]["text"] == "Draft in channel 2"

# Query drafts with sort
response = await client.query_drafts(
random_user["id"],
sort=[{"field": "created_at", "direction": SortOrder.ASC}],
)

assert "drafts" in response
assert len(response["drafts"]) == 2
assert response["drafts"][0]["channel_cid"] == channel.cid
assert response["drafts"][1]["channel_cid"] == channel2.cid

# Query drafts with pagination
response = await client.query_drafts(
random_user["id"],
options={"limit": 1},
)

assert "drafts" in response
assert len(response["drafts"]) == 1
assert response["drafts"][0]["channel_cid"] == channel2.cid

assert response["next"] is not None

# Query drafts with pagination
response = await client.query_drafts(
random_user["id"],
options={"limit": 1, "next": response["next"]},
)

assert "drafts" in response
assert len(response["drafts"]) == 1
assert response["drafts"][0]["channel_cid"] == channel.cid

# Cleanup
try:
await channel2.delete()
except Exception:
pass
Loading