hbutils.random.state

Overview:

Random state and seed management.

Native default random instance (random._inst), torch, numpy and faker (default random instance) are supported. Other random sources can be registered with register_random_source() and register_random_instance().

register_random_source

hbutils.random.state.register_random_source(name: str, seed: Callable[[int], None], getstate: Callable[], T], setstate: Callable[[T], None])[source]
Overview:

Register random source, providing name, random seed function, state getting function and state setting function.

Parameters:
  • name – Name of random source.

  • seed – Seed function, format: seed(x).

  • getstate – State getting function, format: getstate().

  • setstate – State setting function, format: setstate(state).

Examples::
>>> import random
>>> from hbutils.random import global_seed, register_random_source
>>>
>>> rnd = random.Random()  # custom random object
>>> global_seed(0)  # try to use same seed
>>> rnd.random()
0.4563765178328746
>>> global_seed(0)
>>> rnd.random()  # not same
0.06875325446897462
>>>
>>> register_random_source('custom_random', rnd.seed, rnd.getstate, rnd.setstate)
>>> global_seed(0)  # try again
>>> rnd.random()
0.8444218515250481
>>> global_seed(0)
>>> rnd.random()  # the same
0.8444218515250481

register_random_instance

hbutils.random.state.register_random_instance(name: str, rnd: random.Random)[source]
Overview:

Register custom random instance.

Parameters:
  • name – Name of random source.

  • rnd – Custom random instance, should be an instance of random.Random.

Examples::
>>> import random
>>> from hbutils.random import global_seed, register_random_instance
>>>
>>> rnd = random.Random()  # custom random object
>>> global_seed(0)  # try to use same seed
>>> rnd.random()
0.48936053503964005
>>> global_seed(0)
>>> rnd.random()  # not same
0.4113361070387721
>>>
>>> register_random_instance('custom_random', rnd)
>>> global_seed(0)  # try again
>>> rnd.random()
0.8444218515250481
>>> global_seed(0)
>>> rnd.random()  # the same
0.8444218515250481

get_global_state

hbutils.random.state.get_global_state() → Dict[str, T][source]
Overview:

Get states of registered random sources.

Returns:

Global states.

Examples::
>>> import random
>>> import numpy as np
>>> import torch
>>> from faker import Faker
>>>
>>> from hbutils.random import get_global_state, set_global_state
>>>
>>> _ = random.randint(0, 100)  # just do something
>>> _ = random.random()
>>> _ = torch.randn(2, 3)
>>> _ = np.random.randn(2, 3)
>>>
>>> states = get_global_state()
>>> random.randint(0, 100)  # first time's result
99
>>> random.random()
0.4656250864192085
>>> torch.randn(2, 3)
tensor([[ 0.8886, -0.3602,  1.3071],
        [-0.0187, -0.5980, -0.5469]])
>>> np.random.randn(2, 3)
array([[ 1.24249156,  0.71018699, -0.53496231],
       [ 0.78748336, -0.01407442, -0.6607438 ]])
>>> Faker().sentence(5)
'New pass crime most.'
>>>
>>> set_global_state(states)
>>> random.randint(0, 100)  # same as the first time
99
>>> random.random()
0.4656250864192085
>>> torch.randn(2, 3)
tensor([[ 0.8886, -0.3602,  1.3071],
        [-0.0187, -0.5980, -0.5469]])
>>> np.random.randn(2, 3)
array([[ 1.24249156,  0.71018699, -0.53496231],
       [ 0.78748336, -0.01407442, -0.6607438 ]])
>>> Faker().sentence(5)
'New pass crime most.'

set_global_state

hbutils.random.state.set_global_state(states: Mapping[str, T])[source]
Overview:

Set states of registered random sources.

Parameters:

states – States to be set.

Examples::

See get_global_state().

keep_global_state

hbutils.random.state.keep_global_state()[source]
Overview:

Keep all the states during this period.

Examples::
>>> import torch
>>> from hbutils.random import global_seed, keep_global_state
>>>
>>> global_seed(0)
>>> torch.randn(2, 3)  # before value 1
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986]])
>>> torch.randn(2, 3)  # after value 1
tensor([[ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])
>>>
>>> global_seed(0)
>>> torch.randn(2, 3)  # before value 2, same as 1
tensor([[ 1.5410, -0.2934, -2.1788],
        [ 0.5684, -1.0845, -1.3986]])
>>> with keep_global_state():  # do anything you want here
...     _ = torch.randn(100, 200, 2)
...     _ = torch.randint(20, 30, (30, 40))
>>> torch.randn(2, 3)  # after value 2, same as 1
tensor([[ 0.4033,  0.8380, -0.7193],
        [-0.4033, -0.5966,  0.1820]])

global_seed

hbutils.random.state.global_seed(seed: int)[source]
Overview:

Set seed for all registered random source.

Parameters:

seed – Random seed.

Examples::

See keep_global_state() and register_random_instance().

seedable_func

hbutils.random.state.seedable_func(func)[source]
Overview:

Wrap function to support seed.

Parameters:

func – Function to be decorated, a new seed argument will be added.

Examples::
>>> import torch
>>> from hbutils.random import seedable_func
>>>
>>> @seedable_func
... def get_random_value(mean, std):
...     return torch.randn((2, 3)) * std + mean
>>>
>>> get_random_value(2, 3)
tensor([[-0.0844,  5.2530,  3.4248],
        [ 4.4923,  0.0492,  3.6731]])
>>> get_random_value(2, 3)  # not the same
tensor([[2.7600, 2.5135, 0.6484],
        [1.2459, 0.1020, 2.5905]])
>>>
>>> get_random_value(2, 3, seed=0)
tensor([[ 6.6230,  1.1197, -4.5364],
        [ 3.7053, -1.2536, -2.1958]])
>>> get_random_value(2, 3, seed=1)
tensor([[3.9841, 2.8008, 2.1850],
        [3.8640, 0.6443, 1.5016]])
>>> get_random_value(2, 3, seed=0)  # repeatable
tensor([[ 6.6230,  1.1197, -4.5364],
        [ 3.7053, -1.2536, -2.1958]])