From 15096be55344c48b5f6145929cefc770001c1476 Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Tue, 28 Dec 2021 12:50:19 +0000 Subject: [PATCH] Implement class-based IRC commands --- handler/app.py | 3 +- handler/commands.py | 120 ++++++++++++++++++++++++++++++++++++++++++++ handler/irc.py | 107 +++++++++++++++------------------------ 3 files changed, 162 insertions(+), 68 deletions(-) create mode 100644 handler/commands.py diff --git a/handler/app.py b/handler/app.py index 07a79be..0f5105e 100755 --- a/handler/app.py +++ b/handler/app.py @@ -70,9 +70,10 @@ if __name__ == "__main__": # Define Transactions tx = Transactions() - # Pass Agora and IRC to Transactions + # Pass Agora and IRC to Transactions and Transactions to IRC tx.set_agora(agora) tx.set_irc(irc) + irc.set_tx(tx) # Define WebApp webapp = WebApp() diff --git a/handler/commands.py b/handler/commands.py new file mode 100644 index 0000000..ab8a575 --- /dev/null +++ b/handler/commands.py @@ -0,0 +1,120 @@ +# Other library imports +from json import dumps + + +class IRCCommands(object): + class trades(object): + name = "trades" + authed = True + + @staticmethod + def run(cmd, spl, length, authed, msg, agora, revolut, tx): + """ + Get details of open trades and post on IRC. + """ + trades = agora.dashboard(send_irc=False) + if not trades: + msg("No open trades.") + for trade_id in trades: + msg(trade_id) + + class create(object): + name = "create" + authed = True + + @staticmethod + def run(cmd, spl, length, authed, msg, agora, revolut, tx): + """ + Post an ad on AgoraDesk with the given country and currency code. + """ + posted = agora.create_ad(spl[1], spl[2]) + if posted["success"]: + msg(f"{posted['response']['data']['message']}: {posted['response']['data']['ad_id']}") + else: + msg(posted["response"]["data"]["message"]) + + class messages(object): + name = "messages" + authed = True + + @staticmethod + def run(cmd, spl, length, authed, msg, agora, revolut, tx): + """ + Get all messages for all open trades or a given trade. + """ + if length == 1: + messages = agora.get_all_messages() + if not messages: + msg("No messages.") + for message_id in messages: + for message in messages[message_id]: + msg(f"{message_id}: {message}") + msg("---") + + elif length == 2: + messages = agora.get_messages(spl[1], send_irc=False) + if not messages: + msg("No messages.") + for message in messages: + msg(f"{spl[1]}: {message}") + + class dist(object): + name = "dist" + authed = True + + @staticmethod + def run(cmd, spl, length, authed, msg, agora, revolut, tx): + # Distribute out our ad to all countries in the config + rtrn = agora.dist_countries() + msg(dumps(rtrn)) + + class find(object): + name = "find" + authed = True + + @staticmethod + def run(cmd, spl, length, authed, msg, agora, revolut, tx): + """ + Find a transaction received by Revolut with the given reference and amount. + """ + try: + int(spl[2]) + except ValueError: + msg("Amount is not an integer.") + rtrn = tx.find_tx(spl[1], spl[2]) + if rtrn == "AMOUNT_INVALID": + msg("Reference found but amount invalid.") + elif not rtrn: + msg("Reference not found.") + else: + return dumps(rtrn) + + class accounts(object): + name = "accounts" + authed = True + + @staticmethod + def run(cmd, spl, length, authed, msg, agora, revolut, tx): + accounts = revolut.accounts() + for account in accounts: + if account["balance"] > 0: + msg(f"{account['name']} {account['currency']}: {account['balance']}") + + class total(object): + name = "total" + authed = True + + @staticmethod + def run(cmd, spl, length, authed, msg, agora, revolut, tx): + total_usd = revolut.get_total_usd() + if not total_usd: + msg("Error getting total balance.") + msg(f"Total: {round(total_usd, 2)}USD") + + class ping(object): + name = "ping" + authed = False + + @staticmethod + def run(cmd, spl, length, authed, msg): + msg("Pong!") diff --git a/handler/irc.py b/handler/irc.py index b2298d1..98d8080 100644 --- a/handler/irc.py +++ b/handler/irc.py @@ -3,11 +3,9 @@ from twisted.logger import Logger from twisted.words.protocols import irc from twisted.internet import protocol, reactor, ssl -# Other library imports -from json import dumps - # Project imports from settings import settings +from commands import IRCCommands class IRCBot(irc.IRCClient): @@ -18,6 +16,9 @@ class IRCBot(irc.IRCClient): :type log: Logger """ self.log = log + self.cmd = IRCCommands() + # Parse the commands into "commandname": "commandclass" + self.cmdhash = {getattr(self.cmd, x).name: x for x in dir(self.cmd) if not x.startswith("_")} self.nickname = settings.IRC.Nick self.password = settings.IRC.Pass self.realname = self.nickname @@ -42,6 +43,9 @@ class IRCBot(irc.IRCClient): def set_revolut(self, revolut): self.revolut = revolut + def set_tx(self, tx): + self.tx = tx + def parse(self, user, host, channel, msg): """ Simple handler for IRC commands. @@ -59,71 +63,32 @@ class IRCBot(irc.IRCClient): cmd = spl[0] - if cmd == "trades" and host in self.admins: - # Get details of open trades and post on IRC - trades = self.agora.dashboard(send_irc=False) - if not trades: - self.msg(channel, "No open trades.") - for trade_id in trades: - self.msg(channel, trade_id) - - elif cmd == "create" and host in self.admins and len(spl) == 3: - # Post an ad on AgoraDesk with the given country and currency code - posted = self.agora.create_ad(spl[1], spl[2]) - if posted["success"]: - self.msg(channel, f"{posted['response']['data']['message']}: {posted['response']['data']['ad_id']}") + # Check if user is authenticated + authed = host in self.admins + if cmd in self.cmdhash: + # Get the class name of the referenced command + cmdname = self.cmdhash[cmd] + # Get the class name + obj = getattr(self.cmd, cmdname) + + def msgl(x): + self.msg(channel, x) + + # Check if the command required authentication + if obj.authed: + if host in self.admins: + obj.run(cmd, spl, len(spl), authed, msgl, self.agora, self.revolut, self.tx) + else: + # Handle authentication here instead of in the command module for security + self.msg(channel, "Access denied.") else: - self.msg(channel, posted["response"]["data"]["message"]) - - elif cmd == "messages" and host in self.admins and len(spl) == 1: - # Get all messages for all open trades - messages = self.agora.get_all_messages() - if not messages: - self.msg(channel, "No messages.") - for message_id in messages: - for message in messages[message_id]: - self.msg(channel, f"{message_id}: {message}") - self.msg(channel, "---") - # self.msg(channel, dumps(messages)) - - elif cmd == "messages" and host in self.admins and len(spl) == 2: - # Get all messages for a given trade - messages = self.agora.get_messages(spl[1], send_irc=False) - if not messages: - self.msg(channel, "No messages.") - for message in messages: - self.msg(channel, f"{spl[1]}: {message}") - - elif cmd == "dist" and host in self.admins: - # Distribute out our ad to all countries in the config - rtrn = self.agora.dist_countries() - self.msg(channel, dumps(rtrn)) - - elif cmd == "find" and host in self.admins and len(spl) == 3: - # Find a transaction received by Revolut with the given reference and amount - try: - int(spl[2]) - except ValueError: - self.msg(channel, "Amount is not an integer.") - rtrn = self.tx.find_tx(spl[1], spl[2]) - if rtrn == "AMOUNT_INVALID": - self.msg(channel, "Reference found but amount invalid.") - elif not rtrn: - self.msg(channel, "Reference not found.") - else: - return dumps(rtrn) - - elif cmd == "accounts" and host in self.admins: - accounts = self.revolut.accounts() - for account in accounts: - if account["balance"] > 0: - self.msg(channel, f"{account['name']} {account['currency']}: {account['balance']}") - - elif cmd == "total" and host in self.admins: - total_usd = self.revolut.get_total_usd() - if not total_usd: - self.msg(channel, "Error getting total balance.") - self.msg(channel, f"Total: {round(total_usd, 2)}USD") + # Run an unauthenticated command, without passing through secure library calls + obj.run(cmd, spl, len(spl), authed, msgl) + return + self.msg(channel, "Command not found.") + if authed: + # Give user command hints if they are authenticated + self.msg(channel, f"Commands loaded: {', '.join(self.cmdhash.keys())}") def signedOn(self): """ @@ -170,6 +135,10 @@ class IRCBot(irc.IRCClient): if len(msg) > 1: if msg.split()[0] != "!": self.parse(user, host, channel, msg[1:]) + elif host in self.admins and channel == nick: + if len(msg) > 0: + if msg.split()[0] != "!": + self.parse(user, host, channel, msg) def noticed(self, user, channel, msg): """ @@ -197,6 +166,9 @@ class IRCBotFactory(protocol.ClientFactory): def set_revolut(self, revolut): self.revolut = revolut + def set_tx(self, tx): + self.tx = tx + def buildProtocol(self, addr): """ Custom override for the Twisted buildProtocol so we can access the Protocol instance. @@ -207,6 +179,7 @@ class IRCBotFactory(protocol.ClientFactory): self.client = prcol self.client.set_agora(self.agora) self.client.set_revolut(self.revolut) + self.client.set_tx(self.tx) return prcol def clientConnectionLost(self, connector, reason):