Implement updating protection

This commit is contained in:
Mark Veidemanis 2023-02-22 07:20:37 +00:00
parent ba8eb69309
commit ed63085e10
Signed by: m
GPG Key ID: 5ACFCEED46C0904F
4 changed files with 55 additions and 33 deletions

View File

@ -145,13 +145,14 @@ class OANDAExchange(BaseExchange):
r = trades.TradeDetails(accountID=self.account_id, tradeID=trade_id) r = trades.TradeDetails(accountID=self.account_id, tradeID=trade_id)
return self.call(r) return self.call(r)
def update_trade(self, trade): def update_trade(self, trade_id, take_profit_price, stop_loss_price):
raise NotImplementedError data = {}
# r = orders.OrderReplace( if take_profit_price:
# accountID=self.account_id, orderID=trade.order_id, data=data data["takeProfit"] = {"price": str(take_profit_price)}
# ) if stop_loss_price:
# self.client.request(r) data["stopLoss"] = {"price": str(stop_loss_price)}
# return r.response r = trades.TradeCRCDO(accountID=self.account_id, tradeID=trade_id, data=data)
return self.call(r)
def cancel_trade(self, trade_id): def cancel_trade(self, trade_id):
raise NotImplementedError raise NotImplementedError

View File

@ -730,3 +730,22 @@ TradeCloseSchema = {
"longPositionCloseout": "orderCreateTransaction.longPositionCloseout", "longPositionCloseout": "orderCreateTransaction.longPositionCloseout",
"longOrderFillTransaction": "orderCreateTransaction.longOrderFillTransaction", "longOrderFillTransaction": "orderCreateTransaction.longOrderFillTransaction",
} }
class TradeCRCDO(BaseModel):
takeProfitOrderCancelTransaction: OrderTransaction
takeProfitOrderTransaction: OrderTransaction
stopLossOrderCancelTransaction: OrderTransaction
stopLossOrderTransaction: OrderTransaction
relatedTransactionIDs: list[str]
lastTransactionID: str
TradeCRCDOSchema = {
"takeProfitOrderCancelTransaction": "takeProfitOrderCancelTransaction",
"takeProfitOrderTransaction": "takeProfitOrderTransaction",
"stopLossOrderCancelTransaction": "stopLossOrderCancelTransaction",
"stopLossOrderTransaction": "stopLossOrderTransaction",
"relatedTransactionIDs": "relatedTransactionIDs",
"lastTransactionID": "lastTransactionID",
}

View File

@ -182,8 +182,6 @@ class ActiveManagementMixinTestCase(StrategyMixin):
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR" "buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
) )
trade_size = round(trade_size, 0) trade_size = round(trade_size, 0)
print("TRADE SIZE", trade_size)
print("TYPE", type(trade_size))
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5) complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
self.open_trade(complex_trade) self.open_trade(complex_trade)
@ -192,16 +190,20 @@ class ActiveManagementMixinTestCase(StrategyMixin):
expected = { expected = {
"close": [ "close": [
{ {
"id": complex_trade.id, "id": complex_trade.order_id,
"check": "protection", "check": "protection",
} }
] ]
} }
del self.ams.checks["close"][0]["extra"] del self.ams.actions["close"][0]["extra"]
self.assertEqual(self.ams.checks, expected) self.assertEqual(self.ams.actions, expected)
self.ams.execute_actions()
trades = self.account.client.get_all_open_trades()
self.assertEqual(len(trades), 0)
def test_ams_protection_violated_adjust(self): def test_ams_protection_violated_adjust(self):
# Don't violate position size check # Don't violate position size check
@ -209,25 +211,24 @@ class ActiveManagementMixinTestCase(StrategyMixin):
"buy", self.account, self.strategy, self.account.client.get_balance(), "EUR" "buy", self.account, self.strategy, self.account.client.get_balance(), "EUR"
) )
trade_size = round(trade_size, 0) trade_size = round(trade_size, 0)
print("TRADE SIZE", trade_size)
print("TYPE", type(trade_size))
complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5) complex_trade = self.create_complex_trade("buy", trade_size, "EUR_USD", 5, 5)
self.open_trade(complex_trade) self.open_trade(complex_trade)
self.ams.run_checks() self.ams.run_checks()
expected = {
"adjust": [ self.assertEqual(len(self.ams.actions["adjust"]), 1)
{ expected_tp = self.ams.actions["adjust"][0]["extra"]["take_profit_price"]
"id": "21381", expected_sl = self.ams.actions["adjust"][0]["extra"]["stop_loss_price"]
"check": "protection", self.assertEqual(len(self.ams.actions["adjust"]), 1)
"extra": {
"stop_loss_price": D("1.05812"), self.ams.execute_actions()
"take_profit_price": D("1.08484"),
}, trades = self.account.client.get_all_open_trades()
} self.assertEqual(len(trades), 1)
] self.assertEqual(D(trades[0]["takeProfitOrder"]["price"]), expected_tp)
} self.assertEqual(D(trades[0]["stopLossOrder"]["price"]), expected_sl)
print("CHECKS", self.ams.actions)
self.close_trade(complex_trade)
def test_ams_asset_groups_violated(self): def test_ams_asset_groups_violated(self):
pass pass

View File

@ -86,6 +86,7 @@ class ActiveManagement(object):
for action_cast in action_cast_list: for action_cast in action_cast_list:
msg += f"ACTION: '{action}' on trade ID '{action_cast['id']}'\n" msg += f"ACTION: '{action}' on trade ID '{action_cast['id']}'\n"
msg += f"VIOLATION: '{action_cast['check']}'\n" msg += f"VIOLATION: '{action_cast['check']}'\n"
if "extra" in action_cast:
if action_cast["extra"]: if action_cast["extra"]:
extra = action_cast["extra"] extra = action_cast["extra"]
extra = ", ".join([f"{k}: {v}" for k, v in extra.items()]) extra = ", ".join([f"{k}: {v}" for k, v in extra.items()])
@ -114,7 +115,7 @@ class ActiveManagement(object):
self.strategy.account.client.close_trade(trade_id, difference, symbol) self.strategy.account.client.close_trade(trade_id, difference, symbol)
def adjust_protection(self, trade_id, new_protection): def adjust_protection(self, trade_id, new_protection):
pass # TODO self.strategy.account.client.update_trade(trade_id, **new_protection)
def bulk_adjust(self, action_cast_list): def bulk_adjust(self, action_cast_list):
for item in action_cast_list: for item in action_cast_list: