TPU 監控程式庫

透過進階 TPU 監控功能,深入瞭解 Cloud TPU 硬體的效能和行為,這些功能直接建構在基礎軟體層 LibTPU 上。LibTPU 包含驅動程式、網路程式庫、XLA 編譯器和 TPU 執行階段,可與 TPU 互動,但本文重點是 TPU 監控程式庫。

TPU 監控程式庫提供以下功能:

  • 全面觀測能力:存取遙測 API 和指標套件,深入瞭解 TPU 的運作效能和特定行為。

  • 診斷工具包:提供 SDK 和指令列介面 (CLI),可對 TPU 資源進行偵錯及深入分析效能。

這些監控功能是專為客戶設計的頂層解決方案,可提供必要工具,協助您有效最佳化 TPU 工作負載。

TPU 監控程式庫會提供詳細資訊,說明機器學習工作負載在 TPU 硬體上的執行情況,協助您瞭解 TPU 使用率、找出瓶頸,以及排解效能問題。這項程式庫提供的資訊比中斷指標、有效輸送量指標和其他指標更詳細。

開始使用 TPU 監控程式庫

存取這些強大的洞察資料非常簡單。TPU 監控功能已整合至 LibTPU SDK,因此安裝 LibTPU 時,系統會一併安裝這項功能。

安裝 LibTPU

pip install libtpu

此外,LibTPU 更新會與 JAX 版本同步,也就是說,安裝最新 JAX 版本 (每月發布) 時,通常會將您固定在最新相容的 LibTPU 版本及其功能。

安裝 JAX

pip install -U "jax[tpu]"

PyTorch 使用者安裝 PyTorch/XLA 後,即可取得最新的 LibTPU 和 TPU 監控功能。

安裝 PyTorch/XLA

pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
  -f https://storage.googleapis.com/libtpu-releases/index.html \
  -f https://storage.googleapis.com/libtpu-wheels/index.html

  # Optional: if you're using custom kernels, install pallas dependencies
pip install 'torch_xla[pallas]' \
  -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
  -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

如要進一步瞭解如何安裝 PyTorch/XLA,請參閱 PyTorch/XLA GitHub 存放區中的「安裝」。

在 Python 中匯入程式庫

如要開始使用 TPU 監控程式庫,您需要在 Python 程式碼中匯入 libtpu 模組。

from libtpu.sdk import tpumonitoring

列出所有支援的功能

列出所有指標名稱和支援的功能:


from libtpu.sdk import tpumonitoring

tpumonitoring.help()
" libtpu.sdk.monitoring.help():
      List all supported functionality.

  libtpu.sdk.monitoring.list_support_metrics()
      List support metric names in the list of str format.

  libtpu.sdk.monitoring.get_metric(metric_name:str)
      Get metric data with metric name. It represents the snapshot mode.
      The metric data is a object with `description()` and `data()` methods,
      where the `description()` returns a string describe the format of data
      and data unit, `data()` returns the metric data in the list in str format.
"

支援的指標

下列程式碼範例說明如何列出所有支援的指標名稱:

from libtpu.sdk import tpumonitoring

tpumonitoring.list_supported_metrics()

["duty_cycle_pct", "tensorcore_util", "hbm_util", ...]

下表列出所有指標及其對應定義:

指標 定義 API 的指標名稱 範例值
Tensor Core 使用率 測量 TensorCore 的用量百分比,計算方式為 TensorCore 作業的占比。每 1 秒取樣 10 微秒。取樣率無法修改。 您可以透過這項指標,監控 TPU 裝置中的工作負載效率。 tensorcore_util ['1.11', '2.22', '3.33', '4.44']

# utilization percentage for accelerator ID 0-3.
任務週期百分比 過往取樣期間 (每 5 秒;可透過設定 LIBTPU_INIT_ARG 標記調整) 內,加速器積極處理作業的時間占比 (按照上一個取樣期間內執行 HLO 程式的週期記錄)。這項指標呈現 TPU 的忙碌程度。這項指標會針對每個晶片發出。 duty_cycle_pct ['10.00', '20.00', '30.00', '40.00']

# Duty cycle percentage for accelerator ID 0-3.
HBM 容量總計 這項指標會以位元組為單位回報 HBM 總容量。 hbm_capacity_total ['30000000000', '30000000000', '30000000000', '30000000000']

# Total HBM capacity in bytes that attached to accelerator ID 0-3.
HBM 容量用量 這項指標會回報過去取樣期間 (每 5 秒;可透過設定 LIBTPU_INIT_ARG 旗標調整) 的 HBM 容量用量 (以位元組為單位)。 hbm_capacity_usage ['100', '200', '300', '400']

# Capacity usage for HBM in bytes that attached to accelerator ID 0-3.
緩衝區傳輸延遲 巨量多切片流量的網路傳輸延遲時間。這項視覺化功能可協助您瞭解整體網路效能環境。 buffer_transfer_latency ["'8MB+', '2233.25', '2182.02', '3761.93', '19277.01', '53553.6'"]

# buffer size, mean, p50, p90, p99, p99.9 of network transfer latency distribution.
傳入緩衝區轉移延遲 測量透過網路接收資料緩衝區所需的時間,以微秒為單位。這項指標有助於偵測網路效能問題和潛在的網路硬體故障。 inbound_buffer_transfer_latency ["'8MB+', '4723.59', '4500.22', '6210.11', '7100.45', '8510.99'"]

# buffer size, mean, p50, p90, p99, p99.9 of inbound DCN network traffic transfer latency distribution.
主機運算延遲 以微秒為單位,測量計算縮減作業所需的時間。這項指標有助於偵測效能問題,以及參與分散式運算的個別主機上潛在的 CPU 或記憶體硬體故障。 host_compute_latency ["'8MB+', '2408.17', '2105.40', '3600.22', '4800.15', '15000.80'"]

# buffer size, mean, p50, p90, p99, p99.9 of host compute latency distribution.
高階作業執行時間分布指標 提供 HLO 編譯二進位檔執行狀態的精細效能洞察資料,可偵測迴歸並進行模型層級的偵錯。 hlo_exec_timing ["'tensorcore-0', '10.00', '10.00', '20.00', '30.00', '40.00'"]

# The HLO execution time duration distribution for CoreType-CoreID with mean, p50, p90, p95, p999.
高階最佳化工具佇列大小 監控 HLO 執行佇列大小,可追蹤等待或正在執行的已編譯 HLO 程式數量。這項指標會顯示執行管道壅塞情形,有助於找出硬體執行、驅動程式負荷或資源分配方面的效能瓶頸。 hlo_queue_size ["tensorcore-0: 1", "tensorcore-1: 2"]

# Measures queue size for CoreType-CoreID.
集體端對端延遲時間 這項指標會測量 DCN 的端對端集體延遲時間 (以微秒為單位),從主機啟動作業到所有對等互連接收輸出內容。這項指標包含主機端資料縮減,以及將輸出內容傳送至 TPU。結果是詳細說明緩衝區空間、類型,以及平均、第 50、90、95 和 99.9 個百分位數延遲時間的字串。 collective_e2e_latency ["8MB+-ALL_REDUCE, 1000, 2000, 3000, 4000, 5000", …]

# Transfer size-collective op, mean, p50, p90, p95, p999 of collective end to end latency.
傳輸層的封包往返延遲時間 gRPC 用於多重切片 TPU 流量的 TCP 連線所觀察到的最短往返時間 (RTT) 分佈。 grpc_tcp_min_round_trip_times ['27.63, 29.03, 38.52, 41.63, 52.74']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).
傳輸層的處理量 gRPC 用於多節點 TPU 流量的 TCP 連線近期總處理量累積分布。 grpc_tcp_delivery_rates ['11354.89, 10986.35, 18239.55, 25718.35, 44841.55']

# Represents the distribution's mean, p50, p90, p95, and p99.9 percentiles in microseconds (µs).

讀取指標資料

如要讀取指標資料,請在呼叫 tpumonitoring.get_metric 函式時指定指標名稱。您可以將臨時指標檢查插入效能不佳的程式碼,判斷效能問題是源自軟體還是硬體。

以下程式碼範例說明如何讀取 duty_cycle 指標:

from libtpu.sdk import tpumonitoring

metric = tpumonitoring.get_metric("duty_cycle_pct")

metric.description()
"The metric provides a list of duty cycle percentages, one for each
accelerator (from accelerator_0 to accelerator_x). The duty cycle represents
the percentage of time an accelerator was actively processing during the
last sample period, indicating TPU utilization."

metric.data()
["0.00", "0.00", "0.00", "0.00"]

# accelerator_0-3

使用指標檢查 TPU 使用率

下列範例說明如何使用 TPU 監控程式庫中的指標,追蹤 TPU 使用率。

在 JAX 訓練期間監控 TPU 任務週期

情境:您正在執行 JAX 訓練指令碼,並想在整個訓練過程中監控 TPU 的 duty_cycle_pct 指標,確認 TPU 獲得有效運用。您可以在訓練期間定期記錄這項指標,追蹤 TPU 使用率。

下列程式碼範例說明如何在 JAX 訓練期間監控 TPU 負載週期:

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring
import time

 # --- Your JAX model and training setup would go here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)

def loss_fn(params, x, y):
    preds = simple_model(x)
    return jnp.mean((preds - y)**2)

def train_step(params, x, y, optimizer):
    grads = jax.grad(loss_fn)(params, x, y)
    return optimizer.update(grads, params)

key = jax.random.PRNGKey(0)
params = jnp.array([1.0, 2.0]) # Example params
optimizer = ... # Your optimizer (for example, optax.adam)
data_x = jnp.ones((10, 10))
data_y = jnp.zeros((10,))

num_epochs = 10
log_interval_steps = 2  # Log duty cycle every 2 steps

for epoch in range(num_epochs):
    for step in range(5): # Example steps per epoch

        params = train_step(params, data_x, data_y, optimizer)

        if (step + 1) % log_interval_steps == 0:
            # --- Integrate TPU Monitoring Library here to get duty_cycle ---
            duty_cycle_metric = tpumonitoring.get_metric("duty_cycle_pct")
            duty_cycle_data = duty_cycle_metric.data
            print(f"Epoch {epoch+1}, Step {step+1}: TPU Duty Cycle Data:")
            print(f"  Description: {duty_cycle_metric.description}")
            print(f"  Data: {duty_cycle_data}")
            # --- End TPU Monitoring Library Integration ---

        # --- Rest of your training loop logic ---
        time.sleep(0.1) # Simulate some computation

print("Training complete.")

執行 JAX 推論前,請先檢查 HBM 使用率

情境: 使用 JAX 模型執行推論前,請先檢查 TPU 的 HBM (高頻寬記憶體) 使用率,確認有足夠的可用記憶體,並在推論開始前取得基準測量結果。

# The following code sample shows how to check HBM utilization before JAX inference:
import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

  # --- Your JAX model and inference setup would go here ---
  # --- Example placeholder model (replace with your actual model loading/setup)---
def simple_model(x):
    return jnp.sum(x)

key = jax.random.PRNGKey(0)
params = ... # Load your trained parameters

  # Integrate the TPU Monitoring Library to get HBM utilization before inference
hbm_util_metric = tpumonitoring.get_metric("hbm_util")
hbm_util_data = hbm_util_metric.data
print("HBM Utilization Before Inference:")
print(f"  Description: {hbm_util_metric.description}")
print(f"  Data: {hbm_util_data}")
  # End TPU Monitoring Library Integration

  # Your Inference Logic
input_data = jnp.ones((1, 10)) # Example input
predictions = simple_model(input_data)
print("Inference Predictions:", predictions)

print("Inference complete.")

查看網路指標

情境: 您正在執行多主機和多配量工作負載,並想使用 SSH 連線至其中一個 GKE Pod 或 TPU,以便在工作負載執行期間查看網路指標。這些指令也可以直接併入多主機工作負載。

import jax
import jax.numpy as jnp
from libtpu.sdk import tpumonitoring

 # --- Your JAX model and training setup goes here ---
 #  --- Example placeholder model and data (replace with your actual setup)---
def simple_model(x):
    return jnp.sum(x)
     # --- details here ---

    # ==============================================================================
    # Metric 1: TCP Delivery Rate
    # ==============================================================================
    # This metric reflects the delivery rate of the TCP connection (bytes delivered / elapsed time).
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in Mbps)

        # Get the metric object
        delivery_rate_metric = tpumonitoring.get_metric(metric_name_rate)

        # Print the description provided by the library
        print("Description:", delivery_rate_metric.description())

        # Print the actual data payload
        print("Data:", delivery_rate_metric.data())

    # ==============================================================================
    # Metric 2: TCP Minimum Round Trip Time (RTT)
    # ==============================================================================
    # This metric reflects the minimum RTT measured between sending a TCP packet
    # and receiving the acknowledgement.
    # The output is a list of strings representing latency statistics:
    # [mean, p50, p90, p95, p99.9]
    # Example: ['100.00', '200.00', '300.00', '400.00', '500.00'] (Values in us - microseconds)

        # Get the metric object
        min_rtt_metric = tpumonitoring.get_metric(metric_name_rtt)

        # Print the description provided by the library
        print("Description:", min_rtt_metric.description())

        # Print the actual data payload
        print("Data:", min_rtt_metric.data())

TPU 指標的重新整理頻率

TPU 指標的重新整理頻率最低為一秒。主機指標資料會以 1 Hz 的固定頻率匯出。這個匯出程序造成的延遲時間可忽略不計。LibTPU 的執行階段指標不受相同頻率限制。不過,為確保一致性,這些指標也會以 1 Hz 的頻率取樣,也就是每秒取樣一次。

TPU-Z 模組

TPU-Z 是 TPU 的遙測和偵錯設施,可提供連結至主機的所有 TPU 核心詳細執行階段狀態資訊。這項功能透過 tpuz 模組提供,該模組是 libtpu Python SDK 中 libtpu.sdk 模組的一部分。這個模組會提供每個核心狀態的快照。

TPU-Z 的主要用途是診斷分散式 TPU 工作負載中的停止或死結。您可以在主機上查詢 TPU-Z 服務,擷取每個核心的狀態,比較所有核心的程式計數器、HLO 位置和執行 ID,找出異常狀況。

libtpu.sdk 程式庫中使用 get_core_state_summary() 函式,即可顯示 TPU-Z 指標:

summary = sdk.tpuz.get_core_state_summary()

TPU-Z 指標的輸出內容會以字典形式提供。以下是單一核心的截斷範例:

{
  "host_name": "my-tpu-host-vm",
  "core_states": {
    "1": {
      "core_id": {
        "global_core_id": 1,
        "chip_id": 0,
        "core_on_chip": {
          "type": "TPU_CORE_TYPE_TENSOR_CORE",
          "index": 1
        }
      },
      "sequencer_info": [
        {
          "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
          "sequencer_index": 0,
          "pc": 4490,
          "program_id": 3274167277388825310,
          "run_id": 3
        }
      ],
      "program_fingerprint": "b'\\xbefB\\xc6\\x1eb\\xc1#\\xd0...'",
      "queued_program_info": [],
      "error_message": ""
    }
    // ...
  }
}

如要擷取每個核心上高階最佳化工具 (HLO) 的相關資訊,請將 include_hlo_info 參數設為 True

summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)

輸出內容包含其他 HLO 資訊:

"1": {
    "core_id": {
      "global_core_id": 1,
      "chip_id": 0,
      "core_on_chip": {
        "type": "TPU_CORE_TYPE_TENSOR_CORE",
        "index": 1
      }
    },
    "sequencer_info": [
      {
        "sequencer_type": "TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER",
        "sequencer_index": 0,
        "pc": 17776,
        "tag": 3,
        "tracemark": 2147483646,
        "program_id": 3230481660274331500,
        "run_id": 81,
        "hlo_location": "HLO: fusion.11; HLO computation: main.126_spmd",
        "hlo_detailed_info": "[{\"details\":\"HloModule fusion.11, entry_computation_layout={(bf16[20>..."
      }
    ],
    "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde\uff>",
    "launch_id": 1394130914,
    "queued_program_info": [
      {
        "run_id": 81,
        "launch_id": 1394130914,
        "program_fingerprint": "\ufffdU\ufffd4j\u7c6e\ufffd\ufffd{\u0017\ufffd\ufffdHHV\ufffdD\ufffde>"
      }
    ]
  }

TPU-Z 指標

get_core_state_summary 函式會以字典形式傳回 TPU-Z 指標,結構如下。

CurrentCoreStateSummary

CurrentCoreStateSummary 字典會提供個別 TPU 核心狀態的詳細摘要。

欄位 類型 說明
core_id dictionary TpuCoreIdentifier 字典,其中包含 TPU 核心的 ID 資訊。
sequencer_info 字典清單 SequencerInfo 字典清單,說明核心上每個定序器的狀態。
program_fingerprint 位元組 在這個核心上執行的程式指紋。
launch_id 整數 目前或最近一次執行的程式啟動 ID。
queued_program_info 字典清單 已排入執行佇列的程式字典清單。QueuedProgramInfo
error_message 字串 這個核心的任何錯誤訊息。

TpuCoreIdentifier

TpuCoreIdentifier 字典提供 TPU 系統中核心的 ID 資訊。

欄位 類型 說明
global_core_id 整數 核心的 ID。
chip_id 整數 核心所屬晶片的 ID。
core_on_chip dictionary TpuCoreOnChip 字典,說明核心的類型和在晶片上的索引。

TpuCoreOnChip

TpuCoreOnChip 字典包含特定晶片中核心屬性的相關資訊。

欄位 類型 說明
type 字串 TPU 核心的類型。例如:TPU_CORE_TYPE_TENSOR_CORE
index 整數 晶片上核心的索引。

SequencerInfo

SequencerInfo 字典包含核心上單一定序器的狀態資訊。

欄位 類型 說明
sequencer_type 字串 定序器類型。例如:TPU_SEQUENCER_TYPE_TENSOR_CORE_SEQUENCER
sequencer_index 整數 定序器的索引 (如有相同類型的多個定序器)。
pc 整數 目前的程式計數器值。
program_id 整數 與特定程式例項相關聯的 ID,該例項會啟動,以便在 TPU 核心上執行。
run_id 整數 與在 TPU 核心上執行的特定程式執行個體相關聯的執行 ID。
hlo_location 字串 高階最佳化工具位置資訊。
hlo_detailed_info 字串 詳細的高階最佳化工具資訊。

QueuedProgramInfo

QueuedProgramInfo 字典包含已排入佇列,準備在核心上執行的程式相關資訊。

欄位 類型 說明
run_id 整數 已加入佇列的程式執行 ID。
launch_id 整數 已加入佇列節目的啟動 ID。
program_fingerprint 位元組 已加入佇列節目的指紋。

搭配 JAX 使用 TPU-Z

您可以在 JAX 工作負載中透過 libtpu.sdk 程式庫存取 TPU-Z 指標。下列 Python 指令碼使用 JAX 進行高效能張量運算,同時在背景執行緒中使用 libtpu SDK 監控基礎 TPU 硬體的狀態和活動。

包括下列 Python 套件:

import jax
import jax.numpy as jnp
import time
import threading
from functools import partial
from libtpu import sdk

monitor_tpu_status 函式會使用背景執行緒,在主要應用程式執行 JAX 工作負載時,持續顯示 TPU 核心的運作狀態,做為即時診斷工具。

def monitor_tpu_status():
  """Monitors TPU status in a background thread."""

  while monitoring_active:
    try:
      summary = sdk.tpuz.get_core_state_summary(include_hlo_info=True)
      if summary and 'core_states' in summary:
        print(summary)
      else:
        print('WARNING: Call returned an empty or invalid summary.')
    except RuntimeError as e:
      print(f'FAIL: Error calling API: {e}')
    except Exception as e:
      print(f'FAIL: Unexpected error in monitor thread: {e}')

    for _ in range(MONITORING_INTERVAL_SECONDS * 2):
      if not monitoring_active:
        break
      time.sleep(0.5)
  print('✅ Monitoring thread stopped.')

transformer_block 函式會實作 Transformer 架構的完整層,這是 LLM 的基礎構成元素。

@partial(jax.jit, static_argnames=['num_heads'])
def transformer_block(params, x, num_heads=32):
  """A simplified but computationally intensive Transformer block."""
  # Multi-head Self-Attention
  qkv = jnp.dot(x, params['qkv_kernel'])
  q, k, v = jnp.array_split(qkv, 3, axis=-1)

  # Reshape for multi-head attention
  q = q.reshape(q.shape[0], q.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  k = k.reshape(k.shape[0], k.shape[1], num_heads, -1).transpose(0, 2, 1, 3)
  v = v.reshape(v.shape[0], v.shape[1], num_heads, -1).transpose(0, 2, 1, 3)

  # Scaled dot-product attention
  attention_scores = jnp.einsum('nhqd,nhkd->nhqk', q, k) / jnp.sqrt(q.shape[-1])
  attention_weights = jax.nn.softmax(attention_scores, axis=-1)
  attention_output = jnp.einsum('nhqk,nhvd->nhqd', attention_weights, v)
  attention_output = attention_output.transpose(0, 2, 1, 3).reshape(x.shape)

  attention_output = jnp.dot(attention_output, params['o_kernel'])

  # Residual connection and Layer Normalization 1
  h1 = x + attention_output
  h1_norm = h1 - jnp.mean(h1, axis=-1, keepdims=True)
  h1_norm = h1_norm / jnp.sqrt(
      jnp.mean(jnp.square(h1_norm), axis=-1, keepdims=True) + 1e-5
  )

  # Feed-Forward Network
  ffn_hidden = jax.nn.gelu(jnp.dot(h1_norm, params['ffn1_kernel']))
  ffn_output = jnp.dot(ffn_hidden, params['ffn2_kernel'])

  # Residual connection and Layer Normalization 2
  h2 = h1_norm + ffn_output
  h2_norm = h2 - jnp.mean(h2, axis=-1, keepdims=True)
  h2_norm = h2_norm / jnp.sqrt(
      jnp.mean(jnp.square(h2_norm), axis=-1, keepdims=True) + 1e-5
  )

  return h2_norm

main 函式會協調 JAX 運算設定、啟動背景 TPU 監控,並執行主要工作負載迴圈。

def main():
  num_devices = jax.device_count()
  print(f"Running on {num_devices} devices.")

  batch_size = 128 * num_devices
  seq_len = 512
  embed_dim = 1024
  ffn_dim = embed_dim * 4

  key = jax.random.PRNGKey(0)

  params = {
      'qkv_kernel': jax.random.normal(
          key, (embed_dim, embed_dim * 3), dtype=jnp.bfloat16
      ),
      'o_kernel': jax.random.normal(
          key, (embed_dim, embed_dim), dtype=jnp.bfloat16
      ),
      'ffn1_kernel': jax.random.normal(
          key, (embed_dim, ffn_dim), dtype=jnp.bfloat16
      ),
      'ffn2_kernel': jax.random.normal(
          key, (ffn_dim, embed_dim), dtype=jnp.bfloat16
      ),
  }
  input_data = jax.random.normal(
      key, (batch_size, seq_len, embed_dim), dtype=jnp.bfloat16
  )
  input_data = jax.device_put(input_data)
  monitor_thread = threading.Thread(target=monitor_tpu_status)
  monitor_thread.start()
  print("Starting JAX computation loop...")
  start_time = time.time()
  iterations = 0
  while time.time() - start_time < JOB_DURATION_SECONDS:
    result = transformer_block(params, input_data)
    result.block_until_ready()
    iterations += 1
    print(f'  -> Jax iteration {iterations} complete.', end='\r')

  print(f"\nCompleted {iterations} iterations in {JOB_DURATION_SECONDS} seconds.")

  global monitoring_active
  monitoring_active = False
  monitor_thread.join()

if __name__ == '__main__':
  main()

疑難排解

本節提供疑難排解資訊,協助您找出並解決使用 TPU 監控程式庫時可能遇到的問題。

缺少功能或指標

如果無法查看某些功能或指標,最常見的原因是 libtpu 版本過舊。TPU 監控程式庫的功能和指標會納入 libtpu 版本,舊版可能缺少新功能和指標。

檢查環境中執行的 libtpu 版本:

指令列:

pip show libtpu

Python:

import libtpu

print(libtpu.__version__)

如果使用的不是 libtpu最新版本,請使用下列指令更新程式庫:

pip install --upgrade libtpu