From 85c64efc78ef1349983be2125616c5f6b859fe2e Mon Sep 17 00:00:00 2001 From: Mark Veidemanis Date: Sat, 20 May 2023 13:41:49 +0100 Subject: [PATCH] Validate addresses and enable payments --- Dockerfile | 2 +- core/lib/money.py | 5 +- core/management/commands/scheduling.py | 23 +- core/util/validation.py | 742 +++++++++++++++++++++++++ requirements.txt | 2 + 5 files changed, 765 insertions(+), 9 deletions(-) create mode 100644 core/util/validation.py diff --git a/Dockerfile b/Dockerfile index 1f98816..cb5680c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,5 +1,5 @@ # syntax=docker/dockerfile:1 -FROM python:3 +FROM python:3.10 ARG OPERATION RUN useradd -d /code xf diff --git a/core/lib/money.py b/core/lib/money.py index 74ea149..3cd5ae8 100644 --- a/core/lib/money.py +++ b/core/lib/money.py @@ -87,7 +87,10 @@ class Money(object): cast["xtype"] = msgtype # cast["user_id"] = self.instance.user.id # cast["platform_id"] = self.instance.id - await self.es.index(index=settings.ELASTICSEARCH_INDEX, body=cast) + try: + await self.es.index(index=settings.ELASTICSEARCH_INDEX, body=cast) + except RuntimeError: + log.warning("Could not write to ES") async def lookup_rates(self, platform, ads, rates=None): """ diff --git a/core/management/commands/scheduling.py b/core/management/commands/scheduling.py index f5fcc0e..aef0a67 100644 --- a/core/management/commands/scheduling.py +++ b/core/management/commands/scheduling.py @@ -17,6 +17,7 @@ from core.models import ( Requisition, ) from core.util import logs +from core.util.validation import Validation log = logs.get_logger("scheduling") @@ -100,10 +101,20 @@ async def withdrawal_job(group=None): raise Exception("You can only have one platform per group") platform = group.platforms.first() - # run = await AgoraClient(platform) + run = await AgoraClient(platform) otp_code = TOTP(platform.otp_token).now() for wallet, pay_list_iter in pay_list.items(): + print("WALLET ITER", wallet) + if not Validation.is_address("xmr", wallet.address): + print("NOT VALID", wallet.address) + await sendmsg( + group.user, + f"Invalid XMR address: {wallet.address}", + title="Invalid XMR address", + ) + continue + for amount, reason in pay_list_iter: print("ITER", wallet, pay_list_iter) print("ITER SENT", wallet, amount, reason) @@ -121,9 +132,8 @@ async def withdrawal_job(group=None): print("CAST AMOUNT", cast["amount"]) print("CAST OTP TRUNCATED BY 2", cast["otp"][-2]) - # TODO: UNCOMMENT - # sent = await run.call("wallet_send_xmr", **cast) - # print("SENT", sent) + sent = await run.call("wallet_send_xmr", **cast) + print("SENT", sent) payout = Payout.objects.create( # noqa user=group.user, @@ -132,9 +142,8 @@ async def withdrawal_job(group=None): description=reason, ) - # TODO: UNCOMMENT - # payout.response = sent - # payout.save() + payout.response = sent + payout.save() async def aggregator_job(): diff --git a/core/util/validation.py b/core/util/validation.py new file mode 100644 index 0000000..938eef5 --- /dev/null +++ b/core/util/validation.py @@ -0,0 +1,742 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# flake8: noqa + +# ==================================== +# Cryptocurrency Validation Functions +# BTC, LTC, XMR +# +# Code modified from: +# +# Base58 decoding: https://github.com/keis/base58 +# P2PKH validation: http://bit.ly/2DSVAXc +# Bech32 Validation: http://bit.ly/2Eaw40N +# XMR Validation: https://github.com/monero-project +# ==================================== + +import os +import sys + +if os.path.exists(os.getcwd() + "\\venv"): + sys.path.append(os.getcwd() + "\\venv\\Lib\\site-packages") +import operator as _oper +import re +import struct +from binascii import hexlify, unhexlify +from decimal import Decimal + +import base58 +import sha3 + +# --------------------- Global Variables -------------------- # + + +_ADDR_REGEX = re.compile( + r"^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{95}$" +) +_IADDR_REGEX = re.compile( + r"^[123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz]{106}$" +) +_str_types = (str, bytes) +__alphabet = [ + ord(s) for s in "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" +] +__UINT64MAX = 2**64 +__encodedBlockSizes = [0, 2, 3, 5, 6, 7, 9, 10, 11] +__fullBlockSize = 8 +__fullEncodedBlockSize = 11 +indexbytes = _oper.getitem +intlist2bytes = bytes +int2byte = _oper.methodcaller("to_bytes", 1, "big") +b = 256 +q = 2**255 - 19 +l = 2**252 + 27742317777372353535851937790883648493 # noqa +PICONERO = Decimal("0.000000000001") +EMPTY_KEY = "0" * 64 +_integer_types = (int,) + + +# ----------------- Global Functions ----------------- @ + + +def to_atomic(amount): + """Convert Monero decimal to atomic integer of piconero.""" + if not isinstance(amount, (Decimal, float) + _integer_types): + raise ValueError( + "Amount '{}' doesn't have numeric type. Only Decimal, int, long and " + "float (not recommended) are accepted as amounts." + ) + return int(amount * 10**12) + + +def from_atomic(amount): + """Convert atomic integer of piconero to Monero decimal.""" + return (Decimal(amount) * PICONERO).quantize(PICONERO) + + +def as_monero(amount): + """Return the amount rounded to maximal Monero precision.""" + return Decimal(amount).quantize(PICONERO) + + +def _hexToBin(hex_): + if len(hex_) % 2 != 0: + raise ValueError("Hex string has invalid length: %d" % len(hex_)) + return [int(hex_[i : i + 2], 16) for i in range(0, len(hex_), 2)] # noqa + + +def _binToHex(bin_): + return "".join("%02x" % int(b) for b in bin_) + + +def _uint8be_to_64(data): + if not (1 <= len(data) <= 8): + raise ValueError("Invalid input length: %d" % len(data)) + + res = 0 + for b in data: + res = res << 8 | b + return res + + +def _uint64_to_8be(num, size): + if size < 1 or size > 8: + raise ValueError("Invalid input length: %d" % size) + res = [0] * size + + twopow8 = 2**8 + for i in range(size - 1, -1, -1): + res[i] = num % twopow8 + num = num // twopow8 + + return res + + +def xmr_base58_encode_block(data, buf, index): + l_data = len(data) + + if l_data < 1 or l_data > __fullEncodedBlockSize: + raise ValueError("Invalid block length: %d" % l_data) + + num = _uint8be_to_64(data) + i = __encodedBlockSizes[l_data] - 1 + + while num > 0: + remainder = num % 58 + num = num // 58 + buf[index + i] = __alphabet[remainder] + i -= 1 + + return buf + + +def xmr_base58_encode(hex): + """Encode hexadecimal string as base58 (ex: encoding a Monero address).""" + data = _hexToBin(hex) + l_data = len(data) + + if l_data == 0: + return "" + + full_block_count = l_data // __fullBlockSize + last_block_size = l_data % __fullBlockSize + res_size = ( + full_block_count * __fullEncodedBlockSize + __encodedBlockSizes[last_block_size] + ) + + res = bytearray([__alphabet[0]] * res_size) + + for i in range(full_block_count): + res = xmr_base58_encode_block( + data[(i * __fullBlockSize) : (i * __fullBlockSize + __fullBlockSize)], + res, + i * __fullEncodedBlockSize, + ) + + if last_block_size > 0: + res = xmr_base58_encode_block( + data[ + (full_block_count * __fullBlockSize) : ( + full_block_count * __fullBlockSize + last_block_size + ) + ], + res, + full_block_count * __fullEncodedBlockSize, + ) + + return bytes(res).decode("ascii") + + +def xmr_base58_decode_block(data, buf, index): + l_data = len(data) + + if l_data < 1 or l_data > __fullEncodedBlockSize: + raise ValueError("Invalid block length: %d" % l_data) + + res_size = __encodedBlockSizes.index(l_data) + if res_size <= 0: + raise ValueError("Invalid block size: %d" % res_size) + + res_num = 0 + order = 1 + for i in range(l_data - 1, -1, -1): + digit = __alphabet.index(data[i]) + if digit < 0: + raise ValueError("Invalid symbol: %s" % data[i]) + + product = order * digit + res_num + if product > __UINT64MAX: + raise ValueError( + "Overflow: %d * %d + %d = %d" % (order, digit, res_num, product) + ) + + res_num = product + order = order * 58 + + if res_size < __fullBlockSize and 2 ** (8 * res_size) <= res_num: + raise ValueError("Overflow: %d doesn't fit in %d bit(s)" % (res_num, res_size)) + + tmp_buf = _uint64_to_8be(res_num, res_size) + buf[index : index + len(tmp_buf)] = tmp_buf + + return buf + + +def xmr_base58_decode(enc): + """Decode a base58 string (ex: a Monero address) into hexidecimal form.""" + enc = bytearray(enc, encoding="ascii") + l_enc = len(enc) + + if l_enc == 0: + return "" + + full_block_count = l_enc // __fullEncodedBlockSize + last_block_size = l_enc % __fullEncodedBlockSize + try: + last_block_decoded_size = __encodedBlockSizes.index(last_block_size) + except ValueError: + raise ValueError("Invalid encoded length: %d" % l_enc) + + data_size = full_block_count * __fullBlockSize + last_block_decoded_size + + data = bytearray(data_size) + for i in range(full_block_count): + data = xmr_base58_decode_block( + enc[ + (i * __fullEncodedBlockSize) : ( + i * __fullEncodedBlockSize + __fullEncodedBlockSize + ) + ], + data, + i * __fullBlockSize, + ) + + if last_block_size > 0: + data = xmr_base58_decode_block( + enc[ + (full_block_count * __fullEncodedBlockSize) : ( + full_block_count * __fullEncodedBlockSize + last_block_size + ) + ], + data, + full_block_count * __fullBlockSize, + ) + + return _binToHex(data) + + +def expmod(b, e, m): + if e == 0: + return 1 + t = expmod(b, e // 2, m) ** 2 % m + if e & 1: + t = (t * b) % m + return t + + +def inv(x): + return expmod(x, q - 2, q) + + +d = -121665 * inv(121666) +I = expmod(2, (q - 1) // 4, q) + + +def xrecover(y): + xx = (y * y - 1) * inv(d * y * y + 1) + x = expmod(xx, (q + 3) // 8, q) + if (x * x - xx) % q != 0: + x = (x * I) % q + if x % 2 != 0: + x = q - x + return x + + +def compress(P): + zinv = inv(P[2]) + return (P[0] * zinv % q, P[1] * zinv % q) + + +def decompress(P): + return (P[0], P[1], 1, P[0] * P[1] % q) + + +By = 4 * inv(5) +Bx = xrecover(By) +B = [Bx % q, By % q] + + +def edwards(P, Q): + x1 = P[0] + y1 = P[1] + x2 = Q[0] + y2 = Q[1] + x3 = (x1 * y2 + x2 * y1) * inv(1 + d * x1 * x2 * y1 * y2) + y3 = (y1 * y2 + x1 * x2) * inv(1 - d * x1 * x2 * y1 * y2) + return [x3 % q, y3 % q] + + +def add(P, Q): + A = (P[1] - P[0]) * (Q[1] - Q[0]) % q + B = (P[1] + P[0]) * (Q[1] + Q[0]) % q + C = 2 * P[3] * Q[3] * d % q + D = 2 * P[2] * Q[2] % q + E = B - A + F = D - C + G = D + C + H = B + A + return (E * F, G * H, F * G, E * H) + + +def add_compressed(P, Q): + return compress(add(decompress(P), decompress(Q))) + + +def scalarmult(P, e): + if e == 0: + return [0, 1] + Q = scalarmult(P, e // 2) + Q = edwards(Q, Q) + if e & 1: + Q = edwards(Q, P) + return Q + + +def encodeint(y): + bits = [(y >> i) & 1 for i in range(b)] + return b"".join( + [int2byte(sum([bits[i * 8 + j] << j for j in range(8)])) for i in range(b // 8)] + ) + + +def encodepoint(P): + x = P[0] + y = P[1] + bits = [(y >> i) & 1 for i in range(b - 1)] + [x & 1] + return b"".join( + [int2byte(sum([bits[i * 8 + j] << j for j in range(8)])) for i in range(b // 8)] + ) + + +def bit(h, i): + return (indexbytes(h, i // 8) >> (i % 8)) & 1 + + +def isoncurve(P): + x = P[0] + y = P[1] + return (-x * x + y * y - 1 - d * x * x * y * y) % q == 0 + + +def decodeint(s): + return sum(2**i * bit(s, i) for i in range(0, b)) + + +def decodepoint(s): + y = sum(2**i * bit(s, i) for i in range(0, b - 1)) + x = xrecover(y) + if x & 1 != bit(s, b - 1): + x = q - x + P = [x, y] + if not isoncurve(P): + raise Exception("decoding point that is not on curve") + return P + + +def public_from_secret(k): + keyInt = decodeint(k) + aB = scalarmult(B, keyInt) + return encodepoint(aB) + + +def public_from_secret_hex(hk): + return hexlify(public_from_secret(unhexlify(hk))).decode() + + +def bech32_decode(bech): + charset = "qpzry9x8gf2tvdw0s3jn54khce6mua7l" + if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or ( + bech.lower() != bech and bech.upper() != bech + ): + return False + bech = bech.lower() + pos = bech.rfind("1") + if pos < 1 or pos + 7 > len(bech) or len(bech) > 90: + return False + if not all(x in charset for x in bech[pos + 1 :]): + return False + hrp = bech[:pos] + data = [charset.find(x) for x in bech[pos + 1 :]] + if not bech32_verify_checksum(hrp, data): + return False + return True + + +def bech32_polymod(values): + generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3] + chk = 1 + for value in values: + top = chk >> 25 + chk = (chk & 0x1FFFFFF) << 5 ^ value + for i in range(5): + chk ^= generator[i] if ((top >> i) & 1) else 0 + return chk + + +def bech32_hrp_expand(hrp): + return [ord(x) >> 5 for x in hrp] + [0] + [ord(x) & 31 for x in hrp] + + +def bech32_verify_checksum(hrp, data): + return bech32_polymod(bech32_hrp_expand(hrp) + data) == 1 + + +def hextobin(hexstr): + if (hexstr.length % 2) is not 0: + return False + res = list() + index = 0 + for char in hexstr: + res[index] = int(hexstr[(index * 2) : (index * 2 + 2)]) + index += 1 + return res + + +# ------------------ Helper Classes ---------------- # + + +class BaseAddress(object): + label = None + + def __init__(self, addr, label=None): + addr = str(addr) + if not _ADDR_REGEX.match(addr): + raise ValueError( + "Address must be 95 characters long base58-encoded string, " + "is {addr} ({len} chars length)".format(addr=addr, len=len(addr)) + ) + self._decode(addr) + self.label = label or self.label + + def is_mainnet(self): + """Returns `True` if the address belongs to mainnet. + :rtype: bool + """ + return self._decoded[0] == self._valid_netbytes[0] + + def is_testnet(self): + """Returns `True` if the address belongs to testnet. + :rtype: bool + """ + return self._decoded[0] == self._valid_netbytes[1] + + def is_stagenet(self): + """Returns `True` if the address belongs to stagenet. + :rtype: bool + """ + return self._decoded[0] == self._valid_netbytes[2] + + def _decode(self, address): + self._decoded = bytearray(unhexlify(xmr_base58_decode(address))) + checksum = self._decoded[-4:] + if checksum != sha3.keccak_256(self._decoded[:-4]).digest()[:4]: + raise ValueError("Invalid checksum in address {}".format(address)) + if self._decoded[0] not in self._valid_netbytes: + raise ValueError( + "Invalid address netbyte {nb}. Allowed values are: {allowed}".format( + nb=self._decoded[0], + allowed=", ".join(map(lambda b: "%02x" % b, self._valid_netbytes)), + ) + ) + + def __repr__(self): + return xmr_base58_encode(hexlify(self._decoded)) + + def __eq__(self, other): + if isinstance(other, BaseAddress): + return str(self) == str(other) + if isinstance(other, _str_types): + return str(self) == other + return super(BaseAddress, self).__eq__(other) + + def __hash__(self): + return hash(str(self)) + + +class Address(BaseAddress): + """Monero address. + Address of this class is the master address for a :class:`Wallet `. + :param address: a Monero address as string-like object + :param label: a label for the address (defaults to `None`) + """ + + _valid_netbytes = (18, 53, 24) + # NOTE: _valid_netbytes order is (mainnet, testnet, stagenet) + + def view_key(self): + """Returns public view key. + :rtype: str + """ + return hexlify(self._decoded[33:65]).decode() + + def spend_key(self): + """Returns public spend key. + :rtype: str + """ + return hexlify(self._decoded[1:33]).decode() + + def check_private_view_key(self, key): + """Checks if private view key matches this address. + :rtype: bool + """ + return public_from_secret_hex(key) == self.view_key() + + def check_private_spend_key(self, key): + """Checks if private spend key matches this address. + :rtype: bool + """ + return public_from_secret_hex(key) == self.spend_key() + + def with_payment_id(self, payment_id=0): + """Integrates payment id into the address. + :param payment_id: int, hexadecimal string or :class:`PaymentID ` + (max 64-bit long) + :rtype: `IntegratedAddress` + :raises: `TypeError` if the payment id is too long + """ + payment_id = PaymentID(payment_id) + if not payment_id.is_short(): + raise TypeError( + "Payment ID {0} has more than 64 bits and cannot be integrated".format( + payment_id + ) + ) + prefix = 54 if self.is_testnet() else 25 if self.is_stagenet() else 19 + data = ( + bytearray([prefix]) + + self._decoded[1:65] + + struct.pack(">Q", int(payment_id)) + ) + checksum = bytearray(sha3.keccak_256(data).digest()[:4]) + return IntegratedAddress(xmr_base58_encode(hexlify(data + checksum))) + + +class SubAddress(BaseAddress): + """Monero subaddress. + Any type of address which is not the master one for a wallet. + """ + + _valid_netbytes = (42, 63, 36) + # NOTE: _valid_netbytes order is (mainnet, testnet, stagenet) + + def with_payment_id(self, _): + raise TypeError("SubAddress cannot be integrated with payment ID") + + +class IntegratedAddress(Address): + """Monero integrated address. + A master address integrated with payment id (short one, max 64 bit). + """ + + _valid_netbytes = (19, 54, 25) + # NOTE: _valid_netbytes order is (mainnet, testnet, stagenet) + + def __init__(self, address): + address = str(address) + if not _IADDR_REGEX.match(address): + raise ValueError( + "Integrated address must be 106 characters long base58-encoded string, " + "is {addr} ({len} chars length)".format(addr=address, len=len(address)) + ) + self._decode(address) + + def payment_id(self): + """Returns the integrated payment id. + :rtype: :class:`PaymentID ` + """ + return PaymentID(hexlify(self._decoded[65:-4]).decode()) + + def base_address(self): + """Returns the base address without payment id. + :rtype: :class:`Address` + """ + prefix = 53 if self.is_testnet() else 24 if self.is_stagenet() else 18 + data = bytearray([prefix]) + self._decoded[1:65] + checksum = sha3.keccak_256(data).digest()[:4] + return Address(xmr_base58_encode(hexlify(data + checksum))) + + +class PaymentID(object): + """ + A class that validates Monero payment ID. + + Payment IDs can be used as str or int across the module, however this class + offers validation as well as simple conversion and comparison to those two + primitive types. + + :param payment_id: the payment ID as integer or hexadecimal string + """ + + _payment_id = None + + def __init__(self, payment_id): + if isinstance(payment_id, PaymentID): + payment_id = int(payment_id) + if isinstance(payment_id, _str_types): + payment_id = int(payment_id, 16) + elif not isinstance(payment_id, _integer_types): + raise TypeError( + "payment_id must be either int or hexadecimal str or bytes, " + "is {0}".format(type(payment_id)) + ) + if payment_id.bit_length() > 256: + raise ValueError( + "payment_id {0} is more than 256 bits long".format(payment_id) + ) + self._payment_id = payment_id + + def is_short(self): + """Returns True if payment ID is short enough to be included + in :class:`IntegratedAddress `.""" + return self._payment_id.bit_length() <= 64 + + def __repr__(self): + if self.is_short(): + return "{:016x}".format(self._payment_id) + return "{:064x}".format(self._payment_id) + + def __int__(self): + return self._payment_id + + def __eq__(self, other): + if isinstance(other, PaymentID): + return int(self) == int(other) + elif isinstance(other, _integer_types): + return int(self) == other + elif isinstance(other, _str_types): + return str(self) == other + return super(PaymentID, self).__eq__(other) + + +# ------------------ Validation Class ----------------- # + + +class Validation: + @staticmethod + def is_btc_chain(chain): + chain = chain.lower() + chains = ["main", "testnet"] + if chain in chains: + return True + return False + + @staticmethod + def is_xmr_chain(chain): + chain = chain.lower() + chains = ["mainnet", "testnet", "stagenet"] + if chain in chains: + return True + return False + + @staticmethod + def is_coin_ticker(coin): + coin = coin.lower() + coins = ["btc", "ltc", "xmr"] + if coin in coins: + return True + return False + + @staticmethod + def is_coin_name(coin): + coin = coin.lower() + coins = ["bitcoin", "litecoin", "monero"] + if coin in coins: + return True + return False + + @staticmethod + def is_address(coin, address): + coin = coin.lower() + if not Validation.is_coin_ticker(coin): + return False + address = address.strip() + if coin == "btc": + return Validation.is_btc_address(address) + if coin == "ltc": + return Validation.is_ltc_address(address) + if coin == "xmr": + return Validation.is_xmr_address(address) + return False + + @staticmethod + def is_btc_address(address): # Level 4 Validation + if address[0] == "1": # P2PKH Address + return base58.b58decode_check(address) + elif address[0] == "3": # P2SH Address + return base58.b58decode_check(address) + elif address[:3] == "bc1": # Bech32 Addresses (Segwit) + return bech32_decode(address) + else: + return False + + @staticmethod + def is_ltc_address(address): # Level 4 Validation + if len(address) > 43 or len(address) < 26: + return False + if address[0] == "L": # Legacy Non-P2SH Address + return base58.b58decode_check(address) + elif address[0] == "3": # P2SH Address - Deprecated + return False + elif address[0] == "M": # P2SH Address + return base58.b58decode_check(address) + elif address[:4] == "ltc1": # P2WPKH Bech32 (Segwit) + return bech32_decode(address) + return False + + @staticmethod + def is_xmr_address(address, label=None): # Level 4 Validation + addr = str(address) + charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + for char in address: + if char not in charset: + return False + if len(address) not in [95, 106]: + return False + if _ADDR_REGEX.match(addr): + try: + netbyte = bytearray(unhexlify(xmr_base58_decode(addr)))[0] + if netbyte in Address._valid_netbytes: + Address(addr, label=label) + return True + elif netbyte in SubAddress._valid_netbytes: + SubAddress(addr, label=label) + return True + except Exception: + return False + elif _IADDR_REGEX.match(addr): + try: + IntegratedAddress(addr) + return True + except Exception: + return False + return False diff --git a/requirements.txt b/requirements.txt index e2946f8..b29efa8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,3 +35,5 @@ aiohttp[speedups] elasticsearch[async] uvloop arrow +pysha3 +base58