1from typing import Any, Dict, List, Tuple 2from collections import OrderedDict 3from functools import partial, reduce 4 5import os 6import time 7import torch 8try: 9 import pyecharts 10except ImportError: 11 import logging 12 logging.warning("Please install pyecharts first, you can install it by running 'pip install pyecharts'") 13 pyecharts = None 14 15MegaByte = 1024 * 1024 16 17 18class SimpleMemState: 19 """ 20 Overview: 21 A class to represent the memory state of a model layer. 22 Properties: 23 ``layer_mem``, ``total_mem`` 24 Interfaces: 25 ``add``, ``delete``, ``update_total_memory``, ``find_layer_state``, ``dump``, ``to_json`` 26 """ 27 28 def __init__(self, layer_name: str, layer_mem: int = 0) -> None: 29 """ 30 Overview: 31 Initialize the memory state of a model/tensors with the specific name. 32 Arguments: 33 - layer_name (:obj:`str`): The name of the layer. 34 - layer_mem (:obj:`int`, optional): The memory usage of the layer in bytes. Defaults to 0. 35 """ 36 self.layer_name = layer_name 37 38 # Memory status of the current model layer. 39 self._layer_mem: int = layer_mem 40 # Total memory status of the model and sub-models, initialized with layer memory. 41 self._total_mem: int = self._layer_mem 42 # SimpleMemState of sub-models. 43 self.sub_model_stats = OrderedDict() 44 45 @property 46 def layer_mem(self) -> int: 47 """ 48 Overview: 49 Get the memory usage of the layer. 50 51 Returns: 52 - layer_mem (:obj:`int`): The memory usage of the layer in bytes. 53 """ 54 return self._layer_mem 55 56 @layer_mem.setter 57 def layer_mem(self, new_layer_mem: int) -> None: 58 """ 59 Overview: 60 Set the memory usage of the layer and update the total memory. 61 Arguments: 62 - new_layer_mem (:obj:`int`): The new memory usage of the layer in bytes. 63 """ 64 diff = new_layer_mem - self._layer_mem 65 self._layer_mem = new_layer_mem 66 self._total_mem += diff 67 68 @property 69 def total_mem(self) -> int: 70 """ 71 Overview: 72 Get the total memory usage of the model and sub-models. 73 74 Returns: 75 - total_mem (:obj:`int`): The total memory usage of the model and sub-models in bytes. 76 """ 77 return self._total_mem 78 79 def add(self, layer_name: str, layer_mem: int = 0, flush: bool = True) -> None: 80 """ 81 Overview: 82 Add a layer to the memory state. 83 Arguments: 84 - layer_name (:obj:`str`): The name of the layer. 85 - layer_mem (:obj:`int`, optional): The memory usage of the layer in bytes. Defaults to 0. 86 - flush (:obj:`Optional[bool]`): Whether to update the total memory usage. Defaults to True. 87 """ 88 path = layer_name.split(".") 89 90 target = self.find_layer_state(path, create=True) 91 target.layer_mem = layer_mem 92 93 if flush: 94 self.update_total_memory() 95 96 def delete(self, layer_name: str, flush: bool = True) -> None: 97 """ 98 Overview: 99 Delete a layer from the memory state. 100 Arguments: 101 - layer_name (:obj:`str`): The name of the layer. 102 - flush (:obj:`Optional[bool]`): Whether to update the total memory usage. Defaults to True. 103 """ 104 path = layer_name.split(".") 105 assert len(path) >= 2, f"Only support deleting non-root layers, layer_name: {layer_name}" 106 107 parent_path = path[0:-1] 108 layer = path[-1] 109 parent = self.find_layer_state(parent_path) 110 111 if parent is not None and layer in parent.sub_model_stats: 112 del parent.sub_model_stats[layer] 113 114 if flush: 115 self.update_total_memory() 116 117 def update_total_memory(self) -> None: 118 """ 119 Overview: 120 Update the total memory usage of the model and sub-models. 121 """ 122 self._total_mem = self._layer_mem 123 124 for stat in self.sub_model_stats.values(): 125 # Update sub-model status first. 126 stat.update_total_memory() 127 # Add sub-model total_mem to model total_mem. 128 self._total_mem += stat._total_mem 129 130 def find_layer_state(self, path: Tuple[str], create: bool = False) -> "SimpleMemState": 131 """ 132 Overview: 133 Find the memory state of a layer. 134 Arguments: 135 - path (:obj:`Tuple[str]`): The path to the layer. 136 - create (:obj:`Optional[bool]`): Whether to create the layer if it doesn't exist. Defaults to False. 137 Returns: 138 - state (:obj:`SimpleMemState`): The memory state of the layer. 139 """ 140 current_node = self 141 142 for _node in path: 143 if _node not in current_node.sub_model_stats: 144 if not create: 145 return None 146 # Create a layer node. 147 current_node.sub_model_stats[_node] = SimpleMemState(_node) 148 149 current_node = current_node.sub_model_stats[_node] 150 151 return current_node 152 153 def dump(self, prefix: str = "") -> str: 154 """ 155 Overview: 156 Dump the memory state of the model and sub-models. 157 Arguments: 158 - prefix (:obj:`Optional[str]`): The prefix to add to the layer names. Defaults to "". 159 Returns: 160 - result (:obj:`str`): The memory state information. 161 """ 162 cur_prefix = prefix + "." + self.layer_name if prefix != "" else self.layer_name 163 res = f"layer: {cur_prefix}, layer_mem: {self.layer_mem / MegaByte:.2f} MB, total_mem: {self.total_mem / MegaByte:.2f} MB\n" # noqa 164 165 for sub_layer in self.sub_model_stats.values(): 166 res += sub_layer.dump(cur_prefix) 167 168 return res 169 170 def to_json(self, base: int = 1024 * 1024) -> dict: 171 """ 172 Overview: 173 Convert the memory state to a JSON structure. 174 Arguments: 175 - base (:obj:`Optional[int]`): The base value to convert the memory usage to. Defaults to 1024 * 1024, \ 176 which converts the memory usage to MB. 177 Returns: 178 - result (:obj:`dict`): The JSON structure of the memory state. 179 """ 180 children = [child.to_json() for child in self.sub_model_stats.values()] 181 if len(children) == 0: 182 return {"name": self.layer_name, "value": self.layer_mem // base} 183 else: 184 return {"name": self.layer_name, "children": children} 185 186 187class ActivationMemState: 188 """ 189 Overview: 190 A class to represent the memory state of activation tensors. 191 Properties: 192 ``total_mem`` 193 Interfaces: 194 ``add``, ``dump``, ``to_json`` 195 """ 196 197 def __init__(self, num_chunks: int) -> None: 198 """ 199 Overview: 200 Initialize the memory state of activation tensors. 201 Arguments: 202 - num_chunks (:obj:`int`): The number of chunks, multiple chunks are used in some large-scale models. 203 """ 204 self._num_chunks = num_chunks 205 206 self.inited: List[bool] = [False for _ in range(num_chunks)] 207 self.states: List[SimpleMemState] = [SimpleMemState(f"activations_{idx}") for idx in range(num_chunks)] 208 209 @property 210 def total_mem(self) -> int: 211 """ 212 Overview: 213 Get the total memory usage of the activation tensors. 214 Returns: 215 - total_mem (:obj:`int`): The total memory usage of the activation tensors in bytes. 216 """ 217 return sum(state.total_mem for state in self.states) 218 219 def dump(self, prefix: str = "") -> str: 220 """ 221 Overview: 222 Dump the memory state of the activation tensors. 223 Arguments: 224 - prefix (:obj:`Optional[str]`): The prefix to add to the layer names. Defaults to "". 225 Returns: 226 - result (:obj:`str`): The memory state information. 227 """ 228 return reduce(lambda x, y: x + y, [state.dump(prefix) for state in self.states]) 229 230 def to_json(self, base: int = 1024 * 1024) -> List[dict]: 231 """ 232 Overview: 233 Convert the memory state to a JSON structure. 234 Arguments: 235 - base (:obj:`Optional[int]`): The base value to convert the memory usage to. Defaults to 1024 * 1024, \ 236 which converts the memory usage to MB. 237 Returns: 238 - result (:obj:`List[dict]`): The JSON structure of the memory state. 239 """ 240 return [state.to_json(base) for state in self.states] 241 242 243def _unpack_naive_wrapper(model: torch.nn.Module) -> Tuple[torch.nn.Module, int]: 244 num_chunks = len(model) if isinstance(model, torch.nn.ModuleList) else 1 245 246 return model, num_chunks 247 248 249class SimpleMemoryProfiler: 250 """ 251 Overview: 252 A memory profiler for a PyTorch neural network model. 253 Interfaces: 254 ``point``, ``step`` 255 """ 256 257 def __init__( 258 self, 259 model: torch.nn.Module, 260 optimizer: torch.optim.Optimizer, 261 log_folder: str, 262 total_steps: int = 5, 263 ): 264 """ 265 Overview: 266 Initialize the memory profiler. 267 Arguments: 268 - model (:obj:`torch.nn.Module`): The model to profile. 269 - optimizer (:obj:`torch.optim.Optimizer`): The optimizer used for training the model. 270 - log_folder (:obj:`str`): The folder to write the memory state information to. 271 - total_steps (:obj:`Optional[int]`): The number of steps to trace. Defaults to 5. 272 """ 273 self._model, self._num_model_chunks = _unpack_naive_wrapper(model) 274 self._optimizer = optimizer 275 self._log_folder = log_folder 276 self._remaining_steps = total_steps 277 278 self._stoped = False 279 self._record_start_time = time.time() 280 281 # For activation memory state. 282 283 self._activation_mem: int = 0 284 self._activation_mem_max: int = 0 285 self._activation_base_mems = ActivationMemState(self._num_model_chunks) 286 287 # Check or create log folder 288 os.makedirs(self._log_folder, exist_ok=True) 289 290 # Register activation memory tracking hooks 291 if self._num_model_chunks > 1: 292 for chunk_id in range(self._num_model_chunks): 293 self._register_activation_trace_hooks(chunk_id, self._model[chunk_id]) 294 else: 295 self._register_activation_trace_hooks(0, self._model) 296 297 # Calculate static parameter cuda memory 298 self._param_mem_state = SimpleMemState("param_mem") 299 self._calc_tensor_memory(self._param_mem_state, self._model.named_parameters()) 300 # Calculate static grad cuda memory 301 self._grad_mem_state = SimpleMemState("grad_mem") 302 self._calc_tensor_memory(self._grad_mem_state, self._model.named_parameters(), True) 303 # Calculate static optimizer state cuda memory 304 self._os_params_mem_state = SimpleMemState("os_params_mem") 305 self._os_state_mem_state = SimpleMemState("os_state_mem") 306 self._calc_tensor_group_memory(self._os_params_mem_state, list(enumerate(self._optimizer.param_groups))) 307 308 # Generate the first memory record 309 self.point(with_options="params,grads,os_params", create=True) 310 311 def point(self, with_options: str = "", create: bool = False) -> None: 312 """ 313 Overview: 314 Record the memory state of the model and optimizer at current point. 315 Arguments: 316 - with_options (:obj:`Optional[str]`): The options to include in the memory state. Defaults to "". 317 - create (:obj:`Optional[bool]`): Whether to create a new memory record. Defaults to False. 318 """ 319 now = time.time() 320 file = f"{self._log_folder}/memory.log" 321 322 if with_options == "all": 323 options = ["params", "grads", "os_params", "os_state", "activation_base"] 324 else: 325 options = with_options.split(",") 326 327 total_mem = ( 328 self._param_mem_state.total_mem + self._grad_mem_state.total_mem + self._os_params_mem_state.total_mem + 329 self._os_state_mem_state.total_mem + self._activation_mem 330 ) / MegaByte 331 332 # Generate summary information for memory state 333 summary_info = ( 334 f"total_memory: {total_mem:.2f} MB" + "\n" + 335 f"params_memory: {self._param_mem_state.total_mem / MegaByte:.2f} MB, " + 336 f"grads_memory: {self._grad_mem_state.total_mem / MegaByte:.2f} MB, " + 337 f"os_params_memory: {self._os_params_mem_state.total_mem / MegaByte:.2f} MB, " + 338 f"os_state_memory: {self._os_state_mem_state.total_mem / MegaByte:.2f} MB, " + 339 f"activation_memory: {self._activation_mem / MegaByte:.2f} MB" 340 ) 341 342 # Generate layout information based on selected options 343 layout_info = "" 344 if "params" in options: 345 layout_info += "params_layout:\n" + self._param_mem_state.dump() 346 if "grads" in options: 347 layout_info += "grads_layout:\n" + self._grad_mem_state.dump() 348 if "os_params" in options: 349 layout_info += "os_params_layout:\n" + self._os_params_mem_state.dump() 350 if "os_state" in options: 351 layout_info += "os_state_layout:\n" + self._os_state_mem_state.dump() 352 if "activation_base" in options: 353 layout_info += "activation_base_layout:\n" + self._activation_base_mems.dump() 354 355 # Write memory state information to log file 356 file_mode = "w" if create else "a" 357 with open(file, file_mode, encoding="utf-8") as writer: 358 writer.write( 359 "Memory State:\n" + f"time: {now - self._record_start_time}\n" + "---summary---\n" + summary_info + "\n" 360 ) 361 if layout_info != "": 362 writer.write("---Layout---\n" + layout_info) 363 writer.write("\n") 364 365 def step(self) -> None: 366 """ 367 Overview: 368 Update the memory state of the optimizer state (e.g., momentum, learning rate) and record the memory state. 369 """ 370 if self._stoped: 371 return 372 373 self._remaining_steps -= 1 374 if self._remaining_steps == 0: 375 self._stoped = True 376 377 # Update os state memory usage 378 self._os_state_mem_state = SimpleMemState("os_state_mem") 379 self._calc_tensor_group_memory(self._os_state_mem_state, list(self._optimizer.state_dict()["state"].items())) 380 381 if not self._stoped: 382 # Do we need to print os_state_layout every time? Is it always constant? 383 self.point(with_options="os_state") 384 else: 385 # Dump memory layout 386 self.point(with_options="all") 387 # Generate sunburst charts 388 self._render_sunburst_chart(self._param_mem_state.to_json()["children"], "params_memory_sunburst") 389 self._render_sunburst_chart(self._grad_mem_state.to_json()["children"], "grads_memory_sunburst") 390 self._render_sunburst_chart( 391 [self._os_params_mem_state.to_json(), 392 self._os_state_mem_state.to_json()], 393 "os_memory_sunburst", 394 ) 395 self._render_sunburst_chart(self._activation_base_mems.to_json(), "activation_memory_sunburst") 396 # Generate summary sunburst chart 397 summary_sunburst_data = [ 398 { 399 "name": "params", 400 "value": self._param_mem_state.total_mem // MegaByte 401 }, 402 { 403 "name": "grads", 404 "value": self._grad_mem_state.total_mem // MegaByte 405 }, 406 { 407 "name": "os_params", 408 "value": self._os_params_mem_state.total_mem // MegaByte 409 }, 410 { 411 "name": "os_state", 412 "value": self._os_state_mem_state.total_mem // MegaByte 413 }, 414 { 415 "name": "activation", 416 "value": self._activation_mem_max // MegaByte 417 }, 418 ] 419 420 self._render_sunburst_chart(summary_sunburst_data, "summary_sunburst") 421 422 def _render_sunburst_chart(self, data: Any, name: str) -> None: 423 """ 424 Overview: 425 Render a sunburst chart for the memory state with pyecharts. 426 Arguments: 427 - data (:obj:`Any`): The data to render. 428 - name (:obj:`str`): The name of the chart. 429 """ 430 pyecharts.charts.Sunburst(init_opts=pyecharts.options.InitOpts(width="1000px", height="1000px")).add( 431 name, 432 data_pair=data, 433 highlight_policy="ancestor", 434 radius=[0, "95%"], 435 levels=[ 436 {}, 437 { 438 "r0": "10%", 439 "r": "35%", 440 "itemStyle": { 441 "borderWidth": 3 442 }, 443 "label": { 444 "align": "left" 445 }, 446 }, 447 { 448 "r0": "35%", 449 "r": "55%", 450 "label": { 451 "align": "left" 452 } 453 }, 454 { 455 "r0": "55%", 456 "r": "70%", 457 "label": { 458 "align": "left" 459 } 460 }, 461 { 462 "r0": "70%", 463 "r": "80%", 464 "label": { 465 "align": "left" 466 } 467 }, 468 { 469 "r0": "80%", 470 "r": "90%", 471 "label": { 472 "align": "left" 473 } 474 }, 475 { 476 "r0": "90%", 477 "r": "92%", 478 "label": { 479 "position": "outside", 480 "padding": 3, 481 "silent": False 482 }, 483 "itemStyle": { 484 "borderWidth": 3 485 }, 486 }, 487 ], 488 ).set_global_opts(title_opts=pyecharts.options.TitleOpts(title="CUDA Memory") 489 ).set_series_opts(label_opts=pyecharts.options.LabelOpts(formatter="{b}") 490 ).render(f"{self._log_folder}/{name}.html") 491 492 def _inner_activation_trace_hook( 493 self, 494 chunk_id: int, 495 layer_name: str, 496 model: Any, 497 inputs: Any, 498 output: torch.Tensor, 499 ) -> None: 500 """ 501 Overview: 502 Hook function to trace the activation memory usage for a inner layer. 503 504 .. note:: 505 For more details about hook mechanism, please refer to the PyTorch documentation. 506 507 Arguments: 508 - chunk_id (:obj:`int`): The model chunk id. 509 - layer_name (:obj:`str`): The name of the layer. 510 - model (:obj:`Any`): The model to trace. 511 - inputs (:obj:`Any`): The inputs to the layer. 512 - output (:obj:`torch.Tensor`): The output tensor. 513 """ 514 del model, inputs 515 assert isinstance(output, torch.Tensor), f"Invalid output type: {type(output)}" 516 517 if self._stoped or self._activation_base_mems.inited[chunk_id]: 518 return 519 520 # Delay updating the total_mem of activation_base_mem here, it will be handled in the forward ending hook. 521 self._activation_base_mems.states[chunk_id].add( 522 layer_name, output.element_size() * output.nelement(), flush=False 523 ) 524 525 def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: Any) -> None: 526 """ 527 Overview: 528 Hook function to trace the activation memory usage for a forward pass. 529 530 .. note:: 531 For more details about hook mechanism, please refer to the PyTorch documentation. 532 533 Arguments: 534 - chunk_id (:obj:`int`): The model chunk id. 535 - model (:obj:`Any`): The model to trace. 536 - inputs (:obj:`Any`): The inputs to the model. 537 - output (:obj:`Any`): The output of the model. 538 """ 539 del model, inputs 540 541 if self._stoped: 542 return 543 544 # Check if the activation memory has been initialized 545 if self._activation_base_mems.inited[chunk_id] is False: 546 self._activation_base_mems.inited[chunk_id] = True 547 # Update the total memory of the activation base memory state 548 self._activation_base_mems.states[chunk_id].update_total_memory() 549 # Set with_options to "activation_base" to include activation_base_layout in the memory dump 550 with_options = "activation_base" 551 else: 552 with_options = "" 553 554 # Accumulate activation memory usage for each forward pass 555 self._activation_mem += self._activation_base_mems.states[chunk_id].total_mem 556 if self._activation_mem > self._activation_mem_max: 557 self._activation_mem_max = self._activation_mem 558 559 # Trigger a memory record 560 self.point(with_options) 561 562 def _activation_tarce_hook_backward(self, chunk_id: int, model: Any, inputs: Any, grad_outputs: Any) -> None: 563 """ 564 Overview: 565 Hook function to trace the activation memory usage for a backward pass. 566 567 .. note:: 568 For more details about hook mechanism, please refer to the PyTorch documentation. 569 570 Arguments: 571 - chunk_id (:obj:`int`): The model chunk id. 572 - model (:obj:`Any`): The model to trace. 573 - inputs (:obj:`Any`): The inputs to the model. 574 - grad_outputs (:obj:`Any`): The gradients of the outputs. 575 """ 576 del model, inputs, grad_outputs 577 578 if self._stoped: 579 return 580 581 # Release activation memory usage for each backward pass 582 self._activation_mem -= self._activation_base_mems.states[chunk_id].total_mem 583 584 # Trigger a memory record 585 self.point() 586 587 def _register_activation_trace_hooks(self, chunk_id: int, model_chunk: torch.nn.Module) -> None: 588 """ 589 Overview: 590 Register activation trace hooks for the model and each submodule in the model. 591 Arguments: 592 - chunk_id (:obj:`int`): The model chunk id. 593 - model_chunk (:obj:`torch.nn.Module`): The model chunk to trace. 594 """ 595 596 # Register inner activation trace hooks for each submodule in the model 597 for layer_name, sub_model in model_chunk.named_modules(): 598 # Register the hook 599 if len(sub_model._modules) != 0: 600 continue # TODO: in some special cases, we may need some additional configuration to correct 601 602 sub_model.register_forward_hook(partial(self._inner_activation_trace_hook, chunk_id, layer_name)) 603 604 # Register a forward hook for the main model to track activation memory usage 605 model_chunk.register_forward_hook(partial(self._activation_trace_hook_forward, chunk_id)) 606 # Register a backward hook for the main model to release activation memory usage 607 model_chunk.register_full_backward_hook(partial(self._activation_tarce_hook_backward, chunk_id)) 608 609 def _calc_tensor_memory( 610 self, 611 root_stat: SimpleMemState, 612 named_tensors: Dict[str, torch.Tensor], 613 require_grad: bool = False 614 ) -> None: 615 """ 616 Overview: 617 Core function to calculate the memory usage of tensors and update the memory state. 618 Arguments: 619 - root_stat (:obj:`SimpleMemState`): The root memory state. 620 - named_tensors (:obj:`Dict[str, torch.Tensor]`): A dictionary containing the named tensors. 621 - require_grad (:obj:`Optional[bool]`): Whether to consider tensors with gradients. Defaults to False. 622 """ 623 for name, tensor in named_tensors: 624 if require_grad and not tensor.requires_grad: 625 continue 626 627 layer_splits = name.split(sep=".") 628 layer_stat = root_stat.find_layer_state(layer_splits, create=True) 629 layer_stat.layer_mem = tensor.element_size() * tensor.nelement() 630 631 root_stat.update_total_memory() 632 633 def _calc_tensor_group_memory(self, root_stat: SimpleMemState, tensor_groups: List[Tuple[int, torch.Tensor]]): 634 """ 635 Overview: 636 Core function to calculate the memory usage of a group of tensors and update the memory state. 637 Arguments: 638 - root_stat (:obj:`SimpleMemState`): The root memory state. 639 - tensor_groups (:obj:`List[Tuple[int, torch.Tensor]]`): A list of tuples containing the tensor groups. 640 """ 641 642 def _normalize_helper(named_tensors: Dict[str, Any]) -> List[Tuple[str, Any]]: 643 res = {} 644 645 for name, tensors in named_tensors.items(): 646 if isinstance(tensors, torch.Tensor): 647 res[name] = tensors 648 elif isinstance(tensors, (list, tuple)): 649 for index, tensor in enumerate(tensors): 650 res[f"{name}.{index}"] = tensor 651 elif isinstance(tensors, dict): 652 for subname, tensor in tensors.items(): 653 res[f"{name}.{subname}"] = tensor 654 else: 655 raise TypeError(f"unsupported normalize value type: {type(tensors)}") 656 657 return list(res.items()) 658 659 def _value_check(tensor_or_tensors): 660 if torch.is_tensor(tensor_or_tensors): 661 return True 662 elif isinstance(tensor_or_tensors, (list, tuple)) and all(torch.is_tensor(x) for x in tensor_or_tensors): 663 return True 664 elif isinstance(tensor_or_tensors, dict) and all(torch.is_tensor(x) for x in tensor_or_tensors.values()): 665 return True 666 else: 667 return False 668 669 # Calculate the memory usage of a group of tensors. 670 for idx, tensors in tensor_groups: 671 # Normalize the named tensors 672 named_tensors = {f"{idx}.{k}": v for k, v in tensors.items() if _value_check(v)} 673 named_tensors = _normalize_helper(named_tensors) 674 # Calculate the memory usage of the tensors and update the memory state 675 self._calc_tensor_memory(root_stat, named_tensors) 676 677 678def get_current_device() -> torch.device: 679 """ 680 Overview: 681 Get the current PyTorch tensor device. 682 683 Returns: 684 - device (:obj:`torch.device`): The current device. 685 """ 686 return torch.device("cuda" if torch.cuda.is_available() else "cpu") 687 688 689def multi_chunk_test(): 690 """ 691 Overview: 692 A test function to demonstrate the memory profiler for a model with multiple chunks. 693 """ 694 695 class SimpleModel(torch.nn.Module): 696 697 def __init__(self, skip_layer2: bool = False): 698 super().__init__() 699 self.layer1 = torch.nn.Linear(5120, 5120, True) 700 self.layer3 = torch.nn.Linear(5120, 5120, False) 701 702 if skip_layer2: 703 self.layer2 = None 704 else: 705 self.layer2 = SimpleModel(skip_layer2=True) 706 707 def forward(self, inputs: torch.Tensor) -> torch.Tensor: 708 output1 = self.layer1(inputs) 709 if self.layer2 is not None: 710 output2 = self.layer2(output1) 711 else: 712 output2 = output1 713 output = self.layer3(output2) 714 715 return output 716 717 def _simple_schedule(_num_chunks, _model_chunks, _input) -> torch.Tensor: 718 if _num_chunks > 1: 719 _output = _input 720 for _model_chunk in _model_chunks: 721 _output = _model_chunk(_output) 722 else: 723 _output = _model_chunks(_input) 724 725 return _output 726 727 # num_chunks config 728 _num_chunks = 1 729 730 # init model and optimizer 731 if _num_chunks > 1: 732 _chunks = [SimpleModel(skip_layer2=idx % 2 == 0) for idx in range(_num_chunks)] 733 _model = torch.nn.ModuleList(_chunks).to(get_current_device()) 734 else: 735 _model: torch.nn.Module = SimpleModel().to(get_current_device()) 736 _optimizer = torch.optim.Adam(_model.parameters()) 737 738 # init profiler 739 profiler = SimpleMemoryProfiler(_model, _optimizer, "./test_simple_memory_profiler_multi_chunk", total_steps=1) 740 741 _optimizer.zero_grad() 742 743 # inputs 744 x1 = torch.randn((128, 5120)).to(get_current_device()) 745 x2 = torch.randn((128, 5120)).to(get_current_device()) 746 # forward 747 out1 = _simple_schedule(_num_chunks, _model, x1) 748 out2 = _simple_schedule(_num_chunks, _model, x2) 749 # backward 750 out1.mean().backward() 751 out2.mean().backward() 752 753 _optimizer.step() 754 755 # Update the optimizer state memory usage and record the memory state 756 profiler.step() 757 758 759if __name__ == "__main__": 760 multi_chunk_test()