SHGAT Class
The SHGAT class is the main orchestrator for SuperHyperGraph Attention Networks. It handles graph construction, multi-level message passing, K-head scoring, and parameter persistence. Published on JSR.
Factory functions
Section titled “Factory functions”Use factory functions instead of the constructor directly. They handle adaptive head selection, graph building, and level computation automatically.
createSHGAT()
Section titled “createSHGAT()”Create a SHGAT instance from unified Node objects. This is the recommended API.
import { createSHGAT, type Node } from "@casys/shgat";
const nodes: Node[] = [ { id: "psql_query", embedding: queryEmb, children: [], level: 0 }, { id: "psql_exec", embedding: execEmb, children: [], level: 0 }, { id: "database", embedding: dbEmb, children: ["psql_query", "psql_exec"], level: 1 },];
const model = createSHGAT(nodes);// or with config overrides:const model2 = createSHGAT(nodes, { numHeads: 8, learningRate: 0.01 });| Parameter | Type | Description |
|---|---|---|
nodes | Node[] | Array of nodes. Leaves have children: [], composites list their children. |
config | Partial<SHGATConfig> | Optional configuration overrides. |
Returns: SHGAT — a fully initialized instance with all nodes registered, levels computed, and indices built.
SHGAT class
Section titled “SHGAT class”Constructor
Section titled “Constructor”import { SHGAT } from "@casys/shgat";
const shgat = new SHGAT(config?: Partial<SHGATConfig>);Creates a bare SHGAT instance. You must register nodes manually and call finalizeNodes() before scoring. In most cases, use createSHGAT() instead.
Graph management
Section titled “Graph management”| Method | Signature | Description |
|---|---|---|
registerNode() | (node: Node) => void | Register a unified node. Call finalizeNodes() after all registrations. |
finalizeNodes() | () => void | Rebuild indices and compute hierarchy levels. Call once after registering all nodes. |
registerTool() | (node: ToolNode) => void | Register a tool (leaf). Deprecated — use registerNode() with children: []. |
registerCapability() | (node: CapabilityNode) => void | Register a capability (hyperedge). Deprecated — use registerNode() with children. |
buildFromData() | (tools, capabilities) => void | Batch register tools and capabilities from raw data. |
hasToolNode() | (id: string) => boolean | Check if a tool node exists. |
hasCapabilityNode() | (id: string) => boolean | Check if a capability node exists. |
getToolCount() | () => number | Number of registered tools. |
getCapabilityCount() | () => number | Number of registered capabilities. |
getToolIds() | () => string[] | All registered tool IDs. |
getCapabilityIds() | () => string[] | All registered capability IDs. |
Co-occurrence
Section titled “Co-occurrence”| Method | Signature | Description |
|---|---|---|
setCooccurrenceData() | (data: CooccurrenceEntry[]) => void | Set V-to-V co-occurrence edges for tool embedding enrichment. |
getToolIndexMap() | () => Map<string, number> | Get tool ID to index mapping (used by co-occurrence loader). |
Scoring methods
Section titled “Scoring methods”scoreNodes()
Section titled “scoreNodes()”The main scoring function. Runs tensor-native forward pass (multi-level message passing) then K-head attention scoring. All tensor operations stay on the native backend until final result conversion.
const ranked = model.scoreNodes(intentEmbedding);
// Filter by levelconst toolScores = model.scoreNodes(intentEmbedding, 0); // leaves onlyconst capScores = model.scoreNodes(intentEmbedding, 1); // composites only| Parameter | Type | Description |
|---|---|---|
intentEmbedding | number[] | User intent embedding (1024-dim BGE-M3). |
level | number | undefined | Optional level filter. 0 = leaves, 1 = composites. Omit to score all. |
Returns: NodeScore[] — sorted by score descending.
interface NodeScore { nodeId: string; score: number; headScores: number[]; level: number;}scoreLeaves() / scoreComposites()
Section titled “scoreLeaves() / scoreComposites()”Convenience wrappers around scoreNodes().
| Method | Equivalent | Description |
|---|---|---|
scoreLeaves(intent) | scoreNodes(intent, 0) | Score leaf nodes (tools) only. |
scoreComposites(intent, level?) | scoreNodes(intent, level ?? 1) | Score composite nodes at a given level. |
Legacy scoring
Section titled “Legacy scoring”| Method | Signature | Description |
|---|---|---|
scoreAllCapabilities() | (intent: number[], contextToolIds?: string[]) => AttentionResult[] | Deprecated. Use scoreNodes(intent, 1). |
scoreAllTools() | (intent: number[], contextToolIds?: string[]) => Array<{ toolId, score, headScores }> | Deprecated. Use scoreNodes(intent, 0). |
interface AttentionResult { capabilityId: string; score: number; headScores: number[]; toolAttention: number[];}Other scoring
Section titled “Other scoring”| Method | Signature | Description |
|---|---|---|
predictPathSuccess() | (intent: number[], path: string[]) => number | Predict success probability for a tool execution path. |
computeAttention() | (intent, contextEmbeddings, capId, contextCapIds?) => AttentionResult | Compute detailed attention for a single capability. |
Message passing
Section titled “Message passing”forward()
Section titled “forward()”Execute multi-level message passing (V->E->…->V). Results are cached until the graph changes.
const { H, E, cache } = model.forward();// H: number[][] — enriched tool embeddings// E: number[][] — enriched capability embeddings// cache: ForwardCache — intermediate resultsTraining
Section titled “Training”AutogradTrainer (recommended)
Section titled “AutogradTrainer (recommended)”The SHGAT class itself does not perform training. Use AutogradTrainer from the training module, which provides TensorFlow.js automatic differentiation.
import { AutogradTrainer, DEFAULT_TRAINER_CONFIG, type TrainingMetrics,} from "@casys/shgat";
const trainer = new AutogradTrainer({ ...DEFAULT_TRAINER_CONFIG, numHeads: 16, embeddingDim: 1024, learningRate: 0.05,});
// Set embeddings from your SHGAT graphtrainer.setNodeEmbeddings(embeddings);
// Train on a batchconst metrics: TrainingMetrics = trainer.trainBatch(examples);console.log(`Loss: ${metrics.loss}, Accuracy: ${metrics.accuracy}`);Deprecated training functions
Section titled “Deprecated training functions”These functions are exported for backward compatibility but throw errors at runtime. Use AutogradTrainer instead.
| Function | Status |
|---|---|
trainSHGATOnEpisodes() | Deprecated. Throws. |
trainSHGATOnEpisodesKHead() | Deprecated. Throws. |
trainSHGATOnExecution() | Deprecated. Throws. |
shgat.trainBatchV1KHeadBatched() | Deprecated. Throws. |
PER buffer
Section titled “PER buffer”Prioritized Experience Replay for sample-efficient training.
import { PERBuffer, annealBeta, annealTemperature } from "@casys/shgat";
const buffer = new PERBuffer(maxSize);buffer.add(example, priority);
const { samples, weights } = buffer.sample(batchSize);const beta = annealBeta(epoch, maxEpochs);Persistence
Section titled “Persistence”exportParams()
Section titled “exportParams()”Serialize all learned parameters to a JSON-compatible object for storage.
const serialized = model.exportParams();// Store in database, file, etc.await Deno.writeTextFile("params.json", JSON.stringify(serialized));Returns: Record<string, unknown> — contains config, head parameters, level parameters, V2V parameters, and fusion weights.
importParams()
Section titled “importParams()”Restore parameters from a previously exported object.
const data = JSON.parse(await Deno.readTextFile("params.json"));model.importParams(data);| Parameter | Type | Description |
|---|---|---|
serialized | Record<string, unknown> | Output from exportParams(). |
Importing parameters invalidates cached tensor parameters. They are recreated lazily on the next scoreNodes() call.
Resource management
Section titled “Resource management”dispose()
Section titled “dispose()”Free GPU memory used by tensor parameters. Call when the SHGAT instance is no longer needed.
model.dispose();After calling dispose(), the instance can still be used — tensor parameters are recreated lazily on the next scoreNodes() call. But you should call dispose() to avoid GPU memory leaks.
Configuration
Section titled “Configuration”SHGATConfig
Section titled “SHGATConfig”interface SHGATConfig { // Architecture numHeads: number; // Attention heads (4-16, adaptive). Default: 16 hiddenDim: number; // Hidden dim for scoring. Default: 1024 headDim: number; // Dim per head (hiddenDim / numHeads). Default: 64 embeddingDim: number; // Embedding dim (BGE-M3: 1024). Default: 1024 numLayers: number; // Message passing layers. Default: 2 mlpHiddenDim: number; // Fusion MLP hidden size. Default: 32
// Training learningRate: number; // Default: 0.05 batchSize: number; // Default: 32 maxContextLength: number; // Max recent tools in context. Default: 5
// Buffer management maxBufferSize: number; // PER buffer cap. Default: 50_000 minTracesForTraining: number; // Cold start threshold. Default: 100
// Regularization dropout: number; // Default: 0.1 l2Lambda: number; // L2 weight. Default: 0.0001 leakyReluSlope: number; // Default: 0.2 depthDecay: number; // Recursive depth decay. Default: 0.8
// Dimension preservation preserveDim?: boolean; // Keep 1024-dim through MP. Default: true preserveDimResidual?: number; // Residual blend weight. Default: 0.3 preserveDimResiduals?: number[]; // Per-level residual weights
// Multi-location residuals v2vResidual?: number; // V2V phase residual. Default: 0 downwardResidual?: number; // Downward phase residual. Default: 0
// Gradient scaling mpLearningRateScale?: number; // LR multiplier for MP params. Default: 1
// Projection head useProjectionHead?: boolean; // Default: false projectionHiddenDim?: number; // Default: 256 projectionOutputDim?: number; // Default: 256 projectionBlendAlpha?: number; // Default: 0.5 projectionTemperature?: number; // Default: 0.07}DEFAULT_SHGAT_CONFIG
Section titled “DEFAULT_SHGAT_CONFIG”The default configuration object. Uses 16 heads for optimal performance (16 heads x 64 dim = 1024, matching BGE-M3 embeddings).
import { DEFAULT_SHGAT_CONFIG } from "@casys/shgat";getAdaptiveConfig()
Section titled “getAdaptiveConfig()”Returns adaptive configuration based on trace count. Deprecated — createSHGAT() uses getAdaptiveHeadsByGraphSize() automatically.
import { getAdaptiveConfig } from "@casys/shgat";
const overrides = getAdaptiveConfig(traceCount);// Always returns { numHeads: 16, hiddenDim: 1024, headDim: 64, mlpHiddenDim: 32 }TrainingExample
Section titled “TrainingExample”The input format for training data.
interface TrainingExample { intentEmbedding: number[]; // 1024-dim intent embedding contextTools: string[]; // Active tool IDs in session candidateId: string; // Positive capability ID (was executed) outcome: number; // 1 = success, 0 = failure negativeCapIds?: string[]; // Negative IDs for contrastive learning allNegativesSorted?: string[]; // All negatives sorted hard-to-easy (for curriculum)}Logging
Section titled “Logging”import { setLogger, resetLogger, getLogger, type Logger } from "@casys/shgat";
// Custom loggersetLogger({ info: (msg) => myLogger.info(msg), warn: (msg) => myLogger.warn(msg), error: (msg) => myLogger.error(msg), debug: (msg) => myLogger.debug(msg),});
// Reset to default console loggerresetLogger();