Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
352 changes: 352 additions & 0 deletions deployment/address_book.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,352 @@
package deployment

import (
"fmt"
"strings"
"sync"

"golang.org/x/exp/maps"

"github.com/Masterminds/semver/v3"
"github.com/ethereum/go-ethereum/common"
"github.com/pkg/errors"
chainsel "github.com/smartcontractkit/chain-selectors"
)

var (
ErrInvalidChainSelector = errors.New("invalid chain selector")
ErrInvalidAddress = errors.New("invalid address")
ErrChainNotFound = errors.New("chain not found")
)

// ContractType is a simple string type for identifying contract types.
type ContractType string

func (ct ContractType) String() string {
return string(ct)
}

type TypeAndVersion struct {
Type ContractType `json:"Type"`
Version semver.Version `json:"Version"`
Labels LabelSet `json:"Labels,omitempty"`
}

func (tv TypeAndVersion) String() string {
if len(tv.Labels) == 0 {
return fmt.Sprintf("%s %s", tv.Type, tv.Version.String())
}

// Use the LabelSet's String method for sorted labels
sortedLabels := tv.Labels.String()
return fmt.Sprintf("%s %s %s",
tv.Type,
tv.Version.String(),
sortedLabels,
)
}

func (tv TypeAndVersion) Equal(other TypeAndVersion) bool {
// Compare Type
if tv.Type != other.Type {
return false
}
// Compare Versions
if !tv.Version.Equal(&other.Version) {
return false
}
// Compare Labels
return tv.Labels.Equal(other.Labels)
}

func MustTypeAndVersionFromString(s string) TypeAndVersion {
tv, err := TypeAndVersionFromString(s)
if err != nil {
panic(err)
}
return tv
}

// Note this will become useful for validation. When we want
// to assert an onchain call to typeAndVersion yields whats expected.
func TypeAndVersionFromString(s string) (TypeAndVersion, error) {
parts := strings.Fields(s) // Ignores consecutive spaces
if len(parts) < 2 {
return TypeAndVersion{}, fmt.Errorf("invalid type and version string: %s", s)
}
v, err := semver.NewVersion(parts[1])
if err != nil {
return TypeAndVersion{}, err
}
labels := make(LabelSet)
if len(parts) > 2 {
labels = NewLabelSet(parts[2:]...)
}
return TypeAndVersion{
Type: ContractType(parts[0]),
Version: *v,
Labels: labels,
}, nil
}

func NewTypeAndVersion(t ContractType, v semver.Version) TypeAndVersion {
return TypeAndVersion{
Type: t,
Version: v,
Labels: make(LabelSet), // empty set,
}
}

// AddressBook is a simple interface for storing and retrieving contract addresses across
// chains. It is family agnostic as the keys are chain selectors.
// We store rather than derive typeAndVersion as some contracts do not support it.
// For ethereum addresses are always stored in EIP55 format.
type AddressBook interface {
Save(chainSelector uint64, address string, tv TypeAndVersion) error
Addresses() (map[uint64]map[string]TypeAndVersion, error)
AddressesForChain(chain uint64) (map[string]TypeAndVersion, error)
// Allows for merging address books (e.g. new deployments with existing ones)
Merge(other AddressBook) error
Remove(ab AddressBook) error
}

type AddressesByChain map[uint64]map[string]TypeAndVersion

type AddressBookMap struct {
addressesByChain AddressesByChain
mtx sync.RWMutex
}

// Save will save an address for a given chain selector. It will error if there is a conflicting existing address.
func (m *AddressBookMap) save(chainSelector uint64, address string, typeAndVersion TypeAndVersion) error {
family, err := chainsel.GetSelectorFamily(chainSelector)
if err != nil {
return errors.Wrapf(ErrInvalidChainSelector, "chain selector %d", chainSelector)
}
if family == chainsel.FamilyEVM {
if address == "" || address == common.HexToAddress("0x0").Hex() {
return errors.Wrap(ErrInvalidAddress, "address cannot be empty")
}
if common.IsHexAddress(address) {
// IMPORTANT: WE ALWAYS STANDARDIZE ETHEREUM ADDRESS STRINGS TO EIP55
address = common.HexToAddress(address).Hex()
} else {
return errors.Wrapf(ErrInvalidAddress, "address %s is not a valid Ethereum address, only Ethereum addresses supported for EVM chains", address)
}
}

// TODO NONEVM-960: Add validation for non-EVM chain addresses

if typeAndVersion.Type == "" {
return errors.New("type cannot be empty")
}

if _, exists := m.addressesByChain[chainSelector]; !exists {
// First time chain add, create map
m.addressesByChain[chainSelector] = make(map[string]TypeAndVersion)
}
if _, exists := m.addressesByChain[chainSelector][address]; exists {
return fmt.Errorf("address %s already exists for chain %d", address, chainSelector)
}
m.addressesByChain[chainSelector][address] = typeAndVersion
return nil
}

// Save will save an address for a given chain selector. It will error if there is a conflicting existing address.
// thread safety version of the save method
func (m *AddressBookMap) Save(chainSelector uint64, address string, typeAndVersion TypeAndVersion) error {
m.mtx.Lock()
defer m.mtx.Unlock()
return m.save(chainSelector, address, typeAndVersion)
}

func (m *AddressBookMap) Addresses() (map[uint64]map[string]TypeAndVersion, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()

// maps are mutable and pass via a pointer
// creating a copy of the map to prevent concurrency
// read and changes outside object-bound
return m.cloneAddresses(m.addressesByChain), nil
}

func (m *AddressBookMap) AddressesForChain(chainSelector uint64) (map[string]TypeAndVersion, error) {
_, err := chainsel.GetChainIDFromSelector(chainSelector)
if err != nil {
return nil, errors.Wrapf(ErrInvalidChainSelector, "chain selector %d", chainSelector)
}

m.mtx.RLock()
defer m.mtx.RUnlock()

if _, exists := m.addressesByChain[chainSelector]; !exists {
return nil, errors.Wrapf(ErrChainNotFound, "chain selector %d", chainSelector)
}

// maps are mutable and pass via a pointer
// creating a copy of the map to prevent concurrency
// read and changes outside object-bound
return maps.Clone(m.addressesByChain[chainSelector]), nil
}

// Merge will merge the addresses from another address book into this one.
// It will error on any existing addresses.
func (m *AddressBookMap) Merge(ab AddressBook) error {
addresses, err := ab.Addresses()
if err != nil {
return err
}

m.mtx.Lock()
defer m.mtx.Unlock()

for chainSelector, chainAddresses := range addresses {
for address, typeAndVersion := range chainAddresses {
if err := m.save(chainSelector, address, typeAndVersion); err != nil {
return err
}
}
}
return nil
}

// Remove removes the address book addresses specified via the argument from the AddressBookMap.
// Errors if all the addresses in the given address book are not contained in the AddressBookMap.
func (m *AddressBookMap) Remove(ab AddressBook) error {
addresses, err := ab.Addresses()
if err != nil {
return err
}

m.mtx.Lock()
defer m.mtx.Unlock()

// State of m.addressesByChain storage must not be changed in case of an error
// need to do double iteration over the address book. First validation, second actual deletion
for chainSelector, chainAddresses := range addresses {
for address := range chainAddresses {
if _, exists := m.addressesByChain[chainSelector][address]; !exists {
return errors.New("AddressBookMap does not contain address from the given address book")
}
}
}

for chainSelector, chainAddresses := range addresses {
for address := range chainAddresses {
delete(m.addressesByChain[chainSelector], address)
}
}

return nil
}

// cloneAddresses creates a deep copy of map[uint64]map[string]TypeAndVersion object
func (m *AddressBookMap) cloneAddresses(input map[uint64]map[string]TypeAndVersion) map[uint64]map[string]TypeAndVersion {
result := make(map[uint64]map[string]TypeAndVersion)
for chainSelector, chainAddresses := range input {
result[chainSelector] = maps.Clone(chainAddresses)
}
return result
}

// TODO: Maybe could add an environment argument
// which would ensure only mainnet/testnet chain selectors are used
// for further safety?
func NewMemoryAddressBook() *AddressBookMap {
return &AddressBookMap{
addressesByChain: make(map[uint64]map[string]TypeAndVersion),
}
}

func NewMemoryAddressBookFromMap(addressesByChain map[uint64]map[string]TypeAndVersion) *AddressBookMap {
return &AddressBookMap{
addressesByChain: addressesByChain,
}
}

// SearchAddressBook search an address book for a given chain and contract type and return the first matching address.
func SearchAddressBook(ab AddressBook, chain uint64, typ ContractType) (string, error) {
addrs, err := ab.AddressesForChain(chain)
if err != nil {
return "", err
}

for addr, tv := range addrs {
if tv.Type == typ {
return addr, nil
}
}

return "", errors.New("not found")
}

func AddressBookContains(ab AddressBook, chain uint64, addrToFind string) (bool, error) {
addrs, err := ab.AddressesForChain(chain)
if err != nil {
return false, err
}

for addr := range addrs {
if addr == addrToFind {
return true, nil
}
}

return false, nil
}

type typeVersionKey struct {
Type ContractType
Version string
Labels string // store labels in a canonical form (comma-joined sorted list)
}

func tvKey(tv TypeAndVersion) typeVersionKey {
sortedLabels := tv.Labels.String()
return typeVersionKey{
Type: tv.Type,
Version: tv.Version.String(),
Labels: sortedLabels,
}
}

// EnsureDeduped ensures that each contract in the bundle only appears once
// in the address map. It returns an error if there are more than one instance of a contract.
// Returns true if every value in the bundle is found once, false otherwise.
func EnsureDeduped(addrs map[string]TypeAndVersion, bundle []TypeAndVersion) (bool, error) {
var (
grouped = toTypeAndVersionMap(addrs)
found = make([]TypeAndVersion, 0)
)
for _, btv := range bundle {
key := tvKey(btv)
matched, ok := grouped[key]
if ok {
found = append(found, btv)
}
if len(matched) > 1 {
return false, fmt.Errorf("found more than one instance of contract %s v%s (labels=%s)",
key.Type, key.Version, key.Labels)
}
}

// Indicate if each TypeAndVersion in the bundle is found at least once
return len(found) == len(bundle), nil
}

// toTypeAndVersionMap groups contract addresses by unique TypeAndVersion.
func toTypeAndVersionMap(addrs map[string]TypeAndVersion) map[typeVersionKey][]string {
tvkMap := make(map[typeVersionKey][]string)
for k, v := range addrs {
tvkMap[tvKey(v)] = append(tvkMap[tvKey(v)], k)
}
return tvkMap
}

// AddLabel adds a string to the LabelSet in the TypeAndVersion.
func (tv *TypeAndVersion) AddLabel(label string) {
if tv.Labels == nil {
tv.Labels = make(LabelSet)
}
tv.Labels.Add(label)
}
Loading