diff --git a/stream_chat/async_chat/client.py b/stream_chat/async_chat/client.py index 67d1df2..112cb8e 100644 --- a/stream_chat/async_chat/client.py +++ b/stream_chat/async_chat/client.py @@ -29,6 +29,7 @@ SegmentType, SegmentUpdatableFields, ) +from stream_chat.types.shared_locations import SharedLocationsOptions if sys.version_info >= (3, 8): from typing import Literal @@ -859,16 +860,12 @@ async def query_drafts( data: Dict[str, Union[str, Dict[str, Any], List[SortParam]]] = { "user_id": user_id } - if filter is not None: data["filter"] = cast(dict, filter) - if sort is not None: data["sort"] = cast(dict, sort) - if options is not None: data.update(cast(dict, options)) - return await self.post("drafts/query", data=data) async def create_reminder( @@ -956,6 +953,22 @@ async def query_reminders( params["user_id"] = user_id return await self.post("reminders/query", data=params) + async def get_user_locations(self, user_id: str, **options: Any) -> StreamResponse: + params = {"user_id": user_id, **options} + return await self.get("users/live_locations", params=params) + + async def update_user_location( + self, + user_id: str, + message_id: str, + options: Optional[SharedLocationsOptions] = None, + ) -> StreamResponse: + data = {"message_id": message_id} + if options is not None: + data.update(cast(dict, options)) + params = {"user_id": user_id, **options} + return await self.put("users/live_locations", data=data, params=params) + async def close(self) -> None: await self.session.close() diff --git a/stream_chat/base/client.py b/stream_chat/base/client.py index 1865761..77f16a5 100644 --- a/stream_chat/base/client.py +++ b/stream_chat/base/client.py @@ -17,6 +17,7 @@ SegmentType, SegmentUpdatableFields, ) +from stream_chat.types.shared_locations import SharedLocationsOptions if sys.version_info >= (3, 8): from typing import Literal @@ -1505,6 +1506,27 @@ def query_reminders( """ pass + @abc.abstractmethod + def get_user_locations( + self, user_id: str, **options: Any + ) -> Union[StreamResponse, Awaitable[StreamResponse]]: + """ + Get the locations of a user. + """ + pass + + @abc.abstractmethod + def update_user_location( + self, + user_id: str, + message_id: str, + options: Optional[SharedLocationsOptions] = None, + ) -> Union[StreamResponse, Awaitable[StreamResponse]]: + """ + Update the location of a user. + """ + pass + ##################### # Private methods # ##################### diff --git a/stream_chat/client.py b/stream_chat/client.py index 7899a6a..1093b4e 100644 --- a/stream_chat/client.py +++ b/stream_chat/client.py @@ -18,6 +18,7 @@ SegmentType, SegmentUpdatableFields, ) +from stream_chat.types.shared_locations import SharedLocationsOptions if sys.version_info >= (3, 8): from typing import Literal @@ -898,3 +899,19 @@ def query_reminders( params["sort"] = sort or [{"field": "remind_at", "direction": 1}] params["user_id"] = user_id return self.post("reminders/query", data=params) + + def get_user_locations(self, user_id: str, **options: Any) -> StreamResponse: + params = {"user_id": user_id, **options} + return self.get("users/live_locations", params=params) + + def update_user_location( + self, + user_id: str, + message_id: str, + options: Optional[SharedLocationsOptions] = None, + ) -> StreamResponse: + data = {"message_id": message_id} + if options is not None: + data.update(cast(dict, options)) + params = {"user_id": user_id, **options} + return self.put("users/live_locations", data=data, params=params) diff --git a/stream_chat/tests/async_chat/test_live_locations.py b/stream_chat/tests/async_chat/test_live_locations.py new file mode 100644 index 0000000..3b1885b --- /dev/null +++ b/stream_chat/tests/async_chat/test_live_locations.py @@ -0,0 +1,84 @@ +import datetime +from typing import Dict + +import pytest + +from stream_chat.async_chat.client import StreamChatAsync + + +@pytest.mark.incremental +class TestLiveLocations: + @pytest.fixture(autouse=True) + @pytest.mark.asyncio + async def setup_channel_for_shared_locations(self, channel): + await channel.update_partial( + {"config_overrides": {"shared_locations": True}}, + ) + yield + await channel.update_partial( + {"config_overrides": {"shared_locations": False}}, + ) + + async def test_get_user_locations( + self, client: StreamChatAsync, channel, random_user: Dict + ): + # Create a message to attach location to + now = datetime.datetime.now(datetime.timezone.utc) + one_hour_later = now + datetime.timedelta(hours=1) + shared_location = { + "created_by_device_id": "test_device_id", + "latitude": 37.7749, + "longitude": -122.4194, + "end_at": one_hour_later.isoformat(), + } + + channel.send_message( + {"text": "Message with location", "shared_location": shared_location}, + random_user["id"], + ) + + # Get user locations + response = await client.get_user_locations(random_user["id"]) + + assert "active_live_locations" in response + assert isinstance(response["active_live_locations"], list) + + async def test_update_user_location( + self, client: StreamChatAsync, channel, random_user: Dict + ): + # Create a message to attach location to + now = datetime.datetime.now(datetime.timezone.utc) + one_hour_later = now + datetime.timedelta(hours=1) + shared_location = { + "created_by_device_id": "test_device_id", + "latitude": 37.7749, + "longitude": -122.4194, + "end_at": one_hour_later.isoformat(), + } + + msg = await channel.send_message( + {"text": "Message with location", "shared_location": shared_location}, + random_user["id"], + ) + message_id = msg["message"]["id"] + + # Update user location + location_data = { + "created_by_device_id": "test_device_id", + "latitude": 37.7749, + "longitude": -122.4194, + } + response = await client.update_user_location( + random_user["id"], message_id, location_data + ) + + assert response["latitude"] == location_data["latitude"] + assert response["longitude"] == location_data["longitude"] + + # Get user locations to verify + locations_response = await client.get_user_locations(random_user["id"]) + assert "active_live_locations" in locations_response + assert len(locations_response["active_live_locations"]) > 0 + location = locations_response["active_live_locations"][0] + assert location["latitude"] == location_data["latitude"] + assert location["longitude"] == location_data["longitude"] diff --git a/stream_chat/tests/test_live_locations.py b/stream_chat/tests/test_live_locations.py new file mode 100644 index 0000000..e26ca23 --- /dev/null +++ b/stream_chat/tests/test_live_locations.py @@ -0,0 +1,79 @@ +import datetime +from typing import Dict + +import pytest + +from stream_chat import StreamChat + + +@pytest.mark.incremental +class TestLiveLocations: + @pytest.fixture(autouse=True) + def setup_channel_for_shared_locations(self, channel): + channel.update_partial( + {"config_overrides": {"shared_locations": True}}, + ) + yield + channel.update_partial( + {"config_overrides": {"shared_locations": False}}, + ) + + def test_get_user_locations(self, client: StreamChat, channel, random_user: Dict): + # Create a message to attach location to + now = datetime.datetime.now(datetime.timezone.utc) + one_hour_later = now + datetime.timedelta(hours=1) + shared_location = { + "created_by_device_id": "test_device_id", + "latitude": 37.7749, + "longitude": -122.4194, + "end_at": one_hour_later.isoformat(), + } + + channel.send_message( + {"text": "Message with location", "shared_location": shared_location}, + random_user["id"], + ) + + # Get user locations + response = client.get_user_locations(random_user["id"]) + + assert "active_live_locations" in response + assert isinstance(response["active_live_locations"], list) + + def test_update_user_location(self, client: StreamChat, channel, random_user: Dict): + # Create a message to attach location to + now = datetime.datetime.now(datetime.timezone.utc) + one_hour_later = now + datetime.timedelta(hours=1) + shared_location = { + "created_by_device_id": "test_device_id", + "latitude": 37.7749, + "longitude": -122.4194, + "end_at": one_hour_later.isoformat(), + } + + msg = channel.send_message( + {"text": "Message with location", "shared_location": shared_location}, + random_user["id"], + ) + message_id = msg["message"]["id"] + + # Update user location + location_data = { + "created_by_device_id": "test_device_id", + "latitude": 37.7749, + "longitude": -122.4194, + } + response = client.update_user_location( + random_user["id"], message_id, location_data + ) + + assert response["latitude"] == location_data["latitude"] + assert response["longitude"] == location_data["longitude"] + + # Get user locations to verify + locations_response = client.get_user_locations(random_user["id"]) + assert "active_live_locations" in locations_response + assert len(locations_response["active_live_locations"]) > 0 + location = locations_response["active_live_locations"][0] + assert location["latitude"] == location_data["latitude"] + assert location["longitude"] == location_data["longitude"] diff --git a/stream_chat/types/shared_locations.py b/stream_chat/types/shared_locations.py new file mode 100644 index 0000000..1076531 --- /dev/null +++ b/stream_chat/types/shared_locations.py @@ -0,0 +1,8 @@ +from datetime import datetime +from typing import Optional, TypedDict + + +class SharedLocationsOptions(TypedDict): + longitude: Optional[int] + latitude: Optional[int] + end_at: Optional[datetime]