Skip to content

Commit

Permalink
Refactored provider module
Browse files Browse the repository at this point in the history
  • Loading branch information
bsrinivas8687 committed Nov 28, 2022
1 parent 7d0681c commit d08a729
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 65 deletions.
2 changes: 1 addition & 1 deletion x/provider/keeper/alias.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
)

func (k *Keeper) FundCommunityPool(ctx sdk.Context, from sdk.AccAddress, coin sdk.Coin) error {
if coin.IsZero() {
if !coin.IsPositive() {
return nil
}

Expand Down
20 changes: 9 additions & 11 deletions x/provider/keeper/msg_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,25 @@ func NewMsgServiceServer(keeper Keeper) types.MsgServiceServer {
func (k *msgServer) MsgRegister(c context.Context, msg *types.MsgRegisterRequest) (*types.MsgRegisterResponse, error) {
ctx := sdk.UnwrapSDKContext(c)

msgFrom, err := sdk.AccAddressFromBech32(msg.From)
fromAddr, err := sdk.AccAddressFromBech32(msg.From)
if err != nil {
return nil, err
}

_, found := k.GetProvider(ctx, msgFrom.Bytes())
_, found := k.GetProvider(ctx, fromAddr.Bytes())
if found {
return nil, types.ErrorDuplicateProvider
}

deposit := k.Deposit(ctx)
if deposit.IsPositive() {
if err := k.FundCommunityPool(ctx, msgFrom, deposit); err != nil {
return nil, err
}
if err := k.FundCommunityPool(ctx, fromAddr, deposit); err != nil {
return nil, err
}

var (
provAddress = hubtypes.ProvAddress(msgFrom.Bytes())
provider = types.Provider{
Address: provAddress.String(),
provAddr = hubtypes.ProvAddress(fromAddr.Bytes())
provider = types.Provider{
Address: provAddr.String(),
Name: msg.Name,
Identity: msg.Identity,
Website: msg.Website,
Expand All @@ -65,12 +63,12 @@ func (k *msgServer) MsgRegister(c context.Context, msg *types.MsgRegisterRequest
func (k *msgServer) MsgUpdate(c context.Context, msg *types.MsgUpdateRequest) (*types.MsgUpdateResponse, error) {
ctx := sdk.UnwrapSDKContext(c)

msgFrom, err := hubtypes.ProvAddressFromBech32(msg.From)
fromAddr, err := hubtypes.ProvAddressFromBech32(msg.From)
if err != nil {
return nil, err
}

provider, found := k.GetProvider(ctx, msgFrom)
provider, found := k.GetProvider(ctx, fromAddr)
if !found {
return nil, types.ErrorProviderDoesNotExist
}
Expand Down
2 changes: 1 addition & 1 deletion x/provider/simulation/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func RandomizedGenesisState(state *module.SimulationState) *types.GenesisState {
func(r *rand.Rand) {
deposit = sdk.NewInt64Coin(
sdk.DefaultBondDenom,
r.Int63n(MaxDepositAmount),
r.Int63n(MaxInt),
)
},
)
Expand Down
15 changes: 12 additions & 3 deletions x/provider/simulation/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import (
)

const (
MaxDepositAmount = 1 << 18
MaxInt = 1 << 18
MaxInt = 1 << 18
)

func ParamChanges(_ *rand.Rand) []simulationtypes.ParamChange {
Expand All @@ -23,7 +22,17 @@ func ParamChanges(_ *rand.Rand) []simulationtypes.ParamChange {
func(r *rand.Rand) string {
return sdk.NewInt64Coin(
sdk.DefaultBondDenom,
r.Int63n(MaxDepositAmount),
r.Int63n(MaxInt),
).String()
},
),
simulation.NewSimParamChange(
types.ModuleName,
string(types.KeyStakingShare),
func(r *rand.Rand) string {
return sdk.NewDecWithPrec(
r.Int63n(MaxInt),
6,
).String()
},
),
Expand Down
92 changes: 48 additions & 44 deletions x/provider/types/params.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,11 @@ var (
)

func (m *Params) Validate() error {
if m.Deposit.IsNegative() {
return fmt.Errorf("deposit cannot be negative")
}
if !m.Deposit.IsValid() {
return fmt.Errorf("invalid deposit %s", m.Deposit)
if err := validateDeposit(m.Deposit); err != nil {
return err
}
if m.StakingShare.IsNegative() {
return fmt.Errorf("staking_share cannot be negative")
}
if m.StakingShare.GT(sdk.NewDec(1)) {
return fmt.Errorf("staking_share cannot be greater than 1")
if err := validateStakingShare(m.StakingShare); err != nil {
return err
}

return nil
Expand All @@ -41,42 +35,14 @@ func (m *Params) Validate() error {
func (m *Params) ParamSetPairs() params.ParamSetPairs {
return params.ParamSetPairs{
{
Key: KeyDeposit,
Value: &m.Deposit,
ValidatorFn: func(v interface{}) error {
value, ok := v.(sdk.Coin)
if !ok {
return fmt.Errorf("invalid parameter type %T", v)
}

if value.IsNegative() {
return fmt.Errorf("deposit cannot be negative")
}
if !value.IsValid() {
return fmt.Errorf("invalid deposit %s", value)
}

return nil
},
Key: KeyDeposit,
Value: &m.Deposit,
ValidatorFn: validateDeposit,
},
{
Key: KeyStakingShare,
Value: &m.StakingShare,
ValidatorFn: func(v interface{}) error {
value, ok := v.(sdk.Dec)
if !ok {
return fmt.Errorf("invalid parameter type %T", v)
}

if value.IsNegative() {
return fmt.Errorf("staking_share cannot be negative")
}
if value.GT(sdk.NewDec(1)) {
return fmt.Errorf("staking_share cannot be greater than 1")
}

return nil
},
Key: KeyStakingShare,
Value: &m.StakingShare,
ValidatorFn: validateStakingShare,
},
}
}
Expand All @@ -98,3 +64,41 @@ func DefaultParams() Params {
func ParamsKeyTable() params.KeyTable {
return params.NewKeyTable().RegisterParamSet(&Params{})
}

func validateDeposit(v interface{}) error {
value, ok := v.(sdk.Coin)
if !ok {
return fmt.Errorf("invalid parameter type %T", v)
}

if value.IsNil() {
return fmt.Errorf("deposit cannot be nil")
}
if value.IsNegative() {
return fmt.Errorf("deposit cannot be negative")
}
if !value.IsValid() {
return fmt.Errorf("invalid deposit %s", value)
}

return nil
}

func validateStakingShare(v interface{}) error {
value, ok := v.(sdk.Dec)
if !ok {
return fmt.Errorf("invalid parameter type %T", v)
}

if value.IsNil() {
return fmt.Errorf("staking_share cannot be nil")
}
if value.IsNegative() {
return fmt.Errorf("staking_share cannot be negative")
}
if value.GT(sdk.NewDec(1)) {
return fmt.Errorf("staking_share cannot be greater than 1")
}

return nil
}
69 changes: 64 additions & 5 deletions x/provider/types/params_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import (

func TestParams_Validate(t *testing.T) {
type fields struct {
Deposit sdk.Coin
Deposit sdk.Coin
StakingShare sdk.Dec
}
tests := []struct {
name string
Expand All @@ -25,7 +26,14 @@ func TestParams_Validate(t *testing.T) {
{
"invalid denom deposit",
fields{
Deposit: sdk.Coin{Denom: "0", Amount: sdk.NewInt(1000)},
Deposit: sdk.Coin{Denom: "o", Amount: sdk.NewInt(1000)},
},
true,
},
{
"empty amount deposit",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.Int{}},
},
true,
},
Expand All @@ -39,22 +47,73 @@ func TestParams_Validate(t *testing.T) {
{
"zero amount deposit",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(0)},
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(0)},
StakingShare: sdk.NewDec(0),
},
false,
},
{
"positive amount deposit",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
StakingShare: sdk.NewDec(0),
},
false,
},
{
"empty staking share",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
StakingShare: sdk.Dec{},
},
true,
},
{
"less than 0 staking share",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
StakingShare: sdk.NewDec(-1),
},
true,
},
{
"equals to 0 staking share",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
StakingShare: sdk.NewDec(0),
},
false,
},
{
"less than 1 staking share",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
StakingShare: sdk.NewDecWithPrec(1, 1),
},
false,
},
{
"equals to 1 staking share",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
StakingShare: sdk.NewDec(1),
},
false,
},
{
"greater than 1 staking share",
fields{
Deposit: sdk.Coin{Denom: "one", Amount: sdk.NewInt(1000)},
StakingShare: sdk.NewDec(2),
},
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Params{
Deposit: tt.fields.Deposit,
Deposit: tt.fields.Deposit,
StakingShare: tt.fields.StakingShare,
}
if err := p.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
Expand Down

0 comments on commit d08a729

Please sign in to comment.