Skip to content

Commit e34ca0a

Browse files
committed
Greedy support, closes #77
1 parent 060e6ea commit e34ca0a

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

revolt/ext/commands/command.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,20 @@ async def handle_origin(cls, context: Context[ClientT_Co_D], origin: Any, annota
148148
elif origin is Annotated:
149149
annotated_args = get_args(annotation)
150150

151-
if origin := get_origin(annotated_args[0]):
152-
return await cls.handle_origin(context, origin, annotated_args[1], arg)
151+
if annotated_args[1] == "_revolt_greedy_marker":
152+
real_annotation = get_args(annotated_args[0])[0]
153+
converted_args: list[Any] = []
154+
155+
converted_args.append(await cls.convert_argument(arg, real_annotation, context))
156+
157+
for arg in context.view:
158+
try:
159+
converted_args.append(await cls.convert_argument(arg, real_annotation, context))
160+
except:
161+
context.view.undo()
162+
break
163+
164+
return converted_args
153165
else:
154166
return await cls.convert_argument(arg, annotated_args[1], context)
155167

revolt/ext/commands/converters.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
if TYPE_CHECKING:
1414
from .client import CommandsClient
1515

16-
__all__: tuple[str, ...] = ("bool_converter", "category_converter", "channel_converter", "user_converter", "member_converter", "IntConverter", "BoolConverter", "CategoryConverter", "UserConverter", "MemberConverter", "ChannelConverter")
16+
T = TypeVar("T")
17+
18+
__all__: tuple[str, ...] = ("bool_converter", "category_converter", "channel_converter", "user_converter", "member_converter", "IntConverter", "BoolConverter", "CategoryConverter", "UserConverter", "MemberConverter", "ChannelConverter", "Greedy")
1719

1820
channel_regex: re.Pattern[str] = re.compile("<#([A-z0-9]{26})>")
1921
user_regex: re.Pattern[str] = re.compile("<@([A-z0-9]{26})>")
@@ -120,3 +122,5 @@ def int_converter(arg: str, context: Context[ClientT]) -> int:
120122
UserConverter = Annotated[User, user_converter]
121123
MemberConverter = Annotated[Member, member_converter]
122124
ChannelConverter = Annotated[Channel, channel_converter]
125+
126+
Greedy = Annotated[list[T], "_revolt_greedy_marker"]

revolt/ext/commands/view.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import Iterator
2+
from typing_extensions import Self
3+
24
from .errors import NoClosingQuote
35

46

@@ -52,3 +54,9 @@ def get_next_word(self) -> str:
5254
self.temp = output
5355

5456
return output
57+
58+
def __iter__(self) -> Self:
59+
return self
60+
61+
def __next__(self) -> str:
62+
return self.get_next_word()

0 commit comments

Comments
 (0)