From 2bfebea8ba886a40c6b91eaad74860373c4eca09 Mon Sep 17 00:00:00 2001 From: drbh Date: Fri, 21 Mar 2025 11:06:13 -0400 Subject: [PATCH 1/3] feat: generate docs from build file and source --- README.md | 1 + docs/documentation-generator.md | 74 ++++++ scripts/gen-docs.py | 431 ++++++++++++++++++++++++++++++++ scripts/init-kernel.py | 5 + 4 files changed, 511 insertions(+) create mode 100644 docs/documentation-generator.md create mode 100644 scripts/gen-docs.py diff --git a/README.md b/README.md index 388c46b..6ffe4f1 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ the activation folder. - [Building kernels with Nix](./docs/nix.md) - [Local kernel development](docs/local-dev.md) (IDE integration) - [Why Nix?](./docs/why-nix.md) +- [Documentation Generator](./docs/documentation-generator.md) ## Credits diff --git a/docs/documentation-generator.md b/docs/documentation-generator.md new file mode 100644 index 0000000..89f5a7c --- /dev/null +++ b/docs/documentation-generator.md @@ -0,0 +1,74 @@ +# Documentation Generator + +This tool helps you generate consistent documentation for your CUDA kernel projects built with kernel-builder. It analyzes source files and build configuration to create structured Markdown documentation. + +## Features + +- Extracts documentation from CUDA, C++, and header files +- Parses build configuration from build.toml +- Identifies kernels and function definitions with their parameters +- Generates Markdown documentation with a table of contents +- Includes build settings and project information + +## Usage + +```bash +# Generate documentation for a kernel project +python scripts/gen-docs.py /path/to/your/kernel/project + +# Specify custom output directory (default is "docs") +python scripts/gen-docs.py /path/to/your/kernel/project --output custom-docs +``` + +## Comment Format + +For best results, document your kernels and functions using the following format: + +```cpp +/** + * This is a description of the kernel or function. + * + * Any additional details about the implementation or usage can go here. + */ +__global__ void my_kernel(float* input, float* output, int size) { + // kernel implementation +} +``` + +The generator will extract this documentation and include it in the generated files. + +## Output Structure + +The generated documentation includes: + +1. **Table of Contents** - Navigation for all documentation sections +2. **Project Overview** - Basic information about the project (from build.toml if available) +3. **Build Configuration** - Settings from build.toml including: + - Kernel definitions + - CUDA capabilities + - Source files + - Dependencies +4. **API Documentation** - Documentation for each source file: + - Function and kernel signatures + - Parameter tables with types + - Documentation comments + +## Example + +Given a project with the following structure: + +``` +my_kernel/ +├── build.toml +├── kernel.cu +└── torch-ext/ + └── ... +``` + +Running the documentation generator: + +```bash +python scripts/gen-docs.py my_kernel +``` + +Will produce `my_kernel.md` in the `my_kernel/docs` directory. \ No newline at end of file diff --git a/scripts/gen-docs.py b/scripts/gen-docs.py new file mode 100644 index 0000000..f3477ba --- /dev/null +++ b/scripts/gen-docs.py @@ -0,0 +1,431 @@ +# /// script +# dependencies = [ +# "toml", +# ] +# /// +import os +import re +import toml +import argparse +from pathlib import Path +from datetime import datetime +from typing import Dict, List, Optional, Any, Tuple + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Generate documentation for kernel projects" + ) + parser.add_argument("project_path", help="Path to the kernel project root") + parser.add_argument("--output", "-o", help="Output directory", default="docs") + parser.add_argument( + "--toc", + "-t", + help="Include table of contents", + action="store_true", + default=True, + ) + return parser.parse_args() + + +def parse_build_config(project_path: str) -> Optional[Dict[str, Any]]: + """Parse the build.toml configuration file.""" + build_config_path = os.path.join(project_path, "build.toml") + if not os.path.exists(build_config_path): + print(f"Error: build.toml not found at {build_config_path}") + return None + + try: + config = toml.load(build_config_path) + print(f"Successfully parsed build configuration from {build_config_path}") + return config + except Exception as e: + print(f"Error parsing build.toml: {e}") + return None + + +def extract_function_signature(content: str, func_name: str) -> str: + """Extract the full function signature for better documentation.""" + # Look for the function definition + pattern = rf"(?:__global__ void|void|template\s*<.*?>\s*__global__ void|template\s*<.*?>\s*void)\s+{func_name}\s*\([^\)]*\)" + match = re.search(pattern, content, re.DOTALL) + if match: + signature = match.group(0).strip() + # Remove trailing comments on each line + signature_lines = signature.split("\n") + signature_lines = [line.split("//")[0].strip() for line in signature_lines] + signature = "\n".join(signature_lines) + # Clean up the signature + signature = re.sub(r"\s+", " ", signature) + return signature + return func_name + + +def extract_function_params(content: str, func_name: str) -> List[Dict[str, str]]: + """Extract function parameters with their types.""" + pattern = rf"(?:__global__ void|void|template\s*<.*?>\s*__global__ void|template\s*<.*?>\s*void)\s+{func_name}\s*\(([^\)]*)\)" + match = re.search(pattern, content, re.DOTALL) + params = [] + + if match: + param_str = match.group(1).strip() + # Remove trailing comments on each line + param_str_lines = param_str.split("\n") + param_str_lines = [line.split("//")[0].strip() for line in param_str_lines] + param_str = "\n".join(param_str_lines) + + if param_str: + # Split by commas, but handle nested template parameters + param_parts = [] + current_part = "" + template_depth = 0 + + for char in param_str: + if char == "," and template_depth == 0: + param_parts.append(current_part.strip()) + current_part = "" + else: + if char == "<": + template_depth += 1 + elif char == ">": + template_depth -= 1 + current_part += char + + if current_part: + param_parts.append(current_part.strip()) + + for part in param_parts: + # Extract type and name + parts = part.split() + if len(parts) >= 2: + param_type = " ".join(parts[:-1]) + param_name = parts[-1].rstrip(",") + # Clean param name (remove pointers/references from name) + param_name = param_name.lstrip("*&") + params.append({"type": param_type, "name": param_name}) + + return params + + +def parse_kernel_config(config: Dict[str, Any]) -> Dict[str, Any]: + """Extract kernel configuration from the build.toml.""" + kernel_config = {} + + # Extract kernels section + for key, value in config.items(): + if key == "kernel": + kernel_config = value + break + + return kernel_config + + +def extract_kernel_docs( + project_path: str, kernel_config: Dict[str, Any] +) -> List[Dict[str, Any]]: + """Extract documentation from kernel source files.""" + kernel_info = [] + source_files = [] + + # Collect source files from kernel config first + for kernel_name, kernel_data in kernel_config.items(): + if "src" in kernel_data and isinstance(kernel_data["src"], list): + for src_file in kernel_data["src"]: + file_path = os.path.join(project_path, src_file) + if os.path.exists(file_path): + source_files.append(Path(file_path)) + + # Also collect all other source files + for ext in [".cu", ".h", ".cpp", ".cuh"]: + extra_files = [ + p for p in Path(project_path).glob(f"**/*{ext}") if p not in source_files + ] + source_files.extend(extra_files) + + source_files = sorted(set(source_files)) + + for source_file in source_files: + rel_path = os.path.relpath(source_file, project_path) + file_info = {"file": rel_path, "functions": []} + + try: + with open(source_file, "r", encoding="utf-8") as f: + content = f.read() + + # Extract kernel declarations with comments + kernel_pattern = r"/\*\*\s*(.*?)\s*\*/\s*(?:__global__ void|template\s*<.*?>\s*__global__ void)\s+(\w+)" + function_pattern = r"/\*\*\s*(.*?)\s*\*/\s*(?:template\s*<.*?>\s*)?(?:inline\s+)?(?:__host__\s+)?(?:__device__\s+)?(?:\w+\s+)+(\w+)\s*\(" + + # Extract kernels without comments too + simple_kernel_pattern = r"__global__\s+void\s+(\w+)\s*\(" + simple_function_pattern = r"(?:__host__|__device__|__host__\s+__device__|__device__\s+__host__|void)\s+(\w+)\s*\(" + + # Process kernels with comments + for match in re.finditer(kernel_pattern, content, re.DOTALL): + comment = match.group(1).strip().replace("*", "").strip() + name = match.group(2) + signature = extract_function_signature(content, name) + params = extract_function_params(content, name) + + file_info["functions"].append( + { + "name": name, + "type": "kernel", + "doc": comment, + "signature": signature, + "params": params, + } + ) + + # Process regular functions with comments + for match in re.finditer(function_pattern, content, re.DOTALL): + comment = match.group(1).strip().replace("*", "").strip() + name = match.group(2) + # Skip if already processed as kernel + if any(f["name"] == name for f in file_info["functions"]): + continue + signature = extract_function_signature(content, name) + params = extract_function_params(content, name) + + file_info["functions"].append( + { + "name": name, + "type": "function", + "doc": comment, + "signature": signature, + "params": params, + } + ) + + # Process kernels without comments + for match in re.finditer(simple_kernel_pattern, content): + name = match.group(1) + # Skip if already processed + if any(f["name"] == name for f in file_info["functions"]): + continue + signature = extract_function_signature(content, name) + params = extract_function_params(content, name) + + file_info["functions"].append( + { + "name": name, + "type": "kernel", + "doc": "", + "signature": signature, + "params": params, + } + ) + + # Process simple functions without comments + for match in re.finditer(simple_function_pattern, content): + name = match.group(1) + # Skip if already processed + if any(f["name"] == name for f in file_info["functions"]): + continue + signature = extract_function_signature(content, name) + params = extract_function_params(content, name) + + file_info["functions"].append( + { + "name": name, + "type": "function", + "doc": "", + "signature": signature, + "params": params, + } + ) + + # Only add files that have documented functions + if file_info["functions"]: + kernel_info.append(file_info) + + except Exception as e: + print(f"Error processing {source_file}: {e}") + + return kernel_info + + +def generate_toc(sections): + """Generate a table of contents from section headers.""" + toc = ["## Table of Contents", ""] + for section in sections: + indent = " " * (section["level"] - 1) + link = ( + section["title"] + .lower() + .replace(" ", "-") + .replace(".", "") + .replace("(", "") + .replace(")", "") + .replace(":", "") + ) + toc.append(f"{indent}- [{section['title']}](#{link})") + return toc + + +def format_parameter_table(params): + """Format function parameters as a markdown table.""" + if not params: + return "" + + table = ["| Parameter | Type |", "|-----------|------|"] + for param in params: + table.append(f"| `{param['name']}` | `{param['type']}` |") + + return "\n".join(table) + + +def generate_markdown( + project_path: str, + config: Dict[str, Any], + kernel_info: List[Dict[str, Any]], + include_toc: bool = True, +) -> str: + """Generate markdown documentation from parsed information.""" + project_name = os.path.basename(os.path.abspath(project_path)) + sections = [] + + # Extract project name from config if available + if config and "general" in config and "name" in config["general"]: + project_name = config["general"]["name"] + + # Start with title and metadata + lines = [ + f"# `{project_name}` Documentation", + "", + f"*Generated on {datetime.now().strftime('%Y-%m-%d')}*", + "", + ] + + # Add project overview + sections.append({"title": "Project Overview", "level": 2}) + lines.extend(["## Project Overview", ""]) + + # If we have a description in the config, use it + if config and "general" in config and "description" in config.get("general", {}): + lines.append(config["general"]["description"]) + else: + lines.append(f"{project_name} is a CUDA kernel project.") + lines.append("") + + # Add configuration details + if config: + sections.append({"title": "Build Configuration", "level": 2}) + lines.append("## Build Configuration") + lines.append("") + + # Extract kernel configurations + kernel_configs = parse_kernel_config(config) + if kernel_configs: + sections.append({"title": "Kernels", "level": 3}) + lines.append("### Kernels") + lines.append("") + + for kernel_name, kernel_data in kernel_configs.items(): + lines.append(f"#### {kernel_name}") + lines.append("") + + if "cuda-capabilities" in kernel_data: + capabilities = ", ".join(kernel_data["cuda-capabilities"]) + lines.append("**CUDA Capabilities:**") + lines.append(f"- `[{capabilities}]`") + lines.append("") + + if "src" in kernel_data: + lines.append("**Source Files:**") + for src_file in kernel_data["src"]: + lines.append(f"- `{src_file}`") + lines.append("") + + if "depends" in kernel_data: + lines.append("**Dependencies:**") + for dep in kernel_data["depends"]: + lines.append(f"- `{dep}`") + lines.append("") + + # Add source documentation + if kernel_info: + sections.append({"title": "API Documentation", "level": 2}) + lines.append("## API Documentation") + lines.append("") + + for file_info in kernel_info: + sections.append({"title": file_info["file"], "level": 3}) + lines.append(f"### {file_info['file']}") + lines.append("") + + # Add functions and kernels + for func in sorted(file_info["functions"], key=lambda x: x["name"]): + func_type = "Kernel" if func["type"] == "kernel" else "Function" + sections.append({"title": f"{func['name']} ({func_type})", "level": 4}) + lines.append(f"#### {func['name']} ({func_type})") + lines.append("") + + # Add function signature in code block + lines.append("```cpp") + lines.append(func["signature"]) + lines.append("```") + lines.append("") + + # Add documentation if it exists + if func["doc"]: + lines.append(func["doc"]) + lines.append("") + + # Add parameter table if not in short mode + if func["params"]: + lines.append("**Parameters:**") + lines.append("") + lines.append(format_parameter_table(func["params"])) + lines.append("") + + # Insert table of contents after the title if requested + if include_toc: + toc = generate_toc(sections) + lines = lines[:3] + [""] + toc + [""] + lines[3:] + + return "\n".join(lines) + + +def main(): + """Main function to run the documentation generator.""" + args = parse_args() + project_path = args.project_path + + # Parse build configuration + config = parse_build_config(project_path) + + # Parse kernel configs from build.toml + kernel_config = parse_kernel_config(config) if config else {} + + # Extract kernel documentation + kernel_info = extract_kernel_docs(project_path, kernel_config) + + # Generate documentation + markdown_content = generate_markdown( + project_path, + config, + kernel_info, + include_toc=args.toc, + ) + + # Create output directory if it doesn't exist + output_dir = os.path.join(project_path, args.output) + os.makedirs(output_dir, exist_ok=True) + + # Write output file + project_name = os.path.basename(os.path.abspath(project_path)) + if config and "general" in config and "name" in config["general"]: + project_name = config["general"]["name"] + + output_path = os.path.join(output_dir, f"{project_name}.md") + + with open(output_path, "w", encoding="utf-8") as f: + f.write(markdown_content) + + print(f"Documentation generated at {output_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/init-kernel.py b/scripts/init-kernel.py index 3fe7bce..6a0a214 100644 --- a/scripts/init-kernel.py +++ b/scripts/init-kernel.py @@ -163,6 +163,11 @@ def main(): f" {Colors.YELLOW}{1}.{Colors.ENDC} {Colors.BOLD}pytest -vv tests/{Colors.ENDC}" ) + print(f"\n{Colors.CYAN}{Colors.BOLD}Generate documentation{Colors.ENDC}") + print( + f" {Colors.YELLOW}{1}.{Colors.ENDC} {Colors.BOLD}uv run scripts/gen-docs.py ./{Colors.ENDC}" + ) + print("") From 20b295eec22c5af06788c218247e692b7aa5e6d4 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 24 Mar 2025 11:01:41 -0400 Subject: [PATCH 2/3] feat: prefer libclang parsing over regex --- scripts/gen-docs.py | 337 +++++++++++++++++++------------------------- 1 file changed, 144 insertions(+), 193 deletions(-) diff --git a/scripts/gen-docs.py b/scripts/gen-docs.py index f3477ba..491f1e7 100644 --- a/scripts/gen-docs.py +++ b/scripts/gen-docs.py @@ -1,124 +1,95 @@ # /// script # dependencies = [ # "toml", +# "clang", # ] # /// +from clang.cindex import Config +import clang.cindex +from clang.cindex import CursorKind +from typing import Dict, List, Optional, Any, Tuple import os -import re -import toml -import argparse from pathlib import Path +import argparse from datetime import datetime -from typing import Dict, List, Optional, Any, Tuple +import toml -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Generate documentation for kernel projects" - ) - parser.add_argument("project_path", help="Path to the kernel project root") - parser.add_argument("--output", "-o", help="Output directory", default="docs") - parser.add_argument( - "--toc", - "-t", - help="Include table of contents", - action="store_true", - default=True, - ) - return parser.parse_args() +Config.set_library_file("/Library/Developer/CommandLineTools/usr/lib/libclang.dylib") +Config.set_compatibility_check(False) -def parse_build_config(project_path: str) -> Optional[Dict[str, Any]]: - """Parse the build.toml configuration file.""" - build_config_path = os.path.join(project_path, "build.toml") - if not os.path.exists(build_config_path): - print(f"Error: build.toml not found at {build_config_path}") - return None +def get_function_declarations(file_path): + """ + Extract all function declarations from a C++ file. - try: - config = toml.load(build_config_path) - print(f"Successfully parsed build configuration from {build_config_path}") - return config - except Exception as e: - print(f"Error parsing build.toml: {e}") - return None + Args: + file_path (str): Path to the C++ file + Returns: + list: List of dictionaries containing function information + """ + # Initialize clang index + index = clang.cindex.Index.create() -def extract_function_signature(content: str, func_name: str) -> str: - """Extract the full function signature for better documentation.""" - # Look for the function definition - pattern = rf"(?:__global__ void|void|template\s*<.*?>\s*__global__ void|template\s*<.*?>\s*void)\s+{func_name}\s*\([^\)]*\)" - match = re.search(pattern, content, re.DOTALL) - if match: - signature = match.group(0).strip() - # Remove trailing comments on each line - signature_lines = signature.split("\n") - signature_lines = [line.split("//")[0].strip() for line in signature_lines] - signature = "\n".join(signature_lines) - # Clean up the signature - signature = re.sub(r"\s+", " ", signature) - return signature - return func_name - - -def extract_function_params(content: str, func_name: str) -> List[Dict[str, str]]: - """Extract function parameters with their types.""" - pattern = rf"(?:__global__ void|void|template\s*<.*?>\s*__global__ void|template\s*<.*?>\s*void)\s+{func_name}\s*\(([^\)]*)\)" - match = re.search(pattern, content, re.DOTALL) - params = [] - - if match: - param_str = match.group(1).strip() - # Remove trailing comments on each line - param_str_lines = param_str.split("\n") - param_str_lines = [line.split("//")[0].strip() for line in param_str_lines] - param_str = "\n".join(param_str_lines) - - if param_str: - # Split by commas, but handle nested template parameters - param_parts = [] - current_part = "" - template_depth = 0 - - for char in param_str: - if char == "," and template_depth == 0: - param_parts.append(current_part.strip()) - current_part = "" - else: - if char == "<": - template_depth += 1 - elif char == ">": - template_depth -= 1 - current_part += char - - if current_part: - param_parts.append(current_part.strip()) - - for part in param_parts: - # Extract type and name - parts = part.split() - if len(parts) >= 2: - param_type = " ".join(parts[:-1]) - param_name = parts[-1].rstrip(",") - # Clean param name (remove pointers/references from name) - param_name = param_name.lstrip("*&") - params.append({"type": param_type, "name": param_name}) - - return params + # Parse the file + translation_unit = index.parse(file_path) + # Check for parsing errors + if not translation_unit: + print(f"Error parsing {file_path}") + return [] -def parse_kernel_config(config: Dict[str, Any]) -> Dict[str, Any]: - """Extract kernel configuration from the build.toml.""" - kernel_config = {} + functions = [] - # Extract kernels section - for key, value in config.items(): - if key == "kernel": - kernel_config = value - break + # Helper function to recursively traverse the AST + def traverse_ast(cursor, parent=None): + # Check if the cursor represents a function declaration + if cursor.kind == CursorKind.FUNCTION_DECL: + # Get function return type + return_type = cursor.type.get_result().spelling - return kernel_config + # Get function name + func_name = cursor.spelling + + # Get function parameters + params = [] + for param in cursor.get_arguments(): + params.append({"name": param.spelling, "type": param.type.spelling}) + + # Get function location + location = cursor.location + file_path = location.file.name if location.file else "Unknown" + line = location.line + column = location.column + + # Check if the function has a body + has_body = any( + c.kind == CursorKind.COMPOUND_STMT for c in cursor.get_children() + ) + + # Determine if it's a declaration or definition + func_type = "definition" if has_body else "declaration" + + # Add function info to our list + functions.append( + { + "name": func_name, + "return_type": return_type, + "parameters": params, + "location": {"file": file_path, "line": line, "column": column}, + "type": func_type, + } + ) + + # Recursively process children + for child in cursor.get_children(): + traverse_ast(child, cursor) + + # Start traversing from the translation unit cursor + traverse_ast(translation_unit.cursor) + + return functions def extract_kernel_docs( @@ -138,9 +109,14 @@ def extract_kernel_docs( # Also collect all other source files for ext in [".cu", ".h", ".cpp", ".cuh"]: - extra_files = [ - p for p in Path(project_path).glob(f"**/*{ext}") if p not in source_files - ] + extra_files = [] + for p in Path(project_path).glob(f"**/*{ext}"): + # avoid adding `torch-ext` files + if "torch-ext" in str(p): + continue + if p not in source_files: + extra_files.append(p) + source_files.extend(extra_files) source_files = sorted(set(source_files)) @@ -149,101 +125,76 @@ def extract_kernel_docs( rel_path = os.path.relpath(source_file, project_path) file_info = {"file": rel_path, "functions": []} - try: - with open(source_file, "r", encoding="utf-8") as f: - content = f.read() - - # Extract kernel declarations with comments - kernel_pattern = r"/\*\*\s*(.*?)\s*\*/\s*(?:__global__ void|template\s*<.*?>\s*__global__ void)\s+(\w+)" - function_pattern = r"/\*\*\s*(.*?)\s*\*/\s*(?:template\s*<.*?>\s*)?(?:inline\s+)?(?:__host__\s+)?(?:__device__\s+)?(?:\w+\s+)+(\w+)\s*\(" - - # Extract kernels without comments too - simple_kernel_pattern = r"__global__\s+void\s+(\w+)\s*\(" - simple_function_pattern = r"(?:__host__|__device__|__host__\s+__device__|__device__\s+__host__|void)\s+(\w+)\s*\(" - - # Process kernels with comments - for match in re.finditer(kernel_pattern, content, re.DOTALL): - comment = match.group(1).strip().replace("*", "").strip() - name = match.group(2) - signature = extract_function_signature(content, name) - params = extract_function_params(content, name) - - file_info["functions"].append( - { - "name": name, - "type": "kernel", - "doc": comment, - "signature": signature, - "params": params, - } + functions = get_function_declarations(source_file) + for func in functions: + signature = f"{func['return_type']} {func['name']}(" + signature += ( + ", ".join([f"{p['type']} {p['name']}" for p in func["parameters"]]) + + ")" + ) + file_info["functions"].append( + dict( + name=func["name"], + type=func["type"], + doc="", + signature=signature, + params=[ + dict(name=p["name"], type=p["type"], doc="") + for p in func["parameters"] + ], ) + ) - # Process regular functions with comments - for match in re.finditer(function_pattern, content, re.DOTALL): - comment = match.group(1).strip().replace("*", "").strip() - name = match.group(2) - # Skip if already processed as kernel - if any(f["name"] == name for f in file_info["functions"]): - continue - signature = extract_function_signature(content, name) - params = extract_function_params(content, name) - - file_info["functions"].append( - { - "name": name, - "type": "function", - "doc": comment, - "signature": signature, - "params": params, - } - ) + if len(file_info["functions"]) > 0: + kernel_info.append(file_info) - # Process kernels without comments - for match in re.finditer(simple_kernel_pattern, content): - name = match.group(1) - # Skip if already processed - if any(f["name"] == name for f in file_info["functions"]): - continue - signature = extract_function_signature(content, name) - params = extract_function_params(content, name) - - file_info["functions"].append( - { - "name": name, - "type": "kernel", - "doc": "", - "signature": signature, - "params": params, - } - ) + return kernel_info - # Process simple functions without comments - for match in re.finditer(simple_function_pattern, content): - name = match.group(1) - # Skip if already processed - if any(f["name"] == name for f in file_info["functions"]): - continue - signature = extract_function_signature(content, name) - params = extract_function_params(content, name) - - file_info["functions"].append( - { - "name": name, - "type": "function", - "doc": "", - "signature": signature, - "params": params, - } - ) - # Only add files that have documented functions - if file_info["functions"]: - kernel_info.append(file_info) +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Generate documentation for kernel projects" + ) + parser.add_argument("project_path", help="Path to the kernel project root") + parser.add_argument("--output", "-o", help="Output directory", default="docs") + parser.add_argument( + "--toc", + "-t", + help="Include table of contents", + action="store_true", + default=True, + ) + return parser.parse_args() - except Exception as e: - print(f"Error processing {source_file}: {e}") - return kernel_info +def parse_kernel_config(config: Dict[str, Any]) -> Dict[str, Any]: + """Extract kernel configuration from the build.toml.""" + kernel_config = {} + + # Extract kernels section + for key, value in config.items(): + if key == "kernel": + kernel_config = value + break + + return kernel_config + + +def parse_build_config(project_path: str) -> Optional[Dict[str, Any]]: + """Parse the build.toml configuration file.""" + build_config_path = os.path.join(project_path, "build.toml") + if not os.path.exists(build_config_path): + print(f"Error: build.toml not found at {build_config_path}") + return None + + try: + config = toml.load(build_config_path) + print(f"Successfully parsed build configuration from {build_config_path}") + return config + except Exception as e: + print(f"Error parsing build.toml: {e}") + return None def generate_toc(sections): From 5be3b1e28faa8508ff3ea08586100971e9197b78 Mon Sep 17 00:00:00 2001 From: drbh Date: Tue, 25 Mar 2025 15:36:33 -0400 Subject: [PATCH 3/3] feat: simplify doc gen by only parsing TORCH_LIBRARY_EXPAND defs --- scripts/gen-docs.py | 477 +++++++++++--------------------------------- 1 file changed, 119 insertions(+), 358 deletions(-) diff --git a/scripts/gen-docs.py b/scripts/gen-docs.py index 491f1e7..ac7bc5f 100644 --- a/scripts/gen-docs.py +++ b/scripts/gen-docs.py @@ -1,381 +1,142 @@ # /// script -# dependencies = [ -# "toml", -# "clang", -# ] +# dependencies = ["toml", "minijinja"] # /// -from clang.cindex import Config -import clang.cindex -from clang.cindex import CursorKind -from typing import Dict, List, Optional, Any, Tuple -import os -from pathlib import Path -import argparse -from datetime import datetime import toml +import sys +import os +import glob +import re +from minijinja import Environment + + +def extract_ops(file_path): + ops = [] + with open(file_path, "r") as f: + content = f.read() + # Find TORCH_LIBRARY_EXPAND blocks + lib_blocks = re.findall( + r"TORCH_LIBRARY_EXPAND\([^)]+\)([^}]+)}", content, re.DOTALL + ) + for block in lib_blocks: + # Extract ops.def lines + op_defs = re.findall(r"ops\.def\(\"([^\"]+)\"", block) + ops.extend(op_defs) + return ops -Config.set_library_file("/Library/Developer/CommandLineTools/usr/lib/libclang.dylib") -Config.set_compatibility_check(False) - - -def get_function_declarations(file_path): - """ - Extract all function declarations from a C++ file. - - Args: - file_path (str): Path to the C++ file - - Returns: - list: List of dictionaries containing function information - """ - # Initialize clang index - index = clang.cindex.Index.create() - - # Parse the file - translation_unit = index.parse(file_path) - - # Check for parsing errors - if not translation_unit: - print(f"Error parsing {file_path}") - return [] - - functions = [] - - # Helper function to recursively traverse the AST - def traverse_ast(cursor, parent=None): - # Check if the cursor represents a function declaration - if cursor.kind == CursorKind.FUNCTION_DECL: - # Get function return type - return_type = cursor.type.get_result().spelling - - # Get function name - func_name = cursor.spelling - - # Get function parameters - params = [] - for param in cursor.get_arguments(): - params.append({"name": param.spelling, "type": param.type.spelling}) - - # Get function location - location = cursor.location - file_path = location.file.name if location.file else "Unknown" - line = location.line - column = location.column - - # Check if the function has a body - has_body = any( - c.kind == CursorKind.COMPOUND_STMT for c in cursor.get_children() - ) - - # Determine if it's a declaration or definition - func_type = "definition" if has_body else "declaration" - - # Add function info to our list - functions.append( - { - "name": func_name, - "return_type": return_type, - "parameters": params, - "location": {"file": file_path, "line": line, "column": column}, - "type": func_type, - } - ) - - # Recursively process children - for child in cursor.get_children(): - traverse_ast(child, cursor) - - # Start traversing from the translation unit cursor - traverse_ast(translation_unit.cursor) - - return functions - - -def extract_kernel_docs( - project_path: str, kernel_config: Dict[str, Any] -) -> List[Dict[str, Any]]: - """Extract documentation from kernel source files.""" - kernel_info = [] - source_files = [] - - # Collect source files from kernel config first - for kernel_name, kernel_data in kernel_config.items(): - if "src" in kernel_data and isinstance(kernel_data["src"], list): - for src_file in kernel_data["src"]: - file_path = os.path.join(project_path, src_file) - if os.path.exists(file_path): - source_files.append(Path(file_path)) - - # Also collect all other source files - for ext in [".cu", ".h", ".cpp", ".cuh"]: - extra_files = [] - for p in Path(project_path).glob(f"**/*{ext}"): - # avoid adding `torch-ext` files - if "torch-ext" in str(p): - continue - if p not in source_files: - extra_files.append(p) - - source_files.extend(extra_files) - - source_files = sorted(set(source_files)) - - for source_file in source_files: - rel_path = os.path.relpath(source_file, project_path) - file_info = {"file": rel_path, "functions": []} - - functions = get_function_declarations(source_file) - for func in functions: - signature = f"{func['return_type']} {func['name']}(" - signature += ( - ", ".join([f"{p['type']} {p['name']}" for p in func["parameters"]]) - + ")" - ) - file_info["functions"].append( - dict( - name=func["name"], - type=func["type"], - doc="", - signature=signature, - params=[ - dict(name=p["name"], type=p["type"], doc="") - for p in func["parameters"] - ], - ) - ) - - if len(file_info["functions"]) > 0: - kernel_info.append(file_info) - - return kernel_info - - -def parse_args(): - """Parse command line arguments.""" - parser = argparse.ArgumentParser( - description="Generate documentation for kernel projects" - ) - parser.add_argument("project_path", help="Path to the kernel project root") - parser.add_argument("--output", "-o", help="Output directory", default="docs") - parser.add_argument( - "--toc", - "-t", - help="Include table of contents", - action="store_true", - default=True, - ) - return parser.parse_args() - - -def parse_kernel_config(config: Dict[str, Any]) -> Dict[str, Any]: - """Extract kernel configuration from the build.toml.""" - kernel_config = {} - - # Extract kernels section - for key, value in config.items(): - if key == "kernel": - kernel_config = value - break - - return kernel_config +def main(): + if len(sys.argv) < 2: + print("Usage: script.py ") + return + project_dir = sys.argv[1] -def parse_build_config(project_path: str) -> Optional[Dict[str, Any]]: - """Parse the build.toml configuration file.""" - build_config_path = os.path.join(project_path, "build.toml") + # Read build.toml or return if not found + build_config_path = os.path.join(project_dir, "build.toml") if not os.path.exists(build_config_path): print(f"Error: build.toml not found at {build_config_path}") - return None - + return try: - config = toml.load(build_config_path) + build_config = toml.load(build_config_path) print(f"Successfully parsed build configuration from {build_config_path}") - return config + except Exception as e: print(f"Error parsing build.toml: {e}") - return None - - -def generate_toc(sections): - """Generate a table of contents from section headers.""" - toc = ["## Table of Contents", ""] - for section in sections: - indent = " " * (section["level"] - 1) - link = ( - section["title"] - .lower() - .replace(" ", "-") - .replace(".", "") - .replace("(", "") - .replace(")", "") - .replace(":", "") + return + + if not build_config: + return + + # Get all kernels from the config + config_kernels = build_config.get("kernel", {}) + + kernels = [] + for kernel_name, kernel_info in config_kernels.items(): + kernels.append( + { + "name": kernel_name, + "cuda-capabilities": kernel_info.get("cuda-capabilities", []), + "src": kernel_info.get("src", []), + "dependencies": kernel_info.get("depends", []), + } ) - toc.append(f"{indent}- [{section['title']}](#{link})") - return toc - - -def format_parameter_table(params): - """Format function parameters as a markdown table.""" - if not params: - return "" - - table = ["| Parameter | Type |", "|-----------|------|"] - for param in params: - table.append(f"| `{param['name']}` | `{param['type']}` |") - - return "\n".join(table) - - -def generate_markdown( - project_path: str, - config: Dict[str, Any], - kernel_info: List[Dict[str, Any]], - include_toc: bool = True, -) -> str: - """Generate markdown documentation from parsed information.""" - project_name = os.path.basename(os.path.abspath(project_path)) - sections = [] - - # Extract project name from config if available - if config and "general" in config and "name" in config["general"]: - project_name = config["general"]["name"] - - # Start with title and metadata - lines = [ - f"# `{project_name}` Documentation", - "", - f"*Generated on {datetime.now().strftime('%Y-%m-%d')}*", - "", - ] - - # Add project overview - sections.append({"title": "Project Overview", "level": 2}) - lines.extend(["## Project Overview", ""]) - - # If we have a description in the config, use it - if config and "general" in config and "description" in config.get("general", {}): - lines.append(config["general"]["description"]) - else: - lines.append(f"{project_name} is a CUDA kernel project.") - lines.append("") - - # Add configuration details - if config: - sections.append({"title": "Build Configuration", "level": 2}) - lines.append("## Build Configuration") - lines.append("") - - # Extract kernel configurations - kernel_configs = parse_kernel_config(config) - if kernel_configs: - sections.append({"title": "Kernels", "level": 3}) - lines.append("### Kernels") - lines.append("") - - for kernel_name, kernel_data in kernel_configs.items(): - lines.append(f"#### {kernel_name}") - lines.append("") - - if "cuda-capabilities" in kernel_data: - capabilities = ", ".join(kernel_data["cuda-capabilities"]) - lines.append("**CUDA Capabilities:**") - lines.append(f"- `[{capabilities}]`") - lines.append("") - - if "src" in kernel_data: - lines.append("**Source Files:**") - for src_file in kernel_data["src"]: - lines.append(f"- `{src_file}`") - lines.append("") - - if "depends" in kernel_data: - lines.append("**Dependencies:**") - for dep in kernel_data["depends"]: - lines.append(f"- `{dep}`") - lines.append("") - - # Add source documentation - if kernel_info: - sections.append({"title": "API Documentation", "level": 2}) - lines.append("## API Documentation") - lines.append("") - for file_info in kernel_info: - sections.append({"title": file_info["file"], "level": 3}) - lines.append(f"### {file_info['file']}") - lines.append("") - - # Add functions and kernels - for func in sorted(file_info["functions"], key=lambda x: x["name"]): - func_type = "Kernel" if func["type"] == "kernel" else "Function" - sections.append({"title": f"{func['name']} ({func_type})", "level": 4}) - lines.append(f"#### {func['name']} ({func_type})") - lines.append("") - - # Add function signature in code block - lines.append("```cpp") - lines.append(func["signature"]) - lines.append("```") - lines.append("") - - # Add documentation if it exists - if func["doc"]: - lines.append(func["doc"]) - lines.append("") - - # Add parameter table if not in short mode - if func["params"]: - lines.append("**Parameters:**") - lines.append("") - lines.append(format_parameter_table(func["params"])) - lines.append("") - - # Insert table of contents after the title if requested - if include_toc: - toc = generate_toc(sections) - lines = lines[:3] + [""] + toc + [""] + lines[3:] - - return "\n".join(lines) - - -def main(): - """Main function to run the documentation generator.""" - args = parse_args() - project_path = args.project_path - - # Parse build configuration - config = parse_build_config(project_path) - - # Parse kernel configs from build.toml - kernel_config = parse_kernel_config(config) if config else {} - - # Extract kernel documentation - kernel_info = extract_kernel_docs(project_path, kernel_config) - - # Generate documentation - markdown_content = generate_markdown( - project_path, - config, - kernel_info, - include_toc=args.toc, + # Find torch-ext directory and extract all ops + ops = [] + torch_ext_dirs = glob.glob( + os.path.join(project_dir, "**/torch-ext"), recursive=True ) - - # Create output directory if it doesn't exist - output_dir = os.path.join(project_path, args.output) - os.makedirs(output_dir, exist_ok=True) + for torch_ext_dir in torch_ext_dirs: + for file in glob.glob(os.path.join(torch_ext_dir, "**/*.cpp"), recursive=True): + for op in extract_ops(file): + relative_file = "../" + os.path.relpath(file, project_dir) + ops.append(dict(op=op, file=relative_file)) + + # Prepare template data + template_data = { + "project_name": os.path.basename(project_dir), + "build_config": {"kernels": kernels}, + "ops": ops, + } + + # Render template + env = Environment( + templates={ + "doc_template": """ +# `{{ project_name }}` Documentation + +> __Generated on 2025-03-25__ + +## Table of Contents + +- [Project Overview](#project-overview) +- [Build Configuration](#build-configuration) + - [Kernels](#kernels){% for kernel_info in build_config.get("kernels", {}) %}\n - [{{ kernel_info.get("name") }}](#{{ kernel_info.get("name") }}){% endfor %} + - [Operations](#operations){% for op in ops %}\n - [{{ op.get("op") }}]({{ op.get("file") }}){% endfor %} + +## Project Overview +{{ project_name }} is a CUDA kernel project. + +## Build Configuration + +### Kernels +{% for kernel_info in build_config.get("kernels", {}) %} +#### {{ kernel_info.get("name") }} + +**CUDA Capabilities:** +- `{{ kernel_info.get("cuda-capabilities", []) }}` + +**Source Files:** +{% for source in kernel_info.get("src", []) %} +- `{{ source }}` +{% endfor %} + +**Dependencies:** +{% for dep in kernel_info.get("dependencies", []) %} +- `{{ dep }}` +{% endfor %} +{% endfor %} + +### operations +{% for op in ops %} +```cpp +{{ op.get("op") }} +``` +[defined]({{ op.get("file") }}) + +{% endfor %} +""" + } + ) + output = env.render_template("doc_template", **template_data) # Write output file - project_name = os.path.basename(os.path.abspath(project_path)) - if config and "general" in config and "name" in config["general"]: - project_name = config["general"]["name"] - + project_name = os.path.basename(os.path.abspath(project_dir)) + output_dir = os.path.join(project_dir, "docs") + os.makedirs(output_dir, exist_ok=True) output_path = os.path.join(output_dir, f"{project_name}.md") - with open(output_path, "w", encoding="utf-8") as f: - f.write(markdown_content) - - print(f"Documentation generated at {output_path}") + f.write(output) if __name__ == "__main__":