From 467430d902bdd8d76577c423ba7c221340960225 Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Thu, 10 Oct 2024 17:44:41 +0100 Subject: [PATCH 1/4] Updating tools to use newer apis from langchain --- app.py | 232 +++++++-------------------------------- environment.yaml | 32 +++--- tools/geopy/distance.py | 28 ++--- tools/geopy/geocode.py | 34 +++--- tools/mercantile_tool.py | 32 +++--- tools/osmnx/geometry.py | 46 ++++---- tools/osmnx/network.py | 40 +++---- tools/stac/search.py | 56 +++++----- 8 files changed, 165 insertions(+), 335 deletions(-) diff --git a/app.py b/app.py index cf916ae..19e4031 100644 --- a/app.py +++ b/app.py @@ -6,206 +6,54 @@ from streamlit_folium import folium_static import langchain +from langchain import hub from langchain.agents import AgentType -from langchain.chat_models import ChatOpenAI +from langchain_ollama import ChatOllama from langchain.tools import Tool, DuckDuckGoSearchRun -from langchain.callbacks import ( - StreamlitCallbackHandler, - AimCallbackHandler, - get_openai_callback, -) +from langchain.callbacks import StreamlitCallbackHandler +from langchain.agents import AgentExecutor, create_react_agent, load_tools + -from tools.mercantile_tool import MercantileTool -from tools.geopy.geocode import GeopyGeocodeTool -from tools.geopy.distance import GeopyDistanceTool -from tools.osmnx.geometry import OSMnxGeometryTool -from tools.osmnx.network import OSMnxNetworkTool -from tools.stac.search import STACSearchTool +from tools.mercantile_tool import mercantile_tool +from tools.geopy.geocode import geocode_tool +from tools.geopy.distance import distance_tool +from tools.osmnx.geometry import geometry_tool +from tools.osmnx.network import network_tool +from tools.stac.search import stac_search from agents.l4m_agent import base_agent # DEBUG langchain.debug = True +duckduckgo_tool = Tool( + name="DuckDuckGo", + description="Use this tool to answer questions about current events and places. \ + Please ask targeted questions.", + func=DuckDuckGoSearchRun().run, +) -@st.cache_resource(ttl="1h") -def get_agent( - openai_api_key, agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION -): - llm = ChatOpenAI( - temperature=0, - openai_api_key=openai_api_key, - model_name="gpt-3.5-turbo-0613", - ) - # define a set of tools the agent has access to for queries - duckduckgo_tool = Tool( - name="DuckDuckGo", - description="Use this tool to answer questions about current events and places. \ - Please ask targeted questions.", - func=DuckDuckGoSearchRun().run, - ) - geocode_tool = GeopyGeocodeTool() - distance_tool = GeopyDistanceTool() - mercantile_tool = MercantileTool() - geometry_tool = OSMnxGeometryTool() - network_tool = OSMnxNetworkTool() - search_tool = STACSearchTool() - - tools = [ - duckduckgo_tool, - geocode_tool, - distance_tool, - mercantile_tool, - geometry_tool, - network_tool, - search_tool, - ] - - agent = base_agent(llm, tools, agent_type=agent_type) - return agent - - -def run_query(agent, query): - return response - - -def plot_raster(items): - st.subheader("Preview of the first item sorted by cloud cover") - selected_item = min(items, key=lambda item: item.properties["eo:cloud_cover"]) - href = selected_item.assets["rendered_preview"].href - # arr = rio.open(href).read() - - # m = folium.Map(location=[28.6, 77.7], zoom_start=6) - - # img = folium.raster_layers.ImageOverlay( - # name="Sentinel 2", - # image=arr.transpose(1, 2, 0), - # bounds=selected_item.bbox, - # opacity=0.9, - # interactive=True, - # cross_origin=False, - # zindex=1, - # ) - - # img.add_to(m) - # folium.LayerControl().add_to(m) - - # folium_static(m) - st.image(href) - - -def plot_vector(df): - st.subheader("Add the geometry to the Map") - center = df.centroid.iloc[0] - m = folium.Map(location=[center.y, center.x], zoom_start=12) - folium.GeoJson(df).add_to(m) - folium_static(m) - - -st.set_page_config(page_title="LLLLM", page_icon="🤖", layout="wide") -st.subheader("🤖 I am Geo LLM Agent!") - -if "msgs" not in st.session_state: - st.session_state.msgs = [] - -if "total_tokens" not in st.session_state: - st.session_state.total_tokens = 0 - -if "prompt_tokens" not in st.session_state: - st.session_state.prompt_tokens = 0 - -if "completion_tokens" not in st.session_state: - st.session_state.completion_tokens = 0 - -if "total_cost" not in st.session_state: - st.session_state.total_cost = 0 - -with st.sidebar: - openai_api_key = os.getenv("OPENAI_API_KEY") - if not openai_api_key: - openai_api_key = st.text_input("OpenAI API Key", type="password") - - st.subheader("OpenAI Usage") - total_tokens = st.empty() - prompt_tokens = st.empty() - completion_tokens = st.empty() - total_cost = st.empty() - - total_tokens.write(f"Total Tokens: {st.session_state.total_tokens:,.0f}") - prompt_tokens.write(f"Prompt Tokens: {st.session_state.prompt_tokens:,.0f}") - completion_tokens.write( - f"Completion Tokens: {st.session_state.completion_tokens:,.0f}" - ) - total_cost.write(f"Total Cost (USD): ${st.session_state.total_cost:,.4f}") - - -for msg in st.session_state.msgs: - with st.chat_message(name=msg["role"], avatar=msg["avatar"]): - st.markdown(msg["content"]) - -if prompt := st.chat_input("Ask me anything about the flat world..."): - with st.chat_message(name="user", avatar="🧑‍💻"): - st.markdown(prompt) - - st.session_state.msgs.append({"role": "user", "avatar": "🧑‍💻", "content": prompt}) - - if not openai_api_key: - st.info("Please add your OpenAI API key to continue.") - st.stop() - - aim_callback = AimCallbackHandler( - repo=".", - experiment_name="LLLLLM: Base Agent v0.1", - ) - - agent = get_agent(openai_api_key) - - with get_openai_callback() as cb: +llm = ChatOllama( + model="llama3.2", + temperature=0, +) +tools = [ + # duckduckgo_tool, + geocode_tool, + distance_tool, + mercantile_tool, + geometry_tool, + network_tool, + stac_search, +] +prompt = hub.pull("hwchase17/react") +agent = create_react_agent(llm, tools, prompt) +agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) + +if prompt := st.chat_input(): + st.chat_message("user").write(prompt) + with st.chat_message("assistant"): st_callback = StreamlitCallbackHandler(st.container()) - response = agent.run(prompt, callbacks=[st_callback, aim_callback]) - - aim_callback.flush_tracker(langchain_asset=agent, reset=False, finish=True) - - # Log OpenAI stats - # print(f"Model name: {response.llm_output.get('model_name', '')}") - st.session_state.total_tokens += cb.total_tokens - st.session_state.prompt_tokens += cb.prompt_tokens - st.session_state.completion_tokens += cb.completion_tokens - st.session_state.total_cost += cb.total_cost - - total_tokens.write(f"Total Tokens: {st.session_state.total_tokens:,.0f}") - prompt_tokens.write(f"Prompt Tokens: {st.session_state.prompt_tokens:,.0f}") - completion_tokens.write( - f"Completion Tokens: {st.session_state.completion_tokens:,.0f}" + response = agent_executor.invoke( + {"input": prompt}, {"callbacks": [st_callback]} ) - total_cost.write(f"Total Cost (USD): ${st.session_state.total_cost:,.4f}") - - with st.chat_message(name="assistant", avatar="🤖"): - if type(response) == str: - content = response - st.markdown(response) - else: - tool, result = response - - match tool: - case "stac-search": - content = f"Found {len(result)} items from the catalog." - st.markdown(content) - if len(result) > 0: - plot_raster(result) - case "geometry": - content = f"Found {len(result)} geometries." - gdf = result - st.markdown(content) - plot_vector(gdf) - case "network": - content = f"Found {len(result)} network geometries." - ndf = result - st.markdown(content) - plot_vector(ndf) - case _: - content = response - st.markdown(content) - - st.session_state.msgs.append( - {"role": "assistant", "avatar": "🤖", "content": content} - ) + st.write(response["output"]) diff --git a/environment.yaml b/environment.yaml index 631315c..49757d0 100644 --- a/environment.yaml +++ b/environment.yaml @@ -2,20 +2,22 @@ name: llllm-env channels: - conda-forge dependencies: - - python=3 + - python=3.12 - pip - - osmnx=1.3.1 + - rasterio - pip: - - openai==0.27.8 - - langchain==0.0.215 - - duckduckgo-search==3.8.3 - - mercantile==1.2.1 - - geopy==2.3.0 - - ipywidgets==8.0.6 - - jupyterlab==4.0.2 - - planetary-computer==0.5.1 - - pystac-client==0.7.2 - - streamlit==1.24.1 - - streamlit-folium==0.12.0 - - watchdog==3.0.0 - - aim==3.17.5 + - langchain + - langchain-ollama + - langchain-community + - duckduckgo-search + - mercantile + - geopy + - ipywidgets + - jupyterlab + - planetary-computer + - pystac-client + - streamlit + - streamlit-folium + - watchdog + - altair + - osmnx diff --git a/tools/geopy/distance.py b/tools/geopy/distance.py index 79cbf4f..120a93b 100644 --- a/tools/geopy/distance.py +++ b/tools/geopy/distance.py @@ -1,26 +1,18 @@ -from typing import Type - from geopy.distance import distance -from pydantic import BaseModel, Field -from langchain.tools import BaseTool +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool class GeopyDistanceInput(BaseModel): """Input for GeopyDistanceTool.""" - point_1: tuple[float, float] = Field(..., description="lat,lng of a place") - point_2: tuple[float, float] = Field(..., description="lat,lng of a place") - - -class GeopyDistanceTool(BaseTool): - """Custom tool to calculate geodesic distance between two points.""" - - name: str = "distance" - args_schema: Type[BaseModel] = GeopyDistanceInput - description: str = "Use this tool to compute distance between two points available in lat,lng format." + point_1: tuple[float, float] = Field(description="lat,lng of a place") + point_2: tuple[float, float] = Field(description="lat,lng of a place") - def _run(self, point_1: tuple[int, int], point_2: tuple[int, int]) -> float: - return ("distance", distance(point_1, point_2).km) - def _arun(self, place: str): - raise NotImplementedError +@tool("distance-tool", args_schema=GeopyDistanceInput, return_direct=False) +def distance_tool(point_1: tuple[float, float], point_2: tuple[float, float]) -> float: + """ + Custom tool to calculate geodesic distance between two points. + """ + return distance(point_1, point_2).km diff --git a/tools/geopy/geocode.py b/tools/geopy/geocode.py index 8b14198..36afb46 100644 --- a/tools/geopy/geocode.py +++ b/tools/geopy/geocode.py @@ -1,29 +1,25 @@ -from typing import Type +from typing import Tuple from geopy.geocoders import Nominatim -from pydantic import BaseModel, Field -from langchain.tools import BaseTool +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool class GeopyGeocodeInput(BaseModel): """Input for GeopyGeocodeTool.""" - place: str = Field(..., description="name of a place") + place: str = Field(description="name of a place") -class GeopyGeocodeTool(BaseTool): - """Custom tool to perform geocoding.""" +@tool("geocode-tool", args_schema=GeopyGeocodeInput, return_direct=True) +def geocode_tool(place: str) -> Tuple[float, float]: + """ + Custom tool to perform geocoding. - name: str = "geocode" - args_schema: Type[BaseModel] = GeopyGeocodeInput - description: str = "Use this tool for geocoding." - - def _run(self, place: str) -> tuple: - locator = Nominatim(user_agent="geocode") - location = locator.geocode(place) - if location is None: - return ("geocode", "Not a recognised address in Nomatim.") - return ("geocode", (location.latitude, location.longitude)) - - def _arun(self, place: str): - raise NotImplementedError + Use this tool for geocoding an address of a place. + """ + locator = Nominatim(user_agent="geocode") + location = locator.geocode(place) + if location is None: + return ("geocode", "Not a recognised address in Nomatim.") + return location.latitude, location.longitude diff --git a/tools/mercantile_tool.py b/tools/mercantile_tool.py index 63e8bcd..e8ecf5f 100644 --- a/tools/mercantile_tool.py +++ b/tools/mercantile_tool.py @@ -1,19 +1,23 @@ -import mercantile -from langchain.tools import BaseTool +from typing import List +import mercantile +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool +from pydantic import BaseModel, Field -class MercantileTool(BaseTool): - """Tool to perform mercantile operations.""" - name = "mercantile" - description = "use this tool to get the xyz tiles for a place. \ - To use this tool you need to provide lng,lat,zoom level of the place separated by comma." +class MercantileToolInput(BaseModel): + latitude: float = Field(description="Latitude of a location") + longitude: float = Field(description="Longitude of a location") + zoom: str = Field(description="Zoom level for mercantile") - def _run(self, query): - lng, lat, zoom = map(float, query.split(",")) - return ("mercantile", mercantile.tile(lng, lat, zoom)) - def _arun(self, query): - raise NotImplementedError( - "Mercantile tool doesn't have an async implementation." - ) +@tool("mercantile-tool", args_schema=MercantileToolInput, return_direct=True) +def mercantile_tool(latitude: float, longitude: float, zoom: int) -> mercantile.Tile: + """ + Tool to perform mercantile operations. + Use this tool to get the xyz tiles for a place. To use this tool you need to provide + lng,lat,zoom level of the place separated by comma. + """ + lng, lat, zoom = map(float, query.split(",")) + return mercantile.tile(lng, lat, zoom) diff --git a/tools/osmnx/geometry.py b/tools/osmnx/geometry.py index ecc1287..fc9754e 100644 --- a/tools/osmnx/geometry.py +++ b/tools/osmnx/geometry.py @@ -1,32 +1,26 @@ -from typing import Type, Dict +from typing import Dict, Type -import osmnx as ox import geopandas as gpd -from pydantic import BaseModel, Field -from langchain.tools import BaseTool +import osmnx as ox +from geopandas import GeoDataFrame +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool class PlaceWithTags(BaseModel): "Name of a place on the map and tags in OSM." - - place: str = Field(..., description="name of a place on the map.") - tags: Dict[str, str] = Field(..., description="open street maps tags.") - - -class OSMnxGeometryTool(BaseTool): - """Tool to query geometries from Open Street Map (OSM).""" - - name: str = "geometry" - args_schema: Type[BaseModel] = PlaceWithTags - description: str = "Use this tool to get geometry of different features of the place like building footprints, parks, lakes, hospitals, schools etc. \ - Pass the name of the place & tags of OSM as args." - return_direct = True - - def _run(self, place: str, tags: Dict[str, str]) -> gpd.GeoDataFrame: - gdf = ox.geometries_from_place(place, tags) - gdf = gdf[gdf["geometry"].type.isin({"Polygon", "MultiPolygon"})] - gdf = gdf[["name", "geometry"]].reset_index(drop=True) - return ("geometry", gdf) - - def _arun(self, place: str): - raise NotImplementedError + place: str = Field(description="name of a place on the map.") + tags: Dict[str, str] = Field(description="open street maps tags.") + + +@tool("geometry-tool", args_schema=PlaceWithTags, return_direct=True) +def geometry_tool(place: str, tags: Dict[str, str]) -> GeoDataFrame: + """ + Tool to query geometries from Open Street Map (OSM). + Use this tool to get geometry of different features of the place + like building footprints, parks, lakes, hospitals, schools etc. + Pass the name of the place & tags of OSM as args. + """ + gdf = ox.geometries_from_place(place, tags) + gdf = gdf[gdf["geometry"].type.isin({"Polygon", "MultiPolygon"})] + return gdf[["name", "geometry"]].reset_index(drop=True) diff --git a/tools/osmnx/network.py b/tools/osmnx/network.py index 24f8375..0df9d40 100644 --- a/tools/osmnx/network.py +++ b/tools/osmnx/network.py @@ -1,34 +1,26 @@ -from typing import Type, Dict - +import geopandas as gpd import osmnx as ox +from geopandas import GeoDataFrame +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool from osmnx import utils_graph -import geopandas as gpd -from pydantic import BaseModel, Field -from langchain.tools import BaseTool class PlaceWithNetworktype(BaseModel): "Name of a place on the map" - place: str = Field(..., description="name of a place on the map") + place: str = Field(description="name of a place on the map") network_type: str = Field( - ..., description="network type: one of walk, bike, drive or all" + description="network type: one of walk, bike, drive or all" ) -class OSMnxNetworkTool(BaseTool): - """Custom tool to query road networks from OSM.""" - - name: str = "network" - args_schema: Type[BaseModel] = PlaceWithNetworktype - description: str = "Use this tool to get road network of a place. \ - Pass the name of the place & type of road network i.e walk, bike, drive or all." - return_direct = True - - def _run(self, place: str, network_type: str) -> gpd.GeoDataFrame: - G = ox.graph_from_place(place, network_type=network_type, simplify=True) - network = utils_graph.graph_to_gdfs(G, nodes=False) - network = network[["name", "geometry"]].reset_index(drop=True) - return ("network", network) - - def _arun(self, place: str): - raise NotImplementedError +@tool("network-tool", args_schema=PlaceWithNetworktype, return_direct=True) +def network_tool(place: str, network_type: str) -> GeoDataFrame: + """ + Custom tool to query road networks from OSM. + Use this tool to get road network of a place. + Pass the name of the place & type of road network i.e walk, bike, drive or all + """ + G = ox.graph_from_place(place, network_type=network_type, simplify=True) + network = utils_graph.graph_to_gdfs(G, nodes=False) + return network[["name", "geometry"]].reset_index(drop=True) diff --git a/tools/stac/search.py b/tools/stac/search.py index ce8d285..424f190 100644 --- a/tools/stac/search.py +++ b/tools/stac/search.py @@ -1,41 +1,43 @@ -from typing import Type +from typing import List, Type -from pystac_client import Client import planetary_computer as pc -from pydantic import BaseModel, Field -from langchain.tools import BaseTool +from langchain.pydantic_v1 import BaseModel, Field +from langchain_core.tools import tool +from pystac import Item +from pystac_client import Client PC_STAC_API = "https://planetarycomputer.microsoft.com/api/stac/v1" -class PlaceWithDatetimeAndBBox(BaseModel): - "Name of a place and date." +STAC_API = "https://earth-search.aws.element84.com/v1" +COLLECTION = "sentinel-2-l2a" + - bbox: str = Field(..., description="bbox of the place") - datetime: str = Field(..., description="datetime for the stac catalog search") +class StacSearchInput(BaseModel): + latitude: tuple[float, float] = Field(description="Latitude of a location") + longitude: float = Field(description="Longitude of a location") + start: str = Field(description="Start date") + end: str = Field(description="End date") -class STACSearchTool(BaseTool): - """Tool to search for STAC items in a catalog.""" +@tool("stac-search-tool", args_schema=StacSearchInput, return_direct=True) +def stac_search(latitude: float, longitude: float, start: str, end: str) -> List[Item]: + """ + Search Sentinel-2 STAC items. - name: str = "stac-search" - args_schema: Type[BaseModel] = PlaceWithDatetimeAndBBox - description: str = "Use this tool to search for STAC items in a catalog. \ - Pass the bbox of the place & date as args." - return_direct = True + Use this tool to perform a STAC scene search for a Sentinel-2 images at a + latitude and longitude and between a start and an end date. + """ - def _run(self, bbox: str, datetime: str): - catalog = Client.open(PC_STAC_API, modifier=pc.sign_inplace) + catalog = pystac_client.Client.open(STAC_API) - search = catalog.search( - collections=["sentinel-2-l2a"], - bbox=bbox, - datetime=datetime, - max_items=10, - ) - items = search.get_all_items() + search = catalog.search( + collections=[COLLECTION], + datetime=f"{start}/{end}", + bbox=(longitude - 1e-5, latitude - 1e-5, longitude + 1e-5, latitude + 1e-5), + max_items=100, + ) - return ("stac-search", items) + items = search.get_all_items() - def _arun(self, bbox: str, datetime: str): - raise NotImplementedError + return [item for item in items] From 167fbe85cac92d40e8762d5e5b5ac6923feb777b Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Fri, 11 Oct 2024 11:20:09 +0100 Subject: [PATCH 2/4] Use pydantic base model --- tools/geopy/distance.py | 2 +- tools/geopy/geocode.py | 4 ++-- tools/osmnx/geometry.py | 2 +- tools/osmnx/network.py | 2 +- tools/stac/search.py | 15 +++++++++------ 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tools/geopy/distance.py b/tools/geopy/distance.py index 120a93b..cadf91a 100644 --- a/tools/geopy/distance.py +++ b/tools/geopy/distance.py @@ -1,6 +1,6 @@ from geopy.distance import distance -from langchain.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from pydantic import BaseModel, Field class GeopyDistanceInput(BaseModel): diff --git a/tools/geopy/geocode.py b/tools/geopy/geocode.py index 36afb46..a90d237 100644 --- a/tools/geopy/geocode.py +++ b/tools/geopy/geocode.py @@ -1,8 +1,8 @@ from typing import Tuple from geopy.geocoders import Nominatim -from langchain.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from pydantic import BaseModel, Field class GeopyGeocodeInput(BaseModel): @@ -21,5 +21,5 @@ def geocode_tool(place: str) -> Tuple[float, float]: locator = Nominatim(user_agent="geocode") location = locator.geocode(place) if location is None: - return ("geocode", "Not a recognised address in Nomatim.") + return 0, 0 # Null island return location.latitude, location.longitude diff --git a/tools/osmnx/geometry.py b/tools/osmnx/geometry.py index fc9754e..ff23ea7 100644 --- a/tools/osmnx/geometry.py +++ b/tools/osmnx/geometry.py @@ -3,8 +3,8 @@ import geopandas as gpd import osmnx as ox from geopandas import GeoDataFrame -from langchain.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from pydantic import BaseModel, Field class PlaceWithTags(BaseModel): diff --git a/tools/osmnx/network.py b/tools/osmnx/network.py index 0df9d40..8cc1f26 100644 --- a/tools/osmnx/network.py +++ b/tools/osmnx/network.py @@ -1,9 +1,9 @@ import geopandas as gpd import osmnx as ox from geopandas import GeoDataFrame -from langchain.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool from osmnx import utils_graph +from pydantic import BaseModel, Field class PlaceWithNetworktype(BaseModel): diff --git a/tools/stac/search.py b/tools/stac/search.py index 424f190..ec79a46 100644 --- a/tools/stac/search.py +++ b/tools/stac/search.py @@ -1,8 +1,9 @@ +from datetime import datetime from typing import List, Type import planetary_computer as pc -from langchain.pydantic_v1 import BaseModel, Field from langchain_core.tools import tool +from pydantic import BaseModel, Field from pystac import Item from pystac_client import Client @@ -14,14 +15,16 @@ class StacSearchInput(BaseModel): - latitude: tuple[float, float] = Field(description="Latitude of a location") + latitude: float = Field(description="Latitude of a location") longitude: float = Field(description="Longitude of a location") - start: str = Field(description="Start date") - end: str = Field(description="End date") + start: datetime = Field(description="Start date") + end: datetime = Field(description="End date") @tool("stac-search-tool", args_schema=StacSearchInput, return_direct=True) -def stac_search(latitude: float, longitude: float, start: str, end: str) -> List[Item]: +def stac_search( + latitude: float, longitude: float, start: datetime, end: datetime +) -> List[Item]: """ Search Sentinel-2 STAC items. @@ -33,7 +36,7 @@ def stac_search(latitude: float, longitude: float, start: str, end: str) -> List search = catalog.search( collections=[COLLECTION], - datetime=f"{start}/{end}", + datetime=f"{start.date()}/{end.date()}", bbox=(longitude - 1e-5, latitude - 1e-5, longitude + 1e-5, latitude + 1e-5), max_items=100, ) From dea5f37e5c119c9a23bf0d96693cac3c2c70699e Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Fri, 11 Oct 2024 11:22:02 +0100 Subject: [PATCH 3/4] Fix stac client call --- tools/stac/search.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/stac/search.py b/tools/stac/search.py index ec79a46..222a67d 100644 --- a/tools/stac/search.py +++ b/tools/stac/search.py @@ -32,7 +32,7 @@ def stac_search( latitude and longitude and between a start and an end date. """ - catalog = pystac_client.Client.open(STAC_API) + catalog = Client.open(STAC_API) search = catalog.search( collections=[COLLECTION], From a1649e417d8bbc32442af06597522f09d2d512dd Mon Sep 17 00:00:00 2001 From: Daniel Wiesmann Date: Thu, 17 Oct 2024 08:37:50 +0100 Subject: [PATCH 4/4] Update to langgraph --- agents/l4m_agent.py | 31 ---------------- app.py | 69 ++++++++--------------------------- {agents => graphs}/.gitkeep | 0 graphs/l4m_graph.py | 72 +++++++++++++++++++++++++++++++++++++ tools/duck_tool.py | 8 +++++ tools/geopy/distance.py | 14 +++++--- tools/geopy/geocode.py | 3 +- 7 files changed, 105 insertions(+), 92 deletions(-) delete mode 100644 agents/l4m_agent.py rename {agents => graphs}/.gitkeep (100%) create mode 100644 graphs/l4m_graph.py create mode 100644 tools/duck_tool.py diff --git a/agents/l4m_agent.py b/agents/l4m_agent.py deleted file mode 100644 index 302da89..0000000 --- a/agents/l4m_agent.py +++ /dev/null @@ -1,31 +0,0 @@ -from langchain.agents import initialize_agent -from langchain.agents import AgentType -from langchain.prompts import MessagesPlaceholder -from langchain.memory import ConversationBufferMemory - - -def base_agent( - llm, tools, agent_type=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION -): - """Base agent to perform xyz slippy map tiles operations. - - llm: LLM object - tools: List of tools to use by the agent - """ - # chat_history = MessagesPlaceholder(variable_name="chat_history") - # memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) - agent = initialize_agent( - llm=llm, - tools=tools, - agent=agent_type, - max_iterations=5, - early_stopping_method="generate", - verbose=True, - # memory=memory, - # agent_kwargs={ - # "memory_prompts": [chat_history], - # "input_variables": ["input", "agent_scratchpad", "chat_history"], - # }, - ) - print("agent initialized") - return agent diff --git a/app.py b/app.py index 19e4031..c0c688e 100644 --- a/app.py +++ b/app.py @@ -1,59 +1,20 @@ -import os - -import rasterio as rio -import folium import streamlit as st -from streamlit_folium import folium_static - -import langchain -from langchain import hub -from langchain.agents import AgentType -from langchain_ollama import ChatOllama -from langchain.tools import Tool, DuckDuckGoSearchRun -from langchain.callbacks import StreamlitCallbackHandler -from langchain.agents import AgentExecutor, create_react_agent, load_tools - - -from tools.mercantile_tool import mercantile_tool -from tools.geopy.geocode import geocode_tool -from tools.geopy.distance import distance_tool -from tools.osmnx.geometry import geometry_tool -from tools.osmnx.network import network_tool -from tools.stac.search import stac_search -from agents.l4m_agent import base_agent - -# DEBUG -langchain.debug = True - -duckduckgo_tool = Tool( - name="DuckDuckGo", - description="Use this tool to answer questions about current events and places. \ - Please ask targeted questions.", - func=DuckDuckGoSearchRun().run, -) +from langchain_core.messages import HumanMessage -llm = ChatOllama( - model="llama3.2", - temperature=0, -) -tools = [ - # duckduckgo_tool, - geocode_tool, - distance_tool, - mercantile_tool, - geometry_tool, - network_tool, - stac_search, -] -prompt = hub.pull("hwchase17/react") -agent = create_react_agent(llm, tools, prompt) -agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) +from graphs.l4m_graph import graph if prompt := st.chat_input(): st.chat_message("user").write(prompt) - with st.chat_message("assistant"): - st_callback = StreamlitCallbackHandler(st.container()) - response = agent_executor.invoke( - {"input": prompt}, {"callbacks": [st_callback]} - ) - st.write(response["output"]) + config = {"configurable": {"thread_id": "1"}} + for chunk in graph.stream( + {"messages": [HumanMessage(content=prompt)]}, config, stream_mode="updates" + ): + # for chunk in graph.invoke( + # {"messages": [HumanMessage(content=prompt)]}, config, stream_mode="updates" + # ): + # st.markdown(chunk) + + node = "assistant" if "assistant" in chunk else "tools" + with st.chat_message(node): + for msg in chunk[node]["messages"]: + st.markdown(msg.content) diff --git a/agents/.gitkeep b/graphs/.gitkeep similarity index 100% rename from agents/.gitkeep rename to graphs/.gitkeep diff --git a/graphs/l4m_graph.py b/graphs/l4m_graph.py new file mode 100644 index 0000000..376293c --- /dev/null +++ b/graphs/l4m_graph.py @@ -0,0 +1,72 @@ +import langchain +from langchain_core.messages import SystemMessage +from langchain_ollama import ChatOllama +from langgraph.graph import START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode, tools_condition + +from tools.geopy.distance import distance_tool +from tools.geopy.geocode import geocode_tool +from tools.mercantile_tool import mercantile_tool +from tools.osmnx.geometry import geometry_tool +from tools.osmnx.network import network_tool +from tools.stac.search import stac_search + +# from tools.duck_tool import duckduckgo_tool + +# DEBUG +langchain.debug = True + +llm = ChatOllama( + model="llama3.2", + temperature=0, +) + +tools = [ + # duckduckgo_tool, + geocode_tool, + distance_tool, + mercantile_tool, + geometry_tool, + network_tool, + stac_search, +] + +# For this ipynb we set parallel tool calling to false as math generally is done sequentially, and this time we have 3 tools that can do math +# the OpenAI model specifically defaults to parallel tool calling for efficiency, see https://python.langchain.com/docs/how_to/tool_calling_parallel/ +# play around with it and see how the model behaves with math equations! +llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False) + + +# System message +sys_msg = SystemMessage( + content="You are a helpful assistant tasked with answering questions on a set of geographic inputs." + # "You are a helpful assistant tasked with performing arithmetic on a set of inputs. " + "do not use tools unless the message does not contain geographic inputs" + # "do NOT use tools unless strictly necessary to answer the question" + # " Do NOT answer the question, just reformulate it if needed and otherwise return it as is.." +) + + +# Node +def assistant(state: MessagesState): + return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]} + + +# Graph +builder = StateGraph(MessagesState) + +# Define nodes: these do the work +builder.add_node("assistant", assistant) +builder.add_node("tools", ToolNode(tools)) + +# Define edges: these determine how the control flow moves +builder.add_edge(START, "assistant") +builder.add_conditional_edges( + "assistant", + # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools + # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END + tools_condition, +) +builder.add_edge("tools", "assistant") + +graph = builder.compile() diff --git a/tools/duck_tool.py b/tools/duck_tool.py new file mode 100644 index 0000000..5a24c03 --- /dev/null +++ b/tools/duck_tool.py @@ -0,0 +1,8 @@ +from langchain.tools import DuckDuckGoSearchRun, Tool + +duckduckgo_tool = Tool( + name="DuckDuckGo", + description="Use this tool to answer questions about current events and places. \ + Please ask targeted questions.", + func=DuckDuckGoSearchRun().run, +) diff --git a/tools/geopy/distance.py b/tools/geopy/distance.py index cadf91a..471d3a8 100644 --- a/tools/geopy/distance.py +++ b/tools/geopy/distance.py @@ -1,3 +1,5 @@ +from typing import Union + from geopy.distance import distance from langchain_core.tools import tool from pydantic import BaseModel, Field @@ -6,13 +8,15 @@ class GeopyDistanceInput(BaseModel): """Input for GeopyDistanceTool.""" - point_1: tuple[float, float] = Field(description="lat,lng of a place") - point_2: tuple[float, float] = Field(description="lat,lng of a place") + lat1: float = Field(description="Latitude of a first location") + lon1: float = Field(description="Longitude of a first location") + lat2: float = Field(description="Latitude of a second location") + lon2: float = Field(description="Longitude of a second location") @tool("distance-tool", args_schema=GeopyDistanceInput, return_direct=False) -def distance_tool(point_1: tuple[float, float], point_2: tuple[float, float]) -> float: +def distance_tool(lat1: float, lon1: float, lat2: float, lon2: float) -> float: """ - Custom tool to calculate geodesic distance between two points. + Tool to calculate distance in kilometers between two points. """ - return distance(point_1, point_2).km + return distance((lat1, lon1), (lat2, lon2)).km diff --git a/tools/geopy/geocode.py b/tools/geopy/geocode.py index a90d237..f113532 100644 --- a/tools/geopy/geocode.py +++ b/tools/geopy/geocode.py @@ -18,8 +18,7 @@ def geocode_tool(place: str) -> Tuple[float, float]: Use this tool for geocoding an address of a place. """ + return 30.1, 40.1 locator = Nominatim(user_agent="geocode") location = locator.geocode(place) - if location is None: - return 0, 0 # Null island return location.latitude, location.longitude