Runner
This is part of the Cogment Enterprise, AI Redefined's commercial offering.
General usage
asyncio
The use of this module requires Cogment Python SDK >= 2.8.0. It uses Python's asyncio
library and as such should be run in an asyncio.Task
.
This documentation assumes some familiarity with the asyncio
library of Python (see Python documentation).
E.g.
import asyncio
asyncio.run(MyMainFunction())
Logging
This module uses the cogment.enterprise
logger, and the default log level is INFO
. E.g. to change the log level to WARNING
:
import cogment_enterprise
import logging
logging.getLogger("cogment.enterprise").setLevel(logging.WARNING)
Or set the environment variable COGMENT_ENTERPRISE_LOG_LEVEL
to one of the values: off
, error
, warning
, info
, debug
, trace
.
The logging works the same as Cogment Python SDK logging (see Cogment Python SDK documentation).
Trial Specifications
This module is designed to work without any trial specifications (i.e. cog_settings
), but is easier to use if the specification are available.
If the specifications are not provided, some internal object deserializations will not happen (e.g. sample.observation
), and special serialized versions will have to be used (e.g. sample.observation_serialized
).
Helper functions are provided for deserializing the various defined objects in the specifications (see below).
Objects normally received as google.protobuf.Any
will still be deserialized to such an object as it does not depend on the specification of the trial.
Top-level import
The main module of the Runner SDK is cogment_enterprise.runner
, and most enterprise scripts will start with a cogment_enterprise.runner.TrialRunner.
Utilities and Constants
cogment_enterprise.runner.BATCH_ID_PROPERTY
This is the name of the trial property where the batch ID is stored. Each trial started by a batch will have this property.
batch_id = trial_parameters.properties[cogment_enterprise.runner.BATCH_ID_PROPERTY]
cogment_enterprise.runner.BATCH_TRIAL_INDEX_PROPERTY
This is the name of the trial property where the index of the trial in the batch is stored. Each trial started by a batch will have this property.
trial_index_in_batch = trial_parameters.properties[cogment_enterprise.runner.BATCH_TRIAL_INDEX_PROPERTY]
cogment_enterprise.runner.BATCH_LAST_TRIAL_PROPERTY
This is the name of the trial property that will be set on the last trial of the batch. The property value is empty, it's presence indicates that this is the last trial of the batch. Only one trial in a batch may have this property.
Note that there may not be a trial with this property if the batch was stopped prematurely.
last_trial = cogment_enterprise.runner.BATCH_LAST_TRIAL_PROPERTY in trial_parameters.properties
cogment_enterprise.runner.deserialize_action(serialized_data, actor_class, cog_settings)
Function to deserialize raw data into a Python class instance.
The data can only be deserialized by knowing the protobuf message it represents.
It can be done manually if one knows the protobuf message represented.
This function simplifies deserialization of messages related to a Cogment project with the trial spec module cog_settings
.
Parameters:
serialized_data
: bytes - Raw data received.actor_class
: str - Name of the class of the actor to which this data relates. This information is necessary to find the proper message type in the spec.cog_settings
: module - Specification module associated with the trial from which the data relates.
Return: protobuf class instance - Action from an actor of type actor_class
. The class of the action is defined as action space for the specific actor class in the section actor_classes:action:space
in the spec file (e.g. cog_settings
).
cogment_enterprise.runner.deserialize_actor_observation(serialized_data, actor_class, cog_settings)
Function to deserialize raw data into a Python class instance.
The data can only be deserialized by knowing the protobuf message it represents.
It can be done manually if one knows the protobuf message represented.
This function simplifies deserialization of messages related to a Cogment project with the trial spec module cog_settings
.
Parameters:
serialized_data
: bytes - Raw data received.actor_class
: str - Name of the class of the actor to which this data relates. This information is necessary to find the proper message type in the spec.cog_settings
: module - Specification module associated with the trial from which the data relates.
Return: protobuf class instance - Observation for an actor of type actor_class
. The class of the observation is defined as observation space for the specific actor class in the section actor_classes:observation:space
in the spec file (e.g. cog_settings
).
cogment_enterprise.runner.deserialize_actor_config(serialized_data, actor_class, cog_settings)
Function to deserialize raw data into a Python class instance.
The data can only be deserialized by knowing the protobuf message it represents.
It can be done manually if one knows the protobuf message represented.
This function simplifies deserialization of messages related to a Cogment project with the trial spec module cog_settings
.
Parameters:
serialized_data
: bytes - Raw data received.actor_class
: str - Name of the class of the actor to which this data relates. This information is necessary to find the proper message type in the spec.cog_settings
: module - Specification module associated with the trial from which the data relates.
Return: protobuf class instance - Config for an actor of type actor_class
. The class of the config is defined as config type for the specific actor class in the section actor_classes:config_type
in the spec file (e.g. cog_settings
).
cogment_enterprise.runner.deserialize_environment_config(serialized_data, cog_settings)
Function to deserialize raw data into a Python class instance.
The data can only be deserialized by knowing the protobuf message it represents.
It can be done manually if one knows the protobuf message represented.
This function simplifies deserialization of messages related to a Cogment project with the trial spec module cog_settings
.
Parameters:
serialized_data
: bytes - Raw data received.cog_settings
: module - Specification module associated with the trial from which the data relates.
Return: protobuf class instance - Config for the environment. The class of the config is defined as config type in the section environment:config_type
in the spec file (e.g. cog_settings
).
cogment_enterprise.runner.deserialize_trial_config(serialized_data, cog_settings)
Function to deserialize raw data into a Python class instance.
The data can only be deserialized by knowing the protobuf message it represents.
It can be done manually if one knows the protobuf message represented.
This function simplifies deserialization of messages related to a Cogment project with the trial spec module cog_settings
.
Parameters:
serialized_data
: bytes - Raw data received.cog_settings
: module - Specification module associated with the trial from which the data relates.
Return: protobuf class instance - Config for the trial. The class of the config is defined as config type in the section trial:config_type
in the spec file (e.g. cog_settings
).
class cogment_enterprise.runner.TrialRunner
__init__(self, user_id, cog_settings=None, asyncio_loop=None, directory_endpoint=None, directory_auth_token=None, orchestrator_endpoint=None, datastore_endpoint=None, model_registry=None)
Parameters:
user_id
: str - Identifier for the user of this context.cog_settings
: module - Settings module associated with trials that will be run (cog_settings namespace).asyncio_loop
: asyncio.Loop - For special purpose implementations.directory_endpoint
: Endpoint instance - Grpc endpoint (i.e. starting with "grpc://") to access the directory. The directory will be used to inquire discovery endpoints, and to register the services for discovery. If no endpoint is provided, a check for the environment variableCOGMENT_DIRECTORY_ENDPOINT
will be made and if it exists, it will be used as the URL of a basic endpoint.directory_auth_token
: str - Authentication token for access to the directory. This token will be registered with the services, and must match registered tokens when inquiring the directory. If no token is provided, a check for the environment variableCOGMENT_DIRECTORY_AUTHENTICATION_TOKEN
will be made and if it exists, it will be used as the token.orchestrator_endpoint
: Endpoint instance - Details of the connection to the Orchestrator. If not provided, the directory will be inquired. Only needed for running batches, not for training.datastore_endpoint
: Endpoint instance - Details of the connection to the Datastore. If not provided, the directory will be inquired if necessary. This will be used as the datalog endpoint of the trials started by the batch. And it will be used as the source of samples for training.model_registry_endpoint
: Endpoint instance - Details of the connection to the Model Registry. If not provided, the directory will be inquired if necessary. Only needed for training, not running batches.
async get_controller(self)
Returns the Controller used by the TrialRunner.
Parameters: None
Return: cogment.Controller instance - An instance of cogment.Controller
class used to manage trials.
async get_datastore(self)
Returns the Datastore used by the TrialRunner.
Parameters: None
Return: cogment.Datastore instance - Datastore.
async get_model_registry(self)
Returns the Model Registry used by the TrialRunner.
Parameters: None
Return: cogment.ModelRegistry instance - Model Registry.
async run_simple_batch(self, nb_trials, nb_parallel_trials=1, id=None, pre_trial_callback=None, post_trial_callback=None)
Method to start a batch of trials.
Parameters:
nb_trials
: int - The number of trials to run.nb_parallel_trials
: int - The number of trials to run in parallel. Must be <=nb_trials
.id
: str - ID of the batch. This will be added to the properties of the trials that are started by the batch. This should be unique in the Datastore, otherwise there could be a clash of trial IDs (a mix of trials from different batches could also be used by theBatchTrainer
). IfNone
, an ID will be chosen by the system (Unix epoch in nanoseconds).pre_trial_callback
: async func(BatchTrialInfo instance) -> cogment.TrialParameters - This Callbacks function will be called before any new trial is started. If None, then the parameters for the trials will come from the Orchestrator defaults and pre-trial hooks (see Cogment Orchestrator documentation). In which case theBatchTrainer
cannot work with this batch (because the necessary trial properties cannot be set).post_trial_callback
: async func(sample, trial_parameters, model_registry) - This Callbacks function will be called after the end of a trial. If None, then no call will happen at the end of trials.
Return: TrialBatch instance - An instance of the TrialBatch
class.
async run_simple_training(self, batch, sampler_callback, actor_names=[])
Method to start training on a batch of trials.
Parameters:
batch
: TrialBatch instance - The batch to train on. This will be used to identify the trials (from the trial properties) that are part of the batch and retrieve only the samples from these trials.sampler_callback
: async func(cogment.DatastoreSample, cogment.TrialParameters, cogment.ModelRegistry) -> bool - This Callbacks function will be called for every batch sample retrieved.
Return: BatchTrainer instance - An instance of the BatchTrainer
class.
class TrialBatch
Class to run a batch of related trials.
pause(self)
Method to pause the running of the batch. It stops any new trial from starting, but does not stop currently running trials. Even if all running trials end, the batch is not considered done until it is resumed, stopped or terminated.
Parameters: None
Return: None
resume(self)
Method to restart a batch that was paused. It resumes the starting of new trials in the batch.
Parameters: None
Return: None
stop(self)
Method to stop the batch. It stops new trials from starting, and currently running trials continue to their normal end. Once all trials have ended, the batch is done.
Parameters: None
Return: None
terminate(self, hard=False)
Method to terminate the batch. It stops new trials from starting, and currently running trials are terminated. The batch is then considered done.
Parameters:
hard
: bool - If True the trials are sent a "hard" terminate, otherwise they are sent a "soft" terminate (see Python Documentation "Controller.terminate_trial").
Return: None
is_running(self)
Method to inquire whether the batch is done or not.
Parameters: None
Return: bool - True if the batch is still running. I.e. there are still trials running or it is paused. False otherwise.
nb_trials_run(self)
Method to inquire the number of trials run so far.
Parameters: None
Return: int - Number of trials that were run (and ended) so far in the batch.
async wait(self, timeout)
Method to wait for the batch to be done. The batch will end normally when all trials have run and ended. The batch can also be stopped, terminated, or encounter an error to become done.
Parameters:
timeout
: float - Maximum time to wait in seconds.
Return: bool - True if the batch ended normally with the last trial tagged as such. False otherwise. None
if timed out.
class BatchTrainer
Class to help train a model on a specific batch of trials.
terminate(self)
Method to terminate training. The callback task will be cancelled.
Parameters: None
Return: None
async stop(self)
Method to stop training.
Stops retrieving new samples from the Datastore and waits for all samples already queued to be processed by the callback (or the callback to return False
).
Parameters: None
Return: None
is_running(self)
Method to inquire whether the training is done or not.
Parameters: None
Return: bool - True if the training is still running. I.e. there are still samples being retrieved from the trials and sent to the callback. False otherwise.
async wait(self, timeout)
Method to wait for the training to be done. The training will end normally when all samples of the batch have been processed. The training can also be stopped, or encounter an error to become done.
Parameters:
timeout
: float - Maximum time to wait in seconds.
Return: bool - True if all samples available were processed. False otherwise. None
if timed out.
class BatchTrialInfo
batch_id
: str - ID of the batch
trial_index
: int - The index of the trial in the batch. Generally the order the trials were started, and unique in the batch ([0, nb_trials[
).
trial_info
: cogment.DatastoreTrialInfo - The running trial information. May not always be present.
Callbacks
Use
These functions are passed to the TrialRunner.run_simple_batch
or TrialRunner.run_simple_training
methods and will be called at specific times to request information to the user or provide information to the user.
They can be defined and used in a number of ways.
Here we take the pre_trial_callback
as an example, but the other callbacks are similar, except for parameters and return values:
async def my_pre_trial_callback(info: BatchTrialInfo):
trial_params = cogment.TrialParameters()
# ... Fill in the parameters here
return trial_params
runner = TrialRunner(1, 1, None, my_pre_trial_callback)
Sometimes it is more convenient for the callback to be a method of a class in order to re-use data between calls, or share data with other parts of the program (or other callbacks), in this case it could look like this:
class MyBatchData:
async def my_pre_trial_callback(self, info: BatchTrialInfo):
trial_params = cogment.TrialParameters()
# ... Fill in the parameters here
return trial_params
my_data = MyBatchData()
runner = TrialRunner(1, 1, None, my_data.my_pre_trial_callback)
Although rare, it may be inconvenient sometimes to use a class for sharing data, in which case the Python functools
module can be used:
import functools
async def my_function(my_data, info: BatchTrialInfo):
trial_params = cogment.TrialParameters()
# ... Fill in the parameters here
return trial_params
shared_data = #...
actual_callback = functools.partial(my_function, shared_data)
runner = TrialRunner(1, 1, None, actual_callback)
Pre-Trial Callback
This function is passed to the TrialRunner.run_simple_batch
method and will be called before any trial is started to define the trial parameters.
It is an asyncio
coroutine.
e.g.:
async def my_pre_trial_callback(info: BatchTrialInfo) -> cogment.TrialParameters:
trial_params = cogment.TrialParameters()
# ... Fill in the parameters here
return trial_params
The parameter received is a BatchTrialInfo
instance that is partially filled (i.e. it does not contain a trial_info
).
The function must create an instance of cogment.TrialParameters
, fill the necessary parameters for the trial and return it.
The trial ID will automatically be created using the batch ID and the trial index.
Once the trial parameters are received by the TrialBatch
, some data will be added, and some will be overwritten. These are the attributes changed in the received TrialParameter
before passing it to Cogment:
properties
: Some properties will be added to the existing properties (see Module Attributes). If the property names clash, the user properties will be overwritten. In general, do not start property names with an underscore to prevent such clashes.datalog_endpoint
: This attribute of the trial parameters will be overwritten with the provideddatastore_endpoint
argument ofTrialRunner
. Ifdatastore_endpoint
was not provided, or it wasNone
, then the directory will be used to find an appropriate datastore. The same datastore must be used by both theTrialBatch
(as a datalog) andBatchTrainer
(as a datastore), so the endpoint should resolve to the same datastore locally and at the Orchestrator (i.e. ideally use the same directory).datalog_exclude_fields
: This attribute will be reset (i.e. not excluding any fields from the datalog).
Post-Trial Callback
This function is passed to the TrialRunner.run_simple_batch
method and will be called after a trial has ended.
It is an asyncio
coroutine.
e.g.:
async def my_post_trial_callback(info: BatchTrialInfo):
# ... Do cleanup, tracking, etc
The parameter received is a BatchTrialInfo
instance.
Sampler Callback
This function is passed to the TrialRunner.run_simple_trainer
method and will be called for each sample of the trials being run in the batch.
This is asynchronous with the actual running of the trials and uses the Cogment Datastore to retrieve the samples.
It is an asyncio
coroutine.
e.g.:
async def my_sampler_callback(sample, trial_parameters, model_registry) -> bool:
# ... Train model
continue_training = True
return continue_training
The parameters received are:
sample
: cogment.DatastoreSample - The is the sample that was received with all necessary data to train.trial_parameters
: cogment.TrialParameters - These are the parameters of the trial from which the sample came from.model_registry
: cogment.ModelRegistry - A common model registry for the whole batch being trained. The TrialRunner argumentmodel_registry_endpoint
is used to retrieve this model registry.
The expected return value is a bool
. If True, the training will continue normally. If False, the sampler callback will stop being called, and the BatchTrainer
will stop.