Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ lint:
@ruff check . --fix
@ruff format .
.PHONY: lint

test:
python -m coverage run -m pytest tests/ -v && \
python -m coverage report
.PHONY: test
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ dependencies = [
"snet-contracts==1.0.1",
"lighthouseweb3~=0.1.4",
"py-multihash~=3.0",
"pydantic~=2.11",
"pydantic-settings~=2.13"
]

[tool.poetry.group.dev.dependencies]
ruff = "^0.11"
pytest = "^8.3"
coverage = "^7.13"

[tool.ruff]
line-length = 100
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ ipfshttpclient==0.4.13.2
snet-contracts==1.0.1
lighthouseweb3~=0.1.4
py-multihash~=3.0
pydantic~=2.11
pydantic-settings~=2.13
93 changes: 42 additions & 51 deletions snet/sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import google.protobuf.internal.api_implementation
from google.protobuf import symbol_database as _symbol_database

from snet.sdk.storage_provider.service_metadata import MPEServiceMetadata
from snet.sdk.registry.organization_metadata import OrganizationMetadata
from snet.sdk.registry.registry_contract import RegistryContract
from snet.sdk.registry.service_metadata import MPEServiceMetadata

with warnings.catch_warnings():
# Suppress the eth-typing package`s warnings related to some new networks
Expand All @@ -18,9 +20,6 @@
UserWarning,
)

import web3

from snet.contracts import get_contract_object
from snet.sdk.account import Account
from snet.sdk.config import config
from snet.sdk.client_lib_generator import ClientLibGenerator
Expand All @@ -34,12 +33,11 @@
PaymentStrategy,
)
from snet.sdk.service_client import ServiceClient
from snet.sdk.storage_provider.storage_provider import StorageProvider
from snet.sdk.custom_typing import ModuleName, ServiceStub
from snet.sdk.registry.storage_provider import StorageProvider
from snet.sdk.types import ModuleName, ServiceStub
from snet.sdk.utils.utils import (
bytes32_to_str,
find_file_by_keyword,
type_converter,
get_we3_object,
)

google.protobuf.internal.api_implementation.Type = lambda: "python"
Expand All @@ -55,29 +53,13 @@


class SnetSDK:
"""Base Snet SDK"""

def __init__(self):
self.web3 = web3.Web3(web3.HTTPProvider(config.ETH_RPC_ENDPOINT))

mpe_contract_address = config.MPE_CONTRACT_ADDRESS
if not mpe_contract_address:
self.mpe_contract = MPEContract(self.web3)
else:
self.mpe_contract = MPEContract(self.web3, mpe_contract_address)

registry_contract_address = config.REGISTRY_CONTRACT_ADDRESS
if registry_contract_address is None:
self.registry_contract = get_contract_object(self.web3, "Registry")
else:
self.registry_contract = get_contract_object(
self.web3, "Registry", registry_contract_address
)

self.metadata_provider = StorageProvider(self.registry_contract)

self.account = Account(self.web3, self.mpe_contract)
self.payment_channel_provider = PaymentChannelProvider(self.web3, self.mpe_contract)
self.w3 = get_we3_object()
self.mpe_contract = MPEContract()
self.registry_contract = RegistryContract()
self.storage_provider = StorageProvider(self.registry_contract)
self.payment_channel_provider = PaymentChannelProvider(self.mpe_contract)
self.account = Account()

def create_service_client(
self,
Expand All @@ -92,14 +74,14 @@

# Create and instance of the Config object,
# so we can create an instance of ClientLibGenerator
lib_generator = ClientLibGenerator(self.metadata_provider, org_id, service_id)
lib_generator = ClientLibGenerator(self.storage_provider, org_id, service_id)

# Download the proto file and generate stubs if needed
force_update = config.FORCE_UPDATE
if force_update:
lib_generator.generate_client_library()
else:
path_to_pb_files = lib_generator.protodir
path_to_pb_files = lib_generator.proto_dir
pb_2_file_name = find_file_by_keyword(
path_to_pb_files, keyword="pb2.py", exclude=["training"]
)
Expand All @@ -118,7 +100,7 @@
if payment_strategy is None:
payment_strategy = payment_strategy_type.value()

service_metadata = self.metadata_provider.enhance_service_metadata(org_id, service_id)
service_metadata = self._enhance_service_metadata(org_id, service_id)
group = self._get_service_group_details(service_metadata, group_name)

service_stubs = self.get_service_stub(lib_generator)
Expand All @@ -134,16 +116,29 @@
options,
self.mpe_contract,
self.account,
self.web3,
self.w3,
pb2_module,
self.payment_channel_provider,
lib_generator.protodir,
lib_generator.proto_dir,
lib_generator.training_added(),
)
return _service_client

def _enhance_service_metadata(self, org_id, service_id):
service_metadata = self.get_service_metadata(org_id, service_id)
org_metadata = self.get_organization_metadata(org_id)

org_group_map = {}
for group in org_metadata.groups:
org_group_map[group.group_name] = group

for group in service_metadata.groups:
group.payment = org_group_map[group.group_name].payment

return service_metadata

def get_service_stub(self, lib_generator: ClientLibGenerator) -> list[ServiceStub]:
path_to_pb_files = str(lib_generator.protodir)
path_to_pb_files = str(lib_generator.proto_dir)
module_name = self.get_module_by_keyword("pb2_grpc.py", lib_generator)
sys.path.append(path_to_pb_files)
try:
Expand All @@ -159,13 +154,18 @@
raise Exception(f"Error importing module: {e}")

def get_module_by_keyword(self, keyword: str, lib_generator: ClientLibGenerator) -> ModuleName:
path_to_pb_files = lib_generator.protodir
path_to_pb_files = lib_generator.proto_dir
file_name = find_file_by_keyword(path_to_pb_files, keyword, exclude=["training"])
module_name = os.path.splitext(file_name)[0]
return ModuleName(module_name)

def get_service_metadata(self, org_id, service_id):
return self.metadata_provider.fetch_service_metadata(org_id, service_id)
service = self.registry_contract.get_service(org_id, service_id)
return self.storage_provider.fetch_service_metadata(service.metadata_uri)

def get_organization_metadata(self, org_id: str) -> OrganizationMetadata:
org = self.registry_contract.get_org(org_id)
return self.storage_provider.fetch_org_metadata(org.metadata_uri)

def _get_first_group(self, service_metadata: MPEServiceMetadata) -> dict:
return service_metadata["groups"][0]
Expand All @@ -176,7 +176,8 @@
for group in service_metadata["groups"]:
if group["group_name"] == group_name:
return group
return {}
# TODO: configure exceptions

Check notice on line 179 in snet/sdk/__init__.py

View check run for this annotation

snet-sonarqube-app / SonarQube Code Analysis

snet/sdk/__init__.py#L179

Complete the task associated to this "TODO" comment.
raise Exception()

Check warning on line 180 in snet/sdk/__init__.py

View check run for this annotation

snet-sonarqube-app / SonarQube Code Analysis

snet/sdk/__init__.py#L180

Replace this generic exception class with a more specific one.

def _get_service_group_details(
self, service_metadata: MPEServiceMetadata, group_name: str
Expand All @@ -190,17 +191,7 @@
return self._get_group_by_group_name(service_metadata, group_name)

def get_organization_list(self) -> list:
org_list = self.registry_contract.functions.listOrganizations().call()
organization_list = []
for idx, org_id in enumerate(org_list):
organization_list.append(bytes32_to_str(org_id))
return organization_list
return self.registry_contract.list_orgs()

def get_services_list(self, org_id: str) -> list:
found, org_service_list = self.registry_contract.functions.listServicesForOrganization(
type_converter("bytes32")(org_id)
).call()
if not found:
raise Exception(f"Organization with id={org_id} doesn't exist!")
org_service_list = list(map(bytes32_to_str, org_service_list))
return org_service_list
return self.registry_contract.list_service_for_org(org_id)
53 changes: 17 additions & 36 deletions snet/sdk/account.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import json

import web3

from snet.contracts import get_contract_object

from snet.sdk.config import config
from snet.sdk.mpe.mpe_contract import MPEContract
from snet.sdk.utils.utils import get_address_from_private, normalize_private_key
from snet.sdk.utils.utils import get_address_from_private, normalize_private_key, get_we3_object

DEFAULT_GAS = 300000
TRANSACTION_TIMEOUT = 500
Expand All @@ -28,17 +26,13 @@ def __str__(self):


class Account:
def __init__(self, w3: web3.Web3, mpe_contract: MPEContract):
self.web3 = w3
self.mpe_contract = mpe_contract
def __init__(self):
self.w3 = get_we3_object()
self.mpe_address = config.MPE_CONTRACT_ADDRESS

token_contract_address = config.TOKEN_CONTRACT_ADDRESS
if not token_contract_address:
self.token_contract = get_contract_object(self.web3, "FetchToken")
else:
self.token_contract = get_contract_object(
self.web3, "FetchToken", token_contract_address
)
self.token_contract = get_contract_object(
self.w3, "FetchToken", config.TOKEN_CONTRACT_ADDRESS
)

if config.PRIVATE_KEY:
self.private_key = normalize_private_key(config.PRIVATE_KEY)
Expand All @@ -52,19 +46,19 @@ def __init__(self, w3: web3.Web3, mpe_contract: MPEContract):
self.nonce = 0

def _get_nonce(self):
nonce = self.web3.eth.get_transaction_count(self.address)
nonce = self.w3.eth.get_transaction_count(self.address)
if self.nonce >= nonce:
nonce = self.nonce + 1
self.nonce = nonce
return nonce

def _get_gas_price(self):
gas_price = self.web3.eth.gas_price
gas_price = self.w3.eth.gas_price
if gas_price <= 15000000000:
gas_price += gas_price * 1 / 3
elif gas_price > 15000000000 and gas_price <= 50000000000:
elif 15000000000 < gas_price <= 50000000000:
gas_price += gas_price * 1 / 5
elif gas_price > 50000000000 and gas_price <= 150000000000:
elif 50000000000 < gas_price <= 150000000000:
gas_price += 7000000000
elif gas_price > 150000000000:
gas_price += gas_price * 1 / 10
Expand All @@ -73,44 +67,31 @@ def _get_gas_price(self):
def _send_signed_transaction(self, contract_fn, *args):
transaction = contract_fn(*args).build_transaction(
{
"chainId": int(self.web3.net.version),
"chainId": int(self.w3.net.version),
"gas": DEFAULT_GAS,
"gasPrice": self._get_gas_price(),
"nonce": self._get_nonce(),
}
)
signed_txn = self.web3.eth.account.sign_transaction(
transaction, private_key=self.private_key
)
return self.web3.to_hex(self.web3.eth.send_raw_transaction(signed_txn.raw_transaction))
signed_txn = self.w3.eth.account.sign_transaction(transaction, private_key=self.private_key)
return self.w3.to_hex(self.w3.eth.send_raw_transaction(signed_txn.raw_transaction))

def send_transaction(self, contract_fn, *args):
txn_hash = self._send_signed_transaction(contract_fn, *args)
return self.web3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT)
return self.w3.eth.wait_for_transaction_receipt(txn_hash, TRANSACTION_TIMEOUT)

def _parse_receipt(self, receipt, event, encoder=json.JSONEncoder):
if receipt.status == 0:
raise TransactionError("Transaction failed", receipt)
else:
return json.dumps(dict(event().processReceipt(receipt)[0]["args"]), cls=encoder)

def escrow_balance(self):
return self.mpe_contract.balance(self.address)

def deposit_to_escrow_account(self, amount_in_cogs):
already_approved = self.allowance()
if amount_in_cogs > already_approved:
self.approve_transfer(amount_in_cogs)
return self.mpe_contract.deposit(self, amount_in_cogs)

def approve_transfer(self, amount_in_cogs):
return self.send_transaction(
self.token_contract.functions.approve,
self.mpe_contract.contract.address,
self.mpe_address,
amount_in_cogs,
)

def allowance(self):
return self.token_contract.functions.allowance(
self.address, self.mpe_contract.contract.address
).call()
return self.token_contract.functions.allowance(self.address, self.mpe_address).call()
Loading
Loading